Compare commits

..

No commits in common. "main" and "stream-body" have entirely different histories.

24 changed files with 1361 additions and 2227 deletions

View file

@ -1,17 +1,28 @@
---
kind: pipeline kind: pipeline
name: default name: default
workspace:
base: /drone
clone:
disable: true
steps: steps:
- name: clone
image: alpine/git
commands:
- mkdir -p cargo
- git clone $DRONE_GIT_HTTP_URL
- cd netapp
- git checkout $DRONE_COMMIT
- name: style - name: style
image: rust:1.58-buster image: rust:1.58-buster
environment: environment:
CARGO_HOME: /drone/cargo CARGO_HOME: /drone/cargo
volumes:
- name: cargo
path: /drone/cargo
commands: commands:
- rustup component add rustfmt clippy - rustup component add rustfmt clippy
- cd netapp
- cargo fmt -- --check - cargo fmt -- --check
- cargo clippy --all-features -- --deny warnings - cargo clippy --all-features -- --deny warnings
- cargo clippy --example fullmesh -- --deny warnings - cargo clippy --example fullmesh -- --deny warnings
@ -21,13 +32,11 @@ steps:
image: rust:1.58-buster image: rust:1.58-buster
environment: environment:
CARGO_HOME: /drone/cargo CARGO_HOME: /drone/cargo
volumes:
- name: cargo
path: /drone/cargo
commands: commands:
- apt-get update - apt-get update
- apt-get install --yes libsodium-dev - apt-get install --yes libsodium-dev
- cargo install -f cargo-all-features - cargo install -f cargo-all-features
- cd netapp
- cargo build-all-features - cargo build-all-features
- cargo build --example fullmesh - cargo build --example fullmesh
- cargo build --example basalt --features "basalt" - cargo build --example basalt --features "basalt"
@ -36,19 +45,8 @@ steps:
image: rust:1.58-buster image: rust:1.58-buster
environment: environment:
CARGO_HOME: /drone/cargo CARGO_HOME: /drone/cargo
volumes:
- name: cargo
path: /drone/cargo
commands: commands:
- apt-get update - apt-get update
- apt-get install --yes libsodium-dev - apt-get install --yes libsodium-dev
- cd netapp
- cargo test --all-features -- --test-threads 1 - cargo test --all-features -- --test-threads 1
volumes:
- name: cargo
temp: {}
---
kind: signature
hmac: f0d1a9e8d85a22c1d9084b4d90c9930be9700da52284f1875ece996cc52a6ce9
...

543
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
[package] [package]
name = "netapp" name = "netapp"
version = "0.10.0" version = "0.4.4"
authors = ["Alex Auvolat <alex@adnab.me>"] authors = ["Alex Auvolat <alex@adnab.me>"]
edition = "2018" edition = "2018"
license-file = "LICENSE" license-file = "LICENSE"
@ -16,28 +16,28 @@ name = "netapp"
[features] [features]
default = [] default = []
basalt = ["lru"] basalt = ["lru", "rand"]
telemetry = ["opentelemetry", "opentelemetry-contrib"] telemetry = ["opentelemetry", "opentelemetry-contrib", "rand"]
[dependencies] [dependencies]
futures = "0.3.17" futures = "0.3.17"
pin-project = "1.0.10" pin-project = "1.0.10"
tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] } tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] }
tokio-util = { version = "0.7", default-features = false, features = ["compat", "io"] } tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] }
tokio-stream = "0.1.7" tokio-stream = "0.1.7"
serde = { version = "1.0", default-features = false, features = ["derive", "rc"] } serde = { version = "1.0", default-features = false, features = ["derive"] }
rmp-serde = "1.1" rmp-serde = "0.14.3"
hex = "0.4.2" hex = "0.4.2"
rand = { version = "0.8" } rand = { version = "0.5.5", optional = true }
log = "0.4.8" log = "0.4.8"
arc-swap = "1.1" arc-swap = "1.1"
async-trait = "0.1.7" async-trait = "0.1.7"
err-derive = "0.3" err-derive = "0.2.3"
bytes = "1.2" bytes = "0.6.0"
lru = { version = "0.7", optional = true } lru = { version = "0.6", optional = true }
cfg-if = "1.0" cfg-if = "1.0"
sodiumoxide = { version = "0.2.5-0", package = "kuska-sodiumoxide" } sodiumoxide = { version = "0.2.5-0", package = "kuska-sodiumoxide" }
@ -47,7 +47,8 @@ opentelemetry = { version = "0.17", optional = true }
opentelemetry-contrib = { version = "0.9", optional = true } opentelemetry-contrib = { version = "0.9", optional = true }
[dev-dependencies] [dev-dependencies]
env_logger = "0.9" env_logger = "0.8"
pretty_env_logger = "0.4"
structopt = { version = "0.3", default-features = false } structopt = { version = "0.3", default-features = false }
chrono = "0.4" chrono = "0.4"

View file

@ -1,6 +1,5 @@
all: all:
#cargo build --all-features cargo build --all-features
cargo build
cargo build --example fullmesh cargo build --example fullmesh
cargo build --all-features --example basalt cargo build --all-features --example basalt
RUST_LOG=netapp=trace,fullmesh=trace cargo run --example fullmesh -- -n 3242ce79e05e8b6a0e43441fbd140a906e13f335f298ae3a52f29784abbab500 -p 6c304114a0e1018bbe60502a34d33f4f439f370856c3333dda2726da01eb93a4894b7ef7249a71f11d342b69702f1beb7c93ec95fbcf122ad1eca583bb0629e7 RUST_LOG=netapp=trace,fullmesh=trace cargo run --example fullmesh -- -n 3242ce79e05e8b6a0e43441fbd140a906e13f335f298ae3a52f29784abbab500 -p 6c304114a0e1018bbe60502a34d33f4f439f370856c3333dda2726da01eb93a4894b7ef7249a71f11d342b69702f1beb7c93ec95fbcf122ad1eca583bb0629e7

View file

