Propose alternative API

This commit is contained in:
Alex 2022-07-21 20:22:56 +02:00
parent 7d148c7e76
commit 4934ed726d
Signed by untrusted user: lx
GPG key ID: 0E496D15096376BE
7 changed files with 219 additions and 96 deletions

View file

@ -159,11 +159,15 @@ impl Example {
#[async_trait] #[async_trait]
impl EndpointHandler<ExampleMessage> for Example { impl EndpointHandler<ExampleMessage> for Example {
async fn handle(self: &Arc<Self>, msg: ExampleMessage, _from: NodeID) -> ExampleResponse { async fn handle(
self: &Arc<Self>,
msg: Req<ExampleMessage>,
_from: NodeID,
) -> Resp<ExampleMessage> {
debug!("Got example message: {:?}, sending example response", msg); debug!("Got example message: {:?}, sending example response", msg);
ExampleResponse { Resp::new(ExampleResponse {
example_field: false, example_field: false,
} })
} }
} }

View file

@ -135,10 +135,10 @@ impl ClientConn {
pub(crate) async fn call<T>( pub(crate) async fn call<T>(
self: Arc<Self>, self: Arc<Self>,
rq: T, req: Req<T>,
path: &str, path: &str,
prio: RequestPriority, prio: RequestPriority,
) -> Result<<T as Message>::Response, Error> ) -> Result<Resp<T>, Error>
where where
T: Message, T: Message,
{ {
@ -162,9 +162,8 @@ impl ClientConn {
}; };
// Encode request // Encode request
let (rq, stream) = rq.into_parts(); let body = req.msg_ser.unwrap().clone();
let body = rmp_to_vec_all_named(&rq)?; let stream = req.body.into_stream();
drop(rq);
let request = QueryMessage { let request = QueryMessage {
prio, prio,
@ -216,7 +215,11 @@ impl ClientConn {
let code = resp[0]; let code = resp[0];
if code == 0 { if code == 0 {
let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?; let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
Ok(T::Response::from_parts(ser_resp, stream)) Ok(Resp {
_phantom: Default::default(),
msg: ser_resp,
body: BodyData::Stream(stream),
})
} else { } else {
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
Err(Error::Remote(code, msg)) Err(Error::Remote(code, msg))

View file

@ -18,7 +18,7 @@ pub trait EndpointHandler<M>: Send + Sync
where where
M: Message, M: Message,
{ {
async fn handle(self: &Arc<Self>, m: M, from: NodeID) -> M::Response; async fn handle(self: &Arc<Self>, m: Req<M>, from: NodeID) -> Resp<M>;
} }
/// If one simply wants to use an endpoint in a client fashion, /// If one simply wants to use an endpoint in a client fashion,
@ -27,7 +27,7 @@ where
/// it will panic if it is ever made to handle request. /// it will panic if it is ever made to handle request.
#[async_trait] #[async_trait]
impl<M: Message + 'static> EndpointHandler<M> for () { impl<M: Message + 'static> EndpointHandler<M> for () {
async fn handle(self: &Arc<()>, _m: M, _from: NodeID) -> M::Response { async fn handle(self: &Arc<()>, _m: Req<M>, _from: NodeID) -> Resp<M> {
panic!("This endpoint should not have a local handler."); panic!("This endpoint should not have a local handler.");
} }
} }
@ -80,16 +80,19 @@ where
/// Call this endpoint on a remote node (or on the local node, /// Call this endpoint on a remote node (or on the local node,
/// for that matter) /// for that matter)
pub async fn call( pub async fn call_full<T>(
&self, &self,
target: &NodeID, target: &NodeID,
req: M, req: T,
prio: RequestPriority, prio: RequestPriority,
) -> Result<<M as Message>::Response, Error> { ) -> Result<Resp<M>, Error>
where
T: IntoReq<M>,
{
if *target == self.netapp.id { if *target == self.netapp.id {
match self.handler.load_full() { match self.handler.load_full() {
None => Err(Error::NoHandler), None => Err(Error::NoHandler),
Some(h) => Ok(h.handle(req, self.netapp.id).await), Some(h) => Ok(h.handle(req.into_req_local(), self.netapp.id).await),
} }
} else { } else {
let conn = self let conn = self
@ -104,10 +107,21 @@ where
"Not connected: {}", "Not connected: {}",
hex::encode(&target[..8]) hex::encode(&target[..8])
))), ))),
Some(c) => c.call(req, self.path.as_str(), prio).await, Some(c) => c.call(req.into_req()?, self.path.as_str(), prio).await,
} }
} }
} }
/// Call this endpoint on a remote node, without the possibility
/// of adding or receiving a body
pub async fn call(
&self,
target: &NodeID,
req: M,
prio: RequestPriority,
) -> Result<<M as Message>::Response, Error> {
Ok(self.call_full(target, req, prio).await?.into_msg())
}
} }
// ---- Internal stuff ---- // ---- Internal stuff ----
@ -148,11 +162,20 @@ where
None => Err(Error::NoHandler), None => Err(Error::NoHandler),
Some(h) => { Some(h) => {
let req = rmp_serde::decode::from_read_ref(buf)?; let req = rmp_serde::decode::from_read_ref(buf)?;
let req = M::from_parts(req, stream); let req = Req {
_phantom: Default::default(),
msg: Arc::new(req),
msg_ser: None,
body: BodyData::Stream(stream),
};
let res = h.handle(req, from).await; let res = h.handle(req, from).await;
let (res, res_stream) = res.into_parts(); let Resp {
let res_bytes = rmp_to_vec_all_named(&res)?; msg,
Ok((res_bytes, res_stream)) body,
_phantom,
} = res;
let res_bytes = rmp_to_vec_all_named(&msg)?;
Ok((res_bytes, body.into_stream()))
} }
} }
} }

View file

@ -1,3 +1,7 @@
use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
use bytes::Bytes; use bytes::Bytes;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -37,94 +41,169 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01;
/// This trait should be implemented by all messages your application /// This trait should be implemented by all messages your application
/// wants to handle /// wants to handle
pub trait Message: SerializeMessage + Send + Sync { pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
type Response: SerializeMessage + Send + Sync; type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
} }
/// A trait for de/serializing messages, with possible associated stream. pub struct Req<M: Message> {
/// This is default-implemented by anything that can already be serialized pub(crate) _phantom: PhantomData<M>,
/// and deserialized. Adapters are provided that implement this for pub(crate) msg: Arc<M>,
/// adding a body, either from a fixed Bytes buffer (which allows the thing pub(crate) msg_ser: Option<Bytes>,
/// to be Clone), or from a streaming byte stream. pub(crate) body: BodyData,
pub trait SerializeMessage: Sized {
type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>);
fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self;
} }
// ---- pub struct Resp<M: Message> {
pub(crate) _phantom: PhantomData<M>,
pub(crate) msg: M::Response,
pub(crate) body: BodyData,
}
impl<T> SerializeMessage for T pub(crate) enum BodyData {
where None,
T: Serialize + for<'de> Deserialize<'de> + Send, Fixed(Bytes),
{ Stream(ByteStream),
type SerializableSelf = Self; }
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
(self, None) impl BodyData {
pub fn into_stream(self) -> Option<ByteStream> {
match self {
BodyData::None => None,
BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))),
BodyData::Stream(s) => Some(s),
} }
fn from_parts(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
// TODO verify no stream
ser_self
} }
} }
// ---- // ----
/// An adapter that adds a body from a fixed Bytes buffer to a serializable message, impl<M: Message> Req<M> {
/// implementing the SerializeMessage trait. This allows for the SerializeMessage object pub fn msg(&self) -> &M {
/// to be cloned, which is usefull for requests that must be sent to multiple servers. &self.msg
/// Note that cloning the body is cheap thanks to Bytes; make sure that your serializable
/// part is also easily clonable (e.g. by wrapping it in an Arc).
/// Note that this CANNOT be used for a response type, as it cannot be reconstructed
/// from a remote stream.
#[derive(Clone)]
pub struct WithFixedBody<T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static>(
pub T,
pub Bytes,
);
impl<T> SerializeMessage for WithFixedBody<T>
where
T: Serialize + for<'de> Deserialize<'de> + Clone + Send + 'static,
{
type SerializableSelf = T;
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) {
let body = self.1;
(
self.0,
Some(Box::pin(futures::stream::once(async move { Ok(body) }))),
)
} }
fn from_parts(_ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { pub fn with_fixed_body(self, b: Bytes) -> Self {
panic!("Cannot use a WithFixedBody as a response type"); Self {
body: BodyData::Fixed(b),
..self
}
}
pub fn with_streaming_body(self, b: ByteStream) -> Self {
Self {
body: BodyData::Stream(b),
..self
}
} }
} }
/// An adapter that adds a body from a ByteStream. This is usefull for receiving pub trait IntoReq<M: Message> {
/// responses to requests that contain attached byte streams. This type is fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error>;
/// not clonable. fn into_req_local(self) -> Req<M>;
pub struct WithStreamingBody<T: Serialize + for<'de> Deserialize<'de> + Send>( }
pub T,
pub ByteStream,
);
impl<T> SerializeMessage for WithStreamingBody<T> impl<M: Message> IntoReq<M> for M {
fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> {
let msg_ser = rmp_to_vec_all_named(&self)?;
Ok(Req {
_phantom: Default::default(),
msg: Arc::new(self),
msg_ser: Some(Bytes::from(msg_ser)),
body: BodyData::None,
})
}
fn into_req_local(self) -> Req<M> {
Req {
_phantom: Default::default(),
msg: Arc::new(self),
msg_ser: None,
body: BodyData::None,
}
}
}
impl<M: Message> IntoReq<M> for Req<M> {
fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> {
Ok(self)
}
fn into_req_local(self) -> Req<M> {
self
}
}
impl<M: Message> Clone for Req<M> {
fn clone(&self) -> Self {
let body = match &self.body {
BodyData::None => BodyData::None,
BodyData::Fixed(b) => BodyData::Fixed(b.clone()),
BodyData::Stream(_) => panic!("Cannot clone a Req<_> with a stream body"),
};
Self {
_phantom: Default::default(),
msg: self.msg.clone(),
msg_ser: self.msg_ser.clone(),
body,
}
}
}
impl<M> fmt::Debug for Req<M>
where where
T: Serialize + for<'de> Deserialize<'de> + Send, M: Message + fmt::Debug,
{ {
type SerializableSelf = T; fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "Req[{:?}", self.msg)?;
match &self.body {
BodyData::None => write!(f, "]"),
BodyData::Fixed(b) => write!(f, "; body={}]", b.len()),
BodyData::Stream(_) => write!(f, "; body=stream]"),
}
}
}
fn into_parts(self) -> (Self::SerializableSelf, Option<ByteStream>) { impl<M> fmt::Debug for Resp<M>
(self.0, Some(self.1)) where
M: Message,
<M as Message>::Response: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "Resp[{:?}", self.msg)?;
match &self.body {
BodyData::None => write!(f, "]"),
BodyData::Fixed(b) => write!(f, "; body={}]", b.len()),
BodyData::Stream(_) => write!(f, "; body=stream]"),
}
}
}
impl<M: Message> Resp<M> {
pub fn new(v: M::Response) -> Self {
Resp {
_phantom: Default::default(),
msg: v,
body: BodyData::None,
}
} }
fn from_parts(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { pub fn with_fixed_body(self, b: Bytes) -> Self {
WithStreamingBody(ser_self, stream) Self {
body: BodyData::Fixed(b),
..self
}
}
pub fn with_streaming_body(self, b: ByteStream) -> Self {
Self {
body: BodyData::Stream(b),
..self
}
}
pub fn msg(&self) -> &M::Response {
&self.msg
}
pub fn into_msg(self) -> M::Response {
self.msg
} }
} }

