{-# LANGUAGE CPP                 #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE PolyKinds           #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.Mux.Bearer.Pipe (
    PipeChannel (..)
  , pipeChannelFromHandles
#if defined(mingw32_HOST_OS)
  , pipeChannelFromNamedPipe
#endif
  , pipeAsMuxBearer
  ) where

import           Control.Monad.Class.MonadThrow
import           Control.Monad.Class.MonadTime.SI
import           Control.Tracer
import qualified Data.ByteString.Lazy as BL
import           System.IO (Handle, hFlush)

#if defined(mingw32_HOST_OS)
import           Data.Foldable (traverse_)

import qualified System.Win32.Types as Win32 (HANDLE)
import qualified System.Win32.Async as Win32.Async
#endif

import           Network.Mux.Types (MuxBearer)
import qualified Network.Mux.Types as Mx
import qualified Network.Mux.Trace as Mx
import qualified Network.Mux.Codec as Mx
import qualified Network.Mux.Time as Mx
import qualified Network.Mux.Timeout as Mx


-- | Abstraction over various types of handles.  We provide two instances:
--
--  * based on 'Handle': os independent, but will not work well on Windows,
--  * based on 'Win32.HANDLE': Windows specific.
--
data PipeChannel = PipeChannel {
    PipeChannel -> Int -> IO ByteString
readHandle  :: Int -> IO BL.ByteString,
    PipeChannel -> ByteString -> IO ()
writeHandle :: BL.ByteString -> IO ()
  }

pipeChannelFromHandles :: Handle
                       -- ^ read handle
                       -> Handle
                       -- ^ write handle
                       -> PipeChannel
pipeChannelFromHandles :: Handle -> Handle -> PipeChannel
pipeChannelFromHandles Handle
r Handle
w = PipeChannel {
    readHandle :: Int -> IO ByteString
readHandle  = Handle -> Int -> IO ByteString
BL.hGet Handle
r,
    writeHandle :: ByteString -> IO ()
writeHandle = \ByteString
a -> Handle -> ByteString -> IO ()
BL.hPut Handle
w ByteString
a forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
hFlush Handle
w
  }

#if defined(mingw32_HOST_OS)
-- | Create a 'PipeChannel' from a named pipe.  This allows to emulate
-- anonymous pipes using named pipes on Windows.
--
pipeChannelFromNamedPipe :: Win32.HANDLE
                         -> PipeChannel
pipeChannelFromNamedPipe h = PipeChannel {
      readHandle  = fmap BL.fromStrict . Win32.Async.readHandle h,
      writeHandle = traverse_ (Win32.Async.writeHandle h) . BL.toChunks
    }
#endif

pipeAsMuxBearer
  :: Mx.SDUSize
  -> Tracer IO Mx.MuxTrace
  -> PipeChannel
  -> MuxBearer IO
pipeAsMuxBearer :: SDUSize -> Tracer IO MuxTrace -> PipeChannel -> MuxBearer IO
pipeAsMuxBearer SDUSize
sduSize Tracer IO MuxTrace
tracer PipeChannel
channel =
      Mx.MuxBearer {
          read :: TimeoutFn IO -> IO (MuxSDU, Time)
Mx.read    = TimeoutFn IO -> IO (MuxSDU, Time)
readPipe,
          write :: TimeoutFn IO -> MuxSDU -> IO Time
Mx.write   = TimeoutFn IO -> MuxSDU -> IO Time
writePipe,
          sduSize :: SDUSize
Mx.sduSize = SDUSize
sduSize
        }
    where
      readPipe :: Mx.TimeoutFn IO -> IO (Mx.MuxSDU, Time)
      readPipe :: TimeoutFn IO -> IO (MuxSDU, Time)
readPipe TimeoutFn IO
_ = do
          forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceRecvHeaderStart
          ByteString
hbuf <- Int -> [ByteString] -> IO ByteString
recvLen' Int
8 []
          case ByteString -> Either MuxError MuxSDU