@ -14,8 +14,8 @@ use sodiumoxide::crypto::sign::ed25519;
use tokio::sync::watch; use tokio::sync::watch;
use netapp::endpoint::*; use netapp::endpoint::*;
use netapp::message::*;
use netapp::peering::basalt::*; use netapp::peering::basalt::*;
use netapp::proto::*;
use netapp::util::parse_peer_addr; use netapp::util::parse_peer_addr;
use netapp::{NetApp, NodeID}; use netapp::{NetApp, NodeID};
@ -145,7 +145,7 @@ impl Example {
tokio::spawn(async move { tokio::spawn(async move {
match self2 match self2
.example_endpoint .example_endpoint
.call(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL) .call(&p, &ExampleMessage { example_field: 42 }, PRIO_NORMAL)
.await .await
{ {
Ok(resp) => debug!("Got example response: {:?}", resp), Ok(resp) => debug!("Got example response: {:?}", resp),

View file

@ -1,24 +1,16 @@
use std::io::Write; use std::io::Write;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait; use log::info;
use bytes::Bytes;
use futures::{stream, StreamExt};
use log::*;
use serde::{Deserialize, Serialize};
use structopt::StructOpt; use structopt::StructOpt;
use tokio::sync::watch;
use sodiumoxide::crypto::auth; use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519; use sodiumoxide::crypto::sign::ed25519;
use netapp::endpoint::*;
use netapp::message::*;
use netapp::peering::fullmesh::*; use netapp::peering::fullmesh::*;
use netapp::util::*; use netapp::util::*;
use netapp::{NetApp, NodeID}; use netapp::NetApp;
#[derive(StructOpt, Debug)] #[derive(StructOpt, Debug)]
#[structopt(name = "netapp")] #[structopt(name = "netapp")]
@ -72,7 +64,7 @@ async fn main() {
}; };
info!("Node private key: {}", hex::encode(&privkey)); info!("Node private key: {}", hex::encode(&privkey));
info!("Node public key: {}", hex::encode(privkey.public_key())); info!("Node public key: {}", hex::encode(&privkey.public_key()));
let public_addr = opt.public_addr.map(|x| x.parse().unwrap()); let public_addr = opt.public_addr.map(|x| x.parse().unwrap());
let listen_addr: SocketAddr = opt.listen_addr.parse().unwrap(); let listen_addr: SocketAddr = opt.listen_addr.parse().unwrap();
@ -94,126 +86,13 @@ async fn main() {
info!("Add more peers to this mesh by running: fullmesh -n {} -l 127.0.0.1:$((1000 + $RANDOM)) -b {}@{}", info!("Add more peers to this mesh by running: fullmesh -n {} -l 127.0.0.1:$((1000 + $RANDOM)) -b {}@{}",
hex::encode(&netid), hex::encode(&netid),
hex::encode(privkey.public_key()), hex::encode(&privkey.public_key()),
listen_addr); listen_addr);
let watch_cancel = netapp::util::watch_ctrl_c(); let watch_cancel = netapp::util::watch_ctrl_c();
let example = Arc::new(Example {
netapp: netapp.clone(),
fullmesh: peering.clone(),
example_endpoint: netapp.endpoint("__netapp/examples/fullmesh.rs/Example".into()),
});
example.example_endpoint.set_handler(example.clone());
tokio::join!( tokio::join!(
example.exchange_loop(watch_cancel.clone()),
netapp.listen(listen_addr, public_addr, watch_cancel.clone()), netapp.listen(listen_addr, public_addr, watch_cancel.clone()),
peering.run(watch_cancel), peering.run(watch_cancel),
); );
} }
// ----
struct Example {
netapp: Arc<NetApp>,
fullmesh: Arc<FullMeshPeeringStrategy>,
example_endpoint: Arc<Endpoint<ExampleMessage, Self>>,
}
impl Example {
async fn exchange_loop(self: Arc<Self>, must_exit: watch::Receiver<bool>) {
let mut i = 12000;
while !*must_exit.borrow() {
tokio::time::sleep(Duration::from_secs(2)).await;
let peers = self.fullmesh.get_peer_list();
for p in peers.iter() {
let id = p.id;
if id == self.netapp.id {
continue;
}
i += 1;
let example_field = i;
let self2 = self.clone();
tokio::spawn(async move {
info!(
"Send example query {} to {}",
example_field,
hex::encode(id)
);
// Fake data stream with some delays in item production
let stream =
Box::pin(stream::iter([100, 200, 300, 400]).then(|x| async move {
tokio::time::sleep(Duration::from_millis(500)).await;
Ok(Bytes::from(vec![(x % 256) as u8; 133 * x]))
}));
match self2
.example_endpoint
.call_streaming(
&id,
Req::new(ExampleMessage { example_field })
.unwrap()
.with_stream(stream),
PRIO_NORMAL,
)
.await
{
Ok(resp) => {
let (resp, stream) = resp.into_parts();
info!(
"Got example response to {} from {}: {:?}",
example_field,
hex::encode(id),
resp
);
let mut stream = stream.unwrap();
while let Some(x) = stream.next().await {
info!("Response: stream got bytes {:?}", x.map(|b| b.len()));
}
}
Err(e) => warn!("Error with example request: {}", e),
}
});
}
}
}
}
#[async_trait]
impl StreamingEndpointHandler<ExampleMessage> for Example {
async fn handle(
self: &Arc<Self>,
mut msg: Req<ExampleMessage>,
_from: NodeID,
) -> Resp<ExampleMessage> {
info!(
"Got example message: {:?}, sending example response",
msg.msg()
);
let source_stream = msg.take_stream().unwrap();
// Return same stream with 300ms delay
let new_stream = Box::pin(source_stream.then(|x| async move {
tokio::time::sleep(Duration::from_millis(300)).await;
x
}));
Resp::new(ExampleResponse {
example_field: false,
})
.with_stream(new_stream)
}
}
#[derive(Serialize, Deserialize, Debug)]
struct ExampleMessage {
example_field: usize,
}
#[derive(Serialize, Deserialize, Debug)]
struct ExampleResponse {
example_field: bool,
}
impl Message for ExampleMessage {
type Response = ExampleResponse;
}

View file

@ -1,186 +0,0 @@
use std::cmp::Ordering;
use std::collections::VecDeque;
use bytes::BytesMut;
pub use bytes::Bytes;
/// A circular buffer of bytes, internally represented as a list of Bytes
/// for optimization, but that for all intent and purposes acts just like
/// a big byte slice which can be extended on the right and from which
/// stuff can be taken on the left.
pub struct BytesBuf {
buf: VecDeque<Bytes>,
buf_len: usize,
}
impl BytesBuf {
/// Creates a new empty BytesBuf
pub fn new() -> Self {
Self {
buf: VecDeque::new(),
buf_len: 0,
}
}
/// Returns the number of bytes stored in the BytesBuf
#[inline]
pub fn len(&self) -> usize {
self.buf_len
}
/// Returns true iff the BytesBuf contains zero bytes
#[inline]
pub fn is_empty(&self) -> bool {
self.buf_len == 0
}
/// Adds some bytes to the right of the buffer
pub fn extend(&mut self, b: Bytes) {
if !b.is_empty() {
self.buf_len += b.len();
self.buf.push_back(b);
}
}
/// Takes the whole content of the buffer and returns it as a single Bytes unit
pub fn take_all(&mut self) -> Bytes {
if self.buf.is_empty() {
Bytes::new()
} else if self.buf.len() == 1 {
self.buf_len = 0;
self.buf.pop_back().unwrap()
} else {
let mut ret = BytesMut::with_capacity(self.buf_len);
for b in self.buf.iter() {
ret.extend_from_slice(&b[..]);
}
self.buf.clear();
self.buf_len = 0;
ret.freeze()
}
}
/// Takes at most max_len bytes from the left of the buffer
pub fn take_max(&mut self, max_len: usize) -> Bytes {
if self.buf_len <= max_len {
self.take_all()
} else {
self.take_exact_ok(max_len)
}
}
/// Take exactly len bytes from the left of the buffer, returns None if
/// the BytesBuf doesn't contain enough data
pub fn take_exact(&mut self, len: usize) -> Option<Bytes> {
if self.buf_len < len {
None
} else {
Some(self.take_exact_ok(len))
}
}
fn take_exact_ok(&mut self, len: usize) -> Bytes {
assert!(len <= self.buf_len);
let front = self.buf.pop_front().unwrap();
match front.len().cmp(&len) {
Ordering::Greater => {
self.buf.push_front(front.slice(len..));
self.buf_len -= len;
front.slice(..len)
}
Ordering::Equal => {
self.buf_len -= len;
front
}
Ordering::Less => {
let mut ret = BytesMut::with_capacity(len);
ret.extend_from_slice(&front[..]);
self.buf_len -= front.len();
while ret.len() < len {
let front = self.buf.pop_front().unwrap();
if front.len() > len - ret.len() {
let take = len - ret.len();
ret.extend_from_slice(&front[..take]);
self.buf.push_front(front.slice(take..));
self.buf_len -= take;
break;
} else {
ret.extend_from_slice(&front[..]);
self.buf_len -= front.len();
}
}
ret.freeze()
}
}
}
/// Return the internal sequence of Bytes slices that make up the buffer
pub fn into_slices(self) -> VecDeque<Bytes> {
self.buf
}
}
impl Default for BytesBuf {
fn default() -> Self {
Self::new()
}
}
impl From<Bytes> for BytesBuf {
fn from(b: Bytes) -> BytesBuf {
let mut ret = BytesBuf::new();
ret.extend(b);
ret
}
}
impl From<BytesBuf> for Bytes {
fn from(mut b: BytesBuf) -> Bytes {
b.take_all()
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_bytes_buf() {
let mut buf = BytesBuf::new();
assert!(buf.len() == 0);
assert!(buf.is_empty());
buf.extend(Bytes::from(b"Hello, world!".to_vec()));
assert!(buf.len() == 13);
assert!(!buf.is_empty());
buf.extend(Bytes::from(b"1234567890".to_vec()));
assert!(buf.len() == 23);
assert!(!buf.is_empty());
assert_eq!(
buf.take_all(),
Bytes::from(b"Hello, world!1234567890".to_vec())
);
assert!(buf.len() == 0);
assert!(buf.is_empty());
buf.extend(Bytes::from(b"1234567890".to_vec()));
buf.extend(Bytes::from(b"Hello, world!".to_vec()));
assert!(buf.len() == 23);
assert!(!buf.is_empty());
assert_eq!(buf.take_max(12), Bytes::from(b"1234567890He".to_vec()));
assert!(buf.len() == 11);
assert_eq!(buf.take_exact(12), None);
assert!(buf.len() == 11);
assert_eq!(
buf.take_exact(11),
Some(Bytes::from(b"llo, world!".to_vec()))
);
assert!(buf.len() == 0);
assert!(buf.is_empty());
}
}

View file

@ -1,18 +1,13 @@
use std::borrow::Borrow;
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::atomic::{self, AtomicU32}; use std::sync::atomic::{self, AtomicU32};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::task::Poll;
use arc_swap::ArcSwapOption; use arc_swap::ArcSwapOption;
use async_trait::async_trait;
use bytes::Bytes;
use log::{debug, error, trace}; use log::{debug, error, trace};
use futures::io::AsyncReadExt; use futures::channel::mpsc::{unbounded, UnboundedReceiver};
use futures::Stream;
use kuska_handshake::async_std::{handshake_client, BoxStream};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::select; use tokio::select;
use tokio::sync::{mpsc, oneshot, watch}; use tokio::sync::{mpsc, oneshot, watch};
@ -26,22 +21,28 @@ use opentelemetry::{
#[cfg(feature = "telemetry")] #[cfg(feature = "telemetry")]
use opentelemetry_contrib::trace::propagator::binary::*; use opentelemetry_contrib::trace::propagator::binary::*;
use futures::io::AsyncReadExt;
use async_trait::async_trait;
use kuska_handshake::async_std::{handshake_client, BoxStream};
use crate::endpoint::*;
use crate::error::*; use crate::error::*;
use crate::message::*;
use crate::netapp::*; use crate::netapp::*;
use crate::recv::*; use crate::proto::*;
use crate::send::*; use crate::proto2::*;
use crate::stream::*;
use crate::util::*; use crate::util::*;
pub(crate) struct ClientConn { pub(crate) struct ClientConn {
pub(crate) remote_addr: SocketAddr, pub(crate) remote_addr: SocketAddr,
pub(crate) peer_id: NodeID, pub(crate) peer_id: NodeID,
query_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>, query_send:
ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>,
next_query_number: AtomicU32, next_query_number: AtomicU32,
inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>, inflight: Mutex<HashMap<RequestID, oneshot::Sender<UnboundedReceiver<Packet>>>>,
} }
impl ClientConn { impl ClientConn {
@ -65,7 +66,7 @@ impl ClientConn {
debug!( debug!(
"Handshake complete (client) with {}@{}", "Handshake complete (client) with {}@{}",
hex::encode(peer_id), hex::encode(&peer_id),
remote_addr remote_addr
); );
@ -103,16 +104,13 @@ impl ClientConn {
netapp.connected_as_client(peer_id, conn.clone()); netapp.connected_as_client(peer_id, conn.clone());
let debug_name = format!("CLI {}", hex::encode(&peer_id[..8]));
tokio::spawn(async move { tokio::spawn(async move {
let debug_name_2 = debug_name.clone(); let send_future = tokio::spawn(conn.clone().send_loop(query_recv, write));
let send_future = tokio::spawn(conn.clone().send_loop(query_recv, write, debug_name_2));
let conn2 = conn.clone(); let conn2 = conn.clone();
let recv_future = tokio::spawn(async move { let recv_future = tokio::spawn(async move {
select! { select! {
r = conn2.recv_loop(read, debug_name) => r, r = conn2.recv_loop(read) => r,
_ = await_exit(stop_recv_loop_recv) => Ok(()) _ = await_exit(stop_recv_loop_recv) => Ok(())
} }
}); });
@ -140,14 +138,15 @@ impl ClientConn {
self.query_send.store(None); self.query_send.store(None);
} }
pub(crate) async fn call<T>( pub(crate) async fn call<T, B>(
self: Arc<Self>, self: Arc<Self>,
req: Req<T>, rq: B,
path: &str, path: &str,
prio: RequestPriority, prio: RequestPriority,
) -> Result<Resp<T>, Error> ) -> Result<<T as Message>::Response, Error>
where where
T: Message, T: Message,
B: Borrow<T>,
{ {
let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
@ -162,16 +161,24 @@ impl ClientConn {
.with_kind(SpanKind::Client) .with_kind(SpanKind::Client)
.start(&tracer); .start(&tracer);
let propagator = BinaryPropagator::new(); let propagator = BinaryPropagator::new();
let telemetry_id: Bytes = propagator.to_bytes(span.span_context()).to_vec().into(); let telemetry_id = Some(propagator.to_bytes(span.span_context()).to_vec());
} else { } else {
let telemetry_id: Bytes = Bytes::new(); let telemetry_id: Option<Vec<u8>> = None;
} }
}; };
// Encode request // Encode request
let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id); let (body, stream) = rmp_to_vec_all_named(rq.borrow())?;
let req_msg_len = req_enc.msg.len(); drop(rq);
let (req_stream, req_order) = req_enc.encode();
let request = QueryMessage {
prio,
path: path.as_bytes(),
telemetry_id,
body: &body[..],
};
let bytes = request.encode();
drop(body);
// Send request through // Send request through
let (resp_send, resp_recv) = oneshot::channel(); let (resp_send, resp_recv) = oneshot::channel();
@ -180,25 +187,17 @@ impl ClientConn {
error!( error!(
"Too many inflight requests! RequestID collision. Interrupting previous request." "Too many inflight requests! RequestID collision. Interrupting previous request."
); );
let _ = old_ch.send(Box::pin(futures::stream::once(async move { if old_ch.send(unbounded().1).is_err() {
Err(std::io::Error::new( debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
std::io::ErrorKind::Other, }
"RequestID collision, too many inflight requests",
))
})));
} }
debug!( trace!("request: query_send {}, {} bytes", id, bytes.len());
"request: query_send {}, path {}, prio {} (serialized message: {} bytes)",
id, path, prio, req_msg_len
);
#[cfg(feature = "telemetry")] #[cfg(feature = "telemetry")]
span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64)); span.set_attribute(KeyValue::new("len_query", bytes.len() as i64));
query_send.send(SendItem::Stream(id, prio, req_order, req_stream))?; query_send.send((id, prio, Framing::new(bytes, stream).into_stream()))?;
let canceller = CancelOnDrop::new(id, query_send.as_ref().clone());
cfg_if::cfg_if! { cfg_if::cfg_if! {
if #[cfg(feature = "telemetry")] { if #[cfg(feature = "telemetry")] {
@ -209,12 +208,24 @@ impl ClientConn {
let stream = resp_recv.await?; let stream = resp_recv.await?;
} }
} }
let (resp, stream) = Framing::from_stream(stream).await?.into_parts();
let stream = Box::pin(canceller.for_stream(stream)); if resp.is_empty() {
return Err(Error::Message(
"Response is 0 bytes, either a collision or a protocol error".into(),
));
}
let resp_enc = RespEnc::decode(stream).await?; trace!("request response {}: ", id);
debug!("client: got response to request {} (path {})", id, path);
Resp::from_enc(resp_enc) let code = resp[0];
if code == 0 {
let ser_resp = rmp_serde::decode::from_read_ref(&resp[1..])?;
Ok(T::Response::deserialize_msg(ser_resp, stream).await)
} else {
let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
Err(Error::Remote(code, msg))
}
} }
} }
@ -222,7 +233,7 @@ impl SendLoop for ClientConn {}
#[async_trait] #[async_trait]
impl RecvLoop for ClientConn { impl RecvLoop for ClientConn {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) { fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
trace!("ClientConn recv_handler {}", id); trace!("ClientConn recv_handler {}", id);
let mut inflight = self.inflight.lock().unwrap(); let mut inflight = self.inflight.lock().unwrap();
@ -230,63 +241,6 @@ impl RecvLoop for ClientConn {
if ch.send(stream).is_err() { if ch.send(stream).is_err() {
debug!("Could not send request response, probably because request was interrupted. Dropping response."); debug!("Could not send request response, probably because request was interrupted. Dropping response.");
} }
} else {
debug!("Got unexpected response to request {}, dropping it", id);
} }
} }
} }
// ----
struct CancelOnDrop {
id: RequestID,
query_send: mpsc::UnboundedSender<SendItem>,
}
impl CancelOnDrop {
fn new(id: RequestID, query_send: mpsc::UnboundedSender<SendItem>) -> Self {
Self { id, query_send }
}
fn for_stream(self, stream: ByteStream) -> CancelOnDropStream {
CancelOnDropStream {
cancel: Some(self),
stream,
}
}
}
impl Drop for CancelOnDrop {
fn drop(&mut self) {
trace!("cancelling request {}", self.id);
let _ = self.query_send.send(SendItem::Cancel(self.id));
}
}
#[pin_project::pin_project]
struct CancelOnDropStream {
cancel: Option<CancelOnDrop>,
#[pin]
stream: ByteStream,
}
impl Stream for CancelOnDropStream {
type Item = Packet;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.project();
let res = this.stream.poll_next(cx);
if matches!(res, Poll::Ready(None)) {
if let Some(c) = this.cancel.take() {
std::mem::forget(c)
}
}
res
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.stream.size_hint()
}
}

View file

