Clean up framing protocol

This commit is contained in:
Alex 2022-07-22 12:45:38 +02:00
parent c358fe3c92
commit 0b71ca12f9
Signed by untrusted user: lx
GPG key ID: 0E496D15096376BE
9 changed files with 432 additions and 268 deletions

View file

@ -5,6 +5,7 @@ use std::sync::{Arc, Mutex};
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
use bytes::Bytes;
use log::{debug, error, trace};
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
@ -28,6 +29,7 @@ use crate::message::*;
use crate::netapp::*;
use crate::recv::*;
use crate::send::*;
use crate::stream::*;
use crate::util::*;
pub(crate) struct ClientConn {
@ -155,24 +157,16 @@ impl ClientConn {
.with_kind(SpanKind::Client)
.start(&tracer);
let propagator = BinaryPropagator::new();
let telemetry_id = Some(propagator.to_bytes(span.span_context()).to_vec());
let telemetry_id: Bytes = propagator.to_bytes(span.span_context()).to_vec().into();
} else {
let telemetry_id: Option<Vec<u8>> = None;
let telemetry_id: Bytes = Bytes::new();
}
};
// Encode request
let body = req.msg_ser.unwrap().clone();
let stream = req.body.into_stream();
let request = QueryMessage {
prio,
path: path.as_bytes(),
telemetry_id,
body: &body[..],
};
let bytes = request.encode();
drop(body);
let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id);
let req_msg_len = req_enc.msg.len();
let req_stream = req_enc.encode();
// Send request through
let (resp_send, resp_recv) = oneshot::channel();
@ -181,17 +175,19 @@ impl ClientConn {
error!(
"Too many inflight requests! RequestID collision. Interrupting previous request."
);
if old_ch.send(unbounded().1).is_err() {
debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
}
let _ = old_ch.send(unbounded().1);
}
trace!("request: query_send {}, {} bytes", id, bytes.len());
trace!(
"request: query_send {} (serialized message: {} bytes)",
id,
req_msg_len
);
#[cfg(feature = "telemetry")]
span.set_attribute(KeyValue::new("len_query", bytes.len() as i64));
span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64));
query_send.send((id, prio, Framing::new(bytes, stream).into_stream()))?;
query_send.send((id, prio, req_stream))?;
cfg_if::cfg_if! {
if #[cfg(feature = "telemetry")] {
@ -202,28 +198,10 @@ impl ClientConn {
let stream = resp_recv.await?;
}
}
let (resp, stream) = Framing::from_stream(stream).await?.into_parts();
if resp.is_empty() {
return Err(Error::Message(
"Response is 0 bytes, either a collision or a protocol error".into(),
));
}
trace!("request response {}: ", id);
let code = resp[0];
if code == 0 {
let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
Ok(Resp {
_phantom: Default::default(),
msg: ser_resp,
body: BodyData::Stream(stream),
})
} else {
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
Err(Error::Remote(code, msg))
}
let resp_enc = RespEnc::decode(Box::pin(stream)).await?;
trace!("request response {}", id);
Resp::from_enc(resp_enc)
}
}

View file

