{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables   #-}

module Ouroboros.Consensus.Util.IOLike (
    IOLike (..)
    -- * Re-exports
    -- *** MonadThrow
  , Exception (..)
  , ExitCase (..)
  , MonadCatch (..)
  , MonadMask (..)
  , MonadThrow (..)
  , SomeException
    -- *** MonadSTM
  , module Ouroboros.Consensus.Util.MonadSTM.NormalForm
    -- *** MonadFork, TODO: Should we hide this in favour of MonadAsync?
  , MonadFork (..)
  , MonadThread (..)
  , labelThisThread
    -- *** MonadAsync
  , ExceptionInLinkedThread (..)
  , MonadAsync (..)
  , link
  , linkTo
    -- *** MonadST
  , MonadST (..)
    -- *** MonadTime
  , DiffTime
  , MonadMonotonicTime (..)
  , Time (..)
  , addTime
  , diffTime
    -- *** MonadDelay
  , MonadDelay (..)
    -- *** MonadEventlog
  , MonadEventlog (..)
    -- *** MonadEvaluate
  , MonadEvaluate (..)
    -- *** NoThunks
  , NoThunks (..)
  ) where

import           NoThunks.Class (NoThunks (..))

import           Cardano.Crypto.KES (KESAlgorithm, SignKeyKES)
import qualified Cardano.Crypto.KES as KES

import           Control.Monad.Class.MonadAsync
import           Control.Monad.Class.MonadEventlog
import           Control.Monad.Class.MonadFork
import           Control.Monad.Class.MonadST
import           Control.Monad.Class.MonadThrow
import           Control.Monad.Class.MonadTime hiding (MonadTime (..))
import           Control.Monad.Class.MonadTimer

import           Ouroboros.Consensus.Util.MonadSTM.NormalForm
import           Ouroboros.Consensus.Util.Orphans ()

import           Data.Functor (void)

{-------------------------------------------------------------------------------
  IOLike
-------------------------------------------------------------------------------}

class ( MonadAsync              m
      , MonadEventlog           m
      , MonadFork               m
      , MonadST                 m
      , MonadDelay              m
      , MonadThread             m
      , MonadThrow              m
      , MonadCatch              m
      , MonadMask               m
      , MonadMonotonicTime      m
      , MonadEvaluate           m
      , MonadThrow         (STM m)
      , forall a. NoThunks (m a)
      , forall a. NoThunks a => NoThunks (StrictTVar m a)
      , forall a. NoThunks a => NoThunks (StrictMVar m a)
      ) => IOLike m where
  -- | Securely forget a KES signing key.
  --
  -- No-op for the IOSim, but 'KES.forgetSignKeyKES' for IO.
  forgetSignKeyKES :: KESAlgorithm v => SignKeyKES v -> m ()

instance IOLike IO where
  forgetSignKeyKES :: SignKeyKES v -> IO ()
forgetSignKeyKES = SignKeyKES v -> IO ()
forall v. KESAlgorithm v => SignKeyKES v -> IO ()
KES.forgetSignKeyKES

-- | Generalization of 'link' that links an async to an arbitrary thread.
--
-- Non standard (not in 'async' library)
--
linkTo :: (MonadAsync m, MonadFork m, MonadMask m)
       => ThreadId m -> Async m a -> m ()
linkTo :: ThreadId m -> Async m a -> m ()
linkTo ThreadId m
tid = ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
linkToOnly ThreadId m
tid (Bool -> Bool
not (Bool -> Bool) -> (SomeException -> Bool) -> SomeException -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> Bool
isCancel)

-- | Generalization of 'linkOnly' that links an async to an arbitrary thread.
--
-- Non standard (not in 'async' library).
--
linkToOnly :: forall m a. (MonadAsync m, MonadFork m, MonadMask m)
           => ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
linkToOnly :: ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
linkToOnly ThreadId m
tid SomeException -> Bool
shouldThrow Async m a
a = do
    m (ThreadId m) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (ThreadId m) -> m ()) -> m (ThreadId m) -> m ()
forall a b. (a -> b) -> a -> b
$ String -> m () -> m (ThreadId m)
forall (m :: * -> *) a.
(MonadFork m, MonadMask m) =>
String -> m a -> m (ThreadId m)
forkRepeat (String
"linkToOnly " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> ThreadId m -> String
forall a. Show a => a -> String
show ThreadId m
linkedThreadId) (m () -> m (ThreadId m)) -> m () -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ do
      Either SomeException a
r <- Async m a -> m (Either SomeException a)
forall (m :: * -> *) a.
MonadAsync m =>
Async m a -> m (Either SomeException a)
waitCatch Async m a
a
      case Either SomeException a
r of
        Left SomeException
e | SomeException -> Bool
shouldThrow SomeException
e -> ThreadId m -> ExceptionInLinkedThread -> m ()
forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
tid (SomeException -> ExceptionInLinkedThread
exceptionInLinkedThread SomeException
e)
        Either SomeException a
_otherwise             -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where
    linkedThreadId :: ThreadId m
    linkedThreadId :: ThreadId m
linkedThreadId = Async m a -> ThreadId m
forall (m :: * -> *) a. MonadAsync m => Async m a -> ThreadId m
asyncThreadId Async m a
a

    exceptionInLinkedThread :: SomeException -> ExceptionInLinkedThread
    exceptionInLinkedThread :: SomeException -> ExceptionInLinkedThread
exceptionInLinkedThread =
        String -> SomeException -> ExceptionInLinkedThread
ExceptionInLinkedThread (ThreadId m -> String
forall a. Show a => a -> String
show ThreadId m
linkedThreadId)

isCancel :: SomeException -> Bool
isCancel :: SomeException -> Bool
isCancel SomeException
e
  | Just AsyncCancelled
AsyncCancelled <- SomeException -> Maybe AsyncCancelled
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e = Bool
True
  | Bool
otherwise = Bool
False

forkRepeat :: (MonadFork m, MonadMask m) => String -> m a -> m (ThreadId m)
forkRepeat :: String -> m a -> m (ThreadId m)
forkRepeat String
label m a
action =
  ((forall a. m a -> m a) -> m (ThreadId m)) -> m (ThreadId m)
forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask (((forall a. m a -> m a) -> m (ThreadId m)) -> m (ThreadId m))
-> ((forall a. m a -> m a) -> m (ThreadId m)) -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore ->
    let go :: m ()
go = do Either SomeException a
r <- m a -> m (Either SomeException a)
forall (m :: * -> *) a.
MonadCatch m =>
m a -> m (Either SomeException a)
tryAll (m a -> m a
forall a. m a -> m a
restore m a
action)
                case Either SomeException a
r of
                  Left SomeException
_ -> m ()
go
                  Either SomeException a
_      -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    in m () -> m (ThreadId m)
forall (m :: * -> *). MonadFork m => m () -> m (ThreadId m)
forkIO (String -> m ()
forall (m :: * -> *). MonadThread m => String -> m ()
labelThisThread String
label m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m ()
go)

tryAll :: MonadCatch m => m a -> m (Either SomeException a)
tryAll :: m a -> m (Either SomeException a)
tryAll = m a -> m (Either SomeException a)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try