{-# LANGUAGE BangPatterns             #-}
{-# LANGUAGE DataKinds                #-}
{-# LANGUAGE FlexibleContexts         #-}
{-# LANGUAGE GADTs                    #-}
{-# LANGUAGE NamedFieldPuns           #-}
{-# LANGUAGE PolyKinds                #-}
{-# LANGUAGE RankNTypes               #-}
{-# LANGUAGE ScopedTypeVariables      #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeOperators            #-}

-- | Actions for running 'Peer's with a 'Driver'.  This module should be
-- imported qualified.
--
module Network.TypedProtocol.Stateful.Driver
  ( -- * DriverIngerface
    Driver (..)
    -- * Running a peer
  , runPeerWithDriver
    -- * Re-exports
  , SomeMessage (..)
  , DecodeStep (..)
  ) where

import Control.Monad.Class.MonadSTM

import Data.Kind (Type)

import Network.TypedProtocol.Codec (DecodeStep (..), SomeMessage (..))
import Network.TypedProtocol.Core
import Network.TypedProtocol.Stateful.Peer

data Driver ps (pr :: PeerRole) bytes failure dstate f m =
        Driver {
          -- | Send a message.
          --
          forall {k} {k} ps (pr :: PeerRole) (bytes :: k) (failure :: k)
       dstate (f :: ps -> *) (m :: * -> *).
Driver ps pr bytes failure dstate f m
-> forall (st :: ps) (st' :: ps).
   (StateTokenI st, StateTokenI st', ActiveState st) =>
   ReflRelativeAgency
     (StateAgency st) 'WeHaveAgency (Relative pr (StateAgency st))
   -> f st' -> Message ps st st' -> m ()
sendMessage   :: forall (st :: ps) (st' :: ps).
                           StateTokenI st
                        => StateTokenI st'
                        => ActiveState st
                        => ReflRelativeAgency (StateAgency st)
                                               WeHaveAgency
                                              (Relative pr (StateAgency st))
                        -> f st'
                        -> Message ps st st'
                        -> m ()

        , -- | Receive a message, a blocking action which reads from the network
          -- and runs the incremental decoder until a full message is decoded.
          --
          forall {k} {k} ps (pr :: PeerRole) (bytes :: k) (failure :: k)
       dstate (f :: ps -> *) (m :: * -> *).
Driver ps pr bytes failure dstate f m
-> forall (st :: ps).
   (StateTokenI st, ActiveState st) =>
   ReflRelativeAgency
     (StateAgency st) 'TheyHaveAgency (Relative pr (StateAgency st))
   -> f st -> dstate -> m (SomeMessage st, dstate)
recvMessage   :: forall (st :: ps).
                           StateTokenI st
                        => ActiveState st
                        => ReflRelativeAgency (StateAgency st)
                                               TheyHaveAgency
                                              (Relative pr (StateAgency st))
                        -> f st
                        -> dstate
                        -> m (SomeMessage st, dstate)

        , forall {k} {k} ps (pr :: PeerRole) (bytes :: k) (failure :: k)
       dstate (f :: ps -> *) (m :: * -> *).
Driver ps pr bytes failure dstate f m -> dstate
initialDState :: dstate
        }


--
-- Running peers
--

-- | Run a peer with the given driver.
--
-- This runs the peer to completion (if the protocol allows for termination).
--
runPeerWithDriver
  :: forall ps (st :: ps) pr bytes failure dstate (f :: ps -> Type) m a.
     MonadSTM m
  => Driver ps pr bytes failure dstate f m
  -> f st
  -> Peer ps pr st f m a
  -> m (a, dstate)
runPeerWithDriver :: forall {k} {k} ps (st :: ps) (pr :: PeerRole) (bytes :: k)
       (failure :: k) dstate (f :: ps -> *) (m :: * -> *) a.
MonadSTM m =>
Driver ps pr bytes failure dstate f m
-> f st -> Peer ps pr st f m a -> m (a, dstate)
runPeerWithDriver Driver{ forall (st :: ps) (st' :: ps).
(StateTokenI st, StateTokenI st', ActiveState st) =>
ReflRelativeAgency
  (StateAgency st) 'WeHaveAgency (Relative pr (StateAgency st))
-> f st' -> Message ps st st' -> m ()
sendMessage :: forall {k} {k} ps (pr :: PeerRole) (bytes :: k) (failure :: k)
       dstate (f :: ps -> *) (m :: * -> *).
Driver ps pr bytes failure dstate f m
-> forall (st :: ps) (st' :: ps).
   (StateTokenI st, StateTokenI st', ActiveState st) =>
   ReflRelativeAgency
     (StateAgency st) 'WeHaveAgency (Relative pr (StateAgency st))
   -> f st' -> Message ps st st' -> m ()
sendMessage :: forall (st :: ps) (st' :: ps).
(StateTokenI st, StateTokenI st', ActiveState st) =>
ReflRelativeAgency
  (StateAgency st) 'WeHaveAgency (Relative pr (StateAgency st))
-> f st' -> Message ps st st' -> m ()
sendMessage
                        , forall (st :: ps).
(StateTokenI st, ActiveState st) =>
ReflRelativeAgency
  (StateAgency st) 'TheyHaveAgency (Relative pr (StateAgency st))
-> f st -> dstate -> m (SomeMessage st, dstate)
recvMessage :: forall {k} {k} ps (pr :: PeerRole) (bytes :: k) (failure :: k)
       dstate (f :: ps -> *) (m :: * -> *).
Driver ps pr bytes failure dstate f m
-> forall (st :: ps).
   (StateTokenI st, ActiveState st) =>
   ReflRelativeAgency
     (StateAgency st) 'TheyHaveAgency (Relative pr (StateAgency st))
   -> f st -> dstate -> m (SomeMessage st, dstate)
recvMessage :: forall (st :: ps).
(StateTokenI st, ActiveState st) =>
ReflRelativeAgency
  (StateAgency st) 'TheyHaveAgency (Relative pr (StateAgency st))
-> f st -> dstate -> m (SomeMessage st, dstate)
recvMessage
                        , dstate
initialDState :: forall {k} {k} ps (pr :: PeerRole) (bytes :: k) (failure :: k)
       dstate (f :: ps -> *) (m :: * -> *).
Driver ps pr bytes failure dstate f m -> dstate
initialDState :: dstate
initialDState
                        } =
    dstate -> f st -> Peer ps pr st f m a -> m (a, dstate)
forall (st' :: ps).
dstate -> f st' -> Peer ps pr st' f m a -> m (a, dstate)
go dstate
initialDState
  where
    go :: forall st'.
          dstate
       -> f st'
       -> Peer ps pr st' f m a
       -> m (a, dstate)
    go :: forall (st' :: ps).
dstate -> f st' -> Peer ps pr st' f m a -> m (a, dstate)
go !dstate
dstate !f st'
f (Effect m (Peer ps pr st' f m a)
k) = m (Peer ps pr st' f m a)
k m (Peer ps pr st' f m a)
-> (Peer ps pr st' f m a -> m (a, dstate)) -> m (a, dstate)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= dstate -> f st' -> Peer ps pr st' f m a -> m (a, dstate)
forall (st' :: ps).
dstate -> f st' -> Peer ps pr st' f m a -> m (a, dstate)
go dstate
dstate f st'
f

    go !dstate
dstate  f st'
_ (Done NobodyHasAgencyProof pr st'
_ a
x) = (a, dstate) -> m (a, dstate)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, dstate
dstate)

    go !dstate
dstate  f st'
_ (Yield WeHaveAgencyProof pr st'
refl !f st'
f Message ps st' st'
msg Peer ps pr st' f m a
k) = do
      WeHaveAgencyProof pr st' -> f st' -> Message ps st' st' -> m ()
forall (st :: ps) (st' :: ps).
(StateTokenI st, StateTokenI st', ActiveState st) =>
ReflRelativeAgency
  (StateAgency st) 'WeHaveAgency (Relative pr (StateAgency st))
-> f st' -> Message ps st st' -> m ()
sendMessage WeHaveAgencyProof pr st'
refl f st'
f Message ps st' st'
msg
      dstate -> f st' -> Peer ps pr st' f m a -> m (a, dstate)
forall (st' :: ps).
dstate -> f st' -> Peer ps pr st' f m a -> m (a, dstate)
go dstate
dstate f st'
f Peer ps pr st' f m a
k

    go !dstate
dstate !f st'
f (Await TheyHaveAgencyProof pr st'
refl forall (st' :: ps).
f st' -> Message ps st' st' -> (Peer ps pr st' f m a, f st')
k) = do
      (SomeMessage msg, dstate') <- TheyHaveAgencyProof pr st'
-> f st' -> dstate -> m (SomeMessage st', dstate)
forall (st :: ps).
(StateTokenI st, ActiveState st) =>
ReflRelativeAgency
  (StateAgency st) 'TheyHaveAgency (Relative pr (StateAgency st))
-> f st -> dstate -> m (SomeMessage st, dstate)
recvMessage TheyHaveAgencyProof pr st'
refl f st'
f dstate
dstate
      case k f msg of
        (Peer ps pr st' f m a
k', f st'
f') -> dstate -> f st' -> Peer ps pr st' f m a -> m (a, dstate)
forall (st' :: ps).
dstate -> f st' -> Peer ps pr st' f m a -> m (a, dstate)
go dstate
dstate' f st'
f' Peer ps pr st' f m a
k'