forked from lx/netapp
WIP: associated stream #1
8 changed files with 382 additions and 83 deletions
44
Cargo.lock
generated
44
Cargo.lock
generated
|
@ -151,6 +151,19 @@ dependencies = [
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "env_logger"
|
||||||
|
version = "0.7.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "44533bbbb3bb3c1fa17d9f2e4e38bbbaf8396ba82193c4cb1b6445d711445d36"
|
||||||
|
dependencies = [
|
||||||
|
"atty",
|
||||||
|
"humantime 1.3.0",
|
||||||
|
"log",
|
||||||
|
"regex",
|
||||||
|
"termcolor",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "env_logger"
|
name = "env_logger"
|
||||||
version = "0.8.4"
|
version = "0.8.4"
|
||||||
|
@ -158,7 +171,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3"
|
checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"atty",
|
"atty",
|
||||||
"humantime",
|
"humantime 2.1.0",
|
||||||
"log",
|
"log",
|
||||||
"regex",
|
"regex",
|
||||||
"termcolor",
|
"termcolor",
|
||||||
|
@ -322,6 +335,15 @@ version = "0.4.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "humantime"
|
||||||
|
version = "1.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f"
|
||||||
|
dependencies = [
|
||||||
|
"quick-error",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "humantime"
|
name = "humantime"
|
||||||
version = "2.1.0"
|
version = "2.1.0"
|
||||||
|
@ -440,7 +462,7 @@ dependencies = [
|
||||||
"bytes 0.6.0",
|
"bytes 0.6.0",
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"chrono",
|
"chrono",
|
||||||
"env_logger",
|
"env_logger 0.8.4",
|
||||||
"err-derive",
|
"err-derive",
|
||||||
"futures",
|
"futures",
|
||||||
"hex",
|
"hex",
|
||||||
|
@ -450,6 +472,8 @@ dependencies = [
|
||||||
"lru",
|
"lru",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"opentelemetry-contrib",
|
"opentelemetry-contrib",
|
||||||
|
"pin-project",
|
||||||
|
"pretty_env_logger",
|
||||||
"rand 0.5.6",
|
"rand 0.5.6",
|
||||||
"rmp-serde",
|
"rmp-serde",
|
||||||
"serde",
|
"serde",
|
||||||
|
@ -582,6 +606,16 @@ version = "0.2.16"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
|
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pretty_env_logger"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "926d36b9553851b8b0005f1275891b392ee4d2d833852c417ed025477350fb9d"
|
||||||
|
dependencies = [
|
||||||
|
"env_logger 0.7.1",
|
||||||
|
"log",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro-error"
|
name = "proc-macro-error"
|
||||||
version = "1.0.4"
|
version = "1.0.4"
|
||||||
|
@ -627,6 +661,12 @@ dependencies = [
|
||||||
"unicode-xid",
|
"unicode-xid",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quick-error"
|
||||||
|
version = "1.2.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quote"
|
name = "quote"
|
||||||
version = "1.0.10"
|
version = "1.0.10"
|
||||||
|
|
|
@ -21,6 +21,7 @@ telemetry = ["opentelemetry", "opentelemetry-contrib", "rand"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
futures = "0.3.17"
|
futures = "0.3.17"
|
||||||
|
pin-project = "1.0.10"
|
||||||
tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] }
|
tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] }
|
||||||
tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] }
|
tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] }
|
||||||
tokio-stream = "0.1.7"
|
tokio-stream = "0.1.7"
|
||||||
|
@ -47,6 +48,7 @@ opentelemetry-contrib = { version = "0.9", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
env_logger = "0.8"
|
env_logger = "0.8"
|
||||||
|
pretty_env_logger = "0.4"
|
||||||
structopt = { version = "0.3", default-features = false }
|
structopt = { version = "0.3", default-features = false }
|
||||||
chrono = "0.4"
|
chrono = "0.4"
|
||||||
|
|
||||||
|
|
|
@ -37,10 +37,10 @@ pub(crate) struct ClientConn {
|
||||||
pub(crate) remote_addr: SocketAddr,
|
pub(crate) remote_addr: SocketAddr,
|
||||||
pub(crate) peer_id: NodeID,
|
pub(crate) peer_id: NodeID,
|
||||||
|
|
||||||
query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
|
query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Data)>>,
|
||||||
|
|
||||||
next_query_number: AtomicU32,
|
next_query_number: AtomicU32,
|
||||||
inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
|
inflight: Mutex<HashMap<RequestID, oneshot::Sender<(Vec<u8>, AssociatedStream)>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClientConn {
|
impl ClientConn {
|
||||||
|
@ -148,9 +148,11 @@ impl ClientConn {
|
||||||
{
|
{
|
||||||
let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
|
let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
|
||||||
|
|
||||||
|
// increment by 2; even are direct data; odd are associated stream
|
||||||
let id = self
|
let id = self
|
||||||
.next_query_number
|
.next_query_number
|
||||||
.fetch_add(1, atomic::Ordering::Relaxed);
|
.fetch_add(2, atomic::Ordering::Relaxed);
|
||||||
|
let stream_id = id + 1;
|
||||||
|
|
||||||
cfg_if::cfg_if! {
|
cfg_if::cfg_if! {
|
||||||
if #[cfg(feature = "telemetry")] {
|
if #[cfg(feature = "telemetry")] {
|
||||||
|
@ -166,7 +168,7 @@ impl ClientConn {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Encode request
|
// Encode request
|
||||||
let body = rmp_to_vec_all_named(rq.borrow())?;
|
let (body, stream) = rmp_to_vec_all_named(rq.borrow())?;
|
||||||
drop(rq);
|
drop(rq);
|
||||||
|
|
||||||
let request = QueryMessage {
|
let request = QueryMessage {
|
||||||
|
@ -185,7 +187,10 @@ impl ClientConn {
|
||||||
error!(
|
error!(
|
||||||
"Too many inflight requests! RequestID collision. Interrupting previous request."
|
"Too many inflight requests! RequestID collision. Interrupting previous request."
|
||||||
);
|
);
|
||||||
if old_ch.send(vec![]).is_err() {
|
if old_ch
|
||||||
|
.send((vec![], Box::pin(futures::stream::empty())))
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
|
debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -195,15 +200,20 @@ impl ClientConn {
|
||||||
#[cfg(feature = "telemetry")]
|
#[cfg(feature = "telemetry")]
|
||||||
span.set_attribute(KeyValue::new("len_query", bytes.len() as i64));
|
span.set_attribute(KeyValue::new("len_query", bytes.len() as i64));
|
||||||
|
|
||||||
query_send.send((id, prio, bytes))?;
|
query_send.send((id, prio, Data::Full(bytes)))?;
|
||||||
|
if let Some(stream) = stream {
|
||||||
|
query_send.send((stream_id, prio | PRIO_SECONDARY, Data::Streaming(stream)))?;
|
||||||
|
} else {
|
||||||
|
query_send.send((stream_id, prio, Data::Full(Vec::new())))?;
|
||||||
|
}
|
||||||
|
|
||||||
cfg_if::cfg_if! {
|
cfg_if::cfg_if! {
|
||||||
if #[cfg(feature = "telemetry")] {
|
if #[cfg(feature = "telemetry")] {
|
||||||
let resp = resp_recv
|
let (resp, stream) = resp_recv
|
||||||
.with_context(Context::current_with_span(span))
|
.with_context(Context::current_with_span(span))
|
||||||
.await?;
|
.await?;
|
||||||
} else {
|
} else {
|
||||||
let resp = resp_recv.await?;
|
let (resp, stream) = resp_recv.await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,10 +227,9 @@ impl ClientConn {
|
||||||
|
|
||||||
let code = resp[0];
|
let code = resp[0];
|
||||||
if code == 0 {
|
if code == 0 {
|
||||||
Ok(rmp_serde::decode::from_read_ref::<
|
let mut deser = rmp_serde::decode::Deserializer::from_read_ref(&resp[1..]);
|
||||||
_,
|
let res = T::Response::deserialize_msg(&mut deser, stream).await?;
|
||||||
<T as Message>::Response,
|
Ok(res)
|
||||||
>(&resp[1..])?)
|
|
||||||
} else {
|
} else {
|
||||||
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
|
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
|
||||||
Err(Error::Remote(code, msg))
|
Err(Error::Remote(code, msg))
|
||||||
|
@ -232,12 +241,12 @@ impl SendLoop for ClientConn {}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RecvLoop for ClientConn {
|
impl RecvLoop for ClientConn {
|
||||||
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>) {
|
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>, stream: AssociatedStream) {
|
||||||
trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len());
|
trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len());
|
||||||
|
|
||||||
let mut inflight = self.inflight.lock().unwrap();
|
let mut inflight = self.inflight.lock().unwrap();
|
||||||
if let Some(ch) = inflight.remove(&id) {
|
if let Some(ch) = inflight.remove(&id) {
|
||||||
if ch.send(msg).is_err() {
|
if ch.send((msg, stream)).is_err() {
|
||||||
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
|
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,8 @@ use std::sync::Arc;
|
||||||
use arc_swap::ArcSwapOption;
|
use arc_swap::ArcSwapOption;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::de::Error as DeError;
|
||||||
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
|
|
||||||
use crate::error::Error;
|
use crate::error::Error;
|
||||||
use crate::netapp::*;
|
use crate::netapp::*;
|
||||||
|
@ -14,8 +15,50 @@ use crate::util::*;
|
||||||
|
|
||||||
/// This trait should be implemented by all messages your application
|
/// This trait should be implemented by all messages your application
|
||||||
/// wants to handle
|
/// wants to handle
|
||||||
pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
|
pub trait Message: SerializeMessage + Send + Sync {
|
||||||
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
|
type Response: SerializeMessage + Send + Sync;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A trait for de/serializing messages, with possible associated stream.
|
||||||
|
#[async_trait]
|
||||||
|
pub trait SerializeMessage: Sized {
|
||||||
|
fn serialize_msg<S: Serializer>(
|
||||||
|
&self,
|
||||||
|
serializer: S,
|
||||||
|
) -> Result<(S::Ok, Option<AssociatedStream>), S::Error>;
|
||||||
|
|
||||||
|
async fn deserialize_msg<'de, D: Deserializer<'de> + Send>(
|
||||||
|
deserializer: D,
|
||||||
|
stream: AssociatedStream,
|
||||||
|
) -> Result<Self, D::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<T> SerializeMessage for T
|
||||||
|
where
|
||||||
|
T: Serialize + for<'de> Deserialize<'de> + Send + Sync,
|
||||||
|
{
|
||||||
|
fn serialize_msg<S: Serializer>(
|
||||||
|
&self,
|
||||||
|
serializer: S,
|
||||||
|
) -> Result<(S::Ok, Option<AssociatedStream>), S::Error> {
|
||||||
|
self.serialize(serializer).map(|r| (r, None))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn deserialize_msg<'de, D: Deserializer<'de> + Send>(
|
||||||
|
deserializer: D,
|
||||||
|
mut stream: AssociatedStream,
|
||||||
|
) -> Result<Self, D::Error> {
|
||||||
|
use futures::StreamExt;
|
||||||
|
|
||||||
|
let res = Self::deserialize(deserializer)?;
|
||||||
|
if stream.next().await.is_some() {
|
||||||
|
return Err(D::Error::custom(
|
||||||
|
"failed to deserialize: found associated stream when none expected",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This trait should be implemented by an object of your application
|
/// This trait should be implemented by an object of your application
|
||||||
|
@ -128,7 +171,12 @@ pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub(crate) trait GenericEndpoint {
|
pub(crate) trait GenericEndpoint {
|
||||||
async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error>;
|
async fn handle(
|
||||||
|
&self,
|
||||||
|
buf: &[u8],
|
||||||
|
stream: AssociatedStream,
|
||||||
|
from: NodeID,
|
||||||
|
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error>;
|
||||||
fn drop_handler(&self);
|
fn drop_handler(&self);
|
||||||
fn clone_endpoint(&self) -> DynEndpoint;
|
fn clone_endpoint(&self) -> DynEndpoint;
|
||||||
}
|
}
|
||||||
|
@ -145,11 +193,17 @@ where
|
||||||
M: Message + 'static,
|
M: Message + 'static,
|
||||||
H: EndpointHandler<M> + 'static,
|
H: EndpointHandler<M> + 'static,
|
||||||
{
|
{
|
||||||
async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error> {
|
async fn handle(
|
||||||
|
&self,
|
||||||
|
buf: &[u8],
|
||||||
|
stream: AssociatedStream,
|
||||||
|
from: NodeID,
|
||||||
|
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> {
|
||||||
match self.0.handler.load_full() {
|
match self.0.handler.load_full() {
|
||||||
None => Err(Error::NoHandler),
|
None => Err(Error::NoHandler),
|
||||||
Some(h) => {
|
Some(h) => {
|
||||||
let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?;
|
let mut deser = rmp_serde::decode::Deserializer::from_read_ref(buf);
|
||||||
|
let req = M::deserialize_msg(&mut deser, stream).await?;
|
||||||
let res = h.handle(&req, from).await;
|
let res = h.handle(&req, from).await;
|
||||||
let res_bytes = rmp_to_vec_all_named(&res)?;
|
let res_bytes = rmp_to_vec_all_named(&res)?;
|
||||||
Ok(res_bytes)
|
Ok(res_bytes)
|
||||||
|
|
260
src/proto.rs
260
src/proto.rs
|
@ -1,9 +1,13 @@
|
||||||
use std::collections::{HashMap, VecDeque};
|
use std::collections::{HashMap, VecDeque};
|
||||||
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
use log::trace;
|
use log::{trace, warn};
|
||||||
|
|
||||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
|
||||||
|
use futures::Stream;
|
||||||
|
use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, StreamExt};
|
||||||
use kuska_handshake::async_std::BoxStreamWrite;
|
use kuska_handshake::async_std::BoxStreamWrite;
|
||||||
|
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
@ -11,6 +15,7 @@ use tokio::sync::mpsc;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
use crate::error::*;
|
use crate::error::*;
|
||||||
|
use crate::util::AssociatedStream;
|
||||||
|
|
||||||
/// Priority of a request (click to read more about priorities).
|
/// Priority of a request (click to read more about priorities).
|
||||||
///
|
///
|
||||||
|
@ -48,14 +53,73 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01;
|
||||||
|
|
||||||
pub(crate) type RequestID = u32;
|
pub(crate) type RequestID = u32;
|
||||||
type ChunkLength = u16;
|
type ChunkLength = u16;
|
||||||
const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
|
pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
|
||||||
const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
|
const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
|
||||||
|
|
||||||
struct SendQueueItem {
|
struct SendQueueItem {
|
||||||
id: RequestID,
|
id: RequestID,
|
||||||
prio: RequestPriority,
|
prio: RequestPriority,
|
||||||
data: Vec<u8>,
|
data: DataReader,
|
||||||
cursor: usize,
|
}
|
||||||
|
|
||||||
|
pub(crate) enum Data {
|
||||||
|
Full(Vec<u8>),
|
||||||
|
Streaming(AssociatedStream),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pin_project::pin_project(project = DataReaderProj)]
|
||||||
|
enum DataReader {
|
||||||
|
Full {
|
||||||
|
#[pin]
|
||||||
|
data: Vec<u8>,
|
||||||
|
pos: usize,
|
||||||
|
},
|
||||||
|
Streaming {
|
||||||
|
#[pin]
|
||||||
|
reader: AssociatedStream,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Data> for DataReader {
|
||||||
|
fn from(data: Data) -> DataReader {
|
||||||
|
match data {
|
||||||
|
Data::Full(data) => DataReader::Full { data, pos: 0 },
|
||||||
|
Data::Streaming(reader) => DataReader::Streaming { reader },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Stream for DataReader {
|
||||||
|
type Item = ([u8; MAX_CHUNK_LENGTH as usize], usize);
|
||||||
|
|
||||||
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
match self.project() {
|
||||||
|
DataReaderProj::Full { data, pos } => {
|
||||||
|
let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, data.len() - *pos);
|
||||||
|
let end = *pos + len;
|
||||||
|
|
||||||
|
if len == 0 {
|
||||||
|
Poll::Ready(None)
|
||||||
|
} else {
|
||||||
|
let mut body = [0; MAX_CHUNK_LENGTH as usize];
|
||||||
|
body[..len].copy_from_slice(&data[*pos..end]);
|
||||||
|
*pos = end;
|
||||||
|
Poll::Ready(Some((body, len)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DataReaderProj::Streaming { reader } => {
|
||||||
|
reader.poll_next(cx).map(|opt| {
|
||||||
|
opt.map(|v| {
|
||||||
|
let mut body = [0; MAX_CHUNK_LENGTH as usize];
|
||||||
|
let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, v.len());
|
||||||
|
// TODO this can throw away long vec, they should be splited instead
|
||||||
|
body[..len].copy_from_slice(&v[..len]);
|
||||||
|
(body, len)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SendQueue {
|
struct SendQueue {
|
||||||
|
@ -108,7 +172,7 @@ impl SendQueue {
|
||||||
pub(crate) trait SendLoop: Sync {
|
pub(crate) trait SendLoop: Sync {
|
||||||
async fn send_loop<W>(
|
async fn send_loop<W>(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
|
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Data)>,
|
||||||
mut write: BoxStreamWrite<W>,
|
mut write: BoxStreamWrite<W>,
|
||||||
) -> Result<(), Error>
|
) -> Result<(), Error>
|
||||||
where
|
where
|
||||||
|
@ -118,51 +182,78 @@ pub(crate) trait SendLoop: Sync {
|
||||||
let mut should_exit = false;
|
let mut should_exit = false;
|
||||||
while !should_exit || !sending.is_empty() {
|
while !should_exit || !sending.is_empty() {
|
||||||
if let Ok((id, prio, data)) = msg_recv.try_recv() {
|
if let Ok((id, prio, data)) = msg_recv.try_recv() {
|
||||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
match &data {
|
||||||
|
Data::Full(data) => {
|
||||||
|
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||||
|
}
|
||||||
|
Data::Streaming(_) => {
|
||||||
|
trace!("send_loop: got {}, unknown size", id);
|
||||||
|
}
|
||||||
|
}
|
||||||
sending.push(SendQueueItem {
|
sending.push(SendQueueItem {
|
||||||
id,
|
id,
|
||||||
prio,
|
prio,
|
||||||
data,
|
data: data.into(),
|
||||||
cursor: 0,
|
|
||||||
});
|
});
|
||||||
} else if let Some(mut item) = sending.pop() {
|
} else if let Some(mut item) = sending.pop() {
|
||||||
trace!(
|
trace!(
|
||||||
"send_loop: sending bytes for {} ({} bytes, {} already sent)",
|
"send_loop: sending bytes for {}",
|
||||||
item.id,
|
item.id,
|
||||||
item.data.len(),
|
|
||||||
item.cursor
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let data = futures::select! {
|
||||||
|
data = item.data.next().fuse() => data,
|
||||||
|
default => {
|
||||||
|
// nothing to send yet; re-schedule and find something else to do
|
||||||
|
sending.push(item);
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// TODO if every SendQueueItem is waiting on data, use select_all to await
|
||||||
|
// something to do
|
||||||
|
// TODO find some way to not require sending empty last chunk
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let header_id = RequestID::to_be_bytes(item.id);
|
let header_id = RequestID::to_be_bytes(item.id);
|
||||||
write.write_all(&header_id[..]).await?;
|
write.write_all(&header_id[..]).await?;
|
||||||
|
|
||||||
if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize {
|
let data = match data.as_ref() {
|
||||||
|
Some((data, len)) => &data[..*len],
|
||||||
|
None => &[],
|
||||||
|
};
|
||||||
|
|
||||||
|
if !data.is_empty() {
|
||||||
let size_header =
|
let size_header =
|
||||||
ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION);
|
ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION);
|
||||||
write.write_all(&size_header[..]).await?;
|
write.write_all(&size_header[..]).await?;
|
||||||
|
|
||||||
let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize;
|
write.write_all(data).await?;
|
||||||
write.write_all(&item.data[item.cursor..new_cursor]).await?;
|
|
||||||
item.cursor = new_cursor;
|
|
||||||
|
|
||||||
sending.push(item);
|
sending.push(item);
|
||||||
} else {
|
} else {
|
||||||
let send_len = (item.data.len() - item.cursor) as ChunkLength;
|
// this is always zero for now, but may be more when above TODO get fixed
|
||||||
|
let size_header = ChunkLength::to_be_bytes(data.len() as u16);
|
||||||
let size_header = ChunkLength::to_be_bytes(send_len);
|
|
||||||
write.write_all(&size_header[..]).await?;
|
write.write_all(&size_header[..]).await?;
|
||||||
|
|
||||||
write.write_all(&item.data[item.cursor..]).await?;
|
write.write_all(data).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
write.flush().await?;
|
write.flush().await?;
|
||||||
} else {
|
} else {
|
||||||
let sth = msg_recv.recv().await;
|
let sth = msg_recv.recv().await;
|
||||||
if let Some((id, prio, data)) = sth {
|
if let Some((id, prio, data)) = sth {
|
||||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
match &data {
|
||||||
|
Data::Full(data) => {
|
||||||
|
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||||
|
}
|
||||||
|
Data::Streaming(_) => {
|
||||||
|
trace!("send_loop: got {}, unknown size", id);
|
||||||
|
}
|
||||||
|
}
|
||||||
sending.push(SendQueueItem {
|
sending.push(SendQueueItem {
|
||||||
id,
|
id,
|
||||||
prio,
|
prio,
|
||||||
data,
|
data: data.into(),
|
||||||
cursor: 0,
|
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
should_exit = true;
|
should_exit = true;
|
||||||
|
@ -175,6 +266,41 @@ pub(crate) trait SendLoop: Sync {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ChannelPair {
|
||||||
|
receiver: Option<UnboundedReceiver<Vec<u8>>>,
|
||||||
|
sender: Option<UnboundedSender<Vec<u8>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChannelPair {
|
||||||
|
fn take_receiver(&mut self) -> Option<UnboundedReceiver<Vec<u8>>> {
|
||||||
|
self.receiver.take()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn take_sender(&mut self) -> Option<UnboundedSender<Vec<u8>>> {
|
||||||
|
self.sender.take()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ref_sender(&mut self) -> Option<&UnboundedSender<Vec<u8>>> {
|
||||||
|
self.sender.as_ref().take()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_into(self, map: &mut HashMap<RequestID, ChannelPair>, index: RequestID) {
|
||||||
|
if self.receiver.is_some() || self.sender.is_some() {
|
||||||
|
map.insert(index, self);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ChannelPair {
|
||||||
|
fn default() -> Self {
|
||||||
|
let (send, recv) = unbounded();
|
||||||
|
ChannelPair {
|
||||||
|
receiver: Some(recv),
|
||||||
|
sender: Some(send),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// The RecvLoop trait, which is implemented both by the client and the server
|
/// The RecvLoop trait, which is implemented both by the client and the server
|
||||||
/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
|
/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
|
||||||
/// and a prototype of a handler for received messages `.recv_handler()` that
|
/// and a prototype of a handler for received messages `.recv_handler()` that
|
||||||
|
@ -184,13 +310,17 @@ pub(crate) trait SendLoop: Sync {
|
||||||
/// the full message is passed to the receive handler.
|
/// the full message is passed to the receive handler.
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub(crate) trait RecvLoop: Sync + 'static {
|
pub(crate) trait RecvLoop: Sync + 'static {
|
||||||
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>);
|
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>, stream: AssociatedStream);
|
||||||
|
|
||||||
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
|
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
|
||||||
where
|
where
|
||||||
R: AsyncReadExt + Unpin + Send + Sync,
|
R: AsyncReadExt + Unpin + Send + Sync,
|
||||||
{
|
{
|
||||||
let mut receiving = HashMap::new();
|
let mut receiving: HashMap<RequestID, Vec<u8>> = HashMap::new();
|
||||||
|
let mut streams: HashMap<
|
||||||
|
RequestID,
|
||||||
|
ChannelPair,
|
||||||
|
> = HashMap::new();
|
||||||
loop {
|
loop {
|
||||||
trace!("recv_loop: reading packet");
|
trace!("recv_loop: reading packet");
|
||||||
let mut header_id = [0u8; RequestID::BITS as usize / 8];
|
let mut header_id = [0u8; RequestID::BITS as usize / 8];
|
||||||
|
@ -214,13 +344,43 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
||||||
read.read_exact(&mut next_slice[..]).await?;
|
read.read_exact(&mut next_slice[..]).await?;
|
||||||
trace!("recv_loop: read {} bytes", next_slice.len());
|
trace!("recv_loop: read {} bytes", next_slice.len());
|
||||||
|
|
||||||
let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default();
|
if id & 1 == 0 {
|
||||||
msg_bytes.extend_from_slice(&next_slice[..]);
|
// main stream
|
||||||
|
let mut msg_bytes = receiving.remove(&id).unwrap_or_default();
|
||||||
|
msg_bytes.extend_from_slice(&next_slice[..]);
|
||||||
|
|
||||||
if has_cont {
|
if has_cont {
|
||||||
receiving.insert(id, msg_bytes);
|
receiving.insert(id, msg_bytes);
|
||||||
|
} else {
|
||||||
|
let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default();
|
||||||
|
|
||||||
|
if let Some(receiver) = channel_pair.take_receiver() {
|
||||||
|
self.recv_handler(id, msg_bytes, Box::pin(receiver));
|
||||||
|
} else {
|
||||||
|
warn!("Couldn't take receiver part of stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
channel_pair.insert_into(&mut streams, id | 1);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
self.recv_handler(id, msg_bytes);
|
// associated stream
|
||||||
|
let mut channel_pair = streams.remove(&(id)).unwrap_or_default();
|
||||||
|
|
||||||
|
// if we get an error, the receiving end is disconnected. We still need to
|
||||||
|
// reach eos before dropping this sender
|
||||||
|
if !next_slice.is_empty() {
|
||||||
|
if let Some(sender) = channel_pair.ref_sender() {
|
||||||
|
let _ = sender.unbounded_send(next_slice);
|
||||||
|
} else {
|
||||||
|
warn!("Couldn't take sending part of stream")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !has_cont {
|
||||||
|
channel_pair.take_sender();
|
||||||
|
}
|
||||||
|
|
||||||
|
channel_pair.insert_into(&mut streams, id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -236,38 +396,50 @@ mod test {
|
||||||
let i1 = SendQueueItem {
|
let i1 = SendQueueItem {
|
||||||
id: 1,
|
id: 1,
|
||||||
prio: PRIO_NORMAL,
|
prio: PRIO_NORMAL,
|
||||||
data: vec![],
|
data: DataReader::Full {
|
||||||
cursor: 0,
|
data: vec![],
|
||||||
|
pos: 0,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
let i2 = SendQueueItem {
|
let i2 = SendQueueItem {
|
||||||
id: 2,
|
id: 2,
|
||||||
prio: PRIO_HIGH,
|
prio: PRIO_HIGH,
|
||||||
data: vec![],
|
data: DataReader::Full {
|
||||||
cursor: 0,
|
data: vec![],
|
||||||
|
pos: 0,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
let i2bis = SendQueueItem {
|
let i2bis = SendQueueItem {
|
||||||
id: 20,
|
id: 20,
|
||||||
prio: PRIO_HIGH,
|
prio: PRIO_HIGH,
|
||||||
data: vec![],
|
data: DataReader::Full {
|
||||||
cursor: 0,
|
data: vec![],
|
||||||
|
pos: 0,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
let i3 = SendQueueItem {
|
let i3 = SendQueueItem {
|
||||||
id: 3,
|
id: 3,
|
||||||
prio: PRIO_HIGH | PRIO_SECONDARY,
|
prio: PRIO_HIGH | PRIO_SECONDARY,
|
||||||
data: vec![],
|
data: DataReader::Full {
|
||||||
cursor: 0,
|
data: vec![],
|
||||||
|
pos: 0,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
let i4 = SendQueueItem {
|
let i4 = SendQueueItem {
|
||||||
id: 4,
|
id: 4,
|
||||||
prio: PRIO_BACKGROUND | PRIO_SECONDARY,
|
prio: PRIO_BACKGROUND | PRIO_SECONDARY,
|
||||||
data: vec![],
|
data: DataReader::Full {
|
||||||
cursor: 0,
|
data: vec![],
|
||||||
|
pos: 0,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
let i5 = SendQueueItem {
|
let i5 = SendQueueItem {
|
||||||
id: 5,
|
id: 5,
|
||||||
prio: PRIO_BACKGROUND | PRIO_PRIMARY,
|
prio: PRIO_BACKGROUND | PRIO_PRIMARY,
|
||||||
data: vec![],
|
data: DataReader::Full {
|
||||||
cursor: 0,
|
data: vec![],
|
||||||
|
pos: 0,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut q = SendQueue::new();
|
let mut q = SendQueue::new();
|
||||||
|
|
|
@ -55,7 +55,7 @@ pub(crate) struct ServerConn {
|
||||||
|
|
||||||
netapp: Arc<NetApp>,
|
netapp: Arc<NetApp>,
|
||||||
|
|
||||||
resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
|
resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Data)>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ServerConn {
|
impl ServerConn {
|
||||||
|
@ -123,7 +123,11 @@ impl ServerConn {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
|
async fn recv_handler_aux(
|
||||||
|
self: &Arc<Self>,
|
||||||
|
bytes: &[u8],
|
||||||
|
stream: AssociatedStream,
|
||||||
|
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> {
|
||||||
let msg = QueryMessage::decode(bytes)?;
|
let msg = QueryMessage::decode(bytes)?;
|
||||||
let path = String::from_utf8(msg.path.to_vec())?;
|
let path = String::from_utf8(msg.path.to_vec())?;
|
||||||
|
|
||||||
|
@ -156,11 +160,11 @@ impl ServerConn {
|
||||||
span.set_attribute(KeyValue::new("path", path.to_string()));
|
span.set_attribute(KeyValue::new("path", path.to_string()));
|
||||||
span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));
|
span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));
|
||||||
|
|
||||||
handler.handle(msg.body, self.peer_id)
|
handler.handle(msg.body, stream, self.peer_id)
|
||||||
.with_context(Context::current_with_span(span))
|
.with_context(Context::current_with_span(span))
|
||||||
.await
|
.await
|
||||||
} else {
|
} else {
|
||||||
handler.handle(msg.body, self.peer_id).await
|
handler.handle(msg.body, stream, self.peer_id).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -173,7 +177,7 @@ impl SendLoop for ServerConn {}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RecvLoop for ServerConn {
|
impl RecvLoop for ServerConn {
|
||||||
fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) {
|
fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>, stream: AssociatedStream) {
|
||||||
let resp_send = self.resp_send.load_full().unwrap();
|
let resp_send = self.resp_send.load_full().unwrap();
|
||||||
|
|
||||||
let self2 = self.clone();
|
let self2 = self.clone();
|
||||||
|
@ -182,26 +186,36 @@ impl RecvLoop for ServerConn {
|
||||||
let bytes: Bytes = bytes.into();
|
let bytes: Bytes = bytes.into();
|
||||||
|
|
||||||
let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
|
let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
|
||||||
let resp = self2.recv_handler_aux(&bytes[..]).await;
|
let resp = self2.recv_handler_aux(&bytes[..], stream).await;
|
||||||
|
|
||||||
let resp_bytes = match resp {
|
let (resp_bytes, resp_stream) = match resp {
|
||||||
Ok(rb) => {
|
Ok((rb, rs)) => {
|
||||||
let mut resp_bytes = vec![0u8];
|
let mut resp_bytes = vec![0u8];
|
||||||
resp_bytes.extend(rb);
|
resp_bytes.extend(rb);
|
||||||
resp_bytes
|
(resp_bytes, rs)
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let mut resp_bytes = vec![e.code()];
|
let mut resp_bytes = vec![e.code()];
|
||||||
resp_bytes.extend(e.to_string().into_bytes());
|
resp_bytes.extend(e.to_string().into_bytes());
|
||||||
resp_bytes
|
(resp_bytes, None)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
trace!("ServerConn sending response to {}: ", id);
|
trace!("ServerConn sending response to {}: ", id);
|
||||||
|
|
||||||
resp_send
|
resp_send
|
||||||
.send((id, prio, resp_bytes))
|
.send((id, prio, Data::Full(resp_bytes)))
|
||||||
.log_err("ServerConn recv_handler send resp");
|
.log_err("ServerConn recv_handler send resp bytes");
|
||||||
|
|
||||||
|
if let Some(resp_stream) = resp_stream {
|
||||||
|
resp_send
|
||||||
|
.send((id + 1, prio, Data::Streaming(resp_stream)))
|
||||||
|
.log_err("ServerConn recv_handler send resp stream");
|
||||||
|
} else {
|
||||||
|
resp_send
|
||||||
|
.send((id + 1, prio, Data::Full(Vec::new())))
|
||||||
|
.log_err("ServerConn recv_handler send resp stream");
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ use crate::NodeID;
|
||||||
|
|
||||||
#[tokio::test(flavor = "current_thread")]
|
#[tokio::test(flavor = "current_thread")]
|
||||||
async fn test_with_basic_scheduler() {
|
async fn test_with_basic_scheduler() {
|
||||||
|
pretty_env_logger::init();
|
||||||
run_test().await
|
run_test().await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
17
src/util.rs
17
src/util.rs
|
@ -1,7 +1,10 @@
|
||||||
|
use crate::endpoint::SerializeMessage;
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::net::ToSocketAddrs;
|
use std::net::ToSocketAddrs;
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
use serde::Serialize;
|
use futures::Stream;
|
||||||
|
|
||||||
use log::info;
|
use log::info;
|
||||||
|
|
||||||
|
@ -14,21 +17,25 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey;
|
||||||
/// A network key
|
/// A network key
|
||||||
pub type NetworkKey = sodiumoxide::crypto::auth::Key;
|
pub type NetworkKey = sodiumoxide::crypto::auth::Key;
|
||||||
|
|
||||||
|
pub type AssociatedStream = Pin<Box<dyn Stream<Item = Vec<u8>> + Send>>;
|
||||||
|
|
||||||
/// Utility function: encodes any serializable value in MessagePack binary format
|
/// Utility function: encodes any serializable value in MessagePack binary format
|
||||||
/// using the RMP library.
|
/// using the RMP library.
|
||||||
///
|
///
|
||||||
/// Field names and variant names are included in the serialization.
|
/// Field names and variant names are included in the serialization.
|
||||||
/// This is used internally by the netapp communication protocol.
|
/// This is used internally by the netapp communication protocol.
|
||||||
pub fn rmp_to_vec_all_named<T>(val: &T) -> Result<Vec<u8>, rmp_serde::encode::Error>
|
pub fn rmp_to_vec_all_named<T>(
|
||||||
|
val: &T,
|
||||||
|
) -> Result<(Vec<u8>, Option<AssociatedStream>), rmp_serde::encode::Error>
|
||||||
where
|
where
|
||||||
T: Serialize + ?Sized,
|
T: SerializeMessage + ?Sized,
|
||||||
{
|
{
|
||||||
let mut wr = Vec::with_capacity(128);
|
let mut wr = Vec::with_capacity(128);
|
||||||
let mut se = rmp_serde::Serializer::new(&mut wr)
|
let mut se = rmp_serde::Serializer::new(&mut wr)
|
||||||
.with_struct_map()
|
.with_struct_map()
|
||||||
.with_string_variants();
|
.with_string_variants();
|
||||||
val.serialize(&mut se)?;
|
let (_, stream) = val.serialize_msg(&mut se)?;
|
||||||
Ok(wr)
|
Ok((wr, stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This async function returns only when a true signal was received
|
/// This async function returns only when a true signal was received
|
||||||
|
|
Loading…
Reference in a new issue