{-# LANGUAGE CPP                 #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}

-- | A generalisation of
-- <https://hackage.haskell.org/package/base/docs/Control-Concurrent.html Control.Concurrent>
-- API to both 'IO' and <https://hackage.haskell.org/package/io-sim IOSim>.
--
module Control.Monad.Class.MonadFork
  ( MonadThread (..)
  , labelThisThread
  , MonadFork (..)
  ) where

import Control.Concurrent qualified as IO
import Control.Exception (AsyncException (ThreadKilled), Exception,
           SomeException)
import Control.Monad.Reader (ReaderT (..), lift)
import Data.Kind (Type)
import GHC.Conc.Sync qualified as IO


class (Monad m, Eq   (ThreadId m),
                Ord  (ThreadId m),
                Show (ThreadId m)) => MonadThread m where

  type ThreadId m :: Type

  myThreadId     :: m (ThreadId m)
  labelThread    :: ThreadId m -> String -> m ()

  -- | Requires ghc-9.6.1 or newer.
  --
  -- @since 1.8.0.0
  threadLabel    :: ThreadId m -> m (Maybe String)

-- | Apply the label to the current thread
labelThisThread :: MonadThread m => String -> m ()
labelThisThread :: forall (m :: * -> *). MonadThread m => String -> m ()
labelThisThread String
label = m (ThreadId m)
forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId m (ThreadId m) -> (ThreadId m -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ThreadId m
tid -> ThreadId m -> String -> m ()
forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId m
tid String
label


class MonadThread m => MonadFork m where

  forkIO           :: m () -> m (ThreadId m)
  forkOn           :: Int -> m () -> m (ThreadId m)
  forkIOWithUnmask :: ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
  forkFinally      :: m a -> (Either SomeException a -> m ()) -> m (ThreadId m)
  throwTo          :: Exception e => ThreadId m -> e -> m ()

  killThread       :: ThreadId m -> m ()
  killThread ThreadId m
tid = ThreadId m -> AsyncException -> m ()
forall e. Exception e => ThreadId m -> e -> m ()
forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
tid AsyncException
ThreadKilled

  yield            :: m ()


instance MonadThread IO where
  type ThreadId IO = IO.ThreadId
  myThreadId :: IO (ThreadId IO)
myThreadId   = IO ThreadId
IO (ThreadId IO)
IO.myThreadId
  labelThread :: ThreadId IO -> String -> IO ()
labelThread  = ThreadId -> String -> IO ()
ThreadId IO -> String -> IO ()
IO.labelThread
#if MIN_VERSION_base(4,18,0)
  threadLabel :: ThreadId IO -> IO (Maybe String)
threadLabel = ThreadId -> IO (Maybe String)
ThreadId IO -> IO (Maybe String)
IO.threadLabel
#else
  threadLabel = \_ -> pure Nothing
#endif

instance MonadFork IO where
  forkIO :: IO () -> IO (ThreadId IO)
forkIO           = IO () -> IO ThreadId
IO () -> IO (ThreadId IO)
IO.forkIO
  forkOn :: Int -> IO () -> IO (ThreadId IO)
forkOn           = Int -> IO () -> IO ThreadId
Int -> IO () -> IO (ThreadId IO)
IO.forkOn
  forkIOWithUnmask :: ((forall a. IO a -> IO a) -> IO ()) -> IO (ThreadId IO)
forkIOWithUnmask = ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
((forall a. IO a -> IO a) -> IO ()) -> IO (ThreadId IO)
IO.forkIOWithUnmask
  forkFinally :: forall a.
IO a -> (Either SomeException a -> IO ()) -> IO (ThreadId IO)
forkFinally      = IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
IO a -> (Either SomeException a -> IO ()) -> IO (ThreadId IO)
forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
IO.forkFinally
  throwTo :: forall e. Exception e => ThreadId IO -> e -> IO ()
throwTo          = ThreadId -> e -> IO ()
ThreadId IO -> e -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
IO.throwTo
  killThread :: ThreadId IO -> IO ()
killThread       = ThreadId -> IO ()
ThreadId IO -> IO ()
IO.killThread
  yield :: IO ()
yield            = IO ()
IO.yield

instance MonadThread m => MonadThread (ReaderT r m) where
  type ThreadId (ReaderT r m) = ThreadId m
  myThreadId :: ReaderT r m (ThreadId (ReaderT r m))
myThreadId      = m (ThreadId m) -> ReaderT r m (ThreadId m)
forall (m :: * -> *) a. Monad m => m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m (ThreadId m)
forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId
  labelThread :: ThreadId (ReaderT r m) -> String -> ReaderT r m ()
labelThread ThreadId (ReaderT r m)
t String
l = m () -> ReaderT r m ()
forall (m :: * -> *) a. Monad m => m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ThreadId m -> String -> m ()
forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId m
ThreadId (ReaderT r m)
t String
l)
  threadLabel :: ThreadId (ReaderT r m) -> ReaderT r m (Maybe String)
threadLabel     = m (Maybe String) -> ReaderT r m (Maybe String)
forall (m :: * -> *) a. Monad m => m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Maybe String) -> ReaderT r m (Maybe String))
-> (ThreadId m -> m (Maybe String))
-> ThreadId m
-> ReaderT r m (Maybe String)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ThreadId m -> m (Maybe String)
forall (m :: * -> *).
MonadThread m =>
ThreadId m -> m (Maybe String)
threadLabel

