{-# LANGUAGE ConstraintKinds     #-}
{-# LANGUAGE InstanceSigs        #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE NumericUnderscores  #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Control.Monad.Class.MonadTimer.SI
  ( -- * Type classes
    MonadDelay (..)
  , MonadTimer (..)
    -- * Auxiliary functions
  , diffTimeToMicrosecondsAsInt
  , microsecondsAsIntToDiffTime
    -- * Re-exports
  , DiffTime
  , MonadFork
  , MonadMonotonicTime
  , MonadTime
  , TimeoutState (..)
    -- * Default implementations
  , defaultRegisterDelay
  , defaultRegisterDelayCancellable
  ) where

import Control.Concurrent.Class.MonadSTM
import Control.Exception (assert)
import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadTime.SI
import Control.Monad.Class.MonadTimer qualified as MonadTimer
import Control.Monad.Class.MonadTimer.NonStandard (TimeoutState (..))
import Control.Monad.Class.MonadTimer.NonStandard qualified as NonStandard

import Control.Monad.Reader

import Data.Bifunctor (bimap)
import Data.Functor (($>))
import Data.Time.Clock (diffTimeToPicoseconds)

-- | Convert 'DiffTime' in seconds to microseconds represented by an 'Int'.
-- Note that on 32bit systems it can only represent `2^31-1` seconds, which is
-- only ~35 minutes.
-- It doesn't prevent under- or overflows; when assertions are on it will thrown
-- an assertion exception.
diffTimeToMicrosecondsAsInt :: DiffTime -> Int
diffTimeToMicrosecondsAsInt :: DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d =
    let usec :: Integer
        usec :: Integer
usec = DiffTime -> Integer
diffTimeToPicoseconds DiffTime
d Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
1_000_000 in
    Bool -> Int -> Int
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Integer
usec Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
forall a. Bounded a => a
maxBound :: Int)
         Bool -> Bool -> Bool
&& Integer
usec Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
forall a. Bounded a => a
minBound :: Int)) (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
    Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer

-- | Convert time in microseconds in 'DiffTime' (measured in seconds).
microsecondsAsIntToDiffTime :: Int -> DiffTime
microsecondsAsIntToDiffTime :: Int -> DiffTime
microsecondsAsIntToDiffTime = (DiffTime -> DiffTime -> DiffTime
forall a. Fractional a => a -> a -> a
/ DiffTime
1_000_000) (DiffTime -> DiffTime) -> (Int -> DiffTime) -> Int -> DiffTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> DiffTime
forall a b. (Integral a, Num b) => a -> b

class ( MonadTimer.MonadDelay m
      , MonadMonotonicTime m
      ) => MonadDelay m where
  threadDelay :: DiffTime -> m ()

-- | Thread delay. This implementation will not over- or underflow.
-- For delay larger than what `Int` can represent (see
-- `diffTimeToMicrosecondsAsInt`), it will recursively call
-- `Control.Monad.Class.MonadTimer.threadDelay`.
-- For delays smaller than `minBound :: Int` seconds, `minBound :: Int` will be
-- used instead.
instance MonadDelay IO where
  threadDelay :: forall m.
                 MonadDelay m
              => DiffTime -> m ()
  threadDelay :: forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
d | DiffTime
d DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
<= DiffTime
0        = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  threadDelay DiffTime
d | DiffTime
d DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
<= DiffTime
maxDelay =
      Int -> m ()
forall (m :: * -> *). MonadDelay m => Int -> m ()
MonadTimer.threadDelay (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
      maxDelay :: DiffTime
      maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime Int
forall a. Bounded a => a

  threadDelay DiffTime
d = do
      c <- m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
      let u = DiffTime
d DiffTime -> Time -> Time
`addTime` Time
      go c u
      maxDelay :: DiffTime
      maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime Int
forall a. Bounded a => a

      go :: Time -> Time -> m ()
      go :: Time -> Time -> m ()
go Time
c Time
u = do
        if DiffTime
d' DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
>= DiffTime
          then do
            Int -> m ()
forall (m :: * -> *). MonadDelay m => Int -> m ()
MonadTimer.threadDelay Int
forall a. Bounded a => a
            c' <- m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
            go  c' u
            Int -> m ()
forall (m :: * -> *). MonadDelay m => Int -> m ()
MonadTimer.threadDelay (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
          d' :: DiffTime
d' = Time
u Time -> Time -> DiffTime
`diffTime` Time

instance MonadDelay m => MonadDelay (ReaderT r m) where
  threadDelay :: DiffTime -> ReaderT r m ()
threadDelay = m () -> ReaderT r m ()
forall (m :: * -> *) a. Monad m => m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ())
-> (DiffTime -> m ()) -> DiffTime -> ReaderT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> m ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()

class ( MonadTimer.MonadTimer m
      , MonadMonotonicTime m
      ) => MonadTimer m where

  -- | A register delay function which safe on 32-bit systems.
  registerDelay            :: DiffTime -> m (TVar m Bool)

  -- | A cancellable register delay which is safe on 32-bit systems and efficient
  -- for delays smaller than what `Int` can represent (especially on systems which
  -- support native timer manager).
  registerDelayCancellable :: DiffTime -> m (STM m TimeoutState, m ())

  -- | A timeout function.
  -- __TODO__: /'IO' instance is not safe on 32-bit systems./
  timeout                  :: DiffTime -> m a -> m (Maybe a)

-- | A default implementation of `registerDelay` which supports delays longer
-- then `Int`; this is especially important on 32-bit systems where maximum
-- delay expressed in microseconds is around 35 minutes.
defaultRegisterDelay :: forall m timeout.
                        ( MonadFork m
                        , MonadMonotonicTime m
                        , MonadSTM m
                     => NonStandard.NewTimeout m timeout
                     -> NonStandard.AwaitTimeout m timeout
                     -> DiffTime
                     -> m (TVar m Bool)
defaultRegisterDelay :: forall (m :: * -> *) timeout.
(MonadFork m, MonadMonotonicTime m, MonadSTM m) =>
NewTimeout m timeout
-> AwaitTimeout m timeout -> DiffTime -> m (TVar m Bool)
defaultRegisterDelay NewTimeout m timeout
newTimeout AwaitTimeout m timeout
awaitTimeout DiffTime
d = do
    c <- m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
    v <- atomically $ newTVar False
    tid <- forkIO $ go v c (d `addTime` c)
    labelThread tid "delay-thread"
    return v
    maxDelay :: DiffTime
    maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime Int
forall a. Bounded a => a

    go :: TVar m Bool -> Time -> Time -> m ()
    go :: TVar m Bool -> Time -> Time -> m ()
go TVar m Bool
v Time
c Time
u | Time
u Time -> Time -> DiffTime
`diffTime` Time
c DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
>= DiffTime
maxDelay = do
      _ <- NewTimeout m timeout
newTimeout Int
forall a. Bounded a => a
maxBound m timeout -> (timeout -> m Bool) -> m Bool
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STM m Bool -> m Bool
forall a. (?callStack::CallStack) => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, ?callStack::CallStack) =>
STM m a -> m a
atomically (STM m Bool -> m Bool)
-> AwaitTimeout m timeout -> timeout -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AwaitTimeout m timeout
      c' <- getMonotonicTime
      go v c' u

    go TVar m Bool
v Time
c Time
u = do
      t <- NewTimeout m timeout
newTimeout (DiffTime -> Int
diffTimeToMicrosecondsAsInt (DiffTime -> Int) -> DiffTime -> Int
forall a b. (a -> b) -> a -> b
$ Time
u Time -> Time -> DiffTime
`diffTime` Time
      atomically $ do
        _ <- awaitTimeout t
        writeTVar v True

-- | A cancellable register delay which is safe on 32-bit systems and efficient
-- for delays smaller than what `Int` can represent (especially on systems which
-- support native timer manager).
defaultRegisterDelayCancellable :: forall m timeout.
                                   ( MonadFork m
                                   , MonadMonotonicTime m
                                   , MonadSTM m
                                => NonStandard.NewTimeout    m timeout
                                -> NonStandard.ReadTimeout   m timeout
                                -> NonStandard.CancelTimeout m timeout
                                -> NonStandard.AwaitTimeout  m timeout
                                -> DiffTime
                                -> m (STM m TimeoutState, m ())

defaultRegisterDelayCancellable :: forall (m :: * -> *) timeout.
(MonadFork m, MonadMonotonicTime m, MonadSTM m) =>
NewTimeout m timeout
-> ReadTimeout m timeout
-> CancelTimeout m timeout
-> AwaitTimeout m timeout
-> DiffTime
-> m (STM m TimeoutState, m ())
defaultRegisterDelayCancellable NewTimeout m timeout
newTimeout ReadTimeout m timeout
readTimeout CancelTimeout m timeout
cancelTimeout AwaitTimeout m timeout
_awaitTimeout DiffTime
d | DiffTime
d DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
<= DiffTime
maxDelay = do
    t <- NewTimeout m timeout
newTimeout (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
    return (readTimeout t, cancelTimeout t)
    maxDelay :: DiffTime
    maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime Int
forall a. Bounded a => a

defaultRegisterDelayCancellable NewTimeout m timeout
newTimeout ReadTimeout m timeout
_readTimeout CancelTimeout m timeout
_cancelTimeout AwaitTimeout m timeout
awaitTimeout DiffTime
d = do
    -- current time
    c <- m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
    -- timeout state
    v <- newTVarIO TimeoutPending
    tid <- forkIO $ go v c (d `addTime` c)
    labelThread tid "delay-thread"
    let cancel = STM m () -> m ()
forall a. (?callStack::CallStack) => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, ?callStack::CallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ TVar m TimeoutState -> STM m TimeoutState
forall a. TVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m TimeoutState
v STM m TimeoutState -> (TimeoutState -> STM m ()) -> STM m ()
forall a b. STM m a -> (a -> STM m b) -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
TimeoutCancelled -> () -> STM m ()
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
TimeoutFired     -> () -> STM m ()
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
TimeoutPending   -> TVar m TimeoutState -> TimeoutState -> STM m ()
forall a. TVar m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar m TimeoutState
v TimeoutState
    return (readTVar v, cancel)
    maxDelay :: DiffTime
    maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime Int
forall a. Bounded a => a

    go :: TVar m TimeoutState
       -> Time
       -> Time
       -> m ()
    go :: TVar m TimeoutState -> Time -> Time -> m ()
go TVar m TimeoutState
v Time
c Time
u | Time
u Time -> Time -> DiffTime
`diffTime` Time
c DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
>= DiffTime
maxDelay = do
      t <- NewTimeout m timeout
newTimeout Int
forall a. Bounded a => a
      ts <- atomically $ do
        (readTVar v >>= \case
           a :: TimeoutState
TimeoutCancelled -> TimeoutState -> STM m TimeoutState
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return TimeoutState
TimeoutFired       -> String -> STM m TimeoutState
forall a. (?callStack::CallStack) => String -> a
error String
"registerDelayCancellable: invariant violation!"
TimeoutPending     -> STM m TimeoutState
forall a. STM m a
forall (m :: * -> *) a. MonadSTM m => STM m a
        -- the overall timeout is still pending when 't' fires
        (awaitTimeout t $> TimeoutPending)
      case ts of
TimeoutPending -> do
          c' <- m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
          go v c' u
_ -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    go TVar m TimeoutState
v Time
c Time
u = do
      t <- NewTimeout m timeout
newTimeout (DiffTime -> Int
diffTimeToMicrosecondsAsInt (DiffTime -> Int) -> DiffTime -> Int
forall a b. (a -> b) -> a -> b
$ Time
u Time -> Time -> DiffTime
`diffTime` Time
      atomically $ do
        ts <- (readTVar v >>= \case
                 a :: TimeoutState
TimeoutCancelled -> TimeoutState -> STM m TimeoutState
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return TimeoutState
TimeoutFired       -> String -> STM m TimeoutState
forall a. (?callStack::CallStack) => String -> a
error String
"registerDelayCancellable: invariant violation!"
TimeoutPending     -> STM m TimeoutState
forall a. STM m a
forall (m :: * -> *) a. MonadSTM m => STM m a
              -- the overall timeout fires when 't' fires
              (awaitTimeout t $> TimeoutFired)
        case ts of
TimeoutFired -> TVar m TimeoutState -> TimeoutState -> STM m ()
forall a. TVar m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar m TimeoutState
v TimeoutState
_            -> () -> STM m ()
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Like 'GHC.Conc.registerDelay' but safe on 32-bit systems.  When the delay
-- is larger than what `Int` can represent it will fork a thread which will
-- write to the returned 'TVar' once the delay has passed.  When the delay is
-- small enough it will use the `MonadTimer`'s `registerDelay` (e.g. for `IO`
-- monad it will use the `GHC`'s `GHC.Conc.registerDelay`).
-- __TODO__: /'timeout' is not safe on 32-bit systems./
instance MonadTimer IO where
  registerDelay :: DiffTime -> IO (TVar IO Bool)
registerDelay DiffTime
      | DiffTime
d DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
<= DiffTime
maxDelay =
        Int -> IO (TVar IO Bool)
forall (m :: * -> *). MonadTimer m => Int -> m (TVar m Bool)
MonadTimer.registerDelay (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
      | Bool
otherwise =
        NewTimeout IO Timeout
-> AwaitTimeout IO Timeout -> DiffTime -> IO (TVar IO Bool)
forall (m :: * -> *) timeout.
(MonadFork m, MonadMonotonicTime m, MonadSTM m) =>
NewTimeout m timeout
-> AwaitTimeout m timeout -> DiffTime -> m (TVar m Bool)
          NewTimeout IO Timeout
          AwaitTimeout IO Timeout
      maxDelay :: DiffTime
      maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime Int
forall a. Bounded a => a

  registerDelayCancellable :: DiffTime -> IO (STM IO TimeoutState, IO ())
registerDelayCancellable =
    NewTimeout IO Timeout
-> ReadTimeout IO Timeout
-> CancelTimeout IO Timeout
-> AwaitTimeout IO Timeout
-> DiffTime
-> IO (STM IO TimeoutState, IO ())
forall (m :: * -> *) timeout.
(MonadFork m, MonadMonotonicTime m, MonadSTM m) =>
NewTimeout m timeout
-> ReadTimeout m timeout
-> CancelTimeout m timeout
-> AwaitTimeout m timeout
-> DiffTime
-> m (STM m TimeoutState, m ())
      NewTimeout IO Timeout
      ReadTimeout IO Timeout
      CancelTimeout IO Timeout
      AwaitTimeout IO Timeout

  timeout :: forall a. DiffTime -> IO a -> IO (Maybe a)
timeout = Int -> IO a -> IO (Maybe a)
forall a. Int -> IO a -> IO (Maybe a)
forall (m :: * -> *) a. MonadTimer m => Int -> m a -> m (Maybe a)
MonadTimer.timeout (Int -> IO a -> IO (Maybe a))
-> (DiffTime -> Int) -> DiffTime -> IO a -> IO (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> Int

instance MonadTimer m => MonadTimer (ReaderT r m) where
  registerDelay :: DiffTime -> ReaderT r m (TVar (ReaderT r m) Bool)
registerDelay            = m (TVar m Bool) -> ReaderT r m (TVar m Bool)
forall (m :: * -> *) a. Monad m => m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (TVar m Bool) -> ReaderT r m (TVar m Bool))
-> (DiffTime -> m (TVar m Bool))
-> DiffTime
-> ReaderT r m (TVar m Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> m (TVar m Bool)
forall (m :: * -> *). MonadTimer m => DiffTime -> m (TVar m Bool)
  registerDelayCancellable :: DiffTime
-> ReaderT r m (STM (ReaderT r m) TimeoutState, ReaderT r m ())
registerDelayCancellable = ((STM m TimeoutState, m ())
 -> (ReaderT r (STM m) TimeoutState, ReaderT r m ()))
-> ReaderT r m (STM m TimeoutState, m ())
-> ReaderT r m (ReaderT r (STM m) TimeoutState, ReaderT r m ())
forall a b. (a -> b) -> ReaderT r m a -> ReaderT r m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((STM m TimeoutState -> ReaderT r (STM m) TimeoutState)
-> (m () -> ReaderT r m ())
-> (STM m TimeoutState, m ())
-> (ReaderT r (STM m) TimeoutState, ReaderT r m ())
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap STM m TimeoutState -> ReaderT r (STM m) TimeoutState
forall (m :: * -> *) a. Monad m => m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m () -> ReaderT r m ()
forall (m :: * -> *) a. Monad m => m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift) (ReaderT r m (STM m TimeoutState, m ())
 -> ReaderT r m (ReaderT r (STM m) TimeoutState, ReaderT r m ()))
-> (DiffTime -> ReaderT r m (STM m TimeoutState, m ()))
-> DiffTime
-> ReaderT r m (ReaderT r (STM m) TimeoutState, ReaderT r m ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (STM m TimeoutState, m ())
-> ReaderT r m (STM m TimeoutState, m ())
forall (m :: * -> *) a. Monad m => m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (STM m TimeoutState, m ())
 -> ReaderT r m (STM m TimeoutState, m ()))
-> (DiffTime -> m (STM m TimeoutState, m ()))
-> DiffTime
-> ReaderT r m (STM m TimeoutState, m ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> m (STM m TimeoutState, m ())
forall (m :: * -> *).
MonadTimer m =>
DiffTime -> m (STM m TimeoutState, m ())
  timeout :: forall a. DiffTime -> ReaderT r m a -> ReaderT r m (Maybe a)
timeout DiffTime
d ReaderT r m a
f              = (r -> m (Maybe a)) -> ReaderT r m (Maybe a)
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m (Maybe a)) -> ReaderT r m (Maybe a))
-> (r -> m (Maybe a)) -> ReaderT r m (Maybe a)
forall a b. (a -> b) -> a -> b
$ \r
r -> DiffTime -> m a -> m (Maybe a)
forall a. DiffTime -> m a -> m (Maybe a)
forall (m :: * -> *) a.
MonadTimer m =>
DiffTime -> m a -> m (Maybe a)
timeout DiffTime
d (ReaderT r m a -> r -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT r m a
f r