WIP v0.3.0 with changed API

This commit is contained in:
Alex 2021-10-12 17:59:46 +02:00
parent 040231d554
commit f87dbe73dc
No known key found for this signature in database
GPG key ID: EDABF9711E244EB1
9 changed files with 344 additions and 250 deletions

View file

@ -1,6 +1,6 @@
[package] [package]
name = "netapp" name = "netapp"
version = "0.2.0" version = "0.3.0"
authors = ["Alex Auvolat <alex@adnab.me>"] authors = ["Alex Auvolat <alex@adnab.me>"]
edition = "2018" edition = "2018"
license-file = "LICENSE" license-file = "LICENSE"

View file

@ -1,6 +1,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::{self, AtomicBool, AtomicU16}; use std::sync::atomic::{self, AtomicBool, AtomicU32};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use bytes::Bytes; use bytes::Bytes;
@ -16,12 +16,22 @@ use async_trait::async_trait;
use kuska_handshake::async_std::{handshake_client, handshake_server, BoxStream}; use kuska_handshake::async_std::{handshake_client, handshake_server, BoxStream};
use crate::endpoint::*;
use crate::error::*; use crate::error::*;
use crate::message::*;
use crate::netapp::*; use crate::netapp::*;
use crate::proto::*; use crate::proto::*;
use crate::util::*; use crate::util::*;
// Request message format (client -> server):
// - u8 priority
// - u8 path length
// - [u8; path length] path
// - [u8; *] data
// Response message format (server -> client):
// - u8 response code
// - [u8; *] response
pub(crate) struct ServerConn { pub(crate) struct ServerConn {
pub(crate) remote_addr: SocketAddr, pub(crate) remote_addr: SocketAddr,
pub(crate) peer_id: NodeID, pub(crate) peer_id: NodeID,
@ -99,30 +109,60 @@ impl ServerConn {
pub fn close(&self) { pub fn close(&self) {
self.close_send.send(true).unwrap(); self.close_send.send(true).unwrap();
} }
async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
if bytes.len() < 2 {
return Err(Error::Message("Invalid protocol message".into()));
}
// byte 0 is the request priority, we don't care here
let path_length = bytes[1] as usize;
if bytes.len() < 2 + path_length {
return Err(Error::Message("Invalid protocol message".into()));
}
let path = &bytes[2..2 + path_length];
let path = String::from_utf8(path.to_vec())?;
let data = &bytes[2 + path_length..];
let handler_opt = {
let endpoints = self.netapp.endpoints.read().unwrap();
endpoints.get(&path).map(|e| e.clone_endpoint())
};
if let Some(handler) = handler_opt {
handler.handle(data, self.peer_id).await
} else {
Err(Error::NoHandler)
}
}
} }
impl SendLoop for ServerConn {} impl SendLoop for ServerConn {}
#[async_trait] #[async_trait]
impl RecvLoop for ServerConn { impl RecvLoop for ServerConn {
async fn recv_handler(self: Arc<Self>, id: u16, bytes: Vec<u8>) { async fn recv_handler(self: Arc<Self>, id: RequestID, bytes: Vec<u8>) {
trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len());
let bytes: Bytes = bytes.into(); let bytes: Bytes = bytes.into();
let resp = self.recv_handler_aux(&bytes[..]).await;
let prio = bytes[0]; let prio = bytes[0];
let mut kind_bytes = [0u8; 4]; let mut resp_bytes = vec![];
kind_bytes.copy_from_slice(&bytes[1..5]); match resp {
let kind = u32::from_be_bytes(kind_bytes); Ok(rb) => {
resp_bytes.push(0u8);
if let Some(handler) = self.netapp.msg_handlers.load().get(&kind) { resp_bytes.extend(&rb[..]);
let net_handler = &handler.net_handler;
let resp = net_handler(self.peer_id, bytes.slice(5..)).await;
self.resp_send
.send(Some((id, prio, resp)))
.log_err("ServerConn recv_handler send resp");
} }
Err(e) => {
resp_bytes.push(e.code());
}
}
self.resp_send
.send(Some((id, prio, resp_bytes)))
.log_err("ServerConn recv_handler send resp");
} }
} }
pub(crate) struct ClientConn { pub(crate) struct ClientConn {
@ -131,7 +171,7 @@ pub(crate) struct ClientConn {
query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>, query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
next_query_number: AtomicU16, next_query_number: AtomicU32,
inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>, inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
must_exit: AtomicBool, must_exit: AtomicBool,
stop_recv_loop: watch::Sender<bool>, stop_recv_loop: watch::Sender<bool>,
@ -173,7 +213,7 @@ impl ClientConn {
let conn = Arc::new(ClientConn { let conn = Arc::new(ClientConn {
remote_addr, remote_addr,
peer_id, peer_id,
next_query_number: AtomicU16::from(0u16), next_query_number: AtomicU32::from(RequestID::default()),
query_send, query_send,
inflight: Mutex::new(HashMap::new()), inflight: Mutex::new(HashMap::new()),
must_exit: AtomicBool::new(false), must_exit: AtomicBool::new(false),
@ -212,9 +252,10 @@ impl ClientConn {
} }
} }
pub(crate) async fn request<T>( pub(crate) async fn call<T>(
self: Arc<Self>, self: Arc<Self>,
rq: T, rq: T,
path: &str,
prio: RequestPriority, prio: RequestPriority,
) -> Result<<T as Message>::Response, Error> ) -> Result<<T as Message>::Response, Error>
where where
@ -222,9 +263,9 @@ impl ClientConn {
{ {
let id = self let id = self
.next_query_number .next_query_number
.fetch_add(1u16, atomic::Ordering::Relaxed); .fetch_add(1, atomic::Ordering::Relaxed);
let mut bytes = vec![prio]; let mut bytes = vec![prio, path.as_bytes().len() as u8];
bytes.extend_from_slice(&u32::to_be_bytes(T::KIND)[..]); bytes.extend_from_slice(path.as_bytes());
bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]); bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]);
let (resp_send, resp_recv) = oneshot::channel(); let (resp_send, resp_recv) = oneshot::channel();
@ -243,8 +284,15 @@ impl ClientConn {
let resp = resp_recv.await?; let resp = resp_recv.await?;
rmp_serde::decode::from_read_ref::<_, Result<<T as Message>::Response, String>>(&resp[..])? let code = resp[0];
if code == 0 {
rmp_serde::decode::from_read_ref::<_, Result<<T as Message>::Response, String>>(
&resp[1..],
)?
.map_err(Error::Remote) .map_err(Error::Remote)
} else {
Err(Error::Remote(format!("Remote error code {}", code)))
}
} }
} }

125
src/endpoint.rs Normal file
View file

@ -0,0 +1,125 @@
use std::marker::PhantomData;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::Error;
use crate::netapp::*;
use crate::proto::*;
use crate::util::*;
/// This trait should be implemented by all messages your application
/// wants to handle (click to read more).
pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
}
pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;
#[async_trait]
pub trait EndpointHandler<M>: Send + Sync
where
M: Message,
{
async fn handle(self: &Arc<Self>, m: M, from: NodeID) -> M::Response;
}
pub struct Endpoint<M, H>
where
M: Message,
H: EndpointHandler<M>,
{
phantom: PhantomData<M>,
netapp: Arc<NetApp>,
path: String,
handler: ArcSwapOption<H>,
}
impl<M, H> Endpoint<M, H>
where
M: Message,
H: EndpointHandler<M>,
{
pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self {
Self {
phantom: PhantomData::default(),
netapp,
path,
handler: ArcSwapOption::from(None),
}
}
pub fn set_handler(&self, h: Arc<H>) {
self.handler.swap(Some(h));
}
pub async fn call(
&self,
target: &NodeID,
req: M,
prio: RequestPriority,
) -> Result<<M as Message>::Response, Error> {
if *target == self.netapp.id {
match self.handler.load_full() {
None => Err(Error::NoHandler),
Some(h) => Ok(h.handle(req, self.netapp.id).await),
}
} else {
let conn = self
.netapp
.client_conns
.read()
.unwrap()
.get(target)
.cloned();
match conn {
None => Err(Error::Message(format!(
"Not connected: {}",
hex::encode(target)
))),
Some(c) => c.call(req, self.path.as_str(), prio).await,
}
}
}
}
#[async_trait]
pub(crate) trait GenericEndpoint {
async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error>;
fn clear_handler(&self);
fn clone_endpoint(&self) -> DynEndpoint;
}
#[derive(Clone)]
pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>)
where
M: Message,
H: EndpointHandler<M>;
#[async_trait]
impl<M, H> GenericEndpoint for EndpointArc<M, H>
where
M: Message + 'static,
H: EndpointHandler<M> + 'static,
{
async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error> {
match self.0.handler.load_full() {
None => Err(Error::NoHandler),
Some(h) => {
let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?;
let res = h.handle(req, from).await;
let res_bytes = rmp_to_vec_all_named(&res)?;
Ok(res_bytes)
}
}
}
fn clear_handler(&self) {
self.0.handler.swap(None);
}
fn clone_endpoint(&self) -> DynEndpoint {
Box::new(Self(self.0.clone()))
}
}

