{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Control.Concurrent.JobPool
( JobPool
, Job (..)
, withJobPool
, forkJob
, readSize
, readGroupSize
, waitForJob
, cancelGroup
) where
import Data.Functor (($>))
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Control.Concurrent.Class.MonadSTM
import Control.Exception (SomeAsyncException (..))
import Control.Monad (void, when)
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadFork (MonadThread (..))
import Control.Monad.Class.MonadThrow
data JobPool group m a = JobPool {
forall group (m :: * -> *) a.
JobPool group m a -> TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: !(TVar m (Map (group, ThreadId m) (Async m ()))),
forall group (m :: * -> *) a. JobPool group m a -> TQueue m a
completionQueue :: !(TQueue m a)
}
data Job group m a =
Job (m a)
(SomeException -> m a)
group
String
withJobPool :: forall group m a b.
(MonadAsync m, MonadThrow m, MonadLabelledSTM m)
=> (JobPool group m a -> m b) -> m b
withJobPool :: forall group (m :: * -> *) a b.
(MonadAsync m, MonadThrow m, MonadLabelledSTM m) =>
(JobPool group m a -> m b) -> m b
withJobPool =
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket m (JobPool group m a)
create JobPool group m a -> m ()
close
where
create :: m (JobPool group m a)
create :: m (JobPool group m a)
create =
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically forall a b. (a -> b) -> a -> b
$
forall group (m :: * -> *) a.
TVar m (Map (group, ThreadId m) (Async m ()))
-> TQueue m a -> JobPool group m a
JobPool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (m :: * -> *) a. MonadSTM m => a -> STM m (TVar m a)
newTVar forall k a. Map k a
Map.empty forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \TVar m (Map (group, ThreadId m) (Async m ()))
v -> forall (m :: * -> *) a.
MonadLabelledSTM m =>
TVar m a -> String -> STM m ()
labelTVar TVar m (Map (group, ThreadId m) (Async m ()))
v String
"job-pool" forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> TVar m (Map (group, ThreadId m) (Async m ()))
v)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a. MonadSTM m => STM m (TQueue m a)
newTQueue
close :: JobPool group m a -> m ()
close :: JobPool group m a -> m ()
close JobPool{TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: forall group (m :: * -> *) a.
JobPool group m a -> TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar} = do
Map (group, ThreadId m) (Async m ())
jobs <- forall (m :: * -> *) a. MonadSTM m => TVar m a -> m a
readTVarIO TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
uninterruptibleCancel Map (group, ThreadId m) (Async m ())
jobs
forkJob :: forall group m a.
( MonadAsync m, MonadMask m
, Ord group
)
=> JobPool group m a
-> Job group m a
-> m ()
forkJob :: forall group (m :: * -> *) a.
(MonadAsync m, MonadMask m, Ord group) =>
JobPool group m a -> Job group m a -> m ()
forkJob JobPool{TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: forall group (m :: * -> *) a.
JobPool group m a -> TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar, TQueue m a
completionQueue :: TQueue m a
completionQueue :: forall group (m :: * -> *) a. JobPool group m a -> TQueue m a
completionQueue} (Job m a
action SomeException -> m a
handler group
group String
label) =
forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
Async m ()
jobAsync <- forall (m :: * -> *) a. MonadAsync m => m a -> m (Async m a)
async forall a b. (a -> b) -> a -> b
$ do
ThreadId m
tid <- forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId
forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId m
tid String
label
!a
res <- forall (m :: * -> *) e b a.
(MonadCatch m, Exception e) =>
(e -> Maybe b) -> (b -> m a) -> m a -> m a
handleJust SomeException -> Maybe SomeException
notAsyncExceptions SomeException -> m a
handler forall a b. (a -> b) -> a -> b
$
forall a. m a -> m a
restore m a
action
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *) a. MonadSTM m => TQueue m a -> a -> STM m ()
writeTQueue TQueue m a
completionQueue a
res
forall (m :: * -> *) a.
MonadSTM m =>
TVar m a -> (a -> a) -> STM m ()
modifyTVar' TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar (forall k a. Ord k => k -> Map k a -> Map k a
Map.delete (group
group, ThreadId m
tid))
let !tid :: ThreadId m
tid = forall (m :: * -> *) a. MonadAsync m => Async m a -> ThreadId m
asyncThreadId Async m ()
jobAsync
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadSTM m =>
TVar m a -> (a -> a) -> STM m ()
modifyTVar' TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar (forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (group
group, ThreadId m
tid) Async m ()
jobAsync)
forall (m :: * -> *) a. Monad m => a -> m a
return ()
where
notAsyncExceptions :: SomeException -> Maybe SomeException
notAsyncExceptions :: SomeException -> Maybe SomeException
notAsyncExceptions SomeException
e
| Just (SomeAsyncException e
_) <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e
= forall a. Maybe a
Nothing
| Bool
otherwise = forall a. a -> Maybe a
Just SomeException
e
readSize :: MonadSTM m => JobPool group m a -> STM m Int
readSize :: forall (m :: * -> *) group a.
MonadSTM m =>
JobPool group m a -> STM m Int
readSize JobPool{TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: forall group (m :: * -> *) a.
JobPool group m a -> TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar} = forall k a. Map k a -> Int
Map.size forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar
readGroupSize :: ( MonadSTM m
, Eq group
)
=> JobPool group m a -> group -> STM m Int
readGroupSize :: forall (m :: * -> *) group a.
(MonadSTM m, Eq group) =>
JobPool group m a -> group -> STM m Int
readGroupSize JobPool{TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: forall group (m :: * -> *) a.
JobPool group m a -> TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar} group
group =
forall k a. Map k a -> Int
Map.size
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. (k -> a -> Bool) -> Map k a -> Map k a
Map.filterWithKey (\(group
group', ThreadId m
_) Async m ()
_ -> group
group' forall a. Eq a => a -> a -> Bool
== group
group)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar
waitForJob :: MonadSTM m => JobPool group m a -> STM m a
waitForJob :: forall (m :: * -> *) group a.
MonadSTM m =>
JobPool group m a -> STM m a
waitForJob JobPool{TQueue m a
completionQueue :: TQueue m a
completionQueue :: forall group (m :: * -> *) a. JobPool group m a -> TQueue m a
completionQueue} = forall (m :: * -> *) a. MonadSTM m => TQueue m a -> STM m a
readTQueue TQueue m a
completionQueue
cancelGroup :: ( MonadAsync m
, Eq group
)
=> JobPool group m a -> group -> m ()
cancelGroup :: forall (m :: * -> *) group a.
(MonadAsync m, Eq group) =>
JobPool group m a -> group -> m ()
cancelGroup JobPool { TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar :: forall group (m :: * -> *) a.
JobPool group m a -> TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar } group
group = do
Map (group, ThreadId m) (Async m ())
jobs <- forall (m :: * -> *) a. MonadSTM m => TVar m a -> m a
readTVarIO TVar m (Map (group, ThreadId m) (Async m ()))
jobsVar
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) k a b.
Applicative t =>
(k -> a -> t b) -> Map k a -> t (Map k b)
Map.traverseWithKey
(\(group
group', ThreadId m
_) Async m ()
thread ->
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (group
group' forall a. Eq a => a -> a -> Bool
== group
group) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
cancel Async m ()
thread
)
Map (group, ThreadId m) (Async m ())
jobs