{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE PolyKinds           #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}


-- This is already implied by the -Wall in the .cabal file, but lets just be
-- completely explicit about it too, since we rely on the completeness
-- checking in the cases below for the completeness of our proofs.
{-# OPTIONS_GHC -Wincomplete-patterns #-}

-- | Proofs about the typed protocol framework.
--
-- It also provides helpful testing utilities.
--
module Network.TypedProtocol.Stateful.Proofs
  ( connect
  , TerminalStates (..)
  , removeState
  ) where

import Control.Monad.Class.MonadSTM

import Data.Kind (Type)
import Data.Singletons

import Network.TypedProtocol.Core
import Network.TypedProtocol.Stateful.Peer qualified as ST
import Network.TypedProtocol.Peer
import Network.TypedProtocol.Proofs (TerminalStates (..))
import Network.TypedProtocol.Proofs qualified as TP



-- | Remove state for non-pipelined peers.
--
-- TODO: There's a difficulty to write `removeState` for pipelined peers which
-- is type safe.  The `Peer` doesn't track all pipelined transitions, just the
-- depth of pipelining, so we cannot push `f st` to a queue which type is
-- linked to `Peer`.  For a similar reason there's no way to write
-- `forgetPipelined` function.
--
-- However, this is possible if `Peer` tracks all transitions.
--
removeState
  :: Functor m
  => f st
  -> ST.Peer ps pr              st f m a
  ->    Peer ps pr NonPipelined st   m a
removeState :: forall (m :: * -> *) ps (f :: ps -> *) (st :: ps) (pr :: PeerRole)
       a.
Functor m =>
f st -> Peer ps pr st f m a -> Peer ps pr 'NonPipelined st m a
removeState = f st -> Peer ps pr st f m a -> Peer ps pr 'NonPipelined st m a
forall ps (pr :: PeerRole) (st :: ps) (f :: ps -> *) (m :: * -> *)
       a.
Functor m =>
f st -> Peer ps pr st f m a -> Peer ps pr 'NonPipelined st m a
go
  where
    go
      :: forall ps (pr :: PeerRole)
                (st :: ps)
                (f :: ps -> Type)
                m a.
         Functor m
      => f st
      -> ST.Peer ps pr              st f m a
      ->    Peer ps pr NonPipelined st   m a
    go :: forall ps (pr :: PeerRole) (st :: ps) (f :: ps -> *) (m :: * -> *)
       a.
Functor m =>
f st -> Peer ps pr st f m a -> Peer ps pr 'NonPipelined st m a
go f st
f (ST.Effect m (Peer ps pr st f m a)
k) = m (Peer ps pr 'NonPipelined st m a)
-> Peer ps pr 'NonPipelined st m a
forall ps (pr :: PeerRole) (pl :: IsPipelined) (st :: ps)
       (m :: * -> *) a.
m (Peer ps pr pl st m a) -> Peer ps pr pl st m a
Effect (f st -> Peer ps pr st f m a -> Peer ps pr 'NonPipelined st m a
forall ps (pr :: PeerRole) (st :: ps) (f :: ps -> *) (m :: * -> *)
       a.
Functor m =>
f st -> Peer ps pr st f m a -> Peer ps pr 'NonPipelined st m a
go f st
f (Peer ps pr st f m a -> Peer ps pr 'NonPipelined st m a)
-> m (Peer ps pr st f m a) -> m (Peer ps pr 'NonPipelined st m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Peer ps pr st f m a)
k)
    go f st
_ (ST.Yield WeHaveAgencyProof pr st
refl f st'
f Message ps st st'
msg Peer ps pr st' f m a
k) = WeHaveAgencyProof pr st
-> Message ps st st'
-> Peer ps pr 'NonPipelined st' m a
-> Peer ps pr 'NonPipelined st m a
forall ps (pr :: PeerRole) (pl :: IsPipelined) (st :: ps)
       (st' :: ps) (m :: * -> *) a.
(StateTokenI st, StateTokenI st', ActiveState st,
 Outstanding pl ~ 'Z) =>
WeHaveAgencyProof pr st
-> Message ps st st'
-> Peer ps pr pl st' m a
-> Peer ps pr pl st m a
Yield WeHaveAgencyProof pr st
refl Message ps st st'
msg (f st' -> Peer ps pr st' f m a -> Peer ps pr 'NonPipelined st' m a
forall ps (pr :: PeerRole) (st :: ps) (f :: ps -> *) (m :: * -> *)
       a.
Functor m =>
f st -> Peer ps pr st f m a -> Peer ps pr 'NonPipelined st m a
go f st'
f Peer ps pr st' f m a
k)
    go f st
f (ST.Await TheyHaveAgencyProof pr st
refl forall (st' :: ps).
f st -> Message ps st st' -> (Peer ps pr st' f m a, f st')
k) = TheyHaveAgencyProof pr st
-> (forall (st' :: ps).
    Message ps st st' -> Peer ps pr 'NonPipelined st' m a)
-> Peer ps pr 'NonPipelined st m a
forall ps (pr :: PeerRole) (pl :: IsPipelined) (st :: ps)
       (m :: * -> *) a.
(StateTokenI st, ActiveState st, Outstanding pl ~ 'Z) =>
TheyHaveAgencyProof pr st
-> (forall (st' :: ps). Message ps st st' -> Peer ps pr pl st' m a)
-> Peer ps pr pl st m a
Await TheyHaveAgencyProof pr st
refl ((forall (st' :: ps).
  Message ps st st' -> Peer ps pr 'NonPipelined st' m a)
 -> Peer ps pr 'NonPipelined st m a)
-> (forall (st' :: ps).
    Message ps st st' -> Peer ps pr 'NonPipelined st' m a)
-> Peer ps pr 'NonPipelined st m a
forall a b. (a -> b) -> a -> b
$ \Message ps st st'
msg ->
      case f st -> Message ps st st' -> (Peer ps pr st' f m a, f st')
forall (st' :: ps).
f st -> Message ps st st' -> (Peer ps pr st' f m a, f st')
k f st
f Message ps st st'
msg of
        (Peer ps pr st' f m a
k', f st'
f') -> f st' -> Peer ps pr st' f m a -> Peer ps pr 'NonPipelined st' m a
forall ps (pr :: PeerRole) (st :: ps) (f :: ps -> *) (m :: * -> *)
       a.
Functor m =>
f st -> Peer ps pr st f m a -> Peer ps pr 'NonPipelined st m a
go f st'
f' Peer ps pr st' f m a
k'
    go f st
_ (ST.Done NobodyHasAgencyProof pr st
refl a
a) = NobodyHasAgencyProof pr st -> a -> Peer ps pr 'NonPipelined st m a
forall ps (pr :: PeerRole) (pl :: IsPipelined) (st :: ps)
       (m :: * -> *) a.
(StateTokenI st, StateAgency st ~ 'NobodyAgency,
 Outstanding pl ~ 'Z) =>
NobodyHasAgencyProof pr st -> a -> Peer ps pr pl st m a
Done NobodyHasAgencyProof pr st
refl a
a


connect
  :: forall ps (pr :: PeerRole)
               (st :: ps)
               (f :: ps -> Type)
               m a b.
       (MonadSTM m, SingI pr)
    => f st
    -> ST.Peer ps             pr  st f m a
    -> ST.Peer ps (FlipAgency pr) st f m b
    -> m (a, b, TerminalStates ps)
connect :: forall ps (pr :: PeerRole) (st :: ps) (f :: ps -> *) (m :: * -> *)
       a b.
(MonadSTM m, SingI pr) =>
f st
-> Peer ps pr st f m a
-> Peer ps (FlipAgency pr) st f m b
-> m (a, b, TerminalStates ps)
connect f st
f Peer ps pr st f m a
a Peer ps (FlipAgency pr) st f m b
b = Peer ps pr 'NonPipelined st m a
-> Peer ps (FlipAgency pr) 'NonPipelined st m b
-> m (a, b, TerminalStates ps)
forall ps (pr :: PeerRole) (initSt :: ps) (m :: * -> *) a b.
(Monad m, SingI pr) =>
Peer ps pr 'NonPipelined initSt m a
-> Peer ps (FlipAgency pr) 'NonPipelined initSt m b
-> m (a, b, TerminalStates ps)
TP.connect (f st -> Peer ps pr st f m a -> Peer ps pr 'NonPipelined st m a
forall (m :: * -> *) ps (f :: ps -> *) (st :: ps) (pr :: PeerRole)
       a.
Functor m =>
f st -> Peer ps pr st f m a -> Peer ps pr 'NonPipelined st m a
removeState f st
f Peer ps pr st f m a
a) (f st
-> Peer ps (FlipAgency pr) st f m b
-> Peer ps (FlipAgency pr) 'NonPipelined st m b
forall (m :: * -> *) ps (f :: ps -> *) (st :: ps) (pr :: PeerRole)
       a.
Functor m =>
f st -> Peer ps pr st f m a -> Peer ps pr 'NonPipelined st m a
removeState f st
f Peer ps (FlipAgency pr) st f m b
b)