From f35fa7d18d9e0f51bed311355ec1310b1d311ab3 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 21 Jul 2022 17:34:53 +0200 Subject: [PATCH] Move things around --- Makefile | 3 +- examples/basalt.rs | 3 +- examples/fullmesh.rs | 1 + src/client.rs | 18 ++- src/endpoint.rs | 78 +----------- src/error.rs | 2 +- src/lib.rs | 5 +- src/message.rs | 255 ++++++++++++++++++++++++++++++++++++++ src/netapp.rs | 2 +- src/peering/basalt.rs | 3 +- src/peering/fullmesh.rs | 3 +- src/proto2.rs | 75 ----------- src/recv.rs | 114 +++++++++++++++++ src/{proto.rs => send.rs} | 235 +++-------------------------------- src/server.rs | 32 +++-- src/util.rs | 12 +- 16 files changed, 429 insertions(+), 412 deletions(-) create mode 100644 src/message.rs delete mode 100644 src/proto2.rs create mode 100644 src/recv.rs rename src/{proto.rs => send.rs} (60%) diff --git a/Makefile b/Makefile index 5160725..de9a8f4 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,6 @@ all: - cargo build --all-features + #cargo build --all-features + cargo build cargo build --example fullmesh cargo build --all-features --example basalt RUST_LOG=netapp=trace,fullmesh=trace cargo run --example fullmesh -- -n 3242ce79e05e8b6a0e43441fbd140a906e13f335f298ae3a52f29784abbab500 -p 6c304114a0e1018bbe60502a34d33f4f439f370856c3333dda2726da01eb93a4894b7ef7249a71f11d342b69702f1beb7c93ec95fbcf122ad1eca583bb0629e7 diff --git a/examples/basalt.rs b/examples/basalt.rs index 318e37c..52fab4b 100644 --- a/examples/basalt.rs +++ b/examples/basalt.rs @@ -14,8 +14,9 @@ use sodiumoxide::crypto::sign::ed25519; use tokio::sync::watch; use netapp::endpoint::*; +use netapp::message::*; use netapp::peering::basalt::*; -use netapp::proto::*; +use netapp::send::*; use netapp::util::parse_peer_addr; use netapp::{NetApp, NodeID}; diff --git a/examples/fullmesh.rs b/examples/fullmesh.rs index b068410..4ab8a8a 100644 --- a/examples/fullmesh.rs +++ b/examples/fullmesh.rs @@ -10,6 +10,7 @@ use sodiumoxide::crypto::sign::ed25519; use netapp::peering::fullmesh::*; use netapp::util::*; + use netapp::NetApp; #[derive(StructOpt, Debug)] diff --git a/src/client.rs b/src/client.rs index 6d49f5c..663a3e4 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,9 +5,12 @@ use std::sync::atomic::{self, AtomicU32}; use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; +use async_trait::async_trait; use log::{debug, error, trace}; use futures::channel::mpsc::{unbounded, UnboundedReceiver}; +use futures::io::AsyncReadExt; +use kuska_handshake::async_std::{handshake_client, BoxStream}; use tokio::net::TcpStream; use tokio::select; use tokio::sync::{mpsc, oneshot, watch}; @@ -21,25 +24,18 @@ use opentelemetry::{ #[cfg(feature = "telemetry")] use opentelemetry_contrib::trace::propagator::binary::*; -use futures::io::AsyncReadExt; - -use async_trait::async_trait; - -use kuska_handshake::async_std::{handshake_client, BoxStream}; - -use crate::endpoint::*; use crate::error::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; -use crate::proto2::*; +use crate::recv::*; +use crate::send::*; use crate::util::*; pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: - ArcSwapOption>, + query_send: ArcSwapOption>, next_query_number: AtomicU32, inflight: Mutex>>>, diff --git a/src/endpoint.rs b/src/endpoint.rs index f31141d..e6b2236 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -5,79 +5,11 @@ use std::sync::Arc; use arc_swap::ArcSwapOption; use async_trait::async_trait; -use serde::{Deserialize, Serialize}; - use crate::error::Error; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; use crate::util::*; -/// This trait should be implemented by all messages your application -/// wants to handle -pub trait Message: SerializeMessage + Send + Sync { - type Response: SerializeMessage + Send + Sync; -} - -/// A trait for de/serializing messages, with possible associated stream. -#[async_trait] -pub trait SerializeMessage: Sized { - type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; - - fn serialize_msg(&self) -> (Self::SerializableSelf, Option); - - // TODO should return Result - async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self; -} - -pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} - -#[async_trait] -impl SerializeMessage for T -where - T: AutoSerialize, -{ - type SerializableSelf = Self; - fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { - (self.clone(), None) - } - - async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: AssociatedStream) -> Self { - // TODO verify no stream - ser_self - } -} - -impl AutoSerialize for () {} - -#[async_trait] -impl SerializeMessage for Result -where - T: SerializeMessage + Send, - E: SerializeMessage + Send, -{ - type SerializableSelf = Result; - - fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { - match self { - Ok(ok) => { - let (msg, stream) = ok.serialize_msg(); - (Ok(msg), stream) - } - Err(err) => { - let (msg, stream) = err.serialize_msg(); - (Err(msg), stream) - } - } - } - - async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self { - match ser_self { - Ok(ok) => Ok(T::deserialize_msg(ok, stream).await), - Err(err) => Err(E::deserialize_msg(err, stream).await), - } - } -} - /// This trait should be implemented by an object of your application /// that can handle a message of type `M`. /// @@ -191,9 +123,9 @@ pub(crate) trait GenericEndpoint { async fn handle( &self, buf: &[u8], - stream: AssociatedStream, + stream: ByteStream, from: NodeID, - ) -> Result<(Vec, Option), Error>; + ) -> Result<(Vec, Option), Error>; fn drop_handler(&self); fn clone_endpoint(&self) -> DynEndpoint; } @@ -213,9 +145,9 @@ where async fn handle( &self, buf: &[u8], - stream: AssociatedStream, + stream: ByteStream, from: NodeID, - ) -> Result<(Vec, Option), Error> { + ) -> Result<(Vec, Option), Error> { match self.0.handler.load_full() { None => Err(Error::NoHandler), Some(h) => { diff --git a/src/error.rs b/src/error.rs index 7911c29..665647c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ -use err_derive::Error; use std::io; +use err_derive::Error; use log::error; #[derive(Debug, Error)] diff --git a/src/lib.rs b/src/lib.rs index cb24337..1edb919 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,10 +17,11 @@ pub mod error; pub mod util; pub mod endpoint; -pub mod proto; +pub mod message; mod client; -mod proto2; +mod recv; +mod send; mod server; pub mod netapp; diff --git a/src/message.rs b/src/message.rs new file mode 100644 index 0000000..dbcc857 --- /dev/null +++ b/src/message.rs @@ -0,0 +1,255 @@ +use async_trait::async_trait; +use futures::stream::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; + +use crate::error::*; +use crate::util::*; + +/// Priority of a request (click to read more about priorities). +/// +/// This priority value is used to priorize messages +/// in the send queue of the client, and their responses in the send queue of the +/// server. Lower values mean higher priority. +/// +/// This mechanism is usefull for messages bigger than the maximum chunk size +/// (set at `0x4000` bytes), such as large file transfers. +/// In such case, all of the messages in the send queue with the highest priority +/// will take turns to send individual chunks, in a round-robin fashion. +/// Once all highest priority messages are sent successfully, the messages with +/// the next highest priority will begin being sent in the same way. +/// +/// The same priority value is given to a request and to its associated response. +pub type RequestPriority = u8; + +/// Priority class: high +pub const PRIO_HIGH: RequestPriority = 0x20; +/// Priority class: normal +pub const PRIO_NORMAL: RequestPriority = 0x40; +/// Priority class: background +pub const PRIO_BACKGROUND: RequestPriority = 0x80; +/// Priority: primary among given class +pub const PRIO_PRIMARY: RequestPriority = 0x00; +/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) +pub const PRIO_SECONDARY: RequestPriority = 0x01; + +// ---- + +/// This trait should be implemented by all messages your application +/// wants to handle +pub trait Message: SerializeMessage + Send + Sync { + type Response: SerializeMessage + Send + Sync; +} + +/// A trait for de/serializing messages, with possible associated stream. +#[async_trait] +pub trait SerializeMessage: Sized { + type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; + + fn serialize_msg(&self) -> (Self::SerializableSelf, Option); + + // TODO should return Result + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self; +} + +pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {} + +#[async_trait] +impl SerializeMessage for T +where + T: AutoSerialize, +{ + type SerializableSelf = Self; + fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { + (self.clone(), None) + } + + async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self { + // TODO verify no stream + ser_self + } +} + +impl AutoSerialize for () {} + +#[async_trait] +impl SerializeMessage for Result +where + T: SerializeMessage + Send, + E: SerializeMessage + Send, +{ + type SerializableSelf = Result; + + fn serialize_msg(&self) -> (Self::SerializableSelf, Option) { + match self { + Ok(ok) => { + let (msg, stream) = ok.serialize_msg(); + (Ok(msg), stream) + } + Err(err) => { + let (msg, stream) = err.serialize_msg(); + (Err(msg), stream) + } + } + } + + async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self { + match ser_self { + Ok(ok) => Ok(T::deserialize_msg(ok, stream).await), + Err(err) => Err(E::deserialize_msg(err, stream).await), + } + } +} + +// ---- + +pub(crate) struct QueryMessage<'a> { + pub(crate) prio: RequestPriority, + pub(crate) path: &'a [u8], + pub(crate) telemetry_id: Option>, + pub(crate) body: &'a [u8], +} + +/// QueryMessage encoding: +/// - priority: u8 +/// - path length: u8 +/// - path: [u8; path length] +/// - telemetry id length: u8 +/// - telemetry id: [u8; telemetry id length] +/// - body [u8; ..] +impl<'a> QueryMessage<'a> { + pub(crate) fn encode(self) -> Vec { + let tel_len = match &self.telemetry_id { + Some(t) => t.len(), + None => 0, + }; + + let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len()); + + ret.push(self.prio); + + ret.push(self.path.len() as u8); + ret.extend_from_slice(self.path); + + if let Some(t) = self.telemetry_id { + ret.push(t.len() as u8); + ret.extend(t); + } else { + ret.push(0u8); + } + + ret.extend_from_slice(self.body); + + ret + } + + pub(crate) fn decode(bytes: &'a [u8]) -> Result { + if bytes.len() < 3 { + return Err(Error::Message("Invalid protocol message".into())); + } + + let path_length = bytes[1] as usize; + if bytes.len() < 3 + path_length { + return Err(Error::Message("Invalid protocol message".into())); + } + + let telemetry_id_len = bytes[2 + path_length] as usize; + if bytes.len() < 3 + path_length + telemetry_id_len { + return Err(Error::Message("Invalid protocol message".into())); + } + + let path = &bytes[2..2 + path_length]; + let telemetry_id = if telemetry_id_len > 0 { + Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec()) + } else { + None + }; + + let body = &bytes[3 + path_length + telemetry_id_len..]; + + Ok(Self { + prio: bytes[0], + path, + telemetry_id, + body, + }) + } +} + +pub(crate) struct Framing { + direct: Vec, + stream: Option, +} + +impl Framing { + pub fn new(direct: Vec, stream: Option) -> Self { + assert!(direct.len() <= u32::MAX as usize); + Framing { direct, stream } + } + + pub fn into_stream(self) -> ByteStream { + use futures::stream; + let len = self.direct.len() as u32; + // required because otherwise the borrow-checker complains + let Framing { direct, stream } = self; + + let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) + .chain(stream::once(async move { Ok(direct) })); + + if let Some(stream) = stream { + Box::pin(res.chain(stream)) + } else { + Box::pin(res) + } + } + + pub async fn from_stream + Unpin + Send + 'static>( + mut stream: S, + ) -> Result { + let mut packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + if packet.len() < 4 { + return Err(Error::Framing); + } + + let mut len = [0; 4]; + len.copy_from_slice(&packet[..4]); + let len = u32::from_be_bytes(len); + packet.drain(..4); + + let mut buffer = Vec::new(); + let len = len as usize; + loop { + let max_cp = std::cmp::min(len - buffer.len(), packet.len()); + + buffer.extend_from_slice(&packet[..max_cp]); + if buffer.len() == len { + packet.drain(..max_cp); + break; + } + packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + } + + let stream: ByteStream = if packet.is_empty() { + Box::pin(stream) + } else { + Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) + }; + + Ok(Framing { + direct: buffer, + stream: Some(stream), + }) + } + + pub fn into_parts(self) -> (Vec, ByteStream) { + let Framing { direct, stream } = self; + (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) + } +} diff --git a/src/netapp.rs b/src/netapp.rs index 27f17e6..dd22d90 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -20,7 +20,7 @@ use tokio::sync::{mpsc, watch}; use crate::client::*; use crate::endpoint::*; use crate::error::*; -use crate::proto::*; +use crate::message::*; use crate::server::*; use crate::util::*; diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index 7f77995..98977a3 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -14,8 +14,9 @@ use sodiumoxide::crypto::hash; use tokio::sync::watch; use crate::endpoint::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; +use crate::send::*; use crate::NodeID; // -- Protocol messages -- diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 7dfc5c4..5b489ae 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -17,7 +17,8 @@ use sodiumoxide::crypto::hash; use crate::endpoint::*; use crate::error::*; use crate::netapp::*; -use crate::proto::*; + +use crate::message::*; use crate::NodeID; const CONN_RETRY_INTERVAL: Duration = Duration::from_secs(30); diff --git a/src/proto2.rs b/src/proto2.rs deleted file mode 100644 index 7210781..0000000 --- a/src/proto2.rs +++ /dev/null @@ -1,75 +0,0 @@ -use crate::error::*; -use crate::proto::*; - -pub(crate) struct QueryMessage<'a> { - pub(crate) prio: RequestPriority, - pub(crate) path: &'a [u8], - pub(crate) telemetry_id: Option>, - pub(crate) body: &'a [u8], -} - -/// QueryMessage encoding: -/// - priority: u8 -/// - path length: u8 -/// - path: [u8; path length] -/// - telemetry id length: u8 -/// - telemetry id: [u8; telemetry id length] -/// - body [u8; ..] -impl<'a> QueryMessage<'a> { - pub(crate) fn encode(self) -> Vec { - let tel_len = match &self.telemetry_id { - Some(t) => t.len(), - None => 0, - }; - - let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len()); - - ret.push(self.prio); - - ret.push(self.path.len() as u8); - ret.extend_from_slice(self.path); - - if let Some(t) = self.telemetry_id { - ret.push(t.len() as u8); - ret.extend(t); - } else { - ret.push(0u8); - } - - ret.extend_from_slice(self.body); - - ret - } - - pub(crate) fn decode(bytes: &'a [u8]) -> Result { - if bytes.len() < 3 { - return Err(Error::Message("Invalid protocol message".into())); - } - - let path_length = bytes[1] as usize; - if bytes.len() < 3 + path_length { - return Err(Error::Message("Invalid protocol message".into())); - } - - let telemetry_id_len = bytes[2 + path_length] as usize; - if bytes.len() < 3 + path_length + telemetry_id_len { - return Err(Error::Message("Invalid protocol message".into())); - } - - let path = &bytes[2..2 + path_length]; - let telemetry_id = if telemetry_id_len > 0 { - Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec()) - } else { - None - }; - - let body = &bytes[3 + path_length + telemetry_id_len..]; - - Ok(Self { - prio: bytes[0], - path, - telemetry_id, - body, - }) - } -} diff --git a/src/recv.rs b/src/recv.rs new file mode 100644 index 0000000..628612b --- /dev/null +++ b/src/recv.rs @@ -0,0 +1,114 @@ +use std::collections::HashMap; + +use std::sync::Arc; + +use log::trace; + +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; +use futures::AsyncReadExt; + +use async_trait::async_trait; + +use crate::error::*; + +use crate::send::*; +use crate::util::Packet; + +/// Structure to warn when the sender is dropped before end of stream was reached, like when +/// connection to some remote drops while transmitting data +struct Sender { + inner: UnboundedSender, + closed: bool, +} + +impl Sender { + fn new(inner: UnboundedSender) -> Self { + Sender { + inner, + closed: false, + } + } + + fn send(&self, packet: Packet) { + let _ = self.inner.unbounded_send(packet); + } + + fn end(&mut self) { + self.closed = true; + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if !self.closed { + self.send(Err(255)); + } + self.inner.close_channel(); + } +} + +/// The RecvLoop trait, which is implemented both by the client and the server +/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` +/// and a prototype of a handler for received messages `.recv_handler()` that +/// must be filled by implementors. `.recv_loop()` receives messages in a loop +/// according to the protocol defined above: chunks of message in progress of being +/// received are stored in a buffer, and when the last chunk of a message is received, +/// the full message is passed to the receive handler. +#[async_trait] +pub(crate) trait RecvLoop: Sync + 'static { + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver); + + async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> + where + R: AsyncReadExt + Unpin + Send + Sync, + { + let mut streams: HashMap = HashMap::new(); + loop { + trace!("recv_loop: reading packet"); + let mut header_id = [0u8; RequestID::BITS as usize / 8]; + match read.read_exact(&mut header_id[..]).await { + Ok(_) => (), + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e.into()), + }; + let id = RequestID::from_be_bytes(header_id); + trace!("recv_loop: got header id: {:04x}", id); + + let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; + read.read_exact(&mut header_size[..]).await?; + let size = ChunkLength::from_be_bytes(header_size); + trace!("recv_loop: got header size: {:04x}", size); + + let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; + let is_error = (size & ERROR_MARKER) != 0; + let packet = if is_error { + Err(size as u8) + } else { + let size = size & !CHUNK_HAS_CONTINUATION; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + trace!("recv_loop: read {} bytes", next_slice.len()); + Ok(next_slice) + }; + + let mut sender = if let Some(send) = streams.remove(&(id)) { + send + } else { + let (send, recv) = unbounded(); + self.recv_handler(id, recv); + Sender::new(send) + }; + + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + sender.send(packet); + + if has_cont { + streams.insert(id, sender); + } else { + sender.end(); + } + } + Ok(()) + } +} diff --git a/src/proto.rs b/src/send.rs similarity index 60% rename from src/proto.rs rename to src/send.rs index 92d8d80..330d41d 100644 --- a/src/proto.rs +++ b/src/send.rs @@ -1,48 +1,19 @@ -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use async_trait::async_trait; use log::trace; -use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; -use futures::{AsyncReadExt, AsyncWriteExt}; -use futures::{Stream, StreamExt}; +use futures::AsyncWriteExt; +use futures::Stream; use kuska_handshake::async_std::BoxStreamWrite; - use tokio::sync::mpsc; -use async_trait::async_trait; - use crate::error::*; -use crate::util::{AssociatedStream, Packet}; - -/// Priority of a request (click to read more about priorities). -/// -/// This priority value is used to priorize messages -/// in the send queue of the client, and their responses in the send queue of the -/// server. Lower values mean higher priority. -/// -/// This mechanism is usefull for messages bigger than the maximum chunk size -/// (set at `0x4000` bytes), such as large file transfers. -/// In such case, all of the messages in the send queue with the highest priority -/// will take turns to send individual chunks, in a round-robin fashion. -/// Once all highest priority messages are sent successfully, the messages with -/// the next highest priority will begin being sent in the same way. -/// -/// The same priority value is given to a request and to its associated response. -pub type RequestPriority = u8; - -/// Priority class: high -pub const PRIO_HIGH: RequestPriority = 0x20; -/// Priority class: normal -pub const PRIO_NORMAL: RequestPriority = 0x40; -/// Priority class: background -pub const PRIO_BACKGROUND: RequestPriority = 0x80; -/// Priority: primary among given class -pub const PRIO_PRIMARY: RequestPriority = 0x00; -/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) -pub const PRIO_SECONDARY: RequestPriority = 0x01; +use crate::message::*; +use crate::util::{ByteStream, Packet}; // Messages are sent by chunks // Chunk format: @@ -52,10 +23,10 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; // - [u8; chunk_length] chunk data pub(crate) type RequestID = u32; -type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; -const ERROR_MARKER: ChunkLength = 0x4000; -const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; +pub(crate) type ChunkLength = u16; +pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; +pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; +pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { id: RequestID, @@ -66,15 +37,15 @@ struct SendQueueItem { #[pin_project::pin_project] struct DataReader { #[pin] - reader: AssociatedStream, + reader: ByteStream, packet: Packet, pos: usize, buf: Vec, eos: bool, } -impl From for DataReader { - fn from(data: AssociatedStream) -> DataReader { +impl From for DataReader { + fn from(data: ByteStream) -> DataReader { DataReader { reader: data, packet: Ok(Vec::new()), @@ -297,7 +268,7 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> { pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, ByteStream)>, mut write: BoxStreamWrite, ) -> Result<(), Error> where @@ -343,184 +314,6 @@ pub(crate) trait SendLoop: Sync { } } -pub(crate) struct Framing { - direct: Vec, - stream: Option, -} - -impl Framing { - pub fn new(direct: Vec, stream: Option) -> Self { - assert!(direct.len() <= u32::MAX as usize); - Framing { direct, stream } - } - - pub fn into_stream(self) -> AssociatedStream { - use futures::stream; - let len = self.direct.len() as u32; - // required because otherwise the borrow-checker complains - let Framing { direct, stream } = self; - - let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) - .chain(stream::once(async move { Ok(direct) })); - - if let Some(stream) = stream { - Box::pin(res.chain(stream)) - } else { - Box::pin(res) - } - } - - pub async fn from_stream + Unpin + Send + 'static>( - mut stream: S, - ) -> Result { - let mut packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - if packet.len() < 4 { - return Err(Error::Framing); - } - - let mut len = [0; 4]; - len.copy_from_slice(&packet[..4]); - let len = u32::from_be_bytes(len); - packet.drain(..4); - - let mut buffer = Vec::new(); - let len = len as usize; - loop { - let max_cp = std::cmp::min(len - buffer.len(), packet.len()); - - buffer.extend_from_slice(&packet[..max_cp]); - if buffer.len() == len { - packet.drain(..max_cp); - break; - } - packet = stream - .next() - .await - .ok_or(Error::Framing)? - .map_err(|_| Error::Framing)?; - } - - let stream: AssociatedStream = if packet.is_empty() { - Box::pin(stream) - } else { - Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) - }; - - Ok(Framing { - direct: buffer, - stream: Some(stream), - }) - } - - pub fn into_parts(self) -> (Vec, AssociatedStream) { - let Framing { direct, stream } = self; - (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) - } -} - -/// Structure to warn when the sender is dropped before end of stream was reached, like when -/// connection to some remote drops while transmitting data -struct Sender { - inner: UnboundedSender, - closed: bool, -} - -impl Sender { - fn new(inner: UnboundedSender) -> Self { - Sender { - inner, - closed: false, - } - } - - fn send(&self, packet: Packet) { - let _ = self.inner.unbounded_send(packet); - } - - fn end(&mut self) { - self.closed = true; - } -} - -impl Drop for Sender { - fn drop(&mut self) { - if !self.closed { - self.send(Err(255)); - } - self.inner.close_channel(); - } -} - -/// The RecvLoop trait, which is implemented both by the client and the server -/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` -/// and a prototype of a handler for received messages `.recv_handler()` that -/// must be filled by implementors. `.recv_loop()` receives messages in a loop -/// according to the protocol defined above: chunks of message in progress of being -/// received are stored in a buffer, and when the last chunk of a message is received, -/// the full message is passed to the receive handler. -#[async_trait] -pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver); - - async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> - where - R: AsyncReadExt + Unpin + Send + Sync, - { - let mut streams: HashMap = HashMap::new(); - loop { - trace!("recv_loop: reading packet"); - let mut header_id = [0u8; RequestID::BITS as usize / 8]; - match read.read_exact(&mut header_id[..]).await { - Ok(_) => (), - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, - Err(e) => return Err(e.into()), - }; - let id = RequestID::from_be_bytes(header_id); - trace!("recv_loop: got header id: {:04x}", id); - - let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; - read.read_exact(&mut header_size[..]).await?; - let size = ChunkLength::from_be_bytes(header_size); - trace!("recv_loop: got header size: {:04x}", size); - - let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; - let is_error = (size & ERROR_MARKER) != 0; - let packet = if is_error { - Err(size as u8) - } else { - let size = size & !CHUNK_HAS_CONTINUATION; - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; - trace!("recv_loop: read {} bytes", next_slice.len()); - Ok(next_slice) - }; - - let mut sender = if let Some(send) = streams.remove(&(id)) { - send - } else { - let (send, recv) = unbounded(); - self.recv_handler(id, recv); - Sender::new(send) - }; - - // if we get an error, the receiving end is disconnected. We still need to - // reach eos before dropping this sender - sender.send(packet); - - if has_cont { - streams.insert(id, sender); - } else { - sender.end(); - } - } - Ok(()) - } -} - #[cfg(test)] mod test { use super::*; diff --git a/src/server.rs b/src/server.rs index 8075484..1f1c22a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,8 +2,17 @@ use std::net::SocketAddr; use std::sync::Arc; use arc_swap::ArcSwapOption; +use async_trait::async_trait; use log::{debug, trace}; +use futures::channel::mpsc::UnboundedReceiver; +use futures::io::{AsyncReadExt, AsyncWriteExt}; +use kuska_handshake::async_std::{handshake_server, BoxStream}; +use tokio::net::TcpStream; +use tokio::select; +use tokio::sync::{mpsc, watch}; +use tokio_util::compat::*; + #[cfg(feature = "telemetry")] use opentelemetry::{ trace::{FutureExt, Span, SpanKind, TraceContextExt, TraceId, Tracer}, @@ -14,22 +23,11 @@ use opentelemetry_contrib::trace::propagator::binary::*; #[cfg(feature = "telemetry")] use rand::{thread_rng, Rng}; -use tokio::net::TcpStream; -use tokio::select; -use tokio::sync::{mpsc, watch}; -use tokio_util::compat::*; - -use futures::channel::mpsc::UnboundedReceiver; -use futures::io::{AsyncReadExt, AsyncWriteExt}; - -use async_trait::async_trait; - -use kuska_handshake::async_std::{handshake_server, BoxStream}; - use crate::error::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; -use crate::proto2::*; +use crate::recv::*; +use crate::send::*; use crate::util::*; // The client and server connection structs (client.rs and server.rs) @@ -55,7 +53,7 @@ pub(crate) struct ServerConn { netapp: Arc, - resp_send: ArcSwapOption>, + resp_send: ArcSwapOption>, } impl ServerConn { @@ -126,8 +124,8 @@ impl ServerConn { async fn recv_handler_aux( self: &Arc, bytes: &[u8], - stream: AssociatedStream, - ) -> Result<(Vec, Option), Error> { + stream: ByteStream, + ) -> Result<(Vec, Option), Error> { let msg = QueryMessage::decode(bytes)?; let path = String::from_utf8(msg.path.to_vec())?; diff --git a/src/util.rs b/src/util.rs index 186678d..6fbafe6 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,17 +1,15 @@ -use crate::endpoint::SerializeMessage; - use std::net::SocketAddr; use std::net::ToSocketAddrs; use std::pin::Pin; -use futures::Stream; - use log::info; - use serde::Serialize; +use futures::Stream; use tokio::sync::watch; +use crate::message::SerializeMessage; + /// A node's identifier, which is also its public cryptographic key pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey; /// A node's secret key @@ -27,7 +25,7 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// /// Error code 255 means the stream was cut before its end. Other codes have no predefined /// meaning, it's up to your application to define their semantic. -pub type AssociatedStream = Pin + Send>>; +pub type ByteStream = Pin + Send>>; pub type Packet = Result, u8>; @@ -38,7 +36,7 @@ pub type Packet = Result, u8>; /// This is used internally by the netapp communication protocol. pub fn rmp_to_vec_all_named( val: &T, -) -> Result<(Vec, Option), rmp_serde::encode::Error> +) -> Result<(Vec, Option), rmp_serde::encode::Error> where T: SerializeMessage + ?Sized, {