Skip to content

Commit

Permalink
Merge pull request #730 from input-output-hk/plt-8047-propagate-errors
Browse files Browse the repository at this point in the history
PLT-8047 propagate errors
  • Loading branch information
jhbertra authored Oct 17, 2023
2 parents cd39f39 + 73b6c1b commit 587333d
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 85 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
### Added

- Propagate errors from server to client in protocol sessions.
1 change: 1 addition & 0 deletions marlowe-protocols/marlowe-protocols.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ library
Network.Protocol.Connection
Network.Protocol.Driver
Network.Protocol.Driver.Trace
Network.Protocol.Driver.Untyped
Network.Protocol.Handshake.Client
Network.Protocol.Handshake.Server
Network.Protocol.Handshake.Types
Expand Down
3 changes: 2 additions & 1 deletion marlowe-protocols/src/Network/Channel/Typed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import qualified Data.ByteString.Lazy as LBS
import Network.Channel (socketAsChannel)
import Network.Protocol.Codec (BinaryMessage)
import Network.Protocol.Driver.Trace
import qualified Network.Protocol.Driver.Untyped as Untyped
import Network.Protocol.Peer.Trace
import Network.Socket (
AddrInfo (addrSocketType),
Expand Down Expand Up @@ -331,7 +332,7 @@ tcpClientChannel inj host port = withInjectEvent inj Connect \ev -> do
let driver =
mkDriverTraced
(composeInjectSelector inj $ injectSelector $ ClientDriver addr)
(socketAsChannel socket)
(Untyped.mkDriver $ socketAsChannel socket)
pure
( driverToChannel (composeInjectSelector inj $ injectSelector $ ClientPeer addr) driver
, reference ev
Expand Down
2 changes: 1 addition & 1 deletion marlowe-protocols/src/Network/Protocol/Connection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module Network.Protocol.Connection where
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Resource (runResourceT, transResourceT)
import Control.Monad.Trans.Resource.Internal (ResourceT (..))
import Data.ByteString.Lazy (ByteString)
import Data.ByteString (ByteString)
import Network.Protocol.Peer.Trace
import Network.TypedProtocol (Message, PeerHasAgency)
import UnliftIO (MonadUnliftIO)
Expand Down
42 changes: 20 additions & 22 deletions marlowe-protocols/src/Network/Protocol/Driver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ import Control.Concurrent.Component
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Resource (runResourceT)
import Data.ByteString.Lazy (ByteString)
import Data.ByteString (ByteString)
import Data.Proxy (Proxy (Proxy))
import Network.Channel (Channel (..), socketAsChannel)
import Network.Protocol.Codec (BinaryMessage, DeserializeError, binaryCodec)
import qualified Data.Text as T
import Network.Channel (socketAsChannel)
import Network.Protocol.Codec (BinaryMessage (..))
import Network.Protocol.Connection (Connection (..), Connector (..), ServerSource (..), ToPeer)
import qualified Network.Protocol.Driver.Untyped as Untyped
import Network.Protocol.Handshake.Client (handshakeClientPeer, simpleHandshakeClient)
import Network.Protocol.Handshake.Server (handshakeServerPeer, simpleHandshakeServer)
import Network.Protocol.Handshake.Types (HasSignature, signature)
Expand All @@ -35,40 +37,29 @@ import Network.Socket (
openSocket,
)
import Network.TypedProtocol (Message, PeerHasAgency, PeerRole (..), SomeMessage (..), runPeerWithDriver)
import Network.TypedProtocol.Codec (Codec (..), DecodeStep (..))
import Network.TypedProtocol.Driver (Driver (..))
import UnliftIO (MonadIO, MonadUnliftIO, finally, throwIO, withRunInIO)
import UnliftIO (Exception (..), MonadIO, MonadUnliftIO, SomeException (..), catch, finally, throwIO, withRunInIO)

mkDriver
:: forall ps m
. (MonadIO m, BinaryMessage ps)
=> Channel m ByteString
=> Untyped.Driver m
-> Driver ps (Maybe ByteString) m
mkDriver Channel{..} = Driver{..}
mkDriver Untyped.Driver{..} = Driver{..}
where
Codec{..} = binaryCodec
sendMessage
:: forall (pr :: PeerRole) (st :: ps) (st' :: ps)
. PeerHasAgency pr st
-> Message ps st st'
-> m ()
sendMessage tok = send . encode tok
sendMessage = fmap sendSuccessMessage . putMessage

recvMessage
:: forall (pr :: PeerRole) (st :: ps)
. PeerHasAgency pr st
-> Maybe ByteString
-> m (SomeMessage st, Maybe ByteString)
recvMessage tok trailing = decodeChannel trailing =<< decode tok

decodeChannel
:: Maybe ByteString
-> DecodeStep ByteString DeserializeError m a
-> m (a, Maybe ByteString)
decodeChannel _ (DecodeDone a trailing) = pure (a, trailing)
decodeChannel _ (DecodeFail failure) = throwIO failure
decodeChannel Nothing (DecodePartial next) = recv >>= next >>= decodeChannel Nothing
decodeChannel trailing (DecodePartial next) = next trailing >>= decodeChannel Nothing
recvMessage tok trailing = either throwIO pure =<< recvMessageUntyped trailing (getMessage tok)

startDState :: Maybe ByteString
startDState = Nothing
Expand Down Expand Up @@ -97,10 +88,11 @@ tcpServer
tcpServer name = component_ (name <> "-tcp-server") \TcpServerDependencies{..} ->
withRunInIO \runInIO -> runTCPServer (Just host) (show port) $ runComponent_ $ hoistComponent runInIO $ component_ (name <> "-tcp-worker") \socket -> runResourceT do
server <- getServer serverSource
let driver = mkDriver $ socketAsChannel socket
let handshakeServer = simpleHandshakeServer (signature $ Proxy @ps) server
let peer = peerTracedToPeer $ handshakeServerPeer toPeer handshakeServer
lift $ fst <$> runPeerWithDriver driver peer (startDState driver)
let untypedDriver = Untyped.mkDriver $ socketAsChannel socket
let driver = mkDriver untypedDriver
lift $ rethrowErrors untypedDriver $ fst <$> runPeerWithDriver driver peer (startDState driver)

tcpClient
:: forall client ps st m
Expand All @@ -121,8 +113,14 @@ tcpClient host port toPeer = Connector $ liftIO $ do
pure
Connection
{ runConnection = \client -> do
let driver = mkDriver $ socketAsChannel socket
let untypedDriver = Untyped.mkDriver $ socketAsChannel socket
let driver = mkDriver untypedDriver
let handshakeClient = simpleHandshakeClient (signature $ Proxy @ps) client
let peer = peerTracedToPeer $ handshakeClientPeer toPeer handshakeClient
fst <$> runPeerWithDriver driver peer (startDState driver) `finally` liftIO (close socket)
}

rethrowErrors :: (MonadUnliftIO m) => Untyped.Driver m -> m a -> m a
rethrowErrors Untyped.Driver{..} = flip catch \(SomeException ex) -> do
catch (sendFailureMessage $ T.pack $ displayException ex) \SomeException{} -> pure ()
throwIO ex
99 changes: 38 additions & 61 deletions marlowe-protocols/src/Network/Protocol/Driver/Trace.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,18 @@ import Data.Binary (Binary, get, getWord8, put)
import Data.Binary.Get (runGet)
import Data.Binary.Put (putWord8, runPut)
import qualified Data.ByteString as B
import Data.ByteString.Lazy (ByteString)
import Data.ByteString.Base16 (encodeBase16)
import qualified Data.ByteString.Lazy as LBS
import Data.ByteString.Lazy.Base16 (encodeBase16)
import Data.Foldable (traverse_)
import Data.Int (Int64)
import Data.List (intercalate)
import Data.Proxy
import Data.String (fromString)
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Lazy as TL
import Data.Void (Void)
import Network.Channel hiding (close)
import Network.Protocol.Codec (BinaryMessage, DeserializeError, decodeGet, getMessage, putMessage)
import Network.Protocol.Codec (BinaryMessage, getMessage, putMessage)
import Network.Protocol.Codec.Spec (ShowProtocol (..))
import Network.Protocol.Connection (
Connection (..),
Expand All @@ -44,7 +42,8 @@ import Network.Protocol.Connection (
ServerSource (..),
ToPeer,
)
import Network.Protocol.Driver (TcpServerDependencies (..))
import Network.Protocol.Driver (TcpServerDependencies (..), rethrowErrors)
import qualified Network.Protocol.Driver.Untyped as Untyped
import Network.Protocol.Handshake.Client (handshakeClientPeer, simpleHandshakeClient)
import Network.Protocol.Handshake.Server (handshakeServerPeer, simpleHandshakeServer)
import Network.Protocol.Handshake.Types (Handshake, HasSignature, signature)
Expand Down Expand Up @@ -239,21 +238,20 @@ tcpServerTraced name inj = component_ (name <> "-tcp-server") \TcpServerDependen
let closeArgs = (simpleNewEventArgs CloseServer){newEventParent = Just parentRef}
server <- getServer serverSource
lift $ localBackend (setAncestorEventBackend parentRef) do
let driver =
mkDriverTraced
(composeInjectSelector inj $ injectSelector $ ServerDriver addr pName)
(socketAsChannel socket)
handshakeServer = simpleHandshakeServer (signature $ Proxy @ps) server
peer = handshakeServerPeer toPeer handshakeServer
let handshakeServer = simpleHandshakeServer (signature $ Proxy @ps) server
let peer = handshakeServerPeer toPeer handshakeServer
let untypedDriver = Untyped.mkDriver $ socketAsChannel socket
let driver = mkDriverTraced (composeInjectSelector inj $ injectSelector $ ServerDriver addr pName) untypedDriver
mask \restore -> do
result <-
restore $
try $
runPeerWithDriverTraced
(composeInjectSelector inj $ injectSelector $ ServerPeer addr pName)
driver
peer
(startDStateTraced driver)
rethrowErrors untypedDriver $
runPeerWithDriverTraced
(composeInjectSelector inj $ injectSelector $ ServerPeer addr pName)
driver
peer
(startDStateTraced driver)
withInjectEventArgs inj closeArgs \ev' -> do
case result of
Left ex -> do
Expand Down Expand Up @@ -291,21 +289,20 @@ tcpClientTraced inj host port toPeer = Connector $
pure
Connection
{ runConnection = \client -> localBackend (setAncestorEventBackend $ reference ev) do
let driver =
mkDriverTraced
(composeInjectSelector inj $ injectSelector $ ClientDriver addr)
(socketAsChannel socket)
handshakeClient = simpleHandshakeClient (signature $ Proxy @ps) client
peer = handshakeClientPeer toPeer handshakeClient
let untypedDriver = Untyped.mkDriver $ socketAsChannel socket
let driver = mkDriverTraced (composeInjectSelector inj $ injectSelector $ ClientDriver addr) untypedDriver
let handshakeClient = simpleHandshakeClient (signature $ Proxy @ps) client
let peer = handshakeClientPeer toPeer handshakeClient
mask \restore -> do
result <-
restore $
try $
runPeerWithDriverTraced
(composeInjectSelector inj $ injectSelector $ ClientPeer addr)
driver
peer
(startDStateTraced driver)
rethrowErrors untypedDriver $
runPeerWithDriverTraced
(composeInjectSelector inj $ injectSelector $ ClientPeer addr)
driver
peer
(startDStateTraced driver)
withInjectEventArgs inj closeArgs \ev' -> do
liftIO $ close socket
case result of
Expand All @@ -331,9 +328,9 @@ mkDriverTraced
:: forall ps r s m
. (MonadIO m, BinaryMessage ps, HasSpanContext r, MonadEvent r s m)
=> InjectSelector (DriverSelector ps) s
-> Channel m ByteString
-> DriverTraced ps (Maybe ByteString) r m
mkDriverTraced inj Channel{..} = DriverTraced{..}
-> Untyped.Driver m
-> DriverTraced ps (Maybe B.ByteString) r m
mkDriverTraced inj Untyped.Driver{..} = DriverTraced{..}
where
sendMessageTraced
:: forall (pr :: PeerRole) (st :: ps) (st' :: ps)
Expand All @@ -343,16 +340,17 @@ mkDriverTraced inj Channel{..} = DriverTraced{..}
-> m ()
sendMessageTraced r tok msg = withInjectEventFields inj (SendMessage tok msg) [()] \ev -> do
spanContext <- context r
addField ev =<< send (runPut $ put spanContext *> putMessage tok msg)
addField ev ()
sendSuccessMessage $ put spanContext
sendSuccessMessage $ putMessage tok msg

recvMessageTraced
:: forall (pr :: PeerRole) (st :: ps)
. PeerHasAgency pr st
-> Maybe ByteString
-> m (r, SomeMessage st, Maybe ByteString)
-> Maybe B.ByteString
-> m (r, SomeMessage st, Maybe B.ByteString)
recvMessageTraced tok trailing = do
let
(ctx, trailing') <- decodeChannel trailing =<< decodeGet get
(ctx, trailing') <- either throwIO pure =<< recvMessageUntyped trailing get
let r = wrapContext ctx
let args =
(simpleNewEventArgs $ RecvMessage tok)
Expand All @@ -363,33 +361,12 @@ mkDriverTraced inj Channel{..} = DriverTraced{..}
]
}
withInjectEventArgs inj args \ev -> do
(SomeMessage msg, trailing'') <- decodeChannel trailing' =<< decodeGet (getMessage tok)
(SomeMessage msg, trailing'') <- either throwIO pure =<< recvMessageUntyped trailing' (getMessage tok)
addField ev $ RecvMessageStateAfterMessage trailing''
addField ev $ RecvMessageMessage msg
pure (r, SomeMessage msg, trailing'')

decodeChannel
:: Maybe ByteString
-> DecodeStep ByteString DeserializeError m a
-> m (a, Maybe ByteString)
decodeChannel trailing (DecodeDone a _) = pure (a, trailing)
decodeChannel _ (DecodeFail failure) = throwIO failure
decodeChannel trailing (DecodePartial p) =
case trailing of
Nothing -> go $ DecodePartial p
Just trailing' -> go =<< p (Just trailing')
where
go = \case
DecodeDone a Nothing -> pure (a, Nothing)
DecodeDone a (Just trailing') -> do
pure (a, Just trailing')
DecodeFail failure -> throwIO failure
DecodePartial next -> do
mBytes <- recv
nextStep <- next mBytes
go nextStep

startDStateTraced :: Maybe ByteString
startDStateTraced :: Maybe B.ByteString
startDStateTraced = Nothing

instance Binary TraceFlags where
Expand Down Expand Up @@ -522,10 +499,10 @@ renderDriverSelectorOTel = \case
ClientAgency tok' -> showsPrecClientHasAgency 0 tok' ""
ServerAgency tok' -> showsPrecServerHasAgency 0 tok' ""
)
, ("typed-protocols.driver_state_before_span", toAttribute $ TL.toStrict $ foldMap encodeBase16 state)
, ("typed-protocols.driver_state_before_span", toAttribute $ foldMap encodeBase16 state)
]
RecvMessageStateBeforeMessage state -> [("typed-protocols.driver_state_before_message", toAttribute $ TL.toStrict $ foldMap encodeBase16 state)]
RecvMessageStateAfterMessage state -> [("typed-protocols.driver_state_after_message", toAttribute $ TL.toStrict $ foldMap encodeBase16 state)]
RecvMessageStateBeforeMessage state -> [("typed-protocols.driver_state_before_message", toAttribute $ foldMap encodeBase16 state)]
RecvMessageStateAfterMessage state -> [("typed-protocols.driver_state_after_message", toAttribute $ foldMap encodeBase16 state)]
RecvMessageMessage msg -> messageToAttributes $ AnyMessageAndAgency tok msg
}

Expand Down
74 changes: 74 additions & 0 deletions marlowe-protocols/src/Network/Protocol/Driver/Untyped.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
{-# LANGUAGE RankNTypes #-}

module Network.Protocol.Driver.Untyped where

import Control.Exception (Exception)
import Control.Monad (guard)
import Data.Binary
import Data.Binary.Get (ByteOffset, Decoder (..), isEmpty, label, pushChunk, runGetIncremental)
import Data.Binary.Put (runPut)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import Data.Text (Text)
import GHC.Generics (Generic)
import Network.Channel

-- | An untyped protocol driver. Sits between a @Network.TypedProtocol.Driver@ which
-- is a stateful, typed channel for a specific protocol, and a @Network.Channel@ which is an unstructured,
-- stateless channel for raw bytes.
--
-- An untyped driver is able to send or receive arbitrary data, but it does so in a structured manner. Each payload will
-- be preceded by a status byte which indicates if the message is a normal, expected payload, or if it is an exception.
data Driver m = Driver
{ sendSuccessMessage :: Put -> m ()
-- ^ Send a normal message encoded as a @Data.Binary.Put@
, sendFailureMessage :: Text -> m ()
-- ^ Send an exception message.
, recvMessageUntyped :: forall a. Maybe BS.ByteString -> Get a -> m (Either RecvError (a, Maybe BS.ByteString))
-- ^ Receive a message and attempt to decode it using a @Data.Binary.Get@.
}

-- | What can go wrong during a recv call.
data RecvError
= -- | The peer disconnected unexpectedly.
PeerDisconnected
| -- | The peer crashed and sent an exception message.
PeerCrashed Text
| -- | The peer sent unexpected binary data.
DeserializeError BS.ByteString ByteOffset String
deriving stock (Show, Read, Eq, Ord)
deriving anyclass (Exception)

data StatusToken
= SuccessToken
| FailureToken
deriving stock (Show, Read, Eq, Ord, Enum, Bounded, Generic)
deriving anyclass (Binary)

-- | Create a driver which will operate over a channel.
mkDriver :: forall m. (Monad m) => Channel m LBS.ByteString -> Driver m
mkDriver Channel{..} =
Driver
{ sendSuccessMessage = send . runPut . (put SuccessToken *>)
, sendFailureMessage = send . runPut . (put FailureToken *>) . put
, recvMessageUntyped = \trailing getMessage ->
runDecoder $ maybe id (flip pushChunk) trailing $ runGetIncremental do
isEmpty >>= \case
True -> pure Nothing
False ->
Just <$> do
token <- label "StatusToken" get
case token of
SuccessToken -> Right <$> label "Message" getMessage
FailureToken -> Left <$> label "Failure" get
}
where
runDecoder :: Decoder (Maybe (Either Text a)) -> m (Either RecvError (a, Maybe BS.ByteString))
runDecoder = \case
Fail unconsumed byteOffset msg -> pure $ Left $ DeserializeError unconsumed byteOffset msg
Partial consumeNext -> do
next <- fmap LBS.toStrict <$> recv
runDecoder $ consumeNext next
Done _ _ Nothing -> pure $ Left PeerDisconnected
Done _ _ (Just (Left msg)) -> pure $ Left $ PeerCrashed msg
Done unconsumed _ (Just (Right msg)) -> pure $ Right (msg, unconsumed <$ guard (not $ BS.null unconsumed))

0 comments on commit 587333d

Please sign in to comment.