This commit is contained in:
parent
040231d554
commit
f87dbe73dc
9 changed files with 344 additions and 250 deletions
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "netapp"
|
||||
version = "0.2.0"
|
||||
version = "0.3.0"
|
||||
authors = ["Alex Auvolat <alex@adnab.me>"]
|
||||
edition = "2018"
|
||||
license-file = "LICENSE"
|
||||
|
|
90
src/conn.rs
90
src/conn.rs
|
@ -1,6 +1,6 @@
|
|||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{self, AtomicBool, AtomicU16};
|
||||
use std::sync::atomic::{self, AtomicBool, AtomicU32};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use bytes::Bytes;
|
||||
|
@ -16,12 +16,22 @@ use async_trait::async_trait;
|
|||
|
||||
use kuska_handshake::async_std::{handshake_client, handshake_server, BoxStream};
|
||||
|
||||
use crate::endpoint::*;
|
||||
use crate::error::*;
|
||||
use crate::message::*;
|
||||
use crate::netapp::*;
|
||||
use crate::proto::*;
|
||||
use crate::util::*;
|
||||
|
||||
// Request message format (client -> server):
|
||||
// - u8 priority
|
||||
// - u8 path length
|
||||
// - [u8; path length] path
|
||||
// - [u8; *] data
|
||||
|
||||
// Response message format (server -> client):
|
||||
// - u8 response code
|
||||
// - [u8; *] response
|
||||
|
||||
pub(crate) struct ServerConn {
|
||||
pub(crate) remote_addr: SocketAddr,
|
||||
pub(crate) peer_id: NodeID,
|
||||
|
@ -99,30 +109,60 @@ impl ServerConn {
|
|||
pub fn close(&self) {
|
||||
self.close_send.send(true).unwrap();
|
||||
}
|
||||
|
||||
async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
|
||||
if bytes.len() < 2 {
|
||||
return Err(Error::Message("Invalid protocol message".into()));
|
||||
}
|
||||
|
||||
// byte 0 is the request priority, we don't care here
|
||||
let path_length = bytes[1] as usize;
|
||||
if bytes.len() < 2 + path_length {
|
||||
return Err(Error::Message("Invalid protocol message".into()));
|
||||
}
|
||||
|
||||
let path = &bytes[2..2 + path_length];
|
||||
let path = String::from_utf8(path.to_vec())?;
|
||||
let data = &bytes[2 + path_length..];
|
||||
|
||||
let handler_opt = {
|
||||
let endpoints = self.netapp.endpoints.read().unwrap();
|
||||
endpoints.get(&path).map(|e| e.clone_endpoint())
|
||||
};
|
||||
|
||||
if let Some(handler) = handler_opt {
|
||||
handler.handle(data, self.peer_id).await
|
||||
} else {
|
||||
Err(Error::NoHandler)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SendLoop for ServerConn {}
|
||||
|
||||
#[async_trait]
|
||||
impl RecvLoop for ServerConn {
|
||||
async fn recv_handler(self: Arc<Self>, id: u16, bytes: Vec<u8>) {
|
||||
async fn recv_handler(self: Arc<Self>, id: RequestID, bytes: Vec<u8>) {
|
||||
trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len());
|
||||
|
||||
let bytes: Bytes = bytes.into();
|
||||
|
||||
let resp = self.recv_handler_aux(&bytes[..]).await;
|
||||
let prio = bytes[0];
|
||||
|
||||
let mut kind_bytes = [0u8; 4];
|
||||
kind_bytes.copy_from_slice(&bytes[1..5]);
|
||||
let kind = u32::from_be_bytes(kind_bytes);
|
||||
|
||||
if let Some(handler) = self.netapp.msg_handlers.load().get(&kind) {
|
||||
let net_handler = &handler.net_handler;
|
||||
let resp = net_handler(self.peer_id, bytes.slice(5..)).await;
|
||||
self.resp_send
|
||||
.send(Some((id, prio, resp)))
|
||||
.log_err("ServerConn recv_handler send resp");
|
||||
let mut resp_bytes = vec![];
|
||||
match resp {
|
||||
Ok(rb) => {
|
||||
resp_bytes.push(0u8);
|
||||
resp_bytes.extend(&rb[..]);
|
||||
}
|
||||
Err(e) => {
|
||||
resp_bytes.push(e.code());
|
||||
}
|
||||
}
|
||||
|
||||
self.resp_send
|
||||
.send(Some((id, prio, resp_bytes)))
|
||||
.log_err("ServerConn recv_handler send resp");
|
||||
}
|
||||
}
|
||||
pub(crate) struct ClientConn {
|
||||
|
@ -131,7 +171,7 @@ pub(crate) struct ClientConn {
|
|||
|
||||
query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
|
||||
|
||||
next_query_number: AtomicU16,
|
||||
next_query_number: AtomicU32,
|
||||
inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
|
||||
must_exit: AtomicBool,
|
||||
stop_recv_loop: watch::Sender<bool>,
|
||||
|
@ -173,7 +213,7 @@ impl ClientConn {
|
|||
let conn = Arc::new(ClientConn {
|
||||
remote_addr,
|
||||
peer_id,
|
||||
next_query_number: AtomicU16::from(0u16),
|
||||
next_query_number: AtomicU32::from(RequestID::default()),
|
||||
query_send,
|
||||
inflight: Mutex::new(HashMap::new()),
|
||||
must_exit: AtomicBool::new(false),
|
||||
|
@ -212,9 +252,10 @@ impl ClientConn {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn request<T>(
|
||||
pub(crate) async fn call<T>(
|
||||
self: Arc<Self>,
|
||||
rq: T,
|
||||
path: &str,
|
||||
prio: RequestPriority,
|
||||
) -> Result<<T as Message>::Response, Error>
|
||||
where
|
||||
|
@ -222,9 +263,9 @@ impl ClientConn {
|
|||
{
|
||||
let id = self
|
||||
.next_query_number
|
||||
.fetch_add(1u16, atomic::Ordering::Relaxed);
|
||||
let mut bytes = vec![prio];
|
||||
bytes.extend_from_slice(&u32::to_be_bytes(T::KIND)[..]);
|
||||
.fetch_add(1, atomic::Ordering::Relaxed);
|
||||
let mut bytes = vec![prio, path.as_bytes().len() as u8];
|
||||
bytes.extend_from_slice(path.as_bytes());
|
||||
bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]);
|
||||
|
||||
let (resp_send, resp_recv) = oneshot::channel();
|
||||
|
@ -243,8 +284,15 @@ impl ClientConn {
|
|||
|
||||
let resp = resp_recv.await?;
|
||||
|
||||
rmp_serde::decode::from_read_ref::<_, Result<<T as Message>::Response, String>>(&resp[..])?
|
||||
let code = resp[0];
|
||||
if code == 0 {
|
||||
rmp_serde::decode::from_read_ref::<_, Result<<T as Message>::Response, String>>(
|
||||
&resp[1..],
|
||||
)?
|
||||
.map_err(Error::Remote)
|
||||
} else {
|
||||
Err(Error::Remote(format!("Remote error code {}", code)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
125
src/endpoint.rs
Normal file
125
src/endpoint.rs
Normal file
|
@ -0,0 +1,125 @@
|
|||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use async_trait::async_trait;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::netapp::*;
|
||||
use crate::proto::*;
|
||||
use crate::util::*;
|
||||
|
||||
/// This trait should be implemented by all messages your application
|
||||
/// wants to handle (click to read more).
|
||||
pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
|
||||
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
|
||||
}
|
||||
|
||||
pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait EndpointHandler<M>: Send + Sync
|
||||
where
|
||||
M: Message,
|
||||
{
|
||||
async fn handle(self: &Arc<Self>, m: M, from: NodeID) -> M::Response;
|
||||
}
|
||||
|
||||
pub struct Endpoint<M, H>
|
||||
where
|
||||
M: Message,
|
||||
H: EndpointHandler<M>,
|
||||
{
|
||||
phantom: PhantomData<M>,
|
||||
netapp: Arc<NetApp>,
|
||||
path: String,
|
||||
handler: ArcSwapOption<H>,
|
||||
}
|
||||
|
||||
impl<M, H> Endpoint<M, H>
|
||||
where
|
||||
M: Message,
|
||||
H: EndpointHandler<M>,
|
||||
{
|
||||
pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self {
|
||||
Self {
|
||||
phantom: PhantomData::default(),
|
||||
netapp,
|
||||
path,
|
||||
handler: ArcSwapOption::from(None),
|
||||
}
|
||||
}
|
||||
pub fn set_handler(&self, h: Arc<H>) {
|
||||
self.handler.swap(Some(h));
|
||||
}
|
||||
pub async fn call(
|
||||
&self,
|
||||
target: &NodeID,
|
||||
req: M,
|
||||
prio: RequestPriority,
|
||||
) -> Result<<M as Message>::Response, Error> {
|
||||
if *target == self.netapp.id {
|
||||
match self.handler.load_full() {
|
||||
None => Err(Error::NoHandler),
|
||||
Some(h) => Ok(h.handle(req, self.netapp.id).await),
|
||||
}
|
||||
} else {
|
||||
let conn = self
|
||||
.netapp
|
||||
.client_conns
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(target)
|
||||
.cloned();
|
||||
match conn {
|
||||
None => Err(Error::Message(format!(
|
||||
"Not connected: {}",
|
||||
hex::encode(target)
|
||||
))),
|
||||
Some(c) => c.call(req, self.path.as_str(), prio).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait GenericEndpoint {
|
||||
async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error>;
|
||||
fn clear_handler(&self);
|
||||
fn clone_endpoint(&self) -> DynEndpoint;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>)
|
||||
where
|
||||
M: Message,
|
||||
H: EndpointHandler<M>;
|
||||
|
||||
#[async_trait]
|
||||
impl<M, H> GenericEndpoint for EndpointArc<M, H>
|
||||
where
|
||||
M: Message + 'static,
|
||||
H: EndpointHandler<M> + 'static,
|
||||
{
|
||||
async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error> {
|
||||
match self.0.handler.load_full() {
|
||||
None => Err(Error::NoHandler),
|
||||
Some(h) => {
|
||||
let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?;
|
||||
let res = h.handle(req, from).await;
|
||||
let res_bytes = rmp_to_vec_all_named(&res)?;
|
||||
Ok(res_bytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_handler(&self) {
|
||||
self.0.handler.swap(None);
|
||||
}
|
||||
|
||||
fn clone_endpoint(&self) -> DynEndpoint {
|
||||
Box::new(Self(self.0.clone()))
|
||||
}
|
||||
}
|
23
src/error.rs
23
src/error.rs
|
@ -22,13 +22,36 @@ pub enum Error {
|
|||
#[error(display = "Handshake error: {}", _0)]
|
||||
Handshake(#[error(source)] kuska_handshake::async_std::Error),
|
||||
|
||||
#[error(display = "UTF8 error: {}", _0)]
|
||||
UTF8(#[error(source)] std::string::FromUtf8Error),
|
||||
|
||||
#[error(display = "{}", _0)]
|
||||
Message(String),
|
||||
|
||||
#[error(display = "No handler / shutting down")]
|
||||
NoHandler,
|
||||
|
||||
#[error(display = "Remote error: {}", _0)]
|
||||
Remote(String),
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn code(&self) -> u8 {
|
||||
match self {
|
||||
Self::Io(_) => 100,
|
||||
Self::TokioJoin(_) => 110,
|
||||
Self::OneshotRecv(_) => 111,
|
||||
Self::RMPEncode(_) => 10,
|
||||
Self::RMPDecode(_) => 11,
|
||||
Self::UTF8(_) => 12,
|
||||
Self::NoHandler => 20,
|
||||
Self::Handshake(_) => 30,
|
||||
Self::Remote(_) => 40,
|
||||
Self::Message(_) => 99,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<tokio::sync::watch::error::SendError<T>> for Error {
|
||||
fn from(_e: tokio::sync::watch::error::SendError<T>) -> Error {
|
||||
Error::Message("Watch send error".into())
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
pub mod error;
|
||||
pub mod util;
|
||||
|
||||
pub mod message;
|
||||
pub mod endpoint;
|
||||
pub mod proto;
|
||||
|
||||
mod conn;
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
use std::net::IpAddr;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub type MessageKind = u32;
|
||||
|
||||
/// This trait should be implemented by all messages your application
|
||||
/// wants to handle (click to read more).
|
||||
///
|
||||
/// It defines a `KIND`, which should be a **unique**
|
||||
/// `u32` that distinguishes these messages from other types of messages
|
||||
/// (it is used by our communication protocol), as well as an associated
|
||||
/// `Response` type that defines the type of the response that is given
|
||||
/// to the message. It is your responsibility to ensure that `KIND` is a
|
||||
/// unique `u32` that is not used by any other protocol messages.
|
||||
/// All `KIND` values of the form `0x42xxxxxx` are reserved by the netapp
|
||||
/// crate for internal purposes.
|
||||
///
|
||||
/// A handler for this message has type `Self -> Self::Response`.
|
||||
/// If you need to return an error, the `Response` type should be
|
||||
/// a `Result<_, _>`.
|
||||
pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
|
||||
const KIND: MessageKind;
|
||||
type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub(crate) struct HelloMessage {
|
||||
pub server_addr: Option<IpAddr>,
|
||||
pub server_port: u16,
|
||||
}
|
||||
|
||||
impl Message for HelloMessage {
|
||||
const KIND: MessageKind = 0x42000001;
|
||||
type Response = ();
|
||||
}
|
216
src/netapp.rs
216
src/netapp.rs
|
@ -1,43 +1,36 @@
|
|||
use std::any::Any;
|
||||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Instant;
|
||||
|
||||
use std::future::Future;
|
||||
|
||||
use log::{debug, info};
|
||||
|
||||
use arc_swap::{ArcSwap, ArcSwapOption};
|
||||
use bytes::Bytes;
|
||||
use arc_swap::ArcSwapOption;
|
||||
use async_trait::async_trait;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sodiumoxide::crypto::auth;
|
||||
use sodiumoxide::crypto::sign::ed25519;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
use crate::conn::*;
|
||||
use crate::endpoint::*;
|
||||
use crate::error::*;
|
||||
use crate::message::*;
|
||||
use crate::proto::*;
|
||||
use crate::util::*;
|
||||
|
||||
type DynMsg = Box<dyn Any + Send + Sync + 'static>;
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub(crate) struct HelloMessage {
|
||||
pub server_addr: Option<IpAddr>,
|
||||
pub server_port: u16,
|
||||
}
|
||||
|
||||
impl Message for HelloMessage {
|
||||
type Response = ();
|
||||
}
|
||||
|
||||
type OnConnectHandler = Box<dyn Fn(NodeID, SocketAddr, bool) + Send + Sync>;
|
||||
type OnDisconnectHandler = Box<dyn Fn(NodeID, bool) + Send + Sync>;
|
||||
|
||||
pub(crate) type LocalHandler =
|
||||
Box<dyn Fn(DynMsg) -> Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> + Sync + Send>;
|
||||
pub(crate) type NetHandler = Box<
|
||||
dyn Fn(NodeID, Bytes) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> + Sync + Send,
|
||||
>;
|
||||
|
||||
pub(crate) struct Handler {
|
||||
pub(crate) local_handler: LocalHandler,
|
||||
pub(crate) net_handler: NetHandler,
|
||||
}
|
||||
|
||||
/// NetApp is the main class that handles incoming and outgoing connections.
|
||||
///
|
||||
/// The `request()` method can be used to send a message to any peer to which we have
|
||||
|
@ -60,10 +53,12 @@ pub struct NetApp {
|
|||
/// Private key associated with our peer ID
|
||||
pub privkey: ed25519::SecretKey,
|
||||
|
||||
server_conns: RwLock<HashMap<NodeID, Arc<ServerConn>>>,
|
||||
client_conns: RwLock<HashMap<NodeID, Arc<ClientConn>>>,
|
||||
pub(crate) server_conns: RwLock<HashMap<NodeID, Arc<ServerConn>>>,
|
||||
pub(crate) client_conns: RwLock<HashMap<NodeID, Arc<ClientConn>>>,
|
||||
|
||||
pub(crate) endpoints: RwLock<HashMap<String, DynEndpoint>>,
|
||||
hello_endpoint: ArcSwapOption<Endpoint<HelloMessage, NetApp>>,
|
||||
|
||||
pub(crate) msg_handlers: ArcSwap<HashMap<MessageKind, Arc<Handler>>>,
|
||||
on_connected_handler: ArcSwapOption<OnConnectHandler>,
|
||||
on_disconnected_handler: ArcSwapOption<OnDisconnectHandler>,
|
||||
}
|
||||
|
@ -73,44 +68,6 @@ struct ListenParams {
|
|||
public_addr: Option<IpAddr>,
|
||||
}
|
||||
|
||||
async fn net_handler_aux<M, F, R>(handler: Arc<F>, remote: NodeID, bytes: Bytes) -> Vec<u8>
|
||||
where
|
||||
M: Message + 'static,
|
||||
F: Fn(NodeID, M) -> R + Send + Sync + 'static,
|
||||
R: Future<Output = <M as Message>::Response> + Send + Sync,
|
||||
{
|
||||
debug!(
|
||||
"Handling message of kind {:08x} from {}",
|
||||
M::KIND,
|
||||
hex::encode(remote)
|
||||
);
|
||||
let begin_time = Instant::now();
|
||||
let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) {
|
||||
Ok(msg) => Ok(handler(remote, msg).await),
|
||||
Err(e) => Err(e.to_string()),
|
||||
};
|
||||
let end_time = Instant::now();
|
||||
debug!(
|
||||
"Request {:08x} from {} handled in {}msec",
|
||||
M::KIND,
|
||||
hex::encode(remote),
|
||||
(end_time - begin_time).as_millis()
|
||||
);
|
||||
rmp_to_vec_all_named(&res).unwrap_or_default()
|
||||
}
|
||||
|
||||
async fn local_handler_aux<M, F, R>(handler: Arc<F>, remote: NodeID, msg: DynMsg) -> DynMsg
|
||||
where
|
||||
M: Message + 'static,
|
||||
F: Fn(NodeID, M) -> R + Send + Sync + 'static,
|
||||
R: Future<Output = <M as Message>::Response> + Send + Sync,
|
||||
{
|
||||
debug!("Handling message of kind {:08x} from ourself", M::KIND);
|
||||
let msg = (msg as Box<dyn Any + 'static>).downcast::<M>().unwrap();
|
||||
let res = handler(remote, *msg).await;
|
||||
Box::new(res)
|
||||
}
|
||||
|
||||
impl NetApp {
|
||||
/// Creates a new instance of NetApp, which can serve either as a full p2p node,
|
||||
/// or just as a passive client. To upgrade to a full p2p node, spawn a listener
|
||||
|
@ -126,16 +83,20 @@ impl NetApp {
|
|||
privkey,
|
||||
server_conns: RwLock::new(HashMap::new()),
|
||||
client_conns: RwLock::new(HashMap::new()),
|
||||
msg_handlers: ArcSwap::new(Arc::new(HashMap::new())),
|
||||
endpoints: RwLock::new(HashMap::new()),
|
||||
hello_endpoint: ArcSwapOption::new(None),
|
||||
on_connected_handler: ArcSwapOption::new(None),
|
||||
on_disconnected_handler: ArcSwapOption::new(None),
|
||||
});
|
||||
|
||||
let netapp2 = netapp.clone();
|
||||
netapp.add_msg_handler::<HelloMessage, _, _>(move |from: NodeID, msg: HelloMessage| {
|
||||
netapp2.handle_hello_message(from, msg);
|
||||
async {}
|
||||
});
|
||||
netapp
|
||||
.hello_endpoint
|
||||
.swap(Some(netapp.endpoint("__netapp/netapp.rs/Hello".into())));
|
||||
netapp
|
||||
.hello_endpoint
|
||||
.load_full()
|
||||
.unwrap()
|
||||
.set_handler(netapp.clone());
|
||||
|
||||
netapp
|
||||
}
|
||||
|
@ -162,40 +123,23 @@ impl NetApp {
|
|||
.store(Some(Arc::new(Box::new(handler))));
|
||||
}
|
||||
|
||||
/// Add a handler for a certain message type. Note that only one handler
|
||||
/// can be specified for each message type.
|
||||
/// The handler is an asynchronous function, i.e. a function that returns
|
||||
/// a future.
|
||||
pub fn add_msg_handler<M, F, R>(&self, handler: F)
|
||||
pub fn endpoint<M, H>(self: &Arc<Self>, name: String) -> Arc<Endpoint<M, H>>
|
||||
where
|
||||
M: Message + 'static,
|
||||
F: Fn(NodeID, M) -> R + Send + Sync + 'static,
|
||||
R: Future<Output = <M as Message>::Response> + Send + Sync + 'static,
|
||||
H: EndpointHandler<M> + 'static,
|
||||
{
|
||||
let handler = Arc::new(handler);
|
||||
|
||||
let handler2 = handler.clone();
|
||||
let net_handler = Box::new(move |remote: NodeID, bytes: Bytes| {
|
||||
let fun: Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> =
|
||||
Box::pin(net_handler_aux(handler2.clone(), remote, bytes));
|
||||
fun
|
||||
});
|
||||
|
||||
let self_id = self.id;
|
||||
let local_handler = Box::new(move |msg: DynMsg| {
|
||||
let fun: Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> =
|
||||
Box::pin(local_handler_aux(handler.clone(), self_id, msg));
|
||||
fun
|
||||
});
|
||||
|
||||
let funs = Arc::new(Handler {
|
||||
net_handler,
|
||||
local_handler,
|
||||
});
|
||||
|
||||
let mut handlers = self.msg_handlers.load().as_ref().clone();
|
||||
handlers.insert(M::KIND, funs);
|
||||
self.msg_handlers.store(Arc::new(handlers));
|
||||
let endpoint = Arc::new(Endpoint::<M, H>::new(self.clone(), name.clone()));
|
||||
let endpoint_arc = EndpointArc(endpoint.clone());
|
||||
if self
|
||||
.endpoints
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(name.clone(), Box::new(endpoint_arc))
|
||||
.is_some()
|
||||
{
|
||||
panic!("Redefining endpoint: {}", name);
|
||||
};
|
||||
endpoint
|
||||
}
|
||||
|
||||
/// Main listening process for our app. This future runs during the whole
|
||||
|
@ -318,15 +262,6 @@ impl NetApp {
|
|||
// At this point we know they are a full network member, and not just a client,
|
||||
// and we call the on_connected handler so that the peering strategy knows
|
||||
// we have a new potential peer
|
||||
fn handle_hello_message(&self, id: NodeID, msg: HelloMessage) {
|
||||
if let Some(h) = self.on_connected_handler.load().as_ref() {
|
||||
if let Some(c) = self.server_conns.read().unwrap().get(&id) {
|
||||
let remote_ip = msg.server_addr.unwrap_or_else(|| c.remote_addr.ip());
|
||||
let remote_addr = SocketAddr::new(remote_ip, msg.server_port);
|
||||
h(id, remote_addr, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Called from conn.rs when an incoming connection is closed.
|
||||
// We deregister the connection from server_conns and call the
|
||||
|
@ -371,16 +306,19 @@ impl NetApp {
|
|||
if let Some(lp) = self.listen_params.load_full() {
|
||||
let server_addr = lp.public_addr;
|
||||
let server_port = lp.listen_addr.port();
|
||||
let hello_endpoint = self.hello_endpoint.load_full().unwrap();
|
||||
tokio::spawn(async move {
|
||||
conn.request(
|
||||
HelloMessage {
|
||||
server_addr,
|
||||
server_port,
|
||||
},
|
||||
PRIO_NORMAL,
|
||||
)
|
||||
.await
|
||||
.log_err("Sending hello message");
|
||||
hello_endpoint
|
||||
.call(
|
||||
&conn.peer_id,
|
||||
HelloMessage {
|
||||
server_addr,
|
||||
server_port,
|
||||
},
|
||||
PRIO_NORMAL,
|
||||
)
|
||||
.await
|
||||
.log_err("Sending hello message");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -404,44 +342,16 @@ impl NetApp {
|
|||
// else case: happens if connection was removed in .disconnect()
|
||||
// in which case on_disconnected_handler was already called
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a message to a remote host to which a client connection is already
|
||||
/// established, and await their response. The target is the id of the peer we
|
||||
/// want to send the message to.
|
||||
/// The priority is an `u8`, with lower numbers meaning highest priority.
|
||||
pub async fn request<T>(
|
||||
&self,
|
||||
target: &NodeID,
|
||||
rq: T,
|
||||
prio: RequestPriority,
|
||||
) -> Result<<T as Message>::Response, Error>
|
||||
where
|
||||
T: Message + 'static,
|
||||
{
|
||||
if *target == self.id {
|
||||
let handler = self.msg_handlers.load().get(&T::KIND).cloned();
|
||||
match handler {
|
||||
None => Err(Error::Message(format!(
|
||||
"No handler registered for message kind {:08x}",
|
||||
T::KIND
|
||||
))),
|
||||
Some(h) => {
|
||||
let local_handler = &h.local_handler;
|
||||
let res = local_handler(Box::new(rq)).await;
|
||||
let res_t = (res as Box<dyn Any + 'static>)
|
||||
.downcast::<<T as Message>::Response>()
|
||||
.unwrap();
|
||||
Ok(*res_t)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let conn = self.client_conns.read().unwrap().get(target).cloned();
|
||||
match conn {
|
||||
None => Err(Error::Message(format!(
|
||||
"Not connected: {}",
|
||||
hex::encode(target)
|
||||
))),
|
||||
Some(c) => c.request(rq, prio).await,
|
||||
#[async_trait]
|
||||
impl EndpointHandler<HelloMessage> for NetApp {
|
||||
async fn handle(self: &Arc<Self>, msg: HelloMessage, from: NodeID) {
|
||||
if let Some(h) = self.on_connected_handler.load().as_ref() {
|
||||
if let Some(c) = self.server_conns.read().unwrap().get(&from) {
|
||||
let remote_ip = msg.server_addr.unwrap_or_else(|| c.remote_addr.ip());
|
||||
let remote_addr = SocketAddr::new(remote_ip, msg.server_port);
|
||||
h(from, remote_addr, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,12 +4,13 @@ use std::sync::atomic::{self, AtomicU64};
|
|||
use std::sync::{Arc, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use log::{debug, info, trace, warn};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use sodiumoxide::crypto::hash;
|
||||
|
||||
use crate::message::*;
|
||||
use crate::endpoint::*;
|
||||
use crate::netapp::*;
|
||||
use crate::proto::*;
|
||||
use crate::NodeID;
|
||||
|
@ -28,7 +29,6 @@ struct PingMessage {
|
|||
}
|
||||
|
||||
impl Message for PingMessage {
|
||||
const KIND: MessageKind = 0x42001000;
|
||||
type Response = PingMessage;
|
||||
}
|
||||
|
||||
|
@ -38,7 +38,6 @@ struct PeerListMessage {
|
|||
}
|
||||
|
||||
impl Message for PeerListMessage {
|
||||
const KIND: MessageKind = 0x42001001;
|
||||
type Response = PeerListMessage;
|
||||
}
|
||||
|
||||
|
@ -124,6 +123,9 @@ pub struct FullMeshPeeringStrategy {
|
|||
netapp: Arc<NetApp>,
|
||||
known_hosts: RwLock<KnownHosts>,
|
||||
next_ping_id: AtomicU64,
|
||||
|
||||
ping_endpoint: Arc<Endpoint<PingMessage, Self>>,
|
||||
peer_list_endpoint: Arc<Endpoint<PeerListMessage, Self>>,
|
||||
}
|
||||
|
||||
impl FullMeshPeeringStrategy {
|
||||
|
@ -147,27 +149,12 @@ impl FullMeshPeeringStrategy {
|
|||
netapp: netapp.clone(),
|
||||
known_hosts: RwLock::new(known_hosts),
|
||||
next_ping_id: AtomicU64::new(42),
|
||||
ping_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/Ping".into()),
|
||||
peer_list_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/PeerList".into()),
|
||||
});
|
||||
|
||||
let strat2 = strat.clone();
|
||||
netapp.add_msg_handler::<PingMessage, _, _>(move |from: NodeID, ping: PingMessage| {
|
||||
let ping_resp = PingMessage {
|
||||
id: ping.id,
|
||||
peer_list_hash: strat2.known_hosts.read().unwrap().hash,
|
||||
};
|
||||
debug!("Ping from {}", hex::encode(&from));
|
||||
async move { ping_resp }
|
||||
});
|
||||
|
||||
let strat2 = strat.clone();
|
||||
netapp.add_msg_handler::<PeerListMessage, _, _>(
|
||||
move |_from: NodeID, peer_list: PeerListMessage| {
|
||||
strat2.handle_peer_list(&peer_list.list[..]);
|
||||
let peer_list = KnownHosts::map_into_vec(&strat2.known_hosts.read().unwrap().list);
|
||||
let resp = PeerListMessage { list: peer_list };
|
||||
async move { resp }
|
||||
},
|
||||
);
|
||||
strat.ping_endpoint.set_handler(strat.clone());
|
||||
strat.peer_list_endpoint.set_handler(strat.clone());
|
||||
|
||||
let strat2 = strat.clone();
|
||||
netapp.on_connected(move |id: NodeID, addr: SocketAddr, is_incoming: bool| {
|
||||
|
@ -262,7 +249,7 @@ impl FullMeshPeeringStrategy {
|
|||
hex::encode(id),
|
||||
ping_time
|
||||
);
|
||||
match self.netapp.request(&id, ping_msg, PRIO_HIGH).await {
|
||||
match self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH).await {
|
||||
Err(e) => warn!("Error pinging {}: {}", hex::encode(id), e),
|
||||
Ok(ping_resp) => {
|
||||
let resp_time = Instant::now();
|
||||
|
@ -291,7 +278,11 @@ impl FullMeshPeeringStrategy {
|
|||
async fn exchange_peers(self: Arc<Self>, id: &NodeID) {
|
||||
let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list);
|
||||
let pex_message = PeerListMessage { list: peer_list };
|
||||
match self.netapp.request(id, pex_message, PRIO_BACKGROUND).await {
|
||||
match self
|
||||
.peer_list_endpoint
|
||||
.call(id, pex_message, PRIO_BACKGROUND)
|
||||
.await
|
||||
{
|
||||
Err(e) => warn!("Error doing peer exchange: {}", e),
|
||||
Ok(resp) => {
|
||||
self.handle_peer_list(&resp.list[..]);
|
||||
|
@ -408,3 +399,28 @@ impl FullMeshPeeringStrategy {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
|
||||
async fn handle(self: &Arc<Self>, ping: PingMessage, from: NodeID) -> PingMessage {
|
||||
let ping_resp = PingMessage {
|
||||
id: ping.id,
|
||||
peer_list_hash: self.known_hosts.read().unwrap().hash,
|
||||
};
|
||||
debug!("Ping from {}", hex::encode(&from));
|
||||
ping_resp
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EndpointHandler<PeerListMessage> for FullMeshPeeringStrategy {
|
||||
async fn handle(
|
||||
self: &Arc<Self>,
|
||||
peer_list: PeerListMessage,
|
||||
_from: NodeID,
|
||||
) -> PeerListMessage {
|
||||
self.handle_peer_list(&peer_list.list[..]);
|
||||
let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list);
|
||||
PeerListMessage { list: peer_list }
|
||||
}
|
||||
}
|
||||
|
|
36
src/proto.rs
36
src/proto.rs
|
@ -38,9 +38,10 @@ pub const PRIO_PRIMARY: RequestPriority = 0x00;
|
|||
/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`)
|
||||
pub const PRIO_SECONDARY: RequestPriority = 0x01;
|
||||
|
||||
const MAX_CHUNK_SIZE: usize = 0x4000;
|
||||
|
||||
pub(crate) type RequestID = u16;
|
||||
pub(crate) type RequestID = u32;
|
||||
type ChunkLength = u16;
|
||||
const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
|
||||
const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
|
||||
|
||||
struct SendQueueItem {
|
||||
id: RequestID,
|
||||
|
@ -85,6 +86,12 @@ impl SendQueue {
|
|||
}
|
||||
}
|
||||
|
||||
// Messages are sent by chunks
|
||||
// Chunk format:
|
||||
// - u32 BE: request id (same for request and response)
|
||||
// - u16 BE: chunk length
|
||||
// - [u8; chunk_length] chunk data
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait SendLoop: Sync {
|
||||
async fn send_loop<W>(
|
||||
|
@ -117,22 +124,23 @@ pub(crate) trait SendLoop: Sync {
|
|||
item.data.len(),
|
||||
item.cursor
|
||||
);
|
||||
let header_id = u16::to_be_bytes(item.id);
|
||||
let header_id = RequestID::to_be_bytes(item.id);
|
||||
write.write_all(&header_id[..]).await?;
|
||||
|
||||
if item.data.len() - item.cursor > MAX_CHUNK_SIZE {
|
||||
let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000);
|
||||
if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize {
|
||||
let header_size =
|
||||
ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION);
|
||||
write.write_all(&header_size[..]).await?;
|
||||
|
||||
let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize;
|
||||
let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize;
|
||||
write.write_all(&item.data[item.cursor..new_cursor]).await?;
|
||||
item.cursor = new_cursor;
|
||||
|
||||
sending.push(item);
|
||||
} else {
|
||||
let send_len = (item.data.len() - item.cursor) as u16;
|
||||
let send_len = (item.data.len() - item.cursor) as ChunkLength;
|
||||
|
||||
let header_size = u16::to_be_bytes(send_len);
|
||||
let header_size = ChunkLength::to_be_bytes(send_len);
|
||||
write.write_all(&header_size[..]).await?;
|
||||
|
||||
write.write_all(&item.data[item.cursor..]).await?;
|
||||
|
@ -172,18 +180,18 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
|||
let mut receiving = HashMap::new();
|
||||
loop {
|
||||
trace!("recv_loop: reading packet");
|
||||
let mut header_id = [0u8; 2];
|
||||
let mut header_id = [0u8; RequestID::BITS as usize / 8];
|
||||
read.read_exact(&mut header_id[..]).await?;
|
||||
let id = RequestID::from_be_bytes(header_id);
|
||||
trace!("recv_loop: got header id: {:04x}", id);
|
||||
|
||||
let mut header_size = [0u8; 2];
|
||||
let mut header_size = [0u8; ChunkLength::BITS as usize / 8];
|
||||
read.read_exact(&mut header_size[..]).await?;
|
||||
let size = RequestID::from_be_bytes(header_size);
|
||||
let size = ChunkLength::from_be_bytes(header_size);
|
||||
trace!("recv_loop: got header size: {:04x}", size);
|
||||
|
||||
let has_cont = (size & 0x8000) != 0;
|
||||
let size = size & !0x8000;
|
||||
let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
|
||||
let size = size & !CHUNK_HAS_CONTINUATION;
|
||||
|
||||
let mut next_slice = vec![0; size as usize];
|
||||
read.read_exact(&mut next_slice[..]).await?;
|
||||
|
|
Loading…
Reference in a new issue