diff --git a/Cargo.toml b/Cargo.toml index d8a4908..a19e11a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ tokio = { version = "1.0", default-features = false, features = ["net", "rt", "r tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] } tokio-stream = "0.1.7" -serde = { version = "1.0", default-features = false, features = ["derive"] } +serde = { version = "1.0", default-features = false, features = ["derive", "rc"] } rmp-serde = "0.14.3" hex = "0.4.2" diff --git a/src/client.rs b/src/client.rs index 663a3e4..cf80746 100644 --- a/src/client.rs +++ b/src/client.rs @@ -134,15 +134,14 @@ impl ClientConn { self.query_send.store(None); } - pub(crate) async fn call( + pub(crate) async fn call( self: Arc, - rq: B, + rq: T, path: &str, prio: RequestPriority, ) -> Result<::Response, Error> where T: Message, - B: Borrow, { let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; @@ -164,7 +163,8 @@ impl ClientConn { }; // Encode request - let (body, stream) = rmp_to_vec_all_named(rq.borrow())?; + let (rq, stream) = rq.into_parts(); + let body = rmp_to_vec_all_named(&rq)?; drop(rq); let request = QueryMessage { @@ -217,7 +217,7 @@ impl ClientConn { let code = resp[0]; if code == 0 { let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?; - Ok(T::Response::deserialize_msg(ser_resp, stream).await) + Ok(T::Response::from_parts(ser_resp, stream)) } else { let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); Err(Error::Remote(code, msg)) diff --git a/src/endpoint.rs b/src/endpoint.rs index e6b2236..3f292d9 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -19,7 +19,7 @@ pub trait EndpointHandler: Send + Sync where M: Message, { - async fn handle(self: &Arc, m: &M, from: NodeID) -> M::Response; + async fn handle(self: &Arc, m: M, from: NodeID) -> M::Response; } /// If one simply wants to use an endpoint in a client fashion, @@ -28,7 +28,7 @@ where /// it will panic if it is ever made to handle request. #[async_trait] impl EndpointHandler for () { - async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response { + async fn handle(self: &Arc<()>, _m: M, _from: NodeID) -> M::Response { panic!("This endpoint should not have a local handler."); } } @@ -81,19 +81,16 @@ where /// Call this endpoint on a remote node (or on the local node, /// for that matter) - pub async fn call( + pub async fn call( &self, target: &NodeID, - req: B, + req: M, prio: RequestPriority, - ) -> Result<::Response, Error> - where - B: Borrow + Send + Sync, - { + ) -> Result<::Response, Error> { if *target == self.netapp.id { match self.handler.load_full() { None => Err(Error::NoHandler), - Some(h) => Ok(h.handle(req.borrow(), self.netapp.id).await), + Some(h) => Ok(h.handle(req, self.netapp.id).await), } } else { let conn = self @@ -152,10 +149,11 @@ where None => Err(Error::NoHandler), Some(h) => { let req = rmp_serde::decode::from_read_ref(buf)?; - let req = M::deserialize_msg(req, stream).await; - let res = h.handle(&req, from).await; + let req = M::from_parts(req, stream); + let res = h.handle(req, from).await; + let (res, res_stream) = res.into_parts(); let res_bytes = rmp_to_vec_all_named(&res)?; - Ok(res_bytes) + Ok((res_bytes, res_stream)) } } } diff --git a/src/message.rs b/src/message.rs index 6d50254..f92eb8c 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,7 +1,9 @@ use async_trait::async_trait; -use futures::stream::{Stream, StreamExt}; +use bytes::Bytes; use serde::{Deserialize, Serialize}; +use futures::stream::{Stream, StreamExt}; + use crate::error::*; use crate::util::*; @@ -41,66 +43,112 @@ pub trait Message: SerializeMessage + Send + Sync { } /// A trait for de/serializing messages, with possible associated stream. -#[async_trait] pub trait SerializeMessage: Sized { type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; - fn serialize_msg(&self) -> (Self::SerializableSelf, Option); + fn into_parts(self) -> (Self::SerializableSelf, Option); - // TODO should return Result - async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self; + fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self; } -pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} +// ---- -#[async_trait] -impl SerializeMessage for T +impl SerializeMessage for Result where - T: AutoSerialize, + T: SerializeMessage + Send, + E: Serialize + for<'de> Deserialize<'de> + Send, { - type SerializableSelf = Self; - fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { - (self.clone(), None) + type SerializableSelf = Result; + + fn into_parts(self) -> (Self::SerializableSelf, Option) { + match self { + Ok(ok) => { + let (msg, stream) = ok.into_parts(); + (Ok(msg), stream) + } + Err(err) => (Err(err), None), + } } - async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { + fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { + match ser_self { + Ok(ok) => Ok(T::from_parts(ok, stream)), + Err(err) => Err(err), + } + } +} + +// --- + +pub trait SimpleMessage: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} + +impl SerializeMessage for T +where + T: SimpleMessage, +{ + type SerializableSelf = Self; + fn into_parts(self) -> (Self::SerializableSelf, Option) { + (self, None) + } + + fn from_parts(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { // TODO verify no stream ser_self } } -impl AutoSerialize for () {} +impl SimpleMessage for () {} -#[async_trait] -impl SerializeMessage for Result +impl SimpleMessage for std::sync::Arc {} + +// ---- + +#[derive(Clone)] +pub struct WithFixedBody Deserialize<'de> + Clone + Send + 'static>( + pub T, + pub Bytes, +); + +impl SerializeMessage for WithFixedBody where - T: SerializeMessage + Send, - E: SerializeMessage + Send, + T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static, { - type SerializableSelf = Result; + type SerializableSelf = T; - fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { - match self { - Ok(ok) => { - let (msg, stream) = ok.serialize_msg(); - (Ok(msg), stream) - } - Err(err) => { - let (msg, stream) = err.serialize_msg(); - (Err(msg), stream) - } - } + fn into_parts(self) -> (Self::SerializableSelf, Option) { + let body = self.1; + ( + self.0, + Some(Box::pin(futures::stream::once(async move { Ok(body) }))), + ) } - async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { - match ser_self { - Ok(ok) => Ok(T::deserialize_msg(ok, stream).await), - Err(err) => Err(E::deserialize_msg(err, stream).await), - } + fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { + panic!("Cannot reconstruct a WithFixedBody type from parts"); } } -// ---- +pub struct WithStreamingBody Deserialize<'de> + Send>( + pub T, + pub ByteStream, +); + +impl SerializeMessage for WithStreamingBody +where + T: Serialize + for<'de> Deserialize<'de> + Send, +{ + type SerializableSelf = T; + + fn into_parts(self) -> (Self::SerializableSelf, Option) { + (self.0, Some(self.1)) + } + + fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { + WithStreamingBody(ser_self, stream) + } +} + +// ---- ---- pub(crate) struct QueryMessage<'a> { pub(crate) prio: RequestPriority, @@ -175,6 +223,8 @@ impl<'a> QueryMessage<'a> { } } +// ---- ---- + pub(crate) struct Framing { direct: Vec, stream: Option, diff --git a/src/netapp.rs b/src/netapp.rs index dd22d90..8365de0 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -38,7 +38,7 @@ pub(crate) struct HelloMessage { pub server_port: u16, } -impl AutoSerialize for HelloMessage {} +impl SimpleMessage for HelloMessage {} impl Message for HelloMessage { type Response = (); @@ -399,7 +399,7 @@ impl NetApp { hello_endpoint .call( &conn.peer_id, - &HelloMessage { + HelloMessage { server_addr, server_port, }, @@ -434,7 +434,7 @@ impl NetApp { #[async_trait] impl EndpointHandler for NetApp { - async fn handle(self: &Arc, msg: &HelloMessage, from: NodeID) { + async fn handle(self: &Arc, msg: HelloMessage, from: NodeID) { debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg); if let Some(h) = self.on_connected_handler.load().as_ref() { if let Some(c) = self.server_conns.read().unwrap().get(&from) { diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 5b489ae..3eeebb3 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -40,7 +40,7 @@ impl Message for PingMessage { type Response = PingMessage; } -impl AutoSerialize for PingMessage {} +impl SimpleMessage for PingMessage {} #[derive(Serialize, Deserialize, Clone)] struct PeerListMessage { @@ -51,7 +51,7 @@ impl Message for PeerListMessage { type Response = PeerListMessage; } -impl AutoSerialize for PeerListMessage {} +impl SimpleMessage for PeerListMessage {} // -- Algorithm data structures -- @@ -379,7 +379,7 @@ impl FullMeshPeeringStrategy { ping_time ); let ping_response = select! { - r = self.ping_endpoint.call(&id, &ping_msg, PRIO_HIGH) => r, + r = self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH) => r, _ = tokio::time::sleep(PING_TIMEOUT) => Err(Error::Message("Ping timeout".into())), }; @@ -431,7 +431,7 @@ impl FullMeshPeeringStrategy { let pex_message = PeerListMessage { list: peer_list }; match self .peer_list_endpoint - .call(id, &pex_message, PRIO_BACKGROUND) + .call(id, pex_message, PRIO_BACKGROUND) .await { Err(e) => warn!("Error doing peer exchange: {}", e), @@ -587,7 +587,7 @@ impl FullMeshPeeringStrategy { #[async_trait] impl EndpointHandler for FullMeshPeeringStrategy { - async fn handle(self: &Arc, ping: &PingMessage, from: NodeID) -> PingMessage { + async fn handle(self: &Arc, ping: PingMessage, from: NodeID) -> PingMessage { let ping_resp = PingMessage { id: ping.id, peer_list_hash: self.known_hosts.read().unwrap().hash, @@ -601,7 +601,7 @@ impl EndpointHandler for FullMeshPeeringStrategy { impl EndpointHandler for FullMeshPeeringStrategy { async fn handle( self: &Arc, - peer_list: &PeerListMessage, + peer_list: PeerListMessage, _from: NodeID, ) -> PeerListMessage { self.handle_peer_list(&peer_list.list[..]); diff --git a/src/send.rs b/src/send.rs index 660e85c..cc28d7c 100644 --- a/src/send.rs +++ b/src/send.rs @@ -3,8 +3,8 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use bytes::Bytes; use async_trait::async_trait; +use bytes::Bytes; use log::trace; use futures::AsyncWriteExt; diff --git a/src/util.rs b/src/util.rs index e81a89c..f860672 100644 --- a/src/util.rs +++ b/src/util.rs @@ -2,9 +2,9 @@ use std::net::SocketAddr; use std::net::ToSocketAddrs; use std::pin::Pin; +use bytes::Bytes; use log::info; use serde::Serialize; -use bytes::Bytes; use futures::Stream; use tokio::sync::watch; @@ -35,19 +35,16 @@ pub type Packet = Result; /// /// Field names and variant names are included in the serialization. /// This is used internally by the netapp communication protocol. -pub fn rmp_to_vec_all_named( - val: &T, -) -> Result<(Vec, Option), rmp_serde::encode::Error> +pub fn rmp_to_vec_all_named(val: &T) -> Result, rmp_serde::encode::Error> where - T: SerializeMessage + ?Sized, + T: Serialize + ?Sized, { let mut wr = Vec::with_capacity(128); let mut se = rmp_serde::Serializer::new(&mut wr) .with_struct_map() .with_string_variants(); - let (val, stream) = val.serialize_msg(); val.serialize(&mut se)?; - Ok((wr, stream)) + Ok(wr) } /// This async function returns only when a true signal was received