{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE QuasiQuotes #-}

module Network.TypedProtocol.Documentation.TH
( describeProtocol
)
where

import Network.TypedProtocol.Documentation.Types

import Control.Monad
#if MIN_VERSION_template_haskell(2,17,0)
-- This import is only needed when 'getDoc' is available.
import Data.Maybe (maybeToList)
#endif
import Data.Maybe (mapMaybe)
import Data.Proxy
import Language.Haskell.TH
import Language.Haskell.TH.Datatype
import Network.TypedProtocol.Core


applyTyArgs :: Type -> [Name] -> Type
applyTyArgs :: Type -> [Name] -> Type
applyTyArgs Type
t [] = Type
t
applyTyArgs Type
t (Name
x:[Name]
xs) =
  Type -> [Name] -> Type
applyTyArgs (Type -> Type -> Type
AppT Type
t (Name -> Type
ConT Name
x)) [Name]
xs

-- | Generate a 'ProtocolDescription' runtime representation of a typed
-- protocol specification, including serialization information.
describeProtocol :: Name -> [Name] -> Name -> [Name] -> ExpQ
describeProtocol :: Name -> [Name] -> Name -> [Name] -> ExpQ
describeProtocol Name
protoTyCon [Name]
protoTyArgs Name
codecTyCon [Name]
codecTyArgs = do
  info <- Name -> Q DatatypeInfo
reifyDatatype Name
protoTyCon
  protoDescription <- getDescription protoTyCon
  let pname = Name -> String
nameBase (DatatypeInfo -> Name
datatypeName DatatypeInfo
info)

  let extractAgency :: InstanceDec -> Maybe Name
      extractAgency (TySynInstD (TySynEqn Maybe [TyVarBndr ()]
_ Type
_ (PromotedT Name
agency))) = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
agency
      extractAgency InstanceDec
dec = String -> Maybe Name
forall a. HasCallStack => String -> a
error (String -> Maybe Name) -> String -> Maybe Name
forall a b. (a -> b) -> a -> b
$ String
"Unexpected InstanceDec: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ InstanceDec -> String
forall a. Show a => a -> String
show InstanceDec
dec

  let extractAgencies :: [InstanceDec] -> [Name]
      extractAgencies = (InstanceDec -> Maybe Name) -> [InstanceDec] -> [Name]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe InstanceDec -> Maybe Name
extractAgency

  let extractTheAgency :: [InstanceDec] -> Name
      extractTheAgency [InstanceDec]
inst = case [InstanceDec] -> [Name]
extractAgencies [InstanceDec]
inst of
        [Name
agency] -> Name
agency
        [Name]
xs -> String -> Name
forall a. HasCallStack => String -> a
error (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"Incorrect number of agencies: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Name] -> String
forall a. Show a => a -> String
show [Name]
xs

  pstates <- forM (datatypeCons info) $ \ConstructorInfo
conInfo -> do
    let conName :: Name
conName = ConstructorInfo -> Name
constructorName ConstructorInfo
conInfo
    stateDescription <- Name -> Q [Description]
getDescription Name
conName

    stateAgencies <- reifyInstances ''StateAgency [ConT conName]
    let agencyName = [InstanceDec] -> Name
extractTheAgency [InstanceDec]
stateAgencies
        agencyID = case Name -> String
nameBase Name
agencyName of
          String
"ServerAgency" -> 'ServerAgencyID
          String
"ClientAgency" -> 'ClientAgencyID
          String
"NobodyAgency" -> 'NobodyAgencyID
          String
x -> String -> Name
forall a. HasCallStack => String -> a
error (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"Unknown agency type " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
x String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" in state " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
conName

    return (conName, stateDescription, agencyID)

  let protocolTy = Type -> [Name] -> Type
applyTyArgs (Name -> Type
ConT Name
protoTyCon) [Name]
protoTyArgs

  [DataInstD _ _ _ _ cons _] <- reifyInstances ''Message [protocolTy]

  let messageInfos = (Con -> ExpQ) -> [Con] -> [ExpQ]
forall a b. (a -> b) -> [a] -> [b]
map (Name -> [Name] -> Name -> [Name] -> Name -> ExpQ
describeProtocolMessage Name
protoTyCon [Name]
protoTyArgs Name
codecTyCon [Name]
codecTyArgs (Name -> ExpQ) -> (Con -> Name) -> Con -> ExpQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Con -> Name
extractConName) [Con]
cons

  [| ProtocolDescription
        $(litE (stringL pname))
        protoDescription
        ""
        $(listE
            [ [| ( $(makeState $ ConT conName), stateDescription, $(conE agencyID)) |]
            | (conName, stateDescription, agencyID) <- pstates
            ]
         )
         $(listE messageInfos)
   |]

unearthType :: Type -> Type
#if MIN_VERSION_template_haskell(2,17,0)
unearthType :: Type -> Type
unearthType (AppT (AppT Type
MulArrowT Type
_) Type
t) = Type -> Type
unearthType Type
t
#endif
unearthType (SigT Type
a Type
_) = Type -> Type
unearthType Type
a
unearthType Type
t = Type
t


prettyTy :: Type -> String
prettyTy :: Type -> String
prettyTy = (Bool, String) -> String
forall a b. (a, b) -> b
snd ((Bool, String) -> String)
-> (Type -> (Bool, String)) -> Type -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> (Bool, String)
go
  where
    go :: Type -> (Bool, String)
go (ConT Name
n) = (Bool
False, Name -> String
nameBase Name
n)
    go (PromotedT Name
n) = (Bool
False, Name -> String
nameBase Name
n)
    go (VarT Name
n) = (Bool
False, Name -> String
nameBase Name
n)
    go (AppT Type
a Type
b) =
      let
        (Bool
_, String
a') = Type -> (Bool, String)
go Type
a
        (Bool
wrap, String
b') = Type -> (Bool, String)
go Type
b
      in
        (Bool
True, String
a' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ if Bool
wrap then String
"(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
b' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")" else String
b')
    go (ForallT [TyVarBndr Specificity]
_ [Type]
_ Type
a) = Type -> (Bool, String)
go Type
a
    go (ForallVisT [TyVarBndr ()]
_ Type
a) = Type -> (Bool, String)
go Type
a
    go (AppKindT Type
_ Type
a) = Type -> (Bool, String)
go Type
a
    go Type
t = (Bool
True, Type -> String
forall a. Show a => a -> String
show Type
t)

getDescription :: Name -> Q [Description]
getDescription :: Name -> Q [Description]
getDescription Name
name = do
#if MIN_VERSION_template_haskell(2,17,0)
  haddock <- Maybe String -> [String]
forall a. Maybe a -> [a]
maybeToList (Maybe String -> [String]) -> Q (Maybe String) -> Q [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DocLoc -> Q (Maybe String)
getDoc (Name -> DocLoc
DeclDoc Name
name)
#else
  -- 'getDoc' does not exist before template-haskell-2.17.0
  let haddock = []
#endif
  annotations <- reifyAnnotations (AnnLookupName name)
  return $ (Description . (:[]) <$> haddock) ++ annotations

unSigTy :: Type -> Type
unSigTy :: Type -> Type
unSigTy (SigT Type
t Type
_) = Type
t
unSigTy t :: Type
t@(VarT {}) = Type
t
unSigTy t :: Type
t@(PromotedT {}) = Type
t
unSigTy Type
t = String -> Type
forall a. HasCallStack => String -> a
error (String -> Type) -> String -> Type
forall a b. (a -> b) -> a -> b
$ Type -> String
forall a. Show a => a -> String
show Type
t

makeState :: Type -> ExpQ
makeState :: Type -> ExpQ
makeState (ConT Name
name) = Name -> ExpQ
forall (m :: * -> *). Quote m => Name -> m Exp
conE 'State ExpQ -> ExpQ -> ExpQ
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
`appE` (Lit -> ExpQ
forall (m :: * -> *). Quote m => Lit -> m Exp
litE (Lit -> ExpQ) -> (String -> Lit) -> String -> ExpQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Lit
stringL (String -> ExpQ) -> String -> ExpQ
forall a b. (a -> b) -> a -> b
$ Name -> String
nameBase Name
name)
makeState (PromotedT Name
name) = Name -> ExpQ
forall (m :: * -> *). Quote m => Name -> m Exp
conE 'State ExpQ -> ExpQ -> ExpQ
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
`appE` (Lit -> ExpQ
forall (m :: * -> *). Quote m => Lit -> m Exp
litE (Lit -> ExpQ) -> (String -> Lit) -> String -> ExpQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Lit
stringL (String -> ExpQ) -> String -> ExpQ
forall a b. (a -> b) -> a -> b
$ Name -> String
nameBase Name
name)
makeState (VarT Name
_) = Name -> ExpQ
forall (m :: * -> *). Quote m => Name -> m Exp
conE 'AnyState
makeState Type
ty = String -> ExpQ
forall a. HasCallStack => String -> a
error (String -> ExpQ) -> (Type -> String) -> Type -> ExpQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> String
forall a. Show a => a -> String
show (Type -> ExpQ) -> Type -> ExpQ
forall a b. (a -> b) -> a -> b
$ Type
ty -- conE 'AnyState

describeProtocolMessage :: Name -> [Name] -> Name -> [Name] -> Name -> ExpQ
describeProtocolMessage :: Name -> [Name] -> Name -> [Name] -> Name -> ExpQ
describeProtocolMessage Name
protoTyCon [Name]
protoTyArgs Name
codecTyCon [Name]
codecTyArgs Name
msgName = do
  msgInfo <- Name -> Q ConstructorInfo
reifyConstructor Name
msgName
  msgTyInfo <- reifyDatatype msgName
  msgDescription <- getDescription msgName


  let payloads = ConstructorInfo -> [Type]
constructorFields ConstructorInfo
msgInfo
#if MIN_VERSION_template_haskell(2,17,0)
      tyVarName :: TyVarBndr a -> Name
      tyVarName (PlainTV Name
n a
_) = Name
n
      tyVarName (KindedTV Name
n a
_ Type
_) = Name
n
#else
      tyVarName :: TyVarBndr -> Name
      tyVarName (PlainTV n) = n
      tyVarName (KindedTV n _) = n
#endif

      findType :: Name -> Cxt -> Type
      findType Name
n (AppT (AppT Type
EqualityT (VarT Name
vn)) Type
t : [Type]
_)
        | Name
vn Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
n
        = Type
t
      findType Name
n (Type
_ : [Type]
xs) = Name -> [Type] -> Type
findType Name
n [Type]
xs
      findType Name
n [] = Name -> Type
VarT Name
n

      fromStateVar = TyVarBndr () -> Name
forall a. TyVarBndr a -> Name
tyVarName (TyVarBndr () -> Name)
-> ([TyVarBndr ()] -> TyVarBndr ()) -> [TyVarBndr ()] -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TyVarBndr ()] -> TyVarBndr ()
forall a. HasCallStack => [a] -> a
last ([TyVarBndr ()] -> TyVarBndr ())
-> ([TyVarBndr ()] -> [TyVarBndr ()])
-> [TyVarBndr ()]
-> TyVarBndr ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TyVarBndr ()] -> [TyVarBndr ()]
forall a. HasCallStack => [a] -> [a]
init ([TyVarBndr ()] -> Name) -> [TyVarBndr ()] -> Name
forall a b. (a -> b) -> a -> b
$ DatatypeInfo -> [TyVarBndr ()]
datatypeVars DatatypeInfo
msgTyInfo
      toStateVar = TyVarBndr () -> Name
forall a. TyVarBndr a -> Name
tyVarName (TyVarBndr () -> Name)
-> ([TyVarBndr ()] -> TyVarBndr ()) -> [TyVarBndr ()] -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TyVarBndr ()] -> TyVarBndr ()
forall a. HasCallStack => [a] -> a
last ([TyVarBndr ()] -> Name) -> [TyVarBndr ()] -> Name
forall a b. (a -> b) -> a -> b
$ DatatypeInfo -> [TyVarBndr ()]
datatypeVars DatatypeInfo
msgTyInfo
      fromState = Name -> [Type] -> Type
findType Name
fromStateVar (ConstructorInfo -> [Type]
constructorContext ConstructorInfo
msgInfo)
      toState = Name -> [Type] -> Type
findType Name
toStateVar (ConstructorInfo -> [Type]
constructorContext ConstructorInfo
msgInfo)

  [e| MessageDescription
        { messageName = $(litE . stringL . nameBase $ msgName)
        , messageDescription = msgDescription
        , messagePayload = $(listE (map (litE . stringL . prettyTy) payloads))
        , messageFromState = $(makeState . unearthType $ fromState)
        , messageToState = $(makeState . unearthType $ toState)
        , messageInfo =
            infoOf $(litE . stringL . nameBase $ msgName) $
            info 
              (Proxy :: Proxy $(pure $ applyTyArgs (ConT codecTyCon) codecTyArgs))
              (Proxy :: Proxy ( $(conT $ datatypeName msgTyInfo)
                                   $(pure $ applyTyArgs (ConT protoTyCon) protoTyArgs)
                                   $(pure $ unSigTy fromState)
                                   $(pure $ unSigTy toState)
                                 )
              )
        }
    |]

extractConName :: Con -> Name
extractConName :: Con -> Name
extractConName Con
con = case Con
con of
  NormalC Name
n [BangType]
_ -> Name
n
  RecC Name
n [VarBangType]
_ -> Name
n
  InfixC BangType
_ Name
n BangType
_ -> Name
n
  ForallC [TyVarBndr Specificity]
_ [Type]
_ Con
c -> Con -> Name
extractConName Con
c
  GadtC (Name
name:[Name]
_) [BangType]
_ Type
_ -> Name
name
  RecGadtC (Name
name:[Name]
_) [VarBangType]
_ Type
_ -> Name
name
  Con
x -> String -> Name
forall a. HasCallStack => String -> a
error (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"Cannot extract constructor name from " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Con -> String
forall a. Show a => a -> String
show Con
x