Clean up framing protocol
This commit is contained in:
parent
c358fe3c92
commit
0b71ca12f9
9 changed files with 432 additions and 268 deletions
|
@ -5,6 +5,7 @@ use std::sync::{Arc, Mutex};
|
|||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use log::{debug, error, trace};
|
||||
|
||||
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
|
||||
|
@ -28,6 +29,7 @@ use crate::message::*;
|
|||
use crate::netapp::*;
|
||||
use crate::recv::*;
|
||||
use crate::send::*;
|
||||
use crate::stream::*;
|
||||
use crate::util::*;
|
||||
|
||||
pub(crate) struct ClientConn {
|
||||
|
@ -155,24 +157,16 @@ impl ClientConn {
|
|||
.with_kind(SpanKind::Client)
|
||||
.start(&tracer);
|
||||
let propagator = BinaryPropagator::new();
|
||||
let telemetry_id = Some(propagator.to_bytes(span.span_context()).to_vec());
|
||||
let telemetry_id: Bytes = propagator.to_bytes(span.span_context()).to_vec().into();
|
||||
} else {
|
||||
let telemetry_id: Option<Vec<u8>> = None;
|
||||
let telemetry_id: Bytes = Bytes::new();
|
||||
}
|
||||
};
|
||||
|
||||
// Encode request
|
||||
let body = req.msg_ser.unwrap().clone();
|
||||
let stream = req.body.into_stream();
|
||||
|
||||
let request = QueryMessage {
|
||||
prio,
|
||||
path: path.as_bytes(),
|
||||
telemetry_id,
|
||||
body: &body[..],
|
||||
};
|
||||
let bytes = request.encode();
|
||||
drop(body);
|
||||
let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id);
|
||||
let req_msg_len = req_enc.msg.len();
|
||||
let req_stream = req_enc.encode();
|
||||
|
||||
// Send request through
|
||||
let (resp_send, resp_recv) = oneshot::channel();
|
||||
|
@ -181,17 +175,19 @@ impl ClientConn {
|
|||
error!(
|
||||
"Too many inflight requests! RequestID collision. Interrupting previous request."
|
||||
);
|
||||
if old_ch.send(unbounded().1).is_err() {
|
||||
debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
|
||||
}
|
||||
let _ = old_ch.send(unbounded().1);
|
||||
}
|
||||
|
||||
trace!("request: query_send {}, {} bytes", id, bytes.len());
|
||||
trace!(
|
||||
"request: query_send {} (serialized message: {} bytes)",
|
||||
id,
|
||||
req_msg_len
|
||||
);
|
||||
|
||||
#[cfg(feature = "telemetry")]
|
||||
span.set_attribute(KeyValue::new("len_query", bytes.len() as i64));
|
||||
span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64));
|
||||
|
||||
query_send.send((id, prio, Framing::new(bytes, stream).into_stream()))?;
|
||||
query_send.send((id, prio, req_stream))?;
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "telemetry")] {
|
||||
|
@ -202,28 +198,10 @@ impl ClientConn {
|
|||
let stream = resp_recv.await?;
|
||||
}
|
||||
}
|
||||
let (resp, stream) = Framing::from_stream(stream).await?.into_parts();
|
||||
|
||||
if resp.is_empty() {
|
||||
return Err(Error::Message(
|
||||
"Response is 0 bytes, either a collision or a protocol error".into(),
|
||||
));
|
||||
}
|
||||
|
||||
trace!("request response {}: ", id);
|
||||
|
||||
let code = resp[0];
|
||||
if code == 0 {
|
||||
let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
|
||||
Ok(Resp {
|
||||
_phantom: Default::default(),
|
||||
msg: ser_resp,
|
||||
body: BodyData::Stream(stream),
|
||||
})
|
||||
} else {
|
||||
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
|
||||
Err(Error::Remote(code, msg))
|
||||
}
|
||||
let resp_enc = RespEnc::decode(Box::pin(stream)).await?;
|
||||
trace!("request response {}", id);
|
||||
Resp::from_enc(resp_enc)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -158,12 +158,7 @@ pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;
|
|||
|
||||
#[async_trait]
|
||||
pub(crate) trait GenericEndpoint {
|
||||
async fn handle(
|
||||
&self,
|
||||
buf: &[u8],
|
||||
stream: ByteStream,
|
||||
from: NodeID,
|
||||
) -> Result<(Vec<u8>, Option<ByteStream>), Error>;
|
||||
async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error>;
|
||||
fn drop_handler(&self);
|
||||
fn clone_endpoint(&self) -> DynEndpoint;
|
||||
}
|
||||
|
@ -180,30 +175,13 @@ where
|
|||
M: Message + 'static,
|
||||
H: StreamingEndpointHandler<M> + 'static,
|
||||
{
|
||||
async fn handle(
|
||||
&self,
|
||||
buf: &[u8],
|
||||
stream: ByteStream,
|
||||
from: NodeID,
|
||||
) -> Result<(Vec<u8>, Option<ByteStream>), Error> {
|
||||
async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error> {
|
||||
match self.0.handler.load_full() {
|
||||
None => Err(Error::NoHandler),
|
||||
Some(h) => {
|
||||
let req = rmp_serde::decode::from_read_ref(buf)?;
|
||||
let req = Req {
|
||||
_phantom: Default::default(),
|
||||
msg: Arc::new(req),
|
||||
msg_ser: None,
|
||||
body: BodyData::Stream(stream),
|
||||
};
|
||||
let req = Req::from_enc(req_enc)?;
|
||||
let res = h.handle(req, from).await;
|
||||
let Resp {
|
||||
msg,
|
||||
body,
|
||||
_phantom,
|
||||
} = res;
|
||||
let res_bytes = rmp_to_vec_all_named(&msg)?;
|
||||
Ok((res_bytes, body.into_stream()))
|
||||
Ok(res.into_enc()?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
//! Also check out the examples to learn how to use this crate.
|
||||
|
||||
pub mod error;
|
||||
pub mod stream;
|
||||
pub mod util;
|
||||
|
||||
pub mod endpoint;
|
||||
|
|
347
src/message.rs
347
src/message.rs
|
@ -2,12 +2,13 @@ use std::fmt;
|
|||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use futures::stream::StreamExt;
|
||||
|
||||
use crate::error::*;
|
||||
use crate::stream::*;
|
||||
use crate::util::*;
|
||||
|
||||
/// Priority of a request (click to read more about priorities).
|
||||
|
@ -45,6 +46,15 @@ pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
|
|||
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
|
||||
}
|
||||
|
||||
// ----
|
||||
|
||||
/// The Req<M> is a helper object used to create requests and attach them
|
||||
/// a streaming body. If the body is a fixed Bytes and not a ByteStream,
|
||||
/// Req<M> is cheaply clonable to allow the request to be sent to different
|
||||
/// peers (Clone will panic if the body is a ByteStream).
|
||||
///
|
||||
/// Internally, this is also used to encode and decode requests
|
||||
/// from/to byte streams to be sent over the network.
|
||||
pub struct Req<M: Message> {
|
||||
pub(crate) _phantom: PhantomData<M>,
|
||||
pub(crate) msg: Arc<M>,
|
||||
|
@ -52,30 +62,6 @@ pub struct Req<M: Message> {
|
|||
pub(crate) body: BodyData,
|
||||
}
|
||||
|
||||
pub struct Resp<M: Message> {
|
||||
pub(crate) _phantom: PhantomData<M>,
|
||||
pub(crate) msg: M::Response,
|
||||
pub(crate) body: BodyData,
|
||||
}
|
||||
|
||||
pub(crate) enum BodyData {
|
||||
None,
|
||||
Fixed(Bytes),
|
||||
Stream(ByteStream),
|
||||
}
|
||||
|
||||
impl BodyData {
|
||||
pub fn into_stream(self) -> Option<ByteStream> {
|
||||
match self {
|
||||
BodyData::None => None,
|
||||
BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))),
|
||||
BodyData::Stream(s) => Some(s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----
|
||||
|
||||
impl<M: Message> Req<M> {
|
||||
pub fn msg(&self) -> &M {
|
||||
&self.msg
|
||||
|
@ -94,6 +80,31 @@ impl<M: Message> Req<M> {
|
|||
..self
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn into_enc(
|
||||
self,
|
||||
prio: RequestPriority,
|
||||
path: Bytes,
|
||||
telemetry_id: Bytes,
|
||||
) -> ReqEnc {
|
||||
ReqEnc {
|
||||
prio,
|
||||
path,
|
||||
telemetry_id,
|
||||
msg: self.msg_ser.unwrap(),
|
||||
stream: self.body.into_stream(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_enc(enc: ReqEnc) -> Result<Self, rmp_serde::decode::Error> {
|
||||
let msg = rmp_serde::decode::from_read_ref(&enc.msg)?;
|
||||
Ok(Req {
|
||||
_phantom: Default::default(),
|
||||
msg: Arc::new(msg),
|
||||
msg_ser: Some(enc.msg),
|
||||
body: enc.stream.map(BodyData::Stream).unwrap_or(BodyData::None),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub trait IntoReq<M: Message> {
|
||||
|
@ -160,19 +171,14 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<M> fmt::Debug for Resp<M>
|
||||
where
|
||||
M: Message,
|
||||
<M as Message>::Response: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
||||
write!(f, "Resp[{:?}", self.msg)?;
|
||||
match &self.body {
|
||||
BodyData::None => write!(f, "]"),
|
||||
BodyData::Fixed(b) => write!(f, "; body={}]", b.len()),
|
||||
BodyData::Stream(_) => write!(f, "; body=stream]"),
|
||||
}
|
||||
}
|
||||
// ----
|
||||
|
||||
/// The Resp<M> represents a full response from a RPC that may have
|
||||
/// an attached body stream.
|
||||
pub struct Resp<M: Message> {
|
||||
pub(crate) _phantom: PhantomData<M>,
|
||||
pub(crate) msg: M::Response,
|
||||
pub(crate) body: BodyData,
|
||||
}
|
||||
|
||||
impl<M: Message> Resp<M> {
|
||||
|
@ -205,160 +211,213 @@ impl<M: Message> Resp<M> {
|
|||
pub fn into_msg(self) -> M::Response {
|
||||
self.msg
|
||||
}
|
||||
|
||||
pub(crate) fn into_enc(self) -> Result<RespEnc, rmp_serde::encode::Error> {
|
||||
Ok(RespEnc::Success {
|
||||
msg: rmp_to_vec_all_named(&self.msg)?.into(),
|
||||
stream: self.body.into_stream(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn from_enc(enc: RespEnc) -> Result<Self, Error> {
|
||||
match enc {
|
||||
RespEnc::Success { msg, stream } => {
|
||||
let msg = rmp_serde::decode::from_read_ref(&msg)?;
|
||||
Ok(Self {
|
||||
_phantom: Default::default(),
|
||||
msg,
|
||||
body: stream.map(BodyData::Stream).unwrap_or(BodyData::None),
|
||||
})
|
||||
}
|
||||
RespEnc::Error { code, message } => Err(Error::Remote(code, message)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M> fmt::Debug for Resp<M>
|
||||
where
|
||||
M: Message,
|
||||
<M as Message>::Response: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
||||
write!(f, "Resp[{:?}", self.msg)?;
|
||||
match &self.body {
|
||||
BodyData::None => write!(f, "]"),
|
||||
BodyData::Fixed(b) => write!(f, "; body={}]", b.len()),
|
||||
BodyData::Stream(_) => write!(f, "; body=stream]"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----
|
||||
|
||||
pub(crate) enum BodyData {
|
||||
None,
|
||||
Fixed(Bytes),
|
||||
Stream(ByteStream),
|
||||
}
|
||||
|
||||
impl BodyData {
|
||||
pub fn into_stream(self) -> Option<ByteStream> {
|
||||
match self {
|
||||
BodyData::None => None,
|
||||
BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))),
|
||||
BodyData::Stream(s) => Some(s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- ----
|
||||
|
||||
pub(crate) struct QueryMessage<'a> {
|
||||
pub(crate) prio: RequestPriority,
|
||||
pub(crate) path: &'a [u8],
|
||||
pub(crate) telemetry_id: Option<Vec<u8>>,
|
||||
pub(crate) body: &'a [u8],
|
||||
}
|
||||
|
||||
/// QueryMessage encoding:
|
||||
/// Encoding for requests into a ByteStream:
|
||||
/// - priority: u8
|
||||
/// - path length: u8
|
||||
/// - path: [u8; path length]
|
||||
/// - telemetry id length: u8
|
||||
/// - telemetry id: [u8; telemetry id length]
|
||||
/// - body [u8; ..]
|
||||
impl<'a> QueryMessage<'a> {
|
||||
pub(crate) fn encode(self) -> Vec<u8> {
|
||||
let tel_len = match &self.telemetry_id {
|
||||
Some(t) => t.len(),
|
||||
None => 0,
|
||||
};
|
||||
/// - msg len: u32
|
||||
/// - msg [u8; ..]
|
||||
/// - the attached stream as the rest of the encoded stream
|
||||
pub(crate) struct ReqEnc {
|
||||
pub(crate) prio: RequestPriority,
|
||||
pub(crate) path: Bytes,
|
||||
pub(crate) telemetry_id: Bytes,
|
||||
pub(crate) msg: Bytes,
|
||||
pub(crate) stream: Option<ByteStream>,
|
||||
}
|
||||
|
||||
let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len());
|
||||
impl ReqEnc {
|
||||
pub(crate) fn encode(self) -> ByteStream {
|
||||
let mut buf = BytesMut::with_capacity(64);
|
||||
|
||||
ret.push(self.prio);
|
||||
buf.put_u8(self.prio);
|
||||
|
||||
ret.push(self.path.len() as u8);
|
||||
ret.extend_from_slice(self.path);
|
||||
buf.put_u8(self.path.len() as u8);
|
||||
buf.put(self.path);
|
||||
|
||||
if let Some(t) = self.telemetry_id {
|
||||
ret.push(t.len() as u8);
|
||||
ret.extend(t);
|
||||
buf.put_u8(self.telemetry_id.len() as u8);
|
||||
buf.put(&self.telemetry_id[..]);
|
||||
|
||||
buf.put_u32(self.msg.len() as u32);
|
||||
buf.put(&self.msg[..]);
|
||||
|
||||
let header = buf.freeze();
|
||||
|
||||
if let Some(stream) = self.stream {
|
||||
Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream))
|
||||
} else {
|
||||
ret.push(0u8);
|
||||
Box::pin(futures::stream::once(async move { Ok(header) }))
|
||||
}
|
||||
}
|
||||
|
||||
ret.extend_from_slice(self.body);
|
||||
|
||||
ret
|
||||
pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> {
|
||||
Self::decode_aux(stream).await.map_err(|_| Error::Framing)
|
||||
}
|
||||
|
||||
pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> {
|
||||
if bytes.len() < 3 {
|
||||
return Err(Error::Message("Invalid protocol message".into()));
|
||||
}
|
||||
pub(crate) async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> {
|
||||
let mut reader = ByteStreamReader::new(stream);
|
||||
|
||||
let path_length = bytes[1] as usize;
|
||||
if bytes.len() < 3 + path_length {
|
||||
return Err(Error::Message("Invalid protocol message".into()));
|
||||
}
|
||||
let prio = reader.read_u8().await?;
|
||||
|
||||
let telemetry_id_len = bytes[2 + path_length] as usize;
|
||||
if bytes.len() < 3 + path_length + telemetry_id_len {
|
||||
return Err(Error::Message("Invalid protocol message".into()));
|
||||
}
|
||||
let path_len = reader.read_u8().await?;
|
||||
let path = reader.read_exact(path_len as usize).await?;
|
||||
|
||||
let path = &bytes[2..2 + path_length];
|
||||
let telemetry_id = if telemetry_id_len > 0 {
|
||||
Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let telemetry_id_len = reader.read_u8().await?;
|
||||
let telemetry_id = reader.read_exact(telemetry_id_len as usize).await?;
|
||||
|
||||
let body = &bytes[3 + path_length + telemetry_id_len..];
|
||||
let msg_len = reader.read_u32().await?;
|
||||
let msg = reader.read_exact(msg_len as usize).await?;
|
||||
|
||||
Ok(Self {
|
||||
prio: bytes[0],
|
||||
prio,
|
||||
path,
|
||||
telemetry_id,
|
||||
body,
|
||||
msg,
|
||||
stream: Some(reader.into_stream()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---- ----
|
||||
|
||||
pub(crate) struct Framing {
|
||||
direct: Vec<u8>,
|
||||
/// Encoding for responses into a ByteStream:
|
||||
/// IF SUCCESS:
|
||||
/// - 0: u8
|
||||
/// - msg len: u32
|
||||
/// - msg [u8; ..]
|
||||
/// - the attached stream as the rest of the encoded stream
|
||||
/// IF ERROR:
|
||||
/// - message length + 1: u8
|
||||
/// - error code: u8
|
||||
/// - message: [u8; message_length]
|
||||
pub(crate) enum RespEnc {
|
||||
Error {
|
||||
code: u8,
|
||||
message: String,
|
||||
},
|
||||
Success {
|
||||
msg: Bytes,
|
||||
stream: Option<ByteStream>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Framing {
|
||||
pub fn new(direct: Vec<u8>, stream: Option<ByteStream>) -> Self {
|
||||
assert!(direct.len() <= u32::MAX as usize);
|
||||
Framing { direct, stream }
|
||||
impl RespEnc {
|
||||
pub(crate) fn from_err(e: Error) -> Self {
|
||||
RespEnc::Error {
|
||||
code: e.code(),
|
||||
message: format!("{}", e),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_stream(self) -> ByteStream {
|
||||
use futures::stream;
|
||||
let len = self.direct.len() as u32;
|
||||
// required because otherwise the borrow-checker complains
|
||||
let Framing { direct, stream } = self;
|
||||
pub(crate) fn encode(self) -> ByteStream {
|
||||
match self {
|
||||
RespEnc::Success { msg, stream } => {
|
||||
let mut buf = BytesMut::with_capacity(64);
|
||||
|
||||
let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec().into()) })
|
||||
.chain(stream::once(async move { Ok(direct.into()) }));
|
||||
buf.put_u8(0);
|
||||
|
||||
buf.put_u32(msg.len() as u32);
|
||||
buf.put(&msg[..]);
|
||||
|
||||
let header = buf.freeze();
|
||||
|
||||
if let Some(stream) = stream {
|
||||
Box::pin(res.chain(stream))
|
||||
Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream))
|
||||
} else {
|
||||
Box::pin(res)
|
||||
Box::pin(futures::stream::once(async move { Ok(header) }))
|
||||
}
|
||||
}
|
||||
RespEnc::Error { code, message } => {
|
||||
let mut buf = BytesMut::with_capacity(64);
|
||||
buf.put_u8(1 + message.len() as u8);
|
||||
buf.put_u8(code);
|
||||
buf.put(message.as_bytes());
|
||||
let header = buf.freeze();
|
||||
Box::pin(futures::stream::once(async move { Ok(header) }))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + Sync + 'static>(
|
||||
mut stream: S,
|
||||
) -> Result<Self, Error> {
|
||||
let mut packet = stream
|
||||
.next()
|
||||
.await
|
||||
.ok_or(Error::Framing)?
|
||||
.map_err(|_| Error::Framing)?;
|
||||
if packet.len() < 4 {
|
||||
return Err(Error::Framing);
|
||||
pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> {
|
||||
Self::decode_aux(stream).await.map_err(|_| Error::Framing)
|
||||
}
|
||||
|
||||
let mut len = [0; 4];
|
||||
len.copy_from_slice(&packet[..4]);
|
||||
let len = u32::from_be_bytes(len);
|
||||
packet = packet.slice(4..);
|
||||
pub(crate) async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> {
|
||||
let mut reader = ByteStreamReader::new(stream);
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
let len = len as usize;
|
||||
loop {
|
||||
let max_cp = std::cmp::min(len - buffer.len(), packet.len());
|
||||
let is_err = reader.read_u8().await?;
|
||||
|
||||
buffer.extend_from_slice(&packet[..max_cp]);
|
||||
if buffer.len() == len {
|
||||
packet = packet.slice(max_cp..);
|
||||
break;
|
||||
}
|
||||
packet = stream
|
||||
.next()
|
||||
.await
|
||||
.ok_or(Error::Framing)?
|
||||
.map_err(|_| Error::Framing)?;
|
||||
}
|
||||
|
||||
let stream: ByteStream = if packet.is_empty() {
|
||||
Box::pin(stream)
|
||||
if is_err > 0 {
|
||||
let code = reader.read_u8().await?;
|
||||
let message = reader.read_exact(is_err as usize - 1).await?;
|
||||
let message = String::from_utf8(message.to_vec()).unwrap_or_default();
|
||||
Ok(RespEnc::Error { code, message })
|
||||
} else {
|
||||
Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream))
|
||||
};
|
||||
let msg_len = reader.read_u32().await?;
|
||||
let msg = reader.read_exact(msg_len as usize).await?;
|
||||
|
||||
Ok(Framing {
|
||||
direct: buffer,
|
||||
stream: Some(stream),
|
||||
Ok(RespEnc::Success {
|
||||
msg,
|
||||
stream: Some(reader.into_stream()),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn into_parts(self) -> (Vec<u8>, ByteStream) {
|
||||
let Framing { direct, stream } = self;
|
||||
(direct, stream.unwrap_or(Box::pin(futures::stream::empty())))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ use futures::AsyncReadExt;
|
|||
|
||||
use crate::error::*;
|
||||
use crate::send::*;
|
||||
use crate::util::Packet;
|
||||
use crate::stream::*;
|
||||
|
||||
/// Structure to warn when the sender is dropped before end of stream was reached, like when
|
||||
/// connection to some remote drops while transmitting data
|
||||
|
|
|
@ -14,7 +14,7 @@ use tokio::sync::mpsc;
|
|||
|
||||
use crate::error::*;
|
||||
use crate::message::*;
|
||||
use crate::util::{ByteStream, Packet};
|
||||
use crate::stream::*;
|
||||
|
||||
// Messages are sent by chunks
|
||||
// Chunk format:
|
||||
|
|
|
@ -28,6 +28,7 @@ use crate::message::*;
|
|||
use crate::netapp::*;
|
||||
use crate::recv::*;
|
||||
use crate::send::*;
|
||||
use crate::stream::*;
|
||||
use crate::util::*;
|
||||
|
||||
// The client and server connection structs (client.rs and server.rs)
|
||||
|
@ -121,17 +122,12 @@ impl ServerConn {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv_handler_aux(
|
||||
self: &Arc<Self>,
|
||||
bytes: &[u8],
|
||||
stream: ByteStream,
|
||||
) -> Result<(Vec<u8>, Option<ByteStream>), Error> {
|
||||
let msg = QueryMessage::decode(bytes)?;
|
||||
let path = String::from_utf8(msg.path.to_vec())?;
|
||||
async fn recv_handler_aux(self: &Arc<Self>, req_enc: ReqEnc) -> Result<RespEnc, Error> {
|
||||
let path = String::from_utf8(req_enc.path.to_vec())?;
|
||||
|
||||
let handler_opt = {
|
||||
let endpoints = self.netapp.endpoints.read().unwrap();
|
||||
endpoints.get(&path).map(|e| e.clone_endpoint())
|
||||
endpoints.get(&path[..]).map(|e| e.clone_endpoint())
|
||||
};
|
||||
|
||||
if let Some(handler) = handler_opt {
|
||||
|
@ -139,9 +135,9 @@ impl ServerConn {
|
|||
if #[cfg(feature = "telemetry")] {
|
||||
let tracer = opentelemetry::global::tracer("netapp");
|
||||
|
||||
let mut span = if let Some(telemetry_id) = msg.telemetry_id {
|
||||
let mut span = if !req_enc.telemetry_id.is_empty() {
|
||||
let propagator = BinaryPropagator::new();
|
||||
let context = propagator.from_bytes(telemetry_id);
|
||||
let context = propagator.from_bytes(req_enc.telemetry_id.to_vec());
|
||||
let context = Context::new().with_remote_span_context(context);
|
||||
tracer.span_builder(format!(">> RPC {}", path))
|
||||
.with_kind(SpanKind::Server)
|
||||
|
@ -156,13 +152,13 @@ impl ServerConn {
|
|||
.start(&tracer)
|
||||
};
|
||||
span.set_attribute(KeyValue::new("path", path.to_string()));
|
||||
span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));
|
||||
span.set_attribute(KeyValue::new("len_query_msg", req_enc.msg.len() as i64));
|
||||
|
||||
handler.handle(msg.body, stream, self.peer_id)
|
||||
handler.handle(req_enc, self.peer_id)
|
||||
.with_context(Context::current_with_span(span))
|
||||
.await
|
||||
} else {
|
||||
handler.handle(msg.body, stream, self.peer_id).await
|
||||
handler.handle(req_enc, self.peer_id).await
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -181,32 +177,23 @@ impl RecvLoop for ServerConn {
|
|||
let self2 = self.clone();
|
||||
tokio::spawn(async move {
|
||||
trace!("ServerConn recv_handler {}", id);
|
||||
let (bytes, stream) = Framing::from_stream(stream).await?.into_parts();
|
||||
let (prio, resp_enc) = match ReqEnc::decode(Box::pin(stream)).await {
|
||||
Ok(req_enc) => {
|
||||
let prio = req_enc.prio;
|
||||
let resp = self2.recv_handler_aux(req_enc).await;
|
||||
|
||||
let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
|
||||
let resp = self2.recv_handler_aux(&bytes[..], stream).await;
|
||||
|
||||
let (resp_bytes, resp_stream) = match resp {
|
||||
Ok((rb, rs)) => {
|
||||
let mut resp_bytes = vec![0u8];
|
||||
resp_bytes.extend(rb);
|
||||
(resp_bytes, rs)
|
||||
}
|
||||
Err(e) => {
|
||||
let mut resp_bytes = vec![e.code()];
|
||||
resp_bytes.extend(e.to_string().into_bytes());
|
||||
(resp_bytes, None)
|
||||
(prio, match resp {
|
||||
Ok(resp_enc) => resp_enc,
|
||||
Err(e) => RespEnc::from_err(e),
|
||||
})
|
||||
}
|
||||
Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)),
|
||||
};
|
||||
|
||||
trace!("ServerConn sending response to {}: ", id);
|
||||
|
||||
resp_send
|
||||
.send((
|
||||
id,
|
||||
prio,
|
||||
Framing::new(resp_bytes, resp_stream).into_stream(),
|
||||
))
|
||||
.send((id, prio, resp_enc.encode()))
|
||||
.log_err("ServerConn recv_handler send resp bytes");
|
||||
Ok::<_, Error>(())
|
||||
});
|
||||
|
|
176
src/stream.rs
Normal file
176
src/stream.rs
Normal file
|
@ -0,0 +1,176 @@
|
|||
use std::collections::VecDeque;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use bytes::Bytes;
|
||||
|
||||
use futures::Future;
|
||||
use futures::{Stream, StreamExt};
|
||||
|
||||
/// A stream of associated data.
|
||||
///
|
||||
/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
|
||||
/// consecutive Vec may get merged, but Vec and error code may not be reordered
|
||||
///
|
||||
/// Error code 255 means the stream was cut before its end. Other codes have no predefined
|
||||
/// meaning, it's up to your application to define their semantic.
|
||||
pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>;
|
||||
|
||||
pub type Packet = Result<Bytes, u8>;
|
||||
|
||||
pub struct ByteStreamReader {
|
||||
stream: ByteStream,
|
||||
buf: VecDeque<Bytes>,
|
||||
buf_len: usize,
|
||||
eos: bool,
|
||||
err: Option<u8>,
|
||||
}
|
||||
|
||||
impl ByteStreamReader {
|
||||
pub fn new(stream: ByteStream) -> Self {
|
||||
ByteStreamReader {
|
||||
stream,
|
||||
buf: VecDeque::with_capacity(8),
|
||||
buf_len: 0,
|
||||
eos: false,
|
||||
err: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
|
||||
ByteStreamReadExact {
|
||||
reader: self,
|
||||
read_len,
|
||||
fail_on_eos: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
|
||||
ByteStreamReadExact {
|
||||
reader: self,
|
||||
read_len,
|
||||
fail_on_eos: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_u8(&mut self) -> Result<u8, ReadExactError> {
|
||||
Ok(self.read_exact(1).await?[0])
|
||||
}
|
||||
|
||||
pub async fn read_u16(&mut self) -> Result<u16, ReadExactError> {
|
||||
let bytes = self.read_exact(2).await?;
|
||||
let mut b = [0u8; 2];
|
||||
b.copy_from_slice(&bytes[..]);
|
||||
Ok(u16::from_be_bytes(b))
|
||||
}
|
||||
|
||||
pub async fn read_u32(&mut self) -> Result<u32, ReadExactError> {
|
||||
let bytes = self.read_exact(4).await?;
|
||||
let mut b = [0u8; 4];
|
||||
b.copy_from_slice(&bytes[..]);
|
||||
Ok(u32::from_be_bytes(b))
|
||||
}
|
||||
|
||||
pub fn into_stream(self) -> ByteStream {
|
||||
let buf_stream = futures::stream::iter(self.buf.into_iter().map(Ok));
|
||||
if let Some(err) = self.err {
|
||||
Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) })))
|
||||
} else if self.eos {
|
||||
Box::pin(buf_stream)
|
||||
} else {
|
||||
Box::pin(buf_stream.chain(self.stream))
|
||||
}
|
||||
}
|
||||
|
||||
fn try_get(&mut self, read_len: usize) -> Option<Bytes> {
|
||||
if self.buf_len >= read_len {
|
||||
let mut slices = Vec::with_capacity(self.buf.len());
|
||||
let mut taken = 0;
|
||||
while taken < read_len {
|
||||
let front = self.buf.pop_front().unwrap();
|
||||
if taken + front.len() <= read_len {
|
||||
taken += front.len();
|
||||
self.buf_len -= front.len();
|
||||
slices.push(front);
|
||||
} else {
|
||||
let front_take = read_len - taken;
|
||||
slices.push(front.slice(..front_take));
|
||||
self.buf.push_front(front.slice(front_take..));
|
||||
self.buf_len -= front_take;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(
|
||||
slices
|
||||
.iter()
|
||||
.map(|x| &x[..])
|
||||
.collect::<Vec<_>>()
|
||||
.concat()
|
||||
.into(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ReadExactError {
|
||||
UnexpectedEos,
|
||||
Stream(u8),
|
||||
}
|
||||
|
||||
#[pin_project::pin_project]
|
||||
pub struct ByteStreamReadExact<'a> {
|
||||
#[pin]
|
||||
reader: &'a mut ByteStreamReader,
|
||||
read_len: usize,
|
||||
fail_on_eos: bool,
|
||||
}
|
||||
|
||||
impl<'a> Future for ByteStreamReadExact<'a> {
|
||||
type Output = Result<Bytes, ReadExactError>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Bytes, ReadExactError>> {
|
||||
let mut this = self.project();
|
||||
|
||||
loop {
|
||||
if let Some(bytes) = this.reader.try_get(*this.read_len) {
|
||||
return Poll::Ready(Ok(bytes));
|
||||
}
|
||||
if let Some(err) = this.reader.err {
|
||||
return Poll::Ready(Err(ReadExactError::Stream(err)));
|
||||
}
|
||||
if this.reader.eos {
|
||||
if *this.fail_on_eos {
|
||||
return Poll::Ready(Err(ReadExactError::UnexpectedEos));
|
||||
} else {
|
||||
let bytes = Bytes::from(
|
||||
this.reader
|
||||
.buf
|
||||
.iter()
|
||||
.map(|x| &x[..])
|
||||
.collect::<Vec<_>>()
|
||||
.concat(),
|
||||
);
|
||||
this.reader.buf.clear();
|
||||
this.reader.buf_len = 0;
|
||||
return Poll::Ready(Ok(bytes));
|
||||
}
|
||||
}
|
||||
|
||||
match futures::ready!(this.reader.stream.as_mut().poll_next(cx)) {
|
||||
Some(Ok(slice)) => {
|
||||
this.reader.buf_len += slice.len();
|
||||
this.reader.buf.push_back(slice);
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
this.reader.err = Some(e);
|
||||
this.reader.eos = true;
|
||||
}
|
||||
None => {
|
||||
this.reader.eos = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
15
src/util.rs
15
src/util.rs
|
@ -1,12 +1,9 @@
|
|||
use std::net::SocketAddr;
|
||||
use std::net::ToSocketAddrs;
|
||||
use std::pin::Pin;
|
||||
|
||||
use bytes::Bytes;
|
||||
use log::info;
|
||||
use serde::Serialize;
|
||||
|
||||
use futures::Stream;
|
||||
use tokio::sync::watch;
|
||||
|
||||
/// A node's identifier, which is also its public cryptographic key
|
||||
|
@ -16,18 +13,6 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey;
|
|||
/// A network key
|
||||
pub type NetworkKey = sodiumoxide::crypto::auth::Key;
|
||||
|
||||
/// A stream of associated data.
|
||||
///
|
||||
/// The Stream can continue after receiving an error.
|
||||
/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
|
||||
/// consecutive Vec may get merged, but Vec and error code may not be reordered
|
||||
///
|
||||
/// Error code 255 means the stream was cut before its end. Other codes have no predefined
|
||||
/// meaning, it's up to your application to define their semantic.
|
||||
pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>;
|
||||
|
||||
pub type Packet = Result<Bytes, u8>;
|
||||
|
||||
/// Utility function: encodes any serializable value in MessagePack binary format
|
||||
/// using the RMP library.
|
||||
///
|
||||
|
|
Loading…
Reference in a new issue