{-# LANGUAGE DataKinds                 #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts          #-}
{-# LANGUAGE GADTSyntax                #-}
{-# LANGUAGE KindSignatures            #-}
{-# LANGUAGE NamedFieldPuns            #-}
{-# LANGUAGE RankNTypes                #-}
{-# LANGUAGE ScopedTypeVariables       #-}
{-# LANGUAGE TypeFamilies              #-}

module Network.Mux.Compat
  ( muxStart
    -- * Mux bearers
  , MuxBearer
  , MakeBearer (..)
    -- * Defining 'MuxApplication's
  , MuxMode (..)
  , HasInitiator
  , HasResponder
  , MuxApplication (..)
  , MuxMiniProtocol (..)
  , RunMiniProtocol (..)
  , MiniProtocolNum (..)
  , MiniProtocolLimits (..)
  , MiniProtocolDir (..)
    -- * Errors
  , MuxError (..)
  , MuxErrorType (..)
    -- * Tracing
  , traceMuxBearerState
  , MuxBearerState (..)
  , MuxTrace (..)
  , WithMuxBearer (..)
  ) where

import qualified Data.ByteString.Lazy as BL
import           Data.Void (Void)

import           Control.Applicative (Alternative (..), (<|>))
import           Control.Concurrent.Class.MonadSTM.Strict
import           Control.Monad
import           Control.Monad.Class.MonadAsync
import           Control.Monad.Class.MonadThrow
import           Control.Monad.Class.MonadTimer.SI
import           Control.Tracer

import           Network.Mux (StartOnDemandOrEagerly (..), newMux,
                     runMiniProtocol, runMux, stopMux, traceMuxBearerState)
import           Network.Mux.Bearer
import           Network.Mux.Channel
import           Network.Mux.Trace
import           Network.Mux.Types hiding (MiniProtocolInfo (..))
import qualified Network.Mux.Types as Types


newtype MuxApplication (mode :: MuxMode) m a b =
        MuxApplication [MuxMiniProtocol mode m a b]

data MuxMiniProtocol (mode :: MuxMode) m a b =
     MuxMiniProtocol {
       forall (mode :: MuxMode) (m :: * -> *) a b.
MuxMiniProtocol mode m a b -> MiniProtocolNum
miniProtocolNum    :: !MiniProtocolNum,
       forall (mode :: MuxMode) (m :: * -> *) a b.
MuxMiniProtocol mode m a b -> MiniProtocolLimits
miniProtocolLimits :: !MiniProtocolLimits,
       forall (mode :: MuxMode) (m :: * -> *) a b.
MuxMiniProtocol mode m a b -> RunMiniProtocol mode m a b
miniProtocolRun    :: !(RunMiniProtocol mode m a b)
     }

data RunMiniProtocol (mode :: MuxMode) m a b where
  InitiatorProtocolOnly
    -- Initiator application; most simple application will be @'runPeer'@ or
    -- @'runPipelinedPeer'@ supplied with a codec and a @'Peer'@ for each
    -- @ptcl@.  But it allows to handle resources if just application of
    -- @'runPeer'@ is not enough.  It will be run as @'InitiatorDir'@.
    :: (Channel m -> m (a, Maybe BL.ByteString))
    -> RunMiniProtocol InitiatorMode m a Void

  ResponderProtocolOnly
    -- Responder application; similarly to the @'MuxInitiatorApplication'@ but it
    -- will be run using @'ResponderDir'@.
    :: (Channel m -> m (b, Maybe BL.ByteString))
    -> RunMiniProtocol ResponderMode m Void b

  InitiatorAndResponderProtocol
    -- Initiator and server applications.
    :: (Channel m -> m (a, Maybe BL.ByteString))
    -> (Channel m -> m (b, Maybe BL.ByteString))
    -> RunMiniProtocol InitiatorResponderMode m a b


muxStart
    :: forall m mode a b.
       ( MonadAsync m
       , MonadFork m
       , MonadLabelledSTM m
       , Alternative (STM m)
       , MonadThrow (STM m)
       , MonadTimer m
       , MonadMask m
       )
    => Tracer m MuxTrace
    -> MuxApplication mode m a b
    -> MuxBearer m
    -> m ()
muxStart :: forall (m :: * -> *) (mode :: MuxMode) a b.
(MonadAsync m, MonadFork m, MonadLabelledSTM m,
 Alternative (STM m), MonadThrow (STM m), MonadTimer m,
 MonadMask m) =>
Tracer m MuxTrace
-> MuxApplication mode m a b -> MuxBearer m -> m ()
muxStart Tracer m MuxTrace
tracer MuxApplication mode m a b
muxapp MuxBearer m
bearer = do
    Mux mode m
mux <- forall (m :: * -> *) (mode :: MuxMode).
MonadSTM m =>
MiniProtocolBundle mode -> m (Mux mode m)
newMux (MuxApplication mode m a b -> MiniProtocolBundle mode
toMiniProtocolBundle MuxApplication mode m a b
muxapp)

    [STM m (Either SomeException ())]
resOps <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
      [ forall (mode :: MuxMode) (m :: * -> *) a.
(Alternative (STM m), MonadSTM m, MonadThrow m,
 MonadThrow (STM m)) =>
Mux mode m
-> MiniProtocolNum
-> MiniProtocolDirection mode
-> StartOnDemandOrEagerly
-> (Channel m -> m (a, Maybe ByteString))
-> m (STM m (Either SomeException a))
runMiniProtocol
          Mux mode m
mux
          MiniProtocolNum
miniProtocolNum
          MiniProtocolDirection mode
ptclDir
          StartOnDemandOrEagerly
StartEagerly
          (\Channel m
a -> do
            ()
r <- Channel m -> m ()
action Channel m
a
            forall (m :: * -> *) a. Monad m => a -> m a
return (()
r, forall a. Maybe a
Nothing) -- Compat interface doesn't do restarts
          )
      | let MuxApplication [MuxMiniProtocol mode m a b]
ptcls = MuxApplication mode m a b
muxapp
      , MuxMiniProtocol{MiniProtocolNum
miniProtocolNum :: MiniProtocolNum
miniProtocolNum :: forall (mode :: MuxMode) (m :: * -> *) a b.
MuxMiniProtocol mode m a b -> MiniProtocolNum
miniProtocolNum, RunMiniProtocol mode m a b
miniProtocolRun :: RunMiniProtocol mode m a b
miniProtocolRun :: forall (mode :: MuxMode) (m :: * -> *) a b.
MuxMiniProtocol mode m a b -> RunMiniProtocol mode m a b
miniProtocolRun} <- [MuxMiniProtocol mode m a b]
ptcls
      , (MiniProtocolDirection mode
ptclDir, Channel m -> m ()
action) <- RunMiniProtocol mode m a b
-> [(MiniProtocolDirection mode, Channel m -> m ())]
selectRunner RunMiniProtocol mode m a b
miniProtocolRun
      ]

    -- Wait for the first MuxApplication to finish, then stop the mux.
    forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync (forall (m :: * -> *) (mode :: MuxMode).
(MonadAsync m, MonadFork m, MonadLabelledSTM m,
 Alternative (STM m), MonadThrow (STM m), MonadTimer m,
 MonadMask m) =>
Tracer m MuxTrace -> Mux mode m -> MuxBearer m -> m ()
runMux Tracer m MuxTrace
tracer Mux mode m
mux MuxBearer m
bearer) forall a b. (a -> b) -> a -> b
$ \Async m ()
aid -> do
      [STM m (Either SomeException ())] -> m ()
waitOnAny [STM m (Either SomeException ())]
resOps
      forall (m :: * -> *) (mode :: MuxMode).
MonadSTM m =>
Mux mode m -> m ()
stopMux Mux mode m
mux
      forall (m :: * -> *) a. MonadAsync m => Async m a -> m a
wait Async m ()
aid

  where
    waitOnAny :: [STM m (Either SomeException  ())] -> m ()
    waitOnAny :: [STM m (Either SomeException ())] -> m ()
waitOnAny [STM m (Either SomeException ())]
resOps = forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
(<|>) forall (m :: * -> *) a. MonadSTM m => STM m a
retry [STM m (Either SomeException ())]
resOps

    toMiniProtocolBundle :: MuxApplication mode m a b -> MiniProtocolBundle mode
    toMiniProtocolBundle :: MuxApplication mode m a b -> MiniProtocolBundle mode
toMiniProtocolBundle (MuxApplication [MuxMiniProtocol mode m a b]
ptcls) =
      forall (mode :: MuxMode).
[MiniProtocolInfo mode] -> MiniProtocolBundle mode
MiniProtocolBundle
        [ Types.MiniProtocolInfo {
            MiniProtocolNum
miniProtocolNum :: MiniProtocolNum
miniProtocolNum :: MiniProtocolNum
Types.miniProtocolNum,
            MiniProtocolDirection mode
miniProtocolDir :: MiniProtocolDirection mode
miniProtocolDir :: MiniProtocolDirection mode
Types.miniProtocolDir,
            MiniProtocolLimits
miniProtocolLimits :: MiniProtocolLimits
miniProtocolLimits :: MiniProtocolLimits
Types.miniProtocolLimits
          }
        | MuxMiniProtocol {
            MiniProtocolNum
miniProtocolNum :: MiniProtocolNum
miniProtocolNum :: forall (mode :: MuxMode) (m :: * -> *) a b.
MuxMiniProtocol mode m a b -> MiniProtocolNum
miniProtocolNum,
            MiniProtocolLimits
miniProtocolLimits :: MiniProtocolLimits
miniProtocolLimits :: forall (mode :: MuxMode) (m :: * -> *) a b.
MuxMiniProtocol mode m a b -> MiniProtocolLimits
miniProtocolLimits,
            RunMiniProtocol mode m a b
miniProtocolRun :: RunMiniProtocol mode m a b
miniProtocolRun :: forall (mode :: MuxMode) (m :: * -> *) a b.
MuxMiniProtocol mode m a b -> RunMiniProtocol mode m a b
miniProtocolRun
          } <- [MuxMiniProtocol mode m a b]
ptcls
        , MiniProtocolDirection mode
miniProtocolDir <- case RunMiniProtocol mode m a b
miniProtocolRun of
            InitiatorProtocolOnly{} -> [MiniProtocolDirection 'InitiatorMode
InitiatorDirectionOnly]
            ResponderProtocolOnly{} -> [MiniProtocolDirection 'ResponderMode
ResponderDirectionOnly]
            InitiatorAndResponderProtocol{} -> [MiniProtocolDirection 'InitiatorResponderMode
InitiatorDirection, MiniProtocolDirection 'InitiatorResponderMode
ResponderDirection]
        ]

    selectRunner :: RunMiniProtocol mode m a b
                 -> [(MiniProtocolDirection mode, Channel m -> m ())]
    selectRunner :: RunMiniProtocol mode m a b
-> [(MiniProtocolDirection mode, Channel m -> m ())]
selectRunner (InitiatorProtocolOnly Channel m -> m (a, Maybe ByteString)
initiator) =
      [(MiniProtocolDirection 'InitiatorMode
InitiatorDirectionOnly, forall (f :: * -> *) a. Functor f => f a -> f ()
void forall b c a. (b -> c) -> (a -> b) -> a -> c
. Channel m -> m (a, Maybe ByteString)
initiator)]
    selectRunner (ResponderProtocolOnly Channel m -> m (b, Maybe ByteString)
responder) =
      [(MiniProtocolDirection 'ResponderMode
ResponderDirectionOnly, forall (f :: * -> *) a. Functor f => f a -> f ()
void forall b c a. (b -> c) -> (a -> b) -> a -> c
. Channel m -> m (b, Maybe ByteString)
responder)]
    selectRunner (InitiatorAndResponderProtocol Channel m -> m (a, Maybe ByteString)
initiator Channel m -> m (b, Maybe ByteString)
responder) =
      [(MiniProtocolDirection 'InitiatorResponderMode
InitiatorDirection, forall (f :: * -> *) a. Functor f => f a -> f ()
void forall b c a. (b -> c) -> (a -> b) -> a -> c
. Channel m -> m (a, Maybe ByteString)
initiator)
      ,(MiniProtocolDirection 'InitiatorResponderMode
ResponderDirection, forall (f :: * -> *) a. Functor f => f a -> f ()
void forall b c a. (b -> c) -> (a -> b) -> a -> c
. Channel m -> m (b, Maybe ByteString)
responder)]