From 368ba908794901bc793c6a087c02241be046bdf2 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Sun, 5 Jun 2022 15:33:43 +0200 Subject: [PATCH 1/7] initial work on associated stream still require testing, and fixing a few kinks: - sending packets > 16k truncate them - send one more packet than it could at eos - probably update documentation /!\ contains breaking changes --- Cargo.lock | 44 +++++++- Cargo.toml | 2 + src/client.rs | 37 ++++--- src/endpoint.rs | 66 ++++++++++-- src/proto.rs | 260 ++++++++++++++++++++++++++++++++++++++++-------- src/server.rs | 38 ++++--- src/test.rs | 1 + src/util.rs | 17 +++- 8 files changed, 382 insertions(+), 83 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fe2a29d..356c3ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -151,6 +151,19 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "env_logger" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44533bbbb3bb3c1fa17d9f2e4e38bbbaf8396ba82193c4cb1b6445d711445d36" +dependencies = [ + "atty", + "humantime 1.3.0", + "log", + "regex", + "termcolor", +] + [[package]] name = "env_logger" version = "0.8.4" @@ -158,7 +171,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3" dependencies = [ "atty", - "humantime", + "humantime 2.1.0", "log", "regex", "termcolor", @@ -322,6 +335,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "humantime" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f" +dependencies = [ + "quick-error", +] + [[package]] name = "humantime" version = "2.1.0" @@ -440,7 +462,7 @@ dependencies = [ "bytes 0.6.0", "cfg-if", "chrono", - "env_logger", + "env_logger 0.8.4", "err-derive", "futures", "hex", @@ -450,6 +472,8 @@ dependencies = [ "lru", "opentelemetry", "opentelemetry-contrib", + "pin-project", + "pretty_env_logger", "rand 0.5.6", "rmp-serde", "serde", @@ -582,6 +606,16 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" +[[package]] +name = "pretty_env_logger" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "926d36b9553851b8b0005f1275891b392ee4d2d833852c417ed025477350fb9d" +dependencies = [ + "env_logger 0.7.1", + "log", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -627,6 +661,12 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quote" version = "1.0.10" diff --git a/Cargo.toml b/Cargo.toml index 536a1e6..a2f0ab1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ telemetry = ["opentelemetry", "opentelemetry-contrib", "rand"] [dependencies] futures = "0.3.17" +pin-project = "1.0.10" tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] } tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] } tokio-stream = "0.1.7" @@ -47,6 +48,7 @@ opentelemetry-contrib = { version = "0.9", optional = true } [dev-dependencies] env_logger = "0.8" +pretty_env_logger = "0.4" structopt = { version = "0.3", default-features = false } chrono = "0.4" diff --git a/src/client.rs b/src/client.rs index 8227e8f..bce7aca 100644 --- a/src/client.rs +++ b/src/client.rs @@ -37,10 +37,10 @@ pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption)>>, + query_send: ArcSwapOption>, next_query_number: AtomicU32, - inflight: Mutex>>>, + inflight: Mutex, AssociatedStream)>>>, } impl ClientConn { @@ -148,9 +148,11 @@ impl ClientConn { { let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; + // increment by 2; even are direct data; odd are associated stream let id = self .next_query_number - .fetch_add(1, atomic::Ordering::Relaxed); + .fetch_add(2, atomic::Ordering::Relaxed); + let stream_id = id + 1; cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { @@ -166,7 +168,7 @@ impl ClientConn { }; // Encode request - let body = rmp_to_vec_all_named(rq.borrow())?; + let (body, stream) = rmp_to_vec_all_named(rq.borrow())?; drop(rq); let request = QueryMessage { @@ -185,7 +187,10 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch.send(vec![]).is_err() { + if old_ch + .send((vec![], Box::pin(futures::stream::empty()))) + .is_err() + { debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); } } @@ -195,15 +200,20 @@ impl ClientConn { #[cfg(feature = "telemetry")] span.set_attribute(KeyValue::new("len_query", bytes.len() as i64)); - query_send.send((id, prio, bytes))?; + query_send.send((id, prio, Data::Full(bytes)))?; + if let Some(stream) = stream { + query_send.send((stream_id, prio | PRIO_SECONDARY, Data::Streaming(stream)))?; + } else { + query_send.send((stream_id, prio, Data::Full(Vec::new())))?; + } cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { - let resp = resp_recv + let (resp, stream) = resp_recv .with_context(Context::current_with_span(span)) .await?; } else { - let resp = resp_recv.await?; + let (resp, stream) = resp_recv.await?; } } @@ -217,10 +227,9 @@ impl ClientConn { let code = resp[0]; if code == 0 { - Ok(rmp_serde::decode::from_read_ref::< - _, - ::Response, - >(&resp[1..])?) + let mut deser = rmp_serde::decode::Deserializer::from_read_ref(&resp[1..]); + let res = T::Response::deserialize_msg(&mut deser, stream).await?; + Ok(res) } else { let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); Err(Error::Remote(code, msg)) @@ -232,12 +241,12 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc, id: RequestID, msg: Vec) { + fn recv_handler(self: &Arc, id: RequestID, msg: Vec, stream: AssociatedStream) { trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len()); let mut inflight = self.inflight.lock().unwrap(); if let Some(ch) = inflight.remove(&id) { - if ch.send(msg).is_err() { + if ch.send((msg, stream)).is_err() { debug!("Could not send request response, probably because request was interrupted. Dropping response."); } } diff --git a/src/endpoint.rs b/src/endpoint.rs index 42e9a98..81ed036 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -5,7 +5,8 @@ use std::sync::Arc; use arc_swap::ArcSwapOption; use async_trait::async_trait; -use serde::{Deserialize, Serialize}; +use serde::de::Error as DeError; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::error::Error; use crate::netapp::*; @@ -14,8 +15,50 @@ use crate::util::*; /// This trait should be implemented by all messages your application /// wants to handle -pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { - type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; +pub trait Message: SerializeMessage + Send + Sync { + type Response: SerializeMessage + Send + Sync; +} + +/// A trait for de/serializing messages, with possible associated stream. +#[async_trait] +pub trait SerializeMessage: Sized { + fn serialize_msg( + &self, + serializer: S, + ) -> Result<(S::Ok, Option), S::Error>; + + async fn deserialize_msg<'de, D: Deserializer<'de> + Send>( + deserializer: D, + stream: AssociatedStream, + ) -> Result; +} + +#[async_trait] +impl SerializeMessage for T +where + T: Serialize + for<'de> Deserialize<'de> + Send + Sync, +{ + fn serialize_msg( + &self, + serializer: S, + ) -> Result<(S::Ok, Option), S::Error> { + self.serialize(serializer).map(|r| (r, None)) + } + + async fn deserialize_msg<'de, D: Deserializer<'de> + Send>( + deserializer: D, + mut stream: AssociatedStream, + ) -> Result { + use futures::StreamExt; + + let res = Self::deserialize(deserializer)?; + if stream.next().await.is_some() { + return Err(D::Error::custom( + "failed to deserialize: found associated stream when none expected", + )); + } + Ok(res) + } } /// This trait should be implemented by an object of your application @@ -128,7 +171,12 @@ pub(crate) type DynEndpoint = Box; #[async_trait] pub(crate) trait GenericEndpoint { - async fn handle(&self, buf: &[u8], from: NodeID) -> Result, Error>; + async fn handle( + &self, + buf: &[u8], + stream: AssociatedStream, + from: NodeID, + ) -> Result<(Vec, Option), Error>; fn drop_handler(&self); fn clone_endpoint(&self) -> DynEndpoint; } @@ -145,11 +193,17 @@ where M: Message + 'static, H: EndpointHandler + 'static, { - async fn handle(&self, buf: &[u8], from: NodeID) -> Result, Error> { + async fn handle( + &self, + buf: &[u8], + stream: AssociatedStream, + from: NodeID, + ) -> Result<(Vec, Option), Error> { match self.0.handler.load_full() { None => Err(Error::NoHandler), Some(h) => { - let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?; + let mut deser = rmp_serde::decode::Deserializer::from_read_ref(buf); + let req = M::deserialize_msg(&mut deser, stream).await?; let res = h.handle(&req, from).await; let res_bytes = rmp_to_vec_all_named(&res)?; Ok(res_bytes) diff --git a/src/proto.rs b/src/proto.rs index e843bff..b45ff13 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -1,9 +1,13 @@ use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; -use log::trace; +use log::{trace, warn}; -use futures::{AsyncReadExt, AsyncWriteExt}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; +use futures::Stream; +use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -11,6 +15,7 @@ use tokio::sync::mpsc; use async_trait::async_trait; use crate::error::*; +use crate::util::AssociatedStream; /// Priority of a request (click to read more about priorities). /// @@ -48,14 +53,73 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; pub(crate) type RequestID = u32; type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { id: RequestID, prio: RequestPriority, - data: Vec, - cursor: usize, + data: DataReader, +} + +pub(crate) enum Data { + Full(Vec), + Streaming(AssociatedStream), +} + +#[pin_project::pin_project(project = DataReaderProj)] +enum DataReader { + Full { + #[pin] + data: Vec, + pos: usize, + }, + Streaming { + #[pin] + reader: AssociatedStream, + }, +} + +impl From for DataReader { + fn from(data: Data) -> DataReader { + match data { + Data::Full(data) => DataReader::Full { data, pos: 0 }, + Data::Streaming(reader) => DataReader::Streaming { reader }, + } + } +} + +impl Stream for DataReader { + type Item = ([u8; MAX_CHUNK_LENGTH as usize], usize); + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + DataReaderProj::Full { data, pos } => { + let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, data.len() - *pos); + let end = *pos + len; + + if len == 0 { + Poll::Ready(None) + } else { + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + body[..len].copy_from_slice(&data[*pos..end]); + *pos = end; + Poll::Ready(Some((body, len))) + } + } + DataReaderProj::Streaming { reader } => { + reader.poll_next(cx).map(|opt| { + opt.map(|v| { + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, v.len()); + // TODO this can throw away long vec, they should be splited instead + body[..len].copy_from_slice(&v[..len]); + (body, len) + }) + }) + } + } + } } struct SendQueue { @@ -108,7 +172,7 @@ impl SendQueue { pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec)>, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Data)>, mut write: BoxStreamWrite, ) -> Result<(), Error> where @@ -118,51 +182,78 @@ pub(crate) trait SendLoop: Sync { let mut should_exit = false; while !should_exit || !sending.is_empty() { if let Ok((id, prio, data)) = msg_recv.try_recv() { - trace!("send_loop: got {}, {} bytes", id, data.len()); + match &data { + Data::Full(data) => { + trace!("send_loop: got {}, {} bytes", id, data.len()); + } + Data::Streaming(_) => { + trace!("send_loop: got {}, unknown size", id); + } + } sending.push(SendQueueItem { id, prio, - data, - cursor: 0, + data: data.into(), }); } else if let Some(mut item) = sending.pop() { trace!( - "send_loop: sending bytes for {} ({} bytes, {} already sent)", - item.id, - item.data.len(), - item.cursor + "send_loop: sending bytes for {}", + item.id, ); + + let data = futures::select! { + data = item.data.next().fuse() => data, + default => { + // nothing to send yet; re-schedule and find something else to do + sending.push(item); + continue; + + // TODO if every SendQueueItem is waiting on data, use select_all to await + // something to do + // TODO find some way to not require sending empty last chunk + } + }; + let header_id = RequestID::to_be_bytes(item.id); write.write_all(&header_id[..]).await?; - if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize { + let data = match data.as_ref() { + Some((data, len)) => &data[..*len], + None => &[], + }; + + if !data.is_empty() { let size_header = - ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION); + ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION); write.write_all(&size_header[..]).await?; - let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize; - write.write_all(&item.data[item.cursor..new_cursor]).await?; - item.cursor = new_cursor; + write.write_all(data).await?; sending.push(item); } else { - let send_len = (item.data.len() - item.cursor) as ChunkLength; - - let size_header = ChunkLength::to_be_bytes(send_len); + // this is always zero for now, but may be more when above TODO get fixed + let size_header = ChunkLength::to_be_bytes(data.len() as u16); write.write_all(&size_header[..]).await?; - write.write_all(&item.data[item.cursor..]).await?; + write.write_all(data).await?; } + write.flush().await?; } else { let sth = msg_recv.recv().await; if let Some((id, prio, data)) = sth { - trace!("send_loop: got {}, {} bytes", id, data.len()); + match &data { + Data::Full(data) => { + trace!("send_loop: got {}, {} bytes", id, data.len()); + } + Data::Streaming(_) => { + trace!("send_loop: got {}, unknown size", id); + } + } sending.push(SendQueueItem { id, prio, - data, - cursor: 0, + data: data.into(), }); } else { should_exit = true; @@ -175,6 +266,41 @@ pub(crate) trait SendLoop: Sync { } } +struct ChannelPair { + receiver: Option>>, + sender: Option>>, +} + +impl ChannelPair { + fn take_receiver(&mut self) -> Option>> { + self.receiver.take() + } + + fn take_sender(&mut self) -> Option>> { + self.sender.take() + } + + fn ref_sender(&mut self) -> Option<&UnboundedSender>> { + self.sender.as_ref().take() + } + + fn insert_into(self, map: &mut HashMap, index: RequestID) { + if self.receiver.is_some() || self.sender.is_some() { + map.insert(index, self); + } + } +} + +impl Default for ChannelPair { + fn default() -> Self { + let (send, recv) = unbounded(); + ChannelPair { + receiver: Some(recv), + sender: Some(send), + } + } +} + /// The RecvLoop trait, which is implemented both by the client and the server /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// and a prototype of a handler for received messages `.recv_handler()` that @@ -184,13 +310,17 @@ pub(crate) trait SendLoop: Sync { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, msg: Vec); + fn recv_handler(self: &Arc, id: RequestID, msg: Vec, stream: AssociatedStream); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut receiving = HashMap::new(); + let mut receiving: HashMap> = HashMap::new(); + let mut streams: HashMap< + RequestID, + ChannelPair, + > = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -214,13 +344,43 @@ pub(crate) trait RecvLoop: Sync + 'static { read.read_exact(&mut next_slice[..]).await?; trace!("recv_loop: read {} bytes", next_slice.len()); - let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default(); - msg_bytes.extend_from_slice(&next_slice[..]); + if id & 1 == 0 { + // main stream + let mut msg_bytes = receiving.remove(&id).unwrap_or_default(); + msg_bytes.extend_from_slice(&next_slice[..]); - if has_cont { - receiving.insert(id, msg_bytes); + if has_cont { + receiving.insert(id, msg_bytes); + } else { + let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); + + if let Some(receiver) = channel_pair.take_receiver() { + self.recv_handler(id, msg_bytes, Box::pin(receiver)); + } else { + warn!("Couldn't take receiver part of stream") + } + + channel_pair.insert_into(&mut streams, id | 1); + } } else { - self.recv_handler(id, msg_bytes); + // associated stream + let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); + + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + if !next_slice.is_empty() { + if let Some(sender) = channel_pair.ref_sender() { + let _ = sender.unbounded_send(next_slice); + } else { + warn!("Couldn't take sending part of stream") + } + } + + if !has_cont { + channel_pair.take_sender(); + } + + channel_pair.insert_into(&mut streams, id); } } Ok(()) @@ -236,38 +396,50 @@ mod test { let i1 = SendQueueItem { id: 1, prio: PRIO_NORMAL, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i2 = SendQueueItem { id: 2, prio: PRIO_HIGH, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i2bis = SendQueueItem { id: 20, prio: PRIO_HIGH, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i3 = SendQueueItem { id: 3, prio: PRIO_HIGH | PRIO_SECONDARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i4 = SendQueueItem { id: 4, prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i5 = SendQueueItem { id: 5, prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let mut q = SendQueue::new(); diff --git a/src/server.rs b/src/server.rs index 5465307..6cd4056 100644 --- a/src/server.rs +++ b/src/server.rs @@ -55,7 +55,7 @@ pub(crate) struct ServerConn { netapp: Arc, - resp_send: ArcSwapOption)>>, + resp_send: ArcSwapOption>, } impl ServerConn { @@ -123,7 +123,11 @@ impl ServerConn { Ok(()) } - async fn recv_handler_aux(self: &Arc, bytes: &[u8]) -> Result, Error> { + async fn recv_handler_aux( + self: &Arc, + bytes: &[u8], + stream: AssociatedStream, + ) -> Result<(Vec, Option), Error> { let msg = QueryMessage::decode(bytes)?; let path = String::from_utf8(msg.path.to_vec())?; @@ -156,11 +160,11 @@ impl ServerConn { span.set_attribute(KeyValue::new("path", path.to_string())); span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64)); - handler.handle(msg.body, self.peer_id) + handler.handle(msg.body, stream, self.peer_id) .with_context(Context::current_with_span(span)) .await } else { - handler.handle(msg.body, self.peer_id).await + handler.handle(msg.body, stream, self.peer_id).await } } } else { @@ -173,7 +177,7 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc, id: RequestID, bytes: Vec) { + fn recv_handler(self: &Arc, id: RequestID, bytes: Vec, stream: AssociatedStream) { let resp_send = self.resp_send.load_full().unwrap(); let self2 = self.clone(); @@ -182,26 +186,36 @@ impl RecvLoop for ServerConn { let bytes: Bytes = bytes.into(); let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; - let resp = self2.recv_handler_aux(&bytes[..]).await; + let resp = self2.recv_handler_aux(&bytes[..], stream).await; - let resp_bytes = match resp { - Ok(rb) => { + let (resp_bytes, resp_stream) = match resp { + Ok((rb, rs)) => { let mut resp_bytes = vec![0u8]; resp_bytes.extend(rb); - resp_bytes + (resp_bytes, rs) } Err(e) => { let mut resp_bytes = vec![e.code()]; resp_bytes.extend(e.to_string().into_bytes()); - resp_bytes + (resp_bytes, None) } }; trace!("ServerConn sending response to {}: ", id); resp_send - .send((id, prio, resp_bytes)) - .log_err("ServerConn recv_handler send resp"); + .send((id, prio, Data::Full(resp_bytes))) + .log_err("ServerConn recv_handler send resp bytes"); + + if let Some(resp_stream) = resp_stream { + resp_send + .send((id + 1, prio, Data::Streaming(resp_stream))) + .log_err("ServerConn recv_handler send resp stream"); + } else { + resp_send + .send((id + 1, prio, Data::Full(Vec::new()))) + .log_err("ServerConn recv_handler send resp stream"); + } }); } } diff --git a/src/test.rs b/src/test.rs index 82c7ba6..ecd5450 100644 --- a/src/test.rs +++ b/src/test.rs @@ -14,6 +14,7 @@ use crate::NodeID; #[tokio::test(flavor = "current_thread")] async fn test_with_basic_scheduler() { + pretty_env_logger::init(); run_test().await } diff --git a/src/util.rs b/src/util.rs index f4dfac7..4333080 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,7 +1,10 @@ +use crate::endpoint::SerializeMessage; + use std::net::SocketAddr; use std::net::ToSocketAddrs; +use std::pin::Pin; -use serde::Serialize; +use futures::Stream; use log::info; @@ -14,21 +17,25 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; /// A network key pub type NetworkKey = sodiumoxide::crypto::auth::Key; +pub type AssociatedStream = Pin> + Send>>; + /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. /// /// Field names and variant names are included in the serialization. /// This is used internally by the netapp communication protocol. -pub fn rmp_to_vec_all_named(val: &T) -> Result, rmp_serde::encode::Error> +pub fn rmp_to_vec_all_named( + val: &T, +) -> Result<(Vec, Option), rmp_serde::encode::Error> where - T: Serialize + ?Sized, + T: SerializeMessage + ?Sized, { let mut wr = Vec::with_capacity(128); let mut se = rmp_serde::Serializer::new(&mut wr) .with_struct_map() .with_string_variants(); - val.serialize(&mut se)?; - Ok(wr) + let (_, stream) = val.serialize_msg(&mut se)?; + Ok((wr, stream)) } /// This async function returns only when a true signal was received -- 2.43.4 From fb5462ecdb6b5731a63a902519d3ec9b1061b8dd Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Sun, 5 Jun 2022 16:47:29 +0200 Subject: [PATCH 2/7] rechunk stream --- src/proto.rs | 153 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 91 insertions(+), 62 deletions(-) diff --git a/src/proto.rs b/src/proto.rs index b45ff13..ca1a3d2 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -53,7 +53,7 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; pub(crate) type RequestID = u32; type ChunkLength = u16; -pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { @@ -77,6 +77,10 @@ enum DataReader { Streaming { #[pin] reader: AssociatedStream, + packet: Vec, + pos: usize, + buf: Vec, + eos: bool, }, } @@ -84,7 +88,13 @@ impl From for DataReader { fn from(data: Data) -> DataReader { match data { Data::Full(data) => DataReader::Full { data, pos: 0 }, - Data::Streaming(reader) => DataReader::Streaming { reader }, + Data::Streaming(reader) => DataReader::Streaming { + reader, + packet: Vec::new(), + pos: 0, + buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), + eos: false, + }, } } } @@ -107,16 +117,43 @@ impl Stream for DataReader { Poll::Ready(Some((body, len))) } } - DataReaderProj::Streaming { reader } => { - reader.poll_next(cx).map(|opt| { - opt.map(|v| { - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, v.len()); - // TODO this can throw away long vec, they should be splited instead - body[..len].copy_from_slice(&v[..len]); - (body, len) - }) - }) + DataReaderProj::Streaming { + mut reader, + packet, + pos, + buf, + eos, + } => { + if *eos { + // eos was reached at previous call to poll_next, where a partial packet + // was returned. Now return None + return Poll::Ready(None); + } + loop { + let packet_left = packet.len() - *pos; + let buf_left = MAX_CHUNK_LENGTH as usize - buf.len(); + let to_read = std::cmp::min(buf_left, packet_left); + buf.extend_from_slice(&packet[*pos..*pos + to_read]); + *pos += to_read; + if buf.len() == MAX_CHUNK_LENGTH as usize { + // we have a full buf, ready to send + break; + } + + // we don't have a full buf, packet is empty; try receive more + if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) { + *packet = p; + *pos = 0; + } else { + *eos = true; + break; + } + } + + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + body[..buf.len()].copy_from_slice(&buf); + buf.clear(); + Poll::Ready(Some((body, MAX_CHUNK_LENGTH as usize))) } } } @@ -196,10 +233,7 @@ pub(crate) trait SendLoop: Sync { data: data.into(), }); } else if let Some(mut item) = sending.pop() { - trace!( - "send_loop: sending bytes for {}", - item.id, - ); + trace!("send_loop: sending bytes for {}", item.id,); let data = futures::select! { data = item.data.next().fuse() => data, @@ -210,7 +244,6 @@ pub(crate) trait SendLoop: Sync { // TODO if every SendQueueItem is waiting on data, use select_all to await // something to do - // TODO find some way to not require sending empty last chunk } }; @@ -222,7 +255,7 @@ pub(crate) trait SendLoop: Sync { None => &[], }; - if !data.is_empty() { + if data.len() == MAX_CHUNK_LENGTH as usize { let size_header = ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION); write.write_all(&size_header[..]).await?; @@ -231,7 +264,6 @@ pub(crate) trait SendLoop: Sync { sending.push(item); } else { - // this is always zero for now, but may be more when above TODO get fixed let size_header = ChunkLength::to_be_bytes(data.len() as u16); write.write_all(&size_header[..]).await?; @@ -267,38 +299,38 @@ pub(crate) trait SendLoop: Sync { } struct ChannelPair { - receiver: Option>>, - sender: Option>>, + receiver: Option>>, + sender: Option>>, } impl ChannelPair { - fn take_receiver(&mut self) -> Option>> { - self.receiver.take() - } + fn take_receiver(&mut self) -> Option>> { + self.receiver.take() + } - fn take_sender(&mut self) -> Option>> { - self.sender.take() - } + fn take_sender(&mut self) -> Option>> { + self.sender.take() + } - fn ref_sender(&mut self) -> Option<&UnboundedSender>> { - self.sender.as_ref().take() - } + fn ref_sender(&mut self) -> Option<&UnboundedSender>> { + self.sender.as_ref().take() + } - fn insert_into(self, map: &mut HashMap, index: RequestID) { - if self.receiver.is_some() || self.sender.is_some() { - map.insert(index, self); - } - } + fn insert_into(self, map: &mut HashMap, index: RequestID) { + if self.receiver.is_some() || self.sender.is_some() { + map.insert(index, self); + } + } } impl Default for ChannelPair { - fn default() -> Self { - let (send, recv) = unbounded(); - ChannelPair { - receiver: Some(recv), - sender: Some(send), - } - } + fn default() -> Self { + let (send, recv) = unbounded(); + ChannelPair { + receiver: Some(recv), + sender: Some(send), + } + } } /// The RecvLoop trait, which is implemented both by the client and the server @@ -317,10 +349,7 @@ pub(crate) trait RecvLoop: Sync + 'static { R: AsyncReadExt + Unpin + Send + Sync, { let mut receiving: HashMap> = HashMap::new(); - let mut streams: HashMap< - RequestID, - ChannelPair, - > = HashMap::new(); + let mut streams: HashMap = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -345,7 +374,7 @@ pub(crate) trait RecvLoop: Sync + 'static { trace!("recv_loop: read {} bytes", next_slice.len()); if id & 1 == 0 { - // main stream + // main stream let mut msg_bytes = receiving.remove(&id).unwrap_or_default(); msg_bytes.extend_from_slice(&next_slice[..]); @@ -357,30 +386,30 @@ pub(crate) trait RecvLoop: Sync + 'static { if let Some(receiver) = channel_pair.take_receiver() { self.recv_handler(id, msg_bytes, Box::pin(receiver)); } else { - warn!("Couldn't take receiver part of stream") - } + warn!("Couldn't take receiver part of stream") + } - channel_pair.insert_into(&mut streams, id | 1); + channel_pair.insert_into(&mut streams, id | 1); } } else { - // associated stream - let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); + // associated stream + let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); // if we get an error, the receiving end is disconnected. We still need to // reach eos before dropping this sender - if !next_slice.is_empty() { - if let Some(sender) = channel_pair.ref_sender() { - let _ = sender.unbounded_send(next_slice); - } else { - warn!("Couldn't take sending part of stream") - } - } + if !next_slice.is_empty() { + if let Some(sender) = channel_pair.ref_sender() { + let _ = sender.unbounded_send(next_slice); + } else { + warn!("Couldn't take sending part of stream") + } + } if !has_cont { - channel_pair.take_sender(); - } + channel_pair.take_sender(); + } - channel_pair.insert_into(&mut streams, id); + channel_pair.insert_into(&mut streams, id); } } Ok(()) -- 2.43.4 From 4745e7c4ba5665d3303ae567087781778cec9c34 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Wed, 8 Jun 2022 00:30:56 +0200 Subject: [PATCH 3/7] further work on streams most changes still required are related to error handling --- src/client.rs | 5 ++- src/endpoint.rs | 78 +++++++++++++++++++++++++---------------- src/netapp.rs | 4 ++- src/peering/fullmesh.rs | 8 +++-- src/proto.rs | 5 +-- src/util.rs | 5 ++- 6 files changed, 66 insertions(+), 39 deletions(-) diff --git a/src/client.rs b/src/client.rs index bce7aca..bc16fb1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -227,9 +227,8 @@ impl ClientConn { let code = resp[0]; if code == 0 { - let mut deser = rmp_serde::decode::Deserializer::from_read_ref(&resp[1..]); - let res = T::Response::deserialize_msg(&mut deser, stream).await?; - Ok(res) + let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?; + Ok(T::Response::deserialize_msg(ser_resp, stream).await) } else { let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); Err(Error::Remote(code, msg)) diff --git a/src/endpoint.rs b/src/endpoint.rs index 81ed036..c25365a 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -5,8 +5,7 @@ use std::sync::Arc; use arc_swap::ArcSwapOption; use async_trait::async_trait; -use serde::de::Error as DeError; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::{Deserialize, Serialize}; use crate::error::Error; use crate::netapp::*; @@ -22,42 +21,61 @@ pub trait Message: SerializeMessage + Send + Sync { /// A trait for de/serializing messages, with possible associated stream. #[async_trait] pub trait SerializeMessage: Sized { - fn serialize_msg( - &self, - serializer: S, - ) -> Result<(S::Ok, Option), S::Error>; + type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; - async fn deserialize_msg<'de, D: Deserializer<'de> + Send>( - deserializer: D, - stream: AssociatedStream, - ) -> Result; + // TODO should return Result + fn serialize_msg(&self) -> (Self::SerializableSelf, Option); + + // TODO should return Result + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self; } +pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} + #[async_trait] impl SerializeMessage for T where - T: Serialize + for<'de> Deserialize<'de> + Send + Sync, + T: AutoSerialize, { - fn serialize_msg( - &self, - serializer: S, - ) -> Result<(S::Ok, Option), S::Error> { - self.serialize(serializer).map(|r| (r, None)) + type SerializableSelf = Self; + fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { + (self.clone(), None) } - async fn deserialize_msg<'de, D: Deserializer<'de> + Send>( - deserializer: D, - mut stream: AssociatedStream, - ) -> Result { - use futures::StreamExt; + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self { + // TODO verify no stream + ser_self + } +} - let res = Self::deserialize(deserializer)?; - if stream.next().await.is_some() { - return Err(D::Error::custom( - "failed to deserialize: found associated stream when none expected", - )); +impl AutoSerialize for () {} + +#[async_trait] +impl SerializeMessage for Result +where + T: SerializeMessage + Send, + E: SerializeMessage + Send, +{ + type SerializableSelf = Result; + + fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { + match self { + Ok(ok) => { + let (msg, stream) = ok.serialize_msg(); + (Ok(msg), stream) + } + Err(err) => { + let (msg, stream) = err.serialize_msg(); + (Err(msg), stream) + } + } + } + + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self { + match ser_self { + Ok(ok) => Ok(T::deserialize_msg(ok, stream).await), + Err(err) => Err(E::deserialize_msg(err, stream).await), } - Ok(res) } } @@ -139,7 +157,7 @@ where prio: RequestPriority, ) -> Result<::Response, Error> where - B: Borrow, + B: Borrow + Send + Sync, { if *target == self.netapp.id { match self.handler.load_full() { @@ -202,8 +220,8 @@ where match self.0.handler.load_full() { None => Err(Error::NoHandler), Some(h) => { - let mut deser = rmp_serde::decode::Deserializer::from_read_ref(buf); - let req = M::deserialize_msg(&mut deser, stream).await?; + let req = rmp_serde::decode::from_read_ref(buf)?; + let req = M::deserialize_msg(req, stream).await; let res = h.handle(&req, from).await; let res_bytes = rmp_to_vec_all_named(&res)?; Ok(res_bytes) diff --git a/src/netapp.rs b/src/netapp.rs index e9efa2e..27f17e6 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -32,12 +32,14 @@ pub(crate) type VersionTag = [u8; 16]; /// Value of the Netapp version used in the version tag pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004 -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub(crate) struct HelloMessage { pub server_addr: Option, pub server_port: u16, } +impl AutoSerialize for HelloMessage {} + impl Message for HelloMessage { type Response = (); } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 012c5a0..7dfc5c4 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -29,7 +29,7 @@ const FAILED_PING_THRESHOLD: usize = 3; // -- Protocol messages -- -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] struct PingMessage { pub id: u64, pub peer_list_hash: hash::Digest, @@ -39,7 +39,9 @@ impl Message for PingMessage { type Response = PingMessage; } -#[derive(Serialize, Deserialize)] +impl AutoSerialize for PingMessage {} + +#[derive(Serialize, Deserialize, Clone)] struct PeerListMessage { pub list: Vec<(NodeID, SocketAddr)>, } @@ -48,6 +50,8 @@ impl Message for PeerListMessage { type Response = PeerListMessage; } +impl AutoSerialize for PeerListMessage {} + // -- Algorithm data structures -- #[derive(Debug)] diff --git a/src/proto.rs b/src/proto.rs index ca1a3d2..073a317 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -151,9 +151,10 @@ impl Stream for DataReader { } let mut body = [0; MAX_CHUNK_LENGTH as usize]; - body[..buf.len()].copy_from_slice(&buf); + let len = buf.len(); + body[..len].copy_from_slice(buf); buf.clear(); - Poll::Ready(Some((body, MAX_CHUNK_LENGTH as usize))) + Poll::Ready(Some((body, len))) } } } diff --git a/src/util.rs b/src/util.rs index 4333080..02b4e7d 100644 --- a/src/util.rs +++ b/src/util.rs @@ -8,6 +8,8 @@ use futures::Stream; use log::info; +use serde::Serialize; + use tokio::sync::watch; /// A node's identifier, which is also its public cryptographic key @@ -34,7 +36,8 @@ where let mut se = rmp_serde::Serializer::new(&mut wr) .with_struct_map() .with_string_variants(); - let (_, stream) = val.serialize_msg(&mut se)?; + let (val, stream) = val.serialize_msg(); + val.serialize(&mut se)?; Ok((wr, stream)) } -- 2.43.4 From 5d7541e13a4c3640f0dc8aead595b51775fc0ac8 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Sun, 19 Jun 2022 17:44:07 +0200 Subject: [PATCH 4/7] wait for any ready stream instead of the highest priority one --- src/endpoint.rs | 2 +- src/proto.rs | 185 ++++++++++++++++++++++++++++++------------------ src/util.rs | 8 +++ 3 files changed, 124 insertions(+), 71 deletions(-) diff --git a/src/endpoint.rs b/src/endpoint.rs index c25365a..c430d4e 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -42,7 +42,7 @@ where (self.clone(), None) } - async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self { + async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: AssociatedStream) -> Self { // TODO verify no stream ser_self } diff --git a/src/proto.rs b/src/proto.rs index 073a317..417b508 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -7,7 +7,7 @@ use log::{trace, warn}; use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::Stream; -use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, StreamExt}; +use futures::{AsyncReadExt, AsyncWriteExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -53,7 +53,8 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; pub(crate) type RequestID = u32; type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; +const ERROR_MARKER: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { @@ -99,8 +100,29 @@ impl From for DataReader { } } +struct DataReaderItem { + /// a fixed size buffer containing some data, possibly padded with 0s + data: [u8; MAX_CHUNK_LENGTH as usize], + /// actuall lenght of data + len: usize, + /// whethere there may be more data comming from this stream. Can be used for some + /// optimization. It's an error to set it to false if there is more data, but it is correct + /// (albeit sub-optimal) to set it to true if there is nothing coming after + may_have_more: bool, +} + +impl DataReaderItem { + fn empty_last() -> Self { + DataReaderItem { + data: [0; MAX_CHUNK_LENGTH as usize], + len: 0, + may_have_more: false, + } + } +} + impl Stream for DataReader { - type Item = ([u8; MAX_CHUNK_LENGTH as usize], usize); + type Item = DataReaderItem; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { @@ -114,7 +136,11 @@ impl Stream for DataReader { let mut body = [0; MAX_CHUNK_LENGTH as usize]; body[..len].copy_from_slice(&data[*pos..end]); *pos = end; - Poll::Ready(Some((body, len))) + Poll::Ready(Some(DataReaderItem { + data: body, + len, + may_have_more: end < data.len(), + })) } } DataReaderProj::Streaming { @@ -154,7 +180,11 @@ impl Stream for DataReader { let len = buf.len(); body[..len].copy_from_slice(buf); buf.clear(); - Poll::Ready(Some((body, len))) + Poll::Ready(Some(DataReaderItem { + data: body, + len, + may_have_more: !*eos, + })) } } } @@ -181,6 +211,8 @@ impl SendQueue { }; self.items[pos_prio].1.push_back(item); } + // used only in tests. They should probably be rewriten + #[allow(dead_code)] fn pop(&mut self) -> Option { match self.items.pop_front() { None => None, @@ -196,6 +228,54 @@ impl SendQueue { fn is_empty(&self) -> bool { self.items.iter().all(|(_k, v)| v.is_empty()) } + + // this is like an async fn, but hand implemented + fn next_ready(&mut self) -> SendQueuePollNextReady<'_> { + SendQueuePollNextReady { queue: self } + } +} + +struct SendQueuePollNextReady<'a> { + queue: &'a mut SendQueue, +} + +impl<'a> futures::Future for SendQueuePollNextReady<'a> { + type Output = (RequestID, DataReaderItem); + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + for i in 0..self.queue.items.len() { + let (_prio, items_at_prio) = &mut self.queue.items[i]; + + for _ in 0..items_at_prio.len() { + let mut item = items_at_prio.pop_front().unwrap(); + + match Pin::new(&mut item.data).poll_next(ctx) { + Poll::Pending => items_at_prio.push_back(item), + Poll::Ready(Some(data)) => { + let id = item.id; + if data.may_have_more { + self.queue.push(item); + } else { + if items_at_prio.is_empty() { + // this priority level is empty, remove it + self.queue.items.remove(i); + } + } + return Poll::Ready((id, data)); + } + Poll::Ready(None) => { + if items_at_prio.is_empty() { + // this priority level is empty, remove it + self.queue.items.remove(i); + } + return Poll::Ready((item.id, DataReaderItem::empty_last())); + } + } + } + } + // TODO what do we do if self.queue is empty? We won't get scheduled again. + Poll::Pending + } } /// The SendLoop trait, which is implemented both by the client and the server @@ -219,77 +299,42 @@ pub(crate) trait SendLoop: Sync { let mut sending = SendQueue::new(); let mut should_exit = false; while !should_exit || !sending.is_empty() { - if let Ok((id, prio, data)) = msg_recv.try_recv() { - match &data { - Data::Full(data) => { - trace!("send_loop: got {}, {} bytes", id, data.len()); - } - Data::Streaming(_) => { - trace!("send_loop: got {}, unknown size", id); - } + let recv_fut = msg_recv.recv(); + futures::pin_mut!(recv_fut); + let send_fut = sending.next_ready(); + + // recv_fut is cancellation-safe according to tokio doc, + // send_fut is cancellation-safe as implemented above? + use futures::future::Either; + match futures::future::select(recv_fut, send_fut).await { + Either::Left((sth, _send_fut)) => { + if let Some((id, prio, data)) = sth { + sending.push(SendQueueItem { + id, + prio, + data: data.into(), + }); + } else { + should_exit = true; + }; } - sending.push(SendQueueItem { - id, - prio, - data: data.into(), - }); - } else if let Some(mut item) = sending.pop() { - trace!("send_loop: sending bytes for {}", item.id,); + Either::Right(((id, data), _recv_fut)) => { + trace!("send_loop: sending bytes for {}", id); - let data = futures::select! { - data = item.data.next().fuse() => data, - default => { - // nothing to send yet; re-schedule and find something else to do - sending.push(item); - continue; + let header_id = RequestID::to_be_bytes(id); + write.write_all(&header_id[..]).await?; - // TODO if every SendQueueItem is waiting on data, use select_all to await - // something to do - } - }; + let body = &data.data[..data.len]; - let header_id = RequestID::to_be_bytes(item.id); - write.write_all(&header_id[..]).await?; + let size_header = if data.may_have_more { + ChunkLength::to_be_bytes(data.len as u16 | CHUNK_HAS_CONTINUATION) + } else { + ChunkLength::to_be_bytes(data.len as u16) + }; - let data = match data.as_ref() { - Some((data, len)) => &data[..*len], - None => &[], - }; - - if data.len() == MAX_CHUNK_LENGTH as usize { - let size_header = - ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION); write.write_all(&size_header[..]).await?; - - write.write_all(data).await?; - - sending.push(item); - } else { - let size_header = ChunkLength::to_be_bytes(data.len() as u16); - write.write_all(&size_header[..]).await?; - - write.write_all(data).await?; - } - - write.flush().await?; - } else { - let sth = msg_recv.recv().await; - if let Some((id, prio, data)) = sth { - match &data { - Data::Full(data) => { - trace!("send_loop: got {}, {} bytes", id, data.len()); - } - Data::Streaming(_) => { - trace!("send_loop: got {}, unknown size", id); - } - } - sending.push(SendQueueItem { - id, - prio, - data: data.into(), - }); - } else { - should_exit = true; + write.write_all(body).await?; + write.flush().await?; } } } diff --git a/src/util.rs b/src/util.rs index 02b4e7d..3ee0cb9 100644 --- a/src/util.rs +++ b/src/util.rs @@ -19,6 +19,14 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; /// A network key pub type NetworkKey = sodiumoxide::crypto::auth::Key; +/// A stream of associated data. +/// +/// The Stream can continue after receiving an error. +/// When sent through Netapp, the Vec may be split in smaller chunk in such a way +/// consecutive Vec may get merged, but Vec and error code may not be reordered +/// +/// The error code have no predefined meaning, it's up to you application to define their +/// semantic. pub type AssociatedStream = Pin> + Send>>; /// Utility function: encodes any serializable value in MessagePack binary format -- 2.43.4 From 0fec85b47a1bc679d2684994bfae6ef0fe7d4911 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Sun, 19 Jun 2022 18:42:27 +0200 Subject: [PATCH 5/7] start supporting sending error on stream --- src/proto.rs | 99 +++++++++++++++++++++++++++++++++++++--------------- src/util.rs | 2 +- 2 files changed, 72 insertions(+), 29 deletions(-) diff --git a/src/proto.rs b/src/proto.rs index 417b508..e3f9be8 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -78,7 +78,7 @@ enum DataReader { Streaming { #[pin] reader: AssociatedStream, - packet: Vec, + packet: Result, u8>, pos: usize, buf: Vec, eos: bool, @@ -91,7 +91,7 @@ impl From for DataReader { Data::Full(data) => DataReader::Full { data, pos: 0 }, Data::Streaming(reader) => DataReader::Streaming { reader, - packet: Vec::new(), + packet: Ok(Vec::new()), pos: 0, buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), eos: false, @@ -100,11 +100,18 @@ impl From for DataReader { } } +enum DataFrame { + Data { + /// a fixed size buffer containing some data, possibly padded with 0s + data: [u8; MAX_CHUNK_LENGTH as usize], + /// actual lenght of data + len: usize, + }, + Error(u8), +} + struct DataReaderItem { - /// a fixed size buffer containing some data, possibly padded with 0s - data: [u8; MAX_CHUNK_LENGTH as usize], - /// actuall lenght of data - len: usize, + data: DataFrame, /// whethere there may be more data comming from this stream. Can be used for some /// optimization. It's an error to set it to false if there is more data, but it is correct /// (albeit sub-optimal) to set it to true if there is nothing coming after @@ -114,11 +121,34 @@ struct DataReaderItem { impl DataReaderItem { fn empty_last() -> Self { DataReaderItem { - data: [0; MAX_CHUNK_LENGTH as usize], - len: 0, + data: DataFrame::Data { + data: [0; MAX_CHUNK_LENGTH as usize], + len: 0, + }, may_have_more: false, } } + + fn header(&self) -> [u8; 2] { + let continuation = if self.may_have_more { + CHUNK_HAS_CONTINUATION + } else { + 0 + }; + let len = match self.data { + DataFrame::Data { len, .. } => len as u16, + DataFrame::Error(e) => e as u16 | ERROR_MARKER, + }; + + ChunkLength::to_be_bytes(len | continuation) + } + + fn data(&self) -> &[u8] { + match self.data { + DataFrame::Data { ref data, len } => &data[..len], + DataFrame::Error(_) => &[], + } + } } impl Stream for DataReader { @@ -137,15 +167,14 @@ impl Stream for DataReader { body[..len].copy_from_slice(&data[*pos..end]); *pos = end; Poll::Ready(Some(DataReaderItem { - data: body, - len, + data: DataFrame::Data { data: body, len }, may_have_more: end < data.len(), })) } } DataReaderProj::Streaming { mut reader, - packet, + packet: res_packet, pos, buf, eos, @@ -156,6 +185,17 @@ impl Stream for DataReader { return Poll::Ready(None); } loop { + let packet = match res_packet { + Ok(v) => v, + Err(e) => { + let e = *e; + *res_packet = Ok(Vec::new()); + return Poll::Ready(Some(DataReaderItem { + data: DataFrame::Error(e), + may_have_more: true, + })); + } + }; let packet_left = packet.len() - *pos; let buf_left = MAX_CHUNK_LENGTH as usize - buf.len(); let to_read = std::cmp::min(buf_left, packet_left); @@ -168,8 +208,13 @@ impl Stream for DataReader { // we don't have a full buf, packet is empty; try receive more if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) { - *packet = p; + *res_packet = p; *pos = 0; + // if buf is empty, we will loop and return the error directly. If buf + // isn't empty, send it before by breaking. + if res_packet.is_err() && !buf.is_empty() { + break; + } } else { *eos = true; break; @@ -181,8 +226,7 @@ impl Stream for DataReader { body[..len].copy_from_slice(buf); buf.clear(); Poll::Ready(Some(DataReaderItem { - data: body, - len, + data: DataFrame::Data { data: body, len }, may_have_more: !*eos, })) } @@ -211,8 +255,8 @@ impl SendQueue { }; self.items[pos_prio].1.push_back(item); } - // used only in tests. They should probably be rewriten - #[allow(dead_code)] + // used only in tests. They should probably be rewriten + #[allow(dead_code)] fn pop(&mut self) -> Option { match self.items.pop_front() { None => None, @@ -324,16 +368,8 @@ pub(crate) trait SendLoop: Sync { let header_id = RequestID::to_be_bytes(id); write.write_all(&header_id[..]).await?; - let body = &data.data[..data.len]; - - let size_header = if data.may_have_more { - ChunkLength::to_be_bytes(data.len as u16 | CHUNK_HAS_CONTINUATION) - } else { - ChunkLength::to_be_bytes(data.len as u16) - }; - - write.write_all(&size_header[..]).await?; - write.write_all(body).await?; + write.write_all(&data.header()).await?; + write.write_all(data.data()).await?; write.flush().await?; } } @@ -413,7 +449,13 @@ pub(crate) trait RecvLoop: Sync + 'static { trace!("recv_loop: got header size: {:04x}", size); let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; - let size = size & !CHUNK_HAS_CONTINUATION; + let is_error = (size & ERROR_MARKER) != 0; + let size = if !is_error { + size & !CHUNK_HAS_CONTINUATION + } else { + 0 + }; + // TODO propagate errors let mut next_slice = vec![0; size as usize]; read.read_exact(&mut next_slice[..]).await?; @@ -430,7 +472,8 @@ pub(crate) trait RecvLoop: Sync + 'static { let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); if let Some(receiver) = channel_pair.take_receiver() { - self.recv_handler(id, msg_bytes, Box::pin(receiver)); + use futures::StreamExt; + self.recv_handler(id, msg_bytes, Box::pin(receiver.map(|v| Ok(v)))); } else { warn!("Couldn't take receiver part of stream") } diff --git a/src/util.rs b/src/util.rs index 3ee0cb9..76d7ecf 100644 --- a/src/util.rs +++ b/src/util.rs @@ -27,7 +27,7 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// /// The error code have no predefined meaning, it's up to you application to define their /// semantic. -pub type AssociatedStream = Pin> + Send>>; +pub type AssociatedStream = Pin, u8>> + Send>>; /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. -- 2.43.4 From d3d18b8e8bde5fee81022fd050d5f4c114262fcf Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Mon, 20 Jun 2022 23:40:31 +0200 Subject: [PATCH 6/7] use a framing protocol instead of even/odd channel --- src/client.rs | 32 ++--- src/endpoint.rs | 1 - src/error.rs | 4 + src/proto.rs | 362 ++++++++++++++++++++++-------------------------- src/server.rs | 26 ++-- 5 files changed, 192 insertions(+), 233 deletions(-) diff --git a/src/client.rs b/src/client.rs index bc16fb1..a630f87 100644 --- a/src/client.rs +++ b/src/client.rs @@ -37,10 +37,11 @@ pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption>, + query_send: + ArcSwapOption>, next_query_number: AtomicU32, - inflight: Mutex, AssociatedStream)>>>, + inflight: Mutex>>, } impl ClientConn { @@ -148,11 +149,9 @@ impl ClientConn { { let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; - // increment by 2; even are direct data; odd are associated stream let id = self .next_query_number - .fetch_add(2, atomic::Ordering::Relaxed); - let stream_id = id + 1; + .fetch_add(1, atomic::Ordering::Relaxed); cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { @@ -187,10 +186,7 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch - .send((vec![], Box::pin(futures::stream::empty()))) - .is_err() - { + if old_ch.send(Box::pin(futures::stream::empty())).is_err() { debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); } } @@ -200,22 +196,18 @@ impl ClientConn { #[cfg(feature = "telemetry")] span.set_attribute(KeyValue::new("len_query", bytes.len() as i64)); - query_send.send((id, prio, Data::Full(bytes)))?; - if let Some(stream) = stream { - query_send.send((stream_id, prio | PRIO_SECONDARY, Data::Streaming(stream)))?; - } else { - query_send.send((stream_id, prio, Data::Full(Vec::new())))?; - } + query_send.send((id, prio, Framing::new(bytes, stream).into_stream()))?; cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { - let (resp, stream) = resp_recv + let stream = resp_recv .with_context(Context::current_with_span(span)) .await?; } else { - let (resp, stream) = resp_recv.await?; + let stream = resp_recv.await?; } } + let (resp, stream) = Framing::from_stream(stream).await?.into_parts(); if resp.is_empty() { return Err(Error::Message( @@ -240,12 +232,12 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc, id: RequestID, msg: Vec, stream: AssociatedStream) { - trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len()); + fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream) { + trace!("ClientConn recv_handler {}", id); let mut inflight = self.inflight.lock().unwrap(); if let Some(ch) = inflight.remove(&id) { - if ch.send((msg, stream)).is_err() { + if ch.send(stream).is_err() { debug!("Could not send request response, probably because request was interrupted. Dropping response."); } } diff --git a/src/endpoint.rs b/src/endpoint.rs index c430d4e..f31141d 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -23,7 +23,6 @@ pub trait Message: SerializeMessage + Send + Sync { pub trait SerializeMessage: Sized { type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; - // TODO should return Result fn serialize_msg(&self) -> (Self::SerializableSelf, Option); // TODO should return Result diff --git a/src/error.rs b/src/error.rs index 99acdd1..7911c29 100644 --- a/src/error.rs +++ b/src/error.rs @@ -25,6 +25,9 @@ pub enum Error { #[error(display = "UTF8 error: {}", _0)] UTF8(#[error(source)] std::string::FromUtf8Error), + #[error(display = "Framing protocol error")] + Framing, + #[error(display = "{}", _0)] Message(String), @@ -50,6 +53,7 @@ impl Error { Self::RMPEncode(_) => 10, Self::RMPDecode(_) => 11, Self::UTF8(_) => 12, + Self::Framing => 13, Self::NoHandler => 20, Self::ConnectionClosed => 21, Self::Handshake(_) => 30, diff --git a/src/proto.rs b/src/proto.rs index e3f9be8..d6dc35a 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -3,11 +3,11 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use log::{trace, warn}; +use log::trace; -use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; -use futures::Stream; +use futures::channel::mpsc::{unbounded, UnboundedSender}; use futures::{AsyncReadExt, AsyncWriteExt}; +use futures::{Stream, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -63,39 +63,24 @@ struct SendQueueItem { data: DataReader, } -pub(crate) enum Data { - Full(Vec), - Streaming(AssociatedStream), +#[pin_project::pin_project] +struct DataReader { + #[pin] + reader: AssociatedStream, + packet: Result, u8>, + pos: usize, + buf: Vec, + eos: bool, } -#[pin_project::pin_project(project = DataReaderProj)] -enum DataReader { - Full { - #[pin] - data: Vec, - pos: usize, - }, - Streaming { - #[pin] - reader: AssociatedStream, - packet: Result, u8>, - pos: usize, - buf: Vec, - eos: bool, - }, -} - -impl From for DataReader { - fn from(data: Data) -> DataReader { - match data { - Data::Full(data) => DataReader::Full { data, pos: 0 }, - Data::Streaming(reader) => DataReader::Streaming { - reader, - packet: Ok(Vec::new()), - pos: 0, - buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), - eos: false, - }, +impl From for DataReader { + fn from(data: AssociatedStream) -> DataReader { + DataReader { + reader: data, + packet: Ok(Vec::new()), + pos: 0, + buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), + eos: false, } } } @@ -155,82 +140,60 @@ impl Stream for DataReader { type Item = DataReaderItem; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { - DataReaderProj::Full { data, pos } => { - let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, data.len() - *pos); - let end = *pos + len; + let mut this = self.project(); - if len == 0 { - Poll::Ready(None) - } else { - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - body[..len].copy_from_slice(&data[*pos..end]); - *pos = end; - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, - may_have_more: end < data.len(), - })) + if *this.eos { + // eos was reached at previous call to poll_next, where a partial packet + // was returned. Now return None + return Poll::Ready(None); + } + + loop { + let packet = match this.packet { + Ok(v) => v, + Err(e) => { + let e = *e; + *this.packet = Ok(Vec::new()); + return Poll::Ready(Some(DataReaderItem { + data: DataFrame::Error(e), + may_have_more: true, + })); } + }; + let packet_left = packet.len() - *this.pos; + let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len(); + let to_read = std::cmp::min(buf_left, packet_left); + this.buf + .extend_from_slice(&packet[*this.pos..*this.pos + to_read]); + *this.pos += to_read; + if this.buf.len() == MAX_CHUNK_LENGTH as usize { + // we have a full buf, ready to send + break; } - DataReaderProj::Streaming { - mut reader, - packet: res_packet, - pos, - buf, - eos, - } => { - if *eos { - // eos was reached at previous call to poll_next, where a partial packet - // was returned. Now return None - return Poll::Ready(None); - } - loop { - let packet = match res_packet { - Ok(v) => v, - Err(e) => { - let e = *e; - *res_packet = Ok(Vec::new()); - return Poll::Ready(Some(DataReaderItem { - data: DataFrame::Error(e), - may_have_more: true, - })); - } - }; - let packet_left = packet.len() - *pos; - let buf_left = MAX_CHUNK_LENGTH as usize - buf.len(); - let to_read = std::cmp::min(buf_left, packet_left); - buf.extend_from_slice(&packet[*pos..*pos + to_read]); - *pos += to_read; - if buf.len() == MAX_CHUNK_LENGTH as usize { - // we have a full buf, ready to send - break; - } - // we don't have a full buf, packet is empty; try receive more - if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) { - *res_packet = p; - *pos = 0; - // if buf is empty, we will loop and return the error directly. If buf - // isn't empty, send it before by breaking. - if res_packet.is_err() && !buf.is_empty() { - break; - } - } else { - *eos = true; - break; - } + // we don't have a full buf, packet is empty; try receive more + if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) { + *this.packet = p; + *this.pos = 0; + // if buf is empty, we will loop and return the error directly. If buf + // isn't empty, send it before by breaking. + if this.packet.is_err() && !this.buf.is_empty() { + break; } - - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - let len = buf.len(); - body[..len].copy_from_slice(buf); - buf.clear(); - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, - may_have_more: !*eos, - })) + } else { + *this.eos = true; + break; } } + + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + let len = this.buf.len(); + body[..len].copy_from_slice(this.buf); + this.buf.clear(); + Poll::Ready(Some(DataReaderItem { + data: DataFrame::Data { data: body, len }, + may_have_more: !*this.eos, + })) } } @@ -334,7 +297,7 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> { pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Data)>, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>, mut write: BoxStreamWrite, ) -> Result<(), Error> where @@ -380,38 +343,82 @@ pub(crate) trait SendLoop: Sync { } } -struct ChannelPair { - receiver: Option>>, - sender: Option>>, +pub(crate) struct Framing { + direct: Vec, + stream: Option, } -impl ChannelPair { - fn take_receiver(&mut self) -> Option>> { - self.receiver.take() +impl Framing { + pub fn new(direct: Vec, stream: Option) -> Self { + assert!(direct.len() <= u32::MAX as usize); + Framing { direct, stream } } - fn take_sender(&mut self) -> Option>> { - self.sender.take() - } + pub fn into_stream(self) -> AssociatedStream { + use futures::stream; + let len = self.direct.len() as u32; + // required because otherwise the borrow-checker complains + let Framing { direct, stream } = self; - fn ref_sender(&mut self) -> Option<&UnboundedSender>> { - self.sender.as_ref().take() - } + let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) + .chain(stream::once(async move { Ok(direct) })); - fn insert_into(self, map: &mut HashMap, index: RequestID) { - if self.receiver.is_some() || self.sender.is_some() { - map.insert(index, self); + if let Some(stream) = stream { + Box::pin(res.chain(stream)) + } else { + Box::pin(res) } } -} -impl Default for ChannelPair { - fn default() -> Self { - let (send, recv) = unbounded(); - ChannelPair { - receiver: Some(recv), - sender: Some(send), + pub async fn from_stream, u8>> + Unpin + Send + 'static>( + mut stream: S, + ) -> Result { + let mut packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + if packet.len() < 4 { + return Err(Error::Framing); } + + let mut len = [0; 4]; + len.copy_from_slice(&packet[..4]); + let len = u32::from_be_bytes(len); + packet.drain(..4); + + let mut buffer = Vec::new(); + let len = len as usize; + loop { + let max_cp = std::cmp::min(len - buffer.len(), packet.len()); + + buffer.extend_from_slice(&packet[..max_cp]); + if buffer.len() == len { + packet.drain(..max_cp); + break; + } + packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + } + + let stream: AssociatedStream = if packet.is_empty() { + Box::pin(stream) + } else { + Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) + }; + + Ok(Framing { + direct: buffer, + stream: Some(stream), + }) + } + + pub fn into_parts(self) -> (Vec, AssociatedStream) { + let Framing { direct, stream } = self; + (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) } } @@ -424,14 +431,13 @@ impl Default for ChannelPair { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, msg: Vec, stream: AssociatedStream); + fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut receiving: HashMap> = HashMap::new(); - let mut streams: HashMap = HashMap::new(); + let mut streams: HashMap, u8>>> = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -450,55 +456,30 @@ pub(crate) trait RecvLoop: Sync + 'static { let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let is_error = (size & ERROR_MARKER) != 0; - let size = if !is_error { - size & !CHUNK_HAS_CONTINUATION + let packet = if is_error { + Err(size as u8) } else { - 0 + let size = size & !CHUNK_HAS_CONTINUATION; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + trace!("recv_loop: read {} bytes", next_slice.len()); + Ok(next_slice) }; - // TODO propagate errors - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; - trace!("recv_loop: read {} bytes", next_slice.len()); - - if id & 1 == 0 { - // main stream - let mut msg_bytes = receiving.remove(&id).unwrap_or_default(); - msg_bytes.extend_from_slice(&next_slice[..]); - - if has_cont { - receiving.insert(id, msg_bytes); - } else { - let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); - - if let Some(receiver) = channel_pair.take_receiver() { - use futures::StreamExt; - self.recv_handler(id, msg_bytes, Box::pin(receiver.map(|v| Ok(v)))); - } else { - warn!("Couldn't take receiver part of stream") - } - - channel_pair.insert_into(&mut streams, id | 1); - } + let sender = if let Some(send) = streams.remove(&(id)) { + send } else { - // associated stream - let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); + let (send, recv) = unbounded(); + self.recv_handler(id, Box::pin(recv)); + send + }; - // if we get an error, the receiving end is disconnected. We still need to - // reach eos before dropping this sender - if !next_slice.is_empty() { - if let Some(sender) = channel_pair.ref_sender() { - let _ = sender.unbounded_send(next_slice); - } else { - warn!("Couldn't take sending part of stream") - } - } + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + let _ = sender.unbounded_send(packet); - if !has_cont { - channel_pair.take_sender(); - } - - channel_pair.insert_into(&mut streams, id); + if has_cont { + streams.insert(id, sender); } } Ok(()) @@ -509,55 +490,44 @@ pub(crate) trait RecvLoop: Sync + 'static { mod test { use super::*; + fn empty_data() -> DataReader { + type Item = Result, u8>; + let stream: Pin + Send + 'static>> = + Box::pin(futures::stream::empty::, u8>>()); + stream.into() + } + #[test] fn test_priority_queue() { let i1 = SendQueueItem { id: 1, prio: PRIO_NORMAL, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i2 = SendQueueItem { id: 2, prio: PRIO_HIGH, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i2bis = SendQueueItem { id: 20, prio: PRIO_HIGH, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i3 = SendQueueItem { id: 3, prio: PRIO_HIGH | PRIO_SECONDARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i4 = SendQueueItem { id: 4, prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i5 = SendQueueItem { id: 5, prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let mut q = SendQueue::new(); diff --git a/src/server.rs b/src/server.rs index 6cd4056..86e5156 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,7 +2,6 @@ use std::net::SocketAddr; use std::sync::Arc; use arc_swap::ArcSwapOption; -use bytes::Bytes; use log::{debug, trace}; #[cfg(feature = "telemetry")] @@ -55,7 +54,7 @@ pub(crate) struct ServerConn { netapp: Arc, - resp_send: ArcSwapOption>, + resp_send: ArcSwapOption>, } impl ServerConn { @@ -177,13 +176,13 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc, id: RequestID, bytes: Vec, stream: AssociatedStream) { + fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream) { let resp_send = self.resp_send.load_full().unwrap(); let self2 = self.clone(); tokio::spawn(async move { - trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); - let bytes: Bytes = bytes.into(); + trace!("ServerConn recv_handler {}", id); + let (bytes, stream) = Framing::from_stream(stream).await?.into_parts(); let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; let resp = self2.recv_handler_aux(&bytes[..], stream).await; @@ -204,18 +203,13 @@ impl RecvLoop for ServerConn { trace!("ServerConn sending response to {}: ", id); resp_send - .send((id, prio, Data::Full(resp_bytes))) + .send(( + id, + prio, + Framing::new(resp_bytes, resp_stream).into_stream(), + )) .log_err("ServerConn recv_handler send resp bytes"); - - if let Some(resp_stream) = resp_stream { - resp_send - .send((id + 1, prio, Data::Streaming(resp_stream))) - .log_err("ServerConn recv_handler send resp stream"); - } else { - resp_send - .send((id + 1, prio, Data::Full(Vec::new()))) - .log_err("ServerConn recv_handler send resp stream"); - } + Ok::<_, Error>(()) }); } } -- 2.43.4 From cdff8ae1beab44a22d0eb0eb00c624e49971b6ca Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Mon, 18 Jul 2022 15:21:13 +0200 Subject: [PATCH 7/7] add detection of premature eos --- src/client.rs | 7 +++--- src/proto.rs | 59 ++++++++++++++++++++++++++++++++++++++++----------- src/server.rs | 3 ++- src/util.rs | 8 ++++--- 4 files changed, 58 insertions(+), 19 deletions(-) diff --git a/src/client.rs b/src/client.rs index a630f87..6d49f5c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -7,6 +7,7 @@ use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; use log::{debug, error, trace}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver}; use tokio::net::TcpStream; use tokio::select; use tokio::sync::{mpsc, oneshot, watch}; @@ -41,7 +42,7 @@ pub(crate) struct ClientConn { ArcSwapOption>, next_query_number: AtomicU32, - inflight: Mutex>>, + inflight: Mutex>>>, } impl ClientConn { @@ -186,7 +187,7 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch.send(Box::pin(futures::stream::empty())).is_err() { + if old_ch.send(unbounded().1).is_err() { debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); } } @@ -232,7 +233,7 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream) { + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver) { trace!("ClientConn recv_handler {}", id); let mut inflight = self.inflight.lock().unwrap(); diff --git a/src/proto.rs b/src/proto.rs index d6dc35a..92d8d80 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -5,7 +5,7 @@ use std::task::{Context, Poll}; use log::trace; -use futures::channel::mpsc::{unbounded, UnboundedSender}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::{AsyncReadExt, AsyncWriteExt}; use futures::{Stream, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; @@ -15,7 +15,7 @@ use tokio::sync::mpsc; use async_trait::async_trait; use crate::error::*; -use crate::util::AssociatedStream; +use crate::util::{AssociatedStream, Packet}; /// Priority of a request (click to read more about priorities). /// @@ -67,7 +67,7 @@ struct SendQueueItem { struct DataReader { #[pin] reader: AssociatedStream, - packet: Result, u8>, + packet: Packet, pos: usize, buf: Vec, eos: bool, @@ -370,7 +370,7 @@ impl Framing { } } - pub async fn from_stream, u8>> + Unpin + Send + 'static>( + pub async fn from_stream + Unpin + Send + 'static>( mut stream: S, ) -> Result { let mut packet = stream @@ -422,6 +422,39 @@ impl Framing { } } +/// Structure to warn when the sender is dropped before end of stream was reached, like when +/// connection to some remote drops while transmitting data +struct Sender { + inner: UnboundedSender, + closed: bool, +} + +impl Sender { + fn new(inner: UnboundedSender) -> Self { + Sender { + inner, + closed: false, + } + } + + fn send(&self, packet: Packet) { + let _ = self.inner.unbounded_send(packet); + } + + fn end(&mut self) { + self.closed = true; + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if !self.closed { + self.send(Err(255)); + } + self.inner.close_channel(); + } +} + /// The RecvLoop trait, which is implemented both by the client and the server /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// and a prototype of a handler for received messages `.recv_handler()` that @@ -431,13 +464,13 @@ impl Framing { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream); + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut streams: HashMap, u8>>> = HashMap::new(); + let mut streams: HashMap = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -466,20 +499,22 @@ pub(crate) trait RecvLoop: Sync + 'static { Ok(next_slice) }; - let sender = if let Some(send) = streams.remove(&(id)) { + let mut sender = if let Some(send) = streams.remove(&(id)) { send } else { let (send, recv) = unbounded(); - self.recv_handler(id, Box::pin(recv)); - send + self.recv_handler(id, recv); + Sender::new(send) }; // if we get an error, the receiving end is disconnected. We still need to // reach eos before dropping this sender - let _ = sender.unbounded_send(packet); + sender.send(packet); if has_cont { streams.insert(id, sender); + } else { + sender.end(); } } Ok(()) @@ -491,9 +526,9 @@ mod test { use super::*; fn empty_data() -> DataReader { - type Item = Result, u8>; + type Item = Packet; let stream: Pin + Send + 'static>> = - Box::pin(futures::stream::empty::, u8>>()); + Box::pin(futures::stream::empty::()); stream.into() } diff --git a/src/server.rs b/src/server.rs index 86e5156..8075484 100644 --- a/src/server.rs +++ b/src/server.rs @@ -19,6 +19,7 @@ use tokio::select; use tokio::sync::{mpsc, watch}; use tokio_util::compat::*; +use futures::channel::mpsc::UnboundedReceiver; use futures::io::{AsyncReadExt, AsyncWriteExt}; use async_trait::async_trait; @@ -176,7 +177,7 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream) { + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver) { let resp_send = self.resp_send.load_full().unwrap(); let self2 = self.clone(); diff --git a/src/util.rs b/src/util.rs index 76d7ecf..186678d 100644 --- a/src/util.rs +++ b/src/util.rs @@ -25,9 +25,11 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// When sent through Netapp, the Vec may be split in smaller chunk in such a way /// consecutive Vec may get merged, but Vec and error code may not be reordered /// -/// The error code have no predefined meaning, it's up to you application to define their -/// semantic. -pub type AssociatedStream = Pin, u8>> + Send>>; +/// Error code 255 means the stream was cut before its end. Other codes have no predefined +/// meaning, it's up to your application to define their semantic. +pub type AssociatedStream = Pin + Send>>; + +pub type Packet = Result, u8>; /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. -- 2.43.4