forked from lx/netapp
215 lines
5.9 KiB
Rust
215 lines
5.9 KiB
Rust
|
use std::collections::HashMap;
|
||
|
use std::net::SocketAddr;
|
||
|
use std::pin::Pin;
|
||
|
use std::sync::{Arc, RwLock};
|
||
|
|
||
|
use std::future::Future;
|
||
|
|
||
|
use log::{debug, info};
|
||
|
|
||
|
use arc_swap::{ArcSwap, ArcSwapOption};
|
||
|
use bytes::Bytes;
|
||
|
|
||
|
use sodiumoxide::crypto::auth;
|
||
|
use sodiumoxide::crypto::sign::ed25519;
|
||
|
use tokio::net::{TcpListener, TcpStream};
|
||
|
|
||
|
use crate::conn::*;
|
||
|
use crate::error::*;
|
||
|
use crate::message::*;
|
||
|
use crate::proto::*;
|
||
|
use crate::util::*;
|
||
|
|
||
|
pub struct NetApp {
|
||
|
pub listen_addr: SocketAddr,
|
||
|
pub netid: auth::Key,
|
||
|
pub pubkey: ed25519::PublicKey,
|
||
|
pub privkey: ed25519::SecretKey,
|
||
|
pub server_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ServerConn>>>,
|
||
|
pub client_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ClientConn>>>,
|
||
|
pub(crate) msg_handlers: ArcSwap<
|
||
|
HashMap<
|
||
|
MessageKind,
|
||
|
Arc<
|
||
|
dyn Fn(
|
||
|
ed25519::PublicKey,
|
||
|
Bytes,
|
||
|
) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>>
|
||
|
+ Sync
|
||
|
+ Send,
|
||
|
>,
|
||
|
>,
|
||
|
>,
|
||
|
pub(crate) on_connected:
|
||
|
ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, SocketAddr, bool) + Send + Sync>>,
|
||
|
pub(crate) on_disconnected: ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, bool) + Send + Sync>>,
|
||
|
}
|
||
|
|
||
|
async fn handler_aux<M, F, R>(handler: Arc<F>, remote: ed25519::PublicKey, bytes: Bytes) -> Vec<u8>
|
||
|
where
|
||
|
M: Message + 'static,
|
||
|
F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
|
||
|
R: Future<Output = Result<<M as Message>::Response, Error>> + Send + Sync,
|
||
|
{
|
||
|
debug!(
|
||
|
"Handling message of kind {:08x} from {}",
|
||
|
M::KIND,
|
||
|
hex::encode(remote)
|
||
|
);
|
||
|
let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) {
|
||
|
Ok(msg) => handler(remote.clone(), msg).await,
|
||
|
Err(e) => Err(e.into()),
|
||
|
};
|
||
|
let res = res.map_err(|e| format!("{}", e));
|
||
|
rmp_to_vec_all_named(&res).unwrap_or(vec![])
|
||
|
}
|
||
|
|
||
|
impl NetApp {
|
||
|
pub fn new(
|
||
|
listen_addr: SocketAddr,
|
||
|
netid: auth::Key,
|
||
|
privkey: ed25519::SecretKey,
|
||
|
) -> Arc<Self> {
|
||
|
let pubkey = privkey.public_key();
|
||
|
let netapp = Arc::new(Self {
|
||
|
listen_addr,
|
||
|
netid,
|
||
|
pubkey,
|
||
|
privkey,
|
||
|
server_conns: RwLock::new(HashMap::new()),
|
||
|
client_conns: RwLock::new(HashMap::new()),
|
||
|
msg_handlers: ArcSwap::new(Arc::new(HashMap::new())),
|
||
|
on_connected: ArcSwapOption::new(None),
|
||
|
on_disconnected: ArcSwapOption::new(None),
|
||
|
});
|
||
|
|
||
|
let netapp2 = netapp.clone();
|
||
|
netapp.add_msg_handler::<HelloMessage, _, _>(
|
||
|
move |from: ed25519::PublicKey, msg: HelloMessage| {
|
||
|
netapp2.handle_hello_message(from, msg);
|
||
|
async { Ok(()) }
|
||
|
},
|
||
|
);
|
||
|
|
||
|
netapp
|
||
|
}
|
||
|
|
||
|
pub fn add_msg_handler<M, F, R>(&self, handler: F)
|
||
|
where
|
||
|
M: Message + 'static,
|
||
|
F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
|
||
|
R: Future<Output = Result<<M as Message>::Response, Error>> + Send + Sync + 'static,
|
||
|
{
|
||
|
let handler = Arc::new(handler);
|
||
|
let fun = Arc::new(move |remote: ed25519::PublicKey, bytes: Bytes| {
|
||
|
let fun: Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> =
|
||
|
Box::pin(handler_aux(handler.clone(), remote, bytes));
|
||
|
fun
|
||
|
});
|
||
|
let mut handlers = self.msg_handlers.load().as_ref().clone();
|
||
|
handlers.insert(M::KIND, fun);
|
||
|
self.msg_handlers.store(Arc::new(handlers));
|
||
|
}
|
||
|
|
||
|
pub async fn listen(self: Arc<Self>) {
|
||
|
let mut listener = TcpListener::bind(self.listen_addr).await.unwrap();
|
||
|
info!("Listening on {}", self.listen_addr);
|
||
|
|
||
|
loop {
|
||
|
// The second item contains the IP and port of the new connection.
|
||
|
let (socket, _) = listener.accept().await.unwrap();
|
||
|
info!(
|
||
|
"Incoming connection from {}, negotiating handshake...",
|
||
|
socket.peer_addr().unwrap()
|
||
|
);
|
||
|
let self2 = self.clone();
|
||
|
tokio::spawn(async move {
|
||
|
ServerConn::run(self2, socket)
|
||
|
.await
|
||
|
.log_err("ServerConn::run");
|
||
|
});
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub async fn try_connect(
|
||
|
self: Arc<Self>,
|
||
|
ip: SocketAddr,
|
||
|
pk: ed25519::PublicKey,
|
||
|
) -> Result<(), Error> {
|
||
|
if self.client_conns.read().unwrap().contains_key(&pk) {
|
||
|
return Ok(());
|
||
|
}
|
||
|
let socket = TcpStream::connect(ip).await?;
|
||
|
info!("Connected to {}, negotiating handshake...", ip);
|
||
|
ClientConn::init(self, socket, pk.clone()).await?;
|
||
|
Ok(())
|
||
|
}
|
||
|
|
||
|
pub fn disconnect(self: Arc<Self>, id: &ed25519::PublicKey) {
|
||
|
let conn = self.client_conns.read().unwrap().get(id).cloned();
|
||
|
if let Some(c) = conn {
|
||
|
c.close();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub(crate) fn connected_as_server(&self, id: ed25519::PublicKey, conn: Arc<ServerConn>) {
|
||
|
let mut conn_list = self.server_conns.write().unwrap();
|
||
|
conn_list.insert(id.clone(), conn);
|
||
|
}
|
||
|
|
||
|
fn handle_hello_message(&self, id: ed25519::PublicKey, msg: HelloMessage) {
|
||
|
if let Some(h) = self.on_connected.load().as_ref() {
|
||
|
if let Some(c) = self.server_conns.read().unwrap().get(&id) {
|
||
|
let remote_addr = SocketAddr::new(c.remote_addr.ip(), msg.server_port);
|
||
|
h(id, remote_addr, true);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub(crate) fn disconnected_as_server(&self, id: &ed25519::PublicKey, conn: Arc<ServerConn>) {
|
||
|
let mut conn_list = self.server_conns.write().unwrap();
|
||
|
if let Some(c) = conn_list.get(id) {
|
||
|
if Arc::ptr_eq(c, &conn) {
|
||
|
conn_list.remove(id);
|
||
|
}
|
||
|
|
||
|
if let Some(h) = self.on_disconnected.load().as_ref() {
|
||
|
h(conn.peer_pk, true);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub(crate) fn connected_as_client(&self, id: ed25519::PublicKey, conn: Arc<ClientConn>) {
|
||
|
{
|
||
|
let mut conn_list = self.client_conns.write().unwrap();
|
||
|
if let Some(old_c) = conn_list.insert(id.clone(), conn.clone()) {
|
||
|
tokio::spawn(async move { old_c.close() });
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if let Some(h) = self.on_connected.load().as_ref() {
|
||
|
h(conn.peer_pk, conn.remote_addr, false);
|
||
|
}
|
||
|
|
||
|
tokio::spawn(async move {
|
||
|
let server_port = conn.netapp.listen_addr.port();
|
||
|
conn.request(HelloMessage { server_port }, prio::NORMAL)
|
||
|
.await
|
||
|
.log_err("Sending hello message");
|
||
|
});
|
||
|
}
|
||
|
|
||
|
pub(crate) fn disconnected_as_client(&self, id: &ed25519::PublicKey, conn: Arc<ClientConn>) {
|
||
|
let mut conn_list = self.client_conns.write().unwrap();
|
||
|
if let Some(c) = conn_list.get(id) {
|
||
|
if Arc::ptr_eq(c, &conn) {
|
||
|
conn_list.remove(id);
|
||
|
}
|
||
|
|
||
|
if let Some(h) = self.on_disconnected.load().as_ref() {
|
||
|
h(conn.peer_pk, false);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|