First commit

0.2
Alex 2 years ago
commit d4de2ffc40
  1. 1
      .gitignore
  2. 1409
      Cargo.lock
  3. 27
      Cargo.toml
  4. 3
      Makefile
  5. 76
      examples/basalt.rs
  6. 68
      examples/fullmesh.rs
  7. 1
      rustfmt.toml
  8. 280
      src/conn.rs
  9. 57
      src/error.rs
  10. 9
      src/lib.rs
  11. 18
      src/message.rs
  12. 214
      src/netapp.rs
  13. 475
      src/peering/basalt.rs
  14. 437
      src/peering/fullmesh.rs
  15. 2
      src/peering/mod.rs
  16. 251
      src/proto.rs
  17. 14
      src/util.rs

1
.gitignore vendored

@ -0,0 +1 @@
target/

1409
Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -0,0 +1,27 @@
[package]
name = "netapp"
version = "0.1.0"
authors = ["Alex Auvolat <alex.auvolat@inria.fr>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
async-std = { version = "1.5.0", features=["unstable","attributes"] }
tokio = { version = "0.2", features = ["full"] }
kuska-handshake = { path = "../../handshake", features = ["default", "tokio_compat"] }
hex = "0.4.2"
log = "0.4.8"
pretty_env_logger = "0.4"
sodiumoxide = { git = "https://github.com/Dhole/sodiumoxidez", branch = "extra" }
env_logger = "0.7.1"
base64 = "0.12.1"
rmp-serde = "0.14.3"
serde = { version = "1.0", default-features = false, features = ["derive", "rc"] }
arc-swap = "1.0"
structopt = { version = "0.3", default-features = false }
async-trait = "0.1.7"
err-derive = "0.2.3"
bytes = "0.6.0"
lru = "0.6"
rand = "0.5.5"

@ -0,0 +1,3 @@
all:
cargo build
RUST_LOG=netapp=debug cargo run --example fullmesh -- -n 3242ce79e05e8b6a0e43441fbd140a906e13f335f298ae3a52f29784abbab500 -p 6c304114a0e1018bbe60502a34d33f4f439f370856c3333dda2726da01eb93a4894b7ef7249a71f11d342b69702f1beb7c93ec95fbcf122ad1eca583bb0629e7

@ -0,0 +1,76 @@
use std::net::SocketAddr;
use std::time::Duration;
use log::info;
use structopt::StructOpt;
use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519;
use netapp::netapp::*;
use netapp::peering::basalt::*;
#[derive(StructOpt, Debug)]
#[structopt(name = "netapp")]
pub struct Opt {
#[structopt(long = "network-key", short = "n")]
network_key: Option<String>,
#[structopt(long = "private-key", short = "p")]
private_key: Option<String>,
#[structopt(long = "bootstrap-peer", short = "b")]
bootstrap_peers: Vec<String>,
#[structopt(long = "listen-addr", short = "l", default_value = "127.0.0.1:1980")]
listen_addr: String,
}
#[tokio::main]
async fn main() {
pretty_env_logger::init();
let opt = Opt::from_args();
let netid = match &opt.network_key {
Some(k) => auth::Key::from_slice(&hex::decode(k).unwrap()).unwrap(),
None => auth::gen_key(),
};
info!("Network key: {}", hex::encode(&netid));
let privkey = match &opt.private_key {
Some(k) => ed25519::SecretKey::from_slice(&hex::decode(k).unwrap()).unwrap(),
None => {
let (_pk, sk) = ed25519::gen_keypair();
sk
}
};
info!("Node private key: {}", hex::encode(&privkey));
info!("Node public key: {}", hex::encode(&privkey.public_key()));
let listen_addr = opt.listen_addr.parse().unwrap();
let netapp = NetApp::new(listen_addr, netid, privkey);
let mut bootstrap_peers = vec![];
for peer in opt.bootstrap_peers.iter() {
if let Some(delim) = peer.find('@') {
let (key, ip) = peer.split_at(delim);
let pubkey = ed25519::PublicKey::from_slice(&hex::decode(&key).unwrap()).unwrap();
let ip = ip[1..].parse::<SocketAddr>().unwrap();
bootstrap_peers.push((pubkey, ip));
}
}
let basalt_params = BasaltParams{
view_size: 100,
cache_size: 1000,
exchange_interval: Duration::from_secs(1),
reset_interval: Duration::from_secs(10),
reset_count: 20,
};
let peering = Basalt::new(netapp.clone(), bootstrap_peers, basalt_params);
tokio::join!(netapp.listen(), peering.run(),);
}

@ -0,0 +1,68 @@
use std::net::SocketAddr;
use log::info;
use structopt::StructOpt;
use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519;
use netapp::netapp::*;
use netapp::peering::fullmesh::*;
#[derive(StructOpt, Debug)]
#[structopt(name = "netapp")]
pub struct Opt {
#[structopt(long = "network-key", short = "n")]
network_key: Option<String>,
#[structopt(long = "private-key", short = "p")]
private_key: Option<String>,
#[structopt(long = "bootstrap-peer", short = "b")]
bootstrap_peers: Vec<String>,
#[structopt(long = "listen-addr", short = "l", default_value = "127.0.0.1:1980")]
listen_addr: String,
}
#[tokio::main]
async fn main() {
pretty_env_logger::init();
let opt = Opt::from_args();
let netid = match &opt.network_key {
Some(k) => auth::Key::from_slice(&hex::decode(k).unwrap()).unwrap(),
None => auth::gen_key(),
};
info!("Network key: {}", hex::encode(&netid));
let privkey = match &opt.private_key {
Some(k) => ed25519::SecretKey::from_slice(&hex::decode(k).unwrap()).unwrap(),
None => {
let (_pk, sk) = ed25519::gen_keypair();
sk
}
};
info!("Node private key: {}", hex::encode(&privkey));
info!("Node public key: {}", hex::encode(&privkey.public_key()));
let listen_addr = opt.listen_addr.parse().unwrap();
let netapp = NetApp::new(listen_addr, netid, privkey);
let mut bootstrap_peers = vec![];
for peer in opt.bootstrap_peers.iter() {
if let Some(delim) = peer.find('@') {
let (key, ip) = peer.split_at(delim);
let pubkey = ed25519::PublicKey::from_slice(&hex::decode(&key).unwrap()).unwrap();
let ip = ip[1..].parse::<SocketAddr>().unwrap();
bootstrap_peers.push((pubkey, ip));
}
}
let peering = FullMeshPeeringStrategy::new(netapp.clone(), bootstrap_peers);
tokio::join!(netapp.listen(), peering.run(),);
}

@ -0,0 +1 @@
hard_tabs = true

@ -0,0 +1,280 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{self, AtomicU16};
use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use log::{debug, trace};
use sodiumoxide::crypto::sign::ed25519;
use tokio::io::split;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot, watch};
use kuska_handshake::async_std::{
handshake_client, handshake_server, BoxStream, TokioCompatExt, TokioCompatExtRead,
TokioCompatExtWrite,
};
use crate::error::*;
use crate::message::*;
use crate::netapp::*;
use crate::proto::*;
use crate::util::*;
pub struct ServerConn {
netapp: Arc<NetApp>,
pub remote_addr: SocketAddr,
pub peer_pk: ed25519::PublicKey,
resp_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>,
close_send: watch::Sender<bool>,
}
impl ServerConn {
pub(crate) async fn run(netapp: Arc<NetApp>, socket: TcpStream) -> Result<(), Error> {
let mut asyncstd_socket = TokioCompatExt::wrap(socket);
let handshake = handshake_server(
&mut asyncstd_socket,
netapp.netid.clone(),
netapp.pubkey.clone(),
netapp.privkey.clone(),
)
.await?;
let peer_pk = handshake.peer_pk.clone();
let tokio_socket = asyncstd_socket.into_inner();
let remote_addr = tokio_socket.peer_addr().unwrap();
debug!(
"Handshake complete (server) with {}@{}",
hex::encode(&peer_pk),
remote_addr
);
let (read, write) = split(tokio_socket);
let read = TokioCompatExtRead::wrap(read);
let write = TokioCompatExtWrite::wrap(write);
let (box_stream_read, box_stream_write) =
BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();
let (resp_send, resp_recv) = mpsc::unbounded_channel();
let (close_send, close_recv) = watch::channel(false);
let conn = Arc::new(ServerConn {
netapp: netapp.clone(),
remote_addr,
peer_pk: peer_pk.clone(),
resp_send,
close_send,
});
netapp.connected_as_server(peer_pk.clone(), conn.clone());
let conn2 = conn.clone();
let conn3 = conn.clone();
tokio::try_join!(
conn2.recv_loop(box_stream_read, close_recv.clone()),
conn3.send_loop(resp_recv, box_stream_write, close_recv.clone()),
)
.map(|_| ())
.log_err("ServerConn recv_loop/send_loop");
netapp.disconnected_as_server(&peer_pk, conn);
Ok(())
}
pub fn close(&self) {
self.close_send.broadcast(true).unwrap();
}
}
impl SendLoop for ServerConn {}
#[async_trait]
impl RecvLoop for ServerConn {
async fn recv_handler(self: Arc<Self>, id: u16, bytes: Vec<u8>) {
let bytes: Bytes = bytes.into();
let prio = bytes[0];
let mut kind_bytes = [0u8; 4];
kind_bytes.copy_from_slice(&bytes[1..5]);
let kind = u32::from_be_bytes(kind_bytes);
if let Some(handler) = self.netapp.msg_handlers.load().get(&kind) {
let resp = handler(self.peer_pk.clone(), bytes.slice(5..)).await;
self.resp_send
.send((id, prio, resp))
.log_err("ServerConn recv_handler send resp");
}
}
}
pub struct ClientConn {
pub netapp: Arc<NetApp>,
pub remote_addr: SocketAddr,
pub peer_pk: ed25519::PublicKey,
query_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>,
next_query_number: AtomicU16,
resp_send: mpsc::UnboundedSender<(RequestID, Vec<u8>)>,
resp_notify_send: mpsc::UnboundedSender<(RequestID, oneshot::Sender<Vec<u8>>)>,
close_send: watch::Sender<bool>,
}
impl ClientConn {
pub(crate) async fn init(
netapp: Arc<NetApp>,
socket: TcpStream,
remote_pk: ed25519::PublicKey,
) -> Result<(), Error> {
let mut asyncstd_socket = TokioCompatExt::wrap(socket);
let handshake = handshake_client(
&mut asyncstd_socket,
netapp.netid.clone(),
netapp.pubkey.clone(),
netapp.privkey.clone(),
remote_pk.clone(),
)
.await?;
let tokio_socket = asyncstd_socket.into_inner();
let remote_addr = tokio_socket.peer_addr().unwrap();
debug!(
"Handshake complete (client) with {}@{}",
hex::encode(&remote_pk),
remote_addr
);
let (read, write) = split(tokio_socket);
let read = TokioCompatExtRead::wrap(read);
let write = TokioCompatExtWrite::wrap(write);
let (box_stream_read, box_stream_write) =
BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();
let (query_send, query_recv) = mpsc::unbounded_channel();
let (resp_send, resp_recv) = mpsc::unbounded_channel();
let (resp_notify_send, resp_notify_recv) = mpsc::unbounded_channel();
let (close_send, close_recv) = watch::channel(false);
let conn = Arc::new(ClientConn {
netapp: netapp.clone(),
remote_addr,
peer_pk: remote_pk.clone(),
next_query_number: AtomicU16::from(0u16),
query_send,
resp_send,
resp_notify_send,
close_send,
});
netapp.connected_as_client(remote_pk.clone(), conn.clone());
tokio::spawn(async move {
let conn2 = conn.clone();
let conn3 = conn.clone();
let conn4 = conn.clone();
tokio::try_join!(
conn2.send_loop(query_recv, box_stream_write, close_recv.clone()),
conn3.recv_loop(box_stream_read, close_recv.clone()),
conn4.dispatch_resp(resp_recv, resp_notify_recv, close_recv.clone()),
)
.map(|_| ())
.log_err("ClientConn send_loop/recv_loop/dispatch_loop");
netapp.disconnected_as_client(&remote_pk, conn);
});
Ok(())
}
pub fn close(&self) {
self.close_send.broadcast(true).unwrap();
}
async fn dispatch_resp(
self: Arc<Self>,
mut resp_recv: mpsc::UnboundedReceiver<(RequestID, Vec<u8>)>,
mut resp_notify_recv: mpsc::UnboundedReceiver<(RequestID, oneshot::Sender<Vec<u8>>)>,
mut must_exit: watch::Receiver<bool>,
) -> Result<(), Error> {
let mut resps: HashMap<RequestID, Vec<u8>> = HashMap::new();
let mut resp_notify: HashMap<RequestID, oneshot::Sender<Vec<u8>>> = HashMap::new();
while !*must_exit.borrow() {
tokio::select! {
resp = resp_recv.recv() => {
if let Some((id, resp)) = resp {
trace!("dispatch_resp: got resp to {}, {} bytes", id, resp.len());
if let Some(ch) = resp_notify.remove(&id) {
ch.send(resp).map_err(|_| Error::Message("Could not dispatch reply".to_string()))?;
} else {
resps.insert(id, resp);
}
}
}
resp_ch = resp_notify_recv.recv() => {
if let Some((id, resp_ch)) = resp_ch {
trace!("dispatch_resp: got resp_ch {}", id);
if let Some(rs) = resps.remove(&id) {
resp_ch.send(rs).map_err(|_| Error::Message("Could not dispatch reply".to_string()))?;
} else {
resp_notify.insert(id, resp_ch);
}
}
}
exit = must_exit.recv() => {
if exit == Some(true) {
break;
}
}
}
}
Ok(())
}
pub async fn request<T>(
self: Arc<Self>,
rq: T,
prio: RequestPriority,
) -> Result<<T as Message>::Response, Error>
where
T: Message,
{
let id = self
.next_query_number
.fetch_add(1u16, atomic::Ordering::Relaxed);
let mut bytes = vec![prio];
bytes.extend_from_slice(&u32::to_be_bytes(T::KIND)[..]);
bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]);
let (resp_send, resp_recv) = oneshot::channel();
self.resp_notify_send.send((id, resp_send))?;
trace!("request: query_send {}, {} bytes", id, bytes.len());
self.query_send.send((id, prio, bytes))?;
let resp = resp_recv.await?;
rmp_serde::decode::from_read_ref::<_, Result<<T as Message>::Response, String>>(&resp[..])?
.map_err(Error::Remote)
}
}
impl SendLoop for ClientConn {}
#[async_trait]
impl RecvLoop for ClientConn {
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>) {
self.resp_send
.send((id, msg))
.log_err("ClientConn::recv_handler");
}
}

