diff --git a/Cargo.toml b/Cargo.toml index 377e09d..5b1cadc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,8 +16,8 @@ name = "netapp" [features] default = [] -basalt = ["lru", "rand"] -telemetry = ["opentelemetry", "opentelemetry-contrib", "rand"] +basalt = ["lru"] +telemetry = ["opentelemetry", "opentelemetry-contrib"] [dependencies] futures = "0.3.17" @@ -30,7 +30,7 @@ serde = { version = "1.0", default-features = false, features = ["derive", "rc"] rmp-serde = "0.14.3" hex = "0.4.2" -rand = { version = "0.5.5", optional = true } +rand = { version = "0.5.5" } log = "0.4.8" arc-swap = "1.1" diff --git a/src/bytes_buf.rs b/src/bytes_buf.rs index 46c7039..857be9d 100644 --- a/src/bytes_buf.rs +++ b/src/bytes_buf.rs @@ -146,7 +146,10 @@ mod test { assert!(buf.len() == 23); assert!(!buf.is_empty()); - assert_eq!(buf.take_all(), Bytes::from(b"Hello, world!1234567890".to_vec())); + assert_eq!( + buf.take_all(), + Bytes::from(b"Hello, world!1234567890".to_vec()) + ); assert!(buf.len() == 0); assert!(buf.is_empty()); @@ -160,7 +163,10 @@ mod test { assert_eq!(buf.take_exact(12), None); assert!(buf.len() == 11); - assert_eq!(buf.take_exact(11), Some(Bytes::from(b"llo, world!".to_vec()))); + assert_eq!( + buf.take_exact(11), + Some(Bytes::from(b"llo, world!".to_vec())) + ); assert!(buf.len() == 0); assert!(buf.is_empty()); } diff --git a/src/client.rs b/src/client.rs index 0dcbdf1..aef7bbb 100644 --- a/src/client.rs +++ b/src/client.rs @@ -35,7 +35,7 @@ pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption>, + query_send: ArcSwapOption>, next_query_number: AtomicU32, inflight: Mutex>>, @@ -165,7 +165,7 @@ impl ClientConn { // Encode request let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id); let req_msg_len = req_enc.msg.len(); - let req_stream = req_enc.encode(); + let (req_stream, req_order) = req_enc.encode(); // Send request through let (resp_send, resp_recv) = oneshot::channel(); @@ -175,7 +175,10 @@ impl ClientConn { "Too many inflight requests! RequestID collision. Interrupting previous request." ); let _ = old_ch.send(Box::pin(futures::stream::once(async move { - Err(std::io::Error::new(std::io::ErrorKind::Other, "RequestID collision, too many inflight requests")) + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "RequestID collision, too many inflight requests", + )) }))); } diff --git a/src/error.rs b/src/error.rs index 2fa4594..c0aeeac 100644 --- a/src/error.rs +++ b/src/error.rs @@ -28,6 +28,9 @@ pub enum Error { #[error(display = "Framing protocol error")] Framing, + #[error(display = "Remote error ({:?}): {}", _0, _1)] + Remote(io::ErrorKind, String), + #[error(display = "Request ID collision")] IdCollision, @@ -42,30 +45,6 @@ pub enum Error { #[error(display = "Version mismatch: {}", _0)] VersionMismatch(String), - - #[error(display = "Remote error {}: {}", _0, _1)] - Remote(u8, String), -} - -impl Error { - pub fn code(&self) -> u8 { - match self { - Self::Io(_) => 100, - Self::TokioJoin(_) => 110, - Self::OneshotRecv(_) => 111, - Self::RMPEncode(_) => 10, - Self::RMPDecode(_) => 11, - Self::UTF8(_) => 12, - Self::Framing => 13, - Self::NoHandler => 20, - Self::ConnectionClosed => 21, - Self::IdCollision => 22, - Self::Handshake(_) => 30, - Self::VersionMismatch(_) => 31, - Self::Remote(c, _) => *c, - Self::Message(_) => 99, - } - } } impl From> for Error { diff --git a/src/lib.rs b/src/lib.rs index 18091c8..8e30e40 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,10 +13,10 @@ //! about message priorization. //! Also check out the examples to learn how to use this crate. +pub mod bytes_buf; pub mod error; pub mod stream; pub mod util; -pub mod bytes_buf; pub mod endpoint; pub mod message; diff --git a/src/message.rs b/src/message.rs index 61d01d0..ca68cac 100644 --- a/src/message.rs +++ b/src/message.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use std::sync::Arc; use bytes::{BufMut, Bytes, BytesMut}; +use rand::prelude::*; use serde::{Deserialize, Serialize}; use futures::stream::StreamExt; @@ -40,6 +41,24 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; // ---- +#[derive(Clone, Copy)] +pub struct OrderTagStream(u64); +#[derive(Clone, Copy, Serialize, Deserialize, Debug)] +pub struct OrderTag(u64, u64); + +impl OrderTag { + pub fn stream() -> OrderTagStream { + OrderTagStream(thread_rng().gen()) + } +} +impl OrderTagStream { + pub fn order(&self, order: u64) -> OrderTag { + OrderTag(self.0, order) + } +} + +// ---- + /// This trait should be implemented by all messages your application /// wants to handle pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static { @@ -56,6 +75,7 @@ pub struct Req { pub(crate) msg: Arc, pub(crate) msg_ser: Option, pub(crate) stream: AttachedStream, + pub(crate) order_tag: Option, } impl Req { @@ -77,6 +97,13 @@ impl Req { } } + pub fn with_order_tag(self, order_tag: OrderTag) -> Self { + Self { + order_tag: Some(order_tag), + ..self + } + } + pub fn msg(&self) -> &M { &self.msg } @@ -97,6 +124,7 @@ impl Req { telemetry_id, msg: self.msg_ser.unwrap(), stream: self.stream.into_stream(), + order_tag: self.order_tag, } } @@ -109,6 +137,7 @@ impl Req { .stream .map(AttachedStream::Stream) .unwrap_or(AttachedStream::None), + order_tag: enc.order_tag, }) } } @@ -125,6 +154,7 @@ impl IntoReq for M { msg: Arc::new(self), msg_ser: Some(Bytes::from(msg_ser)), stream: AttachedStream::None, + order_tag: None, }) } fn into_req_local(self) -> Req { @@ -132,6 +162,7 @@ impl IntoReq for M { msg: Arc::new(self), msg_ser: None, stream: AttachedStream::None, + order_tag: None, } } } @@ -158,6 +189,7 @@ impl Clone for Req { msg: self.msg.clone(), msg_ser: self.msg_ser.clone(), stream, + order_tag: self.order_tag, } } } @@ -184,6 +216,7 @@ pub struct Resp { pub(crate) _phantom: PhantomData, pub(crate) msg: M::Response, pub(crate) stream: AttachedStream, + pub(crate) order_tag: Option, } impl Resp { @@ -192,6 +225,7 @@ impl Resp { _phantom: Default::default(), msg: v, stream: AttachedStream::None, + order_tag: None, } } @@ -209,6 +243,13 @@ impl Resp { } } + pub fn with_order_tag(self, order_tag: OrderTag) -> Self { + Self { + order_tag: Some(order_tag), + ..self + } + } + pub fn msg(&self) -> &M::Response { &self.msg } @@ -222,26 +263,24 @@ impl Resp { } pub(crate) fn into_enc(self) -> Result { - Ok(RespEnc::Success { + Ok(RespEnc { msg: rmp_to_vec_all_named(&self.msg)?.into(), stream: self.stream.into_stream(), + order_tag: self.order_tag, }) } pub(crate) fn from_enc(enc: RespEnc) -> Result { - match enc { - RespEnc::Success { msg, stream } => { - let msg = rmp_serde::decode::from_read_ref(&msg)?; - Ok(Self { - _phantom: Default::default(), - msg, - stream: stream - .map(AttachedStream::Stream) - .unwrap_or(AttachedStream::None), - }) - } - RespEnc::Error { code, message } => Err(Error::Remote(code, message)), - } + let msg = rmp_serde::decode::from_read_ref(&enc.msg)?; + Ok(Self { + _phantom: Default::default(), + msg, + stream: enc + .stream + .map(AttachedStream::Stream) + .unwrap_or(AttachedStream::None), + order_tag: enc.order_tag, + }) } } @@ -295,10 +334,11 @@ pub(crate) struct ReqEnc { pub(crate) telemetry_id: Bytes, pub(crate) msg: Bytes, pub(crate) stream: Option, + pub(crate) order_tag: Option, } impl ReqEnc { - pub(crate) fn encode(self) -> ByteStream { + pub(crate) fn encode(self) -> (ByteStream, Option) { let mut buf = BytesMut::with_capacity( self.path.len() + self.telemetry_id.len() + self.msg.len() + 16, ); @@ -315,15 +355,18 @@ impl ReqEnc { let header = buf.freeze(); - if let Some(stream) = self.stream { + let res_stream: ByteStream = if let Some(stream) = self.stream { Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)]).chain(stream)) } else { Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)])) - } + }; + (res_stream, self.order_tag) } pub(crate) async fn decode(stream: ByteStream) -> Result { - Self::decode_aux(stream).await.map_err(|_| Error::Framing) + Self::decode_aux(stream) + .await + .map_err(read_exact_error_to_error) } async fn decode_aux(stream: ByteStream) -> Result { @@ -346,6 +389,7 @@ impl ReqEnc { telemetry_id, msg, stream: Some(reader.into_stream()), + order_tag: None, }) } } @@ -360,74 +404,67 @@ impl ReqEnc { /// - message length + 1: u8 /// - error code: u8 /// - message: [u8; message_length] -pub(crate) enum RespEnc { - Error { - code: u8, - message: String, - }, - Success { - msg: Bytes, - stream: Option, - }, +pub(crate) struct RespEnc { + msg: Bytes, + stream: Option, + order_tag: Option, } impl RespEnc { - pub(crate) fn from_err(e: Error) -> Self { - RespEnc::Error { - code: e.code(), - message: format!("{}", e), - } - } - - pub(crate) fn encode(self) -> ByteStream { - match self { - RespEnc::Success { msg, stream } => { - let mut buf = BytesMut::with_capacity(msg.len() + 8); - - buf.put_u8(0); + pub(crate) fn encode(resp: Result) -> (ByteStream, Option) { + match resp { + Ok(Self { + msg, + stream, + order_tag, + }) => { + let mut buf = BytesMut::with_capacity(4); buf.put_u32(msg.len() as u32); - let header = buf.freeze(); - if let Some(stream) = stream { + let res_stream: ByteStream = if let Some(stream) = stream { Box::pin(futures::stream::iter([Ok(header), Ok(msg)]).chain(stream)) } else { Box::pin(futures::stream::iter([Ok(header), Ok(msg)])) - } + }; + (res_stream, order_tag) } - RespEnc::Error { code, message } => { - let mut buf = BytesMut::with_capacity(message.len() + 8); - buf.put_u8(1 + message.len() as u8); - buf.put_u8(code); - buf.put(message.as_bytes()); - let header = buf.freeze(); - Box::pin(futures::stream::once(async move { Ok(header) })) + Err(err) => { + let err = std::io::Error::new( + std::io::ErrorKind::Other, + format!("netapp error: {}", err), + ); + ( + Box::pin(futures::stream::once(async move { Err(err) })), + None, + ) } } } pub(crate) async fn decode(stream: ByteStream) -> Result { - Self::decode_aux(stream).await.map_err(|_| Error::Framing) + Self::decode_aux(stream) + .await + .map_err(read_exact_error_to_error) } async fn decode_aux(stream: ByteStream) -> Result { let mut reader = ByteStreamReader::new(stream); - let is_err = reader.read_u8().await?; + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; - if is_err > 0 { - let code = reader.read_u8().await?; - let message = reader.read_exact(is_err as usize - 1).await?; - let message = String::from_utf8(message.to_vec()).unwrap_or_default(); - Ok(RespEnc::Error { code, message }) - } else { - let msg_len = reader.read_u32().await?; - let msg = reader.read_exact(msg_len as usize).await?; - - Ok(RespEnc::Success { - msg, - stream: Some(reader.into_stream()), - }) - } + Ok(Self { + msg, + stream: Some(reader.into_stream()), + order_tag: None, + }) + } +} + +fn read_exact_error_to_error(e: ReadExactError) -> Error { + match e { + ReadExactError::Stream(err) => Error::Remote(err.kind(), err.to_string()), + ReadExactError::UnexpectedEos => Error::Framing, } } diff --git a/src/recv.rs b/src/recv.rs index f8606f3..b5289fb 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -35,7 +35,10 @@ impl Sender { impl Drop for Sender { fn drop(&mut self) { if let Some(inner) = self.inner.take() { - let _ = inner.send(Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Netapp connection dropped before end of stream"))); + let _ = inner.send(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "Netapp connection dropped before end of stream", + ))); } } } @@ -82,7 +85,8 @@ pub(crate) trait RecvLoop: Sync + 'static { let packet = if is_error { let kind = u8_to_io_errorkind(next_slice[0]); - let msg = std::str::from_utf8(&next_slice[1..]).unwrap_or(""); + let msg = + std::str::from_utf8(&next_slice[1..]).unwrap_or(""); debug!("recv_loop: got id {}, error {:?}: {}", id, kind, msg); Some(Err(std::io::Error::new(kind, msg.to_string()))) } else { diff --git a/src/send.rs b/src/send.rs index 287fe40..c40787f 100644 --- a/src/send.rs +++ b/src/send.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use async_trait::async_trait; -use bytes::{Bytes, BytesMut, BufMut}; +use bytes::{BufMut, Bytes, BytesMut}; use log::*; use futures::AsyncWriteExt; @@ -36,6 +36,8 @@ pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF; +pub(crate) type SendStream = (RequestID, RequestPriority, ByteStream); + struct SendQueue { items: Vec<(u8, VecDeque)>, } @@ -184,7 +186,7 @@ impl DataFrame { pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, - msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, ByteStream)>, + msg_recv: mpsc::UnboundedReceiver, mut write: BoxStreamWrite, ) -> Result<(), Error> where diff --git a/src/server.rs b/src/server.rs index 57062d8..c23c9e4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -53,7 +53,7 @@ pub(crate) struct ServerConn { netapp: Arc, - resp_send: ArcSwapOption>, + resp_send: ArcSwapOption>, } impl ServerConn { @@ -177,26 +177,16 @@ impl RecvLoop for ServerConn { tokio::spawn(async move { debug!("server: recv_handler got {}", id); - let (prio, resp_enc) = match ReqEnc::decode(stream).await { - Ok(req_enc) => { - let prio = req_enc.prio; - let resp = self2.recv_handler_aux(req_enc).await; - - ( - prio, - match resp { - Ok(resp_enc) => resp_enc, - Err(e) => RespEnc::from_err(e), - }, - ) - } - Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)), + let (prio, resp_enc_result) = match ReqEnc::decode(stream).await { + Ok(req_enc) => (req_enc.prio, self2.recv_handler_aux(req_enc).await), + Err(e) => (PRIO_HIGH, Err(e)), }; debug!("server: sending response to {}", id); + let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result); resp_send - .send((id, prio, resp_enc.encode())) + .send((id, prio, resp_stream)) .log_err("ServerConn recv_handler send resp bytes"); Ok::<_, Error>(()) }); diff --git a/src/stream.rs b/src/stream.rs index 6e00e5f..efa0ebc 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -150,7 +150,6 @@ impl<'a> Future for ByteStreamReadExact<'a> { // ---- - pub fn asyncread_stream(reader: R) -> ByteStream { Box::pin(tokio_util::io::ReaderStream::new(reader)) }