Rename AutoSerialize into SimpleMessage and refactor a bit

This commit is contained in:
Alex 2022-07-21 19:05:51 +02:00
parent 26989bba14
commit 44bbc1c00c
Signed by untrusted user: lx
GPG key ID: 0E496D15096376BE
8 changed files with 116 additions and 71 deletions

View file

@ -26,7 +26,7 @@ tokio = { version = "1.0", default-features = false, features = ["net", "rt", "r
tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] }
tokio-stream = "0.1.7"
serde = { version = "1.0", default-features = false, features = ["derive"] }
serde = { version = "1.0", default-features = false, features = ["derive", "rc"] }
rmp-serde = "0.14.3"
hex = "0.4.2"

View file

@ -134,15 +134,14 @@ impl ClientConn {
self.query_send.store(None);
}
pub(crate) async fn call<T, B>(
pub(crate) async fn call<T>(
self: Arc<Self>,
rq: B,
rq: T,
path: &str,
prio: RequestPriority,
) -> Result<<T as Message>::Response, Error>
where
T: Message,
B: Borrow<T>,
{
let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
@ -164,7 +163,8 @@ impl ClientConn {
};
// Encode request
let (body, stream) = rmp_to_vec_all_named(rq.borrow())?;
let (rq, stream) = rq.into_parts();
let body = rmp_to_vec_all_named(&rq)?;
drop(rq);
let request = QueryMessage {
@ -217,7 +217,7 @@ impl ClientConn {
let code = resp[0];
if code == 0 {
let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
Ok(T::Response::deserialize_msg(ser_resp, stream).await)
Ok(T::Response::from_parts(ser_resp, stream))
} else {
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
Err(Error::Remote(code, msg))

View file

@ -19,7 +19,7 @@ pub trait EndpointHandler<M>: Send + Sync
where
M: Message,
{
async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> M::Response;
async fn handle(self: &Arc<Self>, m: M, from: NodeID) -> M::Response;
}
/// If one simply wants to use an endpoint in a client fashion,
@ -28,7 +28,7 @@ where
/// it will panic if it is ever made to handle request.
#[async_trait]
impl<M: Message + 'static> EndpointHandler<M> for () {
async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response {
async fn handle(self: &Arc<()>, _m: M, _from: NodeID) -> M::Response {
panic!("This endpoint should not have a local handler.");
}
}
@ -81,19 +81,16 @@ where
/// Call this endpoint on a remote node (or on the local node,
/// for that matter)
pub async fn call<B>(
pub async fn call(
&self,
target: &NodeID,
req: B,
req: M,
prio: RequestPriority,
) -> Result<<M as Message>::Response, Error>
where
B: Borrow<M> + Send + Sync,
{
) -> Result<<M as Message>::Response, Error> {
if *target == self.netapp.id {
match self.handler.load_full() {
None => Err(Error::NoHandler),
Some(h) => Ok(h.handle(req.borrow(), self.netapp.id).await),
Some(h) => Ok(h.handle(req, self.netapp.id).await),
}
} else {
let conn = self
@ -152,10 +149,11 @@ where
None => Err(Error::NoHandler),
Some(h) => {
let req = rmp_serde::decode::from_read_ref(buf)?;
let req = M::deserialize_msg(req, stream).await;
let res = h.handle(&req, from).await;
let req = M::from_parts(req, stream);
let res = h.handle(req, from).await;
let (res, res_stream) = res.into_parts();
let res_bytes = rmp_to_vec_all_named(&res)?;
Ok(res_bytes)
Ok((res_bytes, res_stream))
}
}
}

View file

@ -1,7 +1,9 @@
use async_trait::async_trait;
use futures::stream::{Stream, StreamExt};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use futures::stream::{Stream, StreamExt};
use crate::error::*;
use crate::util::*;
@ -41,67 +43,113 @@ pub trait Message: SerializeMessage + Send + Sync {
}
/// A trait for de/serializing messages, with possible associated stream.
#[async_trait]
pub trait SerializeMessage: Sized {
type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>);
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>);
// TODO should return Result
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self;
fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self;
}
pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
// ----
impl<T, E> SerializeMessage for Result<T, E>
where
T: SerializeMessage + Send,
E: Serialize + for<'de> Deserialize<'de> + Send,
{
type SerializableSelf = Result<T::SerializableSelf, E>;
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
match self {
Ok(ok) => {
let (msg, stream) = ok.into_parts();
(Ok(msg), stream)
}
Err(err) => (Err(err), None),
}
}
fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
match ser_self {
Ok(ok) => Ok(T::from_parts(ok, stream)),
Err(err) => Err(err),
}
}
}
// ---
pub trait SimpleMessage: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
#[async_trait]
impl<T> SerializeMessage for T
where
T: AutoSerialize,
T: SimpleMessage,
{
type SerializableSelf = Self;
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>) {
(self.clone(), None)
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
(self, None)
}
async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
fn from_parts(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
// TODO verify no stream
ser_self
}
}
impl AutoSerialize for () {}
impl SimpleMessage for () {}
#[async_trait]
impl<T, E> SerializeMessage for Result<T, E>
where
T: SerializeMessage + Send,
E: SerializeMessage + Send,
{
type SerializableSelf = Result<T::SerializableSelf, E::SerializableSelf>;
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>) {
match self {
Ok(ok) => {
let (msg, stream) = ok.serialize_msg();
(Ok(msg), stream)
}
Err(err) => {
let (msg, stream) = err.serialize_msg();
(Err(msg), stream)
}
}
}
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
match ser_self {
Ok(ok) => Ok(T::deserialize_msg(ok, stream).await),
Err(err) => Err(E::deserialize_msg(err, stream).await),
}
}
}
impl<T: SimpleMessage> SimpleMessage for std::sync::Arc<T> {}
// ----
#[derive(Clone)]
pub struct WithFixedBody<T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static>(
pub T,
pub Bytes,
);
impl<T> SerializeMessage for WithFixedBody<T>
where
T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static,
{
type SerializableSelf = T;
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
let body = self.1;
(
self.0,
Some(Box::pin(futures::stream::once(async move { Ok(body) }))),
)
}
fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
panic!("Cannot reconstruct a WithFixedBody type from parts");
}
}
pub struct WithStreamingBody<T: Serialize + for<'de> Deserialize<'de> + Send>(
pub T,
pub ByteStream,
);
impl<T> SerializeMessage for WithStreamingBody<T>
where
T: Serialize + for<'de> Deserialize<'de> + Send,
{
type SerializableSelf = T;
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
(self.0, Some(self.1))
}
fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
WithStreamingBody(ser_self, stream)
}
}
// ---- ----
pub(crate) struct QueryMessage<'a> {
pub(crate) prio: RequestPriority,
pub(crate) path: &'a [u8],
@ -175,6 +223,8 @@ impl<'a> QueryMessage<'a> {
}
}
// ---- ----
pub(crate) struct Framing {
direct: Vec<u8>,
stream: Option<ByteStream>,

View file

@ -38,7 +38,7 @@ pub(crate) struct HelloMessage {
pub server_port: u16,
}
impl AutoSerialize for HelloMessage {}
impl SimpleMessage for HelloMessage {}
impl Message for HelloMessage {
type Response = ();
@ -399,7 +399,7 @@ impl NetApp {
hello_endpoint
.call(
&conn.peer_id,
&HelloMessage {
HelloMessage {
server_addr,
server_port,
},
@ -434,7 +434,7 @@ impl NetApp {
#[async_trait]
impl EndpointHandler<HelloMessage> for NetApp {
async fn handle(self: &Arc<Self>, msg: &HelloMessage, from: NodeID) {
async fn handle(self: &Arc<Self>, msg: HelloMessage, from: NodeID) {
debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg);
if let Some(h) = self.on_connected_handler.load().as_ref() {
if let Some(c) = self.server_conns.read().unwrap().get(&from) {

View file

@ -40,7 +40,7 @@ impl Message for PingMessage {
type Response = PingMessage;
}
impl AutoSerialize for PingMessage {}
impl SimpleMessage for PingMessage {}
#[derive(Serialize, Deserialize, Clone)]
struct PeerListMessage {
@ -51,7 +51,7 @@ impl Message for PeerListMessage {
type Response = PeerListMessage;
}
impl AutoSerialize for PeerListMessage {}
impl SimpleMessage for PeerListMessage {}
// -- Algorithm data structures --
@ -379,7 +379,7 @@ impl FullMeshPeeringStrategy {
ping_time
);
let ping_response = select! {
r = self.ping_endpoint.call(&id, &ping_msg, PRIO_HIGH) => r,
r = self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH) => r,
_ = tokio::time::sleep(PING_TIMEOUT) => Err(Error::Message("Ping timeout".into())),
};
@ -431,7 +431,7 @@ impl FullMeshPeeringStrategy {
let pex_message = PeerListMessage { list: peer_list };
match self
.peer_list_endpoint
.call(id, &pex_message, PRIO_BACKGROUND)
.call(id, pex_message, PRIO_BACKGROUND)
.await
{
Err(e) => warn!("Error doing peer exchange: {}", e),
@ -587,7 +587,7 @@ impl FullMeshPeeringStrategy {
#[async_trait]
impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
async fn handle(self: &Arc<Self>, ping: &PingMessage, from: NodeID) -> PingMessage {
async fn handle(self: &Arc<Self>, ping: PingMessage, from: NodeID) -> PingMessage {
let ping_resp = PingMessage {
id: ping.id,
peer_list_hash: self.known_hosts.read().unwrap().hash,
@ -601,7 +601,7 @@ impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
impl EndpointHandler<PeerListMessage> for FullMeshPeeringStrategy {
async fn handle(
self: &Arc<Self>,
peer_list: &PeerListMessage,
peer_list: PeerListMessage,
_from: NodeID,
) -> PeerListMessage {
self.handle_peer_list(&peer_list.list[..]);

View file

@ -3,8 +3,8 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::Bytes;
use async_trait::async_trait;
use bytes::Bytes;
use log::trace;
use futures::AsyncWriteExt;

View file

@ -2,9 +2,9 @@ use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::pin::Pin;
use bytes::Bytes;
use log::info;
use serde::Serialize;
use bytes::Bytes;
use futures::Stream;
use tokio::sync::watch;
@ -35,19 +35,16 @@ pub type Packet = Result<Bytes, u8>;
///
/// Field names and variant names are included in the serialization.
/// This is used internally by the netapp communication protocol.
pub fn rmp_to_vec_all_named<T>(
val: &T,
) -> Result<(Vec<u8>, Option<ByteStream>), rmp_serde::encode::Error>
pub fn rmp_to_vec_all_named<T>(val: &T) -> Result<Vec<u8>, rmp_serde::encode::Error>
where
T: SerializeMessage + ?Sized,
T: Serialize + ?Sized,
{
let mut wr = Vec::with_capacity(128);
let mut se = rmp_serde::Serializer::new(&mut wr)
.with_struct_map()
.with_string_variants();
let (val, stream) = val.serialize_msg();
val.serialize(&mut se)?;
Ok((wr, stream))
Ok(wr)
}
/// This async function returns only when a true signal was received