@ -1,43 +1,87 @@
use std::borrow::Borrow;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::sync::Arc; use std::sync::Arc;
use arc_swap::ArcSwapOption; use arc_swap::ArcSwapOption;
use async_trait::async_trait; use async_trait::async_trait;
use crate::error::Error; use serde::{Deserialize, Serialize};
use crate::message::*;
use crate::netapp::*;
/// This trait should be implemented by an object of your application use crate::error::Error;
/// that can handle a message of type `M`, if it wishes to handle use crate::netapp::*;
/// streams attached to the request and/or to send back streams use crate::proto::*;
/// attached to the response.. use crate::util::*;
///
/// The handler object should be in an Arc, see `Endpoint::set_handler` /// This trait should be implemented by all messages your application
#[async_trait] /// wants to handle
pub trait StreamingEndpointHandler<M>: Send + Sync pub trait Message: SerializeMessage + Send + Sync {
where type Response: SerializeMessage + Send + Sync;
M: Message,
{
async fn handle(self: &Arc<Self>, m: Req<M>, from: NodeID) -> Resp<M>;
} }
/// If one simply wants to use an endpoint in a client fashion, /// A trait for de/serializing messages, with possible associated stream.
/// without locally serving requests to that endpoint,
/// use the unit type `()` as the handler type:
/// it will panic if it is ever made to handle request.
#[async_trait] #[async_trait]
impl<M: Message> EndpointHandler<M> for () { pub trait SerializeMessage: Sized {
async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response { type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
panic!("This endpoint should not have a local handler.");
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>);
// TODO should return Result
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self;
}
pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
#[async_trait]
impl<T> SerializeMessage for T
where
T: AutoSerialize,
{
type SerializableSelf = Self;
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>) {
(self.clone(), None)
}
async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: AssociatedStream) -> Self {
// TODO verify no stream
ser_self
} }
} }
// ---- impl AutoSerialize for () {}
#[async_trait]
impl<T, E> SerializeMessage for Result<T, E>
where
T: SerializeMessage + Send,
E: SerializeMessage + Send,
{
type SerializableSelf = Result<T::SerializableSelf, E::SerializableSelf>;
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>) {
match self {
Ok(ok) => {
let (msg, stream) = ok.serialize_msg();
(Ok(msg), stream)
}
Err(err) => {
let (msg, stream) = err.serialize_msg();
(Err(msg), stream)
}
}
}
async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self {
match ser_self {
Ok(ok) => Ok(T::deserialize_msg(ok, stream).await),
Err(err) => Err(E::deserialize_msg(err, stream).await),
}
}
}
/// This trait should be implemented by an object of your application /// This trait should be implemented by an object of your application
/// that can handle a message of type `M`, in the cases where it doesn't /// that can handle a message of type `M`.
/// care about attached stream in the request nor in the response. ///
/// The handler object should be in an Arc, see `Endpoint::set_handler`
#[async_trait] #[async_trait]
pub trait EndpointHandler<M>: Send + Sync pub trait EndpointHandler<M>: Send + Sync
where where
@ -46,22 +90,17 @@ where
async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> M::Response; async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> M::Response;
} }
/// If one simply wants to use an endpoint in a client fashion,
/// without locally serving requests to that endpoint,
/// use the unit type `()` as the handler type:
/// it will panic if it is ever made to handle request.
#[async_trait] #[async_trait]
impl<T, M> StreamingEndpointHandler<M> for T impl<M: Message + 'static> EndpointHandler<M> for () {
where async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response {
T: EndpointHandler<M>, panic!("This endpoint should not have a local handler.");
M: Message,
{
async fn handle(self: &Arc<Self>, mut m: Req<M>, from: NodeID) -> Resp<M> {
// Immediately drop stream to ignore all data that comes in,
// instead of buffering it indefinitely
drop(m.take_stream());
Resp::new(EndpointHandler::handle(self, m.msg(), from).await)
} }
} }
// ----
/// This struct represents an endpoint for message of type `M`. /// This struct represents an endpoint for message of type `M`.
/// ///
/// Creating a new endpoint is done by calling `NetApp::endpoint`. /// Creating a new endpoint is done by calling `NetApp::endpoint`.
@ -71,13 +110,13 @@ where
/// An `Endpoint` is used both to send requests to remote nodes, /// An `Endpoint` is used both to send requests to remote nodes,
/// and to specify the handler for such requests on the local node. /// and to specify the handler for such requests on the local node.
/// The type `H` represents the type of the handler object for /// The type `H` represents the type of the handler object for
/// endpoint messages (see `StreamingEndpointHandler`). /// endpoint messages (see `EndpointHandler`).
pub struct Endpoint<M, H> pub struct Endpoint<M, H>
where where
M: Message, M: Message,
H: StreamingEndpointHandler<M>, H: EndpointHandler<M>,
{ {
_phantom: PhantomData<M>, phantom: PhantomData<M>,
netapp: Arc<NetApp>, netapp: Arc<NetApp>,
path: String, path: String,
handler: ArcSwapOption<H>, handler: ArcSwapOption<H>,
@ -86,11 +125,11 @@ where
impl<M, H> Endpoint<M, H> impl<M, H> Endpoint<M, H>
where where
M: Message, M: Message,
H: StreamingEndpointHandler<M>, H: EndpointHandler<M>,
{ {
pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self { pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self {
Self { Self {
_phantom: PhantomData::default(), phantom: PhantomData::default(),
netapp, netapp,
path, path,
handler: ArcSwapOption::from(None), handler: ArcSwapOption::from(None),
@ -109,22 +148,20 @@ where
} }
/// Call this endpoint on a remote node (or on the local node, /// Call this endpoint on a remote node (or on the local node,
/// for that matter). This function invokes the full version that /// for that matter)
/// allows to attach a stream to the request and to pub async fn call<B>(
/// receive such a stream attached to the response.
pub async fn call_streaming<T>(
&self, &self,
target: &NodeID, target: &NodeID,
req: T, req: B,
prio: RequestPriority, prio: RequestPriority,
) -> Result<Resp<M>, Error> ) -> Result<<M as Message>::Response, Error>
where where
T: IntoReq<M>, B: Borrow<M> + Send + Sync,
{ {
if *target == self.netapp.id { if *target == self.netapp.id {
match self.handler.load_full() { match self.handler.load_full() {
None => Err(Error::NoHandler), None => Err(Error::NoHandler),
Some(h) => Ok(h.handle(req.into_req_local(), self.netapp.id).await), Some(h) => Ok(h.handle(req.borrow(), self.netapp.id).await),
} }
} else { } else {
let conn = self let conn = self
@ -139,22 +176,10 @@ where
"Not connected: {}", "Not connected: {}",
hex::encode(&target[..8]) hex::encode(&target[..8])
))), ))),
Some(c) => c.call(req.into_req()?, self.path.as_str(), prio).await, Some(c) => c.call(req, self.path.as_str(), prio).await,
} }
} }
} }
/// Call this endpoint on a remote node. This function is the simplified
/// version that doesn't allow to have streams attached to the request
/// or the response; see `call_streaming` for the full version.
pub async fn call(
&self,
target: &NodeID,
req: M,
prio: RequestPriority,
) -> Result<<M as Message>::Response, Error> {
Ok(self.call_streaming(target, req, prio).await?.into_msg())
}
} }
// ---- Internal stuff ---- // ---- Internal stuff ----
@ -163,7 +188,12 @@ pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;
#[async_trait] #[async_trait]
pub(crate) trait GenericEndpoint { pub(crate) trait GenericEndpoint {
async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error>; async fn handle(
&self,
buf: &[u8],
stream: AssociatedStream,
from: NodeID,
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error>;
fn drop_handler(&self); fn drop_handler(&self);
fn clone_endpoint(&self) -> DynEndpoint; fn clone_endpoint(&self) -> DynEndpoint;
} }
@ -172,21 +202,28 @@ pub(crate) trait GenericEndpoint {
pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>) pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>)
where where
M: Message, M: Message,
H: StreamingEndpointHandler<M>; H: EndpointHandler<M>;
#[async_trait] #[async_trait]
impl<M, H> GenericEndpoint for EndpointArc<M, H> impl<M, H> GenericEndpoint for EndpointArc<M, H>
where where
M: Message, M: Message + 'static,
H: StreamingEndpointHandler<M> + 'static, H: EndpointHandler<M> + 'static,
{ {
async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error> { async fn handle(
&self,
buf: &[u8],
stream: AssociatedStream,
from: NodeID,
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> {
match self.0.handler.load_full() { match self.0.handler.load_full() {
None => Err(Error::NoHandler), None => Err(Error::NoHandler),
Some(h) => { Some(h) => {
let req = Req::from_enc(req_enc)?; let req = rmp_serde::decode::from_read_ref(buf)?;
let res = h.handle(req, from).await; let req = M::deserialize_msg(req, stream).await;
Ok(res.into_enc()?) let res = h.handle(&req, from).await;
let res_bytes = rmp_to_vec_all_named(&res)?;
Ok(res_bytes)
} }
} }
} }

View file

@ -1,6 +1,6 @@
use err_derive::Error;
use std::io; use std::io;
use err_derive::Error;
use log::error; use log::error;
#[derive(Debug, Error)] #[derive(Debug, Error)]
@ -28,12 +28,6 @@ pub enum Error {
#[error(display = "Framing protocol error")] #[error(display = "Framing protocol error")]
Framing, Framing,
#[error(display = "Remote error ({:?}): {}", _0, _1)]
Remote(io::ErrorKind, String),
#[error(display = "Request ID collision")]
IdCollision,
#[error(display = "{}", _0)] #[error(display = "{}", _0)]
Message(String), Message(String),
@ -45,6 +39,29 @@ pub enum Error {
#[error(display = "Version mismatch: {}", _0)] #[error(display = "Version mismatch: {}", _0)]
VersionMismatch(String), VersionMismatch(String),
#[error(display = "Remote error {}: {}", _0, _1)]
Remote(u8, String),
}
impl Error {
pub fn code(&self) -> u8 {
match self {
Self::Io(_) => 100,
Self::TokioJoin(_) => 110,
Self::OneshotRecv(_) => 111,
Self::RMPEncode(_) => 10,
Self::RMPDecode(_) => 11,
Self::UTF8(_) => 12,
Self::Framing => 13,
Self::NoHandler => 20,
Self::ConnectionClosed => 21,
Self::Handshake(_) => 30,
Self::VersionMismatch(_) => 31,
Self::Remote(c, _) => *c,
Self::Message(_) => 99,
}
}
} }
impl<T> From<tokio::sync::watch::error::SendError<T>> for Error { impl<T> From<tokio::sync::watch::error::SendError<T>> for Error {
@ -88,39 +105,3 @@ where
} }
} }
} }
// ---- Helpers for serializing I/O Errors
pub(crate) fn u8_to_io_errorkind(v: u8) -> std::io::ErrorKind {
use std::io::ErrorKind;
match v {
101 => ErrorKind::ConnectionAborted,
102 => ErrorKind::BrokenPipe,
103 => ErrorKind::WouldBlock,
104 => ErrorKind::InvalidInput,
105 => ErrorKind::InvalidData,
106 => ErrorKind::TimedOut,
107 => ErrorKind::Interrupted,
108 => ErrorKind::UnexpectedEof,
109 => ErrorKind::OutOfMemory,
110 => ErrorKind::ConnectionReset,
_ => ErrorKind::Other,
}
}
pub(crate) fn io_errorkind_to_u8(kind: std::io::ErrorKind) -> u8 {
use std::io::ErrorKind;
match kind {
ErrorKind::ConnectionAborted => 101,
ErrorKind::BrokenPipe => 102,
ErrorKind::WouldBlock => 103,
ErrorKind::InvalidInput => 104,
ErrorKind::InvalidData => 105,
ErrorKind::TimedOut => 106,
ErrorKind::Interrupted => 107,
ErrorKind::UnexpectedEof => 108,
ErrorKind::OutOfMemory => 109,
ErrorKind::ConnectionReset => 110,
_ => 100,
}
}

View file

@ -13,23 +13,21 @@
//! about message priorization. //! about message priorization.
//! Also check out the examples to learn how to use this crate. //! Also check out the examples to learn how to use this crate.
pub mod bytes_buf;
pub mod error; pub mod error;
pub mod stream;
pub mod util; pub mod util;
pub mod endpoint; pub mod endpoint;
pub mod message; pub mod proto;
mod client; mod client;
mod recv; mod proto2;
mod send;
mod server; mod server;
pub mod netapp; pub mod netapp;
pub mod peering; pub mod peering;
pub use crate::netapp::*; pub use crate::netapp::*;
pub use util::{NetworkKey, NodeID, NodeKey};
#[cfg(test)] #[cfg(test)]
mod test; mod test;

View file

