forked from lx/netapp
WIP: associated stream #1
11 changed files with 566 additions and 117 deletions
44
Cargo.lock
generated
44
Cargo.lock
generated
|
@ -151,6 +151,19 @@ dependencies = [
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "env_logger"
|
||||||
|
version = "0.7.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "44533bbbb3bb3c1fa17d9f2e4e38bbbaf8396ba82193c4cb1b6445d711445d36"
|
||||||
|
dependencies = [
|
||||||
|
"atty",
|
||||||
|
"humantime 1.3.0",
|
||||||
|
"log",
|
||||||
|
"regex",
|
||||||
|
"termcolor",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "env_logger"
|
name = "env_logger"
|
||||||
version = "0.8.4"
|
version = "0.8.4"
|
||||||
|
@ -158,7 +171,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3"
|
checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"atty",
|
"atty",
|
||||||
"humantime",
|
"humantime 2.1.0",
|
||||||
"log",
|
"log",
|
||||||
"regex",
|
"regex",
|
||||||
"termcolor",
|
"termcolor",
|
||||||
|
@ -322,6 +335,15 @@ version = "0.4.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "humantime"
|
||||||
|
version = "1.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f"
|
||||||
|
dependencies = [
|
||||||
|
"quick-error",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "humantime"
|
name = "humantime"
|
||||||
version = "2.1.0"
|
version = "2.1.0"
|
||||||
|
@ -440,7 +462,7 @@ dependencies = [
|
||||||
"bytes 0.6.0",
|
"bytes 0.6.0",
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"chrono",
|
"chrono",
|
||||||
"env_logger",
|
"env_logger 0.8.4",
|
||||||
"err-derive",
|
"err-derive",
|
||||||
"futures",
|
"futures",
|
||||||
"hex",
|
"hex",
|
||||||
|
@ -450,6 +472,8 @@ dependencies = [
|
||||||
"lru",
|
"lru",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"opentelemetry-contrib",
|
"opentelemetry-contrib",
|
||||||
|
"pin-project",
|
||||||
|
"pretty_env_logger",
|
||||||
"rand 0.5.6",
|
"rand 0.5.6",
|
||||||
"rmp-serde",
|
"rmp-serde",
|
||||||
"serde",
|
"serde",
|
||||||
|
@ -582,6 +606,16 @@ version = "0.2.16"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
|
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pretty_env_logger"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "926d36b9553851b8b0005f1275891b392ee4d2d833852c417ed025477350fb9d"
|
||||||
|
dependencies = [
|
||||||
|
"env_logger 0.7.1",
|
||||||
|
"log",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro-error"
|
name = "proc-macro-error"
|
||||||
version = "1.0.4"
|
version = "1.0.4"
|
||||||
|
@ -627,6 +661,12 @@ dependencies = [
|
||||||
"unicode-xid",
|
"unicode-xid",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quick-error"
|
||||||
|
version = "1.2.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quote"
|
name = "quote"
|
||||||
version = "1.0.10"
|
version = "1.0.10"
|
||||||
|
|
|
@ -21,6 +21,7 @@ telemetry = ["opentelemetry", "opentelemetry-contrib", "rand"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
futures = "0.3.17"
|
futures = "0.3.17"
|
||||||
|
pin-project = "1.0.10"
|
||||||
tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] }
|
tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] }
|
||||||
tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] }
|
tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] }
|
||||||
tokio-stream = "0.1.7"
|
tokio-stream = "0.1.7"
|
||||||
|
@ -47,6 +48,7 @@ opentelemetry-contrib = { version = "0.9", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
env_logger = "0.8"
|
env_logger = "0.8"
|
||||||
|
pretty_env_logger = "0.4"
|
||||||
structopt = { version = "0.3", default-features = false }
|
structopt = { version = "0.3", default-features = false }
|
||||||
chrono = "0.4"
|
chrono = "0.4"
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ use std::sync::{Arc, Mutex};
|
||||||
use arc_swap::ArcSwapOption;
|
use arc_swap::ArcSwapOption;
|
||||||
use log::{debug, error, trace};
|
use log::{debug, error, trace};
|
||||||
|
|
||||||
|
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio::sync::{mpsc, oneshot, watch};
|
use tokio::sync::{mpsc, oneshot, watch};
|
||||||
|
@ -37,10 +38,11 @@ pub(crate) struct ClientConn {
|
||||||
pub(crate) remote_addr: SocketAddr,
|
pub(crate) remote_addr: SocketAddr,
|
||||||
pub(crate) peer_id: NodeID,
|
pub(crate) peer_id: NodeID,
|
||||||
|
|
||||||
query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
|
query_send:
|
||||||
|
ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>,
|
||||||
|
|
||||||
next_query_number: AtomicU32,
|
next_query_number: AtomicU32,
|
||||||
inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
|
inflight: Mutex<HashMap<RequestID, oneshot::Sender<UnboundedReceiver<Packet>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClientConn {
|
impl ClientConn {
|
||||||
|
@ -166,7 +168,7 @@ impl ClientConn {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Encode request
|
// Encode request
|
||||||
let body = rmp_to_vec_all_named(rq.borrow())?;
|
let (body, stream) = rmp_to_vec_all_named(rq.borrow())?;
|
||||||
drop(rq);
|
drop(rq);
|
||||||
|
|
||||||
let request = QueryMessage {
|
let request = QueryMessage {
|
||||||
|
@ -185,7 +187,7 @@ impl ClientConn {
|
||||||
error!(
|
error!(
|
||||||
"Too many inflight requests! RequestID collision. Interrupting previous request."
|
"Too many inflight requests! RequestID collision. Interrupting previous request."
|
||||||
);
|
);
|
||||||
if old_ch.send(vec![]).is_err() {
|
if old_ch.send(unbounded().1).is_err() {
|
||||||
debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
|
debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -195,17 +197,18 @@ impl ClientConn {
|
||||||
#[cfg(feature = "telemetry")]
|
#[cfg(feature = "telemetry")]
|
||||||
span.set_attribute(KeyValue::new("len_query", bytes.len() as i64));
|
span.set_attribute(KeyValue::new("len_query", bytes.len() as i64));
|
||||||
|
|
||||||
query_send.send((id, prio, bytes))?;
|
query_send.send((id, prio, Framing::new(bytes, stream).into_stream()))?;
|
||||||
|
|
||||||
cfg_if::cfg_if! {
|
cfg_if::cfg_if! {
|
||||||
if #[cfg(feature = "telemetry")] {
|
if #[cfg(feature = "telemetry")] {
|
||||||
let resp = resp_recv
|
let stream = resp_recv
|
||||||
.with_context(Context::current_with_span(span))
|
.with_context(Context::current_with_span(span))
|
||||||
.await?;
|
.await?;
|
||||||
} else {
|
} else {
|
||||||
let resp = resp_recv.await?;
|
let stream = resp_recv.await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
let (resp, stream) = Framing::from_stream(stream).await?.into_parts();
|
||||||
|
|
||||||
if resp.is_empty() {
|
if resp.is_empty() {
|
||||||
return Err(Error::Message(
|
return Err(Error::Message(
|
||||||
|
@ -217,10 +220,8 @@ impl ClientConn {
|
||||||
|
|
||||||
let code = resp[0];
|
let code = resp[0];
|
||||||
if code == 0 {
|
if code == 0 {
|
||||||
Ok(rmp_serde::decode::from_read_ref::<
|
let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
|
||||||
_,
|
Ok(T::Response::deserialize_msg(ser_resp, stream).await)
|
||||||
<T as Message>::Response,
|
|
||||||
>(&resp[1..])?)
|
|
||||||
} else {
|
} else {
|
||||||
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
|
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
|
||||||
Err(Error::Remote(code, msg))
|
Err(Error::Remote(code, msg))
|
||||||
|
@ -232,12 +233,12 @@ impl SendLoop for ClientConn {}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RecvLoop for ClientConn {
|
impl RecvLoop for ClientConn {
|
||||||
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>) {
|
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
|
||||||
trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len());
|
trace!("ClientConn recv_handler {}", id);
|
||||||
|
|
||||||
let mut inflight = self.inflight.lock().unwrap();
|
let mut inflight = self.inflight.lock().unwrap();
|
||||||
if let Some(ch) = inflight.remove(&id) {
|
if let Some(ch) = inflight.remove(&id) {
|
||||||
if ch.send(msg).is_err() {
|
if ch.send(stream).is_err() {
|
||||||
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
|
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,8 +14,68 @@ use crate::util::*;
|
||||||
|
|
||||||
/// This trait should be implemented by all messages your application
|
/// This trait should be implemented by all messages your application
|
||||||
/// wants to handle
|
/// wants to handle
|
||||||
pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
|
pub trait Message: SerializeMessage + Send + Sync {
|
||||||
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
|
type Response: SerializeMessage + Send + Sync;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A trait for de/serializing messages, with possible associated stream.
|
||||||
|
#[async_trait]
|
||||||
|
pub trait SerializeMessage: Sized {
|
||||||
|
type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
|
||||||
|
|
||||||
|
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>);
|
||||||
|
|
||||||
|
// TODO should return Result
|
||||||
|
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<T> SerializeMessage for T
|
||||||
|
where
|
||||||
|
T: AutoSerialize,
|
||||||
|
{
|
||||||
|
type SerializableSelf = Self;
|
||||||
|
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>) {
|
||||||
|
(self.clone(), None)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: AssociatedStream) -> Self {
|
||||||
|
// TODO verify no stream
|
||||||
|
ser_self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AutoSerialize for () {}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<T, E> SerializeMessage for Result<T, E>
|
||||||
|
where
|
||||||
|
T: SerializeMessage + Send,
|
||||||
|
E: SerializeMessage + Send,
|
||||||
|
{
|
||||||
|
type SerializableSelf = Result<T::SerializableSelf, E::SerializableSelf>;
|
||||||
|
|
||||||
|
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>) {
|
||||||
|
match self {
|
||||||
|
Ok(ok) => {
|
||||||
|
let (msg, stream) = ok.serialize_msg();
|
||||||
|
(Ok(msg), stream)
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
let (msg, stream) = err.serialize_msg();
|
||||||
|
(Err(msg), stream)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self {
|
||||||
|
match ser_self {
|
||||||
|
Ok(ok) => Ok(T::deserialize_msg(ok, stream).await),
|
||||||
|
Err(err) => Err(E::deserialize_msg(err, stream).await),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This trait should be implemented by an object of your application
|
/// This trait should be implemented by an object of your application
|
||||||
|
@ -96,7 +156,7 @@ where
|
||||||
prio: RequestPriority,
|
prio: RequestPriority,
|
||||||
) -> Result<<M as Message>::Response, Error>
|
) -> Result<<M as Message>::Response, Error>
|
||||||
where
|
where
|
||||||
B: Borrow<M>,
|
B: Borrow<M> + Send + Sync,
|
||||||
{
|
{
|
||||||
if *target == self.netapp.id {
|
if *target == self.netapp.id {
|
||||||
match self.handler.load_full() {
|
match self.handler.load_full() {
|
||||||
|
@ -128,7 +188,12 @@ pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub(crate) trait GenericEndpoint {
|
pub(crate) trait GenericEndpoint {
|
||||||
async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error>;
|
async fn handle(
|
||||||
|
&self,
|
||||||
|
buf: &[u8],
|
||||||
|
stream: AssociatedStream,
|
||||||
|
from: NodeID,
|
||||||
|
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error>;
|
||||||
fn drop_handler(&self);
|
fn drop_handler(&self);
|
||||||
fn clone_endpoint(&self) -> DynEndpoint;
|
fn clone_endpoint(&self) -> DynEndpoint;
|
||||||
}
|
}
|
||||||
|
@ -145,11 +210,17 @@ where
|
||||||
M: Message + 'static,
|
M: Message + 'static,
|
||||||
H: EndpointHandler<M> + 'static,
|
H: EndpointHandler<M> + 'static,
|
||||||
{
|
{
|
||||||
async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error> {
|
async fn handle(
|
||||||
|
&self,
|
||||||
|
buf: &[u8],
|
||||||
|
stream: AssociatedStream,
|
||||||
|
from: NodeID,
|
||||||
|
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> {
|
||||||
match self.0.handler.load_full() {
|
match self.0.handler.load_full() {
|
||||||
None => Err(Error::NoHandler),
|
None => Err(Error::NoHandler),
|
||||||
Some(h) => {
|
Some(h) => {
|
||||||
let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?;
|
let req = rmp_serde::decode::from_read_ref(buf)?;
|
||||||
|
let req = M::deserialize_msg(req, stream).await;
|
||||||
let res = h.handle(&req, from).await;
|
let res = h.handle(&req, from).await;
|
||||||
let res_bytes = rmp_to_vec_all_named(&res)?;
|
let res_bytes = rmp_to_vec_all_named(&res)?;
|
||||||
Ok(res_bytes)
|
Ok(res_bytes)
|
||||||
|
|
|
@ -25,6 +25,9 @@ pub enum Error {
|
||||||
#[error(display = "UTF8 error: {}", _0)]
|
#[error(display = "UTF8 error: {}", _0)]
|
||||||
UTF8(#[error(source)] std::string::FromUtf8Error),
|
UTF8(#[error(source)] std::string::FromUtf8Error),
|
||||||
|
|
||||||
|
#[error(display = "Framing protocol error")]
|
||||||
|
Framing,
|
||||||
|
|
||||||
#[error(display = "{}", _0)]
|
#[error(display = "{}", _0)]
|
||||||
Message(String),
|
Message(String),
|
||||||
|
|
||||||
|
@ -50,6 +53,7 @@ impl Error {
|
||||||
Self::RMPEncode(_) => 10,
|
Self::RMPEncode(_) => 10,
|
||||||
Self::RMPDecode(_) => 11,
|
Self::RMPDecode(_) => 11,
|
||||||
Self::UTF8(_) => 12,
|
Self::UTF8(_) => 12,
|
||||||
|
Self::Framing => 13,
|
||||||
Self::NoHandler => 20,
|
Self::NoHandler => 20,
|
||||||
Self::ConnectionClosed => 21,
|
Self::ConnectionClosed => 21,
|
||||||
Self::Handshake(_) => 30,
|
Self::Handshake(_) => 30,
|
||||||
|
|
|
@ -32,12 +32,14 @@ pub(crate) type VersionTag = [u8; 16];
|
||||||
/// Value of the Netapp version used in the version tag
|
/// Value of the Netapp version used in the version tag
|
||||||
pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004
|
pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub(crate) struct HelloMessage {
|
pub(crate) struct HelloMessage {
|
||||||
pub server_addr: Option<IpAddr>,
|
pub server_addr: Option<IpAddr>,
|
||||||
pub server_port: u16,
|
pub server_port: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AutoSerialize for HelloMessage {}
|
||||||
|
|
||||||
impl Message for HelloMessage {
|
impl Message for HelloMessage {
|
||||||
type Response = ();
|
type Response = ();
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,7 @@ const FAILED_PING_THRESHOLD: usize = 3;
|
||||||
|
|
||||||
// -- Protocol messages --
|
// -- Protocol messages --
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
struct PingMessage {
|
struct PingMessage {
|
||||||
pub id: u64,
|
pub id: u64,
|
||||||
pub peer_list_hash: hash::Digest,
|
pub peer_list_hash: hash::Digest,
|
||||||
|
@ -39,7 +39,9 @@ impl Message for PingMessage {
|
||||||
type Response = PingMessage;
|
type Response = PingMessage;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
impl AutoSerialize for PingMessage {}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
struct PeerListMessage {
|
struct PeerListMessage {
|
||||||
pub list: Vec<(NodeID, SocketAddr)>,
|
pub list: Vec<(NodeID, SocketAddr)>,
|
||||||
}
|
}
|
||||||
|
@ -48,6 +50,8 @@ impl Message for PeerListMessage {
|
||||||
type Response = PeerListMessage;
|
type Response = PeerListMessage;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AutoSerialize for PeerListMessage {}
|
||||||
|
|
||||||
// -- Algorithm data structures --
|
// -- Algorithm data structures --
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
423
src/proto.rs
423
src/proto.rs
|
@ -1,9 +1,13 @@
|
||||||
use std::collections::{HashMap, VecDeque};
|
use std::collections::{HashMap, VecDeque};
|
||||||
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
use log::trace;
|
use log::trace;
|
||||||
|
|
||||||
|
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
|
||||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use futures::{Stream, StreamExt};
|
||||||
use kuska_handshake::async_std::BoxStreamWrite;
|
use kuska_handshake::async_std::BoxStreamWrite;
|
||||||
|
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
@ -11,6 +15,7 @@ use tokio::sync::mpsc;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
use crate::error::*;
|
use crate::error::*;
|
||||||
|
use crate::util::{AssociatedStream, Packet};
|
||||||
|
|
||||||
/// Priority of a request (click to read more about priorities).
|
/// Priority of a request (click to read more about priorities).
|
||||||
///
|
///
|
||||||
|
@ -48,14 +53,148 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01;
|
||||||
|
|
||||||
pub(crate) type RequestID = u32;
|
pub(crate) type RequestID = u32;
|
||||||
type ChunkLength = u16;
|
type ChunkLength = u16;
|
||||||
const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
|
const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0;
|
||||||
|
const ERROR_MARKER: ChunkLength = 0x4000;
|
||||||
const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
|
const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
|
||||||
|
|
||||||
struct SendQueueItem {
|
struct SendQueueItem {
|
||||||
id: RequestID,
|
id: RequestID,
|
||||||
prio: RequestPriority,
|
prio: RequestPriority,
|
||||||
data: Vec<u8>,
|
data: DataReader,
|
||||||
cursor: usize,
|
}
|
||||||
|
|
||||||
|
#[pin_project::pin_project]
|
||||||
|
struct DataReader {
|
||||||
|
#[pin]
|
||||||
|
reader: AssociatedStream,
|
||||||
|
packet: Packet,
|
||||||
|
pos: usize,
|
||||||
|
buf: Vec<u8>,
|
||||||
|
eos: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<AssociatedStream> for DataReader {
|
||||||
|
fn from(data: AssociatedStream) -> DataReader {
|
||||||
|
DataReader {
|
||||||
|
reader: data,
|
||||||
|
packet: Ok(Vec::new()),
|
||||||
|
pos: 0,
|
||||||
|
buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize),
|
||||||
|
eos: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum DataFrame {
|
||||||
|
Data {
|
||||||
|
/// a fixed size buffer containing some data, possibly padded with 0s
|
||||||
|
data: [u8; MAX_CHUNK_LENGTH as usize],
|
||||||
|
/// actual lenght of data
|
||||||
|
len: usize,
|
||||||
|
},
|
||||||
|
Error(u8),
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DataReaderItem {
|
||||||
|
data: DataFrame,
|
||||||
|
/// whethere there may be more data comming from this stream. Can be used for some
|
||||||
|
/// optimization. It's an error to set it to false if there is more data, but it is correct
|
||||||
|
/// (albeit sub-optimal) to set it to true if there is nothing coming after
|
||||||
|
may_have_more: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DataReaderItem {
|
||||||
|
fn empty_last() -> Self {
|
||||||
|
DataReaderItem {
|
||||||
|
data: DataFrame::Data {
|
||||||
|
data: [0; MAX_CHUNK_LENGTH as usize],
|
||||||
|
len: 0,
|
||||||
|
},
|
||||||
|
may_have_more: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn header(&self) -> [u8; 2] {
|
||||||
|
let continuation = if self.may_have_more {
|
||||||
|
CHUNK_HAS_CONTINUATION
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
let len = match self.data {
|
||||||
|
DataFrame::Data { len, .. } => len as u16,
|
||||||
|
DataFrame::Error(e) => e as u16 | ERROR_MARKER,
|
||||||
|
};
|
||||||
|
|
||||||
|
ChunkLength::to_be_bytes(len | continuation)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data(&self) -> &[u8] {
|
||||||
|
match self.data {
|
||||||
|
DataFrame::Data { ref data, len } => &data[..len],
|
||||||
|
DataFrame::Error(_) => &[],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Stream for DataReader {
|
||||||
|
type Item = DataReaderItem;
|
||||||
|
|
||||||
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
let mut this = self.project();
|
||||||
|
|
||||||
|
if *this.eos {
|
||||||
|
// eos was reached at previous call to poll_next, where a partial packet
|
||||||
|
// was returned. Now return None
|
||||||
|
return Poll::Ready(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let packet = match this.packet {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
let e = *e;
|
||||||
|
*this.packet = Ok(Vec::new());
|
||||||
|
return Poll::Ready(Some(DataReaderItem {
|
||||||
|
data: DataFrame::Error(e),
|
||||||
|
may_have_more: true,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let packet_left = packet.len() - *this.pos;
|
||||||
|
let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len();
|
||||||
|
let to_read = std::cmp::min(buf_left, packet_left);
|
||||||
|
this.buf
|
||||||
|
.extend_from_slice(&packet[*this.pos..*this.pos + to_read]);
|
||||||
|
*this.pos += to_read;
|
||||||
|
if this.buf.len() == MAX_CHUNK_LENGTH as usize {
|
||||||
|
// we have a full buf, ready to send
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// we don't have a full buf, packet is empty; try receive more
|
||||||
|
if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) {
|
||||||
|
*this.packet = p;
|
||||||
|
*this.pos = 0;
|
||||||
|
// if buf is empty, we will loop and return the error directly. If buf
|
||||||
|
// isn't empty, send it before by breaking.
|
||||||
|
if this.packet.is_err() && !this.buf.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
*this.eos = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut body = [0; MAX_CHUNK_LENGTH as usize];
|
||||||
|
let len = this.buf.len();
|
||||||
|
body[..len].copy_from_slice(this.buf);
|
||||||
|
this.buf.clear();
|
||||||
|
Poll::Ready(Some(DataReaderItem {
|
||||||
|
data: DataFrame::Data { data: body, len },
|
||||||
|
may_have_more: !*this.eos,
|
||||||
|
}))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SendQueue {
|
struct SendQueue {
|
||||||
|
@ -79,6 +218,8 @@ impl SendQueue {
|
||||||
};
|
};
|
||||||
self.items[pos_prio].1.push_back(item);
|
self.items[pos_prio].1.push_back(item);
|
||||||
}
|
}
|
||||||
|
// used only in tests. They should probably be rewriten
|
||||||
|
#[allow(dead_code)]
|
||||||
fn pop(&mut self) -> Option<SendQueueItem> {
|
fn pop(&mut self) -> Option<SendQueueItem> {
|
||||||
match self.items.pop_front() {
|
match self.items.pop_front() {
|
||||||
None => None,
|
None => None,
|
||||||
|
@ -94,6 +235,54 @@ impl SendQueue {
|
||||||
fn is_empty(&self) -> bool {
|
fn is_empty(&self) -> bool {
|
||||||
self.items.iter().all(|(_k, v)| v.is_empty())
|
self.items.iter().all(|(_k, v)| v.is_empty())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this is like an async fn, but hand implemented
|
||||||
|
fn next_ready(&mut self) -> SendQueuePollNextReady<'_> {
|
||||||
|
SendQueuePollNextReady { queue: self }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SendQueuePollNextReady<'a> {
|
||||||
|
queue: &'a mut SendQueue,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> futures::Future for SendQueuePollNextReady<'a> {
|
||||||
|
type Output = (RequestID, DataReaderItem);
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
for i in 0..self.queue.items.len() {
|
||||||
|
let (_prio, items_at_prio) = &mut self.queue.items[i];
|
||||||
|
|
||||||
|
for _ in 0..items_at_prio.len() {
|
||||||
|
let mut item = items_at_prio.pop_front().unwrap();
|
||||||
|
|
||||||
|
match Pin::new(&mut item.data).poll_next(ctx) {
|
||||||
|
Poll::Pending => items_at_prio.push_back(item),
|
||||||
|
Poll::Ready(Some(data)) => {
|
||||||
|
let id = item.id;
|
||||||
|
if data.may_have_more {
|
||||||
|
self.queue.push(item);
|
||||||
|
} else {
|
||||||
|
if items_at_prio.is_empty() {
|
||||||
|
// this priority level is empty, remove it
|
||||||
|
self.queue.items.remove(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Poll::Ready((id, data));
|
||||||
|
}
|
||||||
|
Poll::Ready(None) => {
|
||||||
|
if items_at_prio.is_empty() {
|
||||||
|
// this priority level is empty, remove it
|
||||||
|
self.queue.items.remove(i);
|
||||||
|
}
|
||||||
|
return Poll::Ready((item.id, DataReaderItem::empty_last()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO what do we do if self.queue is empty? We won't get scheduled again.
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The SendLoop trait, which is implemented both by the client and the server
|
/// The SendLoop trait, which is implemented both by the client and the server
|
||||||
|
@ -108,7 +297,7 @@ impl SendQueue {
|
||||||
pub(crate) trait SendLoop: Sync {
|
pub(crate) trait SendLoop: Sync {
|
||||||
async fn send_loop<W>(
|
async fn send_loop<W>(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
|
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>,
|
||||||
mut write: BoxStreamWrite<W>,
|
mut write: BoxStreamWrite<W>,
|
||||||
) -> Result<(), Error>
|
) -> Result<(), Error>
|
||||||
where
|
where
|
||||||
|
@ -117,55 +306,34 @@ pub(crate) trait SendLoop: Sync {
|
||||||
let mut sending = SendQueue::new();
|
let mut sending = SendQueue::new();
|
||||||
let mut should_exit = false;
|
let mut should_exit = false;
|
||||||
while !should_exit || !sending.is_empty() {
|
while !should_exit || !sending.is_empty() {
|
||||||
if let Ok((id, prio, data)) = msg_recv.try_recv() {
|
let recv_fut = msg_recv.recv();
|
||||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
futures::pin_mut!(recv_fut);
|
||||||
sending.push(SendQueueItem {
|
let send_fut = sending.next_ready();
|
||||||
id,
|
|
||||||
prio,
|
|
||||||
data,
|
|
||||||
cursor: 0,
|
|
||||||
});
|
|
||||||
} else if let Some(mut item) = sending.pop() {
|
|
||||||
trace!(
|
|
||||||
"send_loop: sending bytes for {} ({} bytes, {} already sent)",
|
|
||||||
item.id,
|
|
||||||
item.data.len(),
|
|
||||||
item.cursor
|
|
||||||
);
|
|
||||||
let header_id = RequestID::to_be_bytes(item.id);
|
|
||||||
write.write_all(&header_id[..]).await?;
|
|
||||||
|
|
||||||
if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize {
|
// recv_fut is cancellation-safe according to tokio doc,
|
||||||
let size_header =
|
// send_fut is cancellation-safe as implemented above?
|
||||||
ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION);
|
use futures::future::Either;
|
||||||
write.write_all(&size_header[..]).await?;
|
match futures::future::select(recv_fut, send_fut).await {
|
||||||
|
Either::Left((sth, _send_fut)) => {
|
||||||
let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize;
|
|
||||||
write.write_all(&item.data[item.cursor..new_cursor]).await?;
|
|
||||||
item.cursor = new_cursor;
|
|
||||||
|
|
||||||
sending.push(item);
|
|
||||||
} else {
|
|
||||||
let send_len = (item.data.len() - item.cursor) as ChunkLength;
|
|
||||||
|
|
||||||
let size_header = ChunkLength::to_be_bytes(send_len);
|
|
||||||
write.write_all(&size_header[..]).await?;
|
|
||||||
|
|
||||||
write.write_all(&item.data[item.cursor..]).await?;
|
|
||||||
}
|
|
||||||
write.flush().await?;
|
|
||||||
} else {
|
|
||||||
let sth = msg_recv.recv().await;
|
|
||||||
if let Some((id, prio, data)) = sth {
|
if let Some((id, prio, data)) = sth {
|
||||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
|
||||||
sending.push(SendQueueItem {
|
sending.push(SendQueueItem {
|
||||||
id,
|
id,
|
||||||
prio,
|
prio,
|
||||||
data,
|
data: data.into(),
|
||||||
cursor: 0,
|
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
should_exit = true;
|
should_exit = true;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
Either::Right(((id, data), _recv_fut)) => {
|
||||||
|
trace!("send_loop: sending bytes for {}", id);
|
||||||
|
|
||||||
|
let header_id = RequestID::to_be_bytes(id);
|
||||||
|
write.write_all(&header_id[..]).await?;
|
||||||
|
|
||||||
|
write.write_all(&data.header()).await?;
|
||||||
|
write.write_all(data.data()).await?;
|
||||||
|
write.flush().await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -175,6 +343,118 @@ pub(crate) trait SendLoop: Sync {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) struct Framing {
|
||||||
|
direct: Vec<u8>,
|
||||||
|
stream: Option<AssociatedStream>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Framing {
|
||||||
|
pub fn new(direct: Vec<u8>, stream: Option<AssociatedStream>) -> Self {
|
||||||
|
assert!(direct.len() <= u32::MAX as usize);
|
||||||
|
Framing { direct, stream }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_stream(self) -> AssociatedStream {
|
||||||
|
use futures::stream;
|
||||||
|
let len = self.direct.len() as u32;
|
||||||
|
// required because otherwise the borrow-checker complains
|
||||||
|
let Framing { direct, stream } = self;
|
||||||
|
|
||||||
|
let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) })
|
||||||
|
.chain(stream::once(async move { Ok(direct) }));
|
||||||
|
|
||||||
|
if let Some(stream) = stream {
|
||||||
|
Box::pin(res.chain(stream))
|
||||||
|
} else {
|
||||||
|
Box::pin(res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>(
|
||||||
|
mut stream: S,
|
||||||
|
) -> Result<Self, Error> {
|
||||||
|
let mut packet = stream
|
||||||
|
.next()
|
||||||
|
.await
|
||||||
|
.ok_or(Error::Framing)?
|
||||||
|
.map_err(|_| Error::Framing)?;
|
||||||
|
if packet.len() < 4 {
|
||||||
|
return Err(Error::Framing);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut len = [0; 4];
|
||||||
|
len.copy_from_slice(&packet[..4]);
|
||||||
|
let len = u32::from_be_bytes(len);
|
||||||
|
packet.drain(..4);
|
||||||
|
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
let len = len as usize;
|
||||||
|
loop {
|
||||||
|
let max_cp = std::cmp::min(len - buffer.len(), packet.len());
|
||||||
|
|
||||||
|
buffer.extend_from_slice(&packet[..max_cp]);
|
||||||
|
if buffer.len() == len {
|
||||||
|
packet.drain(..max_cp);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
packet = stream
|
||||||
|
.next()
|
||||||
|
.await
|
||||||
|
.ok_or(Error::Framing)?
|
||||||
|
.map_err(|_| Error::Framing)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let stream: AssociatedStream = if packet.is_empty() {
|
||||||
|
Box::pin(stream)
|
||||||
|
} else {
|
||||||
|
Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream))
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Framing {
|
||||||
|
direct: buffer,
|
||||||
|
stream: Some(stream),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_parts(self) -> (Vec<u8>, AssociatedStream) {
|
||||||
|
let Framing { direct, stream } = self;
|
||||||
|
(direct, stream.unwrap_or(Box::pin(futures::stream::empty())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Structure to warn when the sender is dropped before end of stream was reached, like when
|
||||||
|
/// connection to some remote drops while transmitting data
|
||||||
|
struct Sender {
|
||||||
|
inner: UnboundedSender<Packet>,
|
||||||
|
closed: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Sender {
|
||||||
|
fn new(inner: UnboundedSender<Packet>) -> Self {
|
||||||
|
Sender {
|
||||||
|
inner,
|
||||||
|
closed: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn send(&self, packet: Packet) {
|
||||||
|
let _ = self.inner.unbounded_send(packet);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn end(&mut self) {
|
||||||
|
self.closed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for Sender {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if !self.closed {
|
||||||
|
self.send(Err(255));
|
||||||
|
}
|
||||||
|
self.inner.close_channel();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// The RecvLoop trait, which is implemented both by the client and the server
|
/// The RecvLoop trait, which is implemented both by the client and the server
|
||||||
/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
|
/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
|
||||||
/// and a prototype of a handler for received messages `.recv_handler()` that
|
/// and a prototype of a handler for received messages `.recv_handler()` that
|
||||||
|
@ -184,13 +464,13 @@ pub(crate) trait SendLoop: Sync {
|
||||||
/// the full message is passed to the receive handler.
|
/// the full message is passed to the receive handler.
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub(crate) trait RecvLoop: Sync + 'static {
|
pub(crate) trait RecvLoop: Sync + 'static {
|
||||||
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>);
|
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>);
|
||||||
|
|
||||||
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
|
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
|
||||||
where
|
where
|
||||||
R: AsyncReadExt + Unpin + Send + Sync,
|
R: AsyncReadExt + Unpin + Send + Sync,
|
||||||
{
|
{
|
||||||
let mut receiving = HashMap::new();
|
let mut streams: HashMap<RequestID, Sender> = HashMap::new();
|
||||||
loop {
|
loop {
|
||||||
trace!("recv_loop: reading packet");
|
trace!("recv_loop: reading packet");
|
||||||
let mut header_id = [0u8; RequestID::BITS as usize / 8];
|
let mut header_id = [0u8; RequestID::BITS as usize / 8];
|
||||||
|
@ -208,19 +488,33 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
||||||
trace!("recv_loop: got header size: {:04x}", size);
|
trace!("recv_loop: got header size: {:04x}", size);
|
||||||
|
|
||||||
let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
|
let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
|
||||||
|
let is_error = (size & ERROR_MARKER) != 0;
|
||||||
|
let packet = if is_error {
|
||||||
|
Err(size as u8)
|
||||||
|
} else {
|
||||||
let size = size & !CHUNK_HAS_CONTINUATION;
|
let size = size & !CHUNK_HAS_CONTINUATION;
|
||||||
|
|
||||||
let mut next_slice = vec![0; size as usize];
|
let mut next_slice = vec![0; size as usize];
|
||||||
read.read_exact(&mut next_slice[..]).await?;
|
read.read_exact(&mut next_slice[..]).await?;
|
||||||
trace!("recv_loop: read {} bytes", next_slice.len());
|
trace!("recv_loop: read {} bytes", next_slice.len());
|
||||||
|
Ok(next_slice)
|
||||||
|
};
|
||||||
|
|
||||||
let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default();
|
let mut sender = if let Some(send) = streams.remove(&(id)) {
|
||||||
msg_bytes.extend_from_slice(&next_slice[..]);
|
send
|
||||||
|
} else {
|
||||||
|
let (send, recv) = unbounded();
|
||||||
|
self.recv_handler(id, recv);
|
||||||
|
Sender::new(send)
|
||||||
|
};
|
||||||
|
|
||||||
|
// if we get an error, the receiving end is disconnected. We still need to
|
||||||
|
// reach eos before dropping this sender
|
||||||
|
sender.send(packet);
|
||||||
|
|
||||||
if has_cont {
|
if has_cont {
|
||||||
receiving.insert(id, msg_bytes);
|
streams.insert(id, sender);
|
||||||
} else {
|
} else {
|
||||||
self.recv_handler(id, msg_bytes);
|
sender.end();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -231,43 +525,44 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
||||||
mod test {
|
mod test {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
fn empty_data() -> DataReader {
|
||||||
|
type Item = Packet;
|
||||||
|
let stream: Pin<Box<dyn futures::Stream<Item = Item> + Send + 'static>> =
|
||||||
|
Box::pin(futures::stream::empty::<Packet>());
|
||||||
|
stream.into()
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_priority_queue() {
|
fn test_priority_queue() {
|
||||||
let i1 = SendQueueItem {
|
let i1 = SendQueueItem {
|
||||||
id: 1,
|
id: 1,
|
||||||
prio: PRIO_NORMAL,
|
prio: PRIO_NORMAL,
|
||||||
data: vec![],
|
data: empty_data(),
|
||||||
cursor: 0,
|
|
||||||
};
|
};
|
||||||
let i2 = SendQueueItem {
|
let i2 = SendQueueItem {
|
||||||
id: 2,
|
id: 2,
|
||||||
prio: PRIO_HIGH,
|
prio: PRIO_HIGH,
|
||||||
data: vec![],
|
data: empty_data(),
|
||||||
cursor: 0,
|
|
||||||
};
|
};
|
||||||
let i2bis = SendQueueItem {
|
let i2bis = SendQueueItem {
|
||||||
id: 20,
|
id: 20,
|
||||||
prio: PRIO_HIGH,
|
prio: PRIO_HIGH,
|
||||||
data: vec![],
|
data: empty_data(),
|
||||||
cursor: 0,
|
|
||||||
};
|
};
|
||||||
let i3 = SendQueueItem {
|
let i3 = SendQueueItem {
|
||||||
id: 3,
|
id: 3,
|
||||||
prio: PRIO_HIGH | PRIO_SECONDARY,
|
prio: PRIO_HIGH | PRIO_SECONDARY,
|
||||||
data: vec![],
|
data: empty_data(),
|
||||||
cursor: 0,
|
|
||||||
};
|
};
|
||||||
let i4 = SendQueueItem {
|
let i4 = SendQueueItem {
|
||||||
id: 4,
|
id: 4,
|
||||||
prio: PRIO_BACKGROUND | PRIO_SECONDARY,
|
prio: PRIO_BACKGROUND | PRIO_SECONDARY,
|
||||||
data: vec![],
|
data: empty_data(),
|
||||||
cursor: 0,
|
|
||||||
};
|
};
|
||||||
let i5 = SendQueueItem {
|
let i5 = SendQueueItem {
|
||||||
id: 5,
|
id: 5,
|
||||||
prio: PRIO_BACKGROUND | PRIO_PRIMARY,
|
prio: PRIO_BACKGROUND | PRIO_PRIMARY,
|
||||||
data: vec![],
|
data: empty_data(),
|
||||||
cursor: 0,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut q = SendQueue::new();
|
let mut q = SendQueue::new();
|
||||||
|
|
|
@ -2,7 +2,6 @@ use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use arc_swap::ArcSwapOption;
|
use arc_swap::ArcSwapOption;
|
||||||
use bytes::Bytes;
|
|
||||||
use log::{debug, trace};
|
use log::{debug, trace};
|
||||||
|
|
||||||
#[cfg(feature = "telemetry")]
|
#[cfg(feature = "telemetry")]
|
||||||
|
@ -20,6 +19,7 @@ use tokio::select;
|
||||||
use tokio::sync::{mpsc, watch};
|
use tokio::sync::{mpsc, watch};
|
||||||
use tokio_util::compat::*;
|
use tokio_util::compat::*;
|
||||||
|
|
||||||
|
use futures::channel::mpsc::UnboundedReceiver;
|
||||||
use futures::io::{AsyncReadExt, AsyncWriteExt};
|
use futures::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
@ -55,7 +55,7 @@ pub(crate) struct ServerConn {
|
||||||
|
|
||||||
netapp: Arc<NetApp>,
|
netapp: Arc<NetApp>,
|
||||||
|
|
||||||
resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
|
resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ServerConn {
|
impl ServerConn {
|
||||||
|
@ -123,7 +123,11 @@ impl ServerConn {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
|
async fn recv_handler_aux(
|
||||||
|
self: &Arc<Self>,
|
||||||
|
bytes: &[u8],
|
||||||
|
stream: AssociatedStream,
|
||||||
|
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> {
|
||||||
let msg = QueryMessage::decode(bytes)?;
|
let msg = QueryMessage::decode(bytes)?;
|
||||||
let path = String::from_utf8(msg.path.to_vec())?;
|
let path = String::from_utf8(msg.path.to_vec())?;
|
||||||
|
|
||||||
|
@ -156,11 +160,11 @@ impl ServerConn {
|
||||||
span.set_attribute(KeyValue::new("path", path.to_string()));
|
span.set_attribute(KeyValue::new("path", path.to_string()));
|
||||||
span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));
|
span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));
|
||||||
|
|
||||||
handler.handle(msg.body, self.peer_id)
|
handler.handle(msg.body, stream, self.peer_id)
|
||||||
.with_context(Context::current_with_span(span))
|
.with_context(Context::current_with_span(span))
|
||||||
.await
|
.await
|
||||||
} else {
|
} else {
|
||||||
handler.handle(msg.body, self.peer_id).await
|
handler.handle(msg.body, stream, self.peer_id).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -173,35 +177,40 @@ impl SendLoop for ServerConn {}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RecvLoop for ServerConn {
|
impl RecvLoop for ServerConn {
|
||||||
fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) {
|
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
|
||||||
let resp_send = self.resp_send.load_full().unwrap();
|
let resp_send = self.resp_send.load_full().unwrap();
|
||||||
|
|
||||||
let self2 = self.clone();
|
let self2 = self.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len());
|
trace!("ServerConn recv_handler {}", id);
|
||||||
let bytes: Bytes = bytes.into();
|
let (bytes, stream) = Framing::from_stream(stream).await?.into_parts();
|
||||||
|
|
||||||
let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
|
let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
|
||||||
let resp = self2.recv_handler_aux(&bytes[..]).await;
|
let resp = self2.recv_handler_aux(&bytes[..], stream).await;
|
||||||
|
|
||||||
let resp_bytes = match resp {
|
let (resp_bytes, resp_stream) = match resp {
|
||||||
Ok(rb) => {
|
Ok((rb, rs)) => {
|
||||||
let mut resp_bytes = vec![0u8];
|
let mut resp_bytes = vec![0u8];
|
||||||
resp_bytes.extend(rb);
|
resp_bytes.extend(rb);
|
||||||
resp_bytes
|
(resp_bytes, rs)
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let mut resp_bytes = vec![e.code()];
|
let mut resp_bytes = vec![e.code()];
|
||||||
resp_bytes.extend(e.to_string().into_bytes());
|
resp_bytes.extend(e.to_string().into_bytes());
|
||||||
resp_bytes
|
(resp_bytes, None)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
trace!("ServerConn sending response to {}: ", id);
|
trace!("ServerConn sending response to {}: ", id);
|
||||||
|
|
||||||
resp_send
|
resp_send
|
||||||
.send((id, prio, resp_bytes))
|
.send((
|
||||||
.log_err("ServerConn recv_handler send resp");
|
id,
|
||||||
|
prio,
|
||||||
|
Framing::new(resp_bytes, resp_stream).into_stream(),
|
||||||
|
))
|
||||||
|
.log_err("ServerConn recv_handler send resp bytes");
|
||||||
|
Ok::<_, Error>(())
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ use crate::NodeID;
|
||||||
|
|
||||||
#[tokio::test(flavor = "current_thread")]
|
#[tokio::test(flavor = "current_thread")]
|
||||||
async fn test_with_basic_scheduler() {
|
async fn test_with_basic_scheduler() {
|
||||||
|
pretty_env_logger::init();
|
||||||
run_test().await
|
run_test().await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
28
src/util.rs
28
src/util.rs
|
@ -1,10 +1,15 @@
|
||||||
|
use crate::endpoint::SerializeMessage;
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::net::ToSocketAddrs;
|
use std::net::ToSocketAddrs;
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
use serde::Serialize;
|
use futures::Stream;
|
||||||
|
|
||||||
use log::info;
|
use log::info;
|
||||||
|
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
use tokio::sync::watch;
|
use tokio::sync::watch;
|
||||||
|
|
||||||
/// A node's identifier, which is also its public cryptographic key
|
/// A node's identifier, which is also its public cryptographic key
|
||||||
|
@ -14,21 +19,36 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey;
|
||||||
/// A network key
|
/// A network key
|
||||||
pub type NetworkKey = sodiumoxide::crypto::auth::Key;
|
pub type NetworkKey = sodiumoxide::crypto::auth::Key;
|
||||||
|
|
||||||
|
/// A stream of associated data.
|
||||||
|
///
|
||||||
|
/// The Stream can continue after receiving an error.
|
||||||
|
/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
|
||||||
|
/// consecutive Vec may get merged, but Vec and error code may not be reordered
|
||||||
|
///
|
||||||
|
/// Error code 255 means the stream was cut before its end. Other codes have no predefined
|
||||||
|
/// meaning, it's up to your application to define their semantic.
|
||||||
|
pub type AssociatedStream = Pin<Box<dyn Stream<Item = Packet> + Send>>;
|
||||||
|
|
||||||
|
pub type Packet = Result<Vec<u8>, u8>;
|
||||||
|
|
||||||
/// Utility function: encodes any serializable value in MessagePack binary format
|
/// Utility function: encodes any serializable value in MessagePack binary format
|
||||||
/// using the RMP library.
|
/// using the RMP library.
|
||||||
///
|
///
|
||||||
/// Field names and variant names are included in the serialization.
|
/// Field names and variant names are included in the serialization.
|
||||||
/// This is used internally by the netapp communication protocol.
|
/// This is used internally by the netapp communication protocol.
|
||||||
pub fn rmp_to_vec_all_named<T>(val: &T) -> Result<Vec<u8>, rmp_serde::encode::Error>
|
pub fn rmp_to_vec_all_named<T>(
|
||||||
|
val: &T,
|
||||||
|
) -> Result<(Vec<u8>, Option<AssociatedStream>), rmp_serde::encode::Error>
|
||||||
where
|
where
|
||||||
T: Serialize + ?Sized,
|
T: SerializeMessage + ?Sized,
|
||||||
{
|
{
|
||||||
let mut wr = Vec::with_capacity(128);
|
let mut wr = Vec::with_capacity(128);
|
||||||
let mut se = rmp_serde::Serializer::new(&mut wr)
|
let mut se = rmp_serde::Serializer::new(&mut wr)
|
||||||
.with_struct_map()
|
.with_struct_map()
|
||||||
.with_string_variants();
|
.with_string_variants();
|
||||||
|
let (val, stream) = val.serialize_msg();
|
||||||
val.serialize(&mut se)?;
|
val.serialize(&mut se)?;
|
||||||
Ok(wr)
|
Ok((wr, stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This async function returns only when a true signal was received
|
/// This async function returns only when a true signal was received
|
||||||
|
|
Loading…
Reference in a new issue