Rename AutoSerialize into SimpleMessage and refactor a bit

This commit is contained in:
Alex 2022-07-21 19:05:51 +02:00
parent 26989bba14
commit 44bbc1c00c
Signed by: lx
GPG key ID: 0E496D15096376BE
8 changed files with 116 additions and 71 deletions

View file

@ -26,7 +26,7 @@ tokio = { version = "1.0", default-features = false, features = ["net", "rt", "r
tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] } tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] }
tokio-stream = "0.1.7" tokio-stream = "0.1.7"
serde = { version = "1.0", default-features = false, features = ["derive"] } serde = { version = "1.0", default-features = false, features = ["derive", "rc"] }
rmp-serde = "0.14.3" rmp-serde = "0.14.3"
hex = "0.4.2" hex = "0.4.2"

View file

@ -134,15 +134,14 @@ impl ClientConn {
self.query_send.store(None); self.query_send.store(None);
} }
pub(crate) async fn call<T, B>( pub(crate) async fn call<T>(
self: Arc<Self>, self: Arc<Self>,
rq: B, rq: T,
path: &str, path: &str,
prio: RequestPriority, prio: RequestPriority,
) -> Result<<T as Message>::Response, Error> ) -> Result<<T as Message>::Response, Error>
where where
T: Message, T: Message,
B: Borrow<T>,
{ {
let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
@ -164,7 +163,8 @@ impl ClientConn {
}; };
// Encode request // Encode request
let (body, stream) = rmp_to_vec_all_named(rq.borrow())?; let (rq, stream) = rq.into_parts();
let body = rmp_to_vec_all_named(&rq)?;
drop(rq); drop(rq);
let request = QueryMessage { let request = QueryMessage {
@ -217,7 +217,7 @@ impl ClientConn {
let code = resp[0]; let code = resp[0];
if code == 0 { if code == 0 {
let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?; let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
Ok(T::Response::deserialize_msg(ser_resp, stream).await) Ok(T::Response::from_parts(ser_resp, stream))
} else { } else {
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
Err(Error::Remote(code, msg)) Err(Error::Remote(code, msg))

View file

@ -19,7 +19,7 @@ pub trait EndpointHandler<M>: Send + Sync
where where
M: Message, M: Message,
{ {
async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> M::Response; async fn handle(self: &Arc<Self>, m: M, from: NodeID) -> M::Response;
} }
/// If one simply wants to use an endpoint in a client fashion, /// If one simply wants to use an endpoint in a client fashion,
@ -28,7 +28,7 @@ where
/// it will panic if it is ever made to handle request. /// it will panic if it is ever made to handle request.
#[async_trait] #[async_trait]
impl<M: Message + 'static> EndpointHandler<M> for () { impl<M: Message + 'static> EndpointHandler<M> for () {
async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response { async fn handle(self: &Arc<()>, _m: M, _from: NodeID) -> M::Response {
panic!("This endpoint should not have a local handler."); panic!("This endpoint should not have a local handler.");
} }
} }
@ -81,19 +81,16 @@ where
/// Call this endpoint on a remote node (or on the local node, /// Call this endpoint on a remote node (or on the local node,
/// for that matter) /// for that matter)
pub async fn call<B>( pub async fn call(
&self, &self,
target: &NodeID, target: &NodeID,
req: B, req: M,
prio: RequestPriority, prio: RequestPriority,
) -> Result<<M as Message>::Response, Error> ) -> Result<<M as Message>::Response, Error> {
where
B: Borrow<M> + Send + Sync,
{
if *target == self.netapp.id { if *target == self.netapp.id {
match self.handler.load_full() { match self.handler.load_full() {
None => Err(Error::NoHandler), None => Err(Error::NoHandler),
Some(h) => Ok(h.handle(req.borrow(), self.netapp.id).await), Some(h) => Ok(h.handle(req, self.netapp.id).await),
} }
} else { } else {
let conn = self let conn = self
@ -152,10 +149,11 @@ where
None => Err(Error::NoHandler), None => Err(Error::NoHandler),
Some(h) => { Some(h) => {
let req = rmp_serde::decode::from_read_ref(buf)?; let req = rmp_serde::decode::from_read_ref(buf)?;
let req = M::deserialize_msg(req, stream).await; let req = M::from_parts(req, stream);
let res = h.handle(&req, from).await; let res = h.handle(req, from).await;
let (res, res_stream) = res.into_parts();
let res_bytes = rmp_to_vec_all_named(&res)?; let res_bytes = rmp_to_vec_all_named(&res)?;
Ok(res_bytes) Ok((res_bytes, res_stream))
} }
} }
} }

View file

@ -1,7 +1,9 @@
use async_trait::async_trait; use async_trait::async_trait;
use futures::stream::{Stream, StreamExt}; use bytes::Bytes;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use futures::stream::{Stream, StreamExt};
use crate::error::*; use crate::error::*;
use crate::util::*; use crate::util::*;
@ -41,66 +43,112 @@ pub trait Message: SerializeMessage + Send + Sync {
} }
/// A trait for de/serializing messages, with possible associated stream. /// A trait for de/serializing messages, with possible associated stream.
#[async_trait]
pub trait SerializeMessage: Sized { pub trait SerializeMessage: Sized {
type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>); fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>);
// TODO should return Result fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self;
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self;
} }
pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} // ----
#[async_trait] impl<T, E> SerializeMessage for Result<T, E>
impl<T> SerializeMessage for T
where where
T: AutoSerialize, T: SerializeMessage + Send,
E: Serialize + for<'de> Deserialize<'de> + Send,
{ {
type SerializableSelf = Self; type SerializableSelf = Result<T::SerializableSelf, E>;
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>) {
(self.clone(), None) fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
match self {
Ok(ok) => {
let (msg, stream) = ok.into_parts();
(Ok(msg), stream)
}
Err(err) => (Err(err), None),
}
} }
async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
match ser_self {
Ok(ok) => Ok(T::from_parts(ok, stream)),
Err(err) => Err(err),
}
}
}
// ---
pub trait SimpleMessage: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
impl<T> SerializeMessage for T
where
T: SimpleMessage,
{
type SerializableSelf = Self;
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
(self, None)
}
fn from_parts(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
// TODO verify no stream // TODO verify no stream
ser_self ser_self
} }
} }
impl AutoSerialize for () {} impl SimpleMessage for () {}
#[async_trait] impl<T: SimpleMessage> SimpleMessage for std::sync::Arc<T> {}
impl<T, E> SerializeMessage for Result<T, E>
// ----
#[derive(Clone)]
pub struct WithFixedBody<T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static>(
pub T,
pub Bytes,
);
impl<T> SerializeMessage for WithFixedBody<T>
where where
T: SerializeMessage + Send, T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static,
E: SerializeMessage + Send,
{ {
type SerializableSelf = Result<T::SerializableSelf, E::SerializableSelf>; type SerializableSelf = T;
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>) { fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
match self { let body = self.1;
Ok(ok) => { (
let (msg, stream) = ok.serialize_msg(); self.0,
(Ok(msg), stream) Some(Box::pin(futures::stream::once(async move { Ok(body) }))),
} )
Err(err) => {
let (msg, stream) = err.serialize_msg();
(Err(msg), stream)
}
}
} }
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
match ser_self { panic!("Cannot reconstruct a WithFixedBody type from parts");
Ok(ok) => Ok(T::deserialize_msg(ok, stream).await),
Err(err) => Err(E::deserialize_msg(err, stream).await),
}
} }
} }
// ---- pub struct WithStreamingBody<T: Serialize + for<'de> Deserialize<'de> + Send>(
pub T,
pub ByteStream,
);
impl<T> SerializeMessage for WithStreamingBody<T>
where
T: Serialize + for<'de> Deserialize<'de> + Send,
{
type SerializableSelf = T;
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
(self.0, Some(self.1))
}
fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
WithStreamingBody(ser_self, stream)
}
}
// ---- ----
pub(crate) struct QueryMessage<'a> { pub(crate) struct QueryMessage<'a> {
pub(crate) prio: RequestPriority, pub(crate) prio: RequestPriority,
@ -175,6 +223,8 @@ impl<'a> QueryMessage<'a> {
} }
} }
// ---- ----
pub(crate) struct Framing { pub(crate) struct Framing {
direct: Vec<u8>, direct: Vec<u8>,
stream: Option<ByteStream>, stream: Option<ByteStream>,

View file

@ -38,7 +38,7 @@ pub(crate) struct HelloMessage {
pub server_port: u16, pub server_port: u16,
} }
impl AutoSerialize for HelloMessage {} impl SimpleMessage for HelloMessage {}
impl Message for HelloMessage { impl Message for HelloMessage {
type Response = (); type Response = ();
@ -399,7 +399,7 @@ impl NetApp {
hello_endpoint hello_endpoint
.call( .call(
&conn.peer_id, &conn.peer_id,
&HelloMessage { HelloMessage {
server_addr, server_addr,
server_port, server_port,
}, },
@ -434,7 +434,7 @@ impl NetApp {
#[async_trait] #[async_trait]
impl EndpointHandler<HelloMessage> for NetApp { impl EndpointHandler<HelloMessage> for NetApp {
async fn handle(self: &Arc<Self>, msg: &HelloMessage, from: NodeID) { async fn handle(self: &Arc<Self>, msg: HelloMessage, from: NodeID) {
debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg); debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg);
if let Some(h) = self.on_connected_handler.load().as_ref() { if let Some(h) = self.on_connected_handler.load().as_ref() {
if let Some(c) = self.server_conns.read().unwrap().get(&from) { if let Some(c) = self.server_conns.read().unwrap().get(&from) {

View file

@ -40,7 +40,7 @@ impl Message for PingMessage {
type Response = PingMessage; type Response = PingMessage;
} }
impl AutoSerialize for PingMessage {} impl SimpleMessage for PingMessage {}
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
struct PeerListMessage { struct PeerListMessage {
@ -51,7 +51,7 @@ impl Message for PeerListMessage {
type Response = PeerListMessage; type Response = PeerListMessage;
} }
impl AutoSerialize for PeerListMessage {} impl SimpleMessage for PeerListMessage {}
// -- Algorithm data structures -- // -- Algorithm data structures --
@ -379,7 +379,7 @@ impl FullMeshPeeringStrategy {
ping_time ping_time
); );
let ping_response = select! { let ping_response = select! {
r = self.ping_endpoint.call(&id, &ping_msg, PRIO_HIGH) => r, r = self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH) => r,
_ = tokio::time::sleep(PING_TIMEOUT) => Err(Error::Message("Ping timeout".into())), _ = tokio::time::sleep(PING_TIMEOUT) => Err(Error::Message("Ping timeout".into())),
}; };
@ -431,7 +431,7 @@ impl FullMeshPeeringStrategy {
let pex_message = PeerListMessage { list: peer_list }; let pex_message = PeerListMessage { list: peer_list };
match self match self
.peer_list_endpoint .peer_list_endpoint
.call(id, &pex_message, PRIO_BACKGROUND) .call(id, pex_message, PRIO_BACKGROUND)
.await .await
{ {
Err(e) => warn!("Error doing peer exchange: {}", e), Err(e) => warn!("Error doing peer exchange: {}", e),
@ -587,7 +587,7 @@ impl FullMeshPeeringStrategy {
#[async_trait] #[async_trait]
impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy { impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
async fn handle(self: &Arc<Self>, ping: &PingMessage, from: NodeID) -> PingMessage { async fn handle(self: &Arc<Self>, ping: PingMessage, from: NodeID) -> PingMessage {
let ping_resp = PingMessage { let ping_resp = PingMessage {
id: ping.id, id: ping.id,
peer_list_hash: self.known_hosts.read().unwrap().hash, peer_list_hash: self.known_hosts.read().unwrap().hash,
@ -601,7 +601,7 @@ impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
impl EndpointHandler<PeerListMessage> for FullMeshPeeringStrategy { impl EndpointHandler<PeerListMessage> for FullMeshPeeringStrategy {
async fn handle( async fn handle(
self: &Arc<Self>, self: &Arc<Self>,
peer_list: &PeerListMessage, peer_list: PeerListMessage,
_from: NodeID, _from: NodeID,
) -> PeerListMessage { ) -> PeerListMessage {
self.handle_peer_list(&peer_list.list[..]); self.handle_peer_list(&peer_list.list[..]);

View file

@ -3,8 +3,8 @@ use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use bytes::Bytes;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes;
use log::trace; use log::trace;
use futures::AsyncWriteExt; use futures::AsyncWriteExt;

View file

@ -2,9 +2,9 @@ use std::net::SocketAddr;
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
use std::pin::Pin; use std::pin::Pin;
use bytes::Bytes;
use log::info; use log::info;
use serde::Serialize; use serde::Serialize;
use bytes::Bytes;
use futures::Stream; use futures::Stream;
use tokio::sync::watch; use tokio::sync::watch;
@ -35,19 +35,16 @@ pub type Packet = Result<Bytes, u8>;
/// ///
/// Field names and variant names are included in the serialization. /// Field names and variant names are included in the serialization.
/// This is used internally by the netapp communication protocol. /// This is used internally by the netapp communication protocol.
pub fn rmp_to_vec_all_named<T>( pub fn rmp_to_vec_all_named<T>(val: &T) -> Result<Vec<u8>, rmp_serde::encode::Error>
val: &T,
) -> Result<(Vec<u8>, Option<ByteStream>), rmp_serde::encode::Error>
where where
T: SerializeMessage + ?Sized, T: Serialize + ?Sized,
{ {
let mut wr = Vec::with_capacity(128); let mut wr = Vec::with_capacity(128);
let mut se = rmp_serde::Serializer::new(&mut wr) let mut se = rmp_serde::Serializer::new(&mut wr)
.with_struct_map() .with_struct_map()
.with_string_variants(); .with_string_variants();
let (val, stream) = val.serialize_msg();
val.serialize(&mut se)?; val.serialize(&mut se)?;
Ok((wr, stream)) Ok(wr)
} }
/// This async function returns only when a true signal was received /// This async function returns only when a true signal was received