WIP: associated stream #1

Draft
trinity-1686a wants to merge 7 commits from stream-body into main
6 changed files with 66 additions and 39 deletions
Showing only changes of commit 4745e7c4ba - Show all commits

View file

@ -227,9 +227,8 @@ impl ClientConn {
let code = resp[0]; let code = resp[0];
if code == 0 { if code == 0 {
let mut deser = rmp_serde::decode::Deserializer::from_read_ref(&resp[1..]); let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
let res = T::Response::deserialize_msg(&mut deser, stream).await?; Ok(T::Response::deserialize_msg(ser_resp, stream).await)
Ok(res)
} else { } else {
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
Err(Error::Remote(code, msg)) Err(Error::Remote(code, msg))

View file

@ -5,8 +5,7 @@ use std::sync::Arc;
use arc_swap::ArcSwapOption; use arc_swap::ArcSwapOption;
use async_trait::async_trait; use async_trait::async_trait;
use serde::de::Error as DeError; use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::error::Error; use crate::error::Error;
use crate::netapp::*; use crate::netapp::*;
@ -22,42 +21,61 @@ pub trait Message: SerializeMessage + Send + Sync {
/// A trait for de/serializing messages, with possible associated stream. /// A trait for de/serializing messages, with possible associated stream.
#[async_trait] #[async_trait]
pub trait SerializeMessage: Sized { pub trait SerializeMessage: Sized {
fn serialize_msg<S: Serializer>( type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
&self,
serializer: S,
) -> Result<(S::Ok, Option<AssociatedStream>), S::Error>;
async fn deserialize_msg<'de, D: Deserializer<'de> + Send>( // TODO should return Result
deserializer: D, fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>);
stream: AssociatedStream,
) -> Result<Self, D::Error>; // TODO should return Result
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self;
} }
pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
#[async_trait] #[async_trait]
impl<T> SerializeMessage for T impl<T> SerializeMessage for T
where where
T: Serialize + for<'de> Deserialize<'de> + Send + Sync, T: AutoSerialize,
{ {
fn serialize_msg<S: Serializer>( type SerializableSelf = Self;
&self, fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>) {
serializer: S, (self.clone(), None)
) -> Result<(S::Ok, Option<AssociatedStream>), S::Error> {
self.serialize(serializer).map(|r| (r, None))
} }
async fn deserialize_msg<'de, D: Deserializer<'de> + Send>( async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self {
deserializer: D, // TODO verify no stream
mut stream: AssociatedStream, ser_self
) -> Result<Self, D::Error> { }
use futures::StreamExt; }
let res = Self::deserialize(deserializer)?; impl AutoSerialize for () {}
if stream.next().await.is_some() {
return Err(D::Error::custom( #[async_trait]
"failed to deserialize: found associated stream when none expected", impl<T, E> SerializeMessage for Result<T, E>
)); where
T: SerializeMessage + Send,
E: SerializeMessage + Send,
{
type SerializableSelf = Result<T::SerializableSelf, E::SerializableSelf>;
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>) {
match self {
Ok(ok) => {
let (msg, stream) = ok.serialize_msg();
(Ok(msg), stream)
}
Err(err) => {
let (msg, stream) = err.serialize_msg();
(Err(msg), stream)
}
}
}
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self {
match ser_self {
Ok(ok) => Ok(T::deserialize_msg(ok, stream).await),
Err(err) => Err(E::deserialize_msg(err, stream).await),
} }
Ok(res)
} }
} }
@ -139,7 +157,7 @@ where
prio: RequestPriority, prio: RequestPriority,
) -> Result<<M as Message>::Response, Error> ) -> Result<<M as Message>::Response, Error>
where where
B: Borrow<M>, B: Borrow<M> + Send + Sync,
{ {
if *target == self.netapp.id { if *target == self.netapp.id {
match self.handler.load_full() { match self.handler.load_full() {
@ -202,8 +220,8 @@ where
match self.0.handler.load_full() { match self.0.handler.load_full() {
None => Err(Error::NoHandler), None => Err(Error::NoHandler),
Some(h) => { Some(h) => {
let mut deser = rmp_serde::decode::Deserializer::from_read_ref(buf); let req = rmp_serde::decode::from_read_ref(buf)?;
let req = M::deserialize_msg(&mut deser, stream).await?; let req = M::deserialize_msg(req, stream).await;
let res = h.handle(&req, from).await; let res = h.handle(&req, from).await;
let res_bytes = rmp_to_vec_all_named(&res)?; let res_bytes = rmp_to_vec_all_named(&res)?;
Ok(res_bytes) Ok(res_bytes)

View file

@ -32,12 +32,14 @@ pub(crate) type VersionTag = [u8; 16];
/// Value of the Netapp version used in the version tag /// Value of the Netapp version used in the version tag
pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004 pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub(crate) struct HelloMessage { pub(crate) struct HelloMessage {
pub server_addr: Option<IpAddr>, pub server_addr: Option<IpAddr>,
pub server_port: u16, pub server_port: u16,
} }
impl AutoSerialize for HelloMessage {}
impl Message for HelloMessage { impl Message for HelloMessage {
type Response = (); type Response = ();
} }

View file

@ -29,7 +29,7 @@ const FAILED_PING_THRESHOLD: usize = 3;
// -- Protocol messages -- // -- Protocol messages --
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Clone)]
struct PingMessage { struct PingMessage {
pub id: u64, pub id: u64,
pub peer_list_hash: hash::Digest, pub peer_list_hash: hash::Digest,
@ -39,7 +39,9 @@ impl Message for PingMessage {
type Response = PingMessage; type Response = PingMessage;
} }
#[derive(Serialize, Deserialize)] impl AutoSerialize for PingMessage {}
#[derive(Serialize, Deserialize, Clone)]
struct PeerListMessage { struct PeerListMessage {
pub list: Vec<(NodeID, SocketAddr)>, pub list: Vec<(NodeID, SocketAddr)>,
} }
@ -48,6 +50,8 @@ impl Message for PeerListMessage {
type Response = PeerListMessage; type Response = PeerListMessage;
} }
impl AutoSerialize for PeerListMessage {}
// -- Algorithm data structures -- // -- Algorithm data structures --
#[derive(Debug)] #[derive(Debug)]

View file

@ -151,9 +151,10 @@ impl Stream for DataReader {
} }
let mut body = [0; MAX_CHUNK_LENGTH as usize]; let mut body = [0; MAX_CHUNK_LENGTH as usize];
body[..buf.len()].copy_from_slice(&buf); let len = buf.len();
body[..len].copy_from_slice(buf);
buf.clear(); buf.clear();
Poll::Ready(Some((body, MAX_CHUNK_LENGTH as usize))) Poll::Ready(Some((body, len)))
} }
} }
} }

View file

@ -8,6 +8,8 @@ use futures::Stream;
use log::info; use log::info;
use serde::Serialize;
use tokio::sync::watch; use tokio::sync::watch;
/// A node's identifier, which is also its public cryptographic key /// A node's identifier, which is also its public cryptographic key
@ -34,7 +36,8 @@ where
let mut se = rmp_serde::Serializer::new(&mut wr) let mut se = rmp_serde::Serializer::new(&mut wr)
.with_struct_map() .with_struct_map()
.with_string_variants(); .with_string_variants();
let (_, stream) = val.serialize_msg(&mut se)?; let (val, stream) = val.serialize_msg();
val.serialize(&mut se)?;
Ok((wr, stream)) Ok((wr, stream))
} }