{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Network.TypedProtocol.Channel
( Channel (..)
, hoistChannel
, isoKleisliChannel
, fixedInputChannel
, mvarsAsChannel
, handlesAsChannel
#if !defined(mingw32_HOST_OS)
, socketAsChannel
#endif
, createConnectedChannels
, createConnectedBufferedChannels
, createConnectedBufferedChannelsUnbounded
, createPipelineTestChannels
, channelEffect
, delayChannel
, loggingChannel
) where
import Control.Concurrent.Class.MonadSTM
import Control.Monad ((>=>))
import Control.Monad.Class.MonadSay
import Control.Monad.Class.MonadTimer.SI
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import Data.ByteString.Lazy.Internal (smallChunkSize)
import Data.Proxy
import Numeric.Natural
#if !defined(mingw32_HOST_OS)
import Network.Socket (Socket)
import qualified Network.Socket.ByteString.Lazy as Socket
#endif
import qualified System.IO as IO (Handle, hFlush, hIsEOF)
data Channel m a = Channel {
forall (m :: * -> *) a. Channel m a -> a -> m ()
send :: a -> m (),
forall (m :: * -> *) a. Channel m a -> m (Maybe a)
recv :: m (Maybe a)
}
isoKleisliChannel
:: forall a b m. Monad m
=> (a -> m b)
-> (b -> m a)
-> Channel m a
-> Channel m b
isoKleisliChannel :: forall a b (m :: * -> *).
Monad m =>
(a -> m b) -> (b -> m a) -> Channel m a -> Channel m b
isoKleisliChannel a -> m b
f b -> m a
finv Channel{a -> m ()
send :: forall (m :: * -> *) a. Channel m a -> a -> m ()
send :: a -> m ()
send, m (Maybe a)
recv :: forall (m :: * -> *) a. Channel m a -> m (Maybe a)
recv :: m (Maybe a)
recv} = Channel {
send :: b -> m ()
send = b -> m a
finv (b -> m a) -> (a -> m ()) -> b -> m ()
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> a -> m ()
send,
recv :: m (Maybe b)
recv = m (Maybe a)
recv m (Maybe a) -> (Maybe a -> m (Maybe b)) -> m (Maybe b)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (a -> m b) -> Maybe a -> m (Maybe b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Maybe a -> f (Maybe b)
traverse a -> m b
f
}
hoistChannel
:: (forall x . m x -> n x)
-> Channel m a
-> Channel n a
hoistChannel :: forall (m :: * -> *) (n :: * -> *) a.
(forall x. m x -> n x) -> Channel m a -> Channel n a
hoistChannel forall x. m x -> n x
nat Channel m a
channel = Channel
{ send :: a -> n ()
send = m () -> n ()
forall x. m x -> n x
nat (m () -> n ()) -> (a -> m ()) -> a -> n ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Channel m a -> a -> m ()
forall (m :: * -> *) a. Channel m a -> a -> m ()
send Channel m a
channel
, recv :: n (Maybe a)
recv = m (Maybe a) -> n (Maybe a)
forall x. m x -> n x
nat (Channel m a -> m (Maybe a)
forall (m :: * -> *) a. Channel m a -> m (Maybe a)
recv Channel m a
channel)
}
fixedInputChannel :: MonadSTM m => [a] -> m (Channel m a)
fixedInputChannel :: forall (m :: * -> *) a. MonadSTM m => [a] -> m (Channel m a)
fixedInputChannel [a]
xs0 = do
v <- STM m (TVar m [a]) -> m (TVar m [a])
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (TVar m [a]) -> m (TVar m [a]))
-> STM m (TVar m [a]) -> m (TVar m [a])
forall a b. (a -> b) -> a -> b
$ [a] -> STM m (TVar m [a])
forall a. a -> STM m (TVar m a)
forall (m :: * -> *) a. MonadSTM m => a -> STM m (TVar m a)
newTVar [a]
xs0
return Channel {send, recv = recv v}
where
recv :: TVar m [a] -> m (Maybe a)
recv TVar m [a]
v = STM m (Maybe a) -> m (Maybe a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Maybe a) -> m (Maybe a)) -> STM m (Maybe a) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ do
xs <- TVar m [a] -> STM m [a]
forall a. TVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar m [a]
v
case xs of
[] -> Maybe a -> STM m (Maybe a)
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
(a
x:[a]
xs') -> TVar m [a] -> [a] -> STM m ()
forall a. TVar m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar m [a]
v [a]
xs' STM m () -> STM m (Maybe a) -> STM m (Maybe a)
forall a b. STM m a -> STM m b -> STM m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Maybe a -> STM m (Maybe a)
forall a. a -> STM m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Maybe a
forall a. a -> Maybe a
Just a
x)
send :: p -> m ()
send p
_ = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
mvarsAsChannel :: MonadSTM m
=> TMVar m a
-> TMVar m a
-> Channel m a
mvarsAsChannel :: forall (m :: * -> *) a.
MonadSTM m =>
TMVar m a -> TMVar m a -> Channel m a
mvarsAsChannel TMVar m a
bufferRead TMVar m a
bufferWrite =
Channel{a -> m ()
forall {m :: * -> *}. (TMVar m ~ TMVar m, MonadSTM m) => a -> m ()
send :: a -> m ()
send :: forall {m :: * -> *}. (TMVar m ~ TMVar m, MonadSTM m) => a -> m ()
send, m (Maybe a)
recv :: m (Maybe a)
recv :: m (Maybe a)
recv}
where
send :: a -> m ()
send a
x = STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (TMVar m a -> a -> STM m ()
forall a. TMVar m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TMVar m a -> a -> STM m ()
putTMVar TMVar m a
TMVar m a
bufferWrite a
x)
recv :: m (Maybe a)
recv = STM m (Maybe a) -> m (Maybe a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> STM m a -> STM m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TMVar m a -> STM m a
forall a. TMVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => TMVar m a -> STM m a
takeTMVar TMVar m a
bufferRead)
createConnectedChannels :: forall m a. (MonadLabelledSTM m, MonadTraceSTM m, Show a) => m (Channel m a, Channel m a)
createConnectedChannels :: forall (m :: * -> *) a.
(MonadLabelledSTM m, MonadTraceSTM m, Show a) =>
m (Channel m a, Channel m a)
createConnectedChannels = do
bufferA <- STM m (TMVar m a) -> m (TMVar m a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (TMVar m a) -> m (TMVar m a))
-> STM m (TMVar m a) -> m (TMVar m a)
forall a b. (a -> b) -> a -> b
$ do
v <- STM m (TMVar m a)
forall a. STM m (TMVar m a)
forall (m :: * -> *) a. MonadSTM m => STM m (TMVar m a)
newEmptyTMVar
labelTMVar v "buffer-a"
traceTMVar (Proxy @m) v $ \Maybe (Maybe a)
_ Maybe a
a -> TraceValue -> InspectMonad m TraceValue
forall a. a -> InspectMonad m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TraceValue -> InspectMonad m TraceValue)
-> TraceValue -> InspectMonad m TraceValue
forall a b. (a -> b) -> a -> b
$ String -> TraceValue
TraceString (String
"buffer-a: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Maybe a -> String
forall a. Show a => a -> String
show Maybe a
a)
return v
bufferB <- atomically $ do
v <- newEmptyTMVar
traceTMVar (Proxy @m) v $ \Maybe (Maybe a)
_ Maybe a
a -> TraceValue -> InspectMonad m TraceValue
forall a. a -> InspectMonad m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TraceValue -> InspectMonad m TraceValue)
-> TraceValue -> InspectMonad m TraceValue
forall a b. (a -> b) -> a -> b
$ String -> TraceValue
TraceString (String
"buffer-b: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Maybe a -> String
forall a. Show a => a -> String
show Maybe a
a)
labelTMVar v "buffer-b"
return v
return (mvarsAsChannel bufferB bufferA,
mvarsAsChannel bufferA bufferB)
createConnectedBufferedChannels :: MonadSTM m
=> Natural -> m (Channel m a, Channel m a)
createConnectedBufferedChannels :: forall (m :: * -> *) a.
MonadSTM m =>
Natural -> m (Channel m a, Channel m a)
createConnectedBufferedChannels Natural
sz = do
bufferA <- STM m (TBQueue m a) -> m (TBQueue m a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (TBQueue m a) -> m (TBQueue m a))
-> STM m (TBQueue m a) -> m (TBQueue m a)
forall a b. (a -> b) -> a -> b
$ Natural -> STM m (TBQueue m a)
forall a. Natural -> STM m (TBQueue m a)
forall (m :: * -> *) a.
MonadSTM m =>
Natural -> STM m (TBQueue m a)
newTBQueue Natural
sz
bufferB <- atomically $ newTBQueue sz
return (queuesAsChannel bufferB bufferA,
queuesAsChannel bufferA bufferB)
where
queuesAsChannel :: TBQueue m a -> TBQueue m a -> Channel m a
queuesAsChannel TBQueue m a
bufferRead TBQueue m a
bufferWrite =
Channel{a -> m ()
forall {m :: * -> *}.
(TBQueue m ~ TBQueue m, MonadSTM m) =>
a -> m ()
send :: a -> m ()
send :: forall {m :: * -> *}.
(TBQueue m ~ TBQueue m, MonadSTM m) =>
a -> m ()
send, m (Maybe a)
recv :: m (Maybe a)
recv :: m (Maybe a)
recv}
where
send :: a -> m ()
send a
x = STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (TBQueue m a -> a -> STM m ()
forall a. TBQueue m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TBQueue m a -> a -> STM m ()
writeTBQueue TBQueue m a
TBQueue m a
bufferWrite a
x)
recv :: m (Maybe a)
recv = STM m (Maybe a) -> m (Maybe a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> STM m a -> STM m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TBQueue m a -> STM m a
forall a. TBQueue m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => TBQueue m a -> STM m a
readTBQueue TBQueue m a
bufferRead)
createConnectedBufferedChannelsUnbounded :: forall m a. MonadSTM m
=> m (Channel m a, Channel m a)
createConnectedBufferedChannelsUnbounded :: forall (m :: * -> *) a. MonadSTM m => m (Channel m a, Channel m a)
createConnectedBufferedChannelsUnbounded = do
bufferA <- m (TQueue m a)
forall a. m (TQueue m a)
forall (m :: * -> *) a. MonadSTM m => m (TQueue m a)
newTQueueIO
bufferB <- newTQueueIO
return (queuesAsChannel bufferB bufferA,
queuesAsChannel bufferA bufferB)
where
queuesAsChannel :: TQueue m a -> TQueue m a -> Channel m a
queuesAsChannel TQueue m a
bufferRead TQueue m a
bufferWrite =
Channel{a -> m ()
forall {m :: * -> *}.
(TQueue m ~ TQueue m, MonadSTM m) =>
a -> m ()
send :: a -> m ()
send :: forall {m :: * -> *}.
(TQueue m ~ TQueue m, MonadSTM m) =>
a -> m ()
send, m (Maybe a)
recv :: m (Maybe a)
recv :: m (Maybe a)
recv}
where
send :: a -> m ()
send a
x = STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (TQueue m a -> a -> STM m ()
forall a. TQueue m a -> a -> STM m ()
forall (m :: * -> *) a. MonadSTM m => TQueue m a -> a -> STM m ()
writeTQueue TQueue m a
TQueue m a
bufferWrite a
x)
recv :: m (Maybe a)
recv = STM m (Maybe a) -> m (Maybe a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically ( a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> STM m a -> STM m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TQueue m a -> STM m a
forall a. TQueue m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => TQueue m a -> STM m a
readTQueue TQueue m a
bufferRead)
createPipelineTestChannels :: MonadSTM m
=> Natural -> m (Channel m a, Channel m a)
createPipelineTestChannels :: forall (m :: * -> *) a.
MonadSTM m =>
Natural -> m (Channel m a, Channel m a)
createPipelineTestChannels Natural
sz = do
bufferA <- STM m (TBQueue m a) -> m (TBQueue m a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (TBQueue m a) -> m (TBQueue m a))
-> STM m (TBQueue m a) -> m (TBQueue m a)
forall a b. (a -> b) -> a -> b
$ Natural -> STM m (TBQueue m a)
forall a. Natural -> STM m (TBQueue m a)
forall (m :: * -> *) a.
MonadSTM m =>
Natural -> STM m (TBQueue m a)
newTBQueue Natural
sz
bufferB <- atomically $ newTBQueue sz
return (queuesAsChannel bufferB bufferA,
queuesAsChannel bufferA bufferB)
where
queuesAsChannel :: TBQueue m a -> TBQueue m a -> Channel m a
queuesAsChannel TBQueue m a
bufferRead TBQueue m a
bufferWrite =
Channel{a -> m ()
forall {m :: * -> *}.
(TBQueue m ~ TBQueue m, MonadSTM m) =>
a -> m ()
send :: a -> m ()
send :: forall {m :: * -> *}.
(TBQueue m ~ TBQueue m, MonadSTM m) =>
a -> m ()
send, m (Maybe a)
recv :: m (Maybe a)
recv :: m (Maybe a)
recv}
where
send :: a -> m ()
send a
x = STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
full <- TBQueue m a -> STM m Bool
forall a. TBQueue m a -> STM m Bool
forall (m :: * -> *) a. MonadSTM m => TBQueue m a -> STM m Bool
isFullTBQueue TBQueue m a
TBQueue m a
bufferWrite
if full then error failureMsg
else writeTBQueue bufferWrite x
recv :: m (Maybe a)
recv = STM m (Maybe a) -> m (Maybe a)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> STM m a -> STM m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TBQueue m a -> STM m a
forall a. TBQueue m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => TBQueue m a -> STM m a
readTBQueue TBQueue m a
bufferRead)
failureMsg :: String
failureMsg = String
"createPipelineTestChannels: "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"maximum pipeline depth exceeded: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Natural -> String
forall a. Show a => a -> String
show Natural
sz
handlesAsChannel :: IO.Handle
-> IO.Handle
-> Channel IO LBS.ByteString
handlesAsChannel :: Handle -> Handle -> Channel IO ByteString
handlesAsChannel Handle
hndRead Handle
hndWrite =
Channel{ByteString -> IO ()
send :: ByteString -> IO ()
send :: ByteString -> IO ()
send, IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv}
where
send :: LBS.ByteString -> IO ()
send :: ByteString -> IO ()
send ByteString
chunk = do
Handle -> ByteString -> IO ()
LBS.hPut Handle
hndWrite ByteString
chunk
Handle -> IO ()
IO.hFlush Handle
hndWrite
recv :: IO (Maybe LBS.ByteString)
recv :: IO (Maybe ByteString)
recv = do
eof <- Handle -> IO Bool
IO.hIsEOF Handle
hndRead
if eof
then return Nothing
else Just . LBS.fromStrict <$> BS.hGetSome hndRead smallChunkSize
channelEffect :: forall m a.
Monad m
=> (a -> m ())
-> (Maybe a -> m ())
-> Channel m a
-> Channel m a
channelEffect :: forall (m :: * -> *) a.
Monad m =>
(a -> m ()) -> (Maybe a -> m ()) -> Channel m a -> Channel m a
channelEffect a -> m ()
beforeSend Maybe a -> m ()
afterRecv Channel{a -> m ()
send :: forall (m :: * -> *) a. Channel m a -> a -> m ()
send :: a -> m ()
send, m (Maybe a)
recv :: forall (m :: * -> *) a. Channel m a -> m (Maybe a)
recv :: m (Maybe a)
recv} =
Channel{
send :: a -> m ()
send = \a
x -> do
a -> m ()
beforeSend a
x
a -> m ()
send a
x
, recv :: m (Maybe a)
recv = do
mx <- m (Maybe a)
recv
afterRecv mx
return mx
}
delayChannel :: MonadDelay m
=> DiffTime
-> Channel m a
-> Channel m a
delayChannel :: forall (m :: * -> *) a.
MonadDelay m =>
DiffTime -> Channel m a -> Channel m a
delayChannel DiffTime
delay = (a -> m ()) -> (Maybe a -> m ()) -> Channel m a -> Channel m a
forall (m :: * -> *) a.
Monad m =>
(a -> m ()) -> (Maybe a -> m ()) -> Channel m a -> Channel m a
channelEffect (\a
_ -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
(\Maybe a
_ -> DiffTime -> m ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
delay)
#if !defined(mingw32_HOST_OS)
socketAsChannel :: Socket
-> Channel IO LBS.ByteString
socketAsChannel :: Socket -> Channel IO ByteString
socketAsChannel Socket
sock =
Channel{ByteString -> IO ()
send :: ByteString -> IO ()
send :: ByteString -> IO ()
send, IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv}
where
send :: LBS.ByteString -> IO ()
send :: ByteString -> IO ()
send = Socket -> ByteString -> IO ()
Socket.sendAll Socket
sock
recv :: IO (Maybe LBS.ByteString)
recv :: IO (Maybe ByteString)
recv = do
bs <- Socket -> Int64 -> IO ByteString
Socket.recv Socket
sock (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
smallChunkSize)
if LBS.null bs
then return Nothing
else return (Just bs)
#endif
loggingChannel :: ( MonadSay m
, Show id
, Show a
)
=> id
-> Channel m a
-> Channel m a
loggingChannel :: forall (m :: * -> *) id a.
(MonadSay m, Show id, Show a) =>
id -> Channel m a -> Channel m a
loggingChannel id
ident Channel{a -> m ()
send :: forall (m :: * -> *) a. Channel m a -> a -> m ()
send :: a -> m ()
send,m (Maybe a)
recv :: forall (m :: * -> *) a. Channel m a -> m (Maybe a)
recv :: m (Maybe a)
recv} =
Channel {
send :: a -> m ()
send = a -> m ()
loggingSend,
recv :: m (Maybe a)
recv = m (Maybe a)
loggingRecv
}
where
loggingSend :: a -> m ()
loggingSend a
a = do
String -> m ()
forall (m :: * -> *). MonadSay m => String -> m ()
say (id -> String
forall a. Show a => a -> String
show id
ident String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
":send:" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
a)
a -> m ()
send a
a
loggingRecv :: m (Maybe a)
loggingRecv = do
msg <- m (Maybe a)
recv
case msg of
Maybe a
Nothing -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Just a
a -> String -> m ()
forall (m :: * -> *). MonadSay m => String -> m ()
say (id -> String
forall a. Show a => a -> String
show id
ident String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
":recv:" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
a)
return msg