diff --git a/Cargo.lock b/Cargo.lock index fe2a29d..356c3ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -151,6 +151,19 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "env_logger" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44533bbbb3bb3c1fa17d9f2e4e38bbbaf8396ba82193c4cb1b6445d711445d36" +dependencies = [ + "atty", + "humantime 1.3.0", + "log", + "regex", + "termcolor", +] + [[package]] name = "env_logger" version = "0.8.4" @@ -158,7 +171,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3" dependencies = [ "atty", - "humantime", + "humantime 2.1.0", "log", "regex", "termcolor", @@ -322,6 +335,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "humantime" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f" +dependencies = [ + "quick-error", +] + [[package]] name = "humantime" version = "2.1.0" @@ -440,7 +462,7 @@ dependencies = [ "bytes 0.6.0", "cfg-if", "chrono", - "env_logger", + "env_logger 0.8.4", "err-derive", "futures", "hex", @@ -450,6 +472,8 @@ dependencies = [ "lru", "opentelemetry", "opentelemetry-contrib", + "pin-project", + "pretty_env_logger", "rand 0.5.6", "rmp-serde", "serde", @@ -582,6 +606,16 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" +[[package]] +name = "pretty_env_logger" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "926d36b9553851b8b0005f1275891b392ee4d2d833852c417ed025477350fb9d" +dependencies = [ + "env_logger 0.7.1", + "log", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -627,6 +661,12 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quote" version = "1.0.10" diff --git a/Cargo.toml b/Cargo.toml index 536a1e6..a2f0ab1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ telemetry = ["opentelemetry", "opentelemetry-contrib", "rand"] [dependencies] futures = "0.3.17" +pin-project = "1.0.10" tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] } tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] } tokio-stream = "0.1.7" @@ -47,6 +48,7 @@ opentelemetry-contrib = { version = "0.9", optional = true } [dev-dependencies] env_logger = "0.8" +pretty_env_logger = "0.4" structopt = { version = "0.3", default-features = false } chrono = "0.4" diff --git a/src/client.rs b/src/client.rs index 8227e8f..6d49f5c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -7,6 +7,7 @@ use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; use log::{debug, error, trace}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver}; use tokio::net::TcpStream; use tokio::select; use tokio::sync::{mpsc, oneshot, watch}; @@ -37,10 +38,11 @@ pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption)>>, + query_send: + ArcSwapOption>, next_query_number: AtomicU32, - inflight: Mutex>>>, + inflight: Mutex>>>, } impl ClientConn { @@ -166,7 +168,7 @@ impl ClientConn { }; // Encode request - let body = rmp_to_vec_all_named(rq.borrow())?; + let (body, stream) = rmp_to_vec_all_named(rq.borrow())?; drop(rq); let request = QueryMessage { @@ -185,7 +187,7 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch.send(vec![]).is_err() { + if old_ch.send(unbounded().1).is_err() { debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); } } @@ -195,17 +197,18 @@ impl ClientConn { #[cfg(feature = "telemetry")] span.set_attribute(KeyValue::new("len_query", bytes.len() as i64)); - query_send.send((id, prio, bytes))?; + query_send.send((id, prio, Framing::new(bytes, stream).into_stream()))?; cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { - let resp = resp_recv + let stream = resp_recv .with_context(Context::current_with_span(span)) .await?; } else { - let resp = resp_recv.await?; + let stream = resp_recv.await?; } } + let (resp, stream) = Framing::from_stream(stream).await?.into_parts(); if resp.is_empty() { return Err(Error::Message( @@ -217,10 +220,8 @@ impl ClientConn { let code = resp[0]; if code == 0 { - Ok(rmp_serde::decode::from_read_ref::< - _, - ::Response, - >(&resp[1..])?) + let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?; + Ok(T::Response::deserialize_msg(ser_resp, stream).await) } else { let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); Err(Error::Remote(code, msg)) @@ -232,12 +233,12 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc, id: RequestID, msg: Vec) { - trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len()); + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver) { + trace!("ClientConn recv_handler {}", id); let mut inflight = self.inflight.lock().unwrap(); if let Some(ch) = inflight.remove(&id) { - if ch.send(msg).is_err() { + if ch.send(stream).is_err() { debug!("Could not send request response, probably because request was interrupted. Dropping response."); } } diff --git a/src/endpoint.rs b/src/endpoint.rs index 42e9a98..f31141d 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -14,8 +14,68 @@ use crate::util::*; /// This trait should be implemented by all messages your application /// wants to handle -pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { - type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; +pub trait Message: SerializeMessage + Send + Sync { + type Response: SerializeMessage + Send + Sync; +} + +/// A trait for de/serializing messages, with possible associated stream. +#[async_trait] +pub trait SerializeMessage: Sized { + type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; + + fn serialize_msg(&self) -> (Self::SerializableSelf, Option); + + // TODO should return Result + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self; +} + +pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} + +#[async_trait] +impl SerializeMessage for T +where + T: AutoSerialize, +{ + type SerializableSelf = Self; + fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { + (self.clone(), None) + } + + async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: AssociatedStream) -> Self { + // TODO verify no stream + ser_self + } +} + +impl AutoSerialize for () {} + +#[async_trait] +impl SerializeMessage for Result +where + T: SerializeMessage + Send, + E: SerializeMessage + Send, +{ + type SerializableSelf = Result; + + fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { + match self { + Ok(ok) => { + let (msg, stream) = ok.serialize_msg(); + (Ok(msg), stream) + } + Err(err) => { + let (msg, stream) = err.serialize_msg(); + (Err(msg), stream) + } + } + } + + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self { + match ser_self { + Ok(ok) => Ok(T::deserialize_msg(ok, stream).await), + Err(err) => Err(E::deserialize_msg(err, stream).await), + } + } } /// This trait should be implemented by an object of your application @@ -96,7 +156,7 @@ where prio: RequestPriority, ) -> Result<::Response, Error> where - B: Borrow, + B: Borrow + Send + Sync, { if *target == self.netapp.id { match self.handler.load_full() { @@ -128,7 +188,12 @@ pub(crate) type DynEndpoint = Box; #[async_trait] pub(crate) trait GenericEndpoint { - async fn handle(&self, buf: &[u8], from: NodeID) -> Result, Error>; + async fn handle( + &self, + buf: &[u8], + stream: AssociatedStream, + from: NodeID, + ) -> Result<(Vec, Option), Error>; fn drop_handler(&self); fn clone_endpoint(&self) -> DynEndpoint; } @@ -145,11 +210,17 @@ where M: Message + 'static, H: EndpointHandler + 'static, { - async fn handle(&self, buf: &[u8], from: NodeID) -> Result, Error> { + async fn handle( + &self, + buf: &[u8], + stream: AssociatedStream, + from: NodeID, + ) -> Result<(Vec, Option), Error> { match self.0.handler.load_full() { None => Err(Error::NoHandler), Some(h) => { - let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?; + let req = rmp_serde::decode::from_read_ref(buf)?; + let req = M::deserialize_msg(req, stream).await; let res = h.handle(&req, from).await; let res_bytes = rmp_to_vec_all_named(&res)?; Ok(res_bytes) diff --git a/src/error.rs b/src/error.rs index 99acdd1..7911c29 100644 --- a/src/error.rs +++ b/src/error.rs @@ -25,6 +25,9 @@ pub enum Error { #[error(display = "UTF8 error: {}", _0)] UTF8(#[error(source)] std::string::FromUtf8Error), + #[error(display = "Framing protocol error")] + Framing, + #[error(display = "{}", _0)] Message(String), @@ -50,6 +53,7 @@ impl Error { Self::RMPEncode(_) => 10, Self::RMPDecode(_) => 11, Self::UTF8(_) => 12, + Self::Framing => 13, Self::NoHandler => 20, Self::ConnectionClosed => 21, Self::Handshake(_) => 30, diff --git a/src/netapp.rs b/src/netapp.rs index e9efa2e..27f17e6 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -32,12 +32,14 @@ pub(crate) type VersionTag = [u8; 16]; /// Value of the Netapp version used in the version tag pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004 -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub(crate) struct HelloMessage { pub server_addr: Option, pub server_port: u16, } +impl AutoSerialize for HelloMessage {} + impl Message for HelloMessage { type Response = (); } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 012c5a0..7dfc5c4 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -29,7 +29,7 @@ const FAILED_PING_THRESHOLD: usize = 3; // -- Protocol messages -- -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] struct PingMessage { pub id: u64, pub peer_list_hash: hash::Digest, @@ -39,7 +39,9 @@ impl Message for PingMessage { type Response = PingMessage; } -#[derive(Serialize, Deserialize)] +impl AutoSerialize for PingMessage {} + +#[derive(Serialize, Deserialize, Clone)] struct PeerListMessage { pub list: Vec<(NodeID, SocketAddr)>, } @@ -48,6 +50,8 @@ impl Message for PeerListMessage { type Response = PeerListMessage; } +impl AutoSerialize for PeerListMessage {} + // -- Algorithm data structures -- #[derive(Debug)] diff --git a/src/proto.rs b/src/proto.rs index e843bff..92d8d80 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -1,9 +1,13 @@ use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use log::trace; +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::{AsyncReadExt, AsyncWriteExt}; +use futures::{Stream, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -11,6 +15,7 @@ use tokio::sync::mpsc; use async_trait::async_trait; use crate::error::*; +use crate::util::{AssociatedStream, Packet}; /// Priority of a request (click to read more about priorities). /// @@ -48,14 +53,148 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; pub(crate) type RequestID = u32; type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; +const ERROR_MARKER: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { id: RequestID, prio: RequestPriority, - data: Vec, - cursor: usize, + data: DataReader, +} + +#[pin_project::pin_project] +struct DataReader { + #[pin] + reader: AssociatedStream, + packet: Packet, + pos: usize, + buf: Vec, + eos: bool, +} + +impl From for DataReader { + fn from(data: AssociatedStream) -> DataReader { + DataReader { + reader: data, + packet: Ok(Vec::new()), + pos: 0, + buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), + eos: false, + } + } +} + +enum DataFrame { + Data { + /// a fixed size buffer containing some data, possibly padded with 0s + data: [u8; MAX_CHUNK_LENGTH as usize], + /// actual lenght of data + len: usize, + }, + Error(u8), +} + +struct DataReaderItem { + data: DataFrame, + /// whethere there may be more data comming from this stream. Can be used for some + /// optimization. It's an error to set it to false if there is more data, but it is correct + /// (albeit sub-optimal) to set it to true if there is nothing coming after + may_have_more: bool, +} + +impl DataReaderItem { + fn empty_last() -> Self { + DataReaderItem { + data: DataFrame::Data { + data: [0; MAX_CHUNK_LENGTH as usize], + len: 0, + }, + may_have_more: false, + } + } + + fn header(&self) -> [u8; 2] { + let continuation = if self.may_have_more { + CHUNK_HAS_CONTINUATION + } else { + 0 + }; + let len = match self.data { + DataFrame::Data { len, .. } => len as u16, + DataFrame::Error(e) => e as u16 | ERROR_MARKER, + }; + + ChunkLength::to_be_bytes(len | continuation) + } + + fn data(&self) -> &[u8] { + match self.data { + DataFrame::Data { ref data, len } => &data[..len], + DataFrame::Error(_) => &[], + } + } +} + +impl Stream for DataReader { + type Item = DataReaderItem; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if *this.eos { + // eos was reached at previous call to poll_next, where a partial packet + // was returned. Now return None + return Poll::Ready(None); + } + + loop { + let packet = match this.packet { + Ok(v) => v, + Err(e) => { + let e = *e; + *this.packet = Ok(Vec::new()); + return Poll::Ready(Some(DataReaderItem { + data: DataFrame::Error(e), + may_have_more: true, + })); + } + }; + let packet_left = packet.len() - *this.pos; + let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len(); + let to_read = std::cmp::min(buf_left, packet_left); + this.buf + .extend_from_slice(&packet[*this.pos..*this.pos + to_read]); + *this.pos += to_read; + if this.buf.len() == MAX_CHUNK_LENGTH as usize { + // we have a full buf, ready to send + break; + } + + // we don't have a full buf, packet is empty; try receive more + if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) { + *this.packet = p; + *this.pos = 0; + // if buf is empty, we will loop and return the error directly. If buf + // isn't empty, send it before by breaking. + if this.packet.is_err() && !this.buf.is_empty() { + break; + } + } else { + *this.eos = true; + break; + } + } + + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + let len = this.buf.len(); + body[..len].copy_from_slice(this.buf); + this.buf.clear(); + Poll::Ready(Some(DataReaderItem { + data: DataFrame::Data { data: body, len }, + may_have_more: !*this.eos, + })) + } } struct SendQueue { @@ -79,6 +218,8 @@ impl SendQueue { }; self.items[pos_prio].1.push_back(item); } + // used only in tests. They should probably be rewriten + #[allow(dead_code)] fn pop(&mut self) -> Option { match self.items.pop_front() { None => None, @@ -94,6 +235,54 @@ impl SendQueue { fn is_empty(&self) -> bool { self.items.iter().all(|(_k, v)| v.is_empty()) } + + // this is like an async fn, but hand implemented + fn next_ready(&mut self) -> SendQueuePollNextReady<'_> { + SendQueuePollNextReady { queue: self } + } +} + +struct SendQueuePollNextReady<'a> { + queue: &'a mut SendQueue, +} + +impl<'a> futures::Future for SendQueuePollNextReady<'a> { + type Output = (RequestID, DataReaderItem); + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + for i in 0..self.queue.items.len() { + let (_prio, items_at_prio) = &mut self.queue.items[i]; + + for _ in 0..items_at_prio.len() { + let mut item = items_at_prio.pop_front().unwrap(); + + match Pin::new(&mut item.data).poll_next(ctx) { + Poll::Pending => items_at_prio.push_back(item), + Poll::Ready(Some(data)) => { + let id = item.id; + if data.may_have_more { + self.queue.push(item); + } else { + if items_at_prio.is_empty() { + // this priority level is empty, remove it + self.queue.items.remove(i); + } + } + return Poll::Ready((id, data)); + } + Poll::Ready(None) => { + if items_at_prio.is_empty() { + // this priority level is empty, remove it + self.queue.items.remove(i); + } + return Poll::Ready((item.id, DataReaderItem::empty_last())); + } + } + } + } + // TODO what do we do if self.queue is empty? We won't get scheduled again. + Poll::Pending + } } /// The SendLoop trait, which is implemented both by the client and the server @@ -108,7 +297,7 @@ impl SendQueue { pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec)>, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>, mut write: BoxStreamWrite, ) -> Result<(), Error> where @@ -117,55 +306,34 @@ pub(crate) trait SendLoop: Sync { let mut sending = SendQueue::new(); let mut should_exit = false; while !should_exit || !sending.is_empty() { - if let Ok((id, prio, data)) = msg_recv.try_recv() { - trace!("send_loop: got {}, {} bytes", id, data.len()); - sending.push(SendQueueItem { - id, - prio, - data, - cursor: 0, - }); - } else if let Some(mut item) = sending.pop() { - trace!( - "send_loop: sending bytes for {} ({} bytes, {} already sent)", - item.id, - item.data.len(), - item.cursor - ); - let header_id = RequestID::to_be_bytes(item.id); - write.write_all(&header_id[..]).await?; + let recv_fut = msg_recv.recv(); + futures::pin_mut!(recv_fut); + let send_fut = sending.next_ready(); - if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize { - let size_header = - ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION); - write.write_all(&size_header[..]).await?; - - let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize; - write.write_all(&item.data[item.cursor..new_cursor]).await?; - item.cursor = new_cursor; - - sending.push(item); - } else { - let send_len = (item.data.len() - item.cursor) as ChunkLength; - - let size_header = ChunkLength::to_be_bytes(send_len); - write.write_all(&size_header[..]).await?; - - write.write_all(&item.data[item.cursor..]).await?; + // recv_fut is cancellation-safe according to tokio doc, + // send_fut is cancellation-safe as implemented above? + use futures::future::Either; + match futures::future::select(recv_fut, send_fut).await { + Either::Left((sth, _send_fut)) => { + if let Some((id, prio, data)) = sth { + sending.push(SendQueueItem { + id, + prio, + data: data.into(), + }); + } else { + should_exit = true; + }; } - write.flush().await?; - } else { - let sth = msg_recv.recv().await; - if let Some((id, prio, data)) = sth { - trace!("send_loop: got {}, {} bytes", id, data.len()); - sending.push(SendQueueItem { - id, - prio, - data, - cursor: 0, - }); - } else { - should_exit = true; + Either::Right(((id, data), _recv_fut)) => { + trace!("send_loop: sending bytes for {}", id); + + let header_id = RequestID::to_be_bytes(id); + write.write_all(&header_id[..]).await?; + + write.write_all(&data.header()).await?; + write.write_all(data.data()).await?; + write.flush().await?; } } } @@ -175,6 +343,118 @@ pub(crate) trait SendLoop: Sync { } } +pub(crate) struct Framing { + direct: Vec, + stream: Option, +} + +impl Framing { + pub fn new(direct: Vec, stream: Option) -> Self { + assert!(direct.len() <= u32::MAX as usize); + Framing { direct, stream } + } + + pub fn into_stream(self) -> AssociatedStream { + use futures::stream; + let len = self.direct.len() as u32; + // required because otherwise the borrow-checker complains + let Framing { direct, stream } = self; + + let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) + .chain(stream::once(async move { Ok(direct) })); + + if let Some(stream) = stream { + Box::pin(res.chain(stream)) + } else { + Box::pin(res) + } + } + + pub async fn from_stream + Unpin + Send + 'static>( + mut stream: S, + ) -> Result { + let mut packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + if packet.len() < 4 { + return Err(Error::Framing); + } + + let mut len = [0; 4]; + len.copy_from_slice(&packet[..4]); + let len = u32::from_be_bytes(len); + packet.drain(..4); + + let mut buffer = Vec::new(); + let len = len as usize; + loop { + let max_cp = std::cmp::min(len - buffer.len(), packet.len()); + + buffer.extend_from_slice(&packet[..max_cp]); + if buffer.len() == len { + packet.drain(..max_cp); + break; + } + packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + } + + let stream: AssociatedStream = if packet.is_empty() { + Box::pin(stream) + } else { + Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) + }; + + Ok(Framing { + direct: buffer, + stream: Some(stream), + }) + } + + pub fn into_parts(self) -> (Vec, AssociatedStream) { + let Framing { direct, stream } = self; + (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) + } +} + +/// Structure to warn when the sender is dropped before end of stream was reached, like when +/// connection to some remote drops while transmitting data +struct Sender { + inner: UnboundedSender, + closed: bool, +} + +impl Sender { + fn new(inner: UnboundedSender) -> Self { + Sender { + inner, + closed: false, + } + } + + fn send(&self, packet: Packet) { + let _ = self.inner.unbounded_send(packet); + } + + fn end(&mut self) { + self.closed = true; + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if !self.closed { + self.send(Err(255)); + } + self.inner.close_channel(); + } +} + /// The RecvLoop trait, which is implemented both by the client and the server /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// and a prototype of a handler for received messages `.recv_handler()` that @@ -184,13 +464,13 @@ pub(crate) trait SendLoop: Sync { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, msg: Vec); + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut receiving = HashMap::new(); + let mut streams: HashMap = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -208,19 +488,33 @@ pub(crate) trait RecvLoop: Sync + 'static { trace!("recv_loop: got header size: {:04x}", size); let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; - let size = size & !CHUNK_HAS_CONTINUATION; + let is_error = (size & ERROR_MARKER) != 0; + let packet = if is_error { + Err(size as u8) + } else { + let size = size & !CHUNK_HAS_CONTINUATION; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + trace!("recv_loop: read {} bytes", next_slice.len()); + Ok(next_slice) + }; - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; - trace!("recv_loop: read {} bytes", next_slice.len()); + let mut sender = if let Some(send) = streams.remove(&(id)) { + send + } else { + let (send, recv) = unbounded(); + self.recv_handler(id, recv); + Sender::new(send) + }; - let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default(); - msg_bytes.extend_from_slice(&next_slice[..]); + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + sender.send(packet); if has_cont { - receiving.insert(id, msg_bytes); + streams.insert(id, sender); } else { - self.recv_handler(id, msg_bytes); + sender.end(); } } Ok(()) @@ -231,43 +525,44 @@ pub(crate) trait RecvLoop: Sync + 'static { mod test { use super::*; + fn empty_data() -> DataReader { + type Item = Packet; + let stream: Pin + Send + 'static>> = + Box::pin(futures::stream::empty::()); + stream.into() + } + #[test] fn test_priority_queue() { let i1 = SendQueueItem { id: 1, prio: PRIO_NORMAL, - data: vec![], - cursor: 0, + data: empty_data(), }; let i2 = SendQueueItem { id: 2, prio: PRIO_HIGH, - data: vec![], - cursor: 0, + data: empty_data(), }; let i2bis = SendQueueItem { id: 20, prio: PRIO_HIGH, - data: vec![], - cursor: 0, + data: empty_data(), }; let i3 = SendQueueItem { id: 3, prio: PRIO_HIGH | PRIO_SECONDARY, - data: vec![], - cursor: 0, + data: empty_data(), }; let i4 = SendQueueItem { id: 4, prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: vec![], - cursor: 0, + data: empty_data(), }; let i5 = SendQueueItem { id: 5, prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: vec![], - cursor: 0, + data: empty_data(), }; let mut q = SendQueue::new(); diff --git a/src/server.rs b/src/server.rs index 5465307..8075484 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,7 +2,6 @@ use std::net::SocketAddr; use std::sync::Arc; use arc_swap::ArcSwapOption; -use bytes::Bytes; use log::{debug, trace}; #[cfg(feature = "telemetry")] @@ -20,6 +19,7 @@ use tokio::select; use tokio::sync::{mpsc, watch}; use tokio_util::compat::*; +use futures::channel::mpsc::UnboundedReceiver; use futures::io::{AsyncReadExt, AsyncWriteExt}; use async_trait::async_trait; @@ -55,7 +55,7 @@ pub(crate) struct ServerConn { netapp: Arc, - resp_send: ArcSwapOption)>>, + resp_send: ArcSwapOption>, } impl ServerConn { @@ -123,7 +123,11 @@ impl ServerConn { Ok(()) } - async fn recv_handler_aux(self: &Arc, bytes: &[u8]) -> Result, Error> { + async fn recv_handler_aux( + self: &Arc, + bytes: &[u8], + stream: AssociatedStream, + ) -> Result<(Vec, Option), Error> { let msg = QueryMessage::decode(bytes)?; let path = String::from_utf8(msg.path.to_vec())?; @@ -156,11 +160,11 @@ impl ServerConn { span.set_attribute(KeyValue::new("path", path.to_string())); span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64)); - handler.handle(msg.body, self.peer_id) + handler.handle(msg.body, stream, self.peer_id) .with_context(Context::current_with_span(span)) .await } else { - handler.handle(msg.body, self.peer_id).await + handler.handle(msg.body, stream, self.peer_id).await } } } else { @@ -173,35 +177,40 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc, id: RequestID, bytes: Vec) { + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver) { let resp_send = self.resp_send.load_full().unwrap(); let self2 = self.clone(); tokio::spawn(async move { - trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); - let bytes: Bytes = bytes.into(); + trace!("ServerConn recv_handler {}", id); + let (bytes, stream) = Framing::from_stream(stream).await?.into_parts(); let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; - let resp = self2.recv_handler_aux(&bytes[..]).await; + let resp = self2.recv_handler_aux(&bytes[..], stream).await; - let resp_bytes = match resp { - Ok(rb) => { + let (resp_bytes, resp_stream) = match resp { + Ok((rb, rs)) => { let mut resp_bytes = vec![0u8]; resp_bytes.extend(rb); - resp_bytes + (resp_bytes, rs) } Err(e) => { let mut resp_bytes = vec![e.code()]; resp_bytes.extend(e.to_string().into_bytes()); - resp_bytes + (resp_bytes, None) } }; trace!("ServerConn sending response to {}: ", id); resp_send - .send((id, prio, resp_bytes)) - .log_err("ServerConn recv_handler send resp"); + .send(( + id, + prio, + Framing::new(resp_bytes, resp_stream).into_stream(), + )) + .log_err("ServerConn recv_handler send resp bytes"); + Ok::<_, Error>(()) }); } } diff --git a/src/test.rs b/src/test.rs index 82c7ba6..ecd5450 100644 --- a/src/test.rs +++ b/src/test.rs @@ -14,6 +14,7 @@ use crate::NodeID; #[tokio::test(flavor = "current_thread")] async fn test_with_basic_scheduler() { + pretty_env_logger::init(); run_test().await } diff --git a/src/util.rs b/src/util.rs index f4dfac7..186678d 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,10 +1,15 @@ +use crate::endpoint::SerializeMessage; + use std::net::SocketAddr; use std::net::ToSocketAddrs; +use std::pin::Pin; -use serde::Serialize; +use futures::Stream; use log::info; +use serde::Serialize; + use tokio::sync::watch; /// A node's identifier, which is also its public cryptographic key @@ -14,21 +19,36 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; /// A network key pub type NetworkKey = sodiumoxide::crypto::auth::Key; +/// A stream of associated data. +/// +/// The Stream can continue after receiving an error. +/// When sent through Netapp, the Vec may be split in smaller chunk in such a way +/// consecutive Vec may get merged, but Vec and error code may not be reordered +/// +/// Error code 255 means the stream was cut before its end. Other codes have no predefined +/// meaning, it's up to your application to define their semantic. +pub type AssociatedStream = Pin + Send>>; + +pub type Packet = Result, u8>; + /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. /// /// Field names and variant names are included in the serialization. /// This is used internally by the netapp communication protocol. -pub fn rmp_to_vec_all_named(val: &T) -> Result, rmp_serde::encode::Error> +pub fn rmp_to_vec_all_named( + val: &T, +) -> Result<(Vec, Option), rmp_serde::encode::Error> where - T: Serialize + ?Sized, + T: SerializeMessage + ?Sized, { let mut wr = Vec::with_capacity(128); let mut se = rmp_serde::Serializer::new(&mut wr) .with_struct_map() .with_string_variants(); + let (val, stream) = val.serialize_msg(); val.serialize(&mut se)?; - Ok(wr) + Ok((wr, stream)) } /// This async function returns only when a true signal was received