WIP: associated stream #1

Draft
trinity-1686a wants to merge 7 commits from stream-body into main
11 changed files with 566 additions and 117 deletions

44
Cargo.lock generated
View file

@ -151,6 +151,19 @@ dependencies = [
"lazy_static", "lazy_static",
] ]
[[package]]
name = "env_logger"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44533bbbb3bb3c1fa17d9f2e4e38bbbaf8396ba82193c4cb1b6445d711445d36"
dependencies = [
"atty",
"humantime 1.3.0",
"log",
"regex",
"termcolor",
]
[[package]] [[package]]
name = "env_logger" name = "env_logger"
version = "0.8.4" version = "0.8.4"
@ -158,7 +171,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3" checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3"
dependencies = [ dependencies = [
"atty", "atty",
"humantime", "humantime 2.1.0",
"log", "log",
"regex", "regex",
"termcolor", "termcolor",
@ -322,6 +335,15 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "humantime"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f"
dependencies = [
"quick-error",
]
[[package]] [[package]]
name = "humantime" name = "humantime"
version = "2.1.0" version = "2.1.0"
@ -440,7 +462,7 @@ dependencies = [
"bytes 0.6.0", "bytes 0.6.0",
"cfg-if", "cfg-if",
"chrono", "chrono",
"env_logger", "env_logger 0.8.4",
"err-derive", "err-derive",
"futures", "futures",
"hex", "hex",
@ -450,6 +472,8 @@ dependencies = [
"lru", "lru",
"opentelemetry", "opentelemetry",
"opentelemetry-contrib", "opentelemetry-contrib",
"pin-project",
"pretty_env_logger",
"rand 0.5.6", "rand 0.5.6",
"rmp-serde", "rmp-serde",
"serde", "serde",
@ -582,6 +606,16 @@ version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
[[package]]
name = "pretty_env_logger"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "926d36b9553851b8b0005f1275891b392ee4d2d833852c417ed025477350fb9d"
dependencies = [
"env_logger 0.7.1",
"log",
]
[[package]] [[package]]
name = "proc-macro-error" name = "proc-macro-error"
version = "1.0.4" version = "1.0.4"
@ -627,6 +661,12 @@ dependencies = [
"unicode-xid", "unicode-xid",
] ]
[[package]]
name = "quick-error"
version = "1.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.10" version = "1.0.10"

View file

@ -21,6 +21,7 @@ telemetry = ["opentelemetry", "opentelemetry-contrib", "rand"]
[dependencies] [dependencies]
futures = "0.3.17" futures = "0.3.17"
pin-project = "1.0.10"
tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] } tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] }
tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] } tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] }
tokio-stream = "0.1.7" tokio-stream = "0.1.7"
@ -47,6 +48,7 @@ opentelemetry-contrib = { version = "0.9", optional = true }
[dev-dependencies] [dev-dependencies]
env_logger = "0.8" env_logger = "0.8"
pretty_env_logger = "0.4"
structopt = { version = "0.3", default-features = false } structopt = { version = "0.3", default-features = false }
chrono = "0.4" chrono = "0.4"

View file

