diff --git a/Cargo.lock b/Cargo.lock index aa6b4a4..ca7e3ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -428,7 +428,7 @@ dependencies = [ [[package]] name = "netapp" -version = "0.4.5" +version = "0.5.0" dependencies = [ "arc-swap", "async-trait", @@ -445,6 +445,7 @@ dependencies = [ "lru", "opentelemetry", "opentelemetry-contrib", + "pin-project", "rand", "rmp-serde", "serde", diff --git a/Cargo.toml b/Cargo.toml index 9a69d66..4dce280 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "netapp" -version = "0.4.5" +version = "0.5.0" authors = ["Alex Auvolat "] edition = "2018" license-file = "LICENSE" @@ -16,26 +16,27 @@ name = "netapp" [features] default = [] -basalt = ["lru", "rand"] -telemetry = ["opentelemetry", "opentelemetry-contrib", "rand"] +basalt = ["lru"] +telemetry = ["opentelemetry", "opentelemetry-contrib"] [dependencies] futures = "0.3.17" +pin-project = "1.0.10" tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] } -tokio-util = { version = "0.7", default-features = false, features = ["compat"] } +tokio-util = { version = "0.7", default-features = false, features = ["compat", "io"] } tokio-stream = "0.1.7" -serde = { version = "1.0", default-features = false, features = ["derive"] } +serde = { version = "1.0", default-features = false, features = ["derive", "rc"] } rmp-serde = "0.15" hex = "0.4.2" -rand = { version = "0.8", optional = true } +rand = { version = "0.8" } log = "0.4.8" arc-swap = "1.1" async-trait = "0.1.7" err-derive = "0.3" -bytes = "1.0" +bytes = "1.2" lru = { version = "0.7", optional = true } cfg-if = "1.0" diff --git a/Makefile b/Makefile index 5160725..de9a8f4 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,6 @@ all: - cargo build --all-features + #cargo build --all-features + cargo build cargo build --example fullmesh cargo build --all-features --example basalt RUST_LOG=netapp=trace,fullmesh=trace cargo run --example fullmesh -- -n 3242ce79e05e8b6a0e43441fbd140a906e13f335f298ae3a52f29784abbab500 -p 6c304114a0e1018bbe60502a34d33f4f439f370856c3333dda2726da01eb93a4894b7ef7249a71f11d342b69702f1beb7c93ec95fbcf122ad1eca583bb0629e7 diff --git a/examples/basalt.rs b/examples/basalt.rs index 318e37c..a5a25c3 100644 --- a/examples/basalt.rs +++ b/examples/basalt.rs @@ -14,8 +14,8 @@ use sodiumoxide::crypto::sign::ed25519; use tokio::sync::watch; use netapp::endpoint::*; +use netapp::message::*; use netapp::peering::basalt::*; -use netapp::proto::*; use netapp::util::parse_peer_addr; use netapp::{NetApp, NodeID}; @@ -145,7 +145,7 @@ impl Example { tokio::spawn(async move { match self2 .example_endpoint - .call(&p, &ExampleMessage { example_field: 42 }, PRIO_NORMAL) + .call(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL) .await { Ok(resp) => debug!("Got example response: {:?}", resp), diff --git a/examples/fullmesh.rs b/examples/fullmesh.rs index b068410..d0190ef 100644 --- a/examples/fullmesh.rs +++ b/examples/fullmesh.rs @@ -1,16 +1,24 @@ use std::io::Write; use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; -use log::info; - +use async_trait::async_trait; +use bytes::Bytes; +use futures::{stream, StreamExt}; +use log::*; +use serde::{Deserialize, Serialize}; use structopt::StructOpt; +use tokio::sync::watch; use sodiumoxide::crypto::auth; use sodiumoxide::crypto::sign::ed25519; +use netapp::endpoint::*; +use netapp::message::*; use netapp::peering::fullmesh::*; use netapp::util::*; -use netapp::NetApp; +use netapp::{NetApp, NodeID}; #[derive(StructOpt, Debug)] #[structopt(name = "netapp")] @@ -91,8 +99,121 @@ async fn main() { let watch_cancel = netapp::util::watch_ctrl_c(); + let example = Arc::new(Example { + netapp: netapp.clone(), + fullmesh: peering.clone(), + example_endpoint: netapp.endpoint("__netapp/examples/fullmesh.rs/Example".into()), + }); + example.example_endpoint.set_handler(example.clone()); + tokio::join!( + example.exchange_loop(watch_cancel.clone()), netapp.listen(listen_addr, public_addr, watch_cancel.clone()), peering.run(watch_cancel), ); } + +// ---- + +struct Example { + netapp: Arc, + fullmesh: Arc, + example_endpoint: Arc>, +} + +impl Example { + async fn exchange_loop(self: Arc, must_exit: watch::Receiver) { + let mut i = 12000; + while !*must_exit.borrow() { + tokio::time::sleep(Duration::from_secs(2)).await; + + let peers = self.fullmesh.get_peer_list(); + for p in peers.iter() { + let id = p.id; + if id == self.netapp.id { + continue; + } + i += 1; + let example_field = i; + let self2 = self.clone(); + tokio::spawn(async move { + info!( + "Send example query {} to {}", + example_field, + hex::encode(id) + ); + // Fake data stream with some delays in item production + let stream = + Box::pin(stream::iter([100, 200, 300, 400]).then(|x| async move { + tokio::time::sleep(Duration::from_millis(500)).await; + Ok(Bytes::from(vec![(x % 256) as u8; 133 * x])) + })); + match self2 + .example_endpoint + .call_streaming( + &id, + Req::new(ExampleMessage { example_field }) + .unwrap() + .with_stream(stream), + PRIO_NORMAL, + ) + .await + { + Ok(resp) => { + let (resp, stream) = resp.into_parts(); + info!( + "Got example response to {} from {}: {:?}", + example_field, + hex::encode(id), + resp + ); + let mut stream = stream.unwrap(); + while let Some(x) = stream.next().await { + info!("Response: stream got bytes {:?}", x.map(|b| b.len())); + } + } + Err(e) => warn!("Error with example request: {}", e), + } + }); + } + } + } +} + +#[async_trait] +impl StreamingEndpointHandler for Example { + async fn handle( + self: &Arc, + mut msg: Req, + _from: NodeID, + ) -> Resp { + info!( + "Got example message: {:?}, sending example response", + msg.msg() + ); + let source_stream = msg.take_stream().unwrap(); + // Return same stream with 300ms delay + let new_stream = Box::pin(source_stream.then(|x| async move { + tokio::time::sleep(Duration::from_millis(300)).await; + x + })); + Resp::new(ExampleResponse { + example_field: false, + }) + .with_stream(new_stream) + } +} + +#[derive(Serialize, Deserialize, Debug)] +struct ExampleMessage { + example_field: usize, +} + +#[derive(Serialize, Deserialize, Debug)] +struct ExampleResponse { + example_field: bool, +} + +impl Message for ExampleMessage { + type Response = ExampleResponse; +} diff --git a/src/bytes_buf.rs b/src/bytes_buf.rs new file mode 100644 index 0000000..931be82 --- /dev/null +++ b/src/bytes_buf.rs @@ -0,0 +1,175 @@ +use std::collections::VecDeque; + +use bytes::BytesMut; + +pub use bytes::Bytes; + +/// A circular buffer of bytes, internally represented as a list of Bytes +/// for optimization, but that for all intent and purposes acts just like +/// a big byte slice which can be extended on the right and from which +/// stuff can be taken on the left. +pub struct BytesBuf { + buf: VecDeque, + buf_len: usize, +} + +impl BytesBuf { + /// Creates a new empty BytesBuf + pub fn new() -> Self { + Self { + buf: VecDeque::new(), + buf_len: 0, + } + } + + /// Returns the number of bytes stored in the BytesBuf + #[inline] + pub fn len(&self) -> usize { + self.buf_len + } + + /// Returns true iff the BytesBuf contains zero bytes + #[inline] + pub fn is_empty(&self) -> bool { + self.buf_len == 0 + } + + /// Adds some bytes to the right of the buffer + pub fn extend(&mut self, b: Bytes) { + if !b.is_empty() { + self.buf_len += b.len(); + self.buf.push_back(b); + } + } + + /// Takes the whole content of the buffer and returns it as a single Bytes unit + pub fn take_all(&mut self) -> Bytes { + if self.buf.len() == 0 { + Bytes::new() + } else if self.buf.len() == 1 { + self.buf_len = 0; + self.buf.pop_back().unwrap() + } else { + let mut ret = BytesMut::with_capacity(self.buf_len); + for b in self.buf.iter() { + ret.extend_from_slice(&b[..]); + } + self.buf.clear(); + self.buf_len = 0; + ret.freeze() + } + } + + /// Takes at most max_len bytes from the left of the buffer + pub fn take_max(&mut self, max_len: usize) -> Bytes { + if self.buf_len <= max_len { + self.take_all() + } else { + self.take_exact_ok(max_len) + } + } + + /// Take exactly len bytes from the left of the buffer, returns None if + /// the BytesBuf doesn't contain enough data + pub fn take_exact(&mut self, len: usize) -> Option { + if self.buf_len < len { + None + } else { + Some(self.take_exact_ok(len)) + } + } + + fn take_exact_ok(&mut self, len: usize) -> Bytes { + assert!(len <= self.buf_len); + let front = self.buf.pop_front().unwrap(); + if front.len() > len { + self.buf.push_front(front.slice(len..)); + self.buf_len -= len; + front.slice(..len) + } else if front.len() == len { + self.buf_len -= len; + front + } else { + let mut ret = BytesMut::with_capacity(len); + ret.extend_from_slice(&front[..]); + self.buf_len -= front.len(); + while ret.len() < len { + let front = self.buf.pop_front().unwrap(); + if front.len() > len - ret.len() { + let take = len - ret.len(); + ret.extend_from_slice(&front[..take]); + self.buf.push_front(front.slice(take..)); + self.buf_len -= take; + break; + } else { + ret.extend_from_slice(&front[..]); + self.buf_len -= front.len(); + } + } + ret.freeze() + } + } + + /// Return the internal sequence of Bytes slices that make up the buffer + pub fn into_slices(self) -> VecDeque { + self.buf + } +} + +impl From for BytesBuf { + fn from(b: Bytes) -> BytesBuf { + let mut ret = BytesBuf::new(); + ret.extend(b); + ret + } +} + +impl From for Bytes { + fn from(mut b: BytesBuf) -> Bytes { + b.take_all() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_bytes_buf() { + let mut buf = BytesBuf::new(); + assert!(buf.len() == 0); + assert!(buf.is_empty()); + + buf.extend(Bytes::from(b"Hello, world!".to_vec())); + assert!(buf.len() == 13); + assert!(!buf.is_empty()); + + buf.extend(Bytes::from(b"1234567890".to_vec())); + assert!(buf.len() == 23); + assert!(!buf.is_empty()); + + assert_eq!( + buf.take_all(), + Bytes::from(b"Hello, world!1234567890".to_vec()) + ); + assert!(buf.len() == 0); + assert!(buf.is_empty()); + + buf.extend(Bytes::from(b"1234567890".to_vec())); + buf.extend(Bytes::from(b"Hello, world!".to_vec())); + assert!(buf.len() == 23); + assert!(!buf.is_empty()); + + assert_eq!(buf.take_max(12), Bytes::from(b"1234567890He".to_vec())); + assert!(buf.len() == 11); + + assert_eq!(buf.take_exact(12), None); + assert!(buf.len() == 11); + assert_eq!( + buf.take_exact(11), + Some(Bytes::from(b"llo, world!".to_vec())) + ); + assert!(buf.len() == 0); + assert!(buf.is_empty()); + } +} diff --git a/src/client.rs b/src/client.rs index 5c5a05b..d82c91e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,12 +1,18 @@ -use std::borrow::Borrow; use std::collections::HashMap; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::atomic::{self, AtomicU32}; use std::sync::{Arc, Mutex}; +use std::task::Poll; use arc_swap::ArcSwapOption; +use async_trait::async_trait; +use bytes::Bytes; use log::{debug, error, trace}; +use futures::io::AsyncReadExt; +use futures::Stream; +use kuska_handshake::async_std::{handshake_client, BoxStream}; use tokio::net::TcpStream; use tokio::select; use tokio::sync::{mpsc, oneshot, watch}; @@ -20,27 +26,22 @@ use opentelemetry::{ #[cfg(feature = "telemetry")] use opentelemetry_contrib::trace::propagator::binary::*; -use futures::io::AsyncReadExt; - -use async_trait::async_trait; - -use kuska_handshake::async_std::{handshake_client, BoxStream}; - -use crate::endpoint::*; use crate::error::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; -use crate::proto2::*; +use crate::recv::*; +use crate::send::*; +use crate::stream::*; use crate::util::*; pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption)>>, + query_send: ArcSwapOption>, next_query_number: AtomicU32, - inflight: Mutex>>>, + inflight: Mutex>>, } impl ClientConn { @@ -139,15 +140,14 @@ impl ClientConn { self.query_send.store(None); } - pub(crate) async fn call( + pub(crate) async fn call( self: Arc, - rq: B, + req: Req, path: &str, prio: RequestPriority, - ) -> Result<::Response, Error> + ) -> Result, Error> where T: Message, - B: Borrow, { let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; @@ -162,24 +162,16 @@ impl ClientConn { .with_kind(SpanKind::Client) .start(&tracer); let propagator = BinaryPropagator::new(); - let telemetry_id = Some(propagator.to_bytes(span.span_context()).to_vec()); + let telemetry_id: Bytes = propagator.to_bytes(span.span_context()).to_vec().into(); } else { - let telemetry_id: Option> = None; + let telemetry_id: Bytes = Bytes::new(); } }; // Encode request - let body = rmp_to_vec_all_named(rq.borrow())?; - drop(rq); - - let request = QueryMessage { - prio, - path: path.as_bytes(), - telemetry_id, - body: &body[..], - }; - let bytes = request.encode(); - drop(body); + let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id); + let req_msg_len = req_enc.msg.len(); + let (req_stream, req_order) = req_enc.encode(); // Send request through let (resp_send, resp_recv) = oneshot::channel(); @@ -188,46 +180,41 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch.send(vec![]).is_err() { - debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); - } + let _ = old_ch.send(Box::pin(futures::stream::once(async move { + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "RequestID collision, too many inflight requests", + )) + }))); } - trace!("request: query_send {}, {} bytes", id, bytes.len()); + debug!( + "request: query_send {}, path {}, prio {} (serialized message: {} bytes)", + id, path, prio, req_msg_len + ); #[cfg(feature = "telemetry")] - span.set_attribute(KeyValue::new("len_query", bytes.len() as i64)); + span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64)); - query_send.send((id, prio, bytes))?; + query_send.send(SendItem::Stream(id, prio, req_order, req_stream))?; + + let canceller = CancelOnDrop::new(id, query_send.as_ref().clone()); cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { - let resp = resp_recv + let stream = resp_recv .with_context(Context::current_with_span(span)) .await?; } else { - let resp = resp_recv.await?; + let stream = resp_recv.await?; } } - if resp.is_empty() { - return Err(Error::Message( - "Response is 0 bytes, either a collision or a protocol error".into(), - )); - } + let stream = Box::pin(canceller.for_stream(stream)); - trace!("request response {}: ", id); - - let code = resp[0]; - if code == 0 { - Ok(rmp_serde::decode::from_read_ref::< - _, - ::Response, - >(&resp[1..])?) - } else { - let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); - Err(Error::Remote(code, msg)) - } + let resp_enc = RespEnc::decode(stream).await?; + debug!("client: got response to request {} (path {})", id, path); + Resp::from_enc(resp_enc) } } @@ -235,14 +222,71 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc, id: RequestID, msg: Vec) { - trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len()); + fn recv_handler(self: &Arc, id: RequestID, stream: ByteStream) { + trace!("ClientConn recv_handler {}", id); let mut inflight = self.inflight.lock().unwrap(); if let Some(ch) = inflight.remove(&id) { - if ch.send(msg).is_err() { + if ch.send(stream).is_err() { debug!("Could not send request response, probably because request was interrupted. Dropping response."); } + } else { + debug!("Got unexpected response to request {}, dropping it", id); } } } + +// ---- + +struct CancelOnDrop { + id: RequestID, + query_send: mpsc::UnboundedSender, +} + +impl CancelOnDrop { + fn new(id: RequestID, query_send: mpsc::UnboundedSender) -> Self { + Self { id, query_send } + } + fn for_stream(self, stream: ByteStream) -> CancelOnDropStream { + CancelOnDropStream { + cancel: Some(self), + stream: stream, + } + } +} + +impl Drop for CancelOnDrop { + fn drop(&mut self) { + trace!("cancelling request {}", self.id); + let _ = self.query_send.send(SendItem::Cancel(self.id)); + } +} + +#[pin_project::pin_project] +struct CancelOnDropStream { + cancel: Option, + #[pin] + stream: ByteStream, +} + +impl Stream for CancelOnDropStream { + type Item = Packet; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.project(); + let res = this.stream.poll_next(cx); + if matches!(res, Poll::Ready(None)) { + if let Some(c) = this.cancel.take() { + std::mem::forget(c) + } + } + res + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} diff --git a/src/endpoint.rs b/src/endpoint.rs index 42e9a98..3cafafe 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -1,28 +1,44 @@ -use std::borrow::Borrow; use std::marker::PhantomData; use std::sync::Arc; use arc_swap::ArcSwapOption; use async_trait::async_trait; -use serde::{Deserialize, Serialize}; - use crate::error::Error; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; -use crate::util::*; - -/// This trait should be implemented by all messages your application -/// wants to handle -pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { - type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; -} /// This trait should be implemented by an object of your application -/// that can handle a message of type `M`. +/// that can handle a message of type `M`, if it wishes to handle +/// streams attached to the request and/or to send back streams +/// attached to the response.. /// /// The handler object should be in an Arc, see `Endpoint::set_handler` #[async_trait] +pub trait StreamingEndpointHandler: Send + Sync +where + M: Message, +{ + async fn handle(self: &Arc, m: Req, from: NodeID) -> Resp; +} + +/// If one simply wants to use an endpoint in a client fashion, +/// without locally serving requests to that endpoint, +/// use the unit type `()` as the handler type: +/// it will panic if it is ever made to handle request. +#[async_trait] +impl EndpointHandler for () { + async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response { + panic!("This endpoint should not have a local handler."); + } +} + +// ---- + +/// This trait should be implemented by an object of your application +/// that can handle a message of type `M`, in the cases where it doesn't +/// care about attached stream in the request nor in the response. +#[async_trait] pub trait EndpointHandler: Send + Sync where M: Message, @@ -30,17 +46,22 @@ where async fn handle(self: &Arc, m: &M, from: NodeID) -> M::Response; } -/// If one simply wants to use an endpoint in a client fashion, -/// without locally serving requests to that endpoint, -/// use the unit type `()` as the handler type: -/// it will panic if it is ever made to handle request. #[async_trait] -impl EndpointHandler for () { - async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response { - panic!("This endpoint should not have a local handler."); +impl StreamingEndpointHandler for T +where + T: EndpointHandler, + M: Message, +{ + async fn handle(self: &Arc, mut m: Req, from: NodeID) -> Resp { + // Immediately drop stream to ignore all data that comes in, + // instead of buffering it indefinitely + drop(m.take_stream()); + Resp::new(EndpointHandler::handle(self, m.msg(), from).await) } } +// ---- + /// This struct represents an endpoint for message of type `M`. /// /// Creating a new endpoint is done by calling `NetApp::endpoint`. @@ -50,13 +71,13 @@ impl EndpointHandler for () { /// An `Endpoint` is used both to send requests to remote nodes, /// and to specify the handler for such requests on the local node. /// The type `H` represents the type of the handler object for -/// endpoint messages (see `EndpointHandler`). +/// endpoint messages (see `StreamingEndpointHandler`). pub struct Endpoint where M: Message, - H: EndpointHandler, + H: StreamingEndpointHandler, { - phantom: PhantomData, + _phantom: PhantomData, netapp: Arc, path: String, handler: ArcSwapOption, @@ -65,11 +86,11 @@ where impl Endpoint where M: Message, - H: EndpointHandler, + H: StreamingEndpointHandler, { pub(crate) fn new(netapp: Arc, path: String) -> Self { Self { - phantom: PhantomData::default(), + _phantom: PhantomData::default(), netapp, path, handler: ArcSwapOption::from(None), @@ -88,20 +109,22 @@ where } /// Call this endpoint on a remote node (or on the local node, - /// for that matter) - pub async fn call( + /// for that matter). This function invokes the full version that + /// allows to attach a stream to the request and to + /// receive such a stream attached to the response. + pub async fn call_streaming( &self, target: &NodeID, - req: B, + req: T, prio: RequestPriority, - ) -> Result<::Response, Error> + ) -> Result, Error> where - B: Borrow, + T: IntoReq, { if *target == self.netapp.id { match self.handler.load_full() { None => Err(Error::NoHandler), - Some(h) => Ok(h.handle(req.borrow(), self.netapp.id).await), + Some(h) => Ok(h.handle(req.into_req_local(), self.netapp.id).await), } } else { let conn = self @@ -116,10 +139,22 @@ where "Not connected: {}", hex::encode(&target[..8]) ))), - Some(c) => c.call(req, self.path.as_str(), prio).await, + Some(c) => c.call(req.into_req()?, self.path.as_str(), prio).await, } } } + + /// Call this endpoint on a remote node. This function is the simplified + /// version that doesn't allow to have streams attached to the request + /// or the response; see `call_streaming` for the full version. + pub async fn call( + &self, + target: &NodeID, + req: M, + prio: RequestPriority, + ) -> Result<::Response, Error> { + Ok(self.call_streaming(target, req, prio).await?.into_msg()) + } } // ---- Internal stuff ---- @@ -128,7 +163,7 @@ pub(crate) type DynEndpoint = Box; #[async_trait] pub(crate) trait GenericEndpoint { - async fn handle(&self, buf: &[u8], from: NodeID) -> Result, Error>; + async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result; fn drop_handler(&self); fn clone_endpoint(&self) -> DynEndpoint; } @@ -137,22 +172,21 @@ pub(crate) trait GenericEndpoint { pub(crate) struct EndpointArc(pub(crate) Arc>) where M: Message, - H: EndpointHandler; + H: StreamingEndpointHandler; #[async_trait] impl GenericEndpoint for EndpointArc where - M: Message + 'static, - H: EndpointHandler + 'static, + M: Message, + H: StreamingEndpointHandler + 'static, { - async fn handle(&self, buf: &[u8], from: NodeID) -> Result, Error> { + async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result { match self.0.handler.load_full() { None => Err(Error::NoHandler), Some(h) => { - let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?; - let res = h.handle(&req, from).await; - let res_bytes = rmp_to_vec_all_named(&res)?; - Ok(res_bytes) + let req = Req::from_enc(req_enc)?; + let res = h.handle(req, from).await; + Ok(res.into_enc()?) } } } diff --git a/src/error.rs b/src/error.rs index 99acdd1..c0aeeac 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ -use err_derive::Error; use std::io; +use err_derive::Error; use log::error; #[derive(Debug, Error)] @@ -25,6 +25,15 @@ pub enum Error { #[error(display = "UTF8 error: {}", _0)] UTF8(#[error(source)] std::string::FromUtf8Error), + #[error(display = "Framing protocol error")] + Framing, + + #[error(display = "Remote error ({:?}): {}", _0, _1)] + Remote(io::ErrorKind, String), + + #[error(display = "Request ID collision")] + IdCollision, + #[error(display = "{}", _0)] Message(String), @@ -36,28 +45,6 @@ pub enum Error { #[error(display = "Version mismatch: {}", _0)] VersionMismatch(String), - - #[error(display = "Remote error {}: {}", _0, _1)] - Remote(u8, String), -} - -impl Error { - pub fn code(&self) -> u8 { - match self { - Self::Io(_) => 100, - Self::TokioJoin(_) => 110, - Self::OneshotRecv(_) => 111, - Self::RMPEncode(_) => 10, - Self::RMPDecode(_) => 11, - Self::UTF8(_) => 12, - Self::NoHandler => 20, - Self::ConnectionClosed => 21, - Self::Handshake(_) => 30, - Self::VersionMismatch(_) => 31, - Self::Remote(c, _) => *c, - Self::Message(_) => 99, - } - } } impl From> for Error { @@ -101,3 +88,39 @@ where } } } + +// ---- Helpers for serializing I/O Errors + +pub(crate) fn u8_to_io_errorkind(v: u8) -> std::io::ErrorKind { + use std::io::ErrorKind; + match v { + 101 => ErrorKind::ConnectionAborted, + 102 => ErrorKind::BrokenPipe, + 103 => ErrorKind::WouldBlock, + 104 => ErrorKind::InvalidInput, + 105 => ErrorKind::InvalidData, + 106 => ErrorKind::TimedOut, + 107 => ErrorKind::Interrupted, + 108 => ErrorKind::UnexpectedEof, + 109 => ErrorKind::OutOfMemory, + 110 => ErrorKind::ConnectionReset, + _ => ErrorKind::Other, + } +} + +pub(crate) fn io_errorkind_to_u8(kind: std::io::ErrorKind) -> u8 { + use std::io::ErrorKind; + match kind { + ErrorKind::ConnectionAborted => 101, + ErrorKind::BrokenPipe => 102, + ErrorKind::WouldBlock => 103, + ErrorKind::InvalidInput => 104, + ErrorKind::InvalidData => 105, + ErrorKind::TimedOut => 106, + ErrorKind::Interrupted => 107, + ErrorKind::UnexpectedEof => 108, + ErrorKind::OutOfMemory => 109, + ErrorKind::ConnectionReset => 110, + _ => 100, + } +} diff --git a/src/lib.rs b/src/lib.rs index cb24337..8e30e40 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,21 +13,23 @@ //! about message priorization. //! Also check out the examples to learn how to use this crate. +pub mod bytes_buf; pub mod error; +pub mod stream; pub mod util; pub mod endpoint; -pub mod proto; +pub mod message; mod client; -mod proto2; +mod recv; +mod send; mod server; pub mod netapp; pub mod peering; pub use crate::netapp::*; -pub use util::{NetworkKey, NodeID, NodeKey}; #[cfg(test)] mod test; diff --git a/src/message.rs b/src/message.rs new file mode 100644 index 0000000..9cc1a3e --- /dev/null +++ b/src/message.rs @@ -0,0 +1,522 @@ +use std::fmt; +use std::marker::PhantomData; +use std::sync::Arc; + +use bytes::{BufMut, Bytes, BytesMut}; +use rand::prelude::*; +use serde::{Deserialize, Serialize}; + +use futures::stream::StreamExt; + +use crate::error::*; +use crate::stream::*; +use crate::util::*; + +/// Priority of a request (click to read more about priorities). +/// +/// This priority value is used to priorize messages +/// in the send queue of the client, and their responses in the send queue of the +/// server. Lower values mean higher priority. +/// +/// This mechanism is usefull for messages bigger than the maximum chunk size +/// (set at `0x4000` bytes), such as large file transfers. +/// In such case, all of the messages in the send queue with the highest priority +/// will take turns to send individual chunks, in a round-robin fashion. +/// Once all highest priority messages are sent successfully, the messages with +/// the next highest priority will begin being sent in the same way. +/// +/// The same priority value is given to a request and to its associated response. +pub type RequestPriority = u8; + +/// Priority class: high +pub const PRIO_HIGH: RequestPriority = 0x20; +/// Priority class: normal +pub const PRIO_NORMAL: RequestPriority = 0x40; +/// Priority class: background +pub const PRIO_BACKGROUND: RequestPriority = 0x80; +/// Priority: primary among given class +pub const PRIO_PRIMARY: RequestPriority = 0x00; +/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) +pub const PRIO_SECONDARY: RequestPriority = 0x01; + +// ---- + +/// An order tag can be added to a message or a response to indicate +/// whether it should be sent after or before other messages with order tags +/// referencing a same stream +#[derive(Clone, Copy, Serialize, Deserialize, Debug)] +pub struct OrderTag(pub(crate) u64, pub(crate) u64); + +/// A stream is an opaque identifier that defines a set of messages +/// or responses that are ordered wrt one another using to order tags. +#[derive(Clone, Copy)] +pub struct OrderTagStream(u64); + +impl OrderTag { + /// Create a new stream from which to generate order tags. Example: + /// ```ignore + /// let stream = OrderTag.stream(); + /// let tag_1 = stream.order(1); + /// let tag_2 = stream.order(2); + /// ``` + pub fn stream() -> OrderTagStream { + OrderTagStream(thread_rng().gen()) + } +} +impl OrderTagStream { + /// Create the order tag for message `order` in this stream + pub fn order(&self, order: u64) -> OrderTag { + OrderTag(self.0, order) + } +} + +// ---- + +/// This trait should be implemented by all messages your application +/// wants to handle. It specifies which data type should be sent +/// as a response to this message in the RPC protocol. +pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static { + /// The type of the response that is sent in response to this message + type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static; +} + +// ---- + +/// The Req is a helper object used to create requests and attach them +/// a stream of data. If the stream is a fixed Bytes and not a ByteStream, +/// Req is cheaply clonable to allow the request to be sent to different +/// peers (Clone will panic if the stream is a ByteStream). +pub struct Req { + pub(crate) msg: Arc, + pub(crate) msg_ser: Option, + pub(crate) stream: AttachedStream, + pub(crate) order_tag: Option, +} + +impl Req { + /// Creates a new request from a base message `M` + pub fn new(v: M) -> Result { + Ok(v.into_req()?) + } + + /// Attach a stream to message in request, where the stream is streamed + /// from a fixed `Bytes` buffer + pub fn with_stream_from_buffer(self, b: Bytes) -> Self { + Self { + stream: AttachedStream::Fixed(b), + ..self + } + } + + /// Attach a stream to message in request, where the stream is + /// an instance of `ByteStream`. Note than when a `Req` has an attached + /// stream which is a `ByteStream` instance, it can no longer be cloned + /// to be sent to different nodes (`.clone()` will panic) + pub fn with_stream(self, b: ByteStream) -> Self { + Self { + stream: AttachedStream::Stream(b), + ..self + } + } + + /// Add an order tag to this request to indicate in which order it should + /// be sent. + pub fn with_order_tag(self, order_tag: OrderTag) -> Self { + Self { + order_tag: Some(order_tag), + ..self + } + } + + /// Get a reference to the message `M` contained in this request + pub fn msg(&self) -> &M { + &self.msg + } + + /// Takes out the stream attached to this request, if any + pub fn take_stream(&mut self) -> Option { + std::mem::replace(&mut self.stream, AttachedStream::None).into_stream() + } + + pub(crate) fn into_enc( + self, + prio: RequestPriority, + path: Bytes, + telemetry_id: Bytes, + ) -> ReqEnc { + ReqEnc { + prio, + path, + telemetry_id, + msg: self.msg_ser.unwrap(), + stream: self.stream.into_stream(), + order_tag: self.order_tag, + } + } + + pub(crate) fn from_enc(enc: ReqEnc) -> Result { + let msg = rmp_serde::decode::from_read_ref(&enc.msg)?; + Ok(Req { + msg: Arc::new(msg), + msg_ser: Some(enc.msg), + stream: enc + .stream + .map(AttachedStream::Stream) + .unwrap_or(AttachedStream::None), + order_tag: enc.order_tag, + }) + } +} + +/// `IntoReq` represents any object that can be transformed into `Req` +pub trait IntoReq { + /// Transform the object into a `Req`, serializing the message M + /// to be sent to remote nodes + fn into_req(self) -> Result, rmp_serde::encode::Error>; + /// Transform the object into a `Req`, skipping the serialization + /// of message M, in the case we are not sending this RPC message to + /// a remote node + fn into_req_local(self) -> Req; +} + +impl IntoReq for M { + fn into_req(self) -> Result, rmp_serde::encode::Error> { + let msg_ser = rmp_to_vec_all_named(&self)?; + Ok(Req { + msg: Arc::new(self), + msg_ser: Some(Bytes::from(msg_ser)), + stream: AttachedStream::None, + order_tag: None, + }) + } + fn into_req_local(self) -> Req { + Req { + msg: Arc::new(self), + msg_ser: None, + stream: AttachedStream::None, + order_tag: None, + } + } +} + +impl IntoReq for Req { + fn into_req(self) -> Result, rmp_serde::encode::Error> { + Ok(self) + } + fn into_req_local(self) -> Req { + self + } +} + +impl Clone for Req { + fn clone(&self) -> Self { + let stream = match &self.stream { + AttachedStream::None => AttachedStream::None, + AttachedStream::Fixed(b) => AttachedStream::Fixed(b.clone()), + AttachedStream::Stream(_) => { + panic!("Cannot clone a Req<_> with a non-buffer attached stream") + } + }; + Self { + msg: self.msg.clone(), + msg_ser: self.msg_ser.clone(), + stream, + order_tag: self.order_tag, + } + } +} + +impl fmt::Debug for Req +where + M: Message + fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Req[{:?}", self.msg)?; + match &self.stream { + AttachedStream::None => write!(f, "]"), + AttachedStream::Fixed(b) => write!(f, "; stream=buf:{}]", b.len()), + AttachedStream::Stream(_) => write!(f, "; stream]"), + } + } +} + +// ---- + +/// The Resp represents a full response from a RPC that may have +/// an attached stream. +pub struct Resp { + pub(crate) _phantom: PhantomData, + pub(crate) msg: M::Response, + pub(crate) stream: AttachedStream, + pub(crate) order_tag: Option, +} + +impl Resp { + /// Creates a new response from a base response message + pub fn new(v: M::Response) -> Self { + Resp { + _phantom: Default::default(), + msg: v, + stream: AttachedStream::None, + order_tag: None, + } + } + + /// Attach a stream to message in response, where the stream is streamed + /// from a fixed `Bytes` buffer + pub fn with_stream_from_buffer(self, b: Bytes) -> Self { + Self { + stream: AttachedStream::Fixed(b), + ..self + } + } + + /// Attach a stream to message in response, where the stream is + /// an instance of `ByteStream`. + pub fn with_stream(self, b: ByteStream) -> Self { + Self { + stream: AttachedStream::Stream(b), + ..self + } + } + + /// Add an order tag to this response to indicate in which order it should + /// be sent. + pub fn with_order_tag(self, order_tag: OrderTag) -> Self { + Self { + order_tag: Some(order_tag), + ..self + } + } + + /// Get a reference to the response message contained in this request + pub fn msg(&self) -> &M::Response { + &self.msg + } + + /// Transforms the `Resp` into the response message it contains, + /// dropping everything else (including attached data stream) + pub fn into_msg(self) -> M::Response { + self.msg + } + + /// Transforms the `Resp` into, on the one side, the response message + /// it contains, and on the other side, the associated data stream + /// if it exists + pub fn into_parts(self) -> (M::Response, Option) { + (self.msg, self.stream.into_stream()) + } + + pub(crate) fn into_enc(self) -> Result { + Ok(RespEnc { + msg: rmp_to_vec_all_named(&self.msg)?.into(), + stream: self.stream.into_stream(), + order_tag: self.order_tag, + }) + } + + pub(crate) fn from_enc(enc: RespEnc) -> Result { + let msg = rmp_serde::decode::from_read_ref(&enc.msg)?; + Ok(Self { + _phantom: Default::default(), + msg, + stream: enc + .stream + .map(AttachedStream::Stream) + .unwrap_or(AttachedStream::None), + order_tag: enc.order_tag, + }) + } +} + +impl fmt::Debug for Resp +where + M: Message, + ::Response: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Resp[{:?}", self.msg)?; + match &self.stream { + AttachedStream::None => write!(f, "]"), + AttachedStream::Fixed(b) => write!(f, "; stream=buf:{}]", b.len()), + AttachedStream::Stream(_) => write!(f, "; stream]"), + } + } +} + +// ---- + +pub(crate) enum AttachedStream { + None, + Fixed(Bytes), + Stream(ByteStream), +} + +impl AttachedStream { + pub fn into_stream(self) -> Option { + match self { + AttachedStream::None => None, + AttachedStream::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), + AttachedStream::Stream(s) => Some(s), + } + } +} + +// ---- ---- + +/// Encoding for requests into a ByteStream: +/// - priority: u8 +/// - path length: u8 +/// - path: [u8; path length] +/// - telemetry id length: u8 +/// - telemetry id: [u8; telemetry id length] +/// - msg len: u32 +/// - msg [u8; ..] +/// - the attached stream as the rest of the encoded stream +pub(crate) struct ReqEnc { + pub(crate) prio: RequestPriority, + pub(crate) path: Bytes, + pub(crate) telemetry_id: Bytes, + pub(crate) msg: Bytes, + pub(crate) stream: Option, + pub(crate) order_tag: Option, +} + +impl ReqEnc { + pub(crate) fn encode(self) -> (ByteStream, Option) { + let mut buf = BytesMut::with_capacity( + self.path.len() + self.telemetry_id.len() + self.msg.len() + 16, + ); + + buf.put_u8(self.prio); + + buf.put_u8(self.path.len() as u8); + buf.put(self.path); + + buf.put_u8(self.telemetry_id.len() as u8); + buf.put(&self.telemetry_id[..]); + + buf.put_u32(self.msg.len() as u32); + + let header = buf.freeze(); + + let res_stream: ByteStream = if let Some(stream) = self.stream { + Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)]).chain(stream)) + } else { + Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)])) + }; + (res_stream, self.order_tag) + } + + pub(crate) async fn decode(stream: ByteStream) -> Result { + Self::decode_aux(stream) + .await + .map_err(read_exact_error_to_error) + } + + async fn decode_aux(stream: ByteStream) -> Result { + let mut reader = ByteStreamReader::new(stream); + + let prio = reader.read_u8().await?; + + let path_len = reader.read_u8().await?; + let path = reader.read_exact(path_len as usize).await?; + + let telemetry_id_len = reader.read_u8().await?; + let telemetry_id = reader.read_exact(telemetry_id_len as usize).await?; + + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; + + Ok(Self { + prio, + path, + telemetry_id, + msg, + stream: Some(reader.into_stream()), + order_tag: None, + }) + } +} + +/// Encoding for responses into a ByteStream: +/// IF SUCCESS: +/// - 0: u8 +/// - msg len: u32 +/// - msg [u8; ..] +/// - the attached stream as the rest of the encoded stream +/// IF ERROR: +/// - message length + 1: u8 +/// - error code: u8 +/// - message: [u8; message_length] +pub(crate) struct RespEnc { + msg: Bytes, + stream: Option, + order_tag: Option, +} + +impl RespEnc { + pub(crate) fn encode(resp: Result) -> (ByteStream, Option) { + match resp { + Ok(Self { + msg, + stream, + order_tag, + }) => { + let mut buf = BytesMut::with_capacity(4); + buf.put_u32(msg.len() as u32); + let header = buf.freeze(); + + let res_stream: ByteStream = if let Some(stream) = stream { + Box::pin(futures::stream::iter([Ok(header), Ok(msg)]).chain(stream)) + } else { + Box::pin(futures::stream::iter([Ok(header), Ok(msg)])) + }; + (res_stream, order_tag) + } + Err(err) => { + let err = std::io::Error::new( + std::io::ErrorKind::Other, + format!("netapp error: {}", err), + ); + ( + Box::pin(futures::stream::once(async move { Err(err) })), + None, + ) + } + } + } + + pub(crate) async fn decode(stream: ByteStream) -> Result { + Self::decode_aux(stream) + .await + .map_err(read_exact_error_to_error) + } + + async fn decode_aux(stream: ByteStream) -> Result { + let mut reader = ByteStreamReader::new(stream); + + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; + + // Check whether the response stream still has data or not. + // If no more data is coming, this will defuse the request canceller. + // If we didn't do this, and the client doesn't try to read from the stream, + // the request canceller doesn't know that we read everything and + // sends a cancellation message to the server (which they don't care about). + reader.fill_buffer().await; + + Ok(Self { + msg, + stream: Some(reader.into_stream()), + order_tag: None, + }) + } +} + +fn read_exact_error_to_error(e: ReadExactError) -> Error { + match e { + ReadExactError::Stream(err) => Error::Remote(err.kind(), err.to_string()), + ReadExactError::UnexpectedEos => Error::Framing, + } +} diff --git a/src/netapp.rs b/src/netapp.rs index e9efa2e..b1ad9db 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -20,9 +20,15 @@ use tokio::sync::{mpsc, watch}; use crate::client::*; use crate::endpoint::*; use crate::error::*; -use crate::proto::*; +use crate::message::*; use crate::server::*; -use crate::util::*; + +/// A node's identifier, which is also its public cryptographic key +pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey; +/// A node's secret key +pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; +/// A network key +pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// Tag which is exchanged between client and server upon connection establishment /// to check that they are running compatible versions of Netapp, @@ -30,7 +36,7 @@ use crate::util::*; pub(crate) type VersionTag = [u8; 16]; /// Value of the Netapp version used in the version tag -pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004 +pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700005; // netapp 0x0005 #[derive(Serialize, Deserialize, Debug)] pub(crate) struct HelloMessage { @@ -152,7 +158,7 @@ impl NetApp { pub fn endpoint(self: &Arc, path: String) -> Arc> where M: Message + 'static, - H: EndpointHandler + 'static, + H: StreamingEndpointHandler + 'static, { let endpoint = Arc::new(Endpoint::::new(self.clone(), path.clone())); let endpoint_arc = EndpointArc(endpoint.clone()); @@ -397,13 +403,14 @@ impl NetApp { hello_endpoint .call( &conn.peer_id, - &HelloMessage { + HelloMessage { server_addr, server_port, }, PRIO_NORMAL, ) .await + .map(|_| ()) .log_err("Sending hello message"); }); } diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index cd70c2b..2c27dba 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -14,8 +14,8 @@ use sodiumoxide::crypto::hash; use tokio::sync::watch; use crate::endpoint::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; use crate::NodeID; // -- Protocol messages -- @@ -331,7 +331,7 @@ impl Basalt { async fn do_pull(self: Arc, peer: NodeID) { match self .pull_endpoint - .call(&peer, &PullMessage {}, PRIO_NORMAL) + .call(&peer, PullMessage {}, PRIO_NORMAL) .await { Ok(resp) => { @@ -346,7 +346,7 @@ impl Basalt { async fn do_push(self: Arc, peer: NodeID) { let push_msg = self.make_push_message(); - match self.push_endpoint.call(&peer, &push_msg, PRIO_NORMAL).await { + match self.push_endpoint.call(&peer, push_msg, PRIO_NORMAL).await { Ok(_) => { trace!("KYEV PEXo {}", hex::encode(peer)); } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 208cfe4..fb2e3d1 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -17,7 +17,8 @@ use sodiumoxide::crypto::hash; use crate::endpoint::*; use crate::error::*; use crate::netapp::*; -use crate::proto::*; + +use crate::message::*; use crate::NodeID; const CONN_RETRY_INTERVAL: Duration = Duration::from_secs(30); @@ -80,6 +81,7 @@ impl PeerInfoInternal { } } +/// Information that the full mesh peering strategy can return about the peers it knows of #[derive(Copy, Clone, Debug)] pub struct PeerInfo { /// The node's identifier (its public key) @@ -382,7 +384,7 @@ impl FullMeshPeeringStrategy { ping_time ); let ping_response = select! { - r = self.ping_endpoint.call(&id, &ping_msg, PRIO_HIGH) => r, + r = self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH) => r, _ = tokio::time::sleep(PING_TIMEOUT) => Err(Error::Message("Ping timeout".into())), }; @@ -434,7 +436,7 @@ impl FullMeshPeeringStrategy { let pex_message = PeerListMessage { list: peer_list }; match self .peer_list_endpoint - .call(id, &pex_message, PRIO_BACKGROUND) + .call(id, pex_message, PRIO_BACKGROUND) .await { Err(e) => warn!("Error doing peer exchange: {}", e), diff --git a/src/proto.rs b/src/proto.rs deleted file mode 100644 index 8f7e70f..0000000 --- a/src/proto.rs +++ /dev/null @@ -1,358 +0,0 @@ -use std::collections::{HashMap, VecDeque}; -use std::fmt::Write; -use std::sync::Arc; - -use log::trace; - -use futures::{AsyncReadExt, AsyncWriteExt}; -use kuska_handshake::async_std::BoxStreamWrite; - -use tokio::sync::mpsc; - -use async_trait::async_trait; - -use crate::error::*; - -/// Priority of a request (click to read more about priorities). -/// -/// This priority value is used to priorize messages -/// in the send queue of the client, and their responses in the send queue of the -/// server. Lower values mean higher priority. -/// -/// This mechanism is usefull for messages bigger than the maximum chunk size -/// (set at `0x4000` bytes), such as large file transfers. -/// In such case, all of the messages in the send queue with the highest priority -/// will take turns to send individual chunks, in a round-robin fashion. -/// Once all highest priority messages are sent successfully, the messages with -/// the next highest priority will begin being sent in the same way. -/// -/// The same priority value is given to a request and to its associated response. -pub type RequestPriority = u8; - -/// Priority class: high -pub const PRIO_HIGH: RequestPriority = 0x20; -/// Priority class: normal -pub const PRIO_NORMAL: RequestPriority = 0x40; -/// Priority class: background -pub const PRIO_BACKGROUND: RequestPriority = 0x80; -/// Priority: primary among given class -pub const PRIO_PRIMARY: RequestPriority = 0x00; -/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) -pub const PRIO_SECONDARY: RequestPriority = 0x01; - -// Messages are sent by chunks -// Chunk format: -// - u32 BE: request id (same for request and response) -// - u16 BE: chunk length, possibly with CHUNK_HAS_CONTINUATION flag -// when this is not the last chunk of the message -// - [u8; chunk_length] chunk data - -pub(crate) type RequestID = u32; -type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; -const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; - -struct SendQueueItem { - id: RequestID, - prio: RequestPriority, - data: Vec, - cursor: usize, -} - -struct SendQueue { - items: VecDeque<(u8, VecDeque)>, -} - -impl SendQueue { - fn new() -> Self { - Self { - items: VecDeque::with_capacity(64), - } - } - fn push(&mut self, item: SendQueueItem) { - let prio = item.prio; - let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) { - Ok(i) => i, - Err(i) => { - self.items.insert(i, (prio, VecDeque::new())); - i - } - }; - self.items[pos_prio].1.push_back(item); - } - fn pop(&mut self) -> Option { - match self.items.pop_front() { - None => None, - Some((prio, mut items_at_prio)) => { - let ret = items_at_prio.pop_front(); - if !items_at_prio.is_empty() { - self.items.push_front((prio, items_at_prio)); - } - ret.or_else(|| self.pop()) - } - } - } - fn is_empty(&self) -> bool { - self.items.iter().all(|(_k, v)| v.is_empty()) - } - fn dump(&self) -> String { - let mut ret = String::new(); - for (prio, q) in self.items.iter() { - for item in q.iter() { - write!( - &mut ret, - " [{} {} ({})]", - prio, - item.data.len() - item.cursor, - item.id - ) - .unwrap(); - } - } - ret - } -} - -/// The SendLoop trait, which is implemented both by the client and the server -/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()` -/// that takes a channel of messages to send and an asynchronous writer, -/// and sends messages from the channel to the async writer, putting them in a queue -/// before being sent and doing the round-robin sending strategy. -/// -/// The `.send_loop()` exits when the sending end of the channel is closed, -/// or if there is an error at any time writing to the async writer. -#[async_trait] -pub(crate) trait SendLoop: Sync { - async fn send_loop( - self: Arc, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec)>, - mut write: BoxStreamWrite, - debug_name: String, - ) -> Result<(), Error> - where - W: AsyncWriteExt + Unpin + Send + Sync, - { - let mut sending = SendQueue::new(); - let mut should_exit = false; - while !should_exit || !sending.is_empty() { - trace!("send_loop({}): queue = {}", debug_name, sending.dump()); - if let Ok((id, prio, data)) = msg_recv.try_recv() { - trace!( - "send_loop({}): new message to send, id = {}, prio = {}, {} bytes", - debug_name, - id, - prio, - data.len() - ); - sending.push(SendQueueItem { - id, - prio, - data, - cursor: 0, - }); - } else if let Some(mut item) = sending.pop() { - trace!( - "send_loop({}): sending bytes for {} ({} bytes, {} already sent)", - debug_name, - item.id, - item.data.len(), - item.cursor - ); - let header_id = RequestID::to_be_bytes(item.id); - write.write_all(&header_id[..]).await?; - - if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize { - let size_header = - ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION); - write.write_all(&size_header[..]).await?; - - let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize; - write.write_all(&item.data[item.cursor..new_cursor]).await?; - item.cursor = new_cursor; - - sending.push(item); - } else { - let send_len = (item.data.len() - item.cursor) as ChunkLength; - - let size_header = ChunkLength::to_be_bytes(send_len); - write.write_all(&size_header[..]).await?; - - write.write_all(&item.data[item.cursor..]).await?; - } - write.flush().await?; - } else { - let sth = msg_recv.recv().await; - if let Some((id, prio, data)) = sth { - trace!( - "send_loop({}): new message to send, id = {}, prio = {}, {} bytes", - debug_name, - id, - prio, - data.len() - ); - sending.push(SendQueueItem { - id, - prio, - data, - cursor: 0, - }); - } else { - should_exit = true; - } - } - } - - let _ = write.goodbye().await; - Ok(()) - } -} - -/// The RecvLoop trait, which is implemented both by the client and the server -/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` -/// and a prototype of a handler for received messages `.recv_handler()` that -/// must be filled by implementors. `.recv_loop()` receives messages in a loop -/// according to the protocol defined above: chunks of message in progress of being -/// received are stored in a buffer, and when the last chunk of a message is received, -/// the full message is passed to the receive handler. -#[async_trait] -pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, msg: Vec); - - async fn recv_loop(self: Arc, mut read: R, debug_name: String) -> Result<(), Error> - where - R: AsyncReadExt + Unpin + Send + Sync, - { - let mut receiving = HashMap::new(); - loop { - let mut header_id = [0u8; RequestID::BITS as usize / 8]; - match read.read_exact(&mut header_id[..]).await { - Ok(_) => (), - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, - Err(e) => return Err(e.into()), - }; - let id = RequestID::from_be_bytes(header_id); - - let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; - read.read_exact(&mut header_size[..]).await?; - let size = ChunkLength::from_be_bytes(header_size); - trace!( - "recv_loop({}): got header id = {}, size = 0x{:04x} ({} bytes)", - debug_name, - id, - size, - size & !CHUNK_HAS_CONTINUATION - ); - - let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; - let size = size & !CHUNK_HAS_CONTINUATION; - - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; - trace!("recv_loop({}): read {} bytes", debug_name, next_slice.len()); - - let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default(); - msg_bytes.extend_from_slice(&next_slice[..]); - - if has_cont { - receiving.insert(id, msg_bytes); - } else { - self.recv_handler(id, msg_bytes); - } - } - Ok(()) - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_priority_queue() { - let i1 = SendQueueItem { - id: 1, - prio: PRIO_NORMAL, - data: vec![], - cursor: 0, - }; - let i2 = SendQueueItem { - id: 2, - prio: PRIO_HIGH, - data: vec![], - cursor: 0, - }; - let i2bis = SendQueueItem { - id: 20, - prio: PRIO_HIGH, - data: vec![], - cursor: 0, - }; - let i3 = SendQueueItem { - id: 3, - prio: PRIO_HIGH | PRIO_SECONDARY, - data: vec![], - cursor: 0, - }; - let i4 = SendQueueItem { - id: 4, - prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: vec![], - cursor: 0, - }; - let i5 = SendQueueItem { - id: 5, - prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: vec![], - cursor: 0, - }; - - let mut q = SendQueue::new(); - - q.push(i1); // 1 - let a = q.pop().unwrap(); // empty -> 1 - assert_eq!(a.id, 1); - assert!(q.pop().is_none()); - - q.push(a); // 1 - q.push(i2); // 2 1 - q.push(i2bis); // [2 20] 1 - let a = q.pop().unwrap(); // 20 1 -> 2 - assert_eq!(a.id, 2); - let b = q.pop().unwrap(); // 1 -> 20 - assert_eq!(b.id, 20); - let c = q.pop().unwrap(); // empty -> 1 - assert_eq!(c.id, 1); - assert!(q.pop().is_none()); - - q.push(a); // 2 - q.push(b); // [2 20] - q.push(c); // [2 20] 1 - q.push(i3); // [2 20] 3 1 - q.push(i4); // [2 20] 3 1 4 - q.push(i5); // [2 20] 3 1 5 4 - - let a = q.pop().unwrap(); // 20 3 1 5 4 -> 2 - assert_eq!(a.id, 2); - q.push(a); // [20 2] 3 1 5 4 - - let a = q.pop().unwrap(); // 2 3 1 5 4 -> 20 - assert_eq!(a.id, 20); - let b = q.pop().unwrap(); // 3 1 5 4 -> 2 - assert_eq!(b.id, 2); - q.push(b); // 2 3 1 5 4 - let b = q.pop().unwrap(); // 3 1 5 4 -> 2 - assert_eq!(b.id, 2); - let c = q.pop().unwrap(); // 1 5 4 -> 3 - assert_eq!(c.id, 3); - q.push(b); // 2 1 5 4 - let b = q.pop().unwrap(); // 1 5 4 -> 2 - assert_eq!(b.id, 2); - let e = q.pop().unwrap(); // 5 4 -> 1 - assert_eq!(e.id, 1); - let f = q.pop().unwrap(); // 4 -> 5 - assert_eq!(f.id, 5); - let g = q.pop().unwrap(); // empty -> 4 - assert_eq!(g.id, 4); - assert!(q.pop().is_none()); - } -} diff --git a/src/proto2.rs b/src/proto2.rs deleted file mode 100644 index 7210781..0000000 --- a/src/proto2.rs +++ /dev/null @@ -1,75 +0,0 @@ -use crate::error::*; -use crate::proto::*; - -pub(crate) struct QueryMessage<'a> { - pub(crate) prio: RequestPriority, - pub(crate) path: &'a [u8], - pub(crate) telemetry_id: Option>, - pub(crate) body: &'a [u8], -} - -/// QueryMessage encoding: -/// - priority: u8 -/// - path length: u8 -/// - path: [u8; path length] -/// - telemetry id length: u8 -/// - telemetry id: [u8; telemetry id length] -/// - body [u8; ..] -impl<'a> QueryMessage<'a> { - pub(crate) fn encode(self) -> Vec { - let tel_len = match &self.telemetry_id { - Some(t) => t.len(), - None => 0, - }; - - let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len()); - - ret.push(self.prio); - - ret.push(self.path.len() as u8); - ret.extend_from_slice(self.path); - - if let Some(t) = self.telemetry_id { - ret.push(t.len() as u8); - ret.extend(t); - } else { - ret.push(0u8); - } - - ret.extend_from_slice(self.body); - - ret - } - - pub(crate) fn decode(bytes: &'a [u8]) -> Result { - if bytes.len() < 3 { - return Err(Error::Message("Invalid protocol message".into())); - } - - let path_length = bytes[1] as usize; - if bytes.len() < 3 + path_length { - return Err(Error::Message("Invalid protocol message".into())); - } - - let telemetry_id_len = bytes[2 + path_length] as usize; - if bytes.len() < 3 + path_length + telemetry_id_len { - return Err(Error::Message("Invalid protocol message".into())); - } - - let path = &bytes[2..2 + path_length]; - let telemetry_id = if telemetry_id_len > 0 { - Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec()) - } else { - None - }; - - let body = &bytes[3 + path_length + telemetry_id_len..]; - - Ok(Self { - prio: bytes[0], - path, - telemetry_id, - body, - }) - } -} diff --git a/src/recv.rs b/src/recv.rs new file mode 100644 index 0000000..0de7bef --- /dev/null +++ b/src/recv.rs @@ -0,0 +1,153 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use bytes::Bytes; +use log::*; + +use futures::AsyncReadExt; +use tokio::sync::mpsc; + +use crate::error::*; +use crate::send::*; +use crate::stream::*; + +/// Structure to warn when the sender is dropped before end of stream was reached, like when +/// connection to some remote drops while transmitting data +struct Sender { + inner: Option>, +} + +impl Sender { + fn new(inner: mpsc::UnboundedSender) -> Self { + Sender { inner: Some(inner) } + } + + fn send(&self, packet: Packet) { + let _ = self.inner.as_ref().unwrap().send(packet); + } + + fn end(&mut self) { + self.inner = None; + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if let Some(inner) = self.inner.take() { + let _ = inner.send(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "Netapp connection dropped before end of stream", + ))); + } + } +} + +/// The RecvLoop trait, which is implemented both by the client and the server +/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` +/// and a prototype of a handler for received messages `.recv_handler()` that +/// must be filled by implementors. `.recv_loop()` receives messages in a loop +/// according to the protocol defined above: chunks of message in progress of being +/// received are stored in a buffer, and when the last chunk of a message is received, +/// the full message is passed to the receive handler. +#[async_trait] +pub(crate) trait RecvLoop: Sync + 'static { + fn recv_handler(self: &Arc, id: RequestID, stream: ByteStream); + fn cancel_handler(self: &Arc, _id: RequestID) {} + + async fn recv_loop(self: Arc, mut read: R, debug_name: String) -> Result<(), Error> + where + R: AsyncReadExt + Unpin + Send + Sync, + { + let mut streams: HashMap = HashMap::new(); + loop { + trace!( + "recv_loop({}): in_progress = {:?}", + debug_name, + streams.iter().map(|(id, _)| id).collect::>() + ); + + let mut header_id = [0u8; RequestID::BITS as usize / 8]; + match read.read_exact(&mut header_id[..]).await { + Ok(_) => (), + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e.into()), + }; + let id = RequestID::from_be_bytes(header_id); + + let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; + read.read_exact(&mut header_size[..]).await?; + let size = ChunkLength::from_be_bytes(header_size); + + if size == CANCEL_REQUEST { + if let Some(mut stream) = streams.remove(&id) { + let _ = stream.send(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "netapp: cancel requested", + ))); + stream.end(); + } + self.cancel_handler(id); + continue; + } + + let has_cont = (size & CHUNK_FLAG_HAS_CONTINUATION) != 0; + let is_error = (size & CHUNK_FLAG_ERROR) != 0; + let size = (size & CHUNK_LENGTH_MASK) as usize; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + + let packet = if is_error { + let kind = u8_to_io_errorkind(next_slice[0]); + let msg = + std::str::from_utf8(&next_slice[1..]).unwrap_or(""); + debug!( + "recv_loop({}): got id {}, error {:?}: {}", + debug_name, id, kind, msg + ); + Some(Err(std::io::Error::new(kind, msg.to_string()))) + } else { + trace!( + "recv_loop({}): got id {}, size {}, has_cont {}", + debug_name, + id, + size, + has_cont + ); + if !next_slice.is_empty() { + Some(Ok(Bytes::from(next_slice))) + } else { + None + } + }; + + let mut sender = if let Some(send) = streams.remove(&(id)) { + send + } else { + let (send, recv) = mpsc::unbounded_channel(); + trace!("recv_loop({}): id {} is new channel", debug_name, id); + self.recv_handler( + id, + Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(recv)), + ); + Sender::new(send) + }; + + if let Some(packet) = packet { + // If we cannot put packet in channel, it means that the + // receiving end of the channel is disconnected. + // We still need to reach eos before dropping this sender + let _ = sender.send(packet); + } + + if has_cont { + assert!(!is_error); + streams.insert(id, sender); + } else { + trace!("recv_loop({}): close channel id {}", debug_name, id); + sender.end(); + } + } + Ok(()) + } +} diff --git a/src/send.rs b/src/send.rs new file mode 100644 index 0000000..0db0ba7 --- /dev/null +++ b/src/send.rs @@ -0,0 +1,356 @@ +use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use async_trait::async_trait; +use bytes::{BufMut, Bytes, BytesMut}; +use log::*; + +use futures::{AsyncWriteExt, Future}; +use kuska_handshake::async_std::BoxStreamWrite; +use tokio::sync::mpsc; + +use crate::error::*; +use crate::message::*; +use crate::stream::*; + +// Messages are sent by chunks +// Chunk format: +// - u32 BE: request id (same for request and response) +// - u16 BE: chunk length + flags: +// CHUNK_FLAG_HAS_CONTINUATION when this is not the last chunk of the stream +// CHUNK_FLAG_ERROR if this chunk denotes an error +// (these two flags are exclusive, an error denotes the end of the stream) +// **special value** 0xFFFF indicates a CANCEL message +// - [u8; chunk_length], either +// - if not error: chunk data +// - if error: +// - u8: error kind, encoded using error::io_errorkind_to_u8 +// - rest: error message +// - absent for cancel messag + +pub(crate) type RequestID = u32; +pub(crate) type ChunkLength = u16; + +pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; +pub(crate) const CHUNK_FLAG_ERROR: ChunkLength = 0x4000; +pub(crate) const CHUNK_FLAG_HAS_CONTINUATION: ChunkLength = 0x8000; +pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF; +pub(crate) const CANCEL_REQUEST: ChunkLength = 0xFFFF; + +pub(crate) enum SendItem { + Stream(RequestID, RequestPriority, Option, ByteStream), + Cancel(RequestID), +} + +// ---- + +struct SendQueue { + items: Vec<(u8, SendQueuePriority)>, +} + +struct SendQueuePriority { + items: VecDeque, + order: HashMap>, +} + +struct SendQueueItem { + id: RequestID, + prio: RequestPriority, + order_tag: Option, + data: ByteStreamReader, + sent: usize, +} + +impl SendQueue { + fn new() -> Self { + Self { + items: Vec::with_capacity(64), + } + } + fn push(&mut self, item: SendQueueItem) { + let prio = item.prio; + let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) { + Ok(i) => i, + Err(i) => { + self.items.insert(i, (prio, SendQueuePriority::new())); + i + } + }; + self.items[pos_prio].1.push(item); + } + fn remove(&mut self, id: RequestID) { + for (_, prioq) in self.items.iter_mut() { + prioq.remove(id); + } + self.items.retain(|(_prio, q)| !q.is_empty()); + } + fn is_empty(&self) -> bool { + self.items.iter().all(|(_k, v)| v.is_empty()) + } + + // this is like an async fn, but hand implemented + fn next_ready(&mut self) -> SendQueuePollNextReady<'_> { + SendQueuePollNextReady { queue: self } + } +} + +impl SendQueuePriority { + fn new() -> Self { + Self { + items: VecDeque::new(), + order: HashMap::new(), + } + } + fn push(&mut self, item: SendQueueItem) { + if let Some(OrderTag(stream, order)) = item.order_tag { + let order_vec = self.order.entry(stream).or_default(); + let i = order_vec.iter().take_while(|o2| **o2 < order).count(); + order_vec.insert(i, order); + } + self.items.push_front(item); + } + fn remove(&mut self, id: RequestID) { + if let Some(i) = self.items.iter().position(|x| x.id == id) { + let item = self.items.remove(i).unwrap(); + if let Some(OrderTag(stream, order)) = item.order_tag { + let order_vec = self.order.get_mut(&stream).unwrap(); + let j = order_vec.iter().position(|x| *x == order).unwrap(); + order_vec.remove(j).unwrap(); + if order_vec.is_empty() { + self.order.remove(&stream); + } + } + } + } + fn is_empty(&self) -> bool { + self.items.is_empty() + } + fn poll_next_ready(&mut self, ctx: &mut Context<'_>) -> Poll<(RequestID, DataFrame)> { + for (j, item) in self.items.iter_mut().enumerate() { + if let Some(OrderTag(stream, order)) = item.order_tag { + if order > *self.order.get(&stream).unwrap().front().unwrap() { + continue; + } + } + + let mut item_reader = item.data.read_exact_or_eos(MAX_CHUNK_LENGTH as usize); + if let Poll::Ready(bytes_or_err) = Pin::new(&mut item_reader).poll(ctx) { + let id = item.id; + let eos = item.data.eos(); + + let packet = bytes_or_err.map_err(|e| match e { + ReadExactError::Stream(err) => err, + _ => unreachable!(), + }); + + let is_err = packet.is_err(); + let data_frame = DataFrame::from_packet(packet, !eos); + item.sent += data_frame.data().len(); + + if eos || is_err { + // If item had an order tag, remove it from the corresponding ordering list + if let Some(OrderTag(stream, order)) = item.order_tag { + let order_stream = self.order.get_mut(&stream).unwrap(); + assert_eq!(order_stream.pop_front(), Some(order)); + if order_stream.is_empty() { + self.order.remove(&stream); + } + } + // Remove item from sending queue + self.items.remove(j); + } else { + // Move item later in send queue to implement LAS scheduling + // (LAS = Least Attained Service) + for k in j..self.items.len() - 1 { + if self.items[k].sent >= self.items[k + 1].sent { + self.items.swap(k, k + 1); + } else { + break; + } + } + } + + return Poll::Ready((id, data_frame)); + } + } + + Poll::Pending + } + fn dump(&self, prio: u8) -> String { + self.items + .iter() + .map(|i| format!("[{} {} {:?} @{}]", prio, i.id, i.order_tag, i.sent)) + .collect::>() + .join(" ") + } +} + +struct SendQueuePollNextReady<'a> { + queue: &'a mut SendQueue, +} + +impl<'a> futures::Future for SendQueuePollNextReady<'a> { + type Output = (RequestID, DataFrame); + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + for (i, (_prio, items_at_prio)) in self.queue.items.iter_mut().enumerate() { + if let Poll::Ready(res) = items_at_prio.poll_next_ready(ctx) { + if items_at_prio.is_empty() { + self.queue.items.remove(i); + } + return Poll::Ready(res); + } + } + // If the queue is empty, this futures is eternally pending. + // This is ok because we use it in a select with another future + // that can interrupt it. + Poll::Pending + } +} + +enum DataFrame { + /// a fixed size buffer containing some data + a boolean indicating whether + /// there may be more data comming from this stream. Can be used for some + /// optimization. It's an error to set it to false if there is more data, but it is correct + /// (albeit sub-optimal) to set it to true if there is nothing coming after + Data(Bytes, bool), + /// An error code automatically signals the end of the stream + Error(Bytes), +} + +impl DataFrame { + fn from_packet(p: Packet, has_cont: bool) -> Self { + match p { + Ok(bytes) => { + assert!(bytes.len() <= MAX_CHUNK_LENGTH as usize); + Self::Data(bytes, has_cont) + } + Err(e) => { + let mut buf = BytesMut::new(); + buf.put_u8(io_errorkind_to_u8(e.kind())); + + let msg = format!("{}", e).into_bytes(); + if msg.len() > (MAX_CHUNK_LENGTH - 1) as usize { + buf.put(&msg[..(MAX_CHUNK_LENGTH - 1) as usize]); + } else { + buf.put(&msg[..]); + } + + Self::Error(buf.freeze()) + } + } + } + + fn header(&self) -> [u8; 2] { + let header_u16 = match self { + DataFrame::Data(data, false) => data.len() as u16, + DataFrame::Data(data, true) => data.len() as u16 | CHUNK_FLAG_HAS_CONTINUATION, + DataFrame::Error(msg) => msg.len() as u16 | CHUNK_FLAG_ERROR, + }; + ChunkLength::to_be_bytes(header_u16) + } + + fn data(&self) -> &[u8] { + match self { + DataFrame::Data(ref data, _) => &data[..], + DataFrame::Error(ref msg) => &msg[..], + } + } +} + +/// The SendLoop trait, which is implemented both by the client and the server +/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()` +/// that takes a channel of messages to send and an asynchronous writer, +/// and sends messages from the channel to the async writer, putting them in a queue +/// before being sent and doing the round-robin sending strategy. +/// +/// The `.send_loop()` exits when the sending end of the channel is closed, +/// or if there is an error at any time writing to the async writer. +#[async_trait] +pub(crate) trait SendLoop: Sync { + async fn send_loop( + self: Arc, + msg_recv: mpsc::UnboundedReceiver, + mut write: BoxStreamWrite, + debug_name: String, + ) -> Result<(), Error> + where + W: AsyncWriteExt + Unpin + Send + Sync, + { + let mut sending = SendQueue::new(); + let mut msg_recv = Some(msg_recv); + while msg_recv.is_some() || !sending.is_empty() { + trace!( + "send_loop({}): queue = {:?}", + debug_name, + sending + .items + .iter() + .map(|(prio, i)| i.dump(*prio)) + .collect::>() + .join(" ; ") + ); + + let recv_fut = async { + if let Some(chan) = &mut msg_recv { + chan.recv().await + } else { + futures::future::pending().await + } + }; + let send_fut = sending.next_ready(); + + // recv_fut is cancellation-safe according to tokio doc, + // send_fut is cancellation-safe as implemented above? + tokio::select! { + biased; // always read incomming channel first if it has data + sth = recv_fut => { + match sth { + Some(SendItem::Stream(id, prio, order_tag, data)) => { + trace!("send_loop({}): add stream {} to send", debug_name, id); + sending.push(SendQueueItem { + id, + prio, + order_tag, + data: ByteStreamReader::new(data), + sent: 0, + }) + } + Some(SendItem::Cancel(id)) => { + trace!("send_loop({}): cancelling {}", debug_name, id); + sending.remove(id); + let header_id = RequestID::to_be_bytes(id); + write.write_all(&header_id[..]).await?; + write.write_all(&ChunkLength::to_be_bytes(CANCEL_REQUEST)).await?; + write.flush().await?; + } + None => { + msg_recv = None; + } + }; + } + (id, data) = send_fut => { + trace!( + "send_loop({}): id {}, send {} bytes, header_size {}", + debug_name, + id, + data.data().len(), + hex::encode(data.header()) + ); + + let header_id = RequestID::to_be_bytes(id); + write.write_all(&header_id[..]).await?; + + write.write_all(&data.header()).await?; + write.write_all(data.data()).await?; + write.flush().await?; + } + } + } + + let _ = write.goodbye().await; + Ok(()) + } +} diff --git a/src/server.rs b/src/server.rs index a835959..cd367c4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,9 +1,17 @@ +use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; -use bytes::Bytes; -use log::{debug, trace}; +use async_trait::async_trait; +use log::*; + +use futures::io::{AsyncReadExt, AsyncWriteExt}; +use kuska_handshake::async_std::{handshake_server, BoxStream}; +use tokio::net::TcpStream; +use tokio::select; +use tokio::sync::{mpsc, watch}; +use tokio_util::compat::*; #[cfg(feature = "telemetry")] use opentelemetry::{ @@ -15,21 +23,12 @@ use opentelemetry_contrib::trace::propagator::binary::*; #[cfg(feature = "telemetry")] use rand::{thread_rng, Rng}; -use tokio::net::TcpStream; -use tokio::select; -use tokio::sync::{mpsc, watch}; -use tokio_util::compat::*; - -use futures::io::{AsyncReadExt, AsyncWriteExt}; - -use async_trait::async_trait; - -use kuska_handshake::async_std::{handshake_server, BoxStream}; - use crate::error::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; -use crate::proto2::*; +use crate::recv::*; +use crate::send::*; +use crate::stream::*; use crate::util::*; // The client and server connection structs (client.rs and server.rs) @@ -55,7 +54,8 @@ pub(crate) struct ServerConn { netapp: Arc, - resp_send: ArcSwapOption)>>, + resp_send: ArcSwapOption>, + running_handlers: Mutex>>, } impl ServerConn { @@ -101,6 +101,7 @@ impl ServerConn { remote_addr, peer_id, resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))), + running_handlers: Mutex::new(HashMap::new()), }); netapp.connected_as_server(peer_id, conn.clone()); @@ -126,9 +127,8 @@ impl ServerConn { Ok(()) } - async fn recv_handler_aux(self: &Arc, bytes: &[u8]) -> Result, Error> { - let msg = QueryMessage::decode(bytes)?; - let path = String::from_utf8(msg.path.to_vec())?; + async fn recv_handler_aux(self: &Arc, req_enc: ReqEnc) -> Result { + let path = String::from_utf8(req_enc.path.to_vec())?; let handler_opt = { let endpoints = self.netapp.endpoints.read().unwrap(); @@ -140,9 +140,9 @@ impl ServerConn { if #[cfg(feature = "telemetry")] { let tracer = opentelemetry::global::tracer("netapp"); - let mut span = if let Some(telemetry_id) = msg.telemetry_id { + let mut span = if !req_enc.telemetry_id.is_empty() { let propagator = BinaryPropagator::new(); - let context = propagator.from_bytes(telemetry_id); + let context = propagator.from_bytes(req_enc.telemetry_id.to_vec()); let context = Context::new().with_remote_span_context(context); tracer.span_builder(format!(">> RPC {}", path)) .with_kind(SpanKind::Server) @@ -157,13 +157,13 @@ impl ServerConn { .start(&tracer) }; span.set_attribute(KeyValue::new("path", path.to_string())); - span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64)); + span.set_attribute(KeyValue::new("len_query_msg", req_enc.msg.len() as i64)); - handler.handle(msg.body, self.peer_id) + handler.handle(req_enc, self.peer_id) .with_context(Context::current_with_span(span)) .await } else { - handler.handle(msg.body, self.peer_id).await + handler.handle(req_enc, self.peer_id).await } } } else { @@ -176,35 +176,47 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc, id: RequestID, bytes: Vec) { - let resp_send = self.resp_send.load_full().unwrap(); + fn recv_handler(self: &Arc, id: RequestID, stream: ByteStream) { + let resp_send = match self.resp_send.load_full() { + Some(c) => c, + None => return, + }; + + let mut rh = self.running_handlers.lock().unwrap(); let self2 = self.clone(); - tokio::spawn(async move { - trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); - let bytes: Bytes = bytes.into(); + let jh = tokio::spawn(async move { + debug!("server: recv_handler got {}", id); - let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; - let resp = self2.recv_handler_aux(&bytes[..]).await; - - let resp_bytes = match resp { - Ok(rb) => { - let mut resp_bytes = vec![0u8]; - resp_bytes.extend(rb); - resp_bytes - } - Err(e) => { - let mut resp_bytes = vec![e.code()]; - resp_bytes.extend(e.to_string().into_bytes()); - resp_bytes - } + let (prio, resp_enc_result) = match ReqEnc::decode(stream).await { + Ok(req_enc) => (req_enc.prio, self2.recv_handler_aux(req_enc).await), + Err(e) => (PRIO_HIGH, Err(e)), }; - trace!("ServerConn sending response to {}: ", id); + debug!("server: sending response to {}", id); + let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result); resp_send - .send((id, prio, resp_bytes)) - .log_err("ServerConn recv_handler send resp"); + .send(SendItem::Stream(id, prio, resp_order, resp_stream)) + .log_err("ServerConn recv_handler send resp bytes"); + + self2.running_handlers.lock().unwrap().remove(&id); }); + + rh.insert(id, jh); + } + + fn cancel_handler(self: &Arc, id: RequestID) { + trace!("received cancel for request {}", id); + + // If the handler is still running, abort it now + if let Some(jh) = self.running_handlers.lock().unwrap().remove(&id) { + jh.abort(); + } + + // Inform the response sender that we don't need to send the response + if let Some(resp_send) = self.resp_send.load_full() { + let _ = resp_send.send(SendItem::Cancel(id)); + } } } diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..88c3fed --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,202 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::Bytes; + +use futures::Future; +use futures::{Stream, StreamExt}; +use tokio::io::AsyncRead; + +use crate::bytes_buf::BytesBuf; + +/// A stream of bytes (click to read more). +/// +/// When sent through Netapp, the Vec may be split in smaller chunk in such a way +/// consecutive Vec may get merged, but Vec and error code may not be reordered +/// +/// Items sent in the ByteStream may be errors of type `std::io::Error`. +/// An error indicates the end of the ByteStream: a reader should no longer read +/// after recieving an error, and a writer should stop writing after sending an error. +pub type ByteStream = Pin + Send + Sync>>; + +/// A packet sent in a ByteStream, which may contain either +/// a Bytes object or an error +pub type Packet = Result; + +// ---- + +/// A helper struct to read defined lengths of data from a BytesStream +pub struct ByteStreamReader { + stream: ByteStream, + buf: BytesBuf, + eos: bool, + err: Option, +} + +impl ByteStreamReader { + /// Creates a new `ByteStreamReader` from a `ByteStream` + pub fn new(stream: ByteStream) -> Self { + ByteStreamReader { + stream, + buf: BytesBuf::new(), + eos: false, + err: None, + } + } + + /// Read exactly `read_len` bytes from the underlying stream + /// (returns a future) + pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { + ByteStreamReadExact { + reader: self, + read_len, + fail_on_eos: true, + } + } + + /// Read at most `read_len` bytes from the underlying stream, or less + /// if the end of the stream is reached (returns a future) + pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { + ByteStreamReadExact { + reader: self, + read_len, + fail_on_eos: false, + } + } + + /// Read exactly one byte from the underlying stream and returns it + /// as an u8 + pub async fn read_u8(&mut self) -> Result { + Ok(self.read_exact(1).await?[0]) + } + + /// Read exactly two bytes from the underlying stream and returns them as an u16 (using + /// big-endian decoding) + pub async fn read_u16(&mut self) -> Result { + let bytes = self.read_exact(2).await?; + let mut b = [0u8; 2]; + b.copy_from_slice(&bytes[..]); + Ok(u16::from_be_bytes(b)) + } + + /// Read exactly four bytes from the underlying stream and returns them as an u32 (using + /// big-endian decoding) + pub async fn read_u32(&mut self) -> Result { + let bytes = self.read_exact(4).await?; + let mut b = [0u8; 4]; + b.copy_from_slice(&bytes[..]); + Ok(u32::from_be_bytes(b)) + } + + /// Transforms the stream reader back into the underlying stream (starting + /// after everything that the reader has read) + pub fn into_stream(self) -> ByteStream { + let buf_stream = futures::stream::iter(self.buf.into_slices().into_iter().map(Ok)); + if let Some(err) = self.err { + Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) }))) + } else if self.eos { + Box::pin(buf_stream) + } else { + Box::pin(buf_stream.chain(self.stream)) + } + } + + /// Tries to fill the internal read buffer from the underlying stream if it is empty. + /// Calling this might be necessary to ensure that `.eos()` returns a correct + /// result, otherwise the reader might not be aware that the underlying + /// stream has nothing left to return. + pub async fn fill_buffer(&mut self) { + if self.buf.is_empty() { + let packet = self.stream.next().await; + self.add_stream_next(packet); + } + } + + /// Clears the internal read buffer and returns its content + pub fn take_buffer(&mut self) -> Bytes { + self.buf.take_all() + } + + /// Returns true if the end of the underlying stream has been reached + pub fn eos(&self) -> bool { + self.buf.is_empty() && self.eos + } + + fn try_get(&mut self, read_len: usize) -> Option { + self.buf.take_exact(read_len) + } + + fn add_stream_next(&mut self, packet: Option) { + match packet { + Some(Ok(slice)) => { + self.buf.extend(slice); + } + Some(Err(e)) => { + self.err = Some(e); + self.eos = true; + } + None => { + self.eos = true; + } + } + } +} + +/// The error kind that can be returned by `ByteStreamReader::read_exact` and +/// `ByteStreamReader::read_exact_or_eos` +pub enum ReadExactError { + /// The end of the stream was reached before the requested number of bytes could be read + UnexpectedEos, + /// The underlying data stream returned an IO error when trying to read + Stream(std::io::Error), +} + +/// The future returned by `ByteStreamReader::read_exact` and +/// `ByteStreamReader::read_exact_or_eos` +#[pin_project::pin_project] +pub struct ByteStreamReadExact<'a> { + #[pin] + reader: &'a mut ByteStreamReader, + read_len: usize, + fail_on_eos: bool, +} + +impl<'a> Future for ByteStreamReadExact<'a> { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + loop { + if let Some(bytes) = this.reader.try_get(*this.read_len) { + return Poll::Ready(Ok(bytes)); + } + if let Some(err) = &this.reader.err { + let err = std::io::Error::new(err.kind(), format!("{}", err)); + return Poll::Ready(Err(ReadExactError::Stream(err))); + } + if this.reader.eos { + if *this.fail_on_eos { + return Poll::Ready(Err(ReadExactError::UnexpectedEos)); + } else { + return Poll::Ready(Ok(this.reader.take_buffer())); + } + } + + let next_packet = futures::ready!(this.reader.stream.as_mut().poll_next(cx)); + this.reader.add_stream_next(next_packet); + } + } +} + +// ---- + +/// Turns a `tokio::io::AsyncRead` asynchronous reader into a `ByteStream` +pub fn asyncread_stream(reader: R) -> ByteStream { + Box::pin(tokio_util::io::ReaderStream::new(reader)) +} + +/// Turns a `ByteStream` into a `tokio::io::AsyncRead` asynchronous reader +pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static { + tokio_util::io::StreamReader::new(stream) +} diff --git a/src/test.rs b/src/test.rs index 82c7ba6..19759cc 100644 --- a/src/test.rs +++ b/src/test.rs @@ -14,6 +14,7 @@ use crate::NodeID; #[tokio::test(flavor = "current_thread")] async fn test_with_basic_scheduler() { + env_logger::init(); run_test().await } diff --git a/src/util.rs b/src/util.rs index f4dfac7..425d26f 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,18 +1,12 @@ use std::net::SocketAddr; use std::net::ToSocketAddrs; -use serde::Serialize; - use log::info; +use serde::Serialize; use tokio::sync::watch; -/// A node's identifier, which is also its public cryptographic key -pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey; -/// A node's secret key -pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; -/// A network key -pub type NetworkKey = sodiumoxide::crypto::auth::Key; +use crate::netapp::*; /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library.