forked from lx/netapp
First commit
This commit is contained in:
commit
d4de2ffc40
17 changed files with 3342 additions and 0 deletions
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
target/
|
1409
Cargo.lock
generated
Normal file
1409
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
27
Cargo.toml
Normal file
27
Cargo.toml
Normal file
|
@ -0,0 +1,27 @@
|
|||
[package]
|
||||
name = "netapp"
|
||||
version = "0.1.0"
|
||||
authors = ["Alex Auvolat <alex.auvolat@inria.fr>"]
|
||||
edition = "2018"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
async-std = { version = "1.5.0", features=["unstable","attributes"] }
|
||||
tokio = { version = "0.2", features = ["full"] }
|
||||
kuska-handshake = { path = "../../handshake", features = ["default", "tokio_compat"] }
|
||||
hex = "0.4.2"
|
||||
log = "0.4.8"
|
||||
pretty_env_logger = "0.4"
|
||||
sodiumoxide = { git = "https://github.com/Dhole/sodiumoxidez", branch = "extra" }
|
||||
env_logger = "0.7.1"
|
||||
base64 = "0.12.1"
|
||||
rmp-serde = "0.14.3"
|
||||
serde = { version = "1.0", default-features = false, features = ["derive", "rc"] }
|
||||
arc-swap = "1.0"
|
||||
structopt = { version = "0.3", default-features = false }
|
||||
async-trait = "0.1.7"
|
||||
err-derive = "0.2.3"
|
||||
bytes = "0.6.0"
|
||||
lru = "0.6"
|
||||
rand = "0.5.5"
|
3
Makefile
Normal file
3
Makefile
Normal file
|
@ -0,0 +1,3 @@
|
|||
all:
|
||||
cargo build
|
||||
RUST_LOG=netapp=debug cargo run --example fullmesh -- -n 3242ce79e05e8b6a0e43441fbd140a906e13f335f298ae3a52f29784abbab500 -p 6c304114a0e1018bbe60502a34d33f4f439f370856c3333dda2726da01eb93a4894b7ef7249a71f11d342b69702f1beb7c93ec95fbcf122ad1eca583bb0629e7
|
76
examples/basalt.rs
Normal file
76
examples/basalt.rs
Normal file
|
@ -0,0 +1,76 @@
|
|||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
|
||||
use log::info;
|
||||
|
||||
use structopt::StructOpt;
|
||||
|
||||
use sodiumoxide::crypto::auth;
|
||||
use sodiumoxide::crypto::sign::ed25519;
|
||||
|
||||
use netapp::netapp::*;
|
||||
use netapp::peering::basalt::*;
|
||||
|
||||
#[derive(StructOpt, Debug)]
|
||||
#[structopt(name = "netapp")]
|
||||
pub struct Opt {
|
||||
#[structopt(long = "network-key", short = "n")]
|
||||
network_key: Option<String>,
|
||||
|
||||
#[structopt(long = "private-key", short = "p")]
|
||||
private_key: Option<String>,
|
||||
|
||||
#[structopt(long = "bootstrap-peer", short = "b")]
|
||||
bootstrap_peers: Vec<String>,
|
||||
|
||||
#[structopt(long = "listen-addr", short = "l", default_value = "127.0.0.1:1980")]
|
||||
listen_addr: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
pretty_env_logger::init();
|
||||
|
||||
let opt = Opt::from_args();
|
||||
|
||||
let netid = match &opt.network_key {
|
||||
Some(k) => auth::Key::from_slice(&hex::decode(k).unwrap()).unwrap(),
|
||||
None => auth::gen_key(),
|
||||
};
|
||||
info!("Network key: {}", hex::encode(&netid));
|
||||
|
||||
let privkey = match &opt.private_key {
|
||||
Some(k) => ed25519::SecretKey::from_slice(&hex::decode(k).unwrap()).unwrap(),
|
||||
None => {
|
||||
let (_pk, sk) = ed25519::gen_keypair();
|
||||
sk
|
||||
}
|
||||
};
|
||||
|
||||
info!("Node private key: {}", hex::encode(&privkey));
|
||||
info!("Node public key: {}", hex::encode(&privkey.public_key()));
|
||||
|
||||
let listen_addr = opt.listen_addr.parse().unwrap();
|
||||
let netapp = NetApp::new(listen_addr, netid, privkey);
|
||||
|
||||
let mut bootstrap_peers = vec![];
|
||||
for peer in opt.bootstrap_peers.iter() {
|
||||
if let Some(delim) = peer.find('@') {
|
||||
let (key, ip) = peer.split_at(delim);
|
||||
let pubkey = ed25519::PublicKey::from_slice(&hex::decode(&key).unwrap()).unwrap();
|
||||
let ip = ip[1..].parse::<SocketAddr>().unwrap();
|
||||
bootstrap_peers.push((pubkey, ip));
|
||||
}
|
||||
}
|
||||
|
||||
let basalt_params = BasaltParams{
|
||||
view_size: 100,
|
||||
cache_size: 1000,
|
||||
exchange_interval: Duration::from_secs(1),
|
||||
reset_interval: Duration::from_secs(10),
|
||||
reset_count: 20,
|
||||
};
|
||||
let peering = Basalt::new(netapp.clone(), bootstrap_peers, basalt_params);
|
||||
|
||||
tokio::join!(netapp.listen(), peering.run(),);
|
||||
}
|
68
examples/fullmesh.rs
Normal file
68
examples/fullmesh.rs
Normal file
|
@ -0,0 +1,68 @@
|
|||
use std::net::SocketAddr;
|
||||
|
||||
use log::info;
|
||||
|
||||
use structopt::StructOpt;
|
||||
|
||||
use sodiumoxide::crypto::auth;
|
||||
use sodiumoxide::crypto::sign::ed25519;
|
||||
|
||||
use netapp::netapp::*;
|
||||
use netapp::peering::fullmesh::*;
|
||||
|
||||
#[derive(StructOpt, Debug)]
|
||||
#[structopt(name = "netapp")]
|
||||
pub struct Opt {
|
||||
#[structopt(long = "network-key", short = "n")]
|
||||
network_key: Option<String>,
|
||||
|
||||
#[structopt(long = "private-key", short = "p")]
|
||||
private_key: Option<String>,
|
||||
|
||||
#[structopt(long = "bootstrap-peer", short = "b")]
|
||||
bootstrap_peers: Vec<String>,
|
||||
|
||||
#[structopt(long = "listen-addr", short = "l", default_value = "127.0.0.1:1980")]
|
||||
listen_addr: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
pretty_env_logger::init();
|
||||
|
||||
let opt = Opt::from_args();
|
||||
|
||||
let netid = match &opt.network_key {
|
||||
Some(k) => auth::Key::from_slice(&hex::decode(k).unwrap()).unwrap(),
|
||||
None => auth::gen_key(),
|
||||
};
|
||||
info!("Network key: {}", hex::encode(&netid));
|
||||
|
||||
let privkey = match &opt.private_key {
|
||||
Some(k) => ed25519::SecretKey::from_slice(&hex::decode(k).unwrap()).unwrap(),
|
||||
None => {
|
||||
let (_pk, sk) = ed25519::gen_keypair();
|
||||
sk
|
||||
}
|
||||
};
|
||||
|
||||
info!("Node private key: {}", hex::encode(&privkey));
|
||||
info!("Node public key: {}", hex::encode(&privkey.public_key()));
|
||||
|
||||
let listen_addr = opt.listen_addr.parse().unwrap();
|
||||
let netapp = NetApp::new(listen_addr, netid, privkey);
|
||||
|
||||
let mut bootstrap_peers = vec![];
|
||||
for peer in opt.bootstrap_peers.iter() {
|
||||
if let Some(delim) = peer.find('@') {
|
||||
let (key, ip) = peer.split_at(delim);
|
||||
let pubkey = ed25519::PublicKey::from_slice(&hex::decode(&key).unwrap()).unwrap();
|
||||
let ip = ip[1..].parse::<SocketAddr>().unwrap();
|
||||
bootstrap_peers.push((pubkey, ip));
|
||||
}
|
||||
}
|
||||
|
||||
let peering = FullMeshPeeringStrategy::new(netapp.clone(), bootstrap_peers);
|
||||
|
||||
tokio::join!(netapp.listen(), peering.run(),);
|
||||
}
|
1
rustfmt.toml
Normal file
1
rustfmt.toml
Normal file
|
@ -0,0 +1 @@
|
|||
hard_tabs = true
|
280
src/conn.rs
Normal file
280
src/conn.rs
Normal file
|
@ -0,0 +1,280 @@
|
|||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{self, AtomicU16};
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use log::{debug, trace};
|
||||
|
||||
use sodiumoxide::crypto::sign::ed25519;
|
||||
use tokio::io::split;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, oneshot, watch};
|
||||
|
||||
use kuska_handshake::async_std::{
|
||||
handshake_client, handshake_server, BoxStream, TokioCompatExt, TokioCompatExtRead,
|
||||
TokioCompatExtWrite,
|
||||
};
|
||||
|
||||
use crate::error::*;
|
||||
use crate::message::*;
|
||||
use crate::netapp::*;
|
||||
use crate::proto::*;
|
||||
use crate::util::*;
|
||||
|
||||
pub struct ServerConn {
|
||||
netapp: Arc<NetApp>,
|
||||
pub remote_addr: SocketAddr,
|
||||
pub peer_pk: ed25519::PublicKey,
|
||||
resp_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>,
|
||||
close_send: watch::Sender<bool>,
|
||||
}
|
||||
|
||||
impl ServerConn {
|
||||
pub(crate) async fn run(netapp: Arc<NetApp>, socket: TcpStream) -> Result<(), Error> {
|
||||
let mut asyncstd_socket = TokioCompatExt::wrap(socket);
|
||||
let handshake = handshake_server(
|
||||
&mut asyncstd_socket,
|
||||
netapp.netid.clone(),
|
||||
netapp.pubkey.clone(),
|
||||
netapp.privkey.clone(),
|
||||
)
|
||||
.await?;
|
||||
let peer_pk = handshake.peer_pk.clone();
|
||||
|
||||
let tokio_socket = asyncstd_socket.into_inner();
|
||||
let remote_addr = tokio_socket.peer_addr().unwrap();
|
||||
|
||||
debug!(
|
||||
"Handshake complete (server) with {}@{}",
|
||||
hex::encode(&peer_pk),
|
||||
remote_addr
|
||||
);
|
||||
|
||||
let (read, write) = split(tokio_socket);
|
||||
|
||||
let read = TokioCompatExtRead::wrap(read);
|
||||
let write = TokioCompatExtWrite::wrap(write);
|
||||
|
||||
let (box_stream_read, box_stream_write) =
|
||||
BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();
|
||||
|
||||
let (resp_send, resp_recv) = mpsc::unbounded_channel();
|
||||
|
||||
let (close_send, close_recv) = watch::channel(false);
|
||||
|
||||
let conn = Arc::new(ServerConn {
|
||||
netapp: netapp.clone(),
|
||||
remote_addr,
|
||||
peer_pk: peer_pk.clone(),
|
||||
resp_send,
|
||||
close_send,
|
||||
});
|
||||
|
||||
netapp.connected_as_server(peer_pk.clone(), conn.clone());
|
||||
|
||||
let conn2 = conn.clone();
|
||||
let conn3 = conn.clone();
|
||||
tokio::try_join!(
|
||||
conn2.recv_loop(box_stream_read, close_recv.clone()),
|
||||
conn3.send_loop(resp_recv, box_stream_write, close_recv.clone()),
|
||||
)
|
||||
.map(|_| ())
|
||||
.log_err("ServerConn recv_loop/send_loop");
|
||||
|
||||
netapp.disconnected_as_server(&peer_pk, conn);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn close(&self) {
|
||||
self.close_send.broadcast(true).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl SendLoop for ServerConn {}
|
||||
|
||||
#[async_trait]
|
||||
impl RecvLoop for ServerConn {
|
||||
async fn recv_handler(self: Arc<Self>, id: u16, bytes: Vec<u8>) {
|
||||
let bytes: Bytes = bytes.into();
|
||||
|
||||
let prio = bytes[0];
|
||||
|
||||
let mut kind_bytes = [0u8; 4];
|
||||
kind_bytes.copy_from_slice(&bytes[1..5]);
|
||||
let kind = u32::from_be_bytes(kind_bytes);
|
||||
|
||||
if let Some(handler) = self.netapp.msg_handlers.load().get(&kind) {
|
||||
let resp = handler(self.peer_pk.clone(), bytes.slice(5..)).await;
|
||||
self.resp_send
|
||||
.send((id, prio, resp))
|
||||
.log_err("ServerConn recv_handler send resp");
|
||||
}
|
||||
}
|
||||
}
|
||||
pub struct ClientConn {
|
||||
pub netapp: Arc<NetApp>,
|
||||
pub remote_addr: SocketAddr,
|
||||
pub peer_pk: ed25519::PublicKey,
|
||||
query_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>,
|
||||
next_query_number: AtomicU16,
|
||||
resp_send: mpsc::UnboundedSender<(RequestID, Vec<u8>)>,
|
||||
resp_notify_send: mpsc::UnboundedSender<(RequestID, oneshot::Sender<Vec<u8>>)>,
|
||||
close_send: watch::Sender<bool>,
|
||||
}
|
||||
|
||||
impl ClientConn {
|
||||
pub(crate) async fn init(
|
||||
netapp: Arc<NetApp>,
|
||||
socket: TcpStream,
|
||||
remote_pk: ed25519::PublicKey,
|
||||
) -> Result<(), Error> {
|
||||
let mut asyncstd_socket = TokioCompatExt::wrap(socket);
|
||||
|
||||
let handshake = handshake_client(
|
||||
&mut asyncstd_socket,
|
||||
netapp.netid.clone(),
|
||||
netapp.pubkey.clone(),
|
||||
netapp.privkey.clone(),
|
||||
remote_pk.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let tokio_socket = asyncstd_socket.into_inner();
|
||||
let remote_addr = tokio_socket.peer_addr().unwrap();
|
||||
|
||||
debug!(
|
||||
"Handshake complete (client) with {}@{}",
|
||||
hex::encode(&remote_pk),
|
||||
remote_addr
|
||||
);
|
||||
|
||||
let (read, write) = split(tokio_socket);
|
||||
|
||||
let read = TokioCompatExtRead::wrap(read);
|
||||
let write = TokioCompatExtWrite::wrap(write);
|
||||
|
||||
let (box_stream_read, box_stream_write) =
|
||||
BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();
|
||||
|
||||
let (query_send, query_recv) = mpsc::unbounded_channel();
|
||||
let (resp_send, resp_recv) = mpsc::unbounded_channel();
|
||||
let (resp_notify_send, resp_notify_recv) = mpsc::unbounded_channel();
|
||||
|
||||
let (close_send, close_recv) = watch::channel(false);
|
||||
|
||||
let conn = Arc::new(ClientConn {
|
||||
netapp: netapp.clone(),
|
||||
remote_addr,
|
||||
peer_pk: remote_pk.clone(),
|
||||
next_query_number: AtomicU16::from(0u16),
|
||||
query_send,
|
||||
resp_send,
|
||||
resp_notify_send,
|
||||
close_send,
|
||||
});
|
||||
|
||||
netapp.connected_as_client(remote_pk.clone(), conn.clone());
|
||||
|
||||
tokio::spawn(async move {
|
||||
let conn2 = conn.clone();
|
||||
let conn3 = conn.clone();
|
||||
let conn4 = conn.clone();
|
||||
tokio::try_join!(
|
||||
conn2.send_loop(query_recv, box_stream_write, close_recv.clone()),
|
||||
conn3.recv_loop(box_stream_read, close_recv.clone()),
|
||||
conn4.dispatch_resp(resp_recv, resp_notify_recv, close_recv.clone()),
|
||||
)
|
||||
.map(|_| ())
|
||||
.log_err("ClientConn send_loop/recv_loop/dispatch_loop");
|
||||
|
||||
netapp.disconnected_as_client(&remote_pk, conn);
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn close(&self) {
|
||||
self.close_send.broadcast(true).unwrap();
|
||||
}
|
||||
|
||||
async fn dispatch_resp(
|
||||
self: Arc<Self>,
|
||||
mut resp_recv: mpsc::UnboundedReceiver<(RequestID, Vec<u8>)>,
|
||||
mut resp_notify_recv: mpsc::UnboundedReceiver<(RequestID, oneshot::Sender<Vec<u8>>)>,
|
||||
mut must_exit: watch::Receiver<bool>,
|
||||
) -> Result<(), Error> {
|
||||
let mut resps: HashMap<RequestID, Vec<u8>> = HashMap::new();
|
||||
let mut resp_notify: HashMap<RequestID, oneshot::Sender<Vec<u8>>> = HashMap::new();
|
||||
while !*must_exit.borrow() {
|
||||
tokio::select! {
|
||||
resp = resp_recv.recv() => {
|
||||
if let Some((id, resp)) = resp {
|
||||
trace!("dispatch_resp: got resp to {}, {} bytes", id, resp.len());
|
||||
if let Some(ch) = resp_notify.remove(&id) {
|
||||
ch.send(resp).map_err(|_| Error::Message("Could not dispatch reply".to_string()))?;
|
||||
} else {
|
||||
resps.insert(id, resp);
|
||||
}
|
||||
}
|
||||
}
|
||||
resp_ch = resp_notify_recv.recv() => {
|
||||
if let Some((id, resp_ch)) = resp_ch {
|
||||
trace!("dispatch_resp: got resp_ch {}", id);
|
||||
if let Some(rs) = resps.remove(&id) {
|
||||
resp_ch.send(rs).map_err(|_| Error::Message("Could not dispatch reply".to_string()))?;
|
||||
} else {
|
||||
resp_notify.insert(id, resp_ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
exit = must_exit.recv() => {
|
||||
if exit == Some(true) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn request<T>(
|
||||
self: Arc<Self>,
|
||||
rq: T,
|
||||
prio: RequestPriority,
|
||||
) -> Result<<T as Message>::Response, Error>
|
||||
where
|
||||
T: Message,
|
||||
{
|
||||
let id = self
|
||||
.next_query_number
|
||||
.fetch_add(1u16, atomic::Ordering::Relaxed);
|
||||
let mut bytes = vec![prio];
|
||||
bytes.extend_from_slice(&u32::to_be_bytes(T::KIND)[..]);
|
||||
bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]);
|
||||
|
||||
let (resp_send, resp_recv) = oneshot::channel();
|
||||
self.resp_notify_send.send((id, resp_send))?;
|
||||
|
||||
trace!("request: query_send {}, {} bytes", id, bytes.len());
|
||||
self.query_send.send((id, prio, bytes))?;
|
||||
|
||||
let resp = resp_recv.await?;
|
||||
|
||||
rmp_serde::decode::from_read_ref::<_, Result<<T as Message>::Response, String>>(&resp[..])?
|
||||
.map_err(Error::Remote)
|
||||
}
|
||||
}
|
||||
|
||||
impl SendLoop for ClientConn {}
|
||||
|
||||
#[async_trait]
|
||||
impl RecvLoop for ClientConn {
|
||||
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>) {
|
||||
self.resp_send
|
||||
.send((id, msg))
|
||||
.log_err("ClientConn::recv_handler");
|
||||
}
|
||||
}
|
57
src/error.rs
Normal file
57
src/error.rs
Normal file
|
@ -0,0 +1,57 @@
|
|||
use err_derive::Error;
|
||||
use std::io;
|
||||
|
||||
use log::error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error {
|
||||
#[error(display = "IO error: {}", _0)]
|
||||
Io(#[error(source)] io::Error),
|
||||
|
||||
#[error(display = "Messagepack encode error: {}", _0)]
|
||||
RMPEncode(#[error(source)] rmp_serde::encode::Error),
|
||||
#[error(display = "Messagepack decode error: {}", _0)]
|
||||
RMPDecode(#[error(source)] rmp_serde::decode::Error),
|
||||
|
||||
#[error(display = "Tokio join error: {}", _0)]
|
||||
TokioJoin(#[error(source)] tokio::task::JoinError),
|
||||
|
||||
#[error(display = "oneshot receive error: {}", _0)]
|
||||
OneshotRecv(#[error(source)] tokio::sync::oneshot::error::RecvError),
|
||||
|
||||
#[error(display = "Handshake error: {}", _0)]
|
||||
Handshake(#[error(source)] kuska_handshake::async_std::Error),
|
||||
|
||||
#[error(display = "{}", _0)]
|
||||
Message(String),
|
||||
|
||||
#[error(display = "Remote error: {}", _0)]
|
||||
Remote(String),
|
||||
}
|
||||
|
||||
impl<T> From<tokio::sync::watch::error::SendError<T>> for Error {
|
||||
fn from(_e: tokio::sync::watch::error::SendError<T>) -> Error {
|
||||
Error::Message(format!("Watch send error"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error {
|
||||
fn from(_e: tokio::sync::mpsc::error::SendError<T>) -> Error {
|
||||
Error::Message(format!("MPSC send error"))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LogError {
|
||||
fn log_err(self, msg: &'static str);
|
||||
}
|
||||
|
||||
impl<E> LogError for Result<(), E>
|
||||
where
|
||||
E: Into<Error>,
|
||||
{
|
||||
fn log_err(self, msg: &'static str) {
|
||||
if let Err(e) = self {
|
||||
error!("Error: {}: {}", msg, Into::<Error>::into(e));
|
||||
};
|
||||
}
|
||||
}
|
9
src/lib.rs
Normal file
9
src/lib.rs
Normal file
|
@ -0,0 +1,9 @@
|
|||
#![feature(map_first_last)]
|
||||
|
||||
pub mod conn;
|
||||
pub mod error;
|
||||
pub mod message;
|
||||
pub mod netapp;
|
||||
pub mod peering;
|
||||
pub mod proto;
|
||||
pub mod util;
|
18
src/message.rs
Normal file
18
src/message.rs
Normal file
|
@ -0,0 +1,18 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub type MessageKind = u32;
|
||||
|
||||
pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
|
||||
const KIND: MessageKind;
|
||||
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub(crate) struct HelloMessage {
|
||||
pub server_port: u16,
|
||||
}
|
||||
|
||||
impl Message for HelloMessage {
|
||||
const KIND: MessageKind = 0x42000001;
|
||||
type Response = ();
|
||||
}
|
214
src/netapp.rs
Normal file
214
src/netapp.rs
Normal file
|
@ -0,0 +1,214 @@
|
|||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use std::future::Future;
|
||||
|
||||
use log::{debug, info};
|
||||
|
||||
use arc_swap::{ArcSwap, ArcSwapOption};
|
||||
use bytes::Bytes;
|
||||
|
||||
use sodiumoxide::crypto::auth;
|
||||
use sodiumoxide::crypto::sign::ed25519;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
use crate::conn::*;
|
||||
use crate::error::*;
|
||||
use crate::message::*;
|
||||
use crate::proto::*;
|
||||
use crate::util::*;
|
||||
|
||||
pub struct NetApp {
|
||||
pub listen_addr: SocketAddr,
|
||||
pub netid: auth::Key,
|
||||
pub pubkey: ed25519::PublicKey,
|
||||
pub privkey: ed25519::SecretKey,
|
||||
pub server_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ServerConn>>>,
|
||||
pub client_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ClientConn>>>,
|
||||
pub(crate) msg_handlers: ArcSwap<
|
||||
HashMap<
|
||||
MessageKind,
|
||||
Arc<
|
||||
dyn Fn(
|
||||
ed25519::PublicKey,
|
||||
Bytes,
|
||||
) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>>
|
||||
+ Sync
|
||||
+ Send,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
pub(crate) on_connected:
|
||||
ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, SocketAddr, bool) + Send + Sync>>,
|
||||
pub(crate) on_disconnected: ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, bool) + Send + Sync>>,
|
||||
}
|
||||
|
||||
async fn handler_aux<M, F, R>(handler: Arc<F>, remote: ed25519::PublicKey, bytes: Bytes) -> Vec<u8>
|
||||
where
|
||||
M: Message + 'static,
|
||||
F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
|
||||
R: Future<Output = Result<<M as Message>::Response, Error>> + Send + Sync,
|
||||
{
|
||||
debug!(
|
||||
"Handling message of kind {:08x} from {}",
|
||||
M::KIND,
|
||||
hex::encode(remote)
|
||||
);
|
||||
let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) {
|
||||
Ok(msg) => handler(remote.clone(), msg).await,
|
||||
Err(e) => Err(e.into()),
|
||||
};
|
||||
let res = res.map_err(|e| format!("{}", e));
|
||||
rmp_to_vec_all_named(&res).unwrap_or(vec![])
|
||||
}
|
||||
|
||||
impl NetApp {
|
||||
pub fn new(
|
||||
listen_addr: SocketAddr,
|
||||
netid: auth::Key,
|
||||
privkey: ed25519::SecretKey,
|
||||
) -> Arc<Self> {
|
||||
let pubkey = privkey.public_key();
|
||||
let netapp = Arc::new(Self {
|
||||
listen_addr,
|
||||
netid,
|
||||
pubkey,
|
||||
privkey,
|
||||
server_conns: RwLock::new(HashMap::new()),
|
||||
client_conns: RwLock::new(HashMap::new()),
|
||||
msg_handlers: ArcSwap::new(Arc::new(HashMap::new())),
|
||||
on_connected: ArcSwapOption::new(None),
|
||||
on_disconnected: ArcSwapOption::new(None),
|
||||
});
|
||||
|
||||
let netapp2 = netapp.clone();
|
||||
netapp.add_msg_handler::<HelloMessage, _, _>(
|
||||
move |from: ed25519::PublicKey, msg: HelloMessage| {
|
||||
netapp2.handle_hello_message(from, msg);
|
||||
async { Ok(()) }
|
||||
},
|
||||
);
|
||||
|
||||
netapp
|
||||
}
|
||||
|
||||
pub fn add_msg_handler<M, F, R>(&self, handler: F)
|
||||
where
|
||||
M: Message + 'static,
|
||||
F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
|
||||
R: Future<Output = Result<<M as Message>::Response, Error>> + Send + Sync + 'static,
|
||||
{
|
||||
let handler = Arc::new(handler);
|
||||
let fun = Arc::new(move |remote: ed25519::PublicKey, bytes: Bytes| {
|
||||
let fun: Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> =
|
||||
Box::pin(handler_aux(handler.clone(), remote, bytes));
|
||||
fun
|
||||
});
|
||||
let mut handlers = self.msg_handlers.load().as_ref().clone();
|
||||
handlers.insert(M::KIND, fun);
|
||||
self.msg_handlers.store(Arc::new(handlers));
|
||||
}
|
||||
|
||||
pub async fn listen(self: Arc<Self>) {
|
||||
let mut listener = TcpListener::bind(self.listen_addr).await.unwrap();
|
||||
info!("Listening on {}", self.listen_addr);
|
||||
|
||||
loop {
|
||||
// The second item contains the IP and port of the new connection.
|
||||
let (socket, _) = listener.accept().await.unwrap();
|
||||
info!(
|
||||
"Incoming connection from {}, negotiating handshake...",
|
||||
socket.peer_addr().unwrap()
|
||||
);
|
||||
let self2 = self.clone();
|
||||
tokio::spawn(async move {
|
||||
ServerConn::run(self2, socket)
|
||||
.await
|
||||
.log_err("ServerConn::run");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn try_connect(
|
||||
self: Arc<Self>,
|
||||
ip: SocketAddr,
|
||||
pk: ed25519::PublicKey,
|
||||
) -> Result<(), Error> {
|
||||
if self.client_conns.read().unwrap().contains_key(&pk) {
|
||||
return Ok(());
|
||||
}
|
||||
let socket = TcpStream::connect(ip).await?;
|
||||
info!("Connected to {}, negotiating handshake...", ip);
|
||||
ClientConn::init(self, socket, pk.clone()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn disconnect(self: Arc<Self>, id: &ed25519::PublicKey) {
|
||||
let conn = self.client_conns.read().unwrap().get(id).cloned();
|
||||
if let Some(c) = conn {
|
||||
c.close();
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn connected_as_server(&self, id: ed25519::PublicKey, conn: Arc<ServerConn>) {
|
||||
let mut conn_list = self.server_conns.write().unwrap();
|
||||
conn_list.insert(id.clone(), conn);
|
||||
}
|
||||
|
||||
fn handle_hello_message(&self, id: ed25519::PublicKey, msg: HelloMessage) {
|
||||
if let Some(h) = self.on_connected.load().as_ref() {
|
||||
if let Some(c) = self.server_conns.read().unwrap().get(&id) {
|
||||
let remote_addr = SocketAddr::new(c.remote_addr.ip(), msg.server_port);
|
||||
h(id, remote_addr, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn disconnected_as_server(&self, id: &ed25519::PublicKey, conn: Arc<ServerConn>) {
|
||||
let mut conn_list = self.server_conns.write().unwrap();
|
||||
if let Some(c) = conn_list.get(id) {
|
||||
if Arc::ptr_eq(c, &conn) {
|
||||
conn_list.remove(id);
|
||||
}
|
||||
|
||||
if let Some(h) = self.on_disconnected.load().as_ref() {
|
||||
h(conn.peer_pk, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn connected_as_client(&self, id: ed25519::PublicKey, conn: Arc<ClientConn>) {
|
||||
{
|
||||
let mut conn_list = self.client_conns.write().unwrap();
|
||||
if let Some(old_c) = conn_list.insert(id.clone(), conn.clone()) {
|
||||
tokio::spawn(async move { old_c.close() });
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(h) = self.on_connected.load().as_ref() {
|
||||
h(conn.peer_pk, conn.remote_addr, false);
|
||||
}
|
||||
|
||||
tokio::spawn(async move {
|
||||
let server_port = conn.netapp.listen_addr.port();
|
||||
conn.request(HelloMessage { server_port }, prio::NORMAL)
|
||||
.await
|
||||
.log_err("Sending hello message");
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn disconnected_as_client(&self, id: &ed25519::PublicKey, conn: Arc<ClientConn>) {
|
||||
let mut conn_list = self.client_conns.write().unwrap();
|
||||
if let Some(c) = conn_list.get(id) {
|
||||
if Arc::ptr_eq(c, &conn) {
|
||||
conn_list.remove(id);
|
||||
}
|
||||
|
||||
if let Some(h) = self.on_disconnected.load().as_ref() {
|
||||
h(conn.peer_pk, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
475
src/peering/basalt.rs
Normal file
475
src/peering/basalt.rs
Normal file
|
@ -0,0 +1,475 @@
|
|||
use std::collections::HashSet;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use log::{debug, warn};
|
||||
use lru::LruCache;
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use sodiumoxide::crypto::hash;
|
||||
use sodiumoxide::crypto::sign::ed25519;
|
||||
|
||||
use crate::conn::*;
|
||||
use crate::message::*;
|
||||
use crate::netapp::*;
|
||||
use crate::proto::*;
|
||||
|
||||
// -- Protocol messages --
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct PullMessage {}
|
||||
|
||||
impl Message for PullMessage {
|
||||
const KIND: MessageKind = 0x42001100;
|
||||
type Response = PushMessage;
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct PushMessage {
|
||||
peers: Vec<Peer>,
|
||||
}
|
||||
|
||||
impl Message for PushMessage {
|
||||
const KIND: MessageKind = 0x42001101;
|
||||
type Response = ();
|
||||
}
|
||||
|
||||
// -- Algorithm data structures --
|
||||
|
||||
type Seed = [u8; 32];
|
||||
|
||||
#[derive(Hash, Clone, Copy, Debug, PartialOrd, PartialEq, Eq, Serialize, Deserialize)]
|
||||
struct Peer {
|
||||
id: ed25519::PublicKey,
|
||||
addr: SocketAddr,
|
||||
}
|
||||
|
||||
type Cost = [u8; 40];
|
||||
const MAX_COST: Cost = [0xffu8; 40];
|
||||
|
||||
impl Peer {
|
||||
fn cost(&self, seed: &Seed) -> Cost {
|
||||
let mut hasher = hash::State::new();
|
||||
hasher.update(&seed[..]);
|
||||
|
||||
let mut cost = [0u8; 40];
|
||||
match self.addr {
|
||||
SocketAddr::V4(v4addr) => {
|
||||
let v4ip = v4addr.ip().octets();
|
||||
|
||||
for i in 0..4 {
|
||||
let mut h = hasher.clone();
|
||||
h.update(&v4ip[..i + 1]);
|
||||
cost[i * 8..(i + 1) * 8].copy_from_slice(&h.finalize()[..8]);
|
||||
}
|
||||
}
|
||||
SocketAddr::V6(v6addr) => {
|
||||
let v6ip = v6addr.ip().octets();
|
||||
|
||||
for i in 0..4 {
|
||||
let mut h = hasher.clone();
|
||||
h.update(&v6ip[..i + 2]);
|
||||
cost[i * 8..(i + 1) * 8].copy_from_slice(&h.finalize()[..8]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut h5 = hasher.clone();
|
||||
h5.update(&format!("{}", self.addr).into_bytes()[..]);
|
||||
cost[32..40].copy_from_slice(&h5.finalize()[..8]);
|
||||
}
|
||||
|
||||
cost
|
||||
}
|
||||
}
|
||||
|
||||
struct BasaltSlot {
|
||||
seed: Seed,
|
||||
peer: Option<Peer>,
|
||||
}
|
||||
|
||||
impl BasaltSlot {
|
||||
fn cost(&self) -> Cost {
|
||||
self.peer.map(|p| p.cost(&self.seed)).unwrap_or(MAX_COST)
|
||||
}
|
||||
}
|
||||
|
||||
struct BasaltView {
|
||||
i_reset: usize,
|
||||
slots: Vec<BasaltSlot>,
|
||||
}
|
||||
|
||||
impl BasaltView {
|
||||
fn new(size: usize) -> Self {
|
||||
let slots = (0..size)
|
||||
.map(|_| BasaltSlot {
|
||||
seed: rand_seed(),
|
||||
peer: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
Self { i_reset: 0, slots }
|
||||
}
|
||||
|
||||
fn current_peers(&self) -> HashSet<Peer> {
|
||||
self.slots
|
||||
.iter()
|
||||
.filter(|s| s.peer.is_some())
|
||||
.map(|s| s.peer.unwrap().clone())
|
||||
.collect::<HashSet<_>>()
|
||||
}
|
||||
fn current_peers_vec(&self) -> Vec<Peer> {
|
||||
self.current_peers().drain().collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn sample(&self, count: usize) -> Vec<Peer> {
|
||||
let possibles = self
|
||||
.slots
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_i, s)| s.peer.is_some())
|
||||
.map(|(i, _s)| i)
|
||||
.collect::<Vec<_>>();
|
||||
if possibles.len() == 0 {
|
||||
vec![]
|
||||
} else {
|
||||
let mut ret = vec![];
|
||||
let mut rng = thread_rng();
|
||||
for _i in 0..count {
|
||||
let idx = rng.gen_range(0, possibles.len());
|
||||
ret.push(self.slots[possibles[idx]].peer.unwrap());
|
||||
}
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
fn update_slot(&mut self, i: usize, peers: &[Peer]) {
|
||||
let mut slot_cost = self.slots[i].cost();
|
||||
|
||||
for peer in peers.iter() {
|
||||
let peer_cost = peer.cost(&self.slots[i].seed);
|
||||
if self.slots[i].peer.is_none() || peer_cost < slot_cost {
|
||||
self.slots[i].peer = Some(*peer);
|
||||
slot_cost = peer_cost;
|
||||
}
|
||||
}
|
||||
}
|
||||
fn update_all_slots(&mut self, peers: &[Peer]) {
|
||||
for i in 0..self.slots.len() {
|
||||
self.update_slot(i, peers);
|
||||
}
|
||||
}
|
||||
|
||||
fn disconnected(&mut self, id: ed25519::PublicKey) {
|
||||
let mut cleared_slots = vec![];
|
||||
for i in 0..self.slots.len() {
|
||||
if let Some(p) = self.slots[i].peer {
|
||||
if p.id == id {
|
||||
self.slots[i].peer = None;
|
||||
cleared_slots.push(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let remaining_peers = self.current_peers_vec();
|
||||
|
||||
for i in cleared_slots {
|
||||
self.update_slot(i, &remaining_peers[..]);
|
||||
}
|
||||
}
|
||||
|
||||
fn should_try_list(&self, peers: &[Peer]) -> Vec<Peer> {
|
||||
// Select peers that have lower cost than any of our slots
|
||||
let mut ret = HashSet::new();
|
||||
|
||||
for i in 0..self.slots.len() {
|
||||
if self.slots[i].peer.is_none() {
|
||||
return peers.to_vec();
|
||||
}
|
||||
let mut min_cost = self.slots[i].cost();
|
||||
let mut min_peer = None;
|
||||
for peer in peers.iter() {
|
||||
if ret.contains(peer) {
|
||||
continue;
|
||||
}
|
||||
let peer_cost = peer.cost(&self.slots[i].seed);
|
||||
if peer_cost < min_cost {
|
||||
min_cost = peer_cost;
|
||||
min_peer = Some(*peer);
|
||||
}
|
||||
}
|
||||
if let Some(p) = min_peer {
|
||||
ret.insert(p);
|
||||
if ret.len() == peers.len() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ret.drain().collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn reset_some_slots(&mut self, count: usize) {
|
||||
for _i in 0..count {
|
||||
self.slots[self.i_reset].seed = rand_seed();
|
||||
self.i_reset = (self.i_reset + 1) % self.slots.len();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BasaltParams {
|
||||
pub view_size: usize,
|
||||
pub cache_size: usize,
|
||||
pub exchange_interval: Duration,
|
||||
pub reset_interval: Duration,
|
||||
pub reset_count: usize,
|
||||
}
|
||||
|
||||
pub struct Basalt {
|
||||
netapp: Arc<NetApp>,
|
||||
|
||||
param: BasaltParams,
|
||||
bootstrap_peers: Vec<Peer>,
|
||||
|
||||
view: RwLock<BasaltView>,
|
||||
current_attempts: RwLock<HashSet<Peer>>,
|
||||
backlog: RwLock<LruCache<Peer, ()>>,
|
||||
}
|
||||
|
||||
impl Basalt {
|
||||
pub fn new(
|
||||
netapp: Arc<NetApp>,
|
||||
bootstrap_list: Vec<(ed25519::PublicKey, SocketAddr)>,
|
||||
param: BasaltParams,
|
||||
) -> Arc<Self> {
|
||||
let bootstrap_peers = bootstrap_list
|
||||
.iter()
|
||||
.map(|(id, addr)| Peer {
|
||||
id: *id,
|
||||
addr: *addr,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let view = BasaltView::new(param.view_size);
|
||||
let backlog = LruCache::new(param.cache_size);
|
||||
|
||||
let basalt = Arc::new(Self {
|
||||
netapp: netapp.clone(),
|
||||
param,
|
||||
bootstrap_peers,
|
||||
view: RwLock::new(view),
|
||||
current_attempts: RwLock::new(HashSet::new()),
|
||||
backlog: RwLock::new(backlog),
|
||||
});
|
||||
|
||||
let basalt2 = basalt.clone();
|
||||
netapp.on_connected.store(Some(Arc::new(Box::new(
|
||||
move |pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool| {
|
||||
basalt2.on_connected(pk, addr, is_incoming);
|
||||
},
|
||||
))));
|
||||
|
||||
let basalt2 = basalt.clone();
|
||||
netapp.on_disconnected.store(Some(Arc::new(Box::new(
|
||||
move |pk: ed25519::PublicKey, is_incoming: bool| {
|
||||
basalt2.on_disconnected(pk, is_incoming);
|
||||
},
|
||||
))));
|
||||
|
||||
let basalt2 = basalt.clone();
|
||||
netapp.add_msg_handler::<PullMessage, _, _>(
|
||||
move |_from: ed25519::PublicKey, _pullmsg: PullMessage| {
|
||||
let push_msg = basalt2.make_push_message();
|
||||
async move { Ok(push_msg) }
|
||||
},
|
||||
);
|
||||
|
||||
let basalt2 = basalt.clone();
|
||||
netapp.add_msg_handler::<PushMessage, _, _>(
|
||||
move |_from: ed25519::PublicKey, push_msg: PushMessage| {
|
||||
basalt2.handle_peer_list(&push_msg.peers[..]);
|
||||
async move { Ok(()) }
|
||||
},
|
||||
);
|
||||
|
||||
basalt
|
||||
}
|
||||
|
||||
pub fn sample(&self, count: usize) -> Vec<ed25519::PublicKey> {
|
||||
self.view
|
||||
.read()
|
||||
.unwrap()
|
||||
.sample(count)
|
||||
.iter()
|
||||
.map(|p| p.id)
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub async fn run(self: Arc<Self>) {
|
||||
for peer in self.bootstrap_peers.iter() {
|
||||
tokio::spawn(self.clone().try_connect(*peer));
|
||||
}
|
||||
|
||||
let pushpull_loop = self.clone().run_pushpull_loop();
|
||||
let reset_loop = self.run_reset_loop();
|
||||
tokio::join!(pushpull_loop, reset_loop);
|
||||
}
|
||||
|
||||
async fn run_pushpull_loop(self: Arc<Self>) {
|
||||
loop {
|
||||
tokio::time::delay_for(self.param.exchange_interval).await;
|
||||
|
||||
let peers = self.view.read().unwrap().sample(2);
|
||||
if peers.len() == 2 {
|
||||
let (c1, c2) = {
|
||||
let client_conns = self.netapp.client_conns.read().unwrap();
|
||||
(
|
||||
client_conns.get(&peers[0].id).cloned(),
|
||||
client_conns.get(&peers[1].id).cloned(),
|
||||
)
|
||||
};
|
||||
if let Some(c) = c1 {
|
||||
tokio::spawn(self.clone().do_pull(c));
|
||||
}
|
||||
if let Some(c) = c2 {
|
||||
tokio::spawn(self.clone().do_push(c));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_pull(self: Arc<Self>, peer: Arc<ClientConn>) {
|
||||
match peer.request(PullMessage {}, prio::NORMAL).await {
|
||||
Ok(resp) => {
|
||||
self.handle_peer_list(&resp.peers[..]);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error during pull exchange: {}", e);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
async fn do_push(self: Arc<Self>, peer: Arc<ClientConn>) {
|
||||
let push_msg = self.make_push_message();
|
||||
if let Err(e) = peer.request(push_msg, prio::NORMAL).await {
|
||||
warn!("Error during push exchange: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
fn make_push_message(&self) -> PushMessage {
|
||||
let current_peers = self.view.read().unwrap().current_peers_vec();
|
||||
PushMessage {
|
||||
peers: current_peers,
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_reset_loop(self: Arc<Self>) {
|
||||
loop {
|
||||
tokio::time::delay_for(self.param.reset_interval).await;
|
||||
|
||||
{
|
||||
let mut view = self.view.write().unwrap();
|
||||
let prev_peers = view.current_peers();
|
||||
let prev_peers_vec = prev_peers.iter().cloned().collect::<Vec<_>>();
|
||||
|
||||
view.reset_some_slots(self.param.reset_count);
|
||||
view.update_all_slots(&prev_peers_vec[..]);
|
||||
|
||||
let new_peers = view.current_peers();
|
||||
drop(view);
|
||||
|
||||
self.close_all_diff(&prev_peers, &new_peers);
|
||||
}
|
||||
|
||||
let mut to_retry_maybe = self.bootstrap_peers.clone();
|
||||
for (peer, _) in self.backlog.read().unwrap().iter() {
|
||||
if !self.bootstrap_peers.contains(peer) {
|
||||
to_retry_maybe.push(*peer);
|
||||
}
|
||||
}
|
||||
self.handle_peer_list(&to_retry_maybe[..]);
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_peer_list(self: &Arc<Self>, peers: &[Peer]) {
|
||||
let to_connect = self.view.read().unwrap().should_try_list(peers);
|
||||
|
||||
for peer in to_connect.iter() {
|
||||
tokio::spawn(self.clone().try_connect(*peer));
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_connect(self: Arc<Self>, peer: Peer) {
|
||||
{
|
||||
let view = self.view.read().unwrap();
|
||||
let mut attempts = self.current_attempts.write().unwrap();
|
||||
|
||||
if view.slots.iter().any(|x| x.peer == Some(peer)) {
|
||||
return;
|
||||
}
|
||||
if attempts.contains(&peer) {
|
||||
return;
|
||||
}
|
||||
|
||||
attempts.insert(peer);
|
||||
}
|
||||
let res = self.netapp.clone().try_connect(peer.addr, peer.id).await;
|
||||
debug!("Connection attempt to {}: {:?}", peer.addr, res);
|
||||
|
||||
self.current_attempts.write().unwrap().remove(&peer);
|
||||
|
||||
if res.is_err() {
|
||||
self.backlog.write().unwrap().pop(&peer);
|
||||
}
|
||||
}
|
||||
|
||||
fn on_connected(self: &Arc<Self>, pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool) {
|
||||
if is_incoming {
|
||||
self.handle_peer_list(&[Peer{id: pk, addr}][..]);
|
||||
} else {
|
||||
let peer = Peer { id: pk, addr };
|
||||
|
||||
let mut backlog = self.backlog.write().unwrap();
|
||||
if backlog.get(&peer).is_none() {
|
||||
backlog.put(peer, ());
|
||||
}
|
||||
drop(backlog);
|
||||
|
||||
let mut view = self.view.write().unwrap();
|
||||
let prev_peers = view.current_peers();
|
||||
|
||||
view.update_all_slots(&[peer][..]);
|
||||
|
||||
let new_peers = view.current_peers();
|
||||
drop(view);
|
||||
|
||||
self.close_all_diff(&prev_peers, &new_peers);
|
||||
}
|
||||
}
|
||||
|
||||
fn on_disconnected(&self, pk: ed25519::PublicKey, is_incoming: bool) {
|
||||
if !is_incoming {
|
||||
self.view.write().unwrap().disconnected(pk);
|
||||
}
|
||||
}
|
||||
|
||||
fn close_all_diff(&self, prev_peers: &HashSet<Peer>, new_peers: &HashSet<Peer>) {
|
||||
let client_conns = self.netapp.client_conns.read().unwrap();
|
||||
for peer in prev_peers.iter() {
|
||||
if !new_peers.contains(peer) {
|
||||
if let Some(c) = client_conns.get(&peer.id) {
|
||||
debug!("Closing connection to {} ({})", hex::encode(peer.id), peer.addr);
|
||||
c.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rand_seed() -> Seed {
|
||||
let mut seed = [0u8; 32];
|
||||
sodiumoxide::randombytes::randombytes_into(&mut seed[..]);
|
||||
seed
|
||||
}
|
437
src/peering/fullmesh.rs
Normal file
437
src/peering/fullmesh.rs
Normal file
|
@ -0,0 +1,437 @@
|
|||
use std::collections::{HashMap, VecDeque};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{self, AtomicU64};
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use log::{debug, info, trace, warn};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use sodiumoxide::crypto::hash;
|
||||
use sodiumoxide::crypto::sign::ed25519;
|
||||
|
||||
use crate::conn::*;
|
||||
use crate::message::*;
|
||||
use crate::netapp::*;
|
||||
use crate::proto::*;
|
||||
|
||||
const CONN_RETRY_INTERVAL: Duration = Duration::from_secs(30);
|
||||
const CONN_MAX_RETRIES: usize = 10;
|
||||
const PING_INTERVAL: Duration = Duration::from_secs(10);
|
||||
const LOOP_DELAY: Duration = Duration::from_secs(1);
|
||||
|
||||
// -- Protocol messages --
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct PingMessage {
|
||||
pub id: u64,
|
||||
pub peer_list_hash: hash::Digest,
|
||||
}
|
||||
|
||||
impl Message for PingMessage {
|
||||
const KIND: MessageKind = 0x42001000;
|
||||
type Response = PingMessage;
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct PeerListMessage {
|
||||
pub list: Vec<(ed25519::PublicKey, SocketAddr)>,
|
||||
}
|
||||
|
||||
impl Message for PeerListMessage {
|
||||
const KIND: MessageKind = 0x42001001;
|
||||
type Response = PeerListMessage;
|
||||
}
|
||||
|
||||
// -- Algorithm data structures --
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PeerInfo {
|
||||
addr: SocketAddr,
|
||||
state: PeerConnState,
|
||||
last_seen: Option<Instant>,
|
||||
ping: VecDeque<Duration>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct PeerInfoPub {
|
||||
pub id: ed25519::PublicKey,
|
||||
pub addr: SocketAddr,
|
||||
pub state: PeerConnState,
|
||||
pub last_seen: Option<Instant>,
|
||||
pub avg_ping: Option<Duration>,
|
||||
pub max_ping: Option<Duration>,
|
||||
pub med_ping: Option<Duration>,
|
||||
}
|
||||
|
||||
// PeerConnState: possible states for our tentative connections to given peer
|
||||
// This module is only interested in recording connection info for outgoing
|
||||
// TCP connections
|
||||
#[derive(Copy, Clone, Debug, PartialEq)]
|
||||
pub enum PeerConnState {
|
||||
// This entry represents ourself
|
||||
Ourself,
|
||||
|
||||
// We currently have a connection to this peer
|
||||
Connected,
|
||||
|
||||
// Our next connection tentative (the nth, where n is the first value)
|
||||
// will be at given Instant
|
||||
Waiting(usize, Instant),
|
||||
|
||||
// A connection tentative is in progress
|
||||
Trying(usize),
|
||||
|
||||
// We abandonned trying to connect to this peer (too many failed attempts)
|
||||
Abandonned,
|
||||
}
|
||||
|
||||
struct KnownHosts {
|
||||
list: HashMap<ed25519::PublicKey, PeerInfo>,
|
||||
hash: hash::Digest,
|
||||
}
|
||||
|
||||
impl KnownHosts {
|
||||
fn new() -> Self {
|
||||
let list = HashMap::new();
|
||||
let hash = Self::calculate_hash(&list);
|
||||
Self { list, hash }
|
||||
}
|
||||
fn update_hash(&mut self) {
|
||||
self.hash = Self::calculate_hash(&self.list);
|
||||
}
|
||||
fn map_into_vec(
|
||||
input: &HashMap<ed25519::PublicKey, PeerInfo>,
|
||||
) -> Vec<(ed25519::PublicKey, SocketAddr)> {
|
||||
let mut list = Vec::with_capacity(input.len());
|
||||
for (id, peer) in input.iter() {
|
||||
if peer.state == PeerConnState::Connected || peer.state == PeerConnState::Ourself {
|
||||
list.push((id.clone(), peer.addr));
|
||||
}
|
||||
}
|
||||
list
|
||||
}
|
||||
fn calculate_hash(input: &HashMap<ed25519::PublicKey, PeerInfo>) -> hash::Digest {
|
||||
let mut list = Self::map_into_vec(input);
|
||||
list.sort();
|
||||
let mut hash_state = hash::State::new();
|
||||
for (id, addr) in list {
|
||||
hash_state.update(&id[..]);
|
||||
hash_state.update(&format!("{}", addr).into_bytes()[..]);
|
||||
}
|
||||
hash_state.finalize()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FullMeshPeeringStrategy {
|
||||
netapp: Arc<NetApp>,
|
||||
known_hosts: RwLock<KnownHosts>,
|
||||
next_ping_id: AtomicU64,
|
||||
}
|
||||
|
||||
impl FullMeshPeeringStrategy {
|
||||
pub fn new(
|
||||
netapp: Arc<NetApp>,
|
||||
bootstrap_list: Vec<(ed25519::PublicKey, SocketAddr)>,
|
||||
) -> Arc<Self> {
|
||||
let mut known_hosts = KnownHosts::new();
|
||||
for (pk, addr) in bootstrap_list {
|
||||
if pk != netapp.pubkey {
|
||||
known_hosts.list.insert(
|
||||
pk,
|
||||
PeerInfo {
|
||||
addr: addr,
|
||||
state: PeerConnState::Waiting(0, Instant::now()),
|
||||
last_seen: None,
|
||||
ping: VecDeque::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let strat = Arc::new(Self {
|
||||
netapp: netapp.clone(),
|
||||
known_hosts: RwLock::new(known_hosts),
|
||||
next_ping_id: AtomicU64::new(42),
|
||||
});
|
||||
|
||||
let strat2 = strat.clone();
|
||||
netapp.add_msg_handler::<PingMessage, _, _>(
|
||||
move |from: ed25519::PublicKey, ping: PingMessage| {
|
||||
let ping_resp = PingMessage {
|
||||
id: ping.id,
|
||||
peer_list_hash: strat2.known_hosts.read().unwrap().hash,
|
||||
};
|
||||
async move {
|
||||
debug!("Ping from {}", hex::encode(&from));
|
||||
Ok(ping_resp)
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let strat2 = strat.clone();
|
||||
netapp.add_msg_handler::<PeerListMessage, _, _>(
|
||||
move |_from: ed25519::PublicKey, peer_list: PeerListMessage| {
|
||||
strat2.handle_peer_list(&peer_list.list[..]);
|
||||
let peer_list = KnownHosts::map_into_vec(&strat2.known_hosts.read().unwrap().list);
|
||||
let resp = PeerListMessage { list: peer_list };
|
||||
async move { Ok(resp) }
|
||||
},
|
||||
);
|
||||
|
||||
let strat2 = strat.clone();
|
||||
netapp.on_connected.store(Some(Arc::new(Box::new(
|
||||
move |pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool| {
|
||||
let strat2 = strat2.clone();
|
||||
tokio::spawn(strat2.on_connected(pk, addr, is_incoming));
|
||||
},
|
||||
))));
|
||||
|
||||
let strat2 = strat.clone();
|
||||
netapp.on_disconnected.store(Some(Arc::new(Box::new(
|
||||
move |pk: ed25519::PublicKey, is_incoming: bool| {
|
||||
let strat2 = strat2.clone();
|
||||
tokio::spawn(strat2.on_disconnected(pk, is_incoming));
|
||||
},
|
||||
))));
|
||||
|
||||
strat
|
||||
}
|
||||
|
||||
pub async fn run(self: Arc<Self>) {
|
||||
loop {
|
||||
// 1. Read current state: get list of connected peers (ping them)
|
||||
let known_hosts = self.known_hosts.read().unwrap();
|
||||
debug!("known_hosts: {} peers", known_hosts.list.len());
|
||||
|
||||
let mut to_ping = vec![];
|
||||
let mut to_retry = vec![];
|
||||
for (id, info) in known_hosts.list.iter() {
|
||||
debug!("{}, {:?}", hex::encode(id), info);
|
||||
match info.state {
|
||||
PeerConnState::Connected => {
|
||||
let must_ping = match info.last_seen {
|
||||
None => true,
|
||||
Some(t) => Instant::now() - t > PING_INTERVAL,
|
||||
};
|
||||
if must_ping {
|
||||
to_ping.push(id.clone());
|
||||
}
|
||||
}
|
||||
PeerConnState::Waiting(_, t) => {
|
||||
if Instant::now() >= t {
|
||||
to_retry.push(id.clone());
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
drop(known_hosts);
|
||||
|
||||
// 2. Dispatch ping to hosts
|
||||
trace!("to_ping: {} peers", to_retry.len());
|
||||
for id in to_ping {
|
||||
tokio::spawn(self.clone().ping(id));
|
||||
}
|
||||
|
||||
// 3. Try reconnects
|
||||
trace!("to_retry: {} peers", to_retry.len());
|
||||
if !to_retry.is_empty() {
|
||||
let mut known_hosts = self.known_hosts.write().unwrap();
|
||||
for id in to_retry {
|
||||
if let Some(h) = known_hosts.list.get_mut(&id) {
|
||||
if let PeerConnState::Waiting(i, _) = h.state {
|
||||
info!(
|
||||
"Retrying connection to {} at {} ({})",
|
||||
hex::encode(&id),
|
||||
h.addr,
|
||||
i + 1
|
||||
);
|
||||
h.state = PeerConnState::Trying(i);
|
||||
tokio::spawn(self.clone().try_connect(id, h.addr.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Sleep before next loop iteration
|
||||
tokio::time::delay_for(LOOP_DELAY).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn ping(self: Arc<Self>, id: ed25519::PublicKey) {
|
||||
let peer = {
|
||||
match self.netapp.client_conns.read().unwrap().get(&id) {
|
||||
None => {
|
||||
warn!("Should ping {}, but no connection", hex::encode(id));
|
||||
return;
|
||||
}
|
||||
Some(peer) => peer.clone(),
|
||||
}
|
||||
};
|
||||
|
||||
let peer_list_hash = self.known_hosts.read().unwrap().hash;
|
||||
let ping_id = self.next_ping_id.fetch_add(1u64, atomic::Ordering::Relaxed);
|
||||
let ping_time = Instant::now();
|
||||
let ping_msg = PingMessage {
|
||||
id: ping_id,
|
||||
peer_list_hash,
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Sending ping {} to {} at {:?}",
|
||||
ping_id,
|
||||
hex::encode(id),
|
||||
ping_time
|
||||
);
|
||||
match peer.clone().request(ping_msg, prio::HIGH).await {
|
||||
Err(e) => warn!("Error pinging {}: {}", hex::encode(id), e),
|
||||
Ok(ping_resp) => {
|
||||
let resp_time = Instant::now();
|
||||
debug!(
|
||||
"Got ping response from {} at {:?}",
|
||||
hex::encode(id),
|
||||
resp_time
|
||||
);
|
||||
{
|
||||
let mut known_hosts = self.known_hosts.write().unwrap();
|
||||
if let Some(host) = known_hosts.list.get_mut(&id) {
|
||||
host.last_seen = Some(resp_time);
|
||||
host.ping.push_back(resp_time - ping_time);
|
||||
while host.ping.len() > 10 {
|
||||
host.ping.pop_front();
|
||||
}
|
||||
}
|
||||
}
|
||||
if ping_resp.peer_list_hash != peer_list_hash {
|
||||
self.exchange_peers(peer).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn exchange_peers(self: Arc<Self>, peer: Arc<ClientConn>) {
|
||||
let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list);
|
||||
let pex_message = PeerListMessage { list: peer_list };
|
||||
match peer.request(pex_message, prio::BACKGROUND).await {
|
||||
Err(e) => warn!("Error doing peer exchange: {}", e),
|
||||
Ok(resp) => {
|
||||
self.handle_peer_list(&resp.list[..]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_peer_list(&self, list: &[(ed25519::PublicKey, SocketAddr)]) {
|
||||
let mut known_hosts = self.known_hosts.write().unwrap();
|
||||
for (id, addr) in list.iter() {
|
||||
if !known_hosts.list.contains_key(id) {
|
||||
known_hosts.list.insert(*id, self.new_peer(id, *addr));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_connect(self: Arc<Self>, id: ed25519::PublicKey, addr: SocketAddr) {
|
||||
let conn_result = self.netapp.clone().try_connect(addr, id.clone()).await;
|
||||
if let Err(e) = conn_result {
|
||||
warn!("Error connecting to {}: {}", hex::encode(id), e);
|
||||
let mut known_hosts = self.known_hosts.write().unwrap();
|
||||
if let Some(host) = known_hosts.list.get_mut(&id) {
|
||||
host.state = match host.state {
|
||||
PeerConnState::Trying(i) => {
|
||||
if i >= CONN_MAX_RETRIES {
|
||||
PeerConnState::Abandonned
|
||||
} else {
|
||||
PeerConnState::Waiting(i + 1, Instant::now() + CONN_RETRY_INTERVAL)
|
||||
}
|
||||
}
|
||||
_ => PeerConnState::Waiting(0, Instant::now() + CONN_RETRY_INTERVAL),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_connected(
|
||||
self: Arc<Self>,
|
||||
pk: ed25519::PublicKey,
|
||||
addr: SocketAddr,
|
||||
is_incoming: bool,
|
||||
) {
|
||||
if is_incoming {
|
||||
if !self.known_hosts.read().unwrap().list.contains_key(&pk) {
|
||||
self.known_hosts
|
||||
.write()
|
||||
.unwrap()
|
||||
.list
|
||||
.insert(pk, self.new_peer(&pk, addr));
|
||||
}
|
||||
} else {
|
||||
info!("Successfully connected to {} at {}", hex::encode(&pk), addr);
|
||||
let mut known_hosts = self.known_hosts.write().unwrap();
|
||||
if let Some(host) = known_hosts.list.get_mut(&pk) {
|
||||
host.state = PeerConnState::Connected;
|
||||
known_hosts.update_hash();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_disconnected(self: Arc<Self>, pk: ed25519::PublicKey, is_incoming: bool) {
|
||||
if !is_incoming {
|
||||
info!("Connection to {} was closed", hex::encode(pk));
|
||||
let mut known_hosts = self.known_hosts.write().unwrap();
|
||||
if let Some(host) = known_hosts.list.get_mut(&pk) {
|
||||
host.state = PeerConnState::Waiting(0, Instant::now());
|
||||
known_hosts.update_hash();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_peer_list(&self) -> Vec<PeerInfoPub> {
|
||||
let known_hosts = self.known_hosts.read().unwrap();
|
||||
let mut ret = Vec::with_capacity(known_hosts.list.len());
|
||||
for (id, info) in known_hosts.list.iter() {
|
||||
let mut pings = info.ping.iter().cloned().collect::<Vec<_>>();
|
||||
pings.sort();
|
||||
if pings.len() > 0 {
|
||||
ret.push(PeerInfoPub {
|
||||
id: id.clone(),
|
||||
addr: info.addr,
|
||||
state: info.state,
|
||||
last_seen: info.last_seen,
|
||||
avg_ping: Some(
|
||||
pings
|
||||
.iter()
|
||||
.fold(Duration::from_secs(0), |x, y| x + *y)
|
||||
.div_f64(pings.len() as f64),
|
||||
),
|
||||
max_ping: pings.last().cloned(),
|
||||
med_ping: Some(pings[pings.len() / 2]),
|
||||
});
|
||||
} else {
|
||||
ret.push(PeerInfoPub {
|
||||
id: id.clone(),
|
||||
addr: info.addr,
|
||||
state: info.state,
|
||||
last_seen: info.last_seen,
|
||||
avg_ping: None,
|
||||
max_ping: None,
|
||||
med_ping: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
fn new_peer(&self, id: &ed25519::PublicKey, addr: SocketAddr) -> PeerInfo {
|
||||
let state = if *id == self.netapp.pubkey {
|
||||
PeerConnState::Ourself
|
||||
} else {
|
||||
PeerConnState::Waiting(0, Instant::now())
|
||||
};
|
||||
PeerInfo {
|
||||
addr,
|
||||
state,
|
||||
last_seen: None,
|
||||
ping: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
}
|
2
src/peering/mod.rs
Normal file
2
src/peering/mod.rs
Normal file
|
@ -0,0 +1,2 @@
|
|||
pub mod basalt;
|
||||
pub mod fullmesh;
|
251
src/proto.rs
Normal file
251
src/proto.rs
Normal file
|
@ -0,0 +1,251 @@
|
|||
use std::collections::{BTreeMap, HashMap, VecDeque};
|
||||
use std::sync::Arc;
|
||||
|
||||
use log::trace;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use async_std::io::prelude::WriteExt;
|
||||
use async_std::io::ReadExt;
|
||||
|
||||
use tokio::io::{ReadHalf, WriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
|
||||
use crate::error::*;
|
||||
|
||||
use kuska_handshake::async_std::{BoxStreamRead, BoxStreamWrite, TokioCompat};
|
||||
|
||||
const MAX_CHUNK_SIZE: usize = 0x4000;
|
||||
|
||||
pub mod prio {
|
||||
pub const HIGH: u8 = 0x20;
|
||||
pub const NORMAL: u8 = 0x40;
|
||||
pub const BACKGROUND: u8 = 0x80;
|
||||
|
||||
pub const PRIMARY: u8 = 0x00;
|
||||
pub const SECONDARY: u8 = 0x01;
|
||||
}
|
||||
|
||||
pub type RequestID = u16;
|
||||
pub type RequestPriority = u8;
|
||||
|
||||
struct SendQueueItem {
|
||||
id: RequestID,
|
||||
prio: RequestPriority,
|
||||
data: Vec<u8>,
|
||||
cursor: usize,
|
||||
}
|
||||
|
||||
struct SendQueue {
|
||||
items: BTreeMap<u8, VecDeque<SendQueueItem>>,
|
||||
}
|
||||
|
||||
impl SendQueue {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
items: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
fn push(&mut self, item: SendQueueItem) {
|
||||
let prio = item.prio;
|
||||
let mut items_at_prio = self
|
||||
.items
|
||||
.remove(&prio)
|
||||
.unwrap_or(VecDeque::with_capacity(4));
|
||||
items_at_prio.push_back(item);
|
||||
self.items.insert(prio, items_at_prio);
|
||||
}
|
||||
fn pop(&mut self) -> Option<SendQueueItem> {
|
||||
match self.items.pop_first() {
|
||||
None => None,
|
||||
Some((prio, mut items_at_prio)) => {
|
||||
let ret = items_at_prio.pop_front();
|
||||
if !items_at_prio.is_empty() {
|
||||
self.items.insert(prio, items_at_prio);
|
||||
}
|
||||
ret
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait SendLoop: Sync {
|
||||
async fn send_loop(
|
||||
self: Arc<Self>,
|
||||
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
|
||||
mut write: BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
|
||||
mut must_exit: watch::Receiver<bool>,
|
||||
) -> Result<(), Error> {
|
||||
let mut sending = SendQueue::new();
|
||||
while !*must_exit.borrow() {
|
||||
if let Ok((id, prio, data)) = msg_recv.try_recv() {
|
||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||
sending.push(SendQueueItem {
|
||||
id,
|
||||
prio,
|
||||
data,
|
||||
cursor: 0,
|
||||
});
|
||||
} else if let Some(mut item) = sending.pop() {
|
||||
trace!(
|
||||
"send_loop: sending bytes for {} ({} bytes, {} already sent)",
|
||||
item.id,
|
||||
item.data.len(),
|
||||
item.cursor
|
||||
);
|
||||
let header_id = u16::to_be_bytes(item.id);
|
||||
if write_all_or_exit(&header_id[..], &mut write, &mut must_exit)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
if item.data.len() - item.cursor > MAX_CHUNK_SIZE {
|
||||
let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000);
|
||||
if write_all_or_exit(&header_size[..], &mut write, &mut must_exit)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize;
|
||||
if write_all_or_exit(
|
||||
&item.data[item.cursor..new_cursor],
|
||||
&mut write,
|
||||
&mut must_exit,
|
||||
)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
item.cursor = new_cursor;
|
||||
|
||||
sending.push(item);
|
||||
} else {
|
||||
let send_len = (item.data.len() - item.cursor) as u16;
|
||||
|
||||
let header_size = u16::to_be_bytes(send_len);
|
||||
if write_all_or_exit(&header_size[..], &mut write, &mut must_exit)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
if write_all_or_exit(&item.data[item.cursor..], &mut write, &mut must_exit)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
write.flush().await.log_err("Could not flush in send_loop");
|
||||
} else {
|
||||
let (id, prio, data) = msg_recv
|
||||
.recv()
|
||||
.await
|
||||
.ok_or(Error::Message("Connection closed.".into()))?;
|
||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||
sending.push(SendQueueItem {
|
||||
id,
|
||||
prio,
|
||||
data,
|
||||
cursor: 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait RecvLoop: Sync + 'static {
|
||||
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>);
|
||||
|
||||
async fn recv_loop(
|
||||
self: Arc<Self>,
|
||||
mut read: BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
|
||||
mut must_exit: watch::Receiver<bool>,
|
||||
) -> Result<(), Error> {
|
||||
let mut receiving = HashMap::new();
|
||||
while !*must_exit.borrow() {
|
||||
trace!("recv_loop: reading packet");
|
||||
let mut header_id = [0u8; 2];
|
||||
if read_exact_or_exit(&mut header_id[..], &mut read, &mut must_exit)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
let id = RequestID::from_be_bytes(header_id);
|
||||
trace!("recv_loop: got header id: {:04x}", id);
|
||||
|
||||
let mut header_size = [0u8; 2];
|
||||
if read_exact_or_exit(&mut header_size[..], &mut read, &mut must_exit)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
let size = RequestID::from_be_bytes(header_size);
|
||||
trace!("recv_loop: got header size: {:04x}", id);
|
||||
|
||||
let has_cont = (size & 0x8000) != 0;
|
||||
let size = size & !0x8000;
|
||||
|
||||
let mut next_slice = vec![0; size as usize];
|
||||
if read_exact_or_exit(&mut next_slice[..], &mut read, &mut must_exit)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
trace!("recv_loop: read {} bytes", size);
|
||||
|
||||
let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]);
|
||||
msg_bytes.extend_from_slice(&next_slice[..]);
|
||||
|
||||
if has_cont {
|
||||
receiving.insert(id, msg_bytes);
|
||||
} else {
|
||||
tokio::spawn(self.clone().recv_handler(id, msg_bytes));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_exact_or_exit(
|
||||
buf: &mut [u8],
|
||||
read: &mut BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
|
||||
must_exit: &mut watch::Receiver<bool>,
|
||||
) -> Result<Option<()>, Error> {
|
||||
tokio::select!(
|
||||
res = read.read_exact(buf) => Ok(Some(res?)),
|
||||
_ = await_exit(must_exit) => Ok(None),
|
||||
)
|
||||
}
|
||||
|
||||
async fn write_all_or_exit(
|
||||
buf: &[u8],
|
||||
write: &mut BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
|
||||
must_exit: &mut watch::Receiver<bool>,
|
||||
) -> Result<Option<()>, Error> {
|
||||
tokio::select!(
|
||||
res = write.write_all(buf) => Ok(Some(res?)),
|
||||
_ = await_exit(must_exit) => Ok(None),
|
||||
)
|
||||
}
|
||||
|
||||
async fn await_exit(must_exit: &mut watch::Receiver<bool>) {
|
||||
loop {
|
||||
if must_exit.recv().await == Some(true) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
14
src/util.rs
Normal file
14
src/util.rs
Normal file
|
@ -0,0 +1,14 @@
|
|||
use serde::Serialize;
|
||||
|
||||
// util
|
||||
pub fn rmp_to_vec_all_named<T>(val: &T) -> Result<Vec<u8>, rmp_serde::encode::Error>
|
||||
where
|
||||
T: Serialize + ?Sized,
|
||||
{
|
||||
let mut wr = Vec::with_capacity(128);
|
||||
let mut se = rmp_serde::Serializer::new(&mut wr)
|
||||
.with_struct_map()
|
||||
.with_string_variants();
|
||||
val.serialize(&mut se)?;
|
||||
Ok(wr)
|
||||
}
|
Loading…
Reference in a new issue