instance MonadFork m => MonadFork (ReaderT e m) where
  forkIO :: ReaderT e m () -> ReaderT e m (ThreadId (ReaderT e m))
forkIO (ReaderT e -> m ()
f)   = (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((e -> m (ThreadId (ReaderT e m)))
 -> ReaderT e m (ThreadId (ReaderT e m)))
-> (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall a b. (a -> b) -> a -> b
$ \e
e -> m () -> m (ThreadId m)
forall (m :: * -> *). MonadFork m => m () -> m (ThreadId m)
forkIO (e -> m ()
f e
e)
  forkOn :: Int -> ReaderT e m () -> ReaderT e m (ThreadId (ReaderT e m))
forkOn Int
n (ReaderT e -> m ()
f) = (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((e -> m (ThreadId (ReaderT e m)))
 -> ReaderT e m (ThreadId (ReaderT e m)))
-> (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall a b. (a -> b) -> a -> b
$ \e
e -> Int -> m () -> m (ThreadId m)
forall (m :: * -> *). MonadFork m => Int -> m () -> m (ThreadId m)
forkOn Int
n (e -> m ()
f e
e)
  forkIOWithUnmask :: ((forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ())
-> ReaderT e m (ThreadId (ReaderT e m))
forkIOWithUnmask (forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ()
k   = (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((e -> m (ThreadId (ReaderT e m)))
 -> ReaderT e m (ThreadId (ReaderT e m)))
-> (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall a b. (a -> b) -> a -> b
$ \e
e -> ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
forall (m :: * -> *).
MonadFork m =>
((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
forkIOWithUnmask (((forall a. m a -> m a) -> m ()) -> m (ThreadId m))
-> ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore ->
                         let restore' :: ReaderT e m a -> ReaderT e m a
                             restore' :: forall a. ReaderT e m a -> ReaderT e m a
restore' (ReaderT e -> m a
f) = (e -> m a) -> ReaderT e m a
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((e -> m a) -> ReaderT e m a) -> (e -> m a) -> ReaderT e m a
forall a b. (a -> b) -> a -> b
$ m a -> m a
forall a. m a -> m a
restore (m a -> m a) -> (e -> m a) -> e -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
f
                         in ReaderT e m () -> e -> m ()
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ()
k ReaderT e m a -> ReaderT e m a
forall a. ReaderT e m a -> ReaderT e m a
restore') e
e
  forkFinally :: forall a.
ReaderT e m a
-> (Either SomeException a -> ReaderT e m ())
-> ReaderT e m (ThreadId (ReaderT e m))
forkFinally ReaderT e m a
f Either SomeException a -> ReaderT e m ()
k     = (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((e -> m (ThreadId (ReaderT e m)))
 -> ReaderT e m (ThreadId (ReaderT e m)))
-> (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall a b. (a -> b) -> a -> b
$ \e
e -> m a -> (Either SomeException a -> m ()) -> m (ThreadId m)
forall a. m a -> (Either SomeException a -> m ()) -> m (ThreadId m)
forall (m :: * -> *) a.
MonadFork m =>
m a -> (Either SomeException a -> m ()) -> m (ThreadId m)
forkFinally (ReaderT e m a -> e -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT e m a
f e
e)
                                      ((Either SomeException a -> m ()) -> m (ThreadId m))
-> (Either SomeException a -> m ()) -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ \Either SomeException a
err -> ReaderT e m () -> e -> m ()
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Either SomeException a -> ReaderT e m ()
k Either SomeException a
err) e
e
  throwTo :: forall e.
Exception e =>
ThreadId (ReaderT e m) -> e -> ReaderT e m ()
throwTo ThreadId (ReaderT e m)
e e
t = m () -> ReaderT e m ()
forall (m :: * -> *) a. Monad m => m a -> ReaderT e m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ThreadId m -> e -> m ()
forall e. Exception e => ThreadId m -> e -> m ()
forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
ThreadId (ReaderT e m)
e e
t)
  yield :: ReaderT e m ()
yield       = m () -> ReaderT e m ()
forall (m :: * -> *) a. Monad m => m a -> ReaderT e m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m ()
forall (m :: * -> *). MonadFork m => m ()
yield