View file

@ -22,13 +22,36 @@ pub enum Error {
#[error(display = "Handshake error: {}", _0)] #[error(display = "Handshake error: {}", _0)]
Handshake(#[error(source)] kuska_handshake::async_std::Error), Handshake(#[error(source)] kuska_handshake::async_std::Error),
#[error(display = "UTF8 error: {}", _0)]
UTF8(#[error(source)] std::string::FromUtf8Error),
#[error(display = "{}", _0)] #[error(display = "{}", _0)]
Message(String), Message(String),
#[error(display = "No handler / shutting down")]
NoHandler,
#[error(display = "Remote error: {}", _0)] #[error(display = "Remote error: {}", _0)]
Remote(String), Remote(String),
} }
impl Error {
pub fn code(&self) -> u8 {
match self {
Self::Io(_) => 100,
Self::TokioJoin(_) => 110,
Self::OneshotRecv(_) => 111,
Self::RMPEncode(_) => 10,
Self::RMPDecode(_) => 11,
Self::UTF8(_) => 12,
Self::NoHandler => 20,
Self::Handshake(_) => 30,
Self::Remote(_) => 40,
Self::Message(_) => 99,
}
}
}
impl<T> From<tokio::sync::watch::error::SendError<T>> for Error { impl<T> From<tokio::sync::watch::error::SendError<T>> for Error {
fn from(_e: tokio::sync::watch::error::SendError<T>) -> Error { fn from(_e: tokio::sync::watch::error::SendError<T>) -> Error {
Error::Message("Watch send error".into()) Error::Message("Watch send error".into())

View file

@ -18,7 +18,7 @@
pub mod error; pub mod error;
pub mod util; pub mod util;
pub mod message; pub mod endpoint;
pub mod proto; pub mod proto;
mod conn; mod conn;

View file

@ -1,36 +0,0 @@
use std::net::IpAddr;
use serde::{Deserialize, Serialize};
pub type MessageKind = u32;
/// This trait should be implemented by all messages your application
/// wants to handle (click to read more).
///
/// It defines a `KIND`, which should be a **unique**
/// `u32` that distinguishes these messages from other types of messages
/// (it is used by our communication protocol), as well as an associated
/// `Response` type that defines the type of the response that is given
/// to the message. It is your responsibility to ensure that `KIND` is a
/// unique `u32` that is not used by any other protocol messages.
/// All `KIND` values of the form `0x42xxxxxx` are reserved by the netapp
/// crate for internal purposes.
///
/// A handler for this message has type `Self -> Self::Response`.
/// If you need to return an error, the `Response` type should be
/// a `Result<_, _>`.
pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
const KIND: MessageKind;
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
}
#[derive(Serialize, Deserialize)]
pub(crate) struct HelloMessage {
pub server_addr: Option<IpAddr>,
pub server_port: u16,
}
impl Message for HelloMessage {
const KIND: MessageKind = 0x42000001;
type Response = ();
}

View file

@ -1,43 +1,36 @@
use std::any::Any;
use std::collections::HashMap; use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::Instant;
use std::future::Future;
use log::{debug, info}; use log::{debug, info};
use arc_swap::{ArcSwap, ArcSwapOption}; use arc_swap::ArcSwapOption;
use bytes::Bytes; use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use sodiumoxide::crypto::auth; use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519; use sodiumoxide::crypto::sign::ed25519;
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use crate::conn::*; use crate::conn::*;
use crate::endpoint::*;
use crate::error::*; use crate::error::*;
use crate::message::*;
use crate::proto::*; use crate::proto::*;
use crate::util::*; use crate::util::*;
type DynMsg = Box<dyn Any + Send + Sync + 'static>; #[derive(Serialize, Deserialize)]
pub(crate) struct HelloMessage {
pub server_addr: Option<IpAddr>,
pub server_port: u16,
}
impl Message for HelloMessage {
type Response = ();
}
type OnConnectHandler = Box<dyn Fn(NodeID, SocketAddr, bool) + Send + Sync>; type OnConnectHandler = Box<dyn Fn(NodeID, SocketAddr, bool) + Send + Sync>;
type OnDisconnectHandler = Box<dyn Fn(NodeID, bool) + Send + Sync>; type OnDisconnectHandler = Box<dyn Fn(NodeID, bool) + Send + Sync>;
pub(crate) type LocalHandler =
Box<dyn Fn(DynMsg) -> Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> + Sync + Send>;
pub(crate) type NetHandler = Box<
dyn Fn(NodeID, Bytes) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> + Sync + Send,
>;
pub(crate) struct Handler {
pub(crate) local_handler: LocalHandler,
pub(crate) net_handler: NetHandler,
}
/// NetApp is the main class that handles incoming and outgoing connections. /// NetApp is the main class that handles incoming and outgoing connections.
/// ///
/// The `request()` method can be used to send a message to any peer to which we have /// The `request()` method can be used to send a message to any peer to which we have
@ -60,10 +53,12 @@ pub struct NetApp {
/// Private key associated with our peer ID /// Private key associated with our peer ID
pub privkey: ed25519::SecretKey, pub privkey: ed25519::SecretKey,
server_conns: RwLock<HashMap<NodeID, Arc<ServerConn>>>, pub(crate) server_conns: RwLock<HashMap<NodeID, Arc<ServerConn>>>,
client_conns: RwLock<HashMap<NodeID, Arc<ClientConn>>>, pub(crate) client_conns: RwLock<HashMap<NodeID, Arc<ClientConn>>>,
pub(crate) endpoints: RwLock<HashMap<String, DynEndpoint>>,
hello_endpoint: ArcSwapOption<Endpoint<HelloMessage, NetApp>>,
pub(crate) msg_handlers: ArcSwap<HashMap<MessageKind, Arc<Handler>>>,
on_connected_handler: ArcSwapOption<OnConnectHandler>, on_connected_handler: ArcSwapOption<OnConnectHandler>,
on_disconnected_handler: ArcSwapOption<OnDisconnectHandler>, on_disconnected_handler: ArcSwapOption<OnDisconnectHandler>,
} }
@ -73,44 +68,6 @@ struct ListenParams {
public_addr: Option<IpAddr>, public_addr: Option<IpAddr>,
} }
async fn net_handler_aux<M, F, R>(handler: Arc<F>, remote: NodeID, bytes: Bytes) -> Vec<u8>
where
M: Message + 'static,
F: Fn(NodeID, M) -> R + Send + Sync + 'static,
R: Future<Output = <M as Message>::Response> + Send + Sync,
{
debug!(
"Handling message of kind {:08x} from {}",
M::KIND,
hex::encode(remote)
);
let begin_time = Instant::now();
let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) {
Ok(msg) => Ok(handler(remote, msg).await),
Err(e) => Err(e.to_string()),
};
let end_time = Instant::now();
debug!(
"Request {:08x} from {} handled in {}msec",
M::KIND,
hex::encode(remote),
(end_time - begin_time).as_millis()
);
rmp_to_vec_all_named(&res).unwrap_or_default()
}
async fn local_handler_aux<M, F, R>(handler: Arc<F>, remote: NodeID, msg: DynMsg) -> DynMsg
where
M: Message + 'static,
F: Fn(NodeID, M) -> R + Send + Sync + 'static,
R: Future<Output = <M as Message>::Response> + Send + Sync,
{
debug!("Handling message of kind {:08x} from ourself", M::KIND);
let msg = (msg as Box<dyn Any + 'static>).downcast::<M>().unwrap();
let res = handler(remote, *msg).await;
Box::new(res)
}
impl NetApp { impl NetApp {
/// Creates a new instance of NetApp, which can serve either as a full p2p node, /// Creates a new instance of NetApp, which can serve either as a full p2p node,
/// or just as a passive client. To upgrade to a full p2p node, spawn a listener /// or just as a passive client. To upgrade to a full p2p node, spawn a listener
@ -126,16 +83,20 @@ impl NetApp {
privkey, privkey,
server_conns: RwLock::new(HashMap::new()), server_conns: RwLock::new(HashMap::new()),
client_conns: RwLock::new(HashMap::new()), client_conns: RwLock::new(HashMap::new()),
msg_handlers: ArcSwap::new(Arc::new(HashMap::new())), endpoints: RwLock::new(HashMap::new()),
hello_endpoint: ArcSwapOption::new(None),
on_connected_handler: ArcSwapOption::new(None), on_connected_handler: ArcSwapOption::new(None),
on_disconnected_handler: ArcSwapOption::new(None), on_disconnected_handler: ArcSwapOption::new(None),
}); });
let netapp2 = netapp.clone(); netapp
netapp.add_msg_handler::<HelloMessage, _, _>(move |from: NodeID, msg: HelloMessage| { .hello_endpoint
netapp2.handle_hello_message(from, msg); .swap(Some(netapp.endpoint("__netapp/netapp.rs/Hello".into())));
async {} netapp
}); .hello_endpoint
.load_full()
.unwrap()
.set_handler(netapp.clone());
netapp netapp
} }
@ -162,40 +123,23 @@ impl NetApp {
.store(Some(Arc::new(Box::new(handler)))); .store(Some(Arc::new(Box::new(handler))));
} }
/// Add a handler for a certain message type. Note that only one handler pub fn endpoint<M, H>(self: &Arc<Self>, name: String) -> Arc<Endpoint<M, H>>
/// can be specified for each message type.
/// The handler is an asynchronous function, i.e. a function that returns
/// a future.
pub fn add_msg_handler<M, F, R>(&self, handler: F)
where where
M: Message + 'static, M: Message + 'static,
F: Fn(NodeID, M) -> R + Send + Sync + 'static, H: EndpointHandler<M> + 'static,
R: Future<Output = <M as Message>::Response> + Send + Sync + 'static,
{ {
let handler = Arc::new(handler); let endpoint = Arc::new(Endpoint::<M, H>::new(self.clone(), name.clone()));
let endpoint_arc = EndpointArc(endpoint.clone());
let handler2 = handler.clone(); if self
let net_handler = Box::new(move |remote: NodeID, bytes: Bytes| { .endpoints
let fun: Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> = .write()
Box::pin(net_handler_aux(handler2.clone(), remote, bytes)); .unwrap()
fun .insert(name.clone(), Box::new(endpoint_arc))
}); .is_some()
{
let self_id = self.id; panic!("Redefining endpoint: {}", name);
let local_handler = Box::new(move |msg: DynMsg| { };
let fun: Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> = endpoint
Box::pin(local_handler_aux(handler.clone(), self_id, msg));
fun
});
let funs = Arc::new(Handler {
net_handler,
local_handler,
});
let mut handlers = self.msg_handlers.load().as_ref().clone();
handlers.insert(M::KIND, funs);
self.msg_handlers.store(Arc::new(handlers));
} }
/// Main listening process for our app. This future runs during the whole /// Main listening process for our app. This future runs during the whole
@ -318,15 +262,6 @@ impl NetApp {
// At this point we know they are a full network member, and not just a client, // At this point we know they are a full network member, and not just a client,
// and we call the on_connected handler so that the peering strategy knows // and we call the on_connected handler so that the peering strategy knows
// we have a new potential peer // we have a new potential peer
fn handle_hello_message(&self, id: NodeID, msg: HelloMessage) {
if let Some(h) = self.on_connected_handler.load().as_ref() {
if let Some(c) = self.server_conns.read().unwrap().get(&id) {
let remote_ip = msg.server_addr.unwrap_or_else(|| c.remote_addr.ip());
let remote_addr = SocketAddr::new(remote_ip, msg.server_port);
h(id, remote_addr, true);
}
}
}
// Called from conn.rs when an incoming connection is closed. // Called from conn.rs when an incoming connection is closed.
// We deregister the connection from server_conns and call the // We deregister the connection from server_conns and call the
@ -371,8 +306,11 @@ impl NetApp {
if let Some(lp) = self.listen_params.load_full() { if let Some(lp) = self.listen_params.load_full() {
let server_addr = lp.public_addr; let server_addr = lp.public_addr;
let server_port = lp.listen_addr.port(); let server_port = lp.listen_addr.port();
let hello_endpoint = self.hello_endpoint.load_full().unwrap();
tokio::spawn(async move { tokio::spawn(async move {
conn.request( hello_endpoint
.call(
&conn.peer_id,
HelloMessage { HelloMessage {
server_addr, server_addr,
server_port, server_port,
@ -404,44 +342,16 @@ impl NetApp {
// else case: happens if connection was removed in .disconnect() // else case: happens if connection was removed in .disconnect()
// in which case on_disconnected_handler was already called // in which case on_disconnected_handler was already called
} }
}
/// Send a message to a remote host to which a client connection is already #[async_trait]
/// established, and await their response. The target is the id of the peer we impl EndpointHandler<HelloMessage> for NetApp {
/// want to send the message to. async fn handle(self: &Arc<Self>, msg: HelloMessage, from: NodeID) {
/// The priority is an `u8`, with lower numbers meaning highest priority. if let Some(h) = self.on_connected_handler.load().as_ref() {
pub async fn request<T>( if let Some(c) = self.server_conns.read().unwrap().get(&from) {
&self, let remote_ip = msg.server_addr.unwrap_or_else(|| c.remote_addr.ip());
target: &NodeID, let remote_addr = SocketAddr::new(remote_ip, msg.server_port);
rq: T, h(from, remote_addr, true);
prio: RequestPriority,
) -> Result<<T as Message>::Response, Error>
where
T: Message + 'static,
{
if *target == self.id {
let handler = self.msg_handlers.load().get(&T::KIND).cloned();
match handler {
None => Err(Error::Message(format!(
"No handler registered for message kind {:08x}",
T::KIND
))),
Some(h) => {
let local_handler = &h.local_handler;
let res = local_handler(Box::new(rq)).await;
let res_t = (res as Box<dyn Any + 'static>)
.downcast::<<T as Message>::Response>()
.unwrap();
Ok(*res_t)
}
}
} else {
let conn = self.client_conns.read().unwrap().get(target).cloned();
match conn {
None => Err(Error::Message(format!(
"Not connected: {}",
hex::encode(target)
))),
Some(c) => c.request(rq, prio).await,
} }
} }
} }

View file

@ -4,12 +4,13 @@ use std::sync::atomic::{self, AtomicU64};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use async_trait::async_trait;
use log::{debug, info, trace, warn}; use log::{debug, info, trace, warn};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sodiumoxide::crypto::hash; use sodiumoxide::crypto::hash;
use crate::message::*; use crate::endpoint::*;
use crate::netapp::*; use crate::netapp::*;
use crate::proto::*; use crate::proto::*;
use crate::NodeID; use crate::NodeID;
@ -28,7 +29,6 @@ struct PingMessage {
} }
impl Message for PingMessage { impl Message for PingMessage {
const KIND: MessageKind = 0x42001000;
type Response = PingMessage; type Response = PingMessage;
} }
@ -38,7 +38,6 @@ struct PeerListMessage {
} }
impl Message for PeerListMessage { impl Message for PeerListMessage {
const KIND: MessageKind = 0x42001001;
type Response = PeerListMessage; type Response = PeerListMessage;
} }
@ -124,6 +123,9 @@ pub struct FullMeshPeeringStrategy {
netapp: Arc<NetApp>, netapp: Arc<NetApp>,
known_hosts: RwLock<KnownHosts>, known_hosts: RwLock<KnownHosts>,
next_ping_id: AtomicU64, next_ping_id: AtomicU64,
ping_endpoint: Arc<Endpoint<PingMessage, Self>>,
peer_list_endpoint: Arc<Endpoint<PeerListMessage, Self>>,
} }
impl FullMeshPeeringStrategy { impl FullMeshPeeringStrategy {
@ -147,27 +149,12 @@ impl FullMeshPeeringStrategy {
netapp: netapp.clone(), netapp: netapp.clone(),
known_hosts: RwLock::new(known_hosts), known_hosts: RwLock::new(known_hosts),
next_ping_id: AtomicU64::new(42), next_ping_id: AtomicU64::new(42),
ping_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/Ping".into()),
peer_list_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/PeerList".into()),
}); });
let strat2 = strat.clone(); strat.ping_endpoint.set_handler(strat.clone());
netapp.add_msg_handler::<PingMessage, _, _>(move |from: NodeID, ping: PingMessage| { strat.peer_list_endpoint.set_handler(strat.clone());
let ping_resp = PingMessage {
id: ping.id,
peer_list_hash: strat2.known_hosts.read().unwrap().hash,
};
debug!("Ping from {}", hex::encode(&from));
async move { ping_resp }
});
let strat2 = strat.clone();
netapp.add_msg_handler::<PeerListMessage, _, _>(
move |_from: NodeID, peer_list: PeerListMessage| {
strat2.handle_peer_list(&peer_list.list[..]);
let peer_list = KnownHosts::map_into_vec(&strat2.known_hosts.read().unwrap().list);
let resp = PeerListMessage { list: peer_list };
async move { resp }
},
);
let strat2 = strat.clone(); let strat2 = strat.clone();
netapp.on_connected(move |id: NodeID, addr: SocketAddr, is_incoming: bool| { netapp.on_connected(move |id: NodeID, addr: SocketAddr, is_incoming: bool| {
@ -262,7 +249,7 @@ impl FullMeshPeeringStrategy {
hex::encode(id), hex::encode(id),
ping_time ping_time
); );
match self.netapp.request(&id, ping_msg, PRIO_HIGH).await { match self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH).await {
Err(e) => warn!("Error pinging {}: {}", hex::encode(id), e), Err(e) => warn!("Error pinging {}: {}", hex::encode(id), e),
Ok(ping_resp) => { Ok(ping_resp) => {
let resp_time = Instant::now(); let resp_time = Instant::now();
@ -291,7 +278,11 @@ impl FullMeshPeeringStrategy {
async fn exchange_peers(self: Arc<Self>, id: &NodeID) { async fn exchange_peers(self: Arc<Self>, id: &NodeID) {
let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list); let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list);
let pex_message = PeerListMessage { list: peer_list }; let pex_message = PeerListMessage { list: peer_list };
match self.netapp.request(id, pex_message, PRIO_BACKGROUND).await { match self
.peer_list_endpoint
.call(id, pex_message, PRIO_BACKGROUND)
.await
{
Err(e) => warn!("Error doing peer exchange: {}", e), Err(e) => warn!("Error doing peer exchange: {}", e),
Ok(resp) => { Ok(resp) => {
self.handle_peer_list(&resp.list[..]); self.handle_peer_list(&resp.list[..]);
@ -408,3 +399,28 @@ impl FullMeshPeeringStrategy {
} }
} }
} }
#[async_trait]
impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
async fn handle(self: &Arc<Self>, ping: PingMessage, from: NodeID) -> PingMessage {
let ping_resp = PingMessage {
id: ping.id,
peer_list_hash: self.known_hosts.read().unwrap().hash,
};
debug!("Ping from {}", hex::encode(&from));
ping_resp
}
}
#[async_trait]
impl EndpointHandler<PeerListMessage> for FullMeshPeeringStrategy {
async fn handle(
self: &Arc<Self>,
peer_list: PeerListMessage,
_from: NodeID,
) -> PeerListMessage {
self.handle_peer_list(&peer_list.list[..]);
let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list);
PeerListMessage { list: peer_list }
}
}

View file

@ -38,9 +38,10 @@ pub const PRIO_PRIMARY: RequestPriority = 0x00;
/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) /// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`)
pub const PRIO_SECONDARY: RequestPriority = 0x01; pub const PRIO_SECONDARY: RequestPriority = 0x01;
const MAX_CHUNK_SIZE: usize = 0x4000; pub(crate) type RequestID = u32;
type ChunkLength = u16;
pub(crate) type RequestID = u16; const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
struct SendQueueItem { struct SendQueueItem {
id: RequestID, id: RequestID,
@ -85,6 +86,12 @@ impl SendQueue {
} }
} }
// Messages are sent by chunks
// Chunk format:
// - u32 BE: request id (same for request and response)
// - u16 BE: chunk length
// - [u8; chunk_length] chunk data
#[async_trait] #[async_trait]
pub(crate) trait SendLoop: Sync { pub(crate) trait SendLoop: Sync {
async fn send_loop<W>( async fn send_loop<W>(
@ -117,22 +124,23 @@ pub(crate) trait SendLoop: Sync {
item.data.len(), item.data.len(),
item.cursor item.cursor
); );
let header_id = u16::to_be_bytes(item.id); let header_id = RequestID::to_be_bytes(item.id);
write.write_all(&header_id[..]).await?; write.write_all(&header_id[..]).await?;
if item.data.len() - item.cursor > MAX_CHUNK_SIZE { if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize {
let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000); let header_size =
ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION);
write.write_all(&header_size[..]).await?; write.write_all(&header_size[..]).await?;
let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize; let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize;
write.write_all(&item.data[item.cursor..new_cursor]).await?; write.write_all(&item.data[item.cursor..new_cursor]).await?;
item.cursor = new_cursor; item.cursor = new_cursor;
sending.push(item); sending.push(item);
} else { } else {
let send_len = (item.data.len() - item.cursor) as u16; let send_len = (item.data.len() - item.cursor) as ChunkLength;
let header_size = u16::to_be_bytes(send_len); let header_size = ChunkLength::to_be_bytes(send_len);
write.write_all(&header_size[..]).await?; write.write_all(&header_size[..]).await?;
write.write_all(&item.data[item.cursor..]).await?; write.write_all(&item.data[item.cursor..]).await?;
@ -172,18 +180,18 @@ pub(crate) trait RecvLoop: Sync + 'static {
let mut receiving = HashMap::new(); let mut receiving = HashMap::new();
loop { loop {
trace!("recv_loop: reading packet"); trace!("recv_loop: reading packet");
let mut header_id = [0u8; 2]; let mut header_id = [0u8; RequestID::BITS as usize / 8];
read.read_exact(&mut header_id[..]).await?; read.read_exact(&mut header_id[..]).await?;
let id = RequestID::from_be_bytes(header_id); let id = RequestID::from_be_bytes(header_id);
trace!("recv_loop: got header id: {:04x}", id); trace!("recv_loop: got header id: {:04x}", id);
let mut header_size = [0u8; 2]; let mut header_size = [0u8; ChunkLength::BITS as usize / 8];
read.read_exact(&mut header_size[..]).await?; read.read_exact(&mut header_size[..]).await?;
let size = RequestID::from_be_bytes(header_size); let size = ChunkLength::from_be_bytes(header_size);
trace!("recv_loop: got header size: {:04x}", size); trace!("recv_loop: got header size: {:04x}", size);
let has_cont = (size & 0x8000) != 0; let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
let size = size & !0x8000; let size = size & !CHUNK_HAS_CONTINUATION;
let mut next_slice = vec![0; size as usize]; let mut next_slice = vec![0; size as usize];
read.read_exact(&mut next_slice[..]).await?; read.read_exact(&mut next_slice[..]).await?;