View file

@ -404,6 +404,7 @@ impl NetApp {
PRIO_NORMAL, PRIO_NORMAL,
) )
.await .await
.map(|_| ())
.log_err("Sending hello message"); .log_err("Sending hello message");
}); });
} }
@ -432,7 +433,8 @@ impl NetApp {
#[async_trait] #[async_trait]
impl EndpointHandler<HelloMessage> for NetApp { impl EndpointHandler<HelloMessage> for NetApp {
async fn handle(self: &Arc<Self>, msg: HelloMessage, from: NodeID) { async fn handle(self: &Arc<Self>, msg: Req<HelloMessage>, from: NodeID) -> Resp<HelloMessage> {
let msg = msg.msg();
debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg); debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg);
if let Some(h) = self.on_connected_handler.load().as_ref() { if let Some(h) = self.on_connected_handler.load().as_ref() {
if let Some(c) = self.server_conns.read().unwrap().get(&from) { if let Some(c) = self.server_conns.read().unwrap().get(&from) {
@ -441,5 +443,6 @@ impl EndpointHandler<HelloMessage> for NetApp {
h(from, remote_addr, true); h(from, remote_addr, true);
} }
} }
Resp::new(())
} }
} }

View file

@ -468,15 +468,24 @@ impl Basalt {
#[async_trait] #[async_trait]
impl EndpointHandler<PullMessage> for Basalt { impl EndpointHandler<PullMessage> for Basalt {
async fn handle(self: &Arc<Self>, _pullmsg: PullMessage, _from: NodeID) -> PushMessage { async fn handle(
self.make_push_message() self: &Arc<Self>,
_pullmsg: Req<PullMessage>,
_from: NodeID,
) -> Resp<PullMessage> {
Resp::new(self.make_push_message())
} }
} }
#[async_trait] #[async_trait]
impl EndpointHandler<PushMessage> for Basalt { impl EndpointHandler<PushMessage> for Basalt {
async fn handle(self: &Arc<Self>, pushmsg: PushMessage, _from: NodeID) { async fn handle(
self.handle_peer_list(&pushmsg.peers[..]); self: &Arc<Self>,
pushmsg: Req<PushMessage>,
_from: NodeID,
) -> Resp<PushMessage> {
self.handle_peer_list(&pushmsg.msg().peers[..]);
Resp::new(())
} }
} }

View file

@ -583,13 +583,14 @@ impl FullMeshPeeringStrategy {
#[async_trait] #[async_trait]
impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy { impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
async fn handle(self: &Arc<Self>, ping: PingMessage, from: NodeID) -> PingMessage { async fn handle(self: &Arc<Self>, ping: Req<PingMessage>, from: NodeID) -> Resp<PingMessage> {
let ping = ping.msg();
let ping_resp = PingMessage { let ping_resp = PingMessage {
id: ping.id, id: ping.id,
peer_list_hash: self.known_hosts.read().unwrap().hash, peer_list_hash: self.known_hosts.read().unwrap().hash,
}; };
debug!("Ping from {}", hex::encode(&from[..8])); debug!("Ping from {}", hex::encode(&from[..8]));
ping_resp Resp::new(ping_resp)
} }
} }
@ -597,11 +598,12 @@ impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
impl EndpointHandler<PeerListMessage> for FullMeshPeeringStrategy { impl EndpointHandler<PeerListMessage> for FullMeshPeeringStrategy {
async fn handle( async fn handle(
self: &Arc<Self>, self: &Arc<Self>,
peer_list: PeerListMessage, peer_list: Req<PeerListMessage>,
_from: NodeID, _from: NodeID,
) -> PeerListMessage { ) -> Resp<PeerListMessage> {
let peer_list = peer_list.msg();
self.handle_peer_list(&peer_list.list[..]); self.handle_peer_list(&peer_list.list[..]);
let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list); let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list);
PeerListMessage { list: peer_list } Resp::new(PeerListMessage { list: peer_list })
} }
} }