Refactoring

This commit is contained in:
Alex 2022-02-21 12:01:04 +01:00
parent 109d6c143d
commit 3b8bff6341
Signed by untrusted user: lx
GPG key ID: 0E496D15096376BE
6 changed files with 131 additions and 52 deletions

View file

@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::{self, AtomicU32}; use std::sync::atomic::{self, AtomicU32};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::borrow::Borrow;
use arc_swap::ArcSwapOption; use arc_swap::ArcSwapOption;
use log::{debug, error, trace}; use log::{debug, error, trace};
@ -29,6 +30,7 @@ use crate::endpoint::*;
use crate::error::*; use crate::error::*;
use crate::netapp::*; use crate::netapp::*;
use crate::proto::*; use crate::proto::*;
use crate::proto2::*;
use crate::util::*; use crate::util::*;
pub(crate) struct ClientConn { pub(crate) struct ClientConn {
@ -118,14 +120,15 @@ impl ClientConn {
self.query_send.store(None); self.query_send.store(None);
} }
pub(crate) async fn call<T>( pub(crate) async fn call<'a, T, B>(
self: Arc<Self>, self: Arc<Self>,
rq: &T, rq: B,
path: &str, path: &'a str,
prio: RequestPriority, prio: RequestPriority,
) -> Result<<T as Message>::Response, Error> ) -> Result<<T as Message>::Response, Error>
where where
T: Message, T: Message,
B: Borrow<T>,
{ {
let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
@ -147,19 +150,17 @@ impl ClientConn {
}; };
// Encode request // Encode request
let mut bytes = vec![]; let body = rmp_to_vec_all_named(rq.borrow())?;
drop(rq);
bytes.extend_from_slice(&[prio, path.as_bytes().len() as u8]); let request = QueryMessage {
bytes.extend_from_slice(path.as_bytes()); prio,
path: path.as_bytes(),
if let Some(by) = telemetry_id { telemetry_id,
bytes.push(by.len() as u8); body: &body[..],
bytes.extend(by); };
} else { let bytes = request.encode();
bytes.push(0); drop(body);
}
bytes.extend_from_slice(&rmp_to_vec_all_named(rq)?[..]);
// Send request through // Send request through
let (resp_send, resp_recv) = oneshot::channel(); let (resp_send, resp_recv) = oneshot::channel();

View file

@ -1,5 +1,6 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::sync::Arc; use std::sync::Arc;
use std::borrow::Borrow;
use arc_swap::ArcSwapOption; use arc_swap::ArcSwapOption;
use async_trait::async_trait; use async_trait::async_trait;
@ -88,16 +89,17 @@ where
/// Call this endpoint on a remote node (or on the local node, /// Call this endpoint on a remote node (or on the local node,
/// for that matter) /// for that matter)
pub async fn call( pub async fn call<B>(
&self, &self,
target: &NodeID, target: &NodeID,
req: &M, req: B,
prio: RequestPriority, prio: RequestPriority,
) -> Result<<M as Message>::Response, Error> { ) -> Result<<M as Message>::Response, Error>
where B: Borrow<M> {
if *target == self.netapp.id { if *target == self.netapp.id {
match self.handler.load_full() { match self.handler.load_full() {
None => Err(Error::NoHandler), None => Err(Error::NoHandler),
Some(h) => Ok(h.handle(req, self.netapp.id).await), Some(h) => Ok(h.handle(req.borrow(), self.netapp.id).await),
} }
} else { } else {
let conn = self let conn = self

View file

@ -19,6 +19,7 @@ pub mod util;
pub mod endpoint; pub mod endpoint;
pub mod proto; pub mod proto;
mod proto2;
mod client; mod client;
mod server; mod server;

View file

@ -96,6 +96,14 @@ impl SendQueue {
} }
} }
/// The SendLoop trait, which is implemented both by the client and the server
/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()`
/// that takes a channel of messages to send and an asynchronous writer,
/// and sends messages from the channel to the async writer, putting them in a queue
/// before being sent and doing the round-robin sending strategy.
///
/// The `.send_loop()` exits when the sending end of the channel is closed,
/// or if there is an error at any time writing to the async writer.
#[async_trait] #[async_trait]
pub(crate) trait SendLoop: Sync { pub(crate) trait SendLoop: Sync {
async fn send_loop<W>( async fn send_loop<W>(
@ -128,9 +136,9 @@ pub(crate) trait SendLoop: Sync {
write.write_all(&header_id[..]).await?; write.write_all(&header_id[..]).await?;
if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize { if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize {
let header_size = let size_header =
ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION); ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION);
write.write_all(&header_size[..]).await?; write.write_all(&size_header[..]).await?;
let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize; let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize;
write.write_all(&item.data[item.cursor..new_cursor]).await?; write.write_all(&item.data[item.cursor..new_cursor]).await?;
@ -140,8 +148,8 @@ pub(crate) trait SendLoop: Sync {
} else { } else {
let send_len = (item.data.len() - item.cursor) as ChunkLength; let send_len = (item.data.len() - item.cursor) as ChunkLength;
let header_size = ChunkLength::to_be_bytes(send_len); let size_header = ChunkLength::to_be_bytes(send_len);
write.write_all(&header_size[..]).await?; write.write_all(&size_header[..]).await?;
write.write_all(&item.data[item.cursor..]).await?; write.write_all(&item.data[item.cursor..]).await?;
} }
@ -166,9 +174,15 @@ pub(crate) trait SendLoop: Sync {
} }
} }
/// The RecvLoop trait, which is implemented both by the client and the server
/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
/// and a prototype of a handler for received messages `.recv_handler()` that
/// must be filled by implementors. `.recv_loop()` receives messages in a loop
/// according to the protocol defined above: chunks of message in progress of being
/// received are stored in a buffer, and when the last chunk of a message is received,
/// the full message is passed to the receive handler.
#[async_trait] #[async_trait]
pub(crate) trait RecvLoop: Sync + 'static { pub(crate) trait RecvLoop: Sync + 'static {
// Returns true if we should stop receiving after this
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>); fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>);
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>

75
src/proto2.rs Normal file
View file

@ -0,0 +1,75 @@
use crate::proto::*;
use crate::error::*;
pub(crate) struct QueryMessage<'a> {
pub(crate) prio: RequestPriority,
pub(crate) path: &'a [u8],
pub(crate) telemetry_id: Option<Vec<u8>>,
pub(crate) body: &'a [u8],
}
/// QueryMessage encoding:
/// - priority: u8
/// - path length: u8
/// - path: [u8; path length]
/// - telemetry id length: u8
/// - telemetry id: [u8; telemetry id length]
/// - body [u8; ..]
impl<'a> QueryMessage<'a> {
pub(crate) fn encode(self) -> Vec<u8> {
let tel_len = match &self.telemetry_id {
Some(t) => t.len(),
None => 0,
};
let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len());
ret.push(self.prio);
ret.push(self.path.len() as u8);
ret.extend_from_slice(self.path);
if let Some(t) = self.telemetry_id {
ret.push(t.len() as u8);
ret.extend(t);
} else {
ret.push(0u8);
}
ret.extend_from_slice(self.body);
ret
}
pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> {
if bytes.len() < 3 {
return Err(Error::Message("Invalid protocol message".into()));
}
let path_length = bytes[1] as usize;
if bytes.len() < 3 + path_length {
return Err(Error::Message("Invalid protocol message".into()));
}
let telemetry_id_len = bytes[2 + path_length] as usize;
if bytes.len() < 3 + path_length + telemetry_id_len {
return Err(Error::Message("Invalid protocol message".into()));
}
let path = &bytes[2..2 + path_length];
let telemetry_id = if telemetry_id_len > 0 {
Some(bytes[3 + path_length .. 3 + path_length + telemetry_id_len].to_vec())
} else {
None
};
let body = &bytes[3 + path_length + telemetry_id_len..];
Ok(Self {
prio: bytes[0],
path,
telemetry_id,
body,
})
}
}

View file

@ -29,6 +29,7 @@ use kuska_handshake::async_std::{handshake_server, BoxStream};
use crate::error::*; use crate::error::*;
use crate::netapp::*; use crate::netapp::*;
use crate::proto::*; use crate::proto::*;
use crate::proto2::*;
use crate::util::*; use crate::util::*;
// The client and server connection structs (client.rs and server.rs) // The client and server connection structs (client.rs and server.rs)
@ -116,22 +117,8 @@ impl ServerConn {
} }
async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> { async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
if bytes.len() < 2 { let msg = QueryMessage::decode(bytes)?;
return Err(Error::Message("Invalid protocol message".into())); let path = String::from_utf8(msg.path.to_vec())?;
}
// byte 0 is the request priority, we don't care here
let path_length = bytes[1] as usize;
if bytes.len() < 2 + path_length {
return Err(Error::Message("Invalid protocol message".into()));
}
let path = &bytes[2..2 + path_length];
let path = String::from_utf8(path.to_vec())?;
let telemetry_id_len = bytes[2 + path_length] as usize;
let data = &bytes[3 + path_length + telemetry_id_len..];
let handler_opt = { let handler_opt = {
let endpoints = self.netapp.endpoints.read().unwrap(); let endpoints = self.netapp.endpoints.read().unwrap();
@ -143,10 +130,9 @@ impl ServerConn {
if #[cfg(feature = "telemetry")] { if #[cfg(feature = "telemetry")] {
let tracer = opentelemetry::global::tracer("netapp"); let tracer = opentelemetry::global::tracer("netapp");
let mut span = if telemetry_id_len > 0 { let mut span = if let Some(telemetry_id) = msg.telemetry_id {
let by = bytes[3+path_length..3+path_length+telemetry_id_len].to_vec();
let propagator = BinaryPropagator::new(); let propagator = BinaryPropagator::new();
let context = propagator.from_bytes(by); let context = propagator.from_bytes(telemetry_id);
let context = Context::new().with_remote_span_context(context); let context = Context::new().with_remote_span_context(context);
tracer.span_builder(format!(">> RPC {}", path)) tracer.span_builder(format!(">> RPC {}", path))
.with_kind(SpanKind::Server) .with_kind(SpanKind::Server)
@ -161,13 +147,13 @@ impl ServerConn {
.start(&tracer) .start(&tracer)
}; };
span.set_attribute(KeyValue::new("path", path.to_string())); span.set_attribute(KeyValue::new("path", path.to_string()));
span.set_attribute(KeyValue::new("len_query", data.len() as i64)); span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));
handler.handle(data, self.peer_id) handler.handle(msg.body, self.peer_id)
.with_context(Context::current_with_span(span)) .with_context(Context::current_with_span(span))
.await .await
} else { } else {
handler.handle(data, self.peer_id).await handler.handle(msg.body, self.peer_id).await
} }
} }
} else { } else {
@ -191,16 +177,16 @@ impl RecvLoop for ServerConn {
let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
let resp = self2.recv_handler_aux(&bytes[..]).await; let resp = self2.recv_handler_aux(&bytes[..]).await;
let mut resp_bytes = vec![]; let resp_bytes = match resp {
match resp {
Ok(rb) => { Ok(rb) => {
resp_bytes.push(0u8); let mut resp_bytes = vec![0u8];
resp_bytes.extend(&rb[..]); resp_bytes.extend(rb);
resp_bytes
} }
Err(e) => { Err(e) => {
resp_bytes.push(e.code()); vec![e.code()]
}
} }
};
trace!("ServerConn sending response to {}: ", id); trace!("ServerConn sending response to {}: ", id);