@ -1,522 +0,0 @@
use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
use bytes::{BufMut, Bytes, BytesMut};
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use futures::stream::StreamExt;
use crate::error::*;
use crate::stream::*;
use crate::util::*;
/// Priority of a request (click to read more about priorities).
///
/// This priority value is used to priorize messages
/// in the send queue of the client, and their responses in the send queue of the
/// server. Lower values mean higher priority.
///
/// This mechanism is usefull for messages bigger than the maximum chunk size
/// (set at `0x4000` bytes), such as large file transfers.
/// In such case, all of the messages in the send queue with the highest priority
/// will take turns to send individual chunks, in a round-robin fashion.
/// Once all highest priority messages are sent successfully, the messages with
/// the next highest priority will begin being sent in the same way.
///
/// The same priority value is given to a request and to its associated response.
pub type RequestPriority = u8;
/// Priority class: high
pub const PRIO_HIGH: RequestPriority = 0x20;
/// Priority class: normal
pub const PRIO_NORMAL: RequestPriority = 0x40;
/// Priority class: background
pub const PRIO_BACKGROUND: RequestPriority = 0x80;
/// Priority: primary among given class
pub const PRIO_PRIMARY: RequestPriority = 0x00;
/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`)
pub const PRIO_SECONDARY: RequestPriority = 0x01;
// ----
/// An order tag can be added to a message or a response to indicate
/// whether it should be sent after or before other messages with order tags
/// referencing a same stream
#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
pub struct OrderTag(pub(crate) u64, pub(crate) u64);
/// A stream is an opaque identifier that defines a set of messages
/// or responses that are ordered wrt one another using to order tags.
#[derive(Clone, Copy)]
pub struct OrderTagStream(u64);
impl OrderTag {
/// Create a new stream from which to generate order tags. Example:
/// ```ignore
/// let stream = OrderTag.stream();
/// let tag_1 = stream.order(1);
/// let tag_2 = stream.order(2);
/// ```
pub fn stream() -> OrderTagStream {
OrderTagStream(thread_rng().gen())
}
}
impl OrderTagStream {
/// Create the order tag for message `order` in this stream
pub fn order(&self, order: u64) -> OrderTag {
OrderTag(self.0, order)
}
}
// ----
/// This trait should be implemented by all messages your application
/// wants to handle. It specifies which data type should be sent
/// as a response to this message in the RPC protocol.
pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static {
/// The type of the response that is sent in response to this message
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static;
}
// ----
/// The Req<M> is a helper object used to create requests and attach them
/// a stream of data. If the stream is a fixed Bytes and not a ByteStream,
/// Req<M> is cheaply clonable to allow the request to be sent to different
/// peers (Clone will panic if the stream is a ByteStream).
pub struct Req<M: Message> {
pub(crate) msg: Arc<M>,
pub(crate) msg_ser: Option<Bytes>,
pub(crate) stream: AttachedStream,
pub(crate) order_tag: Option<OrderTag>,
}
impl<M: Message> Req<M> {
/// Creates a new request from a base message `M`
pub fn new(v: M) -> Result<Self, Error> {
Ok(v.into_req()?)
}
/// Attach a stream to message in request, where the stream is streamed
/// from a fixed `Bytes` buffer
pub fn with_stream_from_buffer(self, b: Bytes) -> Self {
Self {
stream: AttachedStream::Fixed(b),
..self
}
}
/// Attach a stream to message in request, where the stream is
/// an instance of `ByteStream`. Note than when a `Req<M>` has an attached
/// stream which is a `ByteStream` instance, it can no longer be cloned
/// to be sent to different nodes (`.clone()` will panic)
pub fn with_stream(self, b: ByteStream) -> Self {
Self {
stream: AttachedStream::Stream(b),
..self
}
}
/// Add an order tag to this request to indicate in which order it should
/// be sent.
pub fn with_order_tag(self, order_tag: OrderTag) -> Self {
Self {
order_tag: Some(order_tag),
..self
}
}
/// Get a reference to the message `M` contained in this request
pub fn msg(&self) -> &M {
&self.msg
}
/// Takes out the stream attached to this request, if any
pub fn take_stream(&mut self) -> Option<ByteStream> {
std::mem::replace(&mut self.stream, AttachedStream::None).into_stream()
}
pub(crate) fn into_enc(
self,
prio: RequestPriority,
path: Bytes,
telemetry_id: Bytes,
) -> ReqEnc {
ReqEnc {
prio,
path,
telemetry_id,
msg: self.msg_ser.unwrap(),
stream: self.stream.into_stream(),
order_tag: self.order_tag,
}
}
pub(crate) fn from_enc(enc: ReqEnc) -> Result<Self, rmp_serde::decode::Error> {
let msg = rmp_serde::decode::from_slice(&enc.msg)?;
Ok(Req {
msg: Arc::new(msg),
msg_ser: Some(enc.msg),
stream: enc
.stream
.map(AttachedStream::Stream)
.unwrap_or(AttachedStream::None),
order_tag: enc.order_tag,
})
}
}
/// `IntoReq<M>` represents any object that can be transformed into `Req<M>`
pub trait IntoReq<M: Message> {
/// Transform the object into a `Req<M>`, serializing the message M
/// to be sent to remote nodes
fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error>;
/// Transform the object into a `Req<M>`, skipping the serialization
/// of message M, in the case we are not sending this RPC message to
/// a remote node
fn into_req_local(self) -> Req<M>;
}
impl<M: Message> IntoReq<M> for M {
fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> {
let msg_ser = rmp_to_vec_all_named(&self)?;
Ok(Req {
msg: Arc::new(self),
msg_ser: Some(Bytes::from(msg_ser)),
stream: AttachedStream::None,
order_tag: None,
})
}
fn into_req_local(self) -> Req<M> {
Req {
msg: Arc::new(self),
msg_ser: None,
stream: AttachedStream::None,
order_tag: None,
}
}
}
impl<M: Message> IntoReq<M> for Req<M> {
fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> {
Ok(self)
}
fn into_req_local(self) -> Req<M> {
self
}
}
impl<M: Message> Clone for Req<M> {
fn clone(&self) -> Self {
let stream = match &self.stream {
AttachedStream::None => AttachedStream::None,
AttachedStream::Fixed(b) => AttachedStream::Fixed(b.clone()),
AttachedStream::Stream(_) => {
panic!("Cannot clone a Req<_> with a non-buffer attached stream")
}
};
Self {
msg: self.msg.clone(),
msg_ser: self.msg_ser.clone(),
stream,
order_tag: self.order_tag,
}
}
}
impl<M> fmt::Debug for Req<M>
where
M: Message + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "Req[{:?}", self.msg)?;
match &self.stream {
AttachedStream::None => write!(f, "]"),
AttachedStream::Fixed(b) => write!(f, "; stream=buf:{}]", b.len()),
AttachedStream::Stream(_) => write!(f, "; stream]"),
}
}
}
// ----
/// The Resp<M> represents a full response from a RPC that may have
/// an attached stream.
pub struct Resp<M: Message> {
pub(crate) _phantom: PhantomData<M>,
pub(crate) msg: M::Response,
pub(crate) stream: AttachedStream,
pub(crate) order_tag: Option<OrderTag>,
}
impl<M: Message> Resp<M> {
/// Creates a new response from a base response message
pub fn new(v: M::Response) -> Self {
Resp {
_phantom: Default::default(),
msg: v,
stream: AttachedStream::None,
order_tag: None,
}
}
/// Attach a stream to message in response, where the stream is streamed
/// from a fixed `Bytes` buffer
pub fn with_stream_from_buffer(self, b: Bytes) -> Self {
Self {
stream: AttachedStream::Fixed(b),
..self
}
}
/// Attach a stream to message in response, where the stream is
/// an instance of `ByteStream`.
pub fn with_stream(self, b: ByteStream) -> Self {
Self {
stream: AttachedStream::Stream(b),
..self
}
}
/// Add an order tag to this response to indicate in which order it should
/// be sent.
pub fn with_order_tag(self, order_tag: OrderTag) -> Self {
Self {
order_tag: Some(order_tag),
..self
}
}
/// Get a reference to the response message contained in this request
pub fn msg(&self) -> &M::Response {
&self.msg
}
/// Transforms the `Resp<M>` into the response message it contains,
/// dropping everything else (including attached data stream)
pub fn into_msg(self) -> M::Response {
self.msg
}
/// Transforms the `Resp<M>` into, on the one side, the response message
/// it contains, and on the other side, the associated data stream
/// if it exists
pub fn into_parts(self) -> (M::Response, Option<ByteStream>) {
(self.msg, self.stream.into_stream())
}
pub(crate) fn into_enc(self) -> Result<RespEnc, rmp_serde::encode::Error> {
Ok(RespEnc {
msg: rmp_to_vec_all_named(&self.msg)?.into(),
stream: self.stream.into_stream(),
order_tag: self.order_tag,
})
}
pub(crate) fn from_enc(enc: RespEnc) -> Result<Self, Error> {
let msg = rmp_serde::decode::from_slice(&enc.msg)?;
Ok(Self {
_phantom: Default::default(),
msg,
stream: enc
.stream
.map(AttachedStream::Stream)
.unwrap_or(AttachedStream::None),
order_tag: enc.order_tag,
})
}
}
impl<M> fmt::Debug for Resp<M>
where
M: Message,
<M as Message>::Response: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "Resp[{:?}", self.msg)?;
match &self.stream {
AttachedStream::None => write!(f, "]"),
AttachedStream::Fixed(b) => write!(f, "; stream=buf:{}]", b.len()),
AttachedStream::Stream(_) => write!(f, "; stream]"),
}
}
}
// ----
pub(crate) enum AttachedStream {
None,
Fixed(Bytes),
Stream(ByteStream),
}
impl AttachedStream {
pub fn into_stream(self) -> Option<ByteStream> {
match self {
AttachedStream::None => None,
AttachedStream::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))),
AttachedStream::Stream(s) => Some(s),
}
}
}
// ---- ----
/// Encoding for requests into a ByteStream:
/// - priority: u8
/// - path length: u8
/// - path: [u8; path length]
/// - telemetry id length: u8
/// - telemetry id: [u8; telemetry id length]
/// - msg len: u32
/// - msg [u8; ..]
/// - the attached stream as the rest of the encoded stream
pub(crate) struct ReqEnc {
pub(crate) prio: RequestPriority,
pub(crate) path: Bytes,
pub(crate) telemetry_id: Bytes,
pub(crate) msg: Bytes,
pub(crate) stream: Option<ByteStream>,
pub(crate) order_tag: Option<OrderTag>,
}
impl ReqEnc {
pub(crate) fn encode(self) -> (ByteStream, Option<OrderTag>) {
let mut buf = BytesMut::with_capacity(
self.path.len() + self.telemetry_id.len() + self.msg.len() + 16,
);
buf.put_u8(self.prio);
buf.put_u8(self.path.len() as u8);
buf.put(self.path);
buf.put_u8(self.telemetry_id.len() as u8);
buf.put(&self.telemetry_id[..]);
buf.put_u32(self.msg.len() as u32);
let header = buf.freeze();
let res_stream: ByteStream = if let Some(stream) = self.stream {
Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)]).chain(stream))
} else {
Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)]))
};
(res_stream, self.order_tag)
}
pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> {
Self::decode_aux(stream)
.await
.map_err(read_exact_error_to_error)
}
async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> {
let mut reader = ByteStreamReader::new(stream);
let prio = reader.read_u8().await?;
let path_len = reader.read_u8().await?;
let path = reader.read_exact(path_len as usize).await?;
let telemetry_id_len = reader.read_u8().await?;
let telemetry_id = reader.read_exact(telemetry_id_len as usize).await?;
let msg_len = reader.read_u32().await?;
let msg = reader.read_exact(msg_len as usize).await?;
Ok(Self {
prio,
path,
telemetry_id,
msg,
stream: Some(reader.into_stream()),
order_tag: None,
})
}
}
/// Encoding for responses into a ByteStream:
/// IF SUCCESS:
/// - 0: u8
/// - msg len: u32
/// - msg [u8; ..]
/// - the attached stream as the rest of the encoded stream
/// IF ERROR:
/// - message length + 1: u8
/// - error code: u8
/// - message: [u8; message_length]
pub(crate) struct RespEnc {
msg: Bytes,
stream: Option<ByteStream>,
order_tag: Option<OrderTag>,
}
impl RespEnc {
pub(crate) fn encode(resp: Result<Self, Error>) -> (ByteStream, Option<OrderTag>) {
match resp {
Ok(Self {
msg,
stream,
order_tag,
}) => {
let mut buf = BytesMut::with_capacity(4);
buf.put_u32(msg.len() as u32);
let header = buf.freeze();
let res_stream: ByteStream = if let Some(stream) = stream {
Box::pin(futures::stream::iter([Ok(header), Ok(msg)]).chain(stream))
} else {
Box::pin(futures::stream::iter([Ok(header), Ok(msg)]))
};
(res_stream, order_tag)
}
Err(err) => {
let err = std::io::Error::new(
std::io::ErrorKind::Other,
format!("netapp error: {}", err),
);
(
Box::pin(futures::stream::once(async move { Err(err) })),
None,
)
}
}
}
pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> {
Self::decode_aux(stream)
.await
.map_err(read_exact_error_to_error)
}
async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> {
let mut reader = ByteStreamReader::new(stream);
let msg_len = reader.read_u32().await?;
let msg = reader.read_exact(msg_len as usize).await?;
// Check whether the response stream still has data or not.
// If no more data is coming, this will defuse the request canceller.
// If we didn't do this, and the client doesn't try to read from the stream,
// the request canceller doesn't know that we read everything and
// sends a cancellation message to the server (which they don't care about).
reader.fill_buffer().await;
Ok(Self {
msg,
stream: Some(reader.into_stream()),
order_tag: None,
})
}
}
fn read_exact_error_to_error(e: ReadExactError) -> Error {
match e {
ReadExactError::Stream(err) => Error::Remote(err.kind(), err.to_string()),
ReadExactError::UnexpectedEos => Error::Framing,
}
}

View file

@ -20,15 +20,9 @@ use tokio::sync::{mpsc, watch};
use crate::client::*; use crate::client::*;
use crate::endpoint::*; use crate::endpoint::*;
use crate::error::*; use crate::error::*;
use crate::message::*; use crate::proto::*;
use crate::server::*; use crate::server::*;
use crate::util::*;
/// A node's identifier, which is also its public cryptographic key
pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey;
/// A node's secret key
pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey;
/// A network key
pub type NetworkKey = sodiumoxide::crypto::auth::Key;
/// Tag which is exchanged between client and server upon connection establishment /// Tag which is exchanged between client and server upon connection establishment
/// to check that they are running compatible versions of Netapp, /// to check that they are running compatible versions of Netapp,
@ -36,14 +30,16 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key;
pub(crate) type VersionTag = [u8; 16]; pub(crate) type VersionTag = [u8; 16];
/// Value of the Netapp version used in the version tag /// Value of the Netapp version used in the version tag
pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700005; // netapp 0x0005 pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub(crate) struct HelloMessage { pub(crate) struct HelloMessage {
pub server_addr: Option<IpAddr>, pub server_addr: Option<IpAddr>,
pub server_port: u16, pub server_port: u16,
} }
impl AutoSerialize for HelloMessage {}
impl Message for HelloMessage { impl Message for HelloMessage {
type Response = (); type Response = ();
} }
@ -158,7 +154,7 @@ impl NetApp {
pub fn endpoint<M, H>(self: &Arc<Self>, path: String) -> Arc<Endpoint<M, H>> pub fn endpoint<M, H>(self: &Arc<Self>, path: String) -> Arc<Endpoint<M, H>>
where where
M: Message + 'static, M: Message + 'static,
H: StreamingEndpointHandler<M> + 'static, H: EndpointHandler<M> + 'static,
{ {
let endpoint = Arc::new(Endpoint::<M, H>::new(self.clone(), path.clone())); let endpoint = Arc::new(Endpoint::<M, H>::new(self.clone(), path.clone()));
let endpoint_arc = EndpointArc(endpoint.clone()); let endpoint_arc = EndpointArc(endpoint.clone());
@ -403,14 +399,13 @@ impl NetApp {
hello_endpoint hello_endpoint
.call( .call(
&conn.peer_id, &conn.peer_id,
HelloMessage { &HelloMessage {
server_addr, server_addr,
server_port, server_port,
}, },
PRIO_NORMAL, PRIO_NORMAL,
) )
.await .await
.map(|_| ())
.log_err("Sending hello message"); .log_err("Sending hello message");
}); });
} }

View file