@ -0,0 +1,57 @@
use err_derive::Error;
use std::io;
use log::error;
#[derive(Debug, Error)]
pub enum Error {
#[error(display = "IO error: {}", _0)]
Io(#[error(source)] io::Error),
#[error(display = "Messagepack encode error: {}", _0)]
RMPEncode(#[error(source)] rmp_serde::encode::Error),
#[error(display = "Messagepack decode error: {}", _0)]
RMPDecode(#[error(source)] rmp_serde::decode::Error),
#[error(display = "Tokio join error: {}", _0)]
TokioJoin(#[error(source)] tokio::task::JoinError),
#[error(display = "oneshot receive error: {}", _0)]
OneshotRecv(#[error(source)] tokio::sync::oneshot::error::RecvError),
#[error(display = "Handshake error: {}", _0)]
Handshake(#[error(source)] kuska_handshake::async_std::Error),
#[error(display = "{}", _0)]
Message(String),
#[error(display = "Remote error: {}", _0)]
Remote(String),
}
impl<T> From<tokio::sync::watch::error::SendError<T>> for Error {
fn from(_e: tokio::sync::watch::error::SendError<T>) -> Error {
Error::Message(format!("Watch send error"))
}
}
impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error {
fn from(_e: tokio::sync::mpsc::error::SendError<T>) -> Error {
Error::Message(format!("MPSC send error"))
}
}
pub trait LogError {
fn log_err(self, msg: &'static str);
}
impl<E> LogError for Result<(), E>
where
E: Into<Error>,
{
fn log_err(self, msg: &'static str) {
if let Err(e) = self {
error!("Error: {}: {}", msg, Into::<Error>::into(e));
};
}
}

@ -0,0 +1,9 @@
#![feature(map_first_last)]
pub mod conn;
pub mod error;
pub mod message;
pub mod netapp;
pub mod peering;
pub mod proto;
pub mod util;

@ -0,0 +1,18 @@
use serde::{Deserialize, Serialize};
pub type MessageKind = u32;
pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
const KIND: MessageKind;
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
}
#[derive(Serialize, Deserialize)]
pub(crate) struct HelloMessage {
pub server_port: u16,
}
impl Message for HelloMessage {
const KIND: MessageKind = 0x42000001;
type Response = ();
}

@ -0,0 +1,214 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use std::future::Future;
use log::{debug, info};
use arc_swap::{ArcSwap, ArcSwapOption};
use bytes::Bytes;
use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519;
use tokio::net::{TcpListener, TcpStream};
use crate::conn::*;
use crate::error::*;
use crate::message::*;
use crate::proto::*;
use crate::util::*;
pub struct NetApp {
pub listen_addr: SocketAddr,
pub netid: auth::Key,
pub pubkey: ed25519::PublicKey,
pub privkey: ed25519::SecretKey,
pub server_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ServerConn>>>,
pub client_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ClientConn>>>,
pub(crate) msg_handlers: ArcSwap<
HashMap<
MessageKind,
Arc<
dyn Fn(
ed25519::PublicKey,
Bytes,
) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>>
+ Sync
+ Send,
>,
>,
>,
pub(crate) on_connected:
ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, SocketAddr, bool) + Send + Sync>>,
pub(crate) on_disconnected: ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, bool) + Send + Sync>>,
}
async fn handler_aux<M, F, R>(handler: Arc<F>, remote: ed25519::PublicKey, bytes: Bytes) -> Vec<u8>
where
M: Message + 'static,
F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
R: Future<Output = Result<<M as Message>::Response, Error>> + Send + Sync,
{
debug!(
"Handling message of kind {:08x} from {}",
M::KIND,
hex::encode(remote)
);
let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) {
Ok(msg) => handler(remote.clone(), msg).await,
Err(e) => Err(e.into()),
};
let res = res.map_err(|e| format!("{}", e));
rmp_to_vec_all_named(&res).unwrap_or(vec![])
}
impl NetApp {
pub fn new(
listen_addr: SocketAddr,
netid: auth::Key,
privkey: ed25519::SecretKey,
) -> Arc<Self> {
let pubkey = privkey.public_key();
let netapp = Arc::new(Self {
listen_addr,
netid,
pubkey,
privkey,
server_conns: RwLock::new(HashMap::new()),
client_conns: RwLock::new(HashMap::new()),
msg_handlers: ArcSwap::new(Arc::new(HashMap::new())),
on_connected: ArcSwapOption::new(None),
on_disconnected: ArcSwapOption::new(None),
});
let netapp2 = netapp.clone();
netapp.add_msg_handler::<HelloMessage, _, _>(
move |from: ed25519::PublicKey, msg: HelloMessage| {
netapp2.handle_hello_message(from, msg);
async { Ok(()) }
},
);
netapp
}
pub fn add_msg_handler<M, F, R>(&self, handler: F)
where
M: Message + 'static,
F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
R: Future<Output = Result<<M as Message>::Response, Error>> + Send + Sync + 'static,
{
let handler = Arc::new(handler);
let fun = Arc::new(move |remote: ed25519::PublicKey, bytes: Bytes| {
let fun: Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> =
Box::pin(handler_aux(handler.clone(), remote, bytes));
fun
});
let mut handlers = self.msg_handlers.load().as_ref().clone();
handlers.insert(M::KIND, fun);
self.msg_handlers.store(Arc::new(handlers));
}
pub async fn listen(self: Arc<Self>) {
let mut listener = TcpListener::bind(self.listen_addr).await.unwrap();
info!("Listening on {}", self.listen_addr);
loop {
// The second item contains the IP and port of the new connection.
let (socket, _) = listener.accept().await.unwrap();
info!(
"Incoming connection from {}, negotiating handshake...",
socket.peer_addr().unwrap()
);
let self2 = self.clone();
tokio::spawn(async move {
ServerConn::run(self2, socket)
.await
.log_err("ServerConn::run");
});
}
}
pub async fn try_connect(
self: Arc<Self>,
ip: SocketAddr,
pk: ed25519::PublicKey,
) -> Result<(), Error> {
if self.client_conns.read().unwrap().contains_key(&pk) {
return Ok(());
}
let socket = TcpStream::connect(ip).await?;
info!("Connected to {}, negotiating handshake...", ip);
ClientConn::init(self, socket, pk.clone()).await?;
Ok(())
}
pub fn disconnect(self: Arc<Self>, id: &ed25519::PublicKey) {
let conn = self.client_conns.read().unwrap().get(id).cloned();
if let Some(c) = conn {
c.close();
}
}
pub(crate) fn connected_as_server(&self, id: ed25519::PublicKey, conn: Arc<ServerConn>) {
let mut conn_list = self.server_conns.write().unwrap();
conn_list.insert(id.clone(), conn);
}
fn handle_hello_message(&self, id: ed25519::PublicKey, msg: HelloMessage) {
if let Some(h) = self.on_connected.load().as_ref() {
if let Some(c) = self.server_conns.read().unwrap().get(&id) {
let remote_addr = SocketAddr::new(c.remote_addr.ip(), msg.server_port);
h(id, remote_addr, true);
}
}
}
pub(crate) fn disconnected_as_server(&self, id: &ed25519::PublicKey, conn: Arc<ServerConn>) {
let mut conn_list = self.server_conns.write().unwrap();
if let Some(c) = conn_list.get(id) {
if Arc::ptr_eq(c, &conn) {
conn_list.remove(id);
}
if let Some(h) = self.on_disconnected.load().as_ref() {
h(conn.peer_pk, true);
}
}
}
pub(crate) fn connected_as_client(&self, id: ed25519::PublicKey, conn: Arc<ClientConn>) {
{
let mut conn_list = self.client_conns.write().unwrap();
if let Some(old_c) = conn_list.insert(id.clone(), conn.clone()) {
tokio::spawn(async move { old_c.close() });
}
}
if let Some(h) = self.on_connected.load().as_ref() {
h(conn.peer_pk, conn.remote_addr, false);
}
tokio::spawn(async move {
let server_port = conn.netapp.listen_addr.port();
conn.request(HelloMessage { server_port }, prio::NORMAL)
.await
.log_err("Sending hello message");
});
}
pub(crate) fn disconnected_as_client(&self, id: &ed25519::PublicKey, conn: Arc<ClientConn>) {
let mut conn_list = self.client_conns.write().unwrap();
if let Some(c) = conn_list.get(id) {
if Arc::ptr_eq(c, &conn) {
conn_list.remove(id);
}
if let Some(h) = self.on_disconnected.load().as_ref() {
h(conn.peer_pk, false);
}
}
}
}

@ -0,0 +1,475 @@
use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use log::{debug, warn};
use lru::LruCache;
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use sodiumoxide::crypto::hash;
use sodiumoxide::crypto::sign::ed25519;
use crate::conn::*;
use crate::message::*;
use crate::netapp::*;
use crate::proto::*;
// -- Protocol messages --
#[derive(Serialize, Deserialize)]
struct PullMessage {}
impl Message for PullMessage {
const KIND: MessageKind = 0x42001100;
type Response = PushMessage;
}
#[derive(Serialize, Deserialize)]
struct PushMessage {
peers: Vec<Peer>,
}
impl Message for PushMessage {
const KIND: MessageKind = 0x42001101;
type Response = ();
}
// -- Algorithm data structures --
type Seed = [u8; 32];
#[derive(Hash, Clone, Copy, Debug, PartialOrd, PartialEq, Eq, Serialize, Deserialize)]
struct Peer {
id: ed25519::PublicKey,
addr: SocketAddr,
}
type Cost = [u8; 40];
const MAX_COST: Cost = [0xffu8; 40];
impl Peer {
fn cost(&self, seed: &Seed) -> Cost {
let mut hasher = hash::State::new();
hasher.update(&seed[..]);
let mut cost = [0u8; 40];
match self.addr {
SocketAddr::V4(v4addr) => {
let v4ip = v4addr.ip().octets();
for i in 0..4 {
let mut h = hasher.clone();
h.update(&v4ip[..i + 1]);
cost[i * 8..(i + 1) * 8].copy_from_slice(&h.finalize()[..8]);
}
}
SocketAddr::V6(v6addr) => {
let v6ip = v6addr.ip().octets();
for i in 0..4 {
let mut h = hasher.clone();
h.update(&v6ip[..i + 2]);
cost[i * 8..(i + 1) * 8].copy_from_slice(&h.finalize()[..8]);
}
}
}
{
let mut h5 = hasher.clone();
h5.update(&format!("{}", self.addr).into_bytes()[..]);
cost[32..40].copy_from_slice(&h5.finalize()[..8]);
}
cost
}
}
struct BasaltSlot {
seed: Seed,
peer: Option<Peer>,
}
impl BasaltSlot {
fn cost(&self) -> Cost {
self.peer.map(|p| p.cost(&self.seed)).unwrap_or(MAX_COST)
}
}
struct BasaltView {
i_reset: usize,
slots: Vec<BasaltSlot>,
}
impl BasaltView {
fn new(size: usize) -> Self {
let slots = (0..size)
.map(|_| BasaltSlot {
seed: rand_seed(),
peer: None,
})
.collect::<Vec<_>>();
Self { i_reset: 0, slots }
}
fn current_peers(&self) -> HashSet<Peer> {
self.slots
.iter()
.filter(|s| s.peer.is_some())
.map(|s| s.peer.unwrap().clone())
.collect::<HashSet<_>>()
}
fn current_peers_vec(&self) -> Vec<Peer> {
self.current_peers().drain().collect::<Vec<_>>()
}
fn sample(&self, count: usize) -> Vec<Peer> {
let possibles = self
.slots
.iter()
.enumerate()
.filter(|(_i, s)| s.peer.is_some())
.map(|(i, _s)| i)
.collect::<Vec<_>>();
if possibles.len() == 0 {
vec![]
} else {
let mut ret = vec![];
let mut rng = thread_rng();
for _i in 0..count {
let idx = rng.gen_range(0, possibles.len());
ret.push(self.slots[possibles[idx]].peer.unwrap());
}
ret
}
}
fn update_slot(&mut self, i: usize, peers: &[Peer]) {
let mut slot_cost = self.slots[i].cost();
for peer in peers.iter() {
let peer_cost = peer.cost(&self.slots[i].seed);
if self.slots[i].peer.is_none() || peer_cost < slot_cost {
self.slots[i].peer = Some(*peer);
slot_cost = peer_cost;
}
}
}
fn update_all_slots(&mut self, peers: &[Peer]) {
for i in 0..self.slots.len() {
self.update_slot(i, peers);
}
}
fn disconnected(&mut self, id: ed25519::PublicKey) {
let mut cleared_slots = vec![];
for i in 0..self.slots.len() {
if let Some(p) = self.slots[i].peer {
if p.id == id {
self.slots[i].peer = None;
cleared_slots.push(i);
}
}
}
let remaining_peers = self.current_peers_vec();
for i in cleared_slots {
self.update_slot(i, &remaining_peers[..]);
}
}
fn should_try_list(&self, peers: &[Peer]) -> Vec<Peer> {
// Select peers that have lower cost than any of our slots
let mut ret = HashSet::new();
for i in 0..self.slots.len() {
if self.slots[i].peer.is_none() {
return peers.to_vec();
}
let mut min_cost = self.slots[i].cost();
let mut min_peer = None;
for peer in peers.iter() {
if ret.contains(peer) {
continue;
}
let peer_cost = peer.cost(&self.slots[i].seed);
if peer_cost < min_cost {
min_cost = peer_cost;
min_peer = Some(*peer);
}
}
if let Some(p) = min_peer {
ret.insert(p);
if ret.len() == peers.len() {
break;
}
}
}
ret.drain().collect::<Vec<_>>()
}
fn reset_some_slots(&mut self, count: usize) {
for _i in 0..count {
self.slots[self.i_reset].seed = rand_seed();
self.i_reset = (self.i_reset + 1) % self.slots.len();
}
}
}
pub struct BasaltParams {
pub view_size: usize,
pub cache_size: usize,
pub exchange_interval: Duration,
pub reset_interval: Duration,
pub reset_count: usize,
}
pub struct Basalt {
netapp: Arc<NetApp>,
param: BasaltParams,
bootstrap_peers: Vec<Peer>,
view: RwLock<BasaltView>,
current_attempts: RwLock<HashSet<Peer>>,
backlog: RwLock<LruCache<Peer, ()>>,
}
impl Basalt {
pub fn new(
netapp: Arc<NetApp>,
bootstrap_list: Vec<(ed25519::PublicKey, SocketAddr)>,
param: BasaltParams,
) -> Arc<Self> {
let bootstrap_peers = bootstrap_list
.iter()
.map(|(id, addr)| Peer {
id: *id,
addr: *addr,
})
.collect::<Vec<_>>();
let view = BasaltView::new(param.view_size);
let backlog = LruCache::new(param.cache_size);
let basalt = Arc::new(Self {
netapp: netapp.clone(),
param,
bootstrap_peers,
view: RwLock::new(view),
current_attempts: RwLock::new(HashSet::new()),
backlog: RwLock::new(backlog),
});
let basalt2 = basalt.clone();
netapp.on_connected.store(Some(Arc::new(Box::new(
move |pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool| {
basalt2.on_connected(pk, addr, is_incoming);
},
))));
let basalt2 = basalt.clone();
netapp.on_disconnected.store(Some(Arc::new(Box::new(
move |pk: ed25519::PublicKey, is_incoming: bool| {
basalt2.on_disconnected(pk, is_incoming);
},
))));
let basalt2 = basalt.clone();
netapp.add_msg_handler::<PullMessage, _, _>(
move |_from: ed25519::PublicKey, _pullmsg: PullMessage| {
let push_msg = basalt2.make_push_message();
async move { Ok(push_msg) }
},
);
let basalt2 = basalt.clone();
netapp.add_msg_handler::<PushMessage, _, _>(
move |_from: ed25519::PublicKey, push_msg: PushMessage| {
basalt2.handle_peer_list(&push_msg.peers[..]);
async move { Ok(()) }
},
);
basalt
}
pub fn sample(&self, count: usize) -> Vec<ed25519::PublicKey> {
self.view
.read()
.unwrap()
.sample(count)
.iter()
.map(|p| p.id)
.collect::<Vec<_>>()
}
pub async fn run(self: Arc<Self>) {
for peer in self.bootstrap_peers.iter() {
tokio::spawn(self.clone().try_connect(*peer));
}
let pushpull_loop = self.clone().run_pushpull_loop();
let reset_loop = self.run_reset_loop();
tokio::join!(pushpull_loop, reset_loop);
}
async fn run_pushpull_loop(self: Arc<Self>) {
loop {
tokio::time::delay_for(self.param.exchange_interval).await;
let peers = self.view.read().unwrap().sample(2);
if peers.len() == 2 {
let (c1, c2) = {
let client_conns = self.netapp.client_conns.read().unwrap();
(
client_conns.get(&peers[0].id).cloned(),
client_conns.get(&peers[1].id).cloned(),
)
};
if let Some(c) = c1 {
tokio::spawn(self.clone().do_pull(c));
}
if let Some(c) = c2 {
tokio::spawn(self.clone().do_push(c));
}
}
}
}
async fn do_pull(self: Arc<Self>, peer: Arc<ClientConn>) {
match peer.request(PullMessage {}, prio::NORMAL).await {
Ok(resp) => {
self.handle_peer_list(&resp.peers[..]);
}
Err(e) => {
warn!("Error during pull exchange: {}", e);
}
};
}
async fn do_push(self: Arc<Self>, peer: Arc<ClientConn>) {
let push_msg = self.make_push_message();
if let Err(e) = peer.request(push_msg, prio::NORMAL).await {
warn!("Error during push exchange: {}", e);
}
}
fn make_push_message(&self) -> PushMessage {
let current_peers = self.view.read().unwrap().current_peers_vec();
PushMessage {
peers: current_peers,
}
}
async fn run_reset_loop(self: Arc<Self>) {
loop {
tokio::time::delay_for(self.param.reset_interval).await;
{
let mut view = self.view.write().unwrap();
let prev_peers = view.current_peers();
let prev_peers_vec = prev_peers.iter().cloned().collect::<Vec<_>>();
view.reset_some_slots(self.param.reset_count);
view.update_all_slots(&prev_peers_vec[..]);
let new_peers = view.current_peers();
drop(view);
self.close_all_diff(&prev_peers, &new_peers);
}
let mut to_retry_maybe = self.bootstrap_peers.clone();
for (peer, _) in self.backlog.read().unwrap().iter() {
if !self.bootstrap_peers.contains(peer) {
to_retry_maybe.push(*peer);
}
}
self.handle_peer_list(&to_retry_maybe[..]);
}
}
fn handle_peer_list(self: &Arc<Self>, peers: &[Peer]) {
let to_connect = self.view.read().unwrap().should_try_list(peers);
for peer in to_connect.iter() {
tokio::spawn(self.clone().try_connect(*peer));
}
}
async fn try_connect(self: Arc<Self>, peer: Peer) {
{
let view = self.view.read().unwrap();
let mut attempts = self.current_attempts.write().unwrap();
if view.slots.iter().any(|x| x.peer == Some(peer)) {
return;
}
if attempts.contains(&peer) {
return;
}
attempts.insert(peer);
}
let res = self.netapp.clone().try_connect(peer.addr, peer.id).await;
debug!("Connection attempt to {}: {:?}", peer.addr, res);
self.current_attempts.write().unwrap().remove(&peer);
if res.is_err() {
self.backlog.write().unwrap().pop(&peer);
}
}
fn on_connected(self: &Arc<Self>, pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool) {
if is_incoming {
self.handle_peer_list(&[Peer{id: pk, addr}][..]);
} else {
let peer = Peer { id: pk, addr };
let mut backlog = self.backlog.write().unwrap();
if backlog.get(&peer).is_none() {
backlog.put(peer, ());
}
drop(backlog);
let mut view = self.view.write().unwrap();
let prev_peers = view.current_peers();
view.update_all_slots(&[peer][..]);
let new_peers = view.current_peers();
drop(view);
self.close_all_diff(&prev_peers, &new_peers);
}
}
fn on_disconnected(&self, pk: ed25519::PublicKey, is_incoming: bool) {
if !is_incoming {
self.view.write().unwrap().disconnected(pk);
}
}
fn close_all_diff(&self, prev_peers: &HashSet<Peer>, new_peers: &HashSet<Peer>) {
let client_conns = self.netapp.client_conns.read().unwrap();
for peer in prev_peers.iter() {
if !new_peers.contains(peer) {
if let Some(c) = client_conns.get(&peer.id) {
debug!("Closing connection to {} ({})", hex::encode(peer.id), peer.addr);
c.close();
}
}
}
}
}
fn rand_seed() -> Seed {
let mut seed = [0u8; 32];
sodiumoxide::randombytes::randombytes_into(&mut seed[..]);
seed
}

@ -0,0 +1,437 @@
use std::collections::{HashMap, VecDeque};
use std::net::SocketAddr;
use std::sync::atomic::{self, AtomicU64};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use log::{debug, info, trace, warn};
use serde::{Deserialize, Serialize};
use sodiumoxide::crypto::hash;
use sodiumoxide::crypto::sign::ed25519;
use crate::conn::*;
use crate::message::*;
use crate::netapp::*;
use crate::proto::*;
const CONN_RETRY_INTERVAL: Duration = Duration::from_secs(30);
const CONN_MAX_RETRIES: usize = 10;
const PING_INTERVAL: Duration = Duration::from_secs(10);
const LOOP_DELAY: Duration = Duration::from_secs(1);
// -- Protocol messages --
#[derive(Serialize, Deserialize)]
struct PingMessage {
pub id: u64,
pub peer_list_hash: hash::Digest,
}
impl Message for PingMessage {
const KIND: MessageKind = 0x42001000;
type Response = PingMessage;
}
#[derive(Serialize, Deserialize)]
struct PeerListMessage {
pub list: Vec<(ed25519::PublicKey, SocketAddr)>,
}
impl Message for PeerListMessage {
const KIND: MessageKind = 0x42001001;
type Response = PeerListMessage;
}
// -- Algorithm data structures --
#[derive(Debug)]