diff --git a/src/client.rs b/src/client.rs index bce7aca..bc16fb1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -227,9 +227,8 @@ impl ClientConn { let code = resp[0]; if code == 0 { - let mut deser = rmp_serde::decode::Deserializer::from_read_ref(&resp[1..]); - let res = T::Response::deserialize_msg(&mut deser, stream).await?; - Ok(res) + let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?; + Ok(T::Response::deserialize_msg(ser_resp, stream).await) } else { let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); Err(Error::Remote(code, msg)) diff --git a/src/endpoint.rs b/src/endpoint.rs index 81ed036..c25365a 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -5,8 +5,7 @@ use std::sync::Arc; use arc_swap::ArcSwapOption; use async_trait::async_trait; -use serde::de::Error as DeError; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::{Deserialize, Serialize}; use crate::error::Error; use crate::netapp::*; @@ -22,42 +21,61 @@ pub trait Message: SerializeMessage + Send + Sync { /// A trait for de/serializing messages, with possible associated stream. #[async_trait] pub trait SerializeMessage: Sized { - fn serialize_msg( - &self, - serializer: S, - ) -> Result<(S::Ok, Option), S::Error>; + type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; - async fn deserialize_msg<'de, D: Deserializer<'de> + Send>( - deserializer: D, - stream: AssociatedStream, - ) -> Result; + // TODO should return Result + fn serialize_msg(&self) -> (Self::SerializableSelf, Option); + + // TODO should return Result + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self; } +pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} + #[async_trait] impl SerializeMessage for T where - T: Serialize + for<'de> Deserialize<'de> + Send + Sync, + T: AutoSerialize, { - fn serialize_msg( - &self, - serializer: S, - ) -> Result<(S::Ok, Option), S::Error> { - self.serialize(serializer).map(|r| (r, None)) + type SerializableSelf = Self; + fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { + (self.clone(), None) } - async fn deserialize_msg<'de, D: Deserializer<'de> + Send>( - deserializer: D, - mut stream: AssociatedStream, - ) -> Result { - use futures::StreamExt; + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self { + // TODO verify no stream + ser_self + } +} - let res = Self::deserialize(deserializer)?; - if stream.next().await.is_some() { - return Err(D::Error::custom( - "failed to deserialize: found associated stream when none expected", - )); +impl AutoSerialize for () {} + +#[async_trait] +impl SerializeMessage for Result +where + T: SerializeMessage + Send, + E: SerializeMessage + Send, +{ + type SerializableSelf = Result; + + fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { + match self { + Ok(ok) => { + let (msg, stream) = ok.serialize_msg(); + (Ok(msg), stream) + } + Err(err) => { + let (msg, stream) = err.serialize_msg(); + (Err(msg), stream) + } + } + } + + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self { + match ser_self { + Ok(ok) => Ok(T::deserialize_msg(ok, stream).await), + Err(err) => Err(E::deserialize_msg(err, stream).await), } - Ok(res) } } @@ -139,7 +157,7 @@ where prio: RequestPriority, ) -> Result<::Response, Error> where - B: Borrow, + B: Borrow + Send + Sync, { if *target == self.netapp.id { match self.handler.load_full() { @@ -202,8 +220,8 @@ where match self.0.handler.load_full() { None => Err(Error::NoHandler), Some(h) => { - let mut deser = rmp_serde::decode::Deserializer::from_read_ref(buf); - let req = M::deserialize_msg(&mut deser, stream).await?; + let req = rmp_serde::decode::from_read_ref(buf)?; + let req = M::deserialize_msg(req, stream).await; let res = h.handle(&req, from).await; let res_bytes = rmp_to_vec_all_named(&res)?; Ok(res_bytes) diff --git a/src/netapp.rs b/src/netapp.rs index e9efa2e..27f17e6 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -32,12 +32,14 @@ pub(crate) type VersionTag = [u8; 16]; /// Value of the Netapp version used in the version tag pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004 -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub(crate) struct HelloMessage { pub server_addr: Option, pub server_port: u16, } +impl AutoSerialize for HelloMessage {} + impl Message for HelloMessage { type Response = (); } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 012c5a0..7dfc5c4 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -29,7 +29,7 @@ const FAILED_PING_THRESHOLD: usize = 3; // -- Protocol messages -- -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] struct PingMessage { pub id: u64, pub peer_list_hash: hash::Digest, @@ -39,7 +39,9 @@ impl Message for PingMessage { type Response = PingMessage; } -#[derive(Serialize, Deserialize)] +impl AutoSerialize for PingMessage {} + +#[derive(Serialize, Deserialize, Clone)] struct PeerListMessage { pub list: Vec<(NodeID, SocketAddr)>, } @@ -48,6 +50,8 @@ impl Message for PeerListMessage { type Response = PeerListMessage; } +impl AutoSerialize for PeerListMessage {} + // -- Algorithm data structures -- #[derive(Debug)] diff --git a/src/proto.rs b/src/proto.rs index ca1a3d2..073a317 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -151,9 +151,10 @@ impl Stream for DataReader { } let mut body = [0; MAX_CHUNK_LENGTH as usize]; - body[..buf.len()].copy_from_slice(&buf); + let len = buf.len(); + body[..len].copy_from_slice(buf); buf.clear(); - Poll::Ready(Some((body, MAX_CHUNK_LENGTH as usize))) + Poll::Ready(Some((body, len))) } } } diff --git a/src/util.rs b/src/util.rs index 4333080..02b4e7d 100644 --- a/src/util.rs +++ b/src/util.rs @@ -8,6 +8,8 @@ use futures::Stream; use log::info; +use serde::Serialize; + use tokio::sync::watch; /// A node's identifier, which is also its public cryptographic key @@ -34,7 +36,8 @@ where let mut se = rmp_serde::Serializer::new(&mut wr) .with_struct_map() .with_string_variants(); - let (_, stream) = val.serialize_msg(&mut se)?; + let (val, stream) = val.serialize_msg(); + val.serialize(&mut se)?; Ok((wr, stream)) }