@ -14,8 +14,8 @@ use sodiumoxide::crypto::hash;
use tokio::sync::watch; use tokio::sync::watch;
use crate::endpoint::*; use crate::endpoint::*;
use crate::message::*;
use crate::netapp::*; use crate::netapp::*;
use crate::proto::*;
use crate::NodeID; use crate::NodeID;
// -- Protocol messages -- // -- Protocol messages --
@ -138,7 +138,7 @@ impl BasaltView {
let mut ret = vec![]; let mut ret = vec![];
let mut rng = thread_rng(); let mut rng = thread_rng();
for _i in 0..count { for _i in 0..count {
let idx = rng.gen_range(0..possibles.len()); let idx = rng.gen_range(0, possibles.len());
ret.push(self.slots[possibles[idx]].peer.unwrap()); ret.push(self.slots[possibles[idx]].peer.unwrap());
} }
ret ret
@ -331,7 +331,7 @@ impl Basalt {
async fn do_pull(self: Arc<Self>, peer: NodeID) { async fn do_pull(self: Arc<Self>, peer: NodeID) {
match self match self
.pull_endpoint .pull_endpoint
.call(&peer, PullMessage {}, PRIO_NORMAL) .call(&peer, &PullMessage {}, PRIO_NORMAL)
.await .await
{ {
Ok(resp) => { Ok(resp) => {
@ -346,7 +346,7 @@ impl Basalt {
async fn do_push(self: Arc<Self>, peer: NodeID) { async fn do_push(self: Arc<Self>, peer: NodeID) {
let push_msg = self.make_push_message(); let push_msg = self.make_push_message();
match self.push_endpoint.call(&peer, push_msg, PRIO_NORMAL).await { match self.push_endpoint.call(&peer, &push_msg, PRIO_NORMAL).await {
Ok(_) => { Ok(_) => {
trace!("KYEV PEXo {}", hex::encode(peer)); trace!("KYEV PEXo {}", hex::encode(peer));
} }

View file

@ -17,21 +17,19 @@ use sodiumoxide::crypto::hash;
use crate::endpoint::*; use crate::endpoint::*;
use crate::error::*; use crate::error::*;
use crate::netapp::*; use crate::netapp::*;
use crate::proto::*;
use crate::message::*;
use crate::NodeID; use crate::NodeID;
const CONN_RETRY_INTERVAL: Duration = Duration::from_secs(30); const CONN_RETRY_INTERVAL: Duration = Duration::from_secs(30);
const CONN_MAX_RETRIES: usize = 10; const CONN_MAX_RETRIES: usize = 10;
const PING_INTERVAL: Duration = Duration::from_secs(15); const PING_INTERVAL: Duration = Duration::from_secs(10);
const LOOP_DELAY: Duration = Duration::from_secs(1); const LOOP_DELAY: Duration = Duration::from_secs(1);
const FAILED_PING_THRESHOLD: usize = 4; const PING_TIMEOUT: Duration = Duration::from_secs(5);
const FAILED_PING_THRESHOLD: usize = 3;
const DEFAULT_PING_TIMEOUT_MILLIS: u64 = 10_000;
// -- Protocol messages -- // -- Protocol messages --
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Clone)]
struct PingMessage { struct PingMessage {
pub id: u64, pub id: u64,
pub peer_list_hash: hash::Digest, pub peer_list_hash: hash::Digest,
@ -41,7 +39,9 @@ impl Message for PingMessage {
type Response = PingMessage; type Response = PingMessage;
} }
#[derive(Serialize, Deserialize)] impl AutoSerialize for PingMessage {}
#[derive(Serialize, Deserialize, Clone)]
struct PeerListMessage { struct PeerListMessage {
pub list: Vec<(NodeID, SocketAddr)>, pub list: Vec<(NodeID, SocketAddr)>,
} }
@ -50,6 +50,8 @@ impl Message for PeerListMessage {
type Response = PeerListMessage; type Response = PeerListMessage;
} }
impl AutoSerialize for PeerListMessage {}
// -- Algorithm data structures -- // -- Algorithm data structures --
#[derive(Debug)] #[derive(Debug)]
@ -62,27 +64,11 @@ struct PeerInfoInternal {
all_addrs: Vec<SocketAddr>, all_addrs: Vec<SocketAddr>,
state: PeerConnState, state: PeerConnState,
last_send_ping: Option<Instant>,
last_seen: Option<Instant>, last_seen: Option<Instant>,
ping: VecDeque<Duration>, ping: VecDeque<Duration>,
failed_pings: usize, failed_pings: usize,
} }
impl PeerInfoInternal {
fn new(addr: SocketAddr, state: PeerConnState) -> Self {
Self {
addr,
all_addrs: vec![addr],
state,
last_send_ping: None,
last_seen: None,
ping: VecDeque::new(),
failed_pings: 0,
}
}
}
/// Information that the full mesh peering strategy can return about the peers it knows of
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
pub struct PeerInfo { pub struct PeerInfo {
/// The node's identifier (its public key) /// The node's identifier (its public key)
@ -113,7 +99,7 @@ impl PeerInfo {
/// PeerConnState: possible states for our tentative connections to given peer /// PeerConnState: possible states for our tentative connections to given peer
/// This structure is only interested in recording connection info for outgoing /// This structure is only interested in recording connection info for outgoing
/// TCP connections /// TCP connections
#[derive(Copy, Clone, Debug, PartialEq, Eq)] #[derive(Copy, Clone, Debug, PartialEq)]
pub enum PeerConnState { pub enum PeerConnState {
/// This entry represents ourself (the local node) /// This entry represents ourself (the local node)
Ourself, Ourself,
@ -185,8 +171,6 @@ pub struct FullMeshPeeringStrategy {
next_ping_id: AtomicU64, next_ping_id: AtomicU64,
ping_endpoint: Arc<Endpoint<PingMessage, Self>>, ping_endpoint: Arc<Endpoint<PingMessage, Self>>,
peer_list_endpoint: Arc<Endpoint<PeerListMessage, Self>>, peer_list_endpoint: Arc<Endpoint<PeerListMessage, Self>>,
ping_timeout_millis: AtomicU64,
} }
impl FullMeshPeeringStrategy { impl FullMeshPeeringStrategy {
@ -204,7 +188,14 @@ impl FullMeshPeeringStrategy {
if id != netapp.id { if id != netapp.id {
known_hosts.list.insert( known_hosts.list.insert(
id, id,
PeerInfoInternal::new(addr, PeerConnState::Waiting(0, Instant::now())), PeerInfoInternal {
addr,
all_addrs: vec![addr],
state: PeerConnState::Waiting(0, Instant::now()),
last_seen: None,
ping: VecDeque::new(),
failed_pings: 0,
},
); );
} }
} }
@ -212,7 +203,14 @@ impl FullMeshPeeringStrategy {
if let Some(addr) = our_addr { if let Some(addr) = our_addr {
known_hosts.list.insert( known_hosts.list.insert(
netapp.id, netapp.id,
PeerInfoInternal::new(addr, PeerConnState::Ourself), PeerInfoInternal {
addr,
all_addrs: vec![addr],
state: PeerConnState::Ourself,
last_seen: None,
ping: VecDeque::new(),
failed_pings: 0,
},
); );
} }
@ -223,7 +221,6 @@ impl FullMeshPeeringStrategy {
next_ping_id: AtomicU64::new(42), next_ping_id: AtomicU64::new(42),
ping_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/Ping".into()), ping_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/Ping".into()),
peer_list_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/PeerList".into()), peer_list_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/PeerList".into()),
ping_timeout_millis: DEFAULT_PING_TIMEOUT_MILLIS.into(),
}); });
strat.update_public_peer_list(&strat.known_hosts.read().unwrap()); strat.update_public_peer_list(&strat.known_hosts.read().unwrap());
@ -261,7 +258,7 @@ impl FullMeshPeeringStrategy {
trace!("{}, {:?}", hex::encode(&id[..8]), info); trace!("{}, {:?}", hex::encode(&id[..8]), info);
match info.state { match info.state {
PeerConnState::Connected => { PeerConnState::Connected => {
let must_ping = match info.last_send_ping { let must_ping = match info.last_seen {
None => true, None => true,
Some(t) => Instant::now() - t > PING_INTERVAL, Some(t) => Instant::now() - t > PING_INTERVAL,
}; };
@ -281,16 +278,9 @@ impl FullMeshPeeringStrategy {
}; };
// 2. Dispatch ping to hosts // 2. Dispatch ping to hosts
trace!("to_ping: {} peers", to_ping.len()); trace!("to_ping: {} peers", to_retry.len());
if !to_ping.is_empty() { for id in to_ping {
let mut known_hosts = self.known_hosts.write().unwrap(); tokio::spawn(self.clone().ping(id));
for id in to_ping.iter() {
known_hosts.list.get_mut(id).unwrap().last_send_ping = Some(Instant::now());
}
drop(known_hosts);
for id in to_ping {
tokio::spawn(self.clone().ping(id));
}
} }
// 3. Try reconnects // 3. Try reconnects
@ -335,12 +325,6 @@ impl FullMeshPeeringStrategy {
self.public_peer_list.load_full() self.public_peer_list.load_full()
} }
/// Set the timeout for ping messages, in milliseconds
pub fn set_ping_timeout_millis(&self, timeout: u64) {
self.ping_timeout_millis
.store(timeout, atomic::Ordering::Relaxed);
}
// -- internal stuff -- // -- internal stuff --
fn update_public_peer_list(&self, known_hosts: &KnownHosts) { fn update_public_peer_list(&self, known_hosts: &KnownHosts) {
@ -382,8 +366,6 @@ impl FullMeshPeeringStrategy {
let peer_list_hash = self.known_hosts.read().unwrap().hash; let peer_list_hash = self.known_hosts.read().unwrap().hash;
let ping_id = self.next_ping_id.fetch_add(1u64, atomic::Ordering::Relaxed); let ping_id = self.next_ping_id.fetch_add(1u64, atomic::Ordering::Relaxed);
let ping_time = Instant::now(); let ping_time = Instant::now();
let ping_timeout =
Duration::from_millis(self.ping_timeout_millis.load(atomic::Ordering::Relaxed));
let ping_msg = PingMessage { let ping_msg = PingMessage {
id: ping_id, id: ping_id,
peer_list_hash, peer_list_hash,
@ -396,8 +378,8 @@ impl FullMeshPeeringStrategy {
ping_time ping_time
); );
let ping_response = select! { let ping_response = select! {
r = self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH) => r, r = self.ping_endpoint.call(&id, &ping_msg, PRIO_HIGH) => r,
_ = tokio::time::sleep(ping_timeout) => Err(Error::Message("Ping timeout".into())), _ = tokio::time::sleep(PING_TIMEOUT) => Err(Error::Message("Ping timeout".into())),
}; };
match ping_response { match ping_response {
@ -448,7 +430,7 @@ impl FullMeshPeeringStrategy {
let pex_message = PeerListMessage { list: peer_list }; let pex_message = PeerListMessage { list: peer_list };
match self match self
.peer_list_endpoint .peer_list_endpoint
.call(id, pex_message, PRIO_BACKGROUND) .call(id, &pex_message, PRIO_BACKGROUND)
.await .await
{ {
Err(e) => warn!("Error doing peer exchange: {}", e), Err(e) => warn!("Error doing peer exchange: {}", e),
@ -556,9 +538,17 @@ impl FullMeshPeeringStrategy {
host.all_addrs.push(addr); host.all_addrs.push(addr);
} }
} else { } else {
known_hosts known_hosts.list.insert(
.list id,
.insert(id, PeerInfoInternal::new(addr, PeerConnState::Connected)); PeerInfoInternal {
state: PeerConnState::Connected,
addr,
all_addrs: vec![addr],
last_seen: None,
ping: VecDeque::new(),
failed_pings: 0,
},
);
} }
} }
known_hosts.update_hash(); known_hosts.update_hash();
@ -583,7 +573,14 @@ impl FullMeshPeeringStrategy {
} else { } else {
PeerConnState::Waiting(0, Instant::now()) PeerConnState::Waiting(0, Instant::now())
}; };
PeerInfoInternal::new(addr, state) PeerInfoInternal {
addr,
all_addrs: vec![addr],
state,
last_seen: None,
ping: VecDeque::new(),
failed_pings: 0,
}
} }
} }

617
src/proto.rs Normal file
View file

@ -0,0 +1,617 @@
use std::collections::{HashMap, VecDeque};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use log::trace;
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use futures::{AsyncReadExt, AsyncWriteExt};
use futures::{Stream, StreamExt};
use kuska_handshake::async_std::BoxStreamWrite;
use tokio::sync::mpsc;
use async_trait::async_trait;
use crate::error::*;
use crate::util::{AssociatedStream, Packet};
/// Priority of a request (click to read more about priorities).
///
/// This priority value is used to priorize messages
/// in the send queue of the client, and their responses in the send queue of the
/// server. Lower values mean higher priority.
///
/// This mechanism is usefull for messages bigger than the maximum chunk size
/// (set at `0x4000` bytes), such as large file transfers.
/// In such case, all of the messages in the send queue with the highest priority
/// will take turns to send individual chunks, in a round-robin fashion.
/// Once all highest priority messages are sent successfully, the messages with
/// the next highest priority will begin being sent in the same way.
///
/// The same priority value is given to a request and to its associated response.
pub type RequestPriority = u8;
/// Priority class: high
pub const PRIO_HIGH: RequestPriority = 0x20;
/// Priority class: normal
pub const PRIO_NORMAL: RequestPriority = 0x40;
/// Priority class: background
pub const PRIO_BACKGROUND: RequestPriority = 0x80;
/// Priority: primary among given class
pub const PRIO_PRIMARY: RequestPriority = 0x00;
/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`)
pub const PRIO_SECONDARY: RequestPriority = 0x01;
// Messages are sent by chunks
// Chunk format:
// - u32 BE: request id (same for request and response)
// - u16 BE: chunk length, possibly with CHUNK_HAS_CONTINUATION flag
// when this is not the last chunk of the message
// - [u8; chunk_length] chunk data
pub(crate) type RequestID = u32;
type ChunkLength = u16;
const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0;
const ERROR_MARKER: ChunkLength = 0x4000;
const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
struct SendQueueItem {
id: RequestID,
prio: RequestPriority,
data: DataReader,
}
#[pin_project::pin_project]
struct DataReader {
#[pin]
reader: AssociatedStream,
packet: Packet,
pos: usize,
buf: Vec<u8>,
eos: bool,
}
impl From<AssociatedStream> for DataReader {
fn from(data: AssociatedStream) -> DataReader {
DataReader {
reader: data,
packet: Ok(Vec::new()),
pos: 0,
buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize),
eos: false,
}
}
}
enum DataFrame {
Data {
/// a fixed size buffer containing some data, possibly padded with 0s
data: [u8; MAX_CHUNK_LENGTH as usize],
/// actual lenght of data
len: usize,
},
Error(u8),
}
struct DataReaderItem {
data: DataFrame,
/// whethere there may be more data comming from this stream. Can be used for some
/// optimization. It's an error to set it to false if there is more data, but it is correct
/// (albeit sub-optimal) to set it to true if there is nothing coming after
may_have_more: bool,
}
impl DataReaderItem {
fn empty_last() -> Self {
DataReaderItem {
data: DataFrame::Data {
data: [0; MAX_CHUNK_LENGTH as usize],
len: 0,
},
may_have_more: false,
}
}
fn header(&self) -> [u8; 2] {
let continuation = if self.may_have_more {
CHUNK_HAS_CONTINUATION
} else {
0
};
let len = match self.data {
DataFrame::Data { len, .. } => len as u16,
DataFrame::Error(e) => e as u16 | ERROR_MARKER,
};
ChunkLength::to_be_bytes(len | continuation)
}
fn data(&self) -> &[u8] {
match self.data {
DataFrame::Data { ref data, len } => &data[..len],
DataFrame::Error(_) => &[],
}
}
}
impl Stream for DataReader {
type Item = DataReaderItem;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if *this.eos {
// eos was reached at previous call to poll_next, where a partial packet
// was returned. Now return None
return Poll::Ready(None);
}
loop {
let packet = match this.packet {
Ok(v) => v,
Err(e) => {
let e = *e;
*this.packet = Ok(Vec::new());
return Poll::Ready(Some(DataReaderItem {
data: DataFrame::Error(e),
may_have_more: true,
}));
}
};
let packet_left = packet.len() - *this.pos;
let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len();
let to_read = std::cmp::min(buf_left, packet_left);
this.buf
.extend_from_slice(&packet[*this.pos..*this.pos + to_read]);
*this.pos += to_read;
if this.buf.len() == MAX_CHUNK_LENGTH as usize {
// we have a full buf, ready to send
break;
}
// we don't have a full buf, packet is empty; try receive more
if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) {
*this.packet = p;
*this.pos = 0;
// if buf is empty, we will loop and return the error directly. If buf
// isn't empty, send it before by breaking.
if this.packet.is_err() && !this.buf.is_empty() {
break;
}
} else {
*this.eos = true;
break;
}
}
let mut body = [0; MAX_CHUNK_LENGTH as usize];
let len = this.buf.len();
body[..len].copy_from_slice(this.buf);
this.buf.clear();
Poll::Ready(Some(DataReaderItem {
data: DataFrame::Data { data: body, len },
may_have_more: !*this.eos,
}))
}
}
struct SendQueue {
items: VecDeque<(u8, VecDeque<SendQueueItem>)>,
}
impl SendQueue {
fn new() -> Self {
Self {
items: VecDeque::with_capacity(64),
}
}
fn push(&mut self, item: SendQueueItem) {
let prio = item.prio;
let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) {
Ok(i) => i,
Err(i) => {
self.items.insert(i, (prio, VecDeque::new()));
i
}
};
self.items[pos_prio].1.push_back(item);
}
// used only in tests. They should probably be rewriten
#[allow(dead_code)]
fn pop(&mut self) -> Option<SendQueueItem> {
match self.items.pop_front() {
None => None,
Some((prio, mut items_at_prio)) => {
let ret = items_at_prio.pop_front();
if !items_at_prio.is_empty() {
self.items.push_front((prio, items_at_prio));
}
ret.or_else(|| self.pop())
}
}
}
fn is_empty(&self) -> bool {
self.items.iter().all(|(_k, v)| v.is_empty())
}
// this is like an async fn, but hand implemented
fn next_ready(&mut self) -> SendQueuePollNextReady<'_> {
SendQueuePollNextReady { queue: self }
}
}
struct SendQueuePollNextReady<'a> {
queue: &'a mut SendQueue,
}
impl<'a> futures::Future for SendQueuePollNextReady<'a> {
type Output = (RequestID, DataReaderItem);
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
for i in 0..self.queue.items.len() {
let (_prio, items_at_prio) = &mut self.queue.items[i];
for _ in 0..items_at_prio.len() {
let mut item = items_at_prio.pop_front().unwrap();
match Pin::new(&mut item.data).poll_next(ctx) {
Poll::Pending => items_at_prio.push_back(item),
Poll::Ready(Some(data)) => {
let id = item.id;
if data.may_have_more {
self.queue.push(item);
} else {
if items_at_prio.is_empty() {
// this priority level is empty, remove it
self.queue.items.remove(i);
}
}
return Poll::Ready((id, data));
}
Poll::Ready(None) => {
if items_at_prio.is_empty() {
// this priority level is empty, remove it
self.queue.items.remove(i);
}
return Poll::Ready((item.id, DataReaderItem::empty_last()));
}
}
}
}
// TODO what do we do if self.queue is empty? We won't get scheduled again.
Poll::Pending
}
}
/// The SendLoop trait, which is implemented both by the client and the server
/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()`
/// that takes a channel of messages to send and an asynchronous writer,
/// and sends messages from the channel to the async writer, putting them in a queue
/// before being sent and doing the round-robin sending strategy.
///
/// The `.send_loop()` exits when the sending end of the channel is closed,
/// or if there is an error at any time writing to the async writer.
#[async_trait]
pub(crate) trait SendLoop: Sync {
async fn send_loop<W>(
self: Arc<Self>,
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>,
mut write: BoxStreamWrite<W>,
) -> Result<(), Error>
where
W: AsyncWriteExt + Unpin + Send + Sync,
{
let mut sending = SendQueue::new();
let mut should_exit = false;
while !should_exit || !sending.is_empty() {
let recv_fut = msg_recv.recv();
futures::pin_mut!(recv_fut);
let send_fut = sending.next_ready();
// recv_fut is cancellation-safe according to tokio doc,
// send_fut is cancellation-safe as implemented above?
use futures::future::Either;
match futures::future::select(recv_fut, send_fut).await {
Either::Left((sth, _send_fut)) => {
if let Some((id, prio, data)) = sth {
sending.push(SendQueueItem {
id,
prio,
data: data.into(),
});
} else {
should_exit = true;
};
}
Either::Right(((id, data), _recv_fut)) => {
trace!("send_loop: sending bytes for {}", id);
let header_id = RequestID::to_be_bytes(id);
write.write_all(&header_id[..]).await?;
write.write_all(&data.header()).await?;
write.write_all(data.data()).await?;
write.flush().await?;
}
}
}
let _ = write.goodbye().await;
Ok(())
}
}
pub(crate) struct Framing {
direct: Vec<u8>,
stream: Option<AssociatedStream>,
}
impl Framing {
pub fn new(direct: Vec<u8>, stream: Option<AssociatedStream>) -> Self {
assert!(direct.len() <= u32::MAX as usize);
Framing { direct, stream }
}
pub fn into_stream(self) -> AssociatedStream {
use futures::stream;
let len = self.direct.len() as u32;
// required because otherwise the borrow-checker complains
let Framing { direct, stream } = self;
let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) })
.chain(stream::once(async move { Ok(direct) }));
if let Some(stream) = stream {
Box::pin(res.chain(stream))
} else {
Box::pin(res)
}
}
pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>(
mut stream: S,
) -> Result<Self, Error> {
let mut packet = stream
.next()
.await
.ok_or(Error::Framing)?
.map_err(|_| Error::Framing)?;
if packet.len() < 4 {
return Err(Error::Framing);
}
let mut len = [0; 4];
len.copy_from_slice(&packet[..4]);
let len = u32::from_be_bytes(len);
packet.drain(..4);
let mut buffer = Vec::new();
let len = len as usize;
loop {
let max_cp = std::cmp::min(len - buffer.len(), packet.len());
buffer.extend_from_slice(&packet[..max_cp]);
if buffer.len() == len {
packet.drain(..max_cp);
break;
}
packet = stream
.next()
.await
.ok_or(Error::Framing)?
.map_err(|_| Error::Framing)?;
}
let stream: AssociatedStream = if packet.is_empty() {
Box::pin(stream)
} else {
Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream))
};
Ok(Framing {
direct: buffer,
stream: Some(stream),
})
}
pub fn into_parts(self) -> (Vec<u8>, AssociatedStream) {
let Framing { direct, stream } = self;
(direct, stream.unwrap_or(Box::pin(futures::stream::empty())))
}
}
/// Structure to warn when the sender is dropped before end of stream was reached, like when
/// connection to some remote drops while transmitting data
struct Sender {
inner: UnboundedSender<Packet>,
closed: bool,
}
impl Sender {
fn new(inner: UnboundedSender<Packet>) -> Self {
Sender {
inner,
closed: false,
}
}
fn send(&self, packet: Packet) {
let _ = self.inner.unbounded_send(packet);
}
fn end(&mut self) {
self.closed = true;
}
}
impl Drop for Sender {
fn drop(&mut self) {
if !self.closed {
self.send(Err(255));
}
self.inner.close_channel();
}
}
/// The RecvLoop trait, which is implemented both by the client and the server
/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
/// and a prototype of a handler for received messages `.recv_handler()` that
/// must be filled by implementors. `.recv_loop()` receives messages in a loop
/// according to the protocol defined above: chunks of message in progress of being
/// received are stored in a buffer, and when the last chunk of a message is received,
/// the full message is passed to the receive handler.
#[async_trait]
pub(crate) trait RecvLoop: Sync + 'static {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>);
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
where
R: AsyncReadExt + Unpin + Send + Sync,
{
let mut streams: HashMap<RequestID, Sender> = HashMap::new();
loop {
trace!("recv_loop: reading packet");
let mut header_id = [0u8; RequestID::BITS as usize / 8];
match read.read_exact(&mut header_id[..]).await {
Ok(_) => (),
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e.into()),
};
let id = RequestID::from_be_bytes(header_id);
trace!("recv_loop: got header id: {:04x}", id);
let mut header_size = [0u8; ChunkLength::BITS as usize / 8];
read.read_exact(&mut header_size[..]).await?;
let size = ChunkLength::from_be_bytes(header_size);
trace!("recv_loop: got header size: {:04x}", size);
let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
let is_error = (size & ERROR_MARKER) != 0;
let packet = if is_error {
Err(size as u8)
} else {
let size = size & !CHUNK_HAS_CONTINUATION;
let mut next_slice = vec![0; size as usize];
read.read_exact(&mut next_slice[..]).await?;
trace!("recv_loop: read {} bytes", next_slice.len());
Ok(next_slice)
};
let mut sender = if let Some(send) = streams.remove(&(id)) {
send
} else {
let (send, recv) = unbounded();
self.recv_handler(id, recv);
Sender::new(send)
};
// if we get an error, the receiving end is disconnected. We still need to
// reach eos before dropping this sender
sender.send(packet);
if has_cont {
streams.insert(id, sender);
} else {
sender.end();
}
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
fn empty_data() -> DataReader {
type Item = Packet;
let stream: Pin<Box<dyn futures::Stream<Item = Item> + Send + 'static>> =
Box::pin(futures::stream::empty::<Packet>());
stream.into()
}
#[test]
fn test_priority_queue() {
let i1 = SendQueueItem {
id: 1,
prio: PRIO_NORMAL,
data: empty_data(),
};
let i2 = SendQueueItem {
id: 2,
prio: PRIO_HIGH,
data: empty_data(),
};
let i2bis = SendQueueItem {
id: 20,
prio: PRIO_HIGH,
data: empty_data(),
};
let i3 = SendQueueItem {
id: 3,
prio: PRIO_HIGH | PRIO_SECONDARY,
data: empty_data(),
};
let i4 = SendQueueItem {
id: 4,
prio: PRIO_BACKGROUND | PRIO_SECONDARY,
data: empty_data(),
};
let i5 = SendQueueItem {
id: 5,
prio: PRIO_BACKGROUND | PRIO_PRIMARY,
data: empty_data(),
};
let mut q = SendQueue::new();
q.push(i1); // 1
let a = q.pop().unwrap(); // empty -> 1
assert_eq!(a.id, 1);
assert!(q.pop().is_none());
q.push(a); // 1
q.push(i2); // 2 1
q.push(i2bis); // [2 20] 1
let a = q.pop().unwrap(); // 20 1 -> 2
assert_eq!(a.id, 2);
let b = q.pop().unwrap(); // 1 -> 20
assert_eq!(b.id, 20);
let c = q.pop().unwrap(); // empty -> 1
assert_eq!(c.id, 1);
assert!(q.pop().is_none());
q.push(a); // 2
q.push(b); // [2 20]
q.push(c); // [2 20] 1
q.push(i3); // [2 20] 3 1
q.push(i4); // [2 20] 3 1 4
q.push(i5); // [2 20] 3 1 5 4
let a = q.pop().unwrap(); // 20 3 1 5 4 -> 2
assert_eq!(a.id, 2);
q.push(a); // [20 2] 3 1 5 4
let a = q.pop().unwrap(); // 2 3 1 5 4 -> 20
assert_eq!(a.id, 20);
let b = q.pop().unwrap(); // 3 1 5 4 -> 2
assert_eq!(b.id, 2);
q.push(b); // 2 3 1 5 4
let b = q.pop().unwrap(); // 3 1 5 4 -> 2
assert_eq!(b.id, 2);
let c = q.pop().unwrap(); // 1 5 4 -> 3
assert_eq!(c.id, 3);
q.push(b); // 2 1 5 4
let b = q.pop().unwrap(); // 1 5 4 -> 2
assert_eq!(b.id, 2);
let e = q.pop().unwrap(); // 5 4 -> 1
assert_eq!(e.id, 1);
let f = q.pop().unwrap(); // 4 -> 5
assert_eq!(f.id, 5);
let g = q.pop().unwrap(); // empty -> 4
assert_eq!(g.id, 4);
assert!(q.pop().is_none());
}
}

75
src/proto2.rs Normal file
View file

@ -0,0 +1,75 @@
use crate::error::*;
use crate::proto::*;
pub(crate) struct QueryMessage<'a> {
pub(crate) prio: RequestPriority,
pub(crate) path: &'a [u8],
pub(crate) telemetry_id: Option<Vec<u8>>,
pub(crate) body: &'a [u8],
}
/// QueryMessage encoding:
/// - priority: u8
/// - path length: u8
/// - path: [u8; path length]
/// - telemetry id length: u8
/// - telemetry id: [u8; telemetry id length]
/// - body [u8; ..]
impl<'a> QueryMessage<'a> {
pub(crate) fn encode(self) -> Vec<u8> {
let tel_len = match &self.telemetry_id {
Some(t) => t.len(),
None => 0,
};
let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len());
ret.push(self.prio);
ret.push(self.path.len() as u8);
ret.extend_from_slice(self.path);
if let Some(t) = self.telemetry_id {
ret.push(t.len() as u8);
ret.extend(t);
} else {
ret.push(0u8);
}
ret.extend_from_slice(self.body);
ret
}
pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> {
if bytes.len() < 3 {
return Err(Error::Message("Invalid protocol message".into()));
}
let path_length = bytes[1] as usize;
if bytes.len() < 3 + path_length {
return Err(Error::Message("Invalid protocol message".into()));
}
let telemetry_id_len = bytes[2 + path_length] as usize;
if bytes.len() < 3 + path_length + telemetry_id_len {
return Err(Error::Message("Invalid protocol message".into()));
}
let path = &bytes[2..2 + path_length];
let telemetry_id = if telemetry_id_len > 0 {
Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec())
} else {
None
};
let body = &bytes[3 + path_length + telemetry_id_len..];
Ok(Self {
prio: bytes[0],
path,
telemetry_id,
body,
})
}
}

View file

@ -1,153 +0,0 @@
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use log::*;
use futures::AsyncReadExt;
use tokio::sync::mpsc;
use crate::error::*;
use crate::send::*;
use crate::stream::*;
/// Structure to warn when the sender is dropped before end of stream was reached, like when
/// connection to some remote drops while transmitting data
struct Sender {
inner: Option<mpsc::UnboundedSender<Packet>>,
}
impl Sender {
fn new(inner: mpsc::UnboundedSender<Packet>) -> Self {
Sender { inner: Some(inner) }
}
fn send(&self, packet: Packet) {
let _ = self.inner.as_ref().unwrap().send(packet);
}
fn end(&mut self) {
self.inner = None;
}
}
impl Drop for Sender {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
let _ = inner.send(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Netapp connection dropped before end of stream",
)));
}
}
}
/// The RecvLoop trait, which is implemented both by the client and the server
/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
/// and a prototype of a handler for received messages `.recv_handler()` that
/// must be filled by implementors. `.recv_loop()` receives messages in a loop
/// according to the protocol defined above: chunks of message in progress of being
/// received are stored in a buffer, and when the last chunk of a message is received,
/// the full message is passed to the receive handler.
#[async_trait]
pub(crate) trait RecvLoop: Sync + 'static {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream);
fn cancel_handler(self: &Arc<Self>, _id: RequestID) {}
async fn recv_loop<R>(self: Arc<Self>, mut read: R, debug_name: String) -> Result<(), Error>
where
R: AsyncReadExt + Unpin + Send + Sync,
{
let mut streams: HashMap<RequestID, Sender> = HashMap::new();
loop {
trace!(
"recv_loop({}): in_progress = {:?}",
debug_name,
streams.iter().map(|(id, _)| id).collect::<Vec<_>>()
);
let mut header_id = [0u8; RequestID::BITS as usize / 8];
match read.read_exact(&mut header_id[..]).await {
Ok(_) => (),
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e.into()),
};
let id = RequestID::from_be_bytes(header_id);
let mut header_size = [0u8; ChunkLength::BITS as usize / 8];
read.read_exact(&mut header_size[..]).await?;
let size = ChunkLength::from_be_bytes(header_size);
if size == CANCEL_REQUEST {
if let Some(mut stream) = streams.remove(&id) {
let _ = stream.send(Err(std::io::Error::new(
std::io::ErrorKind::Other,
"netapp: cancel requested",
)));
stream.end();
}
self.cancel_handler(id);
continue;
}
let has_cont = (size & CHUNK_FLAG_HAS_CONTINUATION) != 0;
let is_error = (size & CHUNK_FLAG_ERROR) != 0;
let size = (size & CHUNK_LENGTH_MASK) as usize;
let mut next_slice = vec![0; size as usize];
read.read_exact(&mut next_slice[..]).await?;
let packet = if is_error {
let kind = u8_to_io_errorkind(next_slice[0]);
let msg =
std::str::from_utf8(&next_slice[1..]).unwrap_or("<invalid utf8 error message>");
debug!(
"recv_loop({}): got id {}, error {:?}: {}",
debug_name, id, kind, msg
);
Some(Err(std::io::Error::new(kind, msg.to_string())))
} else {
trace!(
"recv_loop({}): got id {}, size {}, has_cont {}",
debug_name,
id,
size,
has_cont
);
if !next_slice.is_empty() {
Some(Ok(Bytes::from(next_slice)))
} else {
None
}
};
let mut sender = if let Some(send) = streams.remove(&(id)) {
send
} else {
let (send, recv) = mpsc::unbounded_channel();
trace!("recv_loop({}): id {} is new channel", debug_name, id);
self.recv_handler(
id,
Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(recv)),
);
Sender::new(send)
};
if let Some(packet) = packet {
// If we cannot put packet in channel, it means that the
// receiving end of the channel is disconnected.
// We still need to reach eos before dropping this sender
let _ = sender.send(packet);
}
if has_cont {
assert!(!is_error);
streams.insert(id, sender);
} else {
trace!("recv_loop({}): close channel id {}", debug_name, id);
sender.end();
}
}
Ok(())
}
}

