add streaming body to requests and responses #3
8 changed files with 116 additions and 71 deletions
|
@ -26,7 +26,7 @@ tokio = { version = "1.0", default-features = false, features = ["net", "rt", "r
|
|||
tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] }
|
||||
tokio-stream = "0.1.7"
|
||||
|
||||
serde = { version = "1.0", default-features = false, features = ["derive"] }
|
||||
serde = { version = "1.0", default-features = false, features = ["derive", "rc"] }
|
||||
rmp-serde = "0.14.3"
|
||||
hex = "0.4.2"
|
||||
|
||||
|
|
|
@ -134,15 +134,14 @@ impl ClientConn {
|
|||
self.query_send.store(None);
|
||||
}
|
||||
|
||||
pub(crate) async fn call<T, B>(
|
||||
pub(crate) async fn call<T>(
|
||||
self: Arc<Self>,
|
||||
rq: B,
|
||||
rq: T,
|
||||
path: &str,
|
||||
prio: RequestPriority,
|
||||
) -> Result<<T as Message>::Response, Error>
|
||||
where
|
||||
T: Message,
|
||||
B: Borrow<T>,
|
||||
{
|
||||
let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
|
||||
|
||||
|
@ -164,7 +163,8 @@ impl ClientConn {
|
|||
};
|
||||
|
||||
// Encode request
|
||||
let (body, stream) = rmp_to_vec_all_named(rq.borrow())?;
|
||||
let (rq, stream) = rq.into_parts();
|
||||
let body = rmp_to_vec_all_named(&rq)?;
|
||||
drop(rq);
|
||||
|
||||
let request = QueryMessage {
|
||||
|
@ -217,7 +217,7 @@ impl ClientConn {
|
|||
let code = resp[0];
|
||||
if code == 0 {
|
||||
let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
|
||||
Ok(T::Response::deserialize_msg(ser_resp, stream).await)
|
||||
Ok(T::Response::from_parts(ser_resp, stream))
|
||||
} else {
|
||||
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
|
||||
Err(Error::Remote(code, msg))
|
||||
|
|
|
@ -19,7 +19,7 @@ pub trait EndpointHandler<M>: Send + Sync
|
|||
where
|
||||
M: Message,
|
||||
{
|
||||
async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> M::Response;
|
||||
async fn handle(self: &Arc<Self>, m: M, from: NodeID) -> M::Response;
|
||||
}
|
||||
|
||||
/// If one simply wants to use an endpoint in a client fashion,
|
||||
|
@ -28,7 +28,7 @@ where
|
|||
/// it will panic if it is ever made to handle request.
|
||||
#[async_trait]
|
||||
impl<M: Message + 'static> EndpointHandler<M> for () {
|
||||
async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response {
|
||||
async fn handle(self: &Arc<()>, _m: M, _from: NodeID) -> M::Response {
|
||||
panic!("This endpoint should not have a local handler.");
|
||||
}
|
||||
}
|
||||
|
@ -81,19 +81,16 @@ where
|
|||
|
||||
/// Call this endpoint on a remote node (or on the local node,
|
||||
/// for that matter)
|
||||
pub async fn call<B>(
|
||||
pub async fn call(
|
||||
&self,
|
||||
target: &NodeID,
|
||||
req: B,
|
||||
req: M,
|
||||
prio: RequestPriority,
|
||||
) -> Result<<M as Message>::Response, Error>
|
||||
where
|
||||
B: Borrow<M> + Send + Sync,
|
||||
{
|
||||
) -> Result<<M as Message>::Response, Error> {
|
||||
if *target == self.netapp.id {
|
||||
match self.handler.load_full() {
|
||||
None => Err(Error::NoHandler),
|
||||
Some(h) => Ok(h.handle(req.borrow(), self.netapp.id).await),
|
||||
Some(h) => Ok(h.handle(req, self.netapp.id).await),
|
||||
}
|
||||
} else {
|
||||
let conn = self
|
||||
|
@ -152,10 +149,11 @@ where
|
|||
None => Err(Error::NoHandler),
|
||||
Some(h) => {
|
||||
let req = rmp_serde::decode::from_read_ref(buf)?;
|
||||
let req = M::deserialize_msg(req, stream).await;
|
||||
let res = h.handle(&req, from).await;
|
||||
let req = M::from_parts(req, stream);
|
||||
let res = h.handle(req, from).await;
|
||||
let (res, res_stream) = res.into_parts();
|
||||
let res_bytes = rmp_to_vec_all_named(&res)?;
|
||||
Ok(res_bytes)
|
||||
Ok((res_bytes, res_stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
122
src/message.rs
122
src/message.rs
|
@ -1,7 +1,9 @@
|
|||
use async_trait::async_trait;
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use bytes::Bytes;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
|
||||
use crate::error::*;
|
||||
use crate::util::*;
|
||||
|
||||
|
@ -41,66 +43,112 @@ pub trait Message: SerializeMessage + Send + Sync {
|
|||
}
|
||||
|
||||
/// A trait for de/serializing messages, with possible associated stream.
|
||||
#[async_trait]
|
||||
pub trait SerializeMessage: Sized {
|
||||
type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
|
||||
|
||||
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>);
|
||||
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>);
|
||||
|
||||
// TODO should return Result
|
||||
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self;
|
||||
fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self;
|
||||
}
|
||||
|
||||
pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
|
||||
// ----
|
||||
|
||||
#[async_trait]
|
||||
impl<T> SerializeMessage for T
|
||||
impl<T, E> SerializeMessage for Result<T, E>
|
||||
where
|
||||
T: AutoSerialize,
|
||||
T: SerializeMessage + Send,
|
||||
E: Serialize + for<'de> Deserialize<'de> + Send,
|
||||
{
|
||||
type SerializableSelf = Self;
|
||||
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>) {
|
||||
(self.clone(), None)
|
||||
type SerializableSelf = Result<T::SerializableSelf, E>;
|
||||
|
||||
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
|
||||
match self {
|
||||
Ok(ok) => {
|
||||
let (msg, stream) = ok.into_parts();
|
||||
(Ok(msg), stream)
|
||||
}
|
||||
Err(err) => (Err(err), None),
|
||||
}
|
||||
}
|
||||
|
||||
async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
|
||||
fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
|
||||
match ser_self {
|
||||
Ok(ok) => Ok(T::from_parts(ok, stream)),
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---
|
||||
|
||||
pub trait SimpleMessage: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
|
||||
|
||||
impl<T> SerializeMessage for T
|
||||
where
|
||||
T: SimpleMessage,
|
||||
{
|
||||
type SerializableSelf = Self;
|
||||
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
|
||||
(self, None)
|
||||
}
|
||||
|
||||
fn from_parts(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
|
||||
// TODO verify no stream
|
||||
ser_self
|
||||
}
|
||||
}
|
||||
|
||||
impl AutoSerialize for () {}
|
||||
impl SimpleMessage for () {}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, E> SerializeMessage for Result<T, E>
|
||||
impl<T: SimpleMessage> SimpleMessage for std::sync::Arc<T> {}
|
||||
|
||||
// ----
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WithFixedBody<T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static>(
|
||||
pub T,
|
||||
pub Bytes,
|
||||
);
|
||||
|
||||
impl<T> SerializeMessage for WithFixedBody<T>
|
||||
where
|
||||
T: SerializeMessage + Send,
|
||||
E: SerializeMessage + Send,
|
||||
T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static,
|
||||
{
|
||||
type SerializableSelf = Result<T::SerializableSelf, E::SerializableSelf>;
|
||||
type SerializableSelf = T;
|
||||
|
||||
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>) {
|
||||
match self {
|
||||
Ok(ok) => {
|
||||
let (msg, stream) = ok.serialize_msg();
|
||||
(Ok(msg), stream)
|
||||
}
|
||||
Err(err) => {
|
||||
let (msg, stream) = err.serialize_msg();
|
||||
(Err(msg), stream)
|
||||
}
|
||||
}
|
||||
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
|
||||
let body = self.1;
|
||||
(
|
||||
self.0,
|
||||
Some(Box::pin(futures::stream::once(async move { Ok(body) }))),
|
||||
)
|
||||
}
|
||||
|
||||
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
|
||||
match ser_self {
|
||||
Ok(ok) => Ok(T::deserialize_msg(ok, stream).await),
|
||||
Err(err) => Err(E::deserialize_msg(err, stream).await),
|
||||
}
|
||||
fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
|
||||
panic!("Cannot reconstruct a WithFixedBody type from parts");
|
||||
}
|
||||
}
|
||||
|
||||
// ----
|
||||
pub struct WithStreamingBody<T: Serialize + for<'de> Deserialize<'de> + Send>(
|
||||
pub T,
|
||||
pub ByteStream,
|
||||
);
|
||||
|
||||
impl<T> SerializeMessage for WithStreamingBody<T>
|
||||
where
|
||||
T: Serialize + for<'de> Deserialize<'de> + Send,
|
||||
{
|
||||
type SerializableSelf = T;
|
||||
|
||||
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
|
||||
(self.0, Some(self.1))
|
||||
}
|
||||
|
||||
fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
|
||||
WithStreamingBody(ser_self, stream)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- ----
|
||||
|
||||
pub(crate) struct QueryMessage<'a> {
|
||||
pub(crate) prio: RequestPriority,
|
||||
|
@ -175,6 +223,8 @@ impl<'a> QueryMessage<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
// ---- ----
|
||||
|
||||
pub(crate) struct Framing {
|
||||
direct: Vec<u8>,
|
||||
stream: Option<ByteStream>,
|
||||
|
|
|
@ -38,7 +38,7 @@ pub(crate) struct HelloMessage {
|
|||
pub server_port: u16,
|
||||
}
|
||||
|
||||
impl AutoSerialize for HelloMessage {}
|
||||
impl SimpleMessage for HelloMessage {}
|
||||
|
||||
impl Message for HelloMessage {
|
||||
type Response = ();
|
||||
|
@ -399,7 +399,7 @@ impl NetApp {
|
|||
hello_endpoint
|
||||
.call(
|
||||
&conn.peer_id,
|
||||
&HelloMessage {
|
||||
HelloMessage {
|
||||
server_addr,
|
||||
server_port,
|
||||
},
|
||||
|
@ -434,7 +434,7 @@ impl NetApp {
|
|||
|
||||
#[async_trait]
|
||||
impl EndpointHandler<HelloMessage> for NetApp {
|
||||
async fn handle(self: &Arc<Self>, msg: &HelloMessage, from: NodeID) {
|
||||
async fn handle(self: &Arc<Self>, msg: HelloMessage, from: NodeID) {
|
||||
debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg);
|
||||
if let Some(h) = self.on_connected_handler.load().as_ref() {
|
||||
if let Some(c) = self.server_conns.read().unwrap().get(&from) {
|
||||
|
|
|
@ -40,7 +40,7 @@ impl Message for PingMessage {
|
|||
type Response = PingMessage;
|
||||
}
|
||||
|
||||
impl AutoSerialize for PingMessage {}
|
||||
impl SimpleMessage for PingMessage {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
struct PeerListMessage {
|
||||
|
@ -51,7 +51,7 @@ impl Message for PeerListMessage {
|
|||
type Response = PeerListMessage;
|
||||
}
|
||||
|
||||
impl AutoSerialize for PeerListMessage {}
|
||||
impl SimpleMessage for PeerListMessage {}
|
||||
|
||||
// -- Algorithm data structures --
|
||||
|
||||
|
@ -379,7 +379,7 @@ impl FullMeshPeeringStrategy {
|
|||
ping_time
|
||||
);
|
||||
let ping_response = select! {
|
||||
r = self.ping_endpoint.call(&id, &ping_msg, PRIO_HIGH) => r,
|
||||
r = self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH) => r,
|
||||
_ = tokio::time::sleep(PING_TIMEOUT) => Err(Error::Message("Ping timeout".into())),
|
||||
};
|
||||
|
||||
|
@ -431,7 +431,7 @@ impl FullMeshPeeringStrategy {
|
|||
let pex_message = PeerListMessage { list: peer_list };
|
||||
match self
|
||||
.peer_list_endpoint
|
||||
.call(id, &pex_message, PRIO_BACKGROUND)
|
||||
.call(id, pex_message, PRIO_BACKGROUND)
|
||||
.await
|
||||
{
|
||||
Err(e) => warn!("Error doing peer exchange: {}", e),
|
||||
|
@ -587,7 +587,7 @@ impl FullMeshPeeringStrategy {
|
|||
|
||||
#[async_trait]
|
||||
impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
|
||||
async fn handle(self: &Arc<Self>, ping: &PingMessage, from: NodeID) -> PingMessage {
|
||||
async fn handle(self: &Arc<Self>, ping: PingMessage, from: NodeID) -> PingMessage {
|
||||
let ping_resp = PingMessage {
|
||||
id: ping.id,
|
||||
peer_list_hash: self.known_hosts.read().unwrap().hash,
|
||||
|
@ -601,7 +601,7 @@ impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
|
|||
impl EndpointHandler<PeerListMessage> for FullMeshPeeringStrategy {
|
||||
async fn handle(
|
||||
self: &Arc<Self>,
|
||||
peer_list: &PeerListMessage,
|
||||
peer_list: PeerListMessage,
|
||||
_from: NodeID,
|
||||
) -> PeerListMessage {
|
||||
self.handle_peer_list(&peer_list.list[..]);
|
||||
|
|
|
@ -3,8 +3,8 @@ use std::pin::Pin;
|
|||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use bytes::Bytes;
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use log::trace;
|
||||
|
||||
use futures::AsyncWriteExt;
|
||||
|
|
11
src/util.rs
11
src/util.rs
|
@ -2,9 +2,9 @@ use std::net::SocketAddr;
|
|||
use std::net::ToSocketAddrs;
|
||||
use std::pin::Pin;
|
||||
|
||||
use bytes::Bytes;
|
||||
use log::info;
|
||||
use serde::Serialize;
|
||||
use bytes::Bytes;
|
||||
|
||||
use futures::Stream;
|
||||
use tokio::sync::watch;
|
||||
|
@ -35,19 +35,16 @@ pub type Packet = Result<Bytes, u8>;
|
|||
///
|
||||
/// Field names and variant names are included in the serialization.
|
||||
/// This is used internally by the netapp communication protocol.
|
||||
pub fn rmp_to_vec_all_named<T>(
|
||||
val: &T,
|
||||
) -> Result<(Vec<u8>, Option<ByteStream>), rmp_serde::encode::Error>
|
||||
pub fn rmp_to_vec_all_named<T>(val: &T) -> Result<Vec<u8>, rmp_serde::encode::Error>
|
||||
where
|
||||
T: SerializeMessage + ?Sized,
|
||||
T: Serialize + ?Sized,
|
||||
{
|
||||
let mut wr = Vec::with_capacity(128);
|
||||
let mut se = rmp_serde::Serializer::new(&mut wr)
|
||||
.with_struct_map()
|
||||
.with_string_variants();
|
||||
let (val, stream) = val.serialize_msg();
|
||||
val.serialize(&mut se)?;
|
||||
Ok((wr, stream))
|
||||
Ok(wr)
|
||||
}
|
||||
|
||||
/// This async function returns only when a true signal was received
|
||||
|
|
Loading…
Reference in a new issue