forked from lx/netapp
WIP: associated stream #1
11 changed files with 566 additions and 117 deletions
44
Cargo.lock
generated
44
Cargo.lock
generated
|
@ -151,6 +151,19 @@ dependencies = [
|
|||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44533bbbb3bb3c1fa17d9f2e4e38bbbaf8396ba82193c4cb1b6445d711445d36"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"humantime 1.3.0",
|
||||
"log",
|
||||
"regex",
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.8.4"
|
||||
|
@ -158,7 +171,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"humantime",
|
||||
"humantime 2.1.0",
|
||||
"log",
|
||||
"regex",
|
||||
"termcolor",
|
||||
|
@ -322,6 +335,15 @@ version = "0.4.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||
|
||||
[[package]]
|
||||
name = "humantime"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f"
|
||||
dependencies = [
|
||||
"quick-error",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "humantime"
|
||||
version = "2.1.0"
|
||||
|
@ -440,7 +462,7 @@ dependencies = [
|
|||
"bytes 0.6.0",
|
||||
"cfg-if",
|
||||
"chrono",
|
||||
"env_logger",
|
||||
"env_logger 0.8.4",
|
||||
"err-derive",
|
||||
"futures",
|
||||
"hex",
|
||||
|
@ -450,6 +472,8 @@ dependencies = [
|
|||
"lru",
|
||||
"opentelemetry",
|
||||
"opentelemetry-contrib",
|
||||
"pin-project",
|
||||
"pretty_env_logger",
|
||||
"rand 0.5.6",
|
||||
"rmp-serde",
|
||||
"serde",
|
||||
|
@ -582,6 +606,16 @@ version = "0.2.16"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
|
||||
|
||||
[[package]]
|
||||
name = "pretty_env_logger"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "926d36b9553851b8b0005f1275891b392ee4d2d833852c417ed025477350fb9d"
|
||||
dependencies = [
|
||||
"env_logger 0.7.1",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error"
|
||||
version = "1.0.4"
|
||||
|
@ -627,6 +661,12 @@ dependencies = [
|
|||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quick-error"
|
||||
version = "1.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.10"
|
||||
|
|
|
@ -21,6 +21,7 @@ telemetry = ["opentelemetry", "opentelemetry-contrib", "rand"]
|
|||
|
||||
[dependencies]
|
||||
futures = "0.3.17"
|
||||
pin-project = "1.0.10"
|
||||
tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] }
|
||||
tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] }
|
||||
tokio-stream = "0.1.7"
|
||||
|
@ -47,6 +48,7 @@ opentelemetry-contrib = { version = "0.9", optional = true }
|
|||
|
||||
[dev-dependencies]
|
||||
env_logger = "0.8"
|
||||
pretty_env_logger = "0.4"
|
||||
structopt = { version = "0.3", default-features = false }
|
||||
chrono = "0.4"
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ use std::sync::{Arc, Mutex};
|
|||
use arc_swap::ArcSwapOption;
|
||||
use log::{debug, error, trace};
|
||||
|
||||
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::select;
|
||||
use tokio::sync::{mpsc, oneshot, watch};
|
||||
|
@ -37,10 +38,11 @@ pub(crate) struct ClientConn {
|
|||
pub(crate) remote_addr: SocketAddr,
|
||||
pub(crate) peer_id: NodeID,
|
||||
|
||||
query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
|
||||
query_send:
|
||||
ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>,
|
||||
|
||||
next_query_number: AtomicU32,
|
||||
inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
|
||||
inflight: Mutex<HashMap<RequestID, oneshot::Sender<UnboundedReceiver<Packet>>>>,
|
||||
}
|
||||
|
||||
impl ClientConn {
|
||||
|
@ -166,7 +168,7 @@ impl ClientConn {
|
|||
};
|
||||
|
||||
// Encode request
|
||||
let body = rmp_to_vec_all_named(rq.borrow())?;
|
||||
let (body, stream) = rmp_to_vec_all_named(rq.borrow())?;
|
||||
drop(rq);
|
||||
|
||||
let request = QueryMessage {
|
||||
|
@ -185,7 +187,7 @@ impl ClientConn {
|
|||
error!(
|
||||
"Too many inflight requests! RequestID collision. Interrupting previous request."
|
||||
);
|
||||
if old_ch.send(vec![]).is_err() {
|
||||
if old_ch.send(unbounded().1).is_err() {
|
||||
debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
|
||||
}
|
||||
}
|
||||
|
@ -195,17 +197,18 @@ impl ClientConn {
|
|||
#[cfg(feature = "telemetry")]
|
||||
span.set_attribute(KeyValue::new("len_query", bytes.len() as i64));
|
||||
|
||||
query_send.send((id, prio, bytes))?;
|
||||
query_send.send((id, prio, Framing::new(bytes, stream).into_stream()))?;
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "telemetry")] {
|
||||
let resp = resp_recv
|
||||
let stream = resp_recv
|
||||
.with_context(Context::current_with_span(span))
|
||||
.await?;
|
||||
} else {
|
||||
let resp = resp_recv.await?;
|
||||
let stream = resp_recv.await?;
|
||||
}
|
||||
}
|
||||
let (resp, stream) = Framing::from_stream(stream).await?.into_parts();
|
||||
|
||||
if resp.is_empty() {
|
||||
return Err(Error::Message(
|
||||
|
@ -217,10 +220,8 @@ impl ClientConn {
|
|||
|
||||
let code = resp[0];
|
||||
if code == 0 {
|
||||
Ok(rmp_serde::decode::from_read_ref::<
|
||||
_,
|
||||
<T as Message>::Response,
|
||||
>(&resp[1..])?)
|
||||
let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
|
||||
Ok(T::Response::deserialize_msg(ser_resp, stream).await)
|
||||
} else {
|
||||
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
|
||||
Err(Error::Remote(code, msg))
|
||||
|
@ -232,12 +233,12 @@ impl SendLoop for ClientConn {}
|
|||
|
||||
#[async_trait]
|
||||
impl RecvLoop for ClientConn {
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>) {
|
||||
trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len());
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
|
||||
trace!("ClientConn recv_handler {}", id);
|
||||
|
||||
let mut inflight = self.inflight.lock().unwrap();
|
||||
if let Some(ch) = inflight.remove(&id) {
|
||||
if ch.send(msg).is_err() {
|
||||
if ch.send(stream).is_err() {
|
||||
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,8 +14,68 @@ use crate::util::*;
|
|||
|
||||
/// This trait should be implemented by all messages your application
|
||||
/// wants to handle
|
||||
pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
|
||||
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
|
||||
pub trait Message: SerializeMessage + Send + Sync {
|
||||
type Response: SerializeMessage + Send + Sync;
|
||||
}
|
||||
|
||||
/// A trait for de/serializing messages, with possible associated stream.
|
||||
#[async_trait]
|
||||
pub trait SerializeMessage: Sized {
|
||||
type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
|
||||
|
||||
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>);
|
||||
|
||||
// TODO should return Result
|
||||
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self;
|
||||
}
|
||||
|
||||
pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> SerializeMessage for T
|
||||
where
|
||||
T: AutoSerialize,
|
||||
{
|
||||
type SerializableSelf = Self;
|
||||
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>) {
|
||||
(self.clone(), None)
|
||||
}
|
||||
|
||||
async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: AssociatedStream) -> Self {
|
||||
// TODO verify no stream
|
||||
ser_self
|
||||
}
|
||||
}
|
||||
|
||||
impl AutoSerialize for () {}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, E> SerializeMessage for Result<T, E>
|
||||
where
|
||||
T: SerializeMessage + Send,
|
||||
E: SerializeMessage + Send,
|
||||
{
|
||||
type SerializableSelf = Result<T::SerializableSelf, E::SerializableSelf>;
|
||||
|
||||
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>) {
|
||||
match self {
|
||||
Ok(ok) => {
|
||||
let (msg, stream) = ok.serialize_msg();
|
||||
(Ok(msg), stream)
|
||||
}
|
||||
Err(err) => {
|
||||
let (msg, stream) = err.serialize_msg();
|
||||
(Err(msg), stream)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self {
|
||||
match ser_self {
|
||||
Ok(ok) => Ok(T::deserialize_msg(ok, stream).await),
|
||||
Err(err) => Err(E::deserialize_msg(err, stream).await),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This trait should be implemented by an object of your application
|
||||
|
@ -96,7 +156,7 @@ where
|
|||
prio: RequestPriority,
|
||||
) -> Result<<M as Message>::Response, Error>
|
||||
where
|
||||
B: Borrow<M>,
|
||||
B: Borrow<M> + Send + Sync,
|
||||
{
|
||||
if *target == self.netapp.id {
|
||||
match self.handler.load_full() {
|
||||
|
@ -128,7 +188,12 @@ pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;
|
|||
|
||||
#[async_trait]
|
||||
pub(crate) trait GenericEndpoint {
|
||||
async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error>;
|
||||
async fn handle(
|
||||
&self,
|
||||
buf: &[u8],
|
||||
stream: AssociatedStream,
|
||||
from: NodeID,
|
||||
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error>;
|
||||
fn drop_handler(&self);
|
||||
fn clone_endpoint(&self) -> DynEndpoint;
|
||||
}
|
||||
|
@ -145,11 +210,17 @@ where
|
|||
M: Message + 'static,
|
||||
H: EndpointHandler<M> + 'static,
|
||||
{
|
||||
async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error> {
|
||||
async fn handle(
|
||||
&self,
|
||||
buf: &[u8],
|
||||
stream: AssociatedStream,
|
||||
from: NodeID,
|
||||
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> {
|
||||
match self.0.handler.load_full() {
|
||||
None => Err(Error::NoHandler),
|
||||
Some(h) => {
|
||||
let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?;
|
||||
let req = rmp_serde::decode::from_read_ref(buf)?;
|
||||
let req = M::deserialize_msg(req, stream).await;
|
||||
let res = h.handle(&req, from).await;
|
||||
let res_bytes = rmp_to_vec_all_named(&res)?;
|
||||
Ok(res_bytes)
|
||||
|
|
|
@ -25,6 +25,9 @@ pub enum Error {
|
|||
#[error(display = "UTF8 error: {}", _0)]
|
||||
UTF8(#[error(source)] std::string::FromUtf8Error),
|
||||
|
||||
#[error(display = "Framing protocol error")]
|
||||
Framing,
|
||||
|
||||
#[error(display = "{}", _0)]
|
||||
Message(String),
|
||||
|
||||
|
@ -50,6 +53,7 @@ impl Error {
|
|||
Self::RMPEncode(_) => 10,
|
||||
Self::RMPDecode(_) => 11,
|
||||
Self::UTF8(_) => 12,
|
||||
Self::Framing => 13,
|
||||
Self::NoHandler => 20,
|
||||
Self::ConnectionClosed => 21,
|
||||
Self::Handshake(_) => 30,
|
||||
|
|
|
@ -32,12 +32,14 @@ pub(crate) type VersionTag = [u8; 16];
|
|||
/// Value of the Netapp version used in the version tag
|
||||
pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub(crate) struct HelloMessage {
|
||||
pub server_addr: Option<IpAddr>,
|
||||
pub server_port: u16,
|
||||
}
|
||||
|
||||
impl AutoSerialize for HelloMessage {}
|
||||
|
||||
impl Message for HelloMessage {
|
||||
type Response = ();
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ const FAILED_PING_THRESHOLD: usize = 3;
|
|||
|
||||
// -- Protocol messages --
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
struct PingMessage {
|
||||
pub id: u64,
|
||||
pub peer_list_hash: hash::Digest,
|
||||
|
@ -39,7 +39,9 @@ impl Message for PingMessage {
|
|||
type Response = PingMessage;
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
impl AutoSerialize for PingMessage {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
struct PeerListMessage {
|
||||
pub list: Vec<(NodeID, SocketAddr)>,
|
||||
}
|
||||
|
@ -48,6 +50,8 @@ impl Message for PeerListMessage {
|
|||
type Response = PeerListMessage;
|
||||
}
|
||||
|
||||
impl AutoSerialize for PeerListMessage {}
|
||||
|
||||
// -- Algorithm data structures --
|
||||
|
||||
#[derive(Debug)]
|
||||
|
|
423
src/proto.rs
423
src/proto.rs
|
@ -1,9 +1,13 @@
|
|||
use std::collections::{HashMap, VecDeque};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use log::trace;
|
||||
|
||||
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use futures::{Stream, StreamExt};
|
||||
use kuska_handshake::async_std::BoxStreamWrite;
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
|
@ -11,6 +15,7 @@ use tokio::sync::mpsc;
|
|||
use async_trait::async_trait;
|
||||
|
||||
use crate::error::*;
|
||||
use crate::util::{AssociatedStream, Packet};
|
||||
|
||||
/// Priority of a request (click to read more about priorities).
|
||||
///
|
||||
|
@ -48,14 +53,148 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01;
|
|||
|
||||
pub(crate) type RequestID = u32;
|
||||
type ChunkLength = u16;
|
||||
const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
|
||||
const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0;
|
||||
const ERROR_MARKER: ChunkLength = 0x4000;
|
||||
const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
|
||||
|
||||
struct SendQueueItem {
|
||||
id: RequestID,
|
||||
prio: RequestPriority,
|
||||
data: Vec<u8>,
|
||||
cursor: usize,
|
||||
data: DataReader,
|
||||
}
|
||||
|
||||
#[pin_project::pin_project]
|
||||
struct DataReader {
|
||||
#[pin]
|
||||
reader: AssociatedStream,
|
||||
packet: Packet,
|
||||
pos: usize,
|
||||
buf: Vec<u8>,
|
||||
eos: bool,
|
||||
}
|
||||
|
||||
impl From<AssociatedStream> for DataReader {
|
||||
fn from(data: AssociatedStream) -> DataReader {
|
||||
DataReader {
|
||||
reader: data,
|
||||
packet: Ok(Vec::new()),
|
||||
pos: 0,
|
||||
buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize),
|
||||
eos: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum DataFrame {
|
||||
Data {
|
||||
/// a fixed size buffer containing some data, possibly padded with 0s
|
||||
data: [u8; MAX_CHUNK_LENGTH as usize],
|
||||
/// actual lenght of data
|
||||
len: usize,
|
||||
},
|
||||
Error(u8),
|
||||
}
|
||||
|
||||
struct DataReaderItem {
|
||||
data: DataFrame,
|
||||
/// whethere there may be more data comming from this stream. Can be used for some
|
||||
/// optimization. It's an error to set it to false if there is more data, but it is correct
|
||||
/// (albeit sub-optimal) to set it to true if there is nothing coming after
|
||||
may_have_more: bool,
|
||||
}
|
||||
|
||||
impl DataReaderItem {
|
||||
fn empty_last() -> Self {
|
||||
DataReaderItem {
|
||||
data: DataFrame::Data {
|
||||
data: [0; MAX_CHUNK_LENGTH as usize],
|
||||
len: 0,
|
||||
},
|
||||
may_have_more: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn header(&self) -> [u8; 2] {
|
||||
let continuation = if self.may_have_more {
|
||||
CHUNK_HAS_CONTINUATION
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let len = match self.data {
|
||||
DataFrame::Data { len, .. } => len as u16,
|
||||
DataFrame::Error(e) => e as u16 | ERROR_MARKER,
|
||||
};
|
||||
|
||||
ChunkLength::to_be_bytes(len | continuation)
|
||||
}
|
||||
|
||||
fn data(&self) -> &[u8] {
|
||||
match self.data {
|
||||
DataFrame::Data { ref data, len } => &data[..len],
|
||||
DataFrame::Error(_) => &[],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for DataReader {
|
||||
type Item = DataReaderItem;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let mut this = self.project();
|
||||
|
||||
if *this.eos {
|
||||
// eos was reached at previous call to poll_next, where a partial packet
|
||||
// was returned. Now return None
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
|
||||
loop {
|
||||
let packet = match this.packet {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
let e = *e;
|
||||
*this.packet = Ok(Vec::new());
|
||||
return Poll::Ready(Some(DataReaderItem {
|
||||
data: DataFrame::Error(e),
|
||||
may_have_more: true,
|
||||
}));
|
||||
}
|
||||
};
|
||||
let packet_left = packet.len() - *this.pos;
|
||||
let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len();
|
||||
let to_read = std::cmp::min(buf_left, packet_left);
|
||||
this.buf
|
||||
.extend_from_slice(&packet[*this.pos..*this.pos + to_read]);
|
||||
*this.pos += to_read;
|
||||
if this.buf.len() == MAX_CHUNK_LENGTH as usize {
|
||||
// we have a full buf, ready to send
|
||||
break;
|
||||
}
|
||||
|
||||
// we don't have a full buf, packet is empty; try receive more
|
||||
if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) {
|
||||
*this.packet = p;
|
||||
*this.pos = 0;
|
||||
// if buf is empty, we will loop and return the error directly. If buf
|
||||
// isn't empty, send it before by breaking.
|
||||
if this.packet.is_err() && !this.buf.is_empty() {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
*this.eos = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut body = [0; MAX_CHUNK_LENGTH as usize];
|
||||
let len = this.buf.len();
|
||||
body[..len].copy_from_slice(this.buf);
|
||||
this.buf.clear();
|
||||
Poll::Ready(Some(DataReaderItem {
|
||||
data: DataFrame::Data { data: body, len },
|
||||
may_have_more: !*this.eos,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
struct SendQueue {
|
||||
|
@ -79,6 +218,8 @@ impl SendQueue {
|
|||
};
|
||||
self.items[pos_prio].1.push_back(item);
|
||||
}
|
||||
// used only in tests. They should probably be rewriten
|
||||
#[allow(dead_code)]
|
||||
fn pop(&mut self) -> Option<SendQueueItem> {
|
||||
match self.items.pop_front() {
|
||||
None => None,
|
||||
|
@ -94,6 +235,54 @@ impl SendQueue {
|
|||
fn is_empty(&self) -> bool {
|
||||
self.items.iter().all(|(_k, v)| v.is_empty())
|
||||
}
|
||||
|
||||
// this is like an async fn, but hand implemented
|
||||
fn next_ready(&mut self) -> SendQueuePollNextReady<'_> {
|
||||
SendQueuePollNextReady { queue: self }
|
||||
}
|
||||
}
|
||||
|
||||
struct SendQueuePollNextReady<'a> {
|
||||
queue: &'a mut SendQueue,
|
||||
}
|
||||
|
||||
impl<'a> futures::Future for SendQueuePollNextReady<'a> {
|
||||
type Output = (RequestID, DataReaderItem);
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
for i in 0..self.queue.items.len() {
|
||||
let (_prio, items_at_prio) = &mut self.queue.items[i];
|
||||
|
||||
for _ in 0..items_at_prio.len() {
|
||||
let mut item = items_at_prio.pop_front().unwrap();
|
||||
|
||||
match Pin::new(&mut item.data).poll_next(ctx) {
|
||||
Poll::Pending => items_at_prio.push_back(item),
|
||||
Poll::Ready(Some(data)) => {
|
||||
let id = item.id;
|
||||
if data.may_have_more {
|
||||
self.queue.push(item);
|
||||
} else {
|
||||
if items_at_prio.is_empty() {
|
||||
// this priority level is empty, remove it
|
||||
self.queue.items.remove(i);
|
||||
}
|
||||
}
|
||||
return Poll::Ready((id, data));
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
if items_at_prio.is_empty() {
|
||||
// this priority level is empty, remove it
|
||||
self.queue.items.remove(i);
|
||||
}
|
||||
return Poll::Ready((item.id, DataReaderItem::empty_last()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO what do we do if self.queue is empty? We won't get scheduled again.
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
/// The SendLoop trait, which is implemented both by the client and the server
|
||||
|
@ -108,7 +297,7 @@ impl SendQueue {
|
|||
pub(crate) trait SendLoop: Sync {
|
||||
async fn send_loop<W>(
|
||||
self: Arc<Self>,
|
||||
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
|
||||
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>,
|
||||
mut write: BoxStreamWrite<W>,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
|
@ -117,55 +306,34 @@ pub(crate) trait SendLoop: Sync {
|
|||
let mut sending = SendQueue::new();
|
||||
let mut should_exit = false;
|
||||
while !should_exit || !sending.is_empty() {
|
||||
if let Ok((id, prio, data)) = msg_recv.try_recv() {
|
||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||
sending.push(SendQueueItem {
|
||||
id,
|
||||
prio,
|
||||
data,
|
||||
cursor: 0,
|
||||
});
|
||||
} else if let Some(mut item) = sending.pop() {
|
||||
trace!(
|
||||
"send_loop: sending bytes for {} ({} bytes, {} already sent)",
|
||||
item.id,
|
||||
item.data.len(),
|
||||
item.cursor
|
||||
);
|
||||
let header_id = RequestID::to_be_bytes(item.id);
|
||||
write.write_all(&header_id[..]).await?;
|
||||
let recv_fut = msg_recv.recv();
|
||||
futures::pin_mut!(recv_fut);
|
||||
let send_fut = sending.next_ready();
|
||||
|
||||
if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize {
|
||||
let size_header =
|
||||
ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION);
|
||||
write.write_all(&size_header[..]).await?;
|
||||
|
||||
let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize;
|
||||
write.write_all(&item.data[item.cursor..new_cursor]).await?;
|
||||
item.cursor = new_cursor;
|
||||
|
||||
sending.push(item);
|
||||
} else {
|
||||
let send_len = (item.data.len() - item.cursor) as ChunkLength;
|
||||
|
||||
let size_header = ChunkLength::to_be_bytes(send_len);
|
||||
write.write_all(&size_header[..]).await?;
|
||||
|
||||
write.write_all(&item.data[item.cursor..]).await?;
|
||||
}
|
||||
write.flush().await?;
|
||||
} else {
|
||||
let sth = msg_recv.recv().await;
|
||||
// recv_fut is cancellation-safe according to tokio doc,
|
||||
// send_fut is cancellation-safe as implemented above?
|
||||
use futures::future::Either;
|
||||
match futures::future::select(recv_fut, send_fut).await {
|
||||
Either::Left((sth, _send_fut)) => {
|
||||
if let Some((id, prio, data)) = sth {
|
||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||
sending.push(SendQueueItem {
|
||||
id,
|
||||
prio,
|
||||
data,
|
||||
cursor: 0,
|
||||
data: data.into(),
|
||||
});
|
||||
} else {
|
||||
should_exit = true;
|
||||
};
|
||||
}
|
||||
Either::Right(((id, data), _recv_fut)) => {
|
||||
trace!("send_loop: sending bytes for {}", id);
|
||||
|
||||
let header_id = RequestID::to_be_bytes(id);
|
||||
write.write_all(&header_id[..]).await?;
|
||||
|
||||
write.write_all(&data.header()).await?;
|
||||
write.write_all(data.data()).await?;
|
||||
write.flush().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -175,6 +343,118 @@ pub(crate) trait SendLoop: Sync {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Framing {
|
||||
direct: Vec<u8>,
|
||||
stream: Option<AssociatedStream>,
|
||||
}
|
||||
|
||||
impl Framing {
|
||||
pub fn new(direct: Vec<u8>, stream: Option<AssociatedStream>) -> Self {
|
||||
assert!(direct.len() <= u32::MAX as usize);
|
||||
Framing { direct, stream }
|
||||
}
|
||||
|
||||
pub fn into_stream(self) -> AssociatedStream {
|
||||
use futures::stream;
|
||||
let len = self.direct.len() as u32;
|
||||
// required because otherwise the borrow-checker complains
|
||||
let Framing { direct, stream } = self;
|
||||
|
||||
let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) })
|
||||
.chain(stream::once(async move { Ok(direct) }));
|
||||
|
||||
if let Some(stream) = stream {
|
||||
Box::pin(res.chain(stream))
|
||||
} else {
|
||||
Box::pin(res)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>(
|
||||
mut stream: S,
|
||||
) -> Result<Self, Error> {
|
||||
let mut packet = stream
|
||||
.next()
|
||||
.await
|
||||
.ok_or(Error::Framing)?
|
||||
.map_err(|_| Error::Framing)?;
|
||||
if packet.len() < 4 {
|
||||
return Err(Error::Framing);
|
||||
}
|
||||
|
||||
let mut len = [0; 4];
|
||||
len.copy_from_slice(&packet[..4]);
|
||||
let len = u32::from_be_bytes(len);
|
||||
packet.drain(..4);
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
let len = len as usize;
|
||||
loop {
|
||||
let max_cp = std::cmp::min(len - buffer.len(), packet.len());
|
||||
|
||||
buffer.extend_from_slice(&packet[..max_cp]);
|
||||
if buffer.len() == len {
|
||||
packet.drain(..max_cp);
|
||||
break;
|
||||
}
|
||||
packet = stream
|
||||
.next()
|
||||
.await
|
||||
.ok_or(Error::Framing)?
|
||||
.map_err(|_| Error::Framing)?;
|
||||
}
|
||||
|
||||
let stream: AssociatedStream = if packet.is_empty() {
|
||||
Box::pin(stream)
|
||||
} else {
|
||||
Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream))
|
||||
};
|
||||
|
||||
Ok(Framing {
|
||||
direct: buffer,
|
||||
stream: Some(stream),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn into_parts(self) -> (Vec<u8>, AssociatedStream) {
|
||||
let Framing { direct, stream } = self;
|
||||
(direct, stream.unwrap_or(Box::pin(futures::stream::empty())))
|
||||
}
|
||||
}
|
||||
|
||||
/// Structure to warn when the sender is dropped before end of stream was reached, like when
|
||||
/// connection to some remote drops while transmitting data
|
||||
struct Sender {
|
||||
inner: UnboundedSender<Packet>,
|
||||
closed: bool,
|
||||
}
|
||||
|
||||
impl Sender {
|
||||
fn new(inner: UnboundedSender<Packet>) -> Self {
|
||||
Sender {
|
||||
inner,
|
||||
closed: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn send(&self, packet: Packet) {
|
||||
let _ = self.inner.unbounded_send(packet);
|
||||
}
|
||||
|
||||
fn end(&mut self) {
|
||||
self.closed = true;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Sender {
|
||||
fn drop(&mut self) {
|
||||
if !self.closed {
|
||||
self.send(Err(255));
|
||||
}
|
||||
self.inner.close_channel();
|
||||
}
|
||||
}
|
||||
|
||||
/// The RecvLoop trait, which is implemented both by the client and the server
|
||||
/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
|
||||
/// and a prototype of a handler for received messages `.recv_handler()` that
|
||||
|
@ -184,13 +464,13 @@ pub(crate) trait SendLoop: Sync {
|
|||
/// the full message is passed to the receive handler.
|
||||
#[async_trait]
|
||||
pub(crate) trait RecvLoop: Sync + 'static {
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>);
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>);
|
||||
|
||||
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
|
||||
where
|
||||
R: AsyncReadExt + Unpin + Send + Sync,
|
||||
{
|
||||
let mut receiving = HashMap::new();
|
||||
let mut streams: HashMap<RequestID, Sender> = HashMap::new();
|
||||
loop {
|
||||
trace!("recv_loop: reading packet");
|
||||
let mut header_id = [0u8; RequestID::BITS as usize / 8];
|
||||
|
@ -208,19 +488,33 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
|||
trace!("recv_loop: got header size: {:04x}", size);
|
||||
|
||||
let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
|
||||
let is_error = (size & ERROR_MARKER) != 0;
|
||||
let packet = if is_error {
|
||||
Err(size as u8)
|
||||
} else {
|
||||
let size = size & !CHUNK_HAS_CONTINUATION;
|
||||
|
||||
let mut next_slice = vec![0; size as usize];
|
||||
read.read_exact(&mut next_slice[..]).await?;
|
||||
trace!("recv_loop: read {} bytes", next_slice.len());
|
||||
Ok(next_slice)
|
||||
};
|
||||
|
||||
let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default();
|
||||
msg_bytes.extend_from_slice(&next_slice[..]);
|
||||
let mut sender = if let Some(send) = streams.remove(&(id)) {
|
||||
send
|
||||
} else {
|
||||
let (send, recv) = unbounded();
|
||||
self.recv_handler(id, recv);
|
||||
Sender::new(send)
|
||||
};
|
||||
|
||||
// if we get an error, the receiving end is disconnected. We still need to
|
||||
// reach eos before dropping this sender
|
||||
sender.send(packet);
|
||||
|
||||
if has_cont {
|
||||
receiving.insert(id, msg_bytes);
|
||||
streams.insert(id, sender);
|
||||
} else {
|
||||
self.recv_handler(id, msg_bytes);
|
||||
sender.end();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
@ -231,43 +525,44 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
|||
mod test {
|
||||
use super::*;
|
||||
|
||||
fn empty_data() -> DataReader {
|
||||
type Item = Packet;
|
||||
let stream: Pin<Box<dyn futures::Stream<Item = Item> + Send + 'static>> =
|
||||
Box::pin(futures::stream::empty::<Packet>());
|
||||
stream.into()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_priority_queue() {
|
||||
let i1 = SendQueueItem {
|
||||
id: 1,
|
||||
prio: PRIO_NORMAL,
|
||||
data: vec![],
|
||||
cursor: 0,
|
||||
data: empty_data(),
|
||||
};
|
||||
let i2 = SendQueueItem {
|
||||
id: 2,
|
||||
prio: PRIO_HIGH,
|
||||
data: vec![],
|
||||
cursor: 0,
|
||||
data: empty_data(),
|
||||
};
|
||||
let i2bis = SendQueueItem {
|
||||
id: 20,
|
||||
prio: PRIO_HIGH,
|
||||
data: vec![],
|
||||
cursor: 0,
|
||||
data: empty_data(),
|
||||
};
|
||||
let i3 = SendQueueItem {
|
||||
id: 3,
|
||||
prio: PRIO_HIGH | PRIO_SECONDARY,
|
||||
data: vec![],
|
||||
cursor: 0,
|
||||
data: empty_data(),
|
||||
};
|
||||
let i4 = SendQueueItem {
|
||||
id: 4,
|
||||
prio: PRIO_BACKGROUND | PRIO_SECONDARY,
|
||||
data: vec![],
|
||||
cursor: 0,
|
||||
data: empty_data(),
|
||||
};
|
||||
let i5 = SendQueueItem {
|
||||
id: 5,
|
||||
prio: PRIO_BACKGROUND | PRIO_PRIMARY,
|
||||
data: vec![],
|
||||
cursor: 0,
|
||||
data: empty_data(),
|
||||
};
|
||||
|
||||
let mut q = SendQueue::new();
|
||||
|
|
|
@ -2,7 +2,6 @@ use std::net::SocketAddr;
|
|||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use bytes::Bytes;
|
||||
use log::{debug, trace};
|
||||
|
||||
#[cfg(feature = "telemetry")]
|
||||
|
@ -20,6 +19,7 @@ use tokio::select;
|
|||
use tokio::sync::{mpsc, watch};
|
||||
use tokio_util::compat::*;
|
||||
|
||||
use futures::channel::mpsc::UnboundedReceiver;
|
||||
use futures::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
@ -55,7 +55,7 @@ pub(crate) struct ServerConn {
|
|||
|
||||
netapp: Arc<NetApp>,
|
||||
|
||||
resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
|
||||
resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>,
|
||||
}
|
||||
|
||||
impl ServerConn {
|
||||
|
@ -123,7 +123,11 @@ impl ServerConn {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
|
||||
async fn recv_handler_aux(
|
||||
self: &Arc<Self>,
|
||||
bytes: &[u8],
|
||||
stream: AssociatedStream,
|
||||
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> {
|
||||
let msg = QueryMessage::decode(bytes)?;
|
||||
let path = String::from_utf8(msg.path.to_vec())?;
|
||||
|
||||
|
@ -156,11 +160,11 @@ impl ServerConn {
|
|||
span.set_attribute(KeyValue::new("path", path.to_string()));
|
||||
span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));
|
||||
|
||||
handler.handle(msg.body, self.peer_id)
|
||||
handler.handle(msg.body, stream, self.peer_id)
|
||||
.with_context(Context::current_with_span(span))
|
||||
.await
|
||||
} else {
|
||||
handler.handle(msg.body, self.peer_id).await
|
||||
handler.handle(msg.body, stream, self.peer_id).await
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -173,35 +177,40 @@ impl SendLoop for ServerConn {}
|
|||
|
||||
#[async_trait]
|
||||
impl RecvLoop for ServerConn {
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) {
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
|
||||
let resp_send = self.resp_send.load_full().unwrap();
|
||||
|
||||
let self2 = self.clone();
|
||||
tokio::spawn(async move {
|
||||
trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len());
|
||||
let bytes: Bytes = bytes.into();
|
||||
trace!("ServerConn recv_handler {}", id);
|
||||
let (bytes, stream) = Framing::from_stream(stream).await?.into_parts();
|
||||
|
||||
let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
|
||||
let resp = self2.recv_handler_aux(&bytes[..]).await;
|
||||
let resp = self2.recv_handler_aux(&bytes[..], stream).await;
|
||||
|
||||
let resp_bytes = match resp {
|
||||
Ok(rb) => {
|
||||
let (resp_bytes, resp_stream) = match resp {
|
||||
Ok((rb, rs)) => {
|
||||
let mut resp_bytes = vec![0u8];
|
||||
resp_bytes.extend(rb);
|
||||
resp_bytes
|
||||
(resp_bytes, rs)
|
||||
}
|
||||
Err(e) => {
|
||||
let mut resp_bytes = vec![e.code()];
|
||||
resp_bytes.extend(e.to_string().into_bytes());
|
||||
resp_bytes
|
||||
(resp_bytes, None)
|
||||
}
|
||||
};
|
||||
|
||||
trace!("ServerConn sending response to {}: ", id);
|
||||
|
||||
resp_send
|
||||
.send((id, prio, resp_bytes))
|
||||
.log_err("ServerConn recv_handler send resp");
|
||||
.send((
|
||||
id,
|
||||
prio,
|
||||
Framing::new(resp_bytes, resp_stream).into_stream(),
|
||||
))
|
||||
.log_err("ServerConn recv_handler send resp bytes");
|
||||
Ok::<_, Error>(())
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ use crate::NodeID;
|
|||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn test_with_basic_scheduler() {
|
||||
pretty_env_logger::init();
|
||||
run_test().await
|
||||
}
|
||||
|
||||
|
|
28
src/util.rs
28
src/util.rs
|
@ -1,10 +1,15 @@
|
|||
use crate::endpoint::SerializeMessage;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::net::ToSocketAddrs;
|
||||
use std::pin::Pin;
|
||||
|
||||
use serde::Serialize;
|
||||
use futures::Stream;
|
||||
|
||||
use log::info;
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
use tokio::sync::watch;
|
||||
|
||||
/// A node's identifier, which is also its public cryptographic key
|
||||
|
@ -14,21 +19,36 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey;
|
|||
/// A network key
|
||||
pub type NetworkKey = sodiumoxide::crypto::auth::Key;
|
||||
|
||||
/// A stream of associated data.
|
||||
///
|
||||
/// The Stream can continue after receiving an error.
|
||||
/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
|
||||
/// consecutive Vec may get merged, but Vec and error code may not be reordered
|
||||
///
|
||||
/// Error code 255 means the stream was cut before its end. Other codes have no predefined
|
||||
/// meaning, it's up to your application to define their semantic.
|
||||
pub type AssociatedStream = Pin<Box<dyn Stream<Item = Packet> + Send>>;
|
||||
|
||||
pub type Packet = Result<Vec<u8>, u8>;
|
||||
|
||||
/// Utility function: encodes any serializable value in MessagePack binary format
|
||||
/// using the RMP library.
|
||||
///
|
||||
/// Field names and variant names are included in the serialization.
|
||||
/// This is used internally by the netapp communication protocol.
|
||||
pub fn rmp_to_vec_all_named<T>(val: &T) -> Result<Vec<u8>, rmp_serde::encode::Error>
|
||||
pub fn rmp_to_vec_all_named<T>(
|
||||
val: &T,
|
||||
) -> Result<(Vec<u8>, Option<AssociatedStream>), rmp_serde::encode::Error>
|
||||
where
|
||||
T: Serialize + ?Sized,
|
||||
T: SerializeMessage + ?Sized,
|
||||
{
|
||||
let mut wr = Vec::with_capacity(128);
|
||||
let mut se = rmp_serde::Serializer::new(&mut wr)
|
||||
.with_struct_map()
|
||||
.with_string_variants();
|
||||
let (val, stream) = val.serialize_msg();
|
||||
val.serialize(&mut se)?;
|
||||
Ok(wr)
|
||||
Ok((wr, stream))
|
||||
}
|
||||
|
||||
/// This async function returns only when a true signal was received
|
||||
|
|
Loading…
Reference in a new issue