diff --git a/examples/basalt.rs b/examples/basalt.rs index 3841786..a5a25c3 100644 --- a/examples/basalt.rs +++ b/examples/basalt.rs @@ -159,15 +159,11 @@ impl Example { #[async_trait] impl EndpointHandler for Example { - async fn handle( - self: &Arc, - msg: Req, - _from: NodeID, - ) -> Resp { + async fn handle(self: &Arc, msg: &ExampleMessage, _from: NodeID) -> ExampleResponse { debug!("Got example message: {:?}, sending example response", msg); - Resp::new(ExampleResponse { + ExampleResponse { example_field: false, - }) + } } } diff --git a/src/endpoint.rs b/src/endpoint.rs index 8ee64a5..ff626d8 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -10,11 +10,13 @@ use crate::netapp::*; use crate::util::*; /// This trait should be implemented by an object of your application -/// that can handle a message of type `M`. +/// that can handle a message of type `M`, if it wishes to handle +/// streams attached to the request and/or to send back streams +/// attached to the response.. /// /// The handler object should be in an Arc, see `Endpoint::set_handler` #[async_trait] -pub trait EndpointHandler: Send + Sync +pub trait StreamingEndpointHandler: Send + Sync where M: Message, { @@ -27,11 +29,34 @@ where /// it will panic if it is ever made to handle request. #[async_trait] impl EndpointHandler for () { - async fn handle(self: &Arc<()>, _m: Req, _from: NodeID) -> Resp { + async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response { panic!("This endpoint should not have a local handler."); } } +// ---- + +#[async_trait] +pub trait EndpointHandler: Send + Sync +where + M: Message, +{ + async fn handle(self: &Arc, m: &M, from: NodeID) -> ::Response; +} + +#[async_trait] +impl StreamingEndpointHandler for T +where + T: EndpointHandler, + M: Message + 'static, +{ + async fn handle(self: &Arc, m: Req, from: NodeID) -> Resp { + Resp::new(EndpointHandler::handle(self, m.msg(), from).await) + } +} + +// ---- + /// This struct represents an endpoint for message of type `M`. /// /// Creating a new endpoint is done by calling `NetApp::endpoint`. @@ -41,13 +66,13 @@ impl EndpointHandler for () { /// An `Endpoint` is used both to send requests to remote nodes, /// and to specify the handler for such requests on the local node. /// The type `H` represents the type of the handler object for -/// endpoint messages (see `EndpointHandler`). +/// endpoint messages (see `StreamingEndpointHandler`). pub struct Endpoint where M: Message, - H: EndpointHandler, + H: StreamingEndpointHandler, { - phantom: PhantomData, + _phantom: PhantomData, netapp: Arc, path: String, handler: ArcSwapOption, @@ -56,11 +81,11 @@ where impl Endpoint where M: Message, - H: EndpointHandler, + H: StreamingEndpointHandler, { pub(crate) fn new(netapp: Arc, path: String) -> Self { Self { - phantom: PhantomData::default(), + _phantom: PhantomData::default(), netapp, path, handler: ArcSwapOption::from(None), @@ -79,8 +104,10 @@ where } /// Call this endpoint on a remote node (or on the local node, - /// for that matter) - pub async fn call_full( + /// for that matter). This function invokes the full version that + /// allows to attach a streaming body to the request and to + /// receive such a body attached to the response. + pub async fn call_streaming( &self, target: &NodeID, req: T, @@ -112,15 +139,16 @@ where } } - /// Call this endpoint on a remote node, without the possibility - /// of adding or receiving a body + /// Call this endpoint on a remote node. This function is the simplified + /// version that doesn't allow to have streams attached to the request + /// or the response; see `call_streaming` for the full version. pub async fn call( &self, target: &NodeID, req: M, prio: RequestPriority, ) -> Result<::Response, Error> { - Ok(self.call_full(target, req, prio).await?.into_msg()) + Ok(self.call_streaming(target, req, prio).await?.into_msg()) } } @@ -144,13 +172,13 @@ pub(crate) trait GenericEndpoint { pub(crate) struct EndpointArc(pub(crate) Arc>) where M: Message, - H: EndpointHandler; + H: StreamingEndpointHandler; #[async_trait] impl GenericEndpoint for EndpointArc where M: Message + 'static, - H: EndpointHandler + 'static, + H: StreamingEndpointHandler + 'static, { async fn handle( &self, diff --git a/src/message.rs b/src/message.rs index d918c29..5721318 100644 --- a/src/message.rs +++ b/src/message.rs @@ -311,7 +311,7 @@ impl Framing { } } - pub async fn from_stream + Unpin + Send + 'static>( + pub async fn from_stream + Unpin + Send + Sync + 'static>( mut stream: S, ) -> Result { let mut packet = stream diff --git a/src/netapp.rs b/src/netapp.rs index 0cebac0..166f560 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -152,7 +152,7 @@ impl NetApp { pub fn endpoint(self: &Arc, path: String) -> Arc> where M: Message + 'static, - H: EndpointHandler + 'static, + H: StreamingEndpointHandler + 'static, { let endpoint = Arc::new(Endpoint::::new(self.clone(), path.clone())); let endpoint_arc = EndpointArc(endpoint.clone()); @@ -433,8 +433,7 @@ impl NetApp { #[async_trait] impl EndpointHandler for NetApp { - async fn handle(self: &Arc, msg: Req, from: NodeID) -> Resp { - let msg = msg.msg(); + async fn handle(self: &Arc, msg: &HelloMessage, from: NodeID) { debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg); if let Some(h) = self.on_connected_handler.load().as_ref() { if let Some(c) = self.server_conns.read().unwrap().get(&from) { @@ -443,6 +442,5 @@ impl EndpointHandler for NetApp { h(from, remote_addr, true); } } - Resp::new(()) } } diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index 71dea84..310077f 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -468,24 +468,15 @@ impl Basalt { #[async_trait] impl EndpointHandler for Basalt { - async fn handle( - self: &Arc, - _pullmsg: Req, - _from: NodeID, - ) -> Resp { - Resp::new(self.make_push_message()) + async fn handle(self: &Arc, _pullmsg: &PullMessage, _from: NodeID) -> PushMessage { + self.make_push_message() } } #[async_trait] impl EndpointHandler for Basalt { - async fn handle( - self: &Arc, - pushmsg: Req, - _from: NodeID, - ) -> Resp { - self.handle_peer_list(&pushmsg.msg().peers[..]); - Resp::new(()) + async fn handle(self: &Arc, pushmsg: &PushMessage, _from: NodeID) { + self.handle_peer_list(&pushmsg.peers[..]); } } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 9b7b666..ccbd0ba 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -583,14 +583,13 @@ impl FullMeshPeeringStrategy { #[async_trait] impl EndpointHandler for FullMeshPeeringStrategy { - async fn handle(self: &Arc, ping: Req, from: NodeID) -> Resp { - let ping = ping.msg(); + async fn handle(self: &Arc, ping: &PingMessage, from: NodeID) -> PingMessage { let ping_resp = PingMessage { id: ping.id, peer_list_hash: self.known_hosts.read().unwrap().hash, }; debug!("Ping from {}", hex::encode(&from[..8])); - Resp::new(ping_resp) + ping_resp } } @@ -598,12 +597,11 @@ impl EndpointHandler for FullMeshPeeringStrategy { impl EndpointHandler for FullMeshPeeringStrategy { async fn handle( self: &Arc, - peer_list: Req, + peer_list: &PeerListMessage, _from: NodeID, - ) -> Resp { - let peer_list = peer_list.msg(); + ) -> PeerListMessage { self.handle_peer_list(&peer_list.list[..]); let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list); - Resp::new(PeerListMessage { list: peer_list }) + PeerListMessage { list: peer_list } } } diff --git a/src/util.rs b/src/util.rs index e7ecea8..01c392c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -24,7 +24,7 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// /// Error code 255 means the stream was cut before its end. Other codes have no predefined /// meaning, it's up to your application to define their semantic. -pub type ByteStream = Pin + Send>>; +pub type ByteStream = Pin + Send + Sync>>; pub type Packet = Result;