Mx.decodeMuxSDU ByteString
hbuf of
              Left MuxError
e -> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO MuxError
e
              Right header :: MuxSDU
header@Mx.MuxSDU { MuxSDUHeader
msHeader :: MuxSDU -> MuxSDUHeader
msHeader :: MuxSDUHeader
Mx.msHeader } -> do
                  forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer forall a b. (a -> b) -> a -> b
$ MuxSDUHeader -> MuxTrace
Mx.MuxTraceRecvHeaderEnd MuxSDUHeader
msHeader
                  ByteString
blob <- Int -> [ByteString] -> IO ByteString
recvLen' (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ MuxSDUHeader -> Word16
Mx.mhLength MuxSDUHeader
msHeader) []
                  Time
ts <- forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
                  forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer (MuxSDUHeader -> Time -> MuxTrace
Mx.MuxTraceRecvDeltaQObservation MuxSDUHeader
msHeader Time
ts)
                  forall (m :: * -> *) a. Monad m => a -> m a
return (MuxSDU
header {msBlob :: ByteString
Mx.msBlob = ByteString
blob}, Time
ts)

      recvLen' :: Int -> [BL.ByteString] -> IO BL.ByteString
      recvLen' :: Int -> [ByteString] -> IO ByteString
recvLen' Int
0 [ByteString]
bufs = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [ByteString]
bufs
      recvLen' Int
l [ByteString]
bufs = do
          forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer forall a b. (a -> b) -> a -> b
$ Int -> MuxTrace
Mx.MuxTraceRecvStart Int
l
          ByteString
buf <- PipeChannel -> Int -> IO ByteString
readHandle PipeChannel
channel Int
l
                    forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"readHandle errored"
          if ByteString -> Bool
BL.null ByteString
buf
              then forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ MuxErrorType -> String -> MuxError
Mx.MuxError MuxErrorType
Mx.MuxBearerClosed String
"Pipe closed when reading data"
              else do
                  forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer forall a b. (a -> b) -> a -> b
$ Int -> MuxTrace
Mx.MuxTraceRecvEnd (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ ByteString -> Int64
BL.length ByteString
buf)
                  Int -> [ByteString] -> IO ByteString
recvLen' (Int
l forall a. Num a => a -> a -> a
- forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int64
BL.length ByteString
buf)) (ByteString
buf forall a. a -> [a] -> [a]
: [ByteString]
bufs)

      writePipe :: Mx.TimeoutFn IO -> Mx.MuxSDU -> IO Time
      writePipe :: TimeoutFn IO -> MuxSDU -> IO Time
writePipe TimeoutFn IO
_ MuxSDU
sdu = do
          Time
ts <- forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
          let ts32 :: Word32
ts32 = Time -> Word32
Mx.timestampMicrosecondsLow32Bits Time
ts
              sdu' :: MuxSDU
sdu' = MuxSDU -> RemoteClockModel -> MuxSDU
Mx.setTimestamp MuxSDU
sdu (Word32 -> RemoteClockModel
Mx.RemoteClockModel Word32
ts32)
              buf :: ByteString
buf  = MuxSDU -> ByteString
Mx.encodeMuxSDU MuxSDU
sdu'
          forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer forall a b. (a -> b) -> a -> b
$ MuxSDUHeader -> MuxTrace
Mx.MuxTraceSendStart (MuxSDU -> MuxSDUHeader
Mx.msHeader MuxSDU
sdu')
          PipeChannel -> ByteString -> IO ()
writeHandle PipeChannel
channel ByteString
buf
            forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` forall (m :: * -> *) a.
MonadThrow m =>
String -> IOException -> m a
Mx.handleIOException String
"writeHandle errored"
          forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer IO MuxTrace
tracer forall a b. (a -> b) -> a -> b
$ MuxTrace
Mx.MuxTraceSendEnd
          forall (m :: * -> *) a. Monad m => a -> m a
return Time
ts