@ -7,6 +7,7 @@ use std::sync::{Arc, Mutex};
use arc_swap::ArcSwapOption; use arc_swap::ArcSwapOption;
use log::{debug, error, trace}; use log::{debug, error, trace};
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::select; use tokio::select;
use tokio::sync::{mpsc, oneshot, watch}; use tokio::sync::{mpsc, oneshot, watch};
@ -37,10 +38,11 @@ pub(crate) struct ClientConn {
pub(crate) remote_addr: SocketAddr, pub(crate) remote_addr: SocketAddr,
pub(crate) peer_id: NodeID, pub(crate) peer_id: NodeID,
query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>, query_send:
ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>,
next_query_number: AtomicU32, next_query_number: AtomicU32,
inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>, inflight: Mutex<HashMap<RequestID, oneshot::Sender<UnboundedReceiver<Packet>>>>,
} }
impl ClientConn { impl ClientConn {
@ -166,7 +168,7 @@ impl ClientConn {
}; };
// Encode request // Encode request
let body = rmp_to_vec_all_named(rq.borrow())?; let (body, stream) = rmp_to_vec_all_named(rq.borrow())?;
drop(rq); drop(rq);
let request = QueryMessage { let request = QueryMessage {
@ -185,7 +187,7 @@ impl ClientConn {
error!( error!(
"Too many inflight requests! RequestID collision. Interrupting previous request." "Too many inflight requests! RequestID collision. Interrupting previous request."
); );
if old_ch.send(vec![]).is_err() { if old_ch.send(unbounded().1).is_err() {
debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
} }
} }
@ -195,17 +197,18 @@ impl ClientConn {
#[cfg(feature = "telemetry")] #[cfg(feature = "telemetry")]
span.set_attribute(KeyValue::new("len_query", bytes.len() as i64)); span.set_attribute(KeyValue::new("len_query", bytes.len() as i64));
query_send.send((id, prio, bytes))?; query_send.send((id, prio, Framing::new(bytes, stream).into_stream()))?;
cfg_if::cfg_if! { cfg_if::cfg_if! {
if #[cfg(feature = "telemetry")] { if #[cfg(feature = "telemetry")] {
let resp = resp_recv let stream = resp_recv
.with_context(Context::current_with_span(span)) .with_context(Context::current_with_span(span))
.await?; .await?;
} else { } else {
let resp = resp_recv.await?; let stream = resp_recv.await?;
} }
} }
let (resp, stream) = Framing::from_stream(stream).await?.into_parts();
if resp.is_empty() { if resp.is_empty() {
return Err(Error::Message( return Err(Error::Message(
@ -217,10 +220,8 @@ impl ClientConn {
let code = resp[0]; let code = resp[0];
if code == 0 { if code == 0 {
Ok(rmp_serde::decode::from_read_ref::< let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
_, Ok(T::Response::deserialize_msg(ser_resp, stream).await)
<T as Message>::Response,
>(&resp[1..])?)
} else { } else {
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
Err(Error::Remote(code, msg)) Err(Error::Remote(code, msg))
@ -232,12 +233,12 @@ impl SendLoop for ClientConn {}
#[async_trait] #[async_trait]
impl RecvLoop for ClientConn { impl RecvLoop for ClientConn {
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>) { fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len()); trace!("ClientConn recv_handler {}", id);
let mut inflight = self.inflight.lock().unwrap(); let mut inflight = self.inflight.lock().unwrap();
if let Some(ch) = inflight.remove(&id) { if let Some(ch) = inflight.remove(&id) {
if ch.send(msg).is_err() { if ch.send(stream).is_err() {
debug!("Could not send request response, probably because request was interrupted. Dropping response."); debug!("Could not send request response, probably because request was interrupted. Dropping response.");
} }
} }

View file

@ -14,8 +14,68 @@ use crate::util::*;
/// This trait should be implemented by all messages your application /// This trait should be implemented by all messages your application
/// wants to handle /// wants to handle
pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { pub trait Message: SerializeMessage + Send + Sync {
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; type Response: SerializeMessage + Send + Sync;
}
/// A trait for de/serializing messages, with possible associated stream.
#[async_trait]
pub trait SerializeMessage: Sized {
type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>);
// TODO should return Result
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self;
}
pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
#[async_trait]
impl<T> SerializeMessage for T
where
T: AutoSerialize,
{
type SerializableSelf = Self;
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>) {
(self.clone(), None)
}
async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: AssociatedStream) -> Self {
// TODO verify no stream
ser_self
}
}
impl AutoSerialize for () {}
#[async_trait]
impl<T, E> SerializeMessage for Result<T, E>
where
T: SerializeMessage + Send,
E: SerializeMessage + Send,
{
type SerializableSelf = Result<T::SerializableSelf, E::SerializableSelf>;
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>) {
match self {
Ok(ok) => {
let (msg, stream) = ok.serialize_msg();
(Ok(msg), stream)
}
Err(err) => {
let (msg, stream) = err.serialize_msg();
(Err(msg), stream)
}
}
}
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self {
match ser_self {
Ok(ok) => Ok(T::deserialize_msg(ok, stream).await),
Err(err) => Err(E::deserialize_msg(err, stream).await),
}
}
} }
/// This trait should be implemented by an object of your application /// This trait should be implemented by an object of your application
@ -96,7 +156,7 @@ where
prio: RequestPriority, prio: RequestPriority,
) -> Result<<M as Message>::Response, Error> ) -> Result<<M as Message>::Response, Error>
where where
B: Borrow<M>, B: Borrow<M> + Send + Sync,
{ {
if *target == self.netapp.id { if *target == self.netapp.id {
match self.handler.load_full() { match self.handler.load_full() {
@ -128,7 +188,12 @@ pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;
#[async_trait] #[async_trait]
pub(crate) trait GenericEndpoint { pub(crate) trait GenericEndpoint {
async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error>; async fn handle(
&self,
buf: &[u8],
stream: AssociatedStream,
from: NodeID,
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error>;
fn drop_handler(&self); fn drop_handler(&self);
fn clone_endpoint(&self) -> DynEndpoint; fn clone_endpoint(&self) -> DynEndpoint;
} }
@ -145,11 +210,17 @@ where
M: Message + 'static, M: Message + 'static,
H: EndpointHandler<M> + 'static, H: EndpointHandler<M> + 'static,
{ {
async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error> { async fn handle(
&self,
buf: &[u8],
stream: AssociatedStream,
from: NodeID,
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> {
match self.0.handler.load_full() { match self.0.handler.load_full() {
None => Err(Error::NoHandler), None => Err(Error::NoHandler),
Some(h) => { Some(h) => {
let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?; let req = rmp_serde::decode::from_read_ref(buf)?;
let req = M::deserialize_msg(req, stream).await;
let res = h.handle(&req, from).await; let res = h.handle(&req, from).await;
let res_bytes = rmp_to_vec_all_named(&res)?; let res_bytes = rmp_to_vec_all_named(&res)?;
Ok(res_bytes) Ok(res_bytes)

View file

@ -25,6 +25,9 @@ pub enum Error {
#[error(display = "UTF8 error: {}", _0)] #[error(display = "UTF8 error: {}", _0)]
UTF8(#[error(source)] std::string::FromUtf8Error), UTF8(#[error(source)] std::string::FromUtf8Error),
#[error(display = "Framing protocol error")]
Framing,
#[error(display = "{}", _0)] #[error(display = "{}", _0)]
Message(String), Message(String),
@ -50,6 +53,7 @@ impl Error {
Self::RMPEncode(_) => 10, Self::RMPEncode(_) => 10,
Self::RMPDecode(_) => 11, Self::RMPDecode(_) => 11,
Self::UTF8(_) => 12, Self::UTF8(_) => 12,
Self::Framing => 13,
Self::NoHandler => 20, Self::NoHandler => 20,
Self::ConnectionClosed => 21, Self::ConnectionClosed => 21,
Self::Handshake(_) => 30, Self::Handshake(_) => 30,

View file

@ -32,12 +32,14 @@ pub(crate) type VersionTag = [u8; 16];
/// Value of the Netapp version used in the version tag /// Value of the Netapp version used in the version tag
pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004 pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub(crate) struct HelloMessage { pub(crate) struct HelloMessage {
pub server_addr: Option<IpAddr>, pub server_addr: Option<IpAddr>,
pub server_port: u16, pub server_port: u16,
} }
impl AutoSerialize for HelloMessage {}
impl Message for HelloMessage { impl Message for HelloMessage {
type Response = (); type Response = ();
} }

View file

@ -29,7 +29,7 @@ const FAILED_PING_THRESHOLD: usize = 3;
// -- Protocol messages -- // -- Protocol messages --
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Clone)]
struct PingMessage { struct PingMessage {
pub id: u64, pub id: u64,
pub peer_list_hash: hash::Digest, pub peer_list_hash: hash::Digest,
@ -39,7 +39,9 @@ impl Message for PingMessage {
type Response = PingMessage; type Response = PingMessage;
} }
#[derive(Serialize, Deserialize)] impl AutoSerialize for PingMessage {}
#[derive(Serialize, Deserialize, Clone)]
struct PeerListMessage { struct PeerListMessage {
pub list: Vec<(NodeID, SocketAddr)>, pub list: Vec<(NodeID, SocketAddr)>,
} }
@ -48,6 +50,8 @@ impl Message for PeerListMessage {
type Response = PeerListMessage; type Response = PeerListMessage;
} }
impl AutoSerialize for PeerListMessage {}
// -- Algorithm data structures -- // -- Algorithm data structures --
#[derive(Debug)] #[derive(Debug)]

View file

@ -1,9 +1,13 @@
use std::collections::{HashMap, VecDeque}; use std::collections::{HashMap, VecDeque};
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll};
use log::trace; use log::trace;
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use futures::{AsyncReadExt, AsyncWriteExt}; use futures::{AsyncReadExt, AsyncWriteExt};
use futures::{Stream, StreamExt};
use kuska_handshake::async_std::BoxStreamWrite; use kuska_handshake::async_std::BoxStreamWrite;
use tokio::sync::mpsc; use tokio::sync::mpsc;
@ -11,6 +15,7 @@ use tokio::sync::mpsc;
use async_trait::async_trait; use async_trait::async_trait;
use crate::error::*; use crate::error::*;
use crate::util::{AssociatedStream, Packet};
/// Priority of a request (click to read more about priorities). /// Priority of a request (click to read more about priorities).
/// ///
@ -48,14 +53,148 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01;
pub(crate) type RequestID = u32; pub(crate) type RequestID = u32;
type ChunkLength = u16; type ChunkLength = u16;
const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0;
const ERROR_MARKER: ChunkLength = 0x4000;
const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
struct SendQueueItem { struct SendQueueItem {
id: RequestID, id: RequestID,
prio: RequestPriority, prio: RequestPriority,
data: Vec<u8>, data: DataReader,
cursor: usize, }
#[pin_project::pin_project]
struct DataReader {
#[pin]
reader: AssociatedStream,
packet: Packet,
pos: usize,
buf: Vec<u8>,
eos: bool,
}
impl From<AssociatedStream> for DataReader {
fn from(data: AssociatedStream) -> DataReader {
DataReader {
reader: data,
packet: Ok(Vec::new()),
pos: 0,
buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize),
eos: false,
}
}
}
enum DataFrame {
Data {
/// a fixed size buffer containing some data, possibly padded with 0s
data: [u8; MAX_CHUNK_LENGTH as usize],
/// actual lenght of data
len: usize,
},
Error(u8),
}
struct DataReaderItem {
data: DataFrame,
/// whethere there may be more data comming from this stream. Can be used for some
/// optimization. It's an error to set it to false if there is more data, but it is correct
/// (albeit sub-optimal) to set it to true if there is nothing coming after
may_have_more: bool,
}
impl DataReaderItem {
fn empty_last() -> Self {
DataReaderItem {
data: DataFrame::Data {
data: [0; MAX_CHUNK_LENGTH as usize],
len: 0,
},
may_have_more: false,
}
}
fn header(&self) -> [u8; 2] {
let continuation = if self.may_have_more {
CHUNK_HAS_CONTINUATION
} else {
0
};
let len = match self.data {
DataFrame::Data { len, .. } => len as u16,
DataFrame::Error(e) => e as u16 | ERROR_MARKER,
};
ChunkLength::to_be_bytes(len | continuation)
}
fn data(&self) -> &[u8] {
match self.data {
DataFrame::Data { ref data, len } => &data[..len],
DataFrame::Error(_) => &[],
}
}
}
impl Stream for DataReader {
type Item = DataReaderItem;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if *this.eos {
// eos was reached at previous call to poll_next, where a partial packet
// was returned. Now return None
return Poll::Ready(None);
}
loop {
let packet = match this.packet {
Ok(v) => v,
Err(e) => {
let e = *e;
*this.packet = Ok(Vec::new());
return Poll::Ready(Some(DataReaderItem {
data: DataFrame::Error(e),
may_have_more: true,
}));
}
};
let packet_left = packet.len() - *this.pos;
let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len();
let to_read = std::cmp::min(buf_left, packet_left);
this.buf
.extend_from_slice(&packet[*this.pos..*this.pos + to_read]);
*this.pos += to_read;
if this.buf.len() == MAX_CHUNK_LENGTH as usize {
// we have a full buf, ready to send
break;
}
// we don't have a full buf, packet is empty; try receive more
if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) {
*this.packet = p;
*this.pos = 0;
// if buf is empty, we will loop and return the error directly. If buf
// isn't empty, send it before by breaking.
if this.packet.is_err() && !this.buf.is_empty() {
break;
}
} else {
*this.eos = true;
break;
}
}
let mut body = [0; MAX_CHUNK_LENGTH as usize];
let len = this.buf.len();
body[..len].copy_from_slice(this.buf);
this.buf.clear();
Poll::Ready(Some(DataReaderItem {
data: DataFrame::Data { data: body, len },
may_have_more: !*this.eos,
}))
}
} }
struct SendQueue { struct SendQueue {
@ -79,6 +218,8 @@ impl SendQueue {
}; };
self.items[pos_prio].1.push_back(item); self.items[pos_prio].1.push_back(item);
} }
// used only in tests. They should probably be rewriten
#[allow(dead_code)]
fn pop(&mut self) -> Option<SendQueueItem> { fn pop(&mut self) -> Option<SendQueueItem> {
match self.items.pop_front() { match self.items.pop_front() {
None => None, None => None,
@ -94,6 +235,54 @@ impl SendQueue {
fn is_empty(&self) -> bool { fn is_empty(&self) -> bool {
self.items.iter().all(|(_k, v)| v.is_empty()) self.items.iter().all(|(_k, v)| v.is_empty())
} }
// this is like an async fn, but hand implemented
fn next_ready(&mut self) -> SendQueuePollNextReady<'_> {
SendQueuePollNextReady { queue: self }
}
}
struct SendQueuePollNextReady<'a> {
queue: &'a mut SendQueue,
}
impl<'a> futures::Future for SendQueuePollNextReady<'a> {
type Output = (RequestID, DataReaderItem);
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
for i in 0..self.queue.items.len() {
let (_prio, items_at_prio) = &mut self.queue.items[i];
for _ in 0..items_at_prio.len() {
let mut item = items_at_prio.pop_front().unwrap();
match Pin::new(&mut item.data).poll_next(ctx) {
Poll::Pending => items_at_prio.push_back(item),
Poll::Ready(Some(data)) => {
let id = item.id;
if data.may_have_more {
self.queue.push(item);
} else {
if items_at_prio.is_empty() {
// this priority level is empty, remove it
self.queue.items.remove(i);
}
}
return Poll::Ready((id, data));
}
Poll::Ready(None) => {
if items_at_prio.is_empty() {
// this priority level is empty, remove it
self.queue.items.remove(i);
}
return Poll::Ready((item.id, DataReaderItem::empty_last()));
}
}
}
}
// TODO what do we do if self.queue is empty? We won't get scheduled again.
Poll::Pending
}
} }
/// The SendLoop trait, which is implemented both by the client and the server /// The SendLoop trait, which is implemented both by the client and the server
@ -108,7 +297,7 @@ impl SendQueue {
pub(crate) trait SendLoop: Sync { pub(crate) trait SendLoop: Sync {
async fn send_loop<W>( async fn send_loop<W>(
self: Arc<Self>, self: Arc<Self>,
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>, mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>,
mut write: BoxStreamWrite<W>, mut write: BoxStreamWrite<W>,
) -> Result<(), Error> ) -> Result<(), Error>
where where
@ -117,55 +306,34 @@ pub(crate) trait SendLoop: Sync {
let mut sending = SendQueue::new(); let mut sending = SendQueue::new();
let mut should_exit = false; let mut should_exit = false;
while !should_exit || !sending.is_empty() { while !should_exit || !sending.is_empty() {
if let Ok((id, prio, data)) = msg_recv.try_recv() { let recv_fut = msg_recv.recv();
trace!("send_loop: got {}, {} bytes", id, data.len()); futures::pin_mut!(recv_fut);
sending.push(SendQueueItem { let send_fut = sending.next_ready();
id,
prio,
data,
cursor: 0,
});
} else if let Some(mut item) = sending.pop() {
trace!(
"send_loop: sending bytes for {} ({} bytes, {} already sent)",
item.id,
item.data.len(),
item.cursor
);
let header_id = RequestID::to_be_bytes(item.id);
write.write_all(&header_id[..]).await?;
if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize { // recv_fut is cancellation-safe according to tokio doc,
let size_header = // send_fut is cancellation-safe as implemented above?
ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION); use futures::future::Either;
write.write_all(&size_header[..]).await?; match futures::future::select(recv_fut, send_fut).await {
Either::Left((sth, _send_fut)) => {
let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize;
write.write_all(&item.data[item.cursor..new_cursor]).await?;
item.cursor = new_cursor;
sending.push(item);
} else {
let send_len = (item.data.len() - item.cursor) as ChunkLength;
let size_header = ChunkLength::to_be_bytes(send_len);
write.write_all(&size_header[..]).await?;
write.write_all(&item.data[item.cursor..]).await?;
}
write.flush().await?;
} else {
let sth = msg_recv.recv().await;
if let Some((id, prio, data)) = sth { if let Some((id, prio, data)) = sth {
trace!("send_loop: got {}, {} bytes", id, data.len());
sending.push(SendQueueItem { sending.push(SendQueueItem {
id, id,
prio, prio,
data, data: data.into(),
cursor: 0,
}); });
} else { } else {
should_exit = true; should_exit = true;
};
}
Either::Right(((id, data), _recv_fut)) => {
trace!("send_loop: sending bytes for {}", id);
let header_id = RequestID::to_be_bytes(id);
write.write_all(&header_id[..]).await?;
write.write_all(&data.header()).await?;
write.write_all(data.data()).await?;
write.flush().await?;
} }
} }
} }
@ -175,6 +343,118 @@ pub(crate) trait SendLoop: Sync {
} }
} }
pub(crate) struct Framing {
direct: Vec<u8>,
stream: Option<AssociatedStream>,
}
impl Framing {
pub fn new(direct: Vec<u8>, stream: Option<AssociatedStream>) -> Self {
assert!(direct.len() <= u32::MAX as usize);
Framing { direct, stream }
}
pub fn into_stream(self) -> AssociatedStream {
use futures::stream;
let len = self.direct.len() as u32;
// required because otherwise the borrow-checker complains
let Framing { direct, stream } = self;
let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) })
.chain(stream::once(async move { Ok(direct) }));
if let Some(stream) = stream {
Box::pin(res.chain(stream))
} else {
Box::pin(res)
}
}
pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>(
mut stream: S,
) -> Result<Self, Error> {
let mut packet = stream
.next()
.await
.ok_or(Error::Framing)?
.map_err(|_| Error::Framing)?;
if packet.len() < 4 {
return Err(Error::Framing);
}
let mut len = [0; 4];
len.copy_from_slice(&packet[..4]);
let len = u32::from_be_bytes(len);
packet.drain(..4);
let mut buffer = Vec::new();
let len = len as usize;
loop {
let max_cp = std::cmp::min(len - buffer.len(), packet.len());
buffer.extend_from_slice(&packet[..max_cp]);
if buffer.len() == len {
packet.drain(..max_cp);
break;
}
packet = stream
.next()
.await
.ok_or(Error::Framing)?
.map_err(|_| Error::Framing)?;
}
let stream: AssociatedStream = if packet.is_empty() {
Box::pin(stream)
} else {
Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream))
};
Ok(Framing {
direct: buffer,
stream: Some(stream),
})
}
pub fn into_parts(self) -> (Vec<u8>, AssociatedStream) {
let Framing { direct, stream } = self;
(direct, stream.unwrap_or(Box::pin(futures::stream::empty())))
}
}
/// Structure to warn when the sender is dropped before end of stream was reached, like when
/// connection to some remote drops while transmitting data
struct Sender {
inner: UnboundedSender<Packet>,
closed: bool,
}
impl Sender {
fn new(inner: UnboundedSender<Packet>) -> Self {
Sender {
inner,
closed: false,
}
}
fn send(&self, packet: Packet) {
let _ = self.inner.unbounded_send(packet);
}
fn end(&mut self) {
self.closed = true;
}
}
impl Drop for Sender {
fn drop(&mut self) {
if !self.closed {
self.send(Err(255));
}
self.inner.close_channel();
}
}
/// The RecvLoop trait, which is implemented both by the client and the server /// The RecvLoop trait, which is implemented both by the client and the server
/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
/// and a prototype of a handler for received messages `.recv_handler()` that /// and a prototype of a handler for received messages `.recv_handler()` that
@ -184,13 +464,13 @@ pub(crate) trait SendLoop: Sync {
/// the full message is passed to the receive handler. /// the full message is passed to the receive handler.
#[async_trait] #[async_trait]
pub(crate) trait RecvLoop: Sync + 'static { pub(crate) trait RecvLoop: Sync + 'static {
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>); fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>);
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
where where
R: AsyncReadExt + Unpin + Send + Sync, R: AsyncReadExt + Unpin + Send + Sync,
{ {
let mut receiving = HashMap::new(); let mut streams: HashMap<RequestID, Sender> = HashMap::new();
loop { loop {
trace!("recv_loop: reading packet"); trace!("recv_loop: reading packet");
let mut header_id = [0u8; RequestID::BITS as usize / 8]; let mut header_id = [0u8; RequestID::BITS as usize / 8];
@ -208,19 +488,33 @@ pub(crate) trait RecvLoop: Sync + 'static {
trace!("recv_loop: got header size: {:04x}", size); trace!("recv_loop: got header size: {:04x}", size);
let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
let is_error = (size & ERROR_MARKER) != 0;
let packet = if is_error {
Err(size as u8)
} else {
let size = size & !CHUNK_HAS_CONTINUATION; let size = size & !CHUNK_HAS_CONTINUATION;
let mut next_slice = vec![0; size as usize]; let mut next_slice = vec![0; size as usize];
read.read_exact(&mut next_slice[..]).await?; read.read_exact(&mut next_slice[..]).await?;
trace!("recv_loop: read {} bytes", next_slice.len()); trace!("recv_loop: read {} bytes", next_slice.len());
Ok(next_slice)
};
let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default(); let mut sender = if let Some(send) = streams.remove(&(id)) {
msg_bytes.extend_from_slice(&next_slice[..]); send
} else {
let (send, recv) = unbounded();
self.recv_handler(id, recv);
Sender::new(send)
};
// if we get an error, the receiving end is disconnected. We still need to
// reach eos before dropping this sender
sender.send(packet);
if has_cont { if has_cont {
receiving.insert(id, msg_bytes); streams.insert(id, sender);
} else { } else {
self.recv_handler(id, msg_bytes); sender.end();
} }
} }
Ok(()) Ok(())
@ -231,43 +525,44 @@ pub(crate) trait RecvLoop: Sync + 'static {
mod test { mod test {
use super::*; use super::*;
fn empty_data() -> DataReader {
type Item = Packet;
let stream: Pin<Box<dyn futures::Stream<Item = Item> + Send + 'static>> =
Box::pin(futures::stream::empty::<Packet>());
stream.into()
}
#[test] #[test]
fn test_priority_queue() { fn test_priority_queue() {
let i1 = SendQueueItem { let i1 = SendQueueItem {
id: 1, id: 1,
prio: PRIO_NORMAL, prio: PRIO_NORMAL,
data: vec![], data: empty_data(),
cursor: 0,
}; };
let i2 = SendQueueItem { let i2 = SendQueueItem {
id: 2, id: 2,
prio: PRIO_HIGH, prio: PRIO_HIGH,
data: vec![], data: empty_data(),
cursor: 0,
}; };
let i2bis = SendQueueItem { let i2bis = SendQueueItem {
id: 20, id: 20,
prio: PRIO_HIGH, prio: PRIO_HIGH,
data: vec![], data: empty_data(),
cursor: 0,
}; };
let i3 = SendQueueItem { let i3 = SendQueueItem {
id: 3, id: 3,
prio: PRIO_HIGH | PRIO_SECONDARY, prio: PRIO_HIGH | PRIO_SECONDARY,
data: vec![], data: empty_data(),
cursor: 0,
}; };
let i4 = SendQueueItem { let i4 = SendQueueItem {
id: 4, id: 4,
prio: PRIO_BACKGROUND | PRIO_SECONDARY, prio: PRIO_BACKGROUND | PRIO_SECONDARY,
data: vec![], data: empty_data(),
cursor: 0,
}; };
let i5 = SendQueueItem { let i5 = SendQueueItem {
id: 5, id: 5,
prio: PRIO_BACKGROUND | PRIO_PRIMARY, prio: PRIO_BACKGROUND | PRIO_PRIMARY,
data: vec![], data: empty_data(),
cursor: 0,
}; };
let mut q = SendQueue::new(); let mut q = SendQueue::new();

View file

@ -2,7 +2,6 @@ use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use arc_swap::ArcSwapOption; use arc_swap::ArcSwapOption;
use bytes::Bytes;
use log::{debug, trace}; use log::{debug, trace};
#[cfg(feature = "telemetry")] #[cfg(feature = "telemetry")]
@ -20,6 +19,7 @@ use tokio::select;
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
use tokio_util::compat::*; use tokio_util::compat::*;
use futures::channel::mpsc::UnboundedReceiver;
use futures::io::{AsyncReadExt, AsyncWriteExt}; use futures::io::{AsyncReadExt, AsyncWriteExt};
use async_trait::async_trait; use async_trait::async_trait;
@ -55,7 +55,7 @@ pub(crate) struct ServerConn {
netapp: Arc<NetApp>, netapp: Arc<NetApp>,
resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>, resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>,
} }
impl ServerConn { impl ServerConn {
@ -123,7 +123,11 @@ impl ServerConn {
Ok(()) Ok(())
} }
async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> { async fn recv_handler_aux(
self: &Arc<Self>,
bytes: &[u8],
stream: AssociatedStream,
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> {
let msg = QueryMessage::decode(bytes)?; let msg = QueryMessage::decode(bytes)?;
let path = String::from_utf8(msg.path.to_vec())?; let path = String::from_utf8(msg.path.to_vec())?;
@ -156,11 +160,11 @@ impl ServerConn {
span.set_attribute(KeyValue::new("path", path.to_string())); span.set_attribute(KeyValue::new("path", path.to_string()));
span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64)); span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));
handler.handle(msg.body, self.peer_id) handler.handle(msg.body, stream, self.peer_id)
.with_context(Context::current_with_span(span)) .with_context(Context::current_with_span(span))
.await .await
} else { } else {
handler.handle(msg.body, self.peer_id).await handler.handle(msg.body, stream, self.peer_id).await
} }
} }
} else { } else {
@ -173,35 +177,40 @@ impl SendLoop for ServerConn {}
#[async_trait] #[async_trait]
impl RecvLoop for ServerConn { impl RecvLoop for ServerConn {
fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) { fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
let resp_send = self.resp_send.load_full().unwrap(); let resp_send = self.resp_send.load_full().unwrap();
let self2 = self.clone(); let self2 = self.clone();
tokio::spawn(async move { tokio::spawn(async move {
trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); trace!("ServerConn recv_handler {}", id);
let bytes: Bytes = bytes.into(); let (bytes, stream) = Framing::from_stream(stream).await?.into_parts();
let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
let resp = self2.recv_handler_aux(&bytes[..]).await; let resp = self2.recv_handler_aux(&bytes[..], stream).await;
let resp_bytes = match resp { let (resp_bytes, resp_stream) = match resp {
Ok(rb) => { Ok((rb, rs)) => {
let mut resp_bytes = vec![0u8]; let mut resp_bytes = vec![0u8];
resp_bytes.extend(rb); resp_bytes.extend(rb);
resp_bytes (resp_bytes, rs)
} }
Err(e) => { Err(e) => {
let mut resp_bytes = vec![e.code()]; let mut resp_bytes = vec![e.code()];
resp_bytes.extend(e.to_string().into_bytes()); resp_bytes.extend(e.to_string().into_bytes());
resp_bytes (resp_bytes, None)
} }
}; };
trace!("ServerConn sending response to {}: ", id); trace!("ServerConn sending response to {}: ", id);
resp_send resp_send
.send((id, prio, resp_bytes)) .send((
.log_err("ServerConn recv_handler send resp"); id,
prio,
Framing::new(resp_bytes, resp_stream).into_stream(),
))
.log_err("ServerConn recv_handler send resp bytes");
Ok::<_, Error>(())
}); });
} }
} }

View file

@ -14,6 +14,7 @@ use crate::NodeID;
#[tokio::test(flavor = "current_thread")] #[tokio::test(flavor = "current_thread")]
async fn test_with_basic_scheduler() { async fn test_with_basic_scheduler() {
pretty_env_logger::init();
run_test().await run_test().await
} }

View file

@ -1,10 +1,15 @@
use crate::endpoint::SerializeMessage;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
use std::pin::Pin;
use serde::Serialize; use futures::Stream;
use log::info; use log::info;
use serde::Serialize;
use tokio::sync::watch; use tokio::sync::watch;
/// A node's identifier, which is also its public cryptographic key /// A node's identifier, which is also its public cryptographic key
@ -14,21 +19,36 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey;
/// A network key /// A network key
pub type NetworkKey = sodiumoxide::crypto::auth::Key; pub type NetworkKey = sodiumoxide::crypto::auth::Key;
/// A stream of associated data.
///
/// The Stream can continue after receiving an error.
/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
/// consecutive Vec may get merged, but Vec and error code may not be reordered
///
/// Error code 255 means the stream was cut before its end. Other codes have no predefined
/// meaning, it's up to your application to define their semantic.
pub type AssociatedStream = Pin<Box<dyn Stream<Item = Packet> + Send>>;
pub type Packet = Result<Vec<u8>, u8>;
/// Utility function: encodes any serializable value in MessagePack binary format /// Utility function: encodes any serializable value in MessagePack binary format
/// using the RMP library. /// using the RMP library.
/// ///
/// Field names and variant names are included in the serialization. /// Field names and variant names are included in the serialization.
/// This is used internally by the netapp communication protocol. /// This is used internally by the netapp communication protocol.
pub fn rmp_to_vec_all_named<T>(val: &T) -> Result<Vec<u8>, rmp_serde::encode::Error> pub fn rmp_to_vec_all_named<T>(
val: &T,
) -> Result<(Vec<u8>, Option<AssociatedStream>), rmp_serde::encode::Error>
where where
T: Serialize + ?Sized, T: SerializeMessage + ?Sized,
{ {
let mut wr = Vec::with_capacity(128); let mut wr = Vec::with_capacity(128);
let mut se = rmp_serde::Serializer::new(&mut wr) let mut se = rmp_serde::Serializer::new(&mut wr)
.with_struct_map() .with_struct_map()
.with_string_variants(); .with_string_variants();
let (val, stream) = val.serialize_msg();
val.serialize(&mut se)?; val.serialize(&mut se)?;
Ok(wr) Ok((wr, stream))
} }
/// This async function returns only when a true signal was received /// This async function returns only when a true signal was received