From 3b8bff634198c5ae17ab16d5c85c30b3201ae593 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 21 Feb 2022 12:01:04 +0100 Subject: [PATCH] Refactoring --- src/client.rs | 31 ++++++++++---------- src/endpoint.rs | 10 ++++--- src/lib.rs | 1 + src/proto.rs | 24 ++++++++++++---- src/proto2.rs | 75 +++++++++++++++++++++++++++++++++++++++++++++++++ src/server.rs | 42 +++++++++------------------ 6 files changed, 131 insertions(+), 52 deletions(-) create mode 100644 src/proto2.rs diff --git a/src/client.rs b/src/client.rs index d6caf68..e2d5d84 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::net::SocketAddr; use std::sync::atomic::{self, AtomicU32}; use std::sync::{Arc, Mutex}; +use std::borrow::Borrow; use arc_swap::ArcSwapOption; use log::{debug, error, trace}; @@ -29,6 +30,7 @@ use crate::endpoint::*; use crate::error::*; use crate::netapp::*; use crate::proto::*; +use crate::proto2::*; use crate::util::*; pub(crate) struct ClientConn { @@ -118,14 +120,15 @@ impl ClientConn { self.query_send.store(None); } - pub(crate) async fn call( + pub(crate) async fn call<'a, T, B>( self: Arc, - rq: &T, - path: &str, + rq: B, + path: &'a str, prio: RequestPriority, ) -> Result<::Response, Error> where T: Message, + B: Borrow, { let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; @@ -147,19 +150,17 @@ impl ClientConn { }; // Encode request - let mut bytes = vec![]; + let body = rmp_to_vec_all_named(rq.borrow())?; + drop(rq); - bytes.extend_from_slice(&[prio, path.as_bytes().len() as u8]); - bytes.extend_from_slice(path.as_bytes()); - - if let Some(by) = telemetry_id { - bytes.push(by.len() as u8); - bytes.extend(by); - } else { - bytes.push(0); - } - - bytes.extend_from_slice(&rmp_to_vec_all_named(rq)?[..]); + let request = QueryMessage { + prio, + path: path.as_bytes(), + telemetry_id, + body: &body[..], + }; + let bytes = request.encode(); + drop(body); // Send request through let (resp_send, resp_recv) = oneshot::channel(); diff --git a/src/endpoint.rs b/src/endpoint.rs index 760bf32..b408241 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -1,5 +1,6 @@ use std::marker::PhantomData; use std::sync::Arc; +use std::borrow::Borrow; use arc_swap::ArcSwapOption; use async_trait::async_trait; @@ -88,16 +89,17 @@ where /// Call this endpoint on a remote node (or on the local node, /// for that matter) - pub async fn call( + pub async fn call( &self, target: &NodeID, - req: &M, + req: B, prio: RequestPriority, - ) -> Result<::Response, Error> { + ) -> Result<::Response, Error> + where B: Borrow { if *target == self.netapp.id { match self.handler.load_full() { None => Err(Error::NoHandler), - Some(h) => Ok(h.handle(req, self.netapp.id).await), + Some(h) => Ok(h.handle(req.borrow(), self.netapp.id).await), } } else { let conn = self diff --git a/src/lib.rs b/src/lib.rs index 3162c42..89b4f32 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,7 @@ pub mod util; pub mod endpoint; pub mod proto; +mod proto2; mod client; mod server; diff --git a/src/proto.rs b/src/proto.rs index 18e7c44..2db3f83 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -96,6 +96,14 @@ impl SendQueue { } } +/// The SendLoop trait, which is implemented both by the client and the server +/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()` +/// that takes a channel of messages to send and an asynchronous writer, +/// and sends messages from the channel to the async writer, putting them in a queue +/// before being sent and doing the round-robin sending strategy. +/// +/// The `.send_loop()` exits when the sending end of the channel is closed, +/// or if there is an error at any time writing to the async writer. #[async_trait] pub(crate) trait SendLoop: Sync { async fn send_loop( @@ -128,9 +136,9 @@ pub(crate) trait SendLoop: Sync { write.write_all(&header_id[..]).await?; if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize { - let header_size = + let size_header = ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION); - write.write_all(&header_size[..]).await?; + write.write_all(&size_header[..]).await?; let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize; write.write_all(&item.data[item.cursor..new_cursor]).await?; @@ -140,8 +148,8 @@ pub(crate) trait SendLoop: Sync { } else { let send_len = (item.data.len() - item.cursor) as ChunkLength; - let header_size = ChunkLength::to_be_bytes(send_len); - write.write_all(&header_size[..]).await?; + let size_header = ChunkLength::to_be_bytes(send_len); + write.write_all(&size_header[..]).await?; write.write_all(&item.data[item.cursor..]).await?; } @@ -166,9 +174,15 @@ pub(crate) trait SendLoop: Sync { } } +/// The RecvLoop trait, which is implemented both by the client and the server +/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` +/// and a prototype of a handler for received messages `.recv_handler()` that +/// must be filled by implementors. `.recv_loop()` receives messages in a loop +/// according to the protocol defined above: chunks of message in progress of being +/// received are stored in a buffer, and when the last chunk of a message is received, +/// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - // Returns true if we should stop receiving after this fn recv_handler(self: &Arc, id: RequestID, msg: Vec); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> diff --git a/src/proto2.rs b/src/proto2.rs new file mode 100644 index 0000000..4e126d3 --- /dev/null +++ b/src/proto2.rs @@ -0,0 +1,75 @@ +use crate::proto::*; +use crate::error::*; + +pub(crate) struct QueryMessage<'a> { + pub(crate) prio: RequestPriority, + pub(crate) path: &'a [u8], + pub(crate) telemetry_id: Option>, + pub(crate) body: &'a [u8], +} + +/// QueryMessage encoding: +/// - priority: u8 +/// - path length: u8 +/// - path: [u8; path length] +/// - telemetry id length: u8 +/// - telemetry id: [u8; telemetry id length] +/// - body [u8; ..] +impl<'a> QueryMessage<'a> { + pub(crate) fn encode(self) -> Vec { + let tel_len = match &self.telemetry_id { + Some(t) => t.len(), + None => 0, + }; + + let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len()); + + ret.push(self.prio); + + ret.push(self.path.len() as u8); + ret.extend_from_slice(self.path); + + if let Some(t) = self.telemetry_id { + ret.push(t.len() as u8); + ret.extend(t); + } else { + ret.push(0u8); + } + + ret.extend_from_slice(self.body); + + ret + } + + pub(crate) fn decode(bytes: &'a [u8]) -> Result { + if bytes.len() < 3 { + return Err(Error::Message("Invalid protocol message".into())); + } + + let path_length = bytes[1] as usize; + if bytes.len() < 3 + path_length { + return Err(Error::Message("Invalid protocol message".into())); + } + + let telemetry_id_len = bytes[2 + path_length] as usize; + if bytes.len() < 3 + path_length + telemetry_id_len { + return Err(Error::Message("Invalid protocol message".into())); + } + + let path = &bytes[2..2 + path_length]; + let telemetry_id = if telemetry_id_len > 0 { + Some(bytes[3 + path_length .. 3 + path_length + telemetry_id_len].to_vec()) + } else { + None + }; + + let body = &bytes[3 + path_length + telemetry_id_len..]; + + Ok(Self { + prio: bytes[0], + path, + telemetry_id, + body, + }) + } +} diff --git a/src/server.rs b/src/server.rs index 7bf17df..eb70057 100644 --- a/src/server.rs +++ b/src/server.rs @@ -29,6 +29,7 @@ use kuska_handshake::async_std::{handshake_server, BoxStream}; use crate::error::*; use crate::netapp::*; use crate::proto::*; +use crate::proto2::*; use crate::util::*; // The client and server connection structs (client.rs and server.rs) @@ -116,22 +117,8 @@ impl ServerConn { } async fn recv_handler_aux(self: &Arc, bytes: &[u8]) -> Result, Error> { - if bytes.len() < 2 { - return Err(Error::Message("Invalid protocol message".into())); - } - - // byte 0 is the request priority, we don't care here - let path_length = bytes[1] as usize; - if bytes.len() < 2 + path_length { - return Err(Error::Message("Invalid protocol message".into())); - } - - let path = &bytes[2..2 + path_length]; - let path = String::from_utf8(path.to_vec())?; - - let telemetry_id_len = bytes[2 + path_length] as usize; - - let data = &bytes[3 + path_length + telemetry_id_len..]; + let msg = QueryMessage::decode(bytes)?; + let path = String::from_utf8(msg.path.to_vec())?; let handler_opt = { let endpoints = self.netapp.endpoints.read().unwrap(); @@ -143,10 +130,9 @@ impl ServerConn { if #[cfg(feature = "telemetry")] { let tracer = opentelemetry::global::tracer("netapp"); - let mut span = if telemetry_id_len > 0 { - let by = bytes[3+path_length..3+path_length+telemetry_id_len].to_vec(); + let mut span = if let Some(telemetry_id) = msg.telemetry_id { let propagator = BinaryPropagator::new(); - let context = propagator.from_bytes(by); + let context = propagator.from_bytes(telemetry_id); let context = Context::new().with_remote_span_context(context); tracer.span_builder(format!(">> RPC {}", path)) .with_kind(SpanKind::Server) @@ -161,13 +147,13 @@ impl ServerConn { .start(&tracer) }; span.set_attribute(KeyValue::new("path", path.to_string())); - span.set_attribute(KeyValue::new("len_query", data.len() as i64)); + span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64)); - handler.handle(data, self.peer_id) + handler.handle(msg.body, self.peer_id) .with_context(Context::current_with_span(span)) .await } else { - handler.handle(data, self.peer_id).await + handler.handle(msg.body, self.peer_id).await } } } else { @@ -191,16 +177,16 @@ impl RecvLoop for ServerConn { let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; let resp = self2.recv_handler_aux(&bytes[..]).await; - let mut resp_bytes = vec![]; - match resp { + let resp_bytes = match resp { Ok(rb) => { - resp_bytes.push(0u8); - resp_bytes.extend(&rb[..]); + let mut resp_bytes = vec![0u8]; + resp_bytes.extend(rb); + resp_bytes } Err(e) => { - resp_bytes.push(e.code()); + vec![e.code()] } - } + }; trace!("ServerConn sending response to {}: ", id);