@ -158,12 +158,7 @@ pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;
#[async_trait]
pub(crate) trait GenericEndpoint {
async fn handle(
&self,
buf: &[u8],
stream: ByteStream,
from: NodeID,
) -> Result<(Vec<u8>, Option<ByteStream>), Error>;
async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error>;
fn drop_handler(&self);
fn clone_endpoint(&self) -> DynEndpoint;
}
@ -180,30 +175,13 @@ where
M: Message + 'static,
H: StreamingEndpointHandler<M> + 'static,
{
async fn handle(
&self,
buf: &[u8],
stream: ByteStream,
from: NodeID,
) -> Result<(Vec<u8>, Option<ByteStream>), Error> {
async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error> {
match self.0.handler.load_full() {
None => Err(Error::NoHandler),
Some(h) => {
let req = rmp_serde::decode::from_read_ref(buf)?;
let req = Req {
_phantom: Default::default(),
msg: Arc::new(req),
msg_ser: None,
body: BodyData::Stream(stream),
};
let req = Req::from_enc(req_enc)?;
let res = h.handle(req, from).await;
let Resp {
msg,
body,
_phantom,
} = res;
let res_bytes = rmp_to_vec_all_named(&msg)?;
Ok((res_bytes, body.into_stream()))
Ok(res.into_enc()?)
}
}
}

View file

@ -14,6 +14,7 @@
//! Also check out the examples to learn how to use this crate.
pub mod error;
pub mod stream;
pub mod util;
pub mod endpoint;

View file

@ -2,12 +2,13 @@ use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
use bytes::Bytes;
use bytes::{BufMut, Bytes, BytesMut};
use serde::{Deserialize, Serialize};
use futures::stream::{Stream, StreamExt};
use futures::stream::StreamExt;
use crate::error::*;
use crate::stream::*;
use crate::util::*;
/// Priority of a request (click to read more about priorities).
@ -45,6 +46,15 @@ pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
}
// ----
/// The Req<M> is a helper object used to create requests and attach them
/// a streaming body. If the body is a fixed Bytes and not a ByteStream,
/// Req<M> is cheaply clonable to allow the request to be sent to different
/// peers (Clone will panic if the body is a ByteStream).
///
/// Internally, this is also used to encode and decode requests
/// from/to byte streams to be sent over the network.
pub struct Req<M: Message> {
pub(crate) _phantom: PhantomData<M>,
pub(crate) msg: Arc<M>,
@ -52,30 +62,6 @@ pub struct Req<M: Message> {
pub(crate) body: BodyData,
}
pub struct Resp<M: Message> {
pub(crate) _phantom: PhantomData<M>,
pub(crate) msg: M::Response,
pub(crate) body: BodyData,
}
pub(crate) enum BodyData {
None,
Fixed(Bytes),
Stream(ByteStream),
}
impl BodyData {
pub fn into_stream(self) -> Option<ByteStream> {
match self {
BodyData::None => None,
BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))),
BodyData::Stream(s) => Some(s),
}
}
}
// ----
impl<M: Message> Req<M> {
pub fn msg(&self) -> &M {
&self.msg
@ -94,6 +80,31 @@ impl<M: Message> Req<M> {
..self
}
}
pub(crate) fn into_enc(
self,
prio: RequestPriority,
path: Bytes,
telemetry_id: Bytes,
) -> ReqEnc {
ReqEnc {
prio,
path,
telemetry_id,
msg: self.msg_ser.unwrap(),
stream: self.body.into_stream(),
}
}
pub(crate) fn from_enc(enc: ReqEnc) -> Result<Self, rmp_serde::decode::Error> {
let msg = rmp_serde::decode::from_read_ref(&enc.msg)?;
Ok(Req {
_phantom: Default::default(),
msg: Arc::new(msg),
msg_ser: Some(enc.msg),
body: enc.stream.map(BodyData::Stream).unwrap_or(BodyData::None),
})
}
}
pub trait IntoReq<M: Message> {
@ -160,19 +171,14 @@ where
}
}
impl<M> fmt::Debug for Resp<M>
where
M: Message,
<M as Message>::Response: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "Resp[{:?}", self.msg)?;
match &self.body {
BodyData::None => write!(f, "]"),
BodyData::Fixed(b) => write!(f, "; body={}]", b.len()),
BodyData::Stream(_) => write!(f, "; body=stream]"),
}
}
// ----
/// The Resp<M> represents a full response from a RPC that may have
/// an attached body stream.
pub struct Resp<M: Message> {
pub(crate) _phantom: PhantomData<M>,
pub(crate) msg: M::Response,
pub(crate) body: BodyData,
}
impl<M: Message> Resp<M> {
@ -205,160 +211,213 @@ impl<M: Message> Resp<M> {
pub fn into_msg(self) -> M::Response {
self.msg
}
pub(crate) fn into_enc(self) -> Result<RespEnc, rmp_serde::encode::Error> {
Ok(RespEnc::Success {
msg: rmp_to_vec_all_named(&self.msg)?.into(),
stream: self.body.into_stream(),
})
}
pub(crate) fn from_enc(enc: RespEnc) -> Result<Self, Error> {
match enc {
RespEnc::Success { msg, stream } => {
let msg = rmp_serde::decode::from_read_ref(&msg)?;
Ok(Self {
_phantom: Default::default(),
msg,
body: stream.map(BodyData::Stream).unwrap_or(BodyData::None),
})
}
RespEnc::Error { code, message } => Err(Error::Remote(code, message)),
}
}
}
impl<M> fmt::Debug for Resp<M>
where
M: Message,
<M as Message>::Response: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "Resp[{:?}", self.msg)?;
match &self.body {
BodyData::None => write!(f, "]"),
BodyData::Fixed(b) => write!(f, "; body={}]", b.len()),
BodyData::Stream(_) => write!(f, "; body=stream]"),
}
}
}
// ----
pub(crate) enum BodyData {
None,
Fixed(Bytes),
Stream(ByteStream),
}
impl BodyData {
pub fn into_stream(self) -> Option<ByteStream> {
match self {
BodyData::None => None,
BodyData::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))),
BodyData::Stream(s) => Some(s),
}
}
}
// ---- ----
pub(crate) struct QueryMessage<'a> {
pub(crate) prio: RequestPriority,
pub(crate) path: &'a [u8],
pub(crate) telemetry_id: Option<Vec<u8>>,
pub(crate) body: &'a [u8],
}
/// QueryMessage encoding:
/// Encoding for requests into a ByteStream:
/// - priority: u8
/// - path length: u8
/// - path: [u8; path length]
/// - telemetry id length: u8
/// - telemetry id: [u8; telemetry id length]
/// - body [u8; ..]
impl<'a> QueryMessage<'a> {
pub(crate) fn encode(self) -> Vec<u8> {
let tel_len = match &self.telemetry_id {
Some(t) => t.len(),
None => 0,
};
/// - msg len: u32
/// - msg [u8; ..]
/// - the attached stream as the rest of the encoded stream
pub(crate) struct ReqEnc {
pub(crate) prio: RequestPriority,
pub(crate) path: Bytes,
pub(crate) telemetry_id: Bytes,
pub(crate) msg: Bytes,
pub(crate) stream: Option<ByteStream>,
}
let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len());
impl ReqEnc {
pub(crate) fn encode(self) -> ByteStream {
let mut buf = BytesMut::with_capacity(64);
ret.push(self.prio);
buf.put_u8(self.prio);
ret.push(self.path.len() as u8);
ret.extend_from_slice(self.path);
buf.put_u8(self.path.len() as u8);
buf.put(self.path);
if let Some(t) = self.telemetry_id {
ret.push(t.len() as u8);
ret.extend(t);
buf.put_u8(self.telemetry_id.len() as u8);
buf.put(&self.telemetry_id[..]);
buf.put_u32(self.msg.len() as u32);
buf.put(&self.msg[..]);
let header = buf.freeze();
if let Some(stream) = self.stream {
Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream))
} else {
ret.push(0u8);
Box::pin(futures::stream::once(async move { Ok(header) }))
}
}
ret.extend_from_slice(self.body);
ret
pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> {
Self::decode_aux(stream).await.map_err(|_| Error::Framing)
}
pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> {
if bytes.len() < 3 {
return Err(Error::Message("Invalid protocol message".into()));
}
pub(crate) async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> {
let mut reader = ByteStreamReader::new(stream);
let path_length = bytes[1] as usize;
if bytes.len() < 3 + path_length {
return Err(Error::Message("Invalid protocol message".into()));
}
let prio = reader.read_u8().await?;
let telemetry_id_len = bytes[2 + path_length] as usize;
if bytes.len() < 3 + path_length + telemetry_id_len {
return Err(Error::Message("Invalid protocol message".into()));
}
let path_len = reader.read_u8().await?;
let path = reader.read_exact(path_len as usize).await?;
let path = &bytes[2..2 + path_length];
let telemetry_id = if telemetry_id_len > 0 {
Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec())
} else {
None
};
let telemetry_id_len = reader.read_u8().await?;
let telemetry_id = reader.read_exact(telemetry_id_len as usize).await?;
let body = &bytes[3 + path_length + telemetry_id_len..];
let msg_len = reader.read_u32().await?;
let msg = reader.read_exact(msg_len as usize).await?;
Ok(Self {
prio: bytes[0],
prio,
path,
telemetry_id,
body,
msg,
stream: Some(reader.into_stream()),
})
}
}
// ---- ----
pub(crate) struct Framing {
direct: Vec<u8>,
/// Encoding for responses into a ByteStream:
/// IF SUCCESS:
/// - 0: u8
/// - msg len: u32
/// - msg [u8; ..]
/// - the attached stream as the rest of the encoded stream
/// IF ERROR:
/// - message length + 1: u8
/// - error code: u8
/// - message: [u8; message_length]
pub(crate) enum RespEnc {
Error {
code: u8,
message: String,
},
Success {
msg: Bytes,
stream: Option<ByteStream>,
},
}
impl Framing {
pub fn new(direct: Vec<u8>, stream: Option<ByteStream>) -> Self {
assert!(direct.len() <= u32::MAX as usize);
Framing { direct, stream }
impl RespEnc {
pub(crate) fn from_err(e: Error) -> Self {
RespEnc::Error {
code: e.code(),
message: format!("{}", e),
}
}
pub fn into_stream(self) -> ByteStream {
use futures::stream;
let len = self.direct.len() as u32;
// required because otherwise the borrow-checker complains
let Framing { direct, stream } = self;
pub(crate) fn encode(self) -> ByteStream {
match self {
RespEnc::Success { msg, stream } => {
let mut buf = BytesMut::with_capacity(64);
let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec().into()) })
.chain(stream::once(async move { Ok(direct.into()) }));
buf.put_u8(0);
buf.put_u32(msg.len() as u32);
buf.put(&msg[..]);
let header = buf.freeze();
if let Some(stream) = stream {
Box::pin(res.chain(stream))
Box::pin(futures::stream::once(async move { Ok(header) }).chain(stream))
} else {
Box::pin(res)
Box::pin(futures::stream::once(async move { Ok(header) }))
}
}
RespEnc::Error { code, message } => {
let mut buf = BytesMut::with_capacity(64);
buf.put_u8(1 + message.len() as u8);
buf.put_u8(code);
buf.put(message.as_bytes());
let header = buf.freeze();
Box::pin(futures::stream::once(async move { Ok(header) }))
}
}
}
pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + Sync + 'static>(
mut stream: S,
) -> Result<Self, Error> {
let mut packet = stream
.next()
.await
.ok_or(Error::Framing)?
.map_err(|_| Error::Framing)?;
if packet.len() < 4 {
return Err(Error::Framing);
pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> {
Self::decode_aux(stream).await.map_err(|_| Error::Framing)
}
let mut len = [0; 4];
len.copy_from_slice(&packet[..4]);
let len = u32::from_be_bytes(len);
packet = packet.slice(4..);
pub(crate) async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> {
let mut reader = ByteStreamReader::new(stream);
let mut buffer = Vec::new();
let len = len as usize;
loop {
let max_cp = std::cmp::min(len - buffer.len(), packet.len());
let is_err = reader.read_u8().await?;
buffer.extend_from_slice(&packet[..max_cp]);
if buffer.len() == len {
packet = packet.slice(max_cp..);
break;
}
packet = stream
.next()
.await
.ok_or(Error::Framing)?
.map_err(|_| Error::Framing)?;
}
let stream: ByteStream = if packet.is_empty() {
Box::pin(stream)
if is_err > 0 {
let code = reader.read_u8().await?;
let message = reader.read_exact(is_err as usize - 1).await?;
let message = String::from_utf8(message.to_vec()).unwrap_or_default();
Ok(RespEnc::Error { code, message })
} else {
Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream))
};
let msg_len = reader.read_u32().await?;
let msg = reader.read_exact(msg_len as usize).await?;
Ok(Framing {
direct: buffer,
stream: Some(stream),
Ok(RespEnc::Success {
msg,
stream: Some(reader.into_stream()),
})
}
pub fn into_parts(self) -> (Vec<u8>, ByteStream) {
let Framing { direct, stream } = self;
(direct, stream.unwrap_or(Box::pin(futures::stream::empty())))
}
}

View file

@ -10,7 +10,7 @@ use futures::AsyncReadExt;
use crate::error::*;
use crate::send::*;
use crate::util::Packet;
use crate::stream::*;
/// Structure to warn when the sender is dropped before end of stream was reached, like when
/// connection to some remote drops while transmitting data

View file

@ -14,7 +14,7 @@ use tokio::sync::mpsc;
use crate::error::*;
use crate::message::*;
use crate::util::{ByteStream, Packet};
use crate::stream::*;
// Messages are sent by chunks
// Chunk format:

View file

@ -28,6 +28,7 @@ use crate::message::*;
use crate::netapp::*;
use crate::recv::*;
use crate::send::*;
use crate::stream::*;
use crate::util::*;
// The client and server connection structs (client.rs and server.rs)
@ -121,17 +122,12 @@ impl ServerConn {
Ok(())
}
async fn recv_handler_aux(
self: &Arc<Self>,
bytes: &[u8],
stream: ByteStream,
) -> Result<(Vec<u8>, Option<ByteStream>), Error> {
let msg = QueryMessage::decode(bytes)?;
let path = String::from_utf8(msg.path.to_vec())?;
async fn recv_handler_aux(self: &Arc<Self>, req_enc: ReqEnc) -> Result<RespEnc, Error> {
let path = String::from_utf8(req_enc.path.to_vec())?;
let handler_opt = {
let endpoints = self.netapp.endpoints.read().unwrap();
endpoints.get(&path).map(|e| e.clone_endpoint())
endpoints.get(&path[..]).map(|e| e.clone_endpoint())
};
if let Some(handler) = handler_opt {
@ -139,9 +135,9 @@ impl ServerConn {
if #[cfg(feature = "telemetry")] {
let tracer = opentelemetry::global::tracer("netapp");
let mut span = if let Some(telemetry_id) = msg.telemetry_id {
let mut span = if !req_enc.telemetry_id.is_empty() {
let propagator = BinaryPropagator::new();
let context = propagator.from_bytes(telemetry_id);
let context = propagator.from_bytes(req_enc.telemetry_id.to_vec());
let context = Context::new().with_remote_span_context(context);
tracer.span_builder(format!(">> RPC {}", path))
.with_kind(SpanKind::Server)
@ -156,13 +152,13 @@ impl ServerConn {
.start(&tracer)
};
span.set_attribute(KeyValue::new("path", path.to_string()));
span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));
span.set_attribute(KeyValue::new("len_query_msg", req_enc.msg.len() as i64));
handler.handle(msg.body, stream, self.peer_id)
handler.handle(req_enc, self.peer_id)
.with_context(Context::current_with_span(span))
.await
} else {
handler.handle(msg.body, stream, self.peer_id).await
handler.handle(req_enc, self.peer_id).await
}
}
} else {
@ -181,32 +177,23 @@ impl RecvLoop for ServerConn {
let self2 = self.clone();
tokio::spawn(async move {
trace!("ServerConn recv_handler {}", id);
let (bytes, stream) = Framing::from_stream(stream).await?.into_parts();
let (prio, resp_enc) = match ReqEnc::decode(Box::pin(stream)).await {
Ok(req_enc) => {
let prio = req_enc.prio;
let resp = self2.recv_handler_aux(req_enc).await;
let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
let resp = self2.recv_handler_aux(&bytes[..], stream).await;
let (resp_bytes, resp_stream) = match resp {
Ok((rb, rs)) => {
let mut resp_bytes = vec![0u8];
resp_bytes.extend(rb);
(resp_bytes, rs)
}
Err(e) => {
let mut resp_bytes = vec![e.code()];
resp_bytes.extend(e.to_string().into_bytes());
(resp_bytes, None)
(prio, match resp {
Ok(resp_enc) => resp_enc,
Err(e) => RespEnc::from_err(e),
})
}
Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)),
};
trace!("ServerConn sending response to {}: ", id);
resp_send
.send((
id,
prio,
Framing::new(resp_bytes, resp_stream).into_stream(),
))
.send((id, prio, resp_enc.encode()))
.log_err("ServerConn recv_handler send resp bytes");
Ok::<_, Error>(())
});

176
src/stream.rs Normal file
View file

@ -0,0 +1,176 @@
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures::Future;
use futures::{Stream, StreamExt};
/// A stream of associated data.
///
/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
/// consecutive Vec may get merged, but Vec and error code may not be reordered
///
/// Error code 255 means the stream was cut before its end. Other codes have no predefined
/// meaning, it's up to your application to define their semantic.
pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>;
pub type Packet = Result<Bytes, u8>;
pub struct ByteStreamReader {
stream: ByteStream,
buf: VecDeque<Bytes>,
buf_len: usize,
eos: bool,
err: Option<u8>,
}
impl ByteStreamReader {
pub fn new(stream: ByteStream) -> Self {
ByteStreamReader {
stream,
buf: VecDeque::with_capacity(8),
buf_len: 0,
eos: false,
err: None,
}
}
pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
ByteStreamReadExact {
reader: self,
read_len,
fail_on_eos: true,
}
}
pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
ByteStreamReadExact {
reader: self,
read_len,
fail_on_eos: false,
}
}
pub async fn read_u8(&mut self) -> Result<u8, ReadExactError> {
Ok(self.read_exact(1).await?[0])
}
pub async fn read_u16(&mut self) -> Result<u16, ReadExactError> {
let bytes = self.read_exact(2).await?;
let mut b = [0u8; 2];
b.copy_from_slice(&bytes[..]);
Ok(u16::from_be_bytes(b))
}
pub async fn read_u32(&mut self) -> Result<u32, ReadExactError> {
let bytes = self.read_exact(4).await?;
let mut b = [0u8; 4];
b.copy_from_slice(&bytes[..]);
Ok(u32::from_be_bytes(b))
}
pub fn into_stream(self) -> ByteStream {
let buf_stream = futures::stream::iter(self.buf.into_iter().map(Ok));
if let Some(err) = self.err {
Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) })))
} else if self.eos {
Box::pin(buf_stream)
} else {
Box::pin(buf_stream.chain(self.stream))
}
}
fn try_get(&mut self, read_len: usize) -> Option<Bytes> {
if self.buf_len >= read_len {
let mut slices = Vec::with_capacity(self.buf.len());
let mut taken = 0;
while taken < read_len {
let front = self.buf.pop_front().unwrap();
if taken + front.len() <= read_len {
taken += front.len();
self.buf_len -= front.len();
slices.push(front);
} else {
let front_take = read_len - taken;
slices.push(front.slice(..front_take));
self.buf.push_front(front.slice(front_take..));
self.buf_len -= front_take;
break;
}
}
Some(
slices
.iter()
.map(|x| &x[..])
.collect::<Vec<_>>()
.concat()
.into(),
)
} else {
None
}
}
}
pub enum ReadExactError {
UnexpectedEos,
Stream(u8),
}
#[pin_project::pin_project]
pub struct ByteStreamReadExact<'a> {
#[pin]
reader: &'a mut ByteStreamReader,
read_len: usize,
fail_on_eos: bool,
}
impl<'a> Future for ByteStreamReadExact<'a> {
type Output = Result<Bytes, ReadExactError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Bytes, ReadExactError>> {
let mut this = self.project();
loop {
if let Some(bytes) = this.reader.try_get(*this.read_len) {
return Poll::Ready(Ok(bytes));
}
if let Some(err) = this.reader.err {
return Poll::Ready(Err(ReadExactError::Stream(err)));
}
if this.reader.eos {
if *this.fail_on_eos {
return Poll::Ready(Err(ReadExactError::UnexpectedEos));
} else {
let bytes = Bytes::from(
this.reader
.buf
.iter()
.map(|x| &x[..])
.collect::<Vec<_>>()
.concat(),
);
this.reader.buf.clear();
this.reader.buf_len = 0;
return Poll::Ready(Ok(bytes));
}
}
match futures::ready!(this.reader.stream.as_mut().poll_next(cx)) {
Some(Ok(slice)) => {
this.reader.buf_len += slice.len();
this.reader.buf.push_back(slice);
}
Some(Err(e)) => {
this.reader.err = Some(e);
this.reader.eos = true;
}
None => {
this.reader.eos = true;
}
}
}
}
}

View file

@ -1,12 +1,9 @@
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::pin::Pin;
use bytes::Bytes;
use log::info;
use serde::Serialize;
use futures::Stream;
use tokio::sync::watch;
/// A node's identifier, which is also its public cryptographic key
@ -16,18 +13,6 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey;
/// A network key
pub type NetworkKey = sodiumoxide::crypto::auth::Key;
/// A stream of associated data.
///
/// The Stream can continue after receiving an error.
/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
/// consecutive Vec may get merged, but Vec and error code may not be reordered
///
/// Error code 255 means the stream was cut before its end. Other codes have no predefined
/// meaning, it's up to your application to define their semantic.
pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>;
pub type Packet = Result<Bytes, u8>;
/// Utility function: encodes any serializable value in MessagePack binary format
/// using the RMP library.
///