View file

@ -1,356 +0,0 @@
use std::collections::{HashMap, VecDeque};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use async_trait::async_trait;
use bytes::{BufMut, Bytes, BytesMut};
use log::*;
use futures::{AsyncWriteExt, Future};
use kuska_handshake::async_std::BoxStreamWrite;
use tokio::sync::mpsc;
use crate::error::*;
use crate::message::*;
use crate::stream::*;
// Messages are sent by chunks
// Chunk format:
// - u32 BE: request id (same for request and response)
// - u16 BE: chunk length + flags:
// CHUNK_FLAG_HAS_CONTINUATION when this is not the last chunk of the stream
// CHUNK_FLAG_ERROR if this chunk denotes an error
// (these two flags are exclusive, an error denotes the end of the stream)
// **special value** 0xFFFF indicates a CANCEL message
// - [u8; chunk_length], either
// - if not error: chunk data
// - if error:
// - u8: error kind, encoded using error::io_errorkind_to_u8
// - rest: error message
// - absent for cancel messag
pub(crate) type RequestID = u32;
pub(crate) type ChunkLength = u16;
pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0;
pub(crate) const CHUNK_FLAG_ERROR: ChunkLength = 0x4000;
pub(crate) const CHUNK_FLAG_HAS_CONTINUATION: ChunkLength = 0x8000;
pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF;
pub(crate) const CANCEL_REQUEST: ChunkLength = 0xFFFF;
pub(crate) enum SendItem {
Stream(RequestID, RequestPriority, Option<OrderTag>, ByteStream),
Cancel(RequestID),
}
// ----
struct SendQueue {
items: Vec<(u8, SendQueuePriority)>,
}
struct SendQueuePriority {
items: VecDeque<SendQueueItem>,
order: HashMap<u64, VecDeque<u64>>,
}
struct SendQueueItem {
id: RequestID,
prio: RequestPriority,
order_tag: Option<OrderTag>,
data: ByteStreamReader,
sent: usize,
}
impl SendQueue {
fn new() -> Self {
Self {
items: Vec::with_capacity(64),
}
}
fn push(&mut self, item: SendQueueItem) {
let prio = item.prio;
let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) {
Ok(i) => i,
Err(i) => {
self.items.insert(i, (prio, SendQueuePriority::new()));
i
}
};
self.items[pos_prio].1.push(item);
}
fn remove(&mut self, id: RequestID) {
for (_, prioq) in self.items.iter_mut() {
prioq.remove(id);
}
self.items.retain(|(_prio, q)| !q.is_empty());
}
fn is_empty(&self) -> bool {
self.items.iter().all(|(_k, v)| v.is_empty())
}
// this is like an async fn, but hand implemented
fn next_ready(&mut self) -> SendQueuePollNextReady<'_> {
SendQueuePollNextReady { queue: self }
}
}
impl SendQueuePriority {
fn new() -> Self {
Self {
items: VecDeque::new(),
order: HashMap::new(),
}
}
fn push(&mut self, item: SendQueueItem) {
if let Some(OrderTag(stream, order)) = item.order_tag {
let order_vec = self.order.entry(stream).or_default();
let i = order_vec.iter().take_while(|o2| **o2 < order).count();
order_vec.insert(i, order);
}
self.items.push_front(item);
}
fn remove(&mut self, id: RequestID) {
if let Some(i) = self.items.iter().position(|x| x.id == id) {
let item = self.items.remove(i).unwrap();
if let Some(OrderTag(stream, order)) = item.order_tag {
let order_vec = self.order.get_mut(&stream).unwrap();
let j = order_vec.iter().position(|x| *x == order).unwrap();
order_vec.remove(j).unwrap();
if order_vec.is_empty() {
self.order.remove(&stream);
}
}
}
}
fn is_empty(&self) -> bool {
self.items.is_empty()
}
fn poll_next_ready(&mut self, ctx: &mut Context<'_>) -> Poll<(RequestID, DataFrame)> {
for (j, item) in self.items.iter_mut().enumerate() {
if let Some(OrderTag(stream, order)) = item.order_tag {
if order > *self.order.get(&stream).unwrap().front().unwrap() {
continue;
}
}
let mut item_reader = item.data.read_exact_or_eos(MAX_CHUNK_LENGTH as usize);
if let Poll::Ready(bytes_or_err) = Pin::new(&mut item_reader).poll(ctx) {
let id = item.id;
let eos = item.data.eos();
let packet = bytes_or_err.map_err(|e| match e {
ReadExactError::Stream(err) => err,
_ => unreachable!(),
});
let is_err = packet.is_err();
let data_frame = DataFrame::from_packet(packet, !eos);
item.sent += data_frame.data().len();
if eos || is_err {
// If item had an order tag, remove it from the corresponding ordering list
if let Some(OrderTag(stream, order)) = item.order_tag {
let order_stream = self.order.get_mut(&stream).unwrap();
assert_eq!(order_stream.pop_front(), Some(order));
if order_stream.is_empty() {
self.order.remove(&stream);
}
}
// Remove item from sending queue
self.items.remove(j);
} else {
// Move item later in send queue to implement LAS scheduling
// (LAS = Least Attained Service)
for k in j..self.items.len() - 1 {
if self.items[k].sent >= self.items[k + 1].sent {
self.items.swap(k, k + 1);
} else {
break;
}
}
}
return Poll::Ready((id, data_frame));
}
}
Poll::Pending
}
fn dump(&self, prio: u8) -> String {
self.items
.iter()
.map(|i| format!("[{} {} {:?} @{}]", prio, i.id, i.order_tag, i.sent))
.collect::<Vec<_>>()
.join(" ")
}
}
struct SendQueuePollNextReady<'a> {
queue: &'a mut SendQueue,
}
impl<'a> futures::Future for SendQueuePollNextReady<'a> {
type Output = (RequestID, DataFrame);
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
for (i, (_prio, items_at_prio)) in self.queue.items.iter_mut().enumerate() {
if let Poll::Ready(res) = items_at_prio.poll_next_ready(ctx) {
if items_at_prio.is_empty() {
self.queue.items.remove(i);
}
return Poll::Ready(res);
}
}
// If the queue is empty, this futures is eternally pending.
// This is ok because we use it in a select with another future
// that can interrupt it.
Poll::Pending
}
}
enum DataFrame {
/// a fixed size buffer containing some data + a boolean indicating whether
/// there may be more data comming from this stream. Can be used for some
/// optimization. It's an error to set it to false if there is more data, but it is correct
/// (albeit sub-optimal) to set it to true if there is nothing coming after
Data(Bytes, bool),
/// An error code automatically signals the end of the stream
Error(Bytes),
}
impl DataFrame {
fn from_packet(p: Packet, has_cont: bool) -> Self {
match p {
Ok(bytes) => {
assert!(bytes.len() <= MAX_CHUNK_LENGTH as usize);
Self::Data(bytes, has_cont)
}
Err(e) => {
let mut buf = BytesMut::new();
buf.put_u8(io_errorkind_to_u8(e.kind()));
let msg = format!("{}", e).into_bytes();
if msg.len() > (MAX_CHUNK_LENGTH - 1) as usize {
buf.put(&msg[..(MAX_CHUNK_LENGTH - 1) as usize]);
} else {
buf.put(&msg[..]);
}
Self::Error(buf.freeze())
}
}
}
fn header(&self) -> [u8; 2] {
let header_u16 = match self {
DataFrame::Data(data, false) => data.len() as u16,
DataFrame::Data(data, true) => data.len() as u16 | CHUNK_FLAG_HAS_CONTINUATION,
DataFrame::Error(msg) => msg.len() as u16 | CHUNK_FLAG_ERROR,
};
ChunkLength::to_be_bytes(header_u16)
}
fn data(&self) -> &[u8] {
match self {
DataFrame::Data(ref data, _) => &data[..],
DataFrame::Error(ref msg) => &msg[..],
}
}
}
/// The SendLoop trait, which is implemented both by the client and the server
/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()`
/// that takes a channel of messages to send and an asynchronous writer,
/// and sends messages from the channel to the async writer, putting them in a queue
/// before being sent and doing the round-robin sending strategy.
///
/// The `.send_loop()` exits when the sending end of the channel is closed,
/// or if there is an error at any time writing to the async writer.
#[async_trait]
pub(crate) trait SendLoop: Sync {
async fn send_loop<W>(
self: Arc<Self>,
msg_recv: mpsc::UnboundedReceiver<SendItem>,
mut write: BoxStreamWrite<W>,
debug_name: String,
) -> Result<(), Error>
where
W: AsyncWriteExt + Unpin + Send + Sync,
{
let mut sending = SendQueue::new();
let mut msg_recv = Some(msg_recv);
while msg_recv.is_some() || !sending.is_empty() {
trace!(
"send_loop({}): queue = {:?}",
debug_name,
sending
.items
.iter()
.map(|(prio, i)| i.dump(*prio))
.collect::<Vec<_>>()
.join(" ; ")
);
let recv_fut = async {
if let Some(chan) = &mut msg_recv {
chan.recv().await
} else {
futures::future::pending().await
}
};
let send_fut = sending.next_ready();
// recv_fut is cancellation-safe according to tokio doc,
// send_fut is cancellation-safe as implemented above?
tokio::select! {
biased; // always read incomming channel first if it has data
sth = recv_fut => {
match sth {
Some(SendItem::Stream(id, prio, order_tag, data)) => {
trace!("send_loop({}): add stream {} to send", debug_name, id);
sending.push(SendQueueItem {
id,
prio,
order_tag,
data: ByteStreamReader::new(data),
sent: 0,
})
}
Some(SendItem::Cancel(id)) => {
trace!("send_loop({}): cancelling {}", debug_name, id);
sending.remove(id);
let header_id = RequestID::to_be_bytes(id);
write.write_all(&header_id[..]).await?;
write.write_all(&ChunkLength::to_be_bytes(CANCEL_REQUEST)).await?;
write.flush().await?;
}
None => {
msg_recv = None;
}
};
}
(id, data) = send_fut => {
trace!(
"send_loop({}): id {}, send {} bytes, header_size {}",
debug_name,
id,
data.data().len(),
hex::encode(data.header())
);
let header_id = RequestID::to_be_bytes(id);
write.write_all(&header_id[..]).await?;
write.write_all(&data.header()).await?;
write.write_all(data.data()).await?;
write.flush().await?;
}
}
}
let _ = write.goodbye().await;
Ok(())
}
}

View file

@ -1,17 +1,8 @@
use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::{Arc, Mutex}; use std::sync::Arc;
use arc_swap::ArcSwapOption; use arc_swap::ArcSwapOption;
use async_trait::async_trait; use log::{debug, trace};
use log::*;
use futures::io::{AsyncReadExt, AsyncWriteExt};
use kuska_handshake::async_std::{handshake_server, BoxStream};
use tokio::net::TcpStream;
use tokio::select;
use tokio::sync::{mpsc, watch};
use tokio_util::compat::*;
#[cfg(feature = "telemetry")] #[cfg(feature = "telemetry")]
use opentelemetry::{ use opentelemetry::{
@ -23,12 +14,22 @@ use opentelemetry_contrib::trace::propagator::binary::*;
#[cfg(feature = "telemetry")] #[cfg(feature = "telemetry")]
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use tokio::net::TcpStream;
use tokio::select;
use tokio::sync::{mpsc, watch};
use tokio_util::compat::*;
use futures::channel::mpsc::UnboundedReceiver;
use futures::io::{AsyncReadExt, AsyncWriteExt};
use async_trait::async_trait;
use kuska_handshake::async_std::{handshake_server, BoxStream};
use crate::error::*; use crate::error::*;
use crate::message::*;
use crate::netapp::*; use crate::netapp::*;
use crate::recv::*; use crate::proto::*;
use crate::send::*; use crate::proto2::*;
use crate::stream::*;
use crate::util::*; use crate::util::*;
// The client and server connection structs (client.rs and server.rs) // The client and server connection structs (client.rs and server.rs)
@ -54,8 +55,7 @@ pub(crate) struct ServerConn {
netapp: Arc<NetApp>, netapp: Arc<NetApp>,
resp_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>, resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>,
running_handlers: Mutex<HashMap<RequestID, tokio::task::JoinHandle<()>>>,
} }
impl ServerConn { impl ServerConn {
@ -79,7 +79,7 @@ impl ServerConn {
debug!( debug!(
"Handshake complete (server) with {}@{}", "Handshake complete (server) with {}@{}",
hex::encode(peer_id), hex::encode(&peer_id),
remote_addr remote_addr
); );
@ -101,22 +101,18 @@ impl ServerConn {
remote_addr, remote_addr,
peer_id, peer_id,
resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))), resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))),
running_handlers: Mutex::new(HashMap::new()),
}); });
netapp.connected_as_server(peer_id, conn.clone()); netapp.connected_as_server(peer_id, conn.clone());
let debug_name = format!("SRV {}", hex::encode(&peer_id[..8]));
let debug_name_2 = debug_name.clone();
let conn2 = conn.clone(); let conn2 = conn.clone();
let recv_future = tokio::spawn(async move { let recv_future = tokio::spawn(async move {
select! { select! {
r = conn2.recv_loop(read, debug_name_2) => r, r = conn2.recv_loop(read) => r,
_ = await_exit(must_exit) => Ok(()) _ = await_exit(must_exit) => Ok(())
} }
}); });
let send_future = tokio::spawn(conn.clone().send_loop(resp_recv, write, debug_name)); let send_future = tokio::spawn(conn.clone().send_loop(resp_recv, write));
recv_future.await.log_err("ServerConn recv_loop"); recv_future.await.log_err("ServerConn recv_loop");
conn.resp_send.store(None); conn.resp_send.store(None);
@ -127,8 +123,13 @@ impl ServerConn {
Ok(()) Ok(())
} }
async fn recv_handler_aux(self: &Arc<Self>, req_enc: ReqEnc) -> Result<RespEnc, Error> { async fn recv_handler_aux(
let path = String::from_utf8(req_enc.path.to_vec())?; self: &Arc<Self>,
bytes: &[u8],
stream: AssociatedStream,
) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> {
let msg = QueryMessage::decode(bytes)?;
let path = String::from_utf8(msg.path.to_vec())?;
let handler_opt = { let handler_opt = {
let endpoints = self.netapp.endpoints.read().unwrap(); let endpoints = self.netapp.endpoints.read().unwrap();
@ -140,9 +141,9 @@ impl ServerConn {
if #[cfg(feature = "telemetry")] { if #[cfg(feature = "telemetry")] {
let tracer = opentelemetry::global::tracer("netapp"); let tracer = opentelemetry::global::tracer("netapp");
let mut span = if !req_enc.telemetry_id.is_empty() { let mut span = if let Some(telemetry_id) = msg.telemetry_id {
let propagator = BinaryPropagator::new(); let propagator = BinaryPropagator::new();
let context = propagator.from_bytes(req_enc.telemetry_id.to_vec()); let context = propagator.from_bytes(telemetry_id);
let context = Context::new().with_remote_span_context(context); let context = Context::new().with_remote_span_context(context);
tracer.span_builder(format!(">> RPC {}", path)) tracer.span_builder(format!(">> RPC {}", path))
.with_kind(SpanKind::Server) .with_kind(SpanKind::Server)
@ -157,13 +158,13 @@ impl ServerConn {
.start(&tracer) .start(&tracer)
}; };
span.set_attribute(KeyValue::new("path", path.to_string())); span.set_attribute(KeyValue::new("path", path.to_string()));
span.set_attribute(KeyValue::new("len_query_msg", req_enc.msg.len() as i64)); span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));
handler.handle(req_enc, self.peer_id) handler.handle(msg.body, stream, self.peer_id)
.with_context(Context::current_with_span(span)) .with_context(Context::current_with_span(span))
.await .await
} else { } else {
handler.handle(req_enc, self.peer_id).await handler.handle(msg.body, stream, self.peer_id).await
} }
} }
} else { } else {
@ -176,47 +177,40 @@ impl SendLoop for ServerConn {}
#[async_trait] #[async_trait]
impl RecvLoop for ServerConn { impl RecvLoop for ServerConn {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) { fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
let resp_send = match self.resp_send.load_full() { let resp_send = self.resp_send.load_full().unwrap();
Some(c) => c,
None => return,
};
let mut rh = self.running_handlers.lock().unwrap();
let self2 = self.clone(); let self2 = self.clone();
let jh = tokio::spawn(async move { tokio::spawn(async move {
debug!("server: recv_handler got {}", id); trace!("ServerConn recv_handler {}", id);
let (bytes, stream) = Framing::from_stream(stream).await?.into_parts();
let (prio, resp_enc_result) = match ReqEnc::decode(stream).await { let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
Ok(req_enc) => (req_enc.prio, self2.recv_handler_aux(req_enc).await), let resp = self2.recv_handler_aux(&bytes[..], stream).await;
Err(e) => (PRIO_HIGH, Err(e)),
let (resp_bytes, resp_stream) = match resp {
Ok((rb, rs)) => {
let mut resp_bytes = vec![0u8];
resp_bytes.extend(rb);
(resp_bytes, rs)
}
Err(e) => {
let mut resp_bytes = vec![e.code()];
resp_bytes.extend(e.to_string().into_bytes());
(resp_bytes, None)
}
}; };
debug!("server: sending response to {}", id); trace!("ServerConn sending response to {}: ", id);
let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result);
resp_send resp_send
.send(SendItem::Stream(id, prio, resp_order, resp_stream)) .send((
id,
prio,
Framing::new(resp_bytes, resp_stream).into_stream(),
))
.log_err("ServerConn recv_handler send resp bytes"); .log_err("ServerConn recv_handler send resp bytes");
Ok::<_, Error>(())
self2.running_handlers.lock().unwrap().remove(&id);
}); });
rh.insert(id, jh);
}
fn cancel_handler(self: &Arc<Self>, id: RequestID) {
trace!("received cancel for request {}", id);
// If the handler is still running, abort it now
if let Some(jh) = self.running_handlers.lock().unwrap().remove(&id) {
jh.abort();
}
// Inform the response sender that we don't need to send the response
if let Some(resp_send) = self.resp_send.load_full() {
let _ = resp_send.send(SendItem::Cancel(id));
}
} }
} }

View file

@ -1,202 +0,0 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures::Future;
use futures::{Stream, StreamExt};
use tokio::io::AsyncRead;
use crate::bytes_buf::BytesBuf;
/// A stream of bytes (click to read more).
///
/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
/// consecutive Vec may get merged, but Vec and error code may not be reordered
///
/// Items sent in the ByteStream may be errors of type `std::io::Error`.
/// An error indicates the end of the ByteStream: a reader should no longer read
/// after recieving an error, and a writer should stop writing after sending an error.
pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>;
/// A packet sent in a ByteStream, which may contain either
/// a Bytes object or an error
pub type Packet = Result<Bytes, std::io::Error>;
// ----
/// A helper struct to read defined lengths of data from a BytesStream
pub struct ByteStreamReader {
stream: ByteStream,
buf: BytesBuf,
eos: bool,
err: Option<std::io::Error>,
}
impl ByteStreamReader {
/// Creates a new `ByteStreamReader` from a `ByteStream`
pub fn new(stream: ByteStream) -> Self {
ByteStreamReader {
stream,
buf: BytesBuf::new(),
eos: false,
err: None,
}
}
/// Read exactly `read_len` bytes from the underlying stream
/// (returns a future)
pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
ByteStreamReadExact {
reader: self,
read_len,
fail_on_eos: true,
}
}
/// Read at most `read_len` bytes from the underlying stream, or less
/// if the end of the stream is reached (returns a future)
pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
ByteStreamReadExact {
reader: self,
read_len,
fail_on_eos: false,
}
}
/// Read exactly one byte from the underlying stream and returns it
/// as an u8
pub async fn read_u8(&mut self) -> Result<u8, ReadExactError> {
Ok(self.read_exact(1).await?[0])
}
/// Read exactly two bytes from the underlying stream and returns them as an u16 (using
/// big-endian decoding)
pub async fn read_u16(&mut self) -> Result<u16, ReadExactError> {
let bytes = self.read_exact(2).await?;
let mut b = [0u8; 2];
b.copy_from_slice(&bytes[..]);
Ok(u16::from_be_bytes(b))
}
/// Read exactly four bytes from the underlying stream and returns them as an u32 (using
/// big-endian decoding)
pub async fn read_u32(&mut self) -> Result<u32, ReadExactError> {
let bytes = self.read_exact(4).await?;
let mut b = [0u8; 4];
b.copy_from_slice(&bytes[..]);
Ok(u32::from_be_bytes(b))
}
/// Transforms the stream reader back into the underlying stream (starting
/// after everything that the reader has read)
pub fn into_stream(self) -> ByteStream {
let buf_stream = futures::stream::iter(self.buf.into_slices().into_iter().map(Ok));
if let Some(err) = self.err {
Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) })))
} else if self.eos {
Box::pin(buf_stream)
} else {
Box::pin(buf_stream.chain(self.stream))
}
}
/// Tries to fill the internal read buffer from the underlying stream if it is empty.
/// Calling this might be necessary to ensure that `.eos()` returns a correct
/// result, otherwise the reader might not be aware that the underlying
/// stream has nothing left to return.
pub async fn fill_buffer(&mut self) {
if self.buf.is_empty() {
let packet = self.stream.next().await;
self.add_stream_next(packet);
}
}
/// Clears the internal read buffer and returns its content
pub fn take_buffer(&mut self) -> Bytes {
self.buf.take_all()
}
/// Returns true if the end of the underlying stream has been reached
pub fn eos(&self) -> bool {
self.buf.is_empty() && self.eos
}
fn try_get(&mut self, read_len: usize) -> Option<Bytes> {
self.buf.take_exact(read_len)
}
fn add_stream_next(&mut self, packet: Option<Packet>) {
match packet {
Some(Ok(slice)) => {
self.buf.extend(slice);
}
Some(Err(e)) => {
self.err = Some(e);
self.eos = true;
}
None => {
self.eos = true;
}
}
}
}
/// The error kind that can be returned by `ByteStreamReader::read_exact` and
/// `ByteStreamReader::read_exact_or_eos`
pub enum ReadExactError {
/// The end of the stream was reached before the requested number of bytes could be read
UnexpectedEos,
/// The underlying data stream returned an IO error when trying to read
Stream(std::io::Error),
}
/// The future returned by `ByteStreamReader::read_exact` and
/// `ByteStreamReader::read_exact_or_eos`
#[pin_project::pin_project]
pub struct ByteStreamReadExact<'a> {
#[pin]
reader: &'a mut ByteStreamReader,
read_len: usize,
fail_on_eos: bool,
}
impl<'a> Future for ByteStreamReadExact<'a> {
type Output = Result<Bytes, ReadExactError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Bytes, ReadExactError>> {
let mut this = self.project();
loop {
if let Some(bytes) = this.reader.try_get(*this.read_len) {
return Poll::Ready(Ok(bytes));
}
if let Some(err) = &this.reader.err {
let err = std::io::Error::new(err.kind(), format!("{}", err));
return Poll::Ready(Err(ReadExactError::Stream(err)));
}
if this.reader.eos {
if *this.fail_on_eos {
return Poll::Ready(Err(ReadExactError::UnexpectedEos));
} else {
return Poll::Ready(Ok(this.reader.take_buffer()));
}
}
let next_packet = futures::ready!(this.reader.stream.as_mut().poll_next(cx));
this.reader.add_stream_next(next_packet);
}
}
}
// ----
/// Turns a `tokio::io::AsyncRead` asynchronous reader into a `ByteStream`
pub fn asyncread_stream<R: AsyncRead + Send + Sync + 'static>(reader: R) -> ByteStream {
Box::pin(tokio_util::io::ReaderStream::new(reader))
}
/// Turns a `ByteStream` into a `tokio::io::AsyncRead` asynchronous reader
pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static {
tokio_util::io::StreamReader::new(stream)
}

View file

@ -14,7 +14,7 @@ use crate::NodeID;
#[tokio::test(flavor = "current_thread")] #[tokio::test(flavor = "current_thread")]
async fn test_with_basic_scheduler() { async fn test_with_basic_scheduler() {
env_logger::init(); pretty_env_logger::init();
run_test().await run_test().await
} }

View file

@ -1,25 +1,54 @@
use crate::endpoint::SerializeMessage;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::pin::Pin;
use futures::Stream;
use log::info; use log::info;
use serde::Serialize; use serde::Serialize;
use tokio::sync::watch; use tokio::sync::watch;
use crate::netapp::*; /// A node's identifier, which is also its public cryptographic key
pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey;
/// A node's secret key
pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey;
/// A network key
pub type NetworkKey = sodiumoxide::crypto::auth::Key;
/// A stream of associated data.
///
/// The Stream can continue after receiving an error.
/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
/// consecutive Vec may get merged, but Vec and error code may not be reordered
///
/// Error code 255 means the stream was cut before its end. Other codes have no predefined
/// meaning, it's up to your application to define their semantic.
pub type AssociatedStream = Pin<Box<dyn Stream<Item = Packet> + Send>>;
pub type Packet = Result<Vec<u8>, u8>;
/// Utility function: encodes any serializable value in MessagePack binary format /// Utility function: encodes any serializable value in MessagePack binary format
/// using the RMP library. /// using the RMP library.
/// ///
/// Field names and variant names are included in the serialization. /// Field names and variant names are included in the serialization.
/// This is used internally by the netapp communication protocol. /// This is used internally by the netapp communication protocol.
pub fn rmp_to_vec_all_named<T>(val: &T) -> Result<Vec<u8>, rmp_serde::encode::Error> pub fn rmp_to_vec_all_named<T>(
val: &T,
) -> Result<(Vec<u8>, Option<AssociatedStream>), rmp_serde::encode::Error>
where where
T: Serialize + ?Sized, T: SerializeMessage + ?Sized,
{ {
let mut wr = Vec::with_capacity(128); let mut wr = Vec::with_capacity(128);
let mut se = rmp_serde::Serializer::new(&mut wr).with_struct_map(); let mut se = rmp_serde::Serializer::new(&mut wr)
.with_struct_map()
.with_string_variants();
let (val, stream) = val.serialize_msg();
val.serialize(&mut se)?; val.serialize(&mut se)?;
Ok(wr) Ok((wr, stream))
} }
/// This async function returns only when a true signal was received /// This async function returns only when a true signal was received
@ -60,7 +89,7 @@ pub fn watch_ctrl_c() -> watch::Receiver<bool> {
pub fn parse_peer_addr(peer: &str) -> Option<(NodeID, SocketAddr)> { pub fn parse_peer_addr(peer: &str) -> Option<(NodeID, SocketAddr)> {
let delim = peer.find('@')?; let delim = peer.find('@')?;
let (key, ip) = peer.split_at(delim); let (key, ip) = peer.split_at(delim);
let pubkey = NodeID::from_slice(&hex::decode(key).ok()?)?; let pubkey = NodeID::from_slice(&hex::decode(&key).ok()?)?;
let ip = ip[1..].parse::<SocketAddr>().ok()?; let ip = ip[1..].parse::<SocketAddr>().ok()?;
Some((pubkey, ip)) Some((pubkey, ip))
} }
@ -68,29 +97,12 @@ pub fn parse_peer_addr(peer: &str) -> Option<(NodeID, SocketAddr)> {
/// Parse and resolve a peer's address including public key, written in the format: /// Parse and resolve a peer's address including public key, written in the format:
/// `<public key hex>@<ip or hostname>:<port>` /// `<public key hex>@<ip or hostname>:<port>`
pub fn parse_and_resolve_peer_addr(peer: &str) -> Option<(NodeID, Vec<SocketAddr>)> { pub fn parse_and_resolve_peer_addr(peer: &str) -> Option<(NodeID, Vec<SocketAddr>)> {
use std::net::ToSocketAddrs;
let delim = peer.find('@')?; let delim = peer.find('@')?;
let (key, host) = peer.split_at(delim); let (key, host) = peer.split_at(delim);
let pubkey = NodeID::from_slice(&hex::decode(key).ok()?)?; let pubkey = NodeID::from_slice(&hex::decode(&key).ok()?)?;
let hosts = host[1..].to_socket_addrs().ok()?.collect::<Vec<_>>(); let hosts = host[1..].to_socket_addrs().ok()?.collect::<Vec<_>>();
if hosts.is_empty() { if hosts.is_empty() {
return None; return None;
} }
Some((pubkey, hosts)) Some((pubkey, hosts))
} }
/// async version of parse_and_resolve_peer_addr
pub async fn parse_and_resolve_peer_addr_async(peer: &str) -> Option<(NodeID, Vec<SocketAddr>)> {
let delim = peer.find('@')?;
let (key, host) = peer.split_at(delim);
let pubkey = NodeID::from_slice(&hex::decode(key).ok()?)?;
let hosts = tokio::net::lookup_host(&host[1..])
.await
.ok()?
.collect::<Vec<_>>();
if hosts.is_empty() {
return None;
}
Some((pubkey, hosts))
}

1
target
View file

@ -1 +0,0 @@
/home/lx.nobackup/rust/netapp.target/