From 4934ed726d51913afd97ca937d0ece39ef8b7371 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 21 Jul 2022 20:22:56 +0200 Subject: [PATCH] Propose alternative API --- examples/basalt.rs | 10 +- src/client.rs | 15 +-- src/endpoint.rs | 45 ++++++--- src/message.rs | 211 +++++++++++++++++++++++++++------------- src/netapp.rs | 5 +- src/peering/basalt.rs | 17 +++- src/peering/fullmesh.rs | 12 ++- 7 files changed, 219 insertions(+), 96 deletions(-) diff --git a/examples/basalt.rs b/examples/basalt.rs index dd56cd7..3841786 100644 --- a/examples/basalt.rs +++ b/examples/basalt.rs @@ -159,11 +159,15 @@ impl Example { #[async_trait] impl EndpointHandler for Example { - async fn handle(self: &Arc, msg: ExampleMessage, _from: NodeID) -> ExampleResponse { + async fn handle( + self: &Arc, + msg: Req, + _from: NodeID, + ) -> Resp { debug!("Got example message: {:?}, sending example response", msg); - ExampleResponse { + Resp::new(ExampleResponse { example_field: false, - } + }) } } diff --git a/src/client.rs b/src/client.rs index 9d572aa..c878627 100644 --- a/src/client.rs +++ b/src/client.rs @@ -135,10 +135,10 @@ impl ClientConn { pub(crate) async fn call( self: Arc, - rq: T, + req: Req, path: &str, prio: RequestPriority, - ) -> Result<::Response, Error> + ) -> Result, Error> where T: Message, { @@ -162,9 +162,8 @@ impl ClientConn { }; // Encode request - let (rq, stream) = rq.into_parts(); - let body = rmp_to_vec_all_named(&rq)?; - drop(rq); + let body = req.msg_ser.unwrap().clone(); + let stream = req.body.into_stream(); let request = QueryMessage { prio, @@ -216,7 +215,11 @@ impl ClientConn { let code = resp[0]; if code == 0 { let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?; - Ok(T::Response::from_parts(ser_resp, stream)) + Ok(Resp { + _phantom: Default::default(), + msg: ser_resp, + body: BodyData::Stream(stream), + }) } else { let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); Err(Error::Remote(code, msg)) diff --git a/src/endpoint.rs b/src/endpoint.rs index 97a7644..8ee64a5 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -18,7 +18,7 @@ pub trait EndpointHandler: Send + Sync where M: Message, { - async fn handle(self: &Arc, m: M, from: NodeID) -> M::Response; + async fn handle(self: &Arc, m: Req, from: NodeID) -> Resp; } /// If one simply wants to use an endpoint in a client fashion, @@ -27,7 +27,7 @@ where /// it will panic if it is ever made to handle request. #[async_trait] impl EndpointHandler for () { - async fn handle(self: &Arc<()>, _m: M, _from: NodeID) -> M::Response { + async fn handle(self: &Arc<()>, _m: Req, _from: NodeID) -> Resp { panic!("This endpoint should not have a local handler."); } } @@ -80,16 +80,19 @@ where /// Call this endpoint on a remote node (or on the local node, /// for that matter) - pub async fn call( + pub async fn call_full( &self, target: &NodeID, - req: M, + req: T, prio: RequestPriority, - ) -> Result<::Response, Error> { + ) -> Result, Error> + where + T: IntoReq, + { if *target == self.netapp.id { match self.handler.load_full() { None => Err(Error::NoHandler), - Some(h) => Ok(h.handle(req, self.netapp.id).await), + Some(h) => Ok(h.handle(req.into_req_local(), self.netapp.id).await), } } else { let conn = self @@ -104,10 +107,21 @@ where "Not connected: {}", hex::encode(&target[..8]) ))), - Some(c) => c.call(req, self.path.as_str(), prio).await, + Some(c) => c.call(req.into_req()?, self.path.as_str(), prio).await, } } } + + /// Call this endpoint on a remote node, without the possibility + /// of adding or receiving a body + pub async fn call( + &self, + target: &NodeID, + req: M, + prio: RequestPriority, + ) -> Result<::Response, Error> { + Ok(self.call_full(target, req, prio).await?.into_msg()) + } } // ---- Internal stuff ---- @@ -148,11 +162,20 @@ where None => Err(Error::NoHandler), Some(h) => { let req = rmp_serde::decode::from_read_ref(buf)?; - let req = M::from_parts(req, stream); + let req = Req { + _phantom: Default::default(), + msg: Arc::new(req), + msg_ser: None, + body: BodyData::Stream(stream), + }; let res = h.handle(req, from).await; - let (res, res_stream) = res.into_parts(); - let res_bytes = rmp_to_vec_all_named(&res)?; - Ok((res_bytes, res_stream)) + let Resp { + msg, + body, + _phantom, + } = res; + let res_bytes = rmp_to_vec_all_named(&msg)?; + Ok((res_bytes, body.into_stream())) } } } diff --git a/src/message.rs b/src/message.rs index 22cae6a..d918c29 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,3 +1,7 @@ +use std::fmt; +use std::marker::PhantomData; +use std::sync::Arc; + use bytes::Bytes; use serde::{Deserialize, Serialize}; @@ -37,94 +41,169 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; /// This trait should be implemented by all messages your application /// wants to handle -pub trait Message: SerializeMessage + Send + Sync { - type Response: SerializeMessage + Send + Sync; +pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { + type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; } -/// A trait for de/serializing messages, with possible associated stream. -/// This is default-implemented by anything that can already be serialized -/// and deserialized. Adapters are provided that implement this for -/// adding a body, either from a fixed Bytes buffer (which allows the thing -/// to be Clone), or from a streaming byte stream. -pub trait SerializeMessage: Sized { - type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; - - fn into_parts(self) -> (Self::SerializableSelf, Option); - - fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self; +pub struct Req { + pub(crate) _phantom: PhantomData, + pub(crate) msg: Arc, + pub(crate) msg_ser: Option, + pub(crate) body: BodyData, } -// ---- +pub struct Resp { + pub(crate) _phantom: PhantomData, + pub(crate) msg: M::Response, + pub(crate) body: BodyData, +} -impl SerializeMessage for T -where - T: Serialize + for<'de> Deserialize<'de> + Send, -{ - type SerializableSelf = Self; - fn into_parts(self) -> (Self::SerializableSelf, Option) { - (self, None) - } +pub(crate) enum BodyData { + None, + Fixed(Bytes), + Stream(ByteStream), +} - fn from_parts(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { - // TODO verify no stream - ser_self +impl BodyData { + pub fn into_stream(self) -> Option { + match self { + BodyData::None => None, + BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), + BodyData::Stream(s) => Some(s), + } } } // ---- -/// An adapter that adds a body from a fixed Bytes buffer to a serializable message, -/// implementing the SerializeMessage trait. This allows for the SerializeMessage object -/// to be cloned, which is usefull for requests that must be sent to multiple servers. -/// Note that cloning the body is cheap thanks to Bytes; make sure that your serializable -/// part is also easily clonable (e.g. by wrapping it in an Arc). -/// Note that this CANNOT be used for a response type, as it cannot be reconstructed -/// from a remote stream. -#[derive(Clone)] -pub struct WithFixedBody Deserialize<'de> + Clone + Send + 'static>( - pub T, - pub Bytes, -); - -impl SerializeMessage for WithFixedBody -where - T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static, -{ - type SerializableSelf = T; - - fn into_parts(self) -> (Self::SerializableSelf, Option) { - let body = self.1; - ( - self.0, - Some(Box::pin(futures::stream::once(async move { Ok(body) }))), - ) +impl Req { + pub fn msg(&self) -> &M { + &self.msg } - fn from_parts(_ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { - panic!("Cannot use a WithFixedBody as a response type"); + pub fn with_fixed_body(self, b: Bytes) -> Self { + Self { + body: BodyData::Fixed(b), + ..self + } + } + + pub fn with_streaming_body(self, b: ByteStream) -> Self { + Self { + body: BodyData::Stream(b), + ..self + } } } -/// An adapter that adds a body from a ByteStream. This is usefull for receiving -/// responses to requests that contain attached byte streams. This type is -/// not clonable. -pub struct WithStreamingBody Deserialize<'de> + Send>( - pub T, - pub ByteStream, -); +pub trait IntoReq { + fn into_req(self) -> Result, rmp_serde::encode::Error>; + fn into_req_local(self) -> Req; +} -impl SerializeMessage for WithStreamingBody +impl IntoReq for M { + fn into_req(self) -> Result, rmp_serde::encode::Error> { + let msg_ser = rmp_to_vec_all_named(&self)?; + Ok(Req { + _phantom: Default::default(), + msg: Arc::new(self), + msg_ser: Some(Bytes::from(msg_ser)), + body: BodyData::None, + }) + } + fn into_req_local(self) -> Req { + Req { + _phantom: Default::default(), + msg: Arc::new(self), + msg_ser: None, + body: BodyData::None, + } + } +} + +impl IntoReq for Req { + fn into_req(self) -> Result, rmp_serde::encode::Error> { + Ok(self) + } + fn into_req_local(self) -> Req { + self + } +} + +impl Clone for Req { + fn clone(&self) -> Self { + let body = match &self.body { + BodyData::None => BodyData::None, + BodyData::Fixed(b) => BodyData::Fixed(b.clone()), + BodyData::Stream(_) => panic!("Cannot clone a Req<_> with a stream body"), + }; + Self { + _phantom: Default::default(), + msg: self.msg.clone(), + msg_ser: self.msg_ser.clone(), + body, + } + } +} + +impl fmt::Debug for Req where - T: Serialize + for<'de> Deserialize<'de> + Send, + M: Message + fmt::Debug, { - type SerializableSelf = T; + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Req[{:?}", self.msg)?; + match &self.body { + BodyData::None => write!(f, "]"), + BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), + BodyData::Stream(_) => write!(f, "; body=stream]"), + } + } +} - fn into_parts(self) -> (Self::SerializableSelf, Option) { - (self.0, Some(self.1)) +impl fmt::Debug for Resp +where + M: Message, + ::Response: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Resp[{:?}", self.msg)?; + match &self.body { + BodyData::None => write!(f, "]"), + BodyData::Fixed(b) => write!(f, "; body={}]", b.len()), + BodyData::Stream(_) => write!(f, "; body=stream]"), + } + } +} + +impl Resp { + pub fn new(v: M::Response) -> Self { + Resp { + _phantom: Default::default(), + msg: v, + body: BodyData::None, + } } - fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { - WithStreamingBody(ser_self, stream) + pub fn with_fixed_body(self, b: Bytes) -> Self { + Self { + body: BodyData::Fixed(b), + ..self + } + } + + pub fn with_streaming_body(self, b: ByteStream) -> Self { + Self { + body: BodyData::Stream(b), + ..self + } + } + + pub fn msg(&self) -> &M::Response { + &self.msg + } + + pub fn into_msg(self) -> M::Response { + self.msg } } diff --git a/src/netapp.rs b/src/netapp.rs index 32a5c23..0cebac0 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -404,6 +404,7 @@ impl NetApp { PRIO_NORMAL, ) .await + .map(|_| ()) .log_err("Sending hello message"); }); } @@ -432,7 +433,8 @@ impl NetApp { #[async_trait] impl EndpointHandler for NetApp { - async fn handle(self: &Arc, msg: HelloMessage, from: NodeID) { + async fn handle(self: &Arc, msg: Req, from: NodeID) -> Resp { + let msg = msg.msg(); debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg); if let Some(h) = self.on_connected_handler.load().as_ref() { if let Some(c) = self.server_conns.read().unwrap().get(&from) { @@ -441,5 +443,6 @@ impl EndpointHandler for NetApp { h(from, remote_addr, true); } } + Resp::new(()) } } diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index d7bc6a8..71dea84 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -468,15 +468,24 @@ impl Basalt { #[async_trait] impl EndpointHandler for Basalt { - async fn handle(self: &Arc, _pullmsg: PullMessage, _from: NodeID) -> PushMessage { - self.make_push_message() + async fn handle( + self: &Arc, + _pullmsg: Req, + _from: NodeID, + ) -> Resp { + Resp::new(self.make_push_message()) } } #[async_trait] impl EndpointHandler for Basalt { - async fn handle(self: &Arc, pushmsg: PushMessage, _from: NodeID) { - self.handle_peer_list(&pushmsg.peers[..]); + async fn handle( + self: &Arc, + pushmsg: Req, + _from: NodeID, + ) -> Resp { + self.handle_peer_list(&pushmsg.msg().peers[..]); + Resp::new(()) } } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index f8348af..9b7b666 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -583,13 +583,14 @@ impl FullMeshPeeringStrategy { #[async_trait] impl EndpointHandler for FullMeshPeeringStrategy { - async fn handle(self: &Arc, ping: PingMessage, from: NodeID) -> PingMessage { + async fn handle(self: &Arc, ping: Req, from: NodeID) -> Resp { + let ping = ping.msg(); let ping_resp = PingMessage { id: ping.id, peer_list_hash: self.known_hosts.read().unwrap().hash, }; debug!("Ping from {}", hex::encode(&from[..8])); - ping_resp + Resp::new(ping_resp) } } @@ -597,11 +598,12 @@ impl EndpointHandler for FullMeshPeeringStrategy { impl EndpointHandler for FullMeshPeeringStrategy { async fn handle( self: &Arc, - peer_list: PeerListMessage, + peer_list: Req, _from: NodeID, - ) -> PeerListMessage { + ) -> Resp { + let peer_list = peer_list.msg(); self.handle_peer_list(&peer_list.list[..]); let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list); - PeerListMessage { list: peer_list } + Resp::new(PeerListMessage { list: peer_list }) } }