forked from lx/netapp
Do not close connections immediately on close signal, await for remaining responses
This commit is contained in:
parent
83789a3076
commit
5a9ae8615e
9 changed files with 192 additions and 235 deletions
|
@ -1,20 +1,20 @@
|
||||||
|
use std::io::Write;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::io::Write;
|
|
||||||
|
|
||||||
use log::{debug, info, warn};
|
use log::{debug, info, warn};
|
||||||
|
|
||||||
use structopt::StructOpt;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use structopt::StructOpt;
|
||||||
|
|
||||||
use sodiumoxide::crypto::auth;
|
use sodiumoxide::crypto::auth;
|
||||||
use sodiumoxide::crypto::sign::ed25519;
|
use sodiumoxide::crypto::sign::ed25519;
|
||||||
|
|
||||||
use netapp::NetApp;
|
|
||||||
use netapp::peering::basalt::*;
|
|
||||||
use netapp::message::*;
|
use netapp::message::*;
|
||||||
|
use netapp::peering::basalt::*;
|
||||||
use netapp::proto::*;
|
use netapp::proto::*;
|
||||||
|
use netapp::NetApp;
|
||||||
|
|
||||||
#[derive(StructOpt, Debug)]
|
#[derive(StructOpt, Debug)]
|
||||||
#[structopt(name = "netapp")]
|
#[structopt(name = "netapp")]
|
||||||
|
@ -52,7 +52,8 @@ async fn main() {
|
||||||
env_logger::Builder::new()
|
env_logger::Builder::new()
|
||||||
.parse_env("RUST_LOG")
|
.parse_env("RUST_LOG")
|
||||||
.format(|buf, record| {
|
.format(|buf, record| {
|
||||||
writeln!(buf,
|
writeln!(
|
||||||
|
buf,
|
||||||
"{} {} {} {}",
|
"{} {} {} {}",
|
||||||
chrono::Local::now().format("%s%.6f"),
|
chrono::Local::now().format("%s%.6f"),
|
||||||
record.module_path().unwrap_or("_"),
|
record.module_path().unwrap_or("_"),
|
||||||
|
@ -62,7 +63,6 @@ async fn main() {
|
||||||
})
|
})
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
|
|
||||||
let opt = Opt::from_args();
|
let opt = Opt::from_args();
|
||||||
|
|
||||||
let netid = match &opt.network_key {
|
let netid = match &opt.network_key {
|
||||||
|
@ -108,9 +108,11 @@ async fn main() {
|
||||||
|_from: ed25519::PublicKey, msg: ExampleMessage| {
|
|_from: ed25519::PublicKey, msg: ExampleMessage| {
|
||||||
debug!("Got example message: {:?}, sending example response", msg);
|
debug!("Got example message: {:?}, sending example response", msg);
|
||||||
async {
|
async {
|
||||||
ExampleResponse{example_field: false}
|
ExampleResponse {
|
||||||
|
example_field: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
tokio::join!(
|
tokio::join!(
|
||||||
|
@ -120,8 +122,6 @@ async fn main() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async fn sampling_loop(netapp: Arc<NetApp>, basalt: Arc<Basalt>) {
|
async fn sampling_loop(netapp: Arc<NetApp>, basalt: Arc<Basalt>) {
|
||||||
loop {
|
loop {
|
||||||
tokio::time::delay_for(Duration::from_secs(10)).await;
|
tokio::time::delay_for(Duration::from_secs(10)).await;
|
||||||
|
@ -132,9 +132,10 @@ async fn sampling_loop(netapp: Arc<NetApp>, basalt: Arc<Basalt>) {
|
||||||
|
|
||||||
let netapp2 = netapp.clone();
|
let netapp2 = netapp.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
match netapp2.request(&p, ExampleMessage{
|
match netapp2
|
||||||
example_field: 42,
|
.request(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL)
|
||||||
}, PRIO_NORMAL).await {
|
.await
|
||||||
|
{
|
||||||
Ok(resp) => debug!("Got example response: {:?}", resp),
|
Ok(resp) => debug!("Got example response: {:?}", resp),
|
||||||
Err(e) => warn!("Error with example request: {}", e),
|
Err(e) => warn!("Error with example request: {}", e),
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use std::net::SocketAddr;
|
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
use log::info;
|
use log::info;
|
||||||
|
|
||||||
|
@ -32,7 +32,8 @@ async fn main() {
|
||||||
env_logger::Builder::new()
|
env_logger::Builder::new()
|
||||||
.parse_env("RUST_LOG")
|
.parse_env("RUST_LOG")
|
||||||
.format(|buf, record| {
|
.format(|buf, record| {
|
||||||
writeln!(buf,
|
writeln!(
|
||||||
|
buf,
|
||||||
"{} {} {} {}",
|
"{} {} {} {}",
|
||||||
chrono::Local::now().format("%s%.6f"),
|
chrono::Local::now().format("%s%.6f"),
|
||||||
record.module_path().unwrap_or("_"),
|
record.module_path().unwrap_or("_"),
|
||||||
|
|
137
src/conn.rs
137
src/conn.rs
|
@ -1,17 +1,18 @@
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::atomic::{self, AtomicU16};
|
use std::sync::atomic::{self, AtomicBool, AtomicU16};
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use log::{debug, trace};
|
use log::{debug, error, trace};
|
||||||
|
|
||||||
use sodiumoxide::crypto::sign::ed25519;
|
use sodiumoxide::crypto::sign::ed25519;
|
||||||
use tokio::io::split;
|
use tokio::io::split;
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::sync::{mpsc, oneshot, watch};
|
use tokio::sync::{mpsc, oneshot, watch};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
use kuska_handshake::async_std::{
|
use kuska_handshake::async_std::{
|
||||||
handshake_client, handshake_server, BoxStream, TokioCompatExt, TokioCompatExtRead,
|
handshake_client, handshake_server, BoxStream, TokioCompatExt, TokioCompatExtRead,
|
||||||
TokioCompatExtWrite,
|
TokioCompatExtWrite,
|
||||||
|
@ -29,7 +30,7 @@ pub(crate) struct ServerConn {
|
||||||
|
|
||||||
netapp: Arc<NetApp>,
|
netapp: Arc<NetApp>,
|
||||||
|
|
||||||
resp_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>,
|
resp_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
|
||||||
close_send: watch::Sender<bool>,
|
close_send: watch::Sender<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,9 +79,20 @@ impl ServerConn {
|
||||||
|
|
||||||
let conn2 = conn.clone();
|
let conn2 = conn.clone();
|
||||||
let conn3 = conn.clone();
|
let conn3 = conn.clone();
|
||||||
|
let close_recv2 = close_recv.clone();
|
||||||
tokio::try_join!(
|
tokio::try_join!(
|
||||||
conn2.recv_loop(box_stream_read, close_recv.clone()),
|
async move {
|
||||||
conn3.send_loop(resp_recv, box_stream_write, close_recv.clone()),
|
tokio::select!(
|
||||||
|
r = conn2.recv_loop(box_stream_read) => r,
|
||||||
|
_ = await_exit(close_recv) => Ok(()),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
async move {
|
||||||
|
tokio::select!(
|
||||||
|
r = conn3.send_loop(resp_recv, box_stream_write) => r,
|
||||||
|
_ = await_exit(close_recv2) => Ok(()),
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
.map(|_| ())
|
.map(|_| ())
|
||||||
.log_err("ServerConn recv_loop/send_loop");
|
.log_err("ServerConn recv_loop/send_loop");
|
||||||
|
@ -112,7 +124,7 @@ impl RecvLoop for ServerConn {
|
||||||
let net_handler = &handler.net_handler;
|
let net_handler = &handler.net_handler;
|
||||||
let resp = net_handler(self.peer_pk.clone(), bytes.slice(5..)).await;
|
let resp = net_handler(self.peer_pk.clone(), bytes.slice(5..)).await;
|
||||||
self.resp_send
|
self.resp_send
|
||||||
.send((id, prio, resp))
|
.send(Some((id, prio, resp)))
|
||||||
.log_err("ServerConn recv_handler send resp");
|
.log_err("ServerConn recv_handler send resp");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -121,11 +133,12 @@ pub(crate) struct ClientConn {
|
||||||
pub(crate) remote_addr: SocketAddr,
|
pub(crate) remote_addr: SocketAddr,
|
||||||
pub(crate) peer_pk: ed25519::PublicKey,
|
pub(crate) peer_pk: ed25519::PublicKey,
|
||||||
|
|
||||||
query_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>,
|
query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
|
||||||
|
|
||||||
next_query_number: AtomicU16,
|
next_query_number: AtomicU16,
|
||||||
resp_send: mpsc::UnboundedSender<(RequestID, Vec<u8>)>,
|
inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
|
||||||
resp_notify_send: mpsc::UnboundedSender<(RequestID, oneshot::Sender<Vec<u8>>)>,
|
must_exit: AtomicBool,
|
||||||
close_send: watch::Sender<bool>,
|
stop_recv_loop: watch::Sender<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClientConn {
|
impl ClientConn {
|
||||||
|
@ -163,19 +176,17 @@ impl ClientConn {
|
||||||
BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();
|
BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();
|
||||||
|
|
||||||
let (query_send, query_recv) = mpsc::unbounded_channel();
|
let (query_send, query_recv) = mpsc::unbounded_channel();
|
||||||
let (resp_send, resp_recv) = mpsc::unbounded_channel();
|
|
||||||
let (resp_notify_send, resp_notify_recv) = mpsc::unbounded_channel();
|
|
||||||
|
|
||||||
let (close_send, close_recv) = watch::channel(false);
|
let (stop_recv_loop, stop_recv_loop_recv) = watch::channel(false);
|
||||||
|
|
||||||
let conn = Arc::new(ClientConn {
|
let conn = Arc::new(ClientConn {
|
||||||
remote_addr,
|
remote_addr,
|
||||||
peer_pk: remote_pk.clone(),
|
peer_pk: remote_pk.clone(),
|
||||||
next_query_number: AtomicU16::from(0u16),
|
next_query_number: AtomicU16::from(0u16),
|
||||||
query_send,
|
query_send,
|
||||||
resp_send,
|
inflight: Mutex::new(HashMap::new()),
|
||||||
resp_notify_send,
|
must_exit: AtomicBool::new(false),
|
||||||
close_send,
|
stop_recv_loop,
|
||||||
});
|
});
|
||||||
|
|
||||||
netapp.connected_as_client(remote_pk.clone(), conn.clone());
|
netapp.connected_as_client(remote_pk.clone(), conn.clone());
|
||||||
|
@ -183,11 +194,14 @@ impl ClientConn {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let conn2 = conn.clone();
|
let conn2 = conn.clone();
|
||||||
let conn3 = conn.clone();
|
let conn3 = conn.clone();
|
||||||
let conn4 = conn.clone();
|
|
||||||
tokio::try_join!(
|
tokio::try_join!(
|
||||||
conn2.send_loop(query_recv, box_stream_write, close_recv.clone()),
|
conn2.send_loop(query_recv, box_stream_write),
|
||||||
conn3.recv_loop(box_stream_read, close_recv.clone()),
|
async move {
|
||||||
conn4.dispatch_resp(resp_recv, resp_notify_recv, close_recv.clone()),
|
tokio::select!(
|
||||||
|
r = conn3.recv_loop(box_stream_read) => r,
|
||||||
|
_ = await_exit(stop_recv_loop_recv) => Ok(()),
|
||||||
|
)
|
||||||
|
}
|
||||||
)
|
)
|
||||||
.map(|_| ())
|
.map(|_| ())
|
||||||
.log_err("ClientConn send_loop/recv_loop/dispatch_loop");
|
.log_err("ClientConn send_loop/recv_loop/dispatch_loop");
|
||||||
|
@ -199,51 +213,15 @@ impl ClientConn {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn close(&self) {
|
pub fn close(&self) {
|
||||||
self.close_send.broadcast(true).unwrap();
|
self.must_exit.store(true, atomic::Ordering::SeqCst);
|
||||||
|
self.query_send
|
||||||
|
.send(None)
|
||||||
|
.log_err("could not write None in query_send");
|
||||||
|
if self.inflight.lock().unwrap().is_empty() {
|
||||||
|
self.stop_recv_loop
|
||||||
|
.broadcast(true)
|
||||||
|
.log_err("could not write true to stop_recv_loop");
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn dispatch_resp(
|
|
||||||
self: Arc<Self>,
|
|
||||||
mut resp_recv: mpsc::UnboundedReceiver<(RequestID, Vec<u8>)>,
|
|
||||||
mut resp_notify_recv: mpsc::UnboundedReceiver<(RequestID, oneshot::Sender<Vec<u8>>)>,
|
|
||||||
mut must_exit: watch::Receiver<bool>,
|
|
||||||
) -> Result<(), Error> {
|
|
||||||
let mut resps: HashMap<RequestID, Vec<u8>> = HashMap::new();
|
|
||||||
let mut resp_notify: HashMap<RequestID, oneshot::Sender<Vec<u8>>> = HashMap::new();
|
|
||||||
while !*must_exit.borrow() {
|
|
||||||
tokio::select! {
|
|
||||||
resp = resp_recv.recv() => {
|
|
||||||
if let Some((id, resp)) = resp {
|
|
||||||
trace!("dispatch_resp: got resp to {}, {} bytes", id, resp.len());
|
|
||||||
if let Some(ch) = resp_notify.remove(&id) {
|
|
||||||
if ch.send(resp).is_err() {
|
|
||||||
debug!("Could not dispatch reply (channel probably closed, happens if request was canceled)");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
resps.insert(id, resp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
resp_ch = resp_notify_recv.recv() => {
|
|
||||||
if let Some((id, resp_ch)) = resp_ch {
|
|
||||||
trace!("dispatch_resp: got resp_ch {}", id);
|
|
||||||
if let Some(rs) = resps.remove(&id) {
|
|
||||||
if resp_ch.send(rs).is_err() {
|
|
||||||
debug!("Could not dispatch reply (channel probably closed, happens if request was canceled)");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
resp_notify.insert(id, resp_ch);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
exit = must_exit.recv() => {
|
|
||||||
if exit == Some(true) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn request<T>(
|
pub(crate) async fn request<T>(
|
||||||
|
@ -262,10 +240,18 @@ impl ClientConn {
|
||||||
bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]);
|
bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]);
|
||||||
|
|
||||||
let (resp_send, resp_recv) = oneshot::channel();
|
let (resp_send, resp_recv) = oneshot::channel();
|
||||||
self.resp_notify_send.send((id, resp_send))?;
|
let old = self.inflight.lock().unwrap().insert(id, resp_send);
|
||||||
|
if let Some(old_ch) = old {
|
||||||
|
error!(
|
||||||
|
"Too many inflight requests! RequestID collision. Interrupting previous request."
|
||||||
|
);
|
||||||
|
if old_ch.send(vec![]).is_err() {
|
||||||
|
debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
trace!("request: query_send {}, {} bytes", id, bytes.len());
|
trace!("request: query_send {}, {} bytes", id, bytes.len());
|
||||||
self.query_send.send((id, prio, bytes))?;
|
self.query_send.send(Some((id, prio, bytes)))?;
|
||||||
|
|
||||||
let resp = resp_recv.await?;
|
let resp = resp_recv.await?;
|
||||||
|
|
||||||
|
@ -279,8 +265,17 @@ impl SendLoop for ClientConn {}
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RecvLoop for ClientConn {
|
impl RecvLoop for ClientConn {
|
||||||
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>) {
|
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>) {
|
||||||
self.resp_send
|
let mut inflight = self.inflight.lock().unwrap();
|
||||||
.send((id, msg))
|
if let Some(ch) = inflight.remove(&id) {
|
||||||
.log_err("ClientConn::recv_handler");
|
if ch.send(msg).is_err() {
|
||||||
|
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if inflight.is_empty() && self.must_exit.load(atomic::Ordering::SeqCst) {
|
||||||
|
self.stop_recv_loop
|
||||||
|
.broadcast(true)
|
||||||
|
.log_err("could not write true to stop_recv_loop");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,8 +18,8 @@
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod util;
|
pub mod util;
|
||||||
|
|
||||||
pub mod proto;
|
|
||||||
pub mod message;
|
pub mod message;
|
||||||
|
pub mod proto;
|
||||||
|
|
||||||
mod conn;
|
mod conn;
|
||||||
|
|
||||||
|
|
|
@ -133,18 +133,22 @@ impl NetApp {
|
||||||
/// been successfully established. Do not set this if using a peering strategy,
|
/// been successfully established. Do not set this if using a peering strategy,
|
||||||
/// as the peering strategy will need to set this itself.
|
/// as the peering strategy will need to set this itself.
|
||||||
pub fn on_connected<F>(&self, handler: F)
|
pub fn on_connected<F>(&self, handler: F)
|
||||||
where F: Fn(ed25519::PublicKey, SocketAddr, bool) + Sized + Send + Sync + 'static
|
where
|
||||||
|
F: Fn(ed25519::PublicKey, SocketAddr, bool) + Sized + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
self.on_connected_handler.store(Some(Arc::new(Box::new(handler))));
|
self.on_connected_handler
|
||||||
|
.store(Some(Arc::new(Box::new(handler))));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the handler to be called when an existing connection (incoming or outgoing) has
|
/// Set the handler to be called when an existing connection (incoming or outgoing) has
|
||||||
/// been closed by either party. Do not set this if using a peering strategy,
|
/// been closed by either party. Do not set this if using a peering strategy,
|
||||||
/// as the peering strategy will need to set this itself.
|
/// as the peering strategy will need to set this itself.
|
||||||
pub fn on_disconnected<F>(&self, handler: F)
|
pub fn on_disconnected<F>(&self, handler: F)
|
||||||
where F: Fn(ed25519::PublicKey, bool) + Sized + Send + Sync + 'static
|
where
|
||||||
|
F: Fn(ed25519::PublicKey, bool) + Sized + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
self.on_disconnected_handler.store(Some(Arc::new(Box::new(handler))));
|
self.on_disconnected_handler
|
||||||
|
.store(Some(Arc::new(Box::new(handler))));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a handler for a certain message type. Note that only one handler
|
/// Add a handler for a certain message type. Note that only one handler
|
||||||
|
@ -240,11 +244,13 @@ impl NetApp {
|
||||||
pub fn disconnect(self: &Arc<Self>, pk: &ed25519::PublicKey) {
|
pub fn disconnect(self: &Arc<Self>, pk: &ed25519::PublicKey) {
|
||||||
// If pk is ourself, we're not supposed to have a connection open
|
// If pk is ourself, we're not supposed to have a connection open
|
||||||
if *pk != self.pubkey {
|
if *pk != self.pubkey {
|
||||||
let conn = self.client_conns.read().unwrap().remove(pk);
|
let conn = self.client_conns.write().unwrap().remove(pk);
|
||||||
if let Some(c) = conn {
|
if let Some(c) = conn {
|
||||||
debug!("Closing connection to {} ({})",
|
debug!(
|
||||||
|
"Closing connection to {} ({})",
|
||||||
hex::encode(c.peer_pk),
|
hex::encode(c.peer_pk),
|
||||||
c.remote_addr);
|
c.remote_addr
|
||||||
|
);
|
||||||
c.close();
|
c.close();
|
||||||
} else {
|
} else {
|
||||||
return;
|
return;
|
||||||
|
@ -268,9 +274,11 @@ impl NetApp {
|
||||||
pub fn server_disconnect(self: &Arc<Self>, pk: &ed25519::PublicKey) {
|
pub fn server_disconnect(self: &Arc<Self>, pk: &ed25519::PublicKey) {
|
||||||
let conn = self.server_conns.read().unwrap().get(pk).cloned();
|
let conn = self.server_conns.read().unwrap().get(pk).cloned();
|
||||||
if let Some(c) = conn {
|
if let Some(c) = conn {
|
||||||
debug!("Closing incoming connection from {} ({})",
|
debug!(
|
||||||
|
"Closing incoming connection from {} ({})",
|
||||||
hex::encode(c.peer_pk),
|
hex::encode(c.peer_pk),
|
||||||
c.remote_addr);
|
c.remote_addr
|
||||||
|
);
|
||||||
c.close();
|
c.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@ use std::net::SocketAddr;
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use log::{trace, debug, info, warn};
|
use log::{debug, info, trace, warn};
|
||||||
use lru::LruCache;
|
use lru::LruCache;
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -267,15 +267,13 @@ impl Basalt {
|
||||||
netapp.on_connected(
|
netapp.on_connected(
|
||||||
move |pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool| {
|
move |pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool| {
|
||||||
basalt2.on_connected(pk, addr, is_incoming);
|
basalt2.on_connected(pk, addr, is_incoming);
|
||||||
}
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
let basalt2 = basalt.clone();
|
let basalt2 = basalt.clone();
|
||||||
netapp.on_disconnected(
|
netapp.on_disconnected(move |pk: ed25519::PublicKey, is_incoming: bool| {
|
||||||
move |pk: ed25519::PublicKey, is_incoming: bool| {
|
|
||||||
basalt2.on_disconnected(pk, is_incoming);
|
basalt2.on_disconnected(pk, is_incoming);
|
||||||
},
|
});
|
||||||
);
|
|
||||||
|
|
||||||
let basalt2 = basalt.clone();
|
let basalt2 = basalt.clone();
|
||||||
netapp.add_msg_handler::<PullMessage, _, _>(
|
netapp.add_msg_handler::<PullMessage, _, _>(
|
||||||
|
|
|
@ -185,12 +185,10 @@ impl FullMeshPeeringStrategy {
|
||||||
);
|
);
|
||||||
|
|
||||||
let strat2 = strat.clone();
|
let strat2 = strat.clone();
|
||||||
netapp.on_disconnected(
|
netapp.on_disconnected(move |pk: ed25519::PublicKey, is_incoming: bool| {
|
||||||
move |pk: ed25519::PublicKey, is_incoming: bool| {
|
|
||||||
let strat2 = strat2.clone();
|
let strat2 = strat2.clone();
|
||||||
tokio::spawn(strat2.on_disconnected(pk, is_incoming));
|
tokio::spawn(strat2.on_disconnected(pk, is_incoming));
|
||||||
},
|
});
|
||||||
);
|
|
||||||
|
|
||||||
strat
|
strat
|
||||||
}
|
}
|
||||||
|
|
122
src/proto.rs
122
src/proto.rs
|
@ -3,14 +3,14 @@ use std::sync::Arc;
|
||||||
|
|
||||||
use log::trace;
|
use log::trace;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
|
|
||||||
use async_std::io::prelude::WriteExt;
|
use async_std::io::prelude::WriteExt;
|
||||||
use async_std::io::ReadExt;
|
use async_std::io::ReadExt;
|
||||||
|
|
||||||
use tokio::io::{ReadHalf, WriteHalf};
|
use tokio::io::{ReadHalf, WriteHalf};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::sync::{mpsc, watch};
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
use crate::error::*;
|
use crate::error::*;
|
||||||
|
|
||||||
|
@ -85,19 +85,23 @@ impl SendQueue {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fn is_empty(&self) -> bool {
|
||||||
|
self.items.iter().all(|(_k, v)| v.is_empty())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub(crate) trait SendLoop: Sync {
|
pub(crate) trait SendLoop: Sync {
|
||||||
async fn send_loop(
|
async fn send_loop(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
|
mut msg_recv: mpsc::UnboundedReceiver<Option<(RequestID, RequestPriority, Vec<u8>)>>,
|
||||||
mut write: BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
|
mut write: BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
|
||||||
mut must_exit: watch::Receiver<bool>,
|
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
let mut sending = SendQueue::new();
|
let mut sending = SendQueue::new();
|
||||||
while !*must_exit.borrow() {
|
let mut should_exit = false;
|
||||||
if let Ok((id, prio, data)) = msg_recv.try_recv() {
|
while !should_exit || !sending.is_empty() {
|
||||||
|
if let Ok(sth) = msg_recv.try_recv() {
|
||||||
|
if let Some((id, prio, data)) = sth {
|
||||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||||
sending.push(SendQueueItem {
|
sending.push(SendQueueItem {
|
||||||
id,
|
id,
|
||||||
|
@ -105,6 +109,9 @@ pub(crate) trait SendLoop: Sync {
|
||||||
data,
|
data,
|
||||||
cursor: 0,
|
cursor: 0,
|
||||||
});
|
});
|
||||||
|
} else {
|
||||||
|
should_exit = true;
|
||||||
|
}
|
||||||
} else if let Some(mut item) = sending.pop() {
|
} else if let Some(mut item) = sending.pop() {
|
||||||
trace!(
|
trace!(
|
||||||
"send_loop: sending bytes for {} ({} bytes, {} already sent)",
|
"send_loop: sending bytes for {} ({} bytes, {} already sent)",
|
||||||
|
@ -113,33 +120,14 @@ pub(crate) trait SendLoop: Sync {
|
||||||
item.cursor
|
item.cursor
|
||||||
);
|
);
|
||||||
let header_id = u16::to_be_bytes(item.id);
|
let header_id = u16::to_be_bytes(item.id);
|
||||||
if write_all_or_exit(&header_id[..], &mut write, &mut must_exit)
|
write.write_all(&header_id[..]).await?;
|
||||||
.await?
|
|
||||||
.is_none()
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if item.data.len() - item.cursor > MAX_CHUNK_SIZE {
|
if item.data.len() - item.cursor > MAX_CHUNK_SIZE {
|
||||||
let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000);
|
let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000);
|
||||||
if write_all_or_exit(&header_size[..], &mut write, &mut must_exit)
|
write.write_all(&header_size[..]).await?;
|
||||||
.await?
|
|
||||||
.is_none()
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize;
|
let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize;
|
||||||
if write_all_or_exit(
|
write.write_all(&item.data[item.cursor..new_cursor]).await?;
|
||||||
&item.data[item.cursor..new_cursor],
|
|
||||||
&mut write,
|
|
||||||
&mut must_exit,
|
|
||||||
)
|
|
||||||
.await?
|
|
||||||
.is_none()
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
item.cursor = new_cursor;
|
item.cursor = new_cursor;
|
||||||
|
|
||||||
sending.push(item);
|
sending.push(item);
|
||||||
|
@ -147,26 +135,17 @@ pub(crate) trait SendLoop: Sync {
|
||||||
let send_len = (item.data.len() - item.cursor) as u16;
|
let send_len = (item.data.len() - item.cursor) as u16;
|
||||||
|
|
||||||
let header_size = u16::to_be_bytes(send_len);
|
let header_size = u16::to_be_bytes(send_len);
|
||||||
if write_all_or_exit(&header_size[..], &mut write, &mut must_exit)
|
write.write_all(&header_size[..]).await?;
|
||||||
.await?
|
|
||||||
.is_none()
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if write_all_or_exit(&item.data[item.cursor..], &mut write, &mut must_exit)
|
write.write_all(&item.data[item.cursor..]).await?;
|
||||||
.await?
|
|
||||||
.is_none()
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
write.flush().await.log_err("Could not flush in send_loop");
|
write.flush().await.log_err("Could not flush in send_loop");
|
||||||
} else {
|
} else {
|
||||||
let (id, prio, data) = msg_recv
|
let sth = msg_recv
|
||||||
.recv()
|
.recv()
|
||||||
.await
|
.await
|
||||||
.ok_or(Error::Message("Connection closed.".into()))?;
|
.ok_or(Error::Message("Connection closed.".into()))?;
|
||||||
|
if let Some((id, prio, data)) = sth {
|
||||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||||
sending.push(SendQueueItem {
|
sending.push(SendQueueItem {
|
||||||
id,
|
id,
|
||||||
|
@ -174,6 +153,9 @@ pub(crate) trait SendLoop: Sync {
|
||||||
data,
|
data,
|
||||||
cursor: 0,
|
cursor: 0,
|
||||||
});
|
});
|
||||||
|
} else {
|
||||||
|
should_exit = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -182,33 +164,23 @@ pub(crate) trait SendLoop: Sync {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub(crate) trait RecvLoop: Sync + 'static {
|
pub(crate) trait RecvLoop: Sync + 'static {
|
||||||
|
// Returns true if we should stop receiving after this
|
||||||
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>);
|
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>);
|
||||||
|
|
||||||
async fn recv_loop(
|
async fn recv_loop(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
mut read: BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
|
mut read: BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
|
||||||
mut must_exit: watch::Receiver<bool>,
|
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
let mut receiving = HashMap::new();
|
let mut receiving = HashMap::new();
|
||||||
while !*must_exit.borrow() {
|
loop {
|
||||||
trace!("recv_loop: reading packet");
|
trace!("recv_loop: reading packet");
|
||||||
let mut header_id = [0u8; 2];
|
let mut header_id = [0u8; 2];
|
||||||
if read_exact_or_exit(&mut header_id[..], &mut read, &mut must_exit)
|
read.read_exact(&mut header_id[..]).await?;
|
||||||
.await?
|
|
||||||
.is_none()
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let id = RequestID::from_be_bytes(header_id);
|
let id = RequestID::from_be_bytes(header_id);
|
||||||
trace!("recv_loop: got header id: {:04x}", id);
|
trace!("recv_loop: got header id: {:04x}", id);
|
||||||
|
|
||||||
let mut header_size = [0u8; 2];
|
let mut header_size = [0u8; 2];
|
||||||
if read_exact_or_exit(&mut header_size[..], &mut read, &mut must_exit)
|
read.read_exact(&mut header_size[..]).await?;
|
||||||
.await?
|
|
||||||
.is_none()
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let size = RequestID::from_be_bytes(header_size);
|
let size = RequestID::from_be_bytes(header_size);
|
||||||
trace!("recv_loop: got header size: {:04x}", id);
|
trace!("recv_loop: got header size: {:04x}", id);
|
||||||
|
|
||||||
|
@ -216,12 +188,7 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
||||||
let size = size & !0x8000;
|
let size = size & !0x8000;
|
||||||
|
|
||||||
let mut next_slice = vec![0; size as usize];
|
let mut next_slice = vec![0; size as usize];
|
||||||
if read_exact_or_exit(&mut next_slice[..], &mut read, &mut must_exit)
|
read.read_exact(&mut next_slice[..]).await?;
|
||||||
.await?
|
|
||||||
.is_none()
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
trace!("recv_loop: read {} bytes", size);
|
trace!("recv_loop: read {} bytes", size);
|
||||||
|
|
||||||
let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]);
|
let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]);
|
||||||
|
@ -233,36 +200,5 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
||||||
tokio::spawn(self.clone().recv_handler(id, msg_bytes));
|
tokio::spawn(self.clone().recv_handler(id, msg_bytes));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn read_exact_or_exit(
|
|
||||||
buf: &mut [u8],
|
|
||||||
read: &mut BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
|
|
||||||
must_exit: &mut watch::Receiver<bool>,
|
|
||||||
) -> Result<Option<()>, Error> {
|
|
||||||
tokio::select!(
|
|
||||||
res = read.read_exact(buf) => Ok(Some(res?)),
|
|
||||||
_ = await_exit(must_exit) => Ok(None),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn write_all_or_exit(
|
|
||||||
buf: &[u8],
|
|
||||||
write: &mut BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
|
|
||||||
must_exit: &mut watch::Receiver<bool>,
|
|
||||||
) -> Result<Option<()>, Error> {
|
|
||||||
tokio::select!(
|
|
||||||
res = write.write_all(buf) => Ok(Some(res?)),
|
|
||||||
_ = await_exit(must_exit) => Ok(None),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn await_exit(must_exit: &mut watch::Receiver<bool>) {
|
|
||||||
loop {
|
|
||||||
if must_exit.recv().await == Some(true) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
20
src/util.rs
20
src/util.rs
|
@ -1,5 +1,7 @@
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use tokio::sync::watch;
|
||||||
|
|
||||||
/// Utility function: encodes any serializable value in MessagePack binary format
|
/// Utility function: encodes any serializable value in MessagePack binary format
|
||||||
/// using the RMP library.
|
/// using the RMP library.
|
||||||
///
|
///
|
||||||
|
@ -16,3 +18,21 @@ where
|
||||||
val.serialize(&mut se)?;
|
val.serialize(&mut se)?;
|
||||||
Ok(wr)
|
Ok(wr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// This async function returns only when a true signal was received
|
||||||
|
/// from a watcher that tells us when to exit.
|
||||||
|
/// Usefull in a select statement to interrupt another
|
||||||
|
/// future:
|
||||||
|
/// ```
|
||||||
|
/// select!(
|
||||||
|
/// _ = a_long_task() => Success,
|
||||||
|
/// _ = await_exit(must_exit) => Interrupted,
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
pub async fn await_exit(mut must_exit: watch::Receiver<bool>) {
|
||||||
|
loop {
|
||||||
|
if must_exit.recv().await == Some(true) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue