Do not close connections immediately on close signal, await for remaining responses

This commit is contained in:
Alex 2020-12-07 13:35:24 +01:00
parent 83789a3076
commit 5a9ae8615e
9 changed files with 192 additions and 235 deletions

View file

@ -1,20 +1,20 @@
use std::io::Write;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use std::io::Write;
use log::{debug, info, warn};
use structopt::StructOpt;
use serde::{Deserialize, Serialize};
use structopt::StructOpt;
use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519;
use netapp::NetApp;
use netapp::peering::basalt::*;
use netapp::message::*;
use netapp::peering::basalt::*;
use netapp::proto::*;
use netapp::NetApp;
#[derive(StructOpt, Debug)]
#[structopt(name = "netapp")]
@ -52,17 +52,17 @@ async fn main() {
env_logger::Builder::new()
.parse_env("RUST_LOG")
.format(|buf, record| {
writeln!(buf,
"{} {} {} {}",
chrono::Local::now().format("%s%.6f"),
record.module_path().unwrap_or("_"),
record.level(),
record.args()
)
writeln!(
buf,
"{} {} {} {}",
chrono::Local::now().format("%s%.6f"),
record.module_path().unwrap_or("_"),
record.level(),
record.args()
)
})
.init();
let opt = Opt::from_args();
let netid = match &opt.network_key {
@ -108,10 +108,12 @@ async fn main() {
|_from: ed25519::PublicKey, msg: ExampleMessage| {
debug!("Got example message: {:?}, sending example response", msg);
async {
ExampleResponse{example_field: false}
ExampleResponse {
example_field: false,
}
}
}
);
},
);
tokio::join!(
sampling_loop(netapp.clone(), peering.clone()),
@ -120,8 +122,6 @@ async fn main() {
);
}
async fn sampling_loop(netapp: Arc<NetApp>, basalt: Arc<Basalt>) {
loop {
tokio::time::delay_for(Duration::from_secs(10)).await;
@ -132,9 +132,10 @@ async fn sampling_loop(netapp: Arc<NetApp>, basalt: Arc<Basalt>) {
let netapp2 = netapp.clone();
tokio::spawn(async move {
match netapp2.request(&p, ExampleMessage{
example_field: 42,
}, PRIO_NORMAL).await {
match netapp2
.request(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL)
.await
{
Ok(resp) => debug!("Got example response: {:?}", resp),
Err(e) => warn!("Error with example request: {}", e),
}

View file

@ -1,5 +1,5 @@
use std::net::SocketAddr;
use std::io::Write;
use std::net::SocketAddr;
use log::info;
@ -32,13 +32,14 @@ async fn main() {
env_logger::Builder::new()
.parse_env("RUST_LOG")
.format(|buf, record| {
writeln!(buf,
"{} {} {} {}",
chrono::Local::now().format("%s%.6f"),
record.module_path().unwrap_or("_"),
record.level(),
record.args()
)
writeln!(
buf,
"{} {} {} {}",
chrono::Local::now().format("%s%.6f"),
record.module_path().unwrap_or("_"),
record.level(),
record.args()
)
})
.init();

View file

@ -1,17 +1,18 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{self, AtomicU16};
use std::sync::Arc;
use std::sync::atomic::{self, AtomicBool, AtomicU16};
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use bytes::Bytes;
use log::{debug, trace};
use log::{debug, error, trace};
use sodiumoxide::crypto::sign::ed25519;
use tokio::io::split;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot, watch};
use async_trait::async_trait;
use kuska_handshake::async_std::{
handshake_client, handshake_server, BoxStream, TokioCompatExt, TokioCompatExtRead,
TokioCompatExtWrite,
@ -29,7 +30,7 @@ pub(crate) struct ServerConn {
netapp: Arc<NetApp>,
resp_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>,
resp_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
close_send: watch::Sender<bool>,
}
@ -78,9 +79,20 @@ impl ServerConn {
let conn2 = conn.clone();
let conn3 = conn.clone();
let close_recv2 = close_recv.clone();
tokio::try_join!(
conn2.recv_loop(box_stream_read, close_recv.clone()),
conn3.send_loop(resp_recv, box_stream_write, close_recv.clone()),
async move {
tokio::select!(
r = conn2.recv_loop(box_stream_read) => r,
_ = await_exit(close_recv) => Ok(()),
)
},
async move {
tokio::select!(
r = conn3.send_loop(resp_recv, box_stream_write) => r,
_ = await_exit(close_recv2) => Ok(()),
)
},
)
.map(|_| ())
.log_err("ServerConn recv_loop/send_loop");
@ -112,7 +124,7 @@ impl RecvLoop for ServerConn {
let net_handler = &handler.net_handler;
let resp = net_handler(self.peer_pk.clone(), bytes.slice(5..)).await;
self.resp_send
.send((id, prio, resp))
.send(Some((id, prio, resp)))
.log_err("ServerConn recv_handler send resp");
}
}
@ -121,11 +133,12 @@ pub(crate) struct ClientConn {
pub(crate) remote_addr: SocketAddr,
pub(crate) peer_pk: ed25519::PublicKey,
query_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>,
query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
next_query_number: AtomicU16,
resp_send: mpsc::UnboundedSender<(RequestID, Vec<u8>)>,
resp_notify_send: mpsc::UnboundedSender<(RequestID, oneshot::Sender<Vec<u8>>)>,
close_send: watch::Sender<bool>,
inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
must_exit: AtomicBool,
stop_recv_loop: watch::Sender<bool>,
}
impl ClientConn {
@ -163,19 +176,17 @@ impl ClientConn {
BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();
let (query_send, query_recv) = mpsc::unbounded_channel();
let (resp_send, resp_recv) = mpsc::unbounded_channel();
let (resp_notify_send, resp_notify_recv) = mpsc::unbounded_channel();
let (close_send, close_recv) = watch::channel(false);
let (stop_recv_loop, stop_recv_loop_recv) = watch::channel(false);
let conn = Arc::new(ClientConn {
remote_addr,
peer_pk: remote_pk.clone(),
next_query_number: AtomicU16::from(0u16),
query_send,
resp_send,
resp_notify_send,
close_send,
inflight: Mutex::new(HashMap::new()),
must_exit: AtomicBool::new(false),
stop_recv_loop,
});
netapp.connected_as_client(remote_pk.clone(), conn.clone());
@ -183,11 +194,14 @@ impl ClientConn {
tokio::spawn(async move {
let conn2 = conn.clone();
let conn3 = conn.clone();
let conn4 = conn.clone();
tokio::try_join!(
conn2.send_loop(query_recv, box_stream_write, close_recv.clone()),
conn3.recv_loop(box_stream_read, close_recv.clone()),
conn4.dispatch_resp(resp_recv, resp_notify_recv, close_recv.clone()),
conn2.send_loop(query_recv, box_stream_write),
async move {
tokio::select!(
r = conn3.recv_loop(box_stream_read) => r,
_ = await_exit(stop_recv_loop_recv) => Ok(()),
)
}
)
.map(|_| ())
.log_err("ClientConn send_loop/recv_loop/dispatch_loop");
@ -199,51 +213,15 @@ impl ClientConn {
}
pub fn close(&self) {
self.close_send.broadcast(true).unwrap();
}
async fn dispatch_resp(
self: Arc<Self>,
mut resp_recv: mpsc::UnboundedReceiver<(RequestID, Vec<u8>)>,
mut resp_notify_recv: mpsc::UnboundedReceiver<(RequestID, oneshot::Sender<Vec<u8>>)>,
mut must_exit: watch::Receiver<bool>,
) -> Result<(), Error> {
let mut resps: HashMap<RequestID, Vec<u8>> = HashMap::new();
let mut resp_notify: HashMap<RequestID, oneshot::Sender<Vec<u8>>> = HashMap::new();
while !*must_exit.borrow() {
tokio::select! {
resp = resp_recv.recv() => {
if let Some((id, resp)) = resp {
trace!("dispatch_resp: got resp to {}, {} bytes", id, resp.len());
if let Some(ch) = resp_notify.remove(&id) {
if ch.send(resp).is_err() {
debug!("Could not dispatch reply (channel probably closed, happens if request was canceled)");
}
} else {
resps.insert(id, resp);
}
}
}
resp_ch = resp_notify_recv.recv() => {
if let Some((id, resp_ch)) = resp_ch {
trace!("dispatch_resp: got resp_ch {}", id);
if let Some(rs) = resps.remove(&id) {
if resp_ch.send(rs).is_err() {
debug!("Could not dispatch reply (channel probably closed, happens if request was canceled)");
}
} else {
resp_notify.insert(id, resp_ch);
}
}
}
exit = must_exit.recv() => {
if exit == Some(true) {
break;
}
}
}
self.must_exit.store(true, atomic::Ordering::SeqCst);
self.query_send
.send(None)
.log_err("could not write None in query_send");
if self.inflight.lock().unwrap().is_empty() {
self.stop_recv_loop
.broadcast(true)
.log_err("could not write true to stop_recv_loop");
}
Ok(())
}
pub(crate) async fn request<T>(
@ -262,10 +240,18 @@ impl ClientConn {
bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]);
let (resp_send, resp_recv) = oneshot::channel();
self.resp_notify_send.send((id, resp_send))?;
let old = self.inflight.lock().unwrap().insert(id, resp_send);
if let Some(old_ch) = old {
error!(
"Too many inflight requests! RequestID collision. Interrupting previous request."
);
if old_ch.send(vec![]).is_err() {
debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
}
}
trace!("request: query_send {}, {} bytes", id, bytes.len());
self.query_send.send((id, prio, bytes))?;
self.query_send.send(Some((id, prio, bytes)))?;
let resp = resp_recv.await?;
@ -279,8 +265,17 @@ impl SendLoop for ClientConn {}
#[async_trait]
impl RecvLoop for ClientConn {
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>) {
self.resp_send
.send((id, msg))
.log_err("ClientConn::recv_handler");
let mut inflight = self.inflight.lock().unwrap();
if let Some(ch) = inflight.remove(&id) {
if ch.send(msg).is_err() {
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
}
}
if inflight.is_empty() && self.must_exit.load(atomic::Ordering::SeqCst) {
self.stop_recv_loop
.broadcast(true)
.log_err("could not write true to stop_recv_loop");
}
}
}

View file

@ -18,8 +18,8 @@
pub mod error;
pub mod util;
pub mod proto;
pub mod message;
pub mod proto;
mod conn;

View file

@ -133,18 +133,22 @@ impl NetApp {
/// been successfully established. Do not set this if using a peering strategy,
/// as the peering strategy will need to set this itself.
pub fn on_connected<F>(&self, handler: F)
where F: Fn(ed25519::PublicKey, SocketAddr, bool) + Sized + Send + Sync + 'static
{
self.on_connected_handler.store(Some(Arc::new(Box::new(handler))));
where
F: Fn(ed25519::PublicKey, SocketAddr, bool) + Sized + Send + Sync + 'static,
{
self.on_connected_handler
.store(Some(Arc::new(Box::new(handler))));
}
/// Set the handler to be called when an existing connection (incoming or outgoing) has
/// been closed by either party. Do not set this if using a peering strategy,
/// as the peering strategy will need to set this itself.
pub fn on_disconnected<F>(&self, handler: F)
where F: Fn(ed25519::PublicKey, bool) + Sized + Send + Sync + 'static
{
self.on_disconnected_handler.store(Some(Arc::new(Box::new(handler))));
where
F: Fn(ed25519::PublicKey, bool) + Sized + Send + Sync + 'static,
{
self.on_disconnected_handler
.store(Some(Arc::new(Box::new(handler))));
}
/// Add a handler for a certain message type. Note that only one handler
@ -240,11 +244,13 @@ impl NetApp {
pub fn disconnect(self: &Arc<Self>, pk: &ed25519::PublicKey) {
// If pk is ourself, we're not supposed to have a connection open
if *pk != self.pubkey {
let conn = self.client_conns.read().unwrap().remove(pk);
let conn = self.client_conns.write().unwrap().remove(pk);
if let Some(c) = conn {
debug!("Closing connection to {} ({})",
hex::encode(c.peer_pk),
c.remote_addr);
debug!(
"Closing connection to {} ({})",
hex::encode(c.peer_pk),
c.remote_addr
);
c.close();
} else {
return;
@ -268,9 +274,11 @@ impl NetApp {
pub fn server_disconnect(self: &Arc<Self>, pk: &ed25519::PublicKey) {
let conn = self.server_conns.read().unwrap().get(pk).cloned();
if let Some(c) = conn {
debug!("Closing incoming connection from {} ({})",
hex::encode(c.peer_pk),
c.remote_addr);
debug!(
"Closing incoming connection from {} ({})",
hex::encode(c.peer_pk),
c.remote_addr
);
c.close();
}
}

View file

@ -3,7 +3,7 @@ use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use log::{trace, debug, info, warn};
use log::{debug, info, trace, warn};
use lru::LruCache;
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
@ -267,15 +267,13 @@ impl Basalt {
netapp.on_connected(
move |pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool| {
basalt2.on_connected(pk, addr, is_incoming);
}
},
);
let basalt2 = basalt.clone();
netapp.on_disconnected(
move |pk: ed25519::PublicKey, is_incoming: bool| {
basalt2.on_disconnected(pk, is_incoming);
},
);
netapp.on_disconnected(move |pk: ed25519::PublicKey, is_incoming: bool| {
basalt2.on_disconnected(pk, is_incoming);
});
let basalt2 = basalt.clone();
netapp.add_msg_handler::<PullMessage, _, _>(

View file

@ -185,12 +185,10 @@ impl FullMeshPeeringStrategy {
);
let strat2 = strat.clone();
netapp.on_disconnected(
move |pk: ed25519::PublicKey, is_incoming: bool| {
let strat2 = strat2.clone();
tokio::spawn(strat2.on_disconnected(pk, is_incoming));
},
);
netapp.on_disconnected(move |pk: ed25519::PublicKey, is_incoming: bool| {
let strat2 = strat2.clone();
tokio::spawn(strat2.on_disconnected(pk, is_incoming));
});
strat
}

View file

@ -3,14 +3,14 @@ use std::sync::Arc;
use log::trace;
use async_trait::async_trait;
use async_std::io::prelude::WriteExt;
use async_std::io::ReadExt;
use tokio::io::{ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio::sync::{mpsc, watch};
use tokio::sync::mpsc;
use async_trait::async_trait;
use crate::error::*;
@ -85,26 +85,33 @@ impl SendQueue {
}
}
}
fn is_empty(&self) -> bool {
self.items.iter().all(|(_k, v)| v.is_empty())
}
}
#[async_trait]
pub(crate) trait SendLoop: Sync {
async fn send_loop(
self: Arc<Self>,
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
mut msg_recv: mpsc::UnboundedReceiver<Option<(RequestID, RequestPriority, Vec<u8>)>>,
mut write: BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
mut must_exit: watch::Receiver<bool>,
) -> Result<(), Error> {
let mut sending = SendQueue::new();
while !*must_exit.borrow() {
if let Ok((id, prio, data)) = msg_recv.try_recv() {
trace!("send_loop: got {}, {} bytes", id, data.len());
sending.push(SendQueueItem {
id,
prio,
data,
cursor: 0,
});
let mut should_exit = false;
while !should_exit || !sending.is_empty() {
if let Ok(sth) = msg_recv.try_recv() {
if let Some((id, prio, data)) = sth {
trace!("send_loop: got {}, {} bytes", id, data.len());
sending.push(SendQueueItem {
id,
prio,
data,
cursor: 0,
});
} else {
should_exit = true;
}
} else if let Some(mut item) = sending.pop() {
trace!(
"send_loop: sending bytes for {} ({} bytes, {} already sent)",
@ -113,33 +120,14 @@ pub(crate) trait SendLoop: Sync {
item.cursor
);
let header_id = u16::to_be_bytes(item.id);
if write_all_or_exit(&header_id[..], &mut write, &mut must_exit)
.await?
.is_none()
{
break;
}
write.write_all(&header_id[..]).await?;
if item.data.len() - item.cursor > MAX_CHUNK_SIZE {
let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000);
if write_all_or_exit(&header_size[..], &mut write, &mut must_exit)
.await?
.is_none()
{
break;
}
write.write_all(&header_size[..]).await?;
let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize;
if write_all_or_exit(
&item.data[item.cursor..new_cursor],
&mut write,
&mut must_exit,
)
.await?
.is_none()
{
break;
}
write.write_all(&item.data[item.cursor..new_cursor]).await?;
item.cursor = new_cursor;
sending.push(item);
@ -147,33 +135,27 @@ pub(crate) trait SendLoop: Sync {
let send_len = (item.data.len() - item.cursor) as u16;
let header_size = u16::to_be_bytes(send_len);
if write_all_or_exit(&header_size[..], &mut write, &mut must_exit)
.await?
.is_none()
{
break;
}
write.write_all(&header_size[..]).await?;
if write_all_or_exit(&item.data[item.cursor..], &mut write, &mut must_exit)
.await?
.is_none()
{
break;
}
write.write_all(&item.data[item.cursor..]).await?;
}
write.flush().await.log_err("Could not flush in send_loop");
} else {
let (id, prio, data) = msg_recv
let sth = msg_recv
.recv()
.await
.ok_or(Error::Message("Connection closed.".into()))?;
trace!("send_loop: got {}, {} bytes", id, data.len());
sending.push(SendQueueItem {
id,
prio,
data,
cursor: 0,
});
if let Some((id, prio, data)) = sth {
trace!("send_loop: got {}, {} bytes", id, data.len());
sending.push(SendQueueItem {
id,
prio,
data,
cursor: 0,
});
} else {
should_exit = true;
}
}
}
Ok(())
@ -182,33 +164,23 @@ pub(crate) trait SendLoop: Sync {
#[async_trait]
pub(crate) trait RecvLoop: Sync + 'static {
// Returns true if we should stop receiving after this
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>);
async fn recv_loop(
self: Arc<Self>,
mut read: BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
mut must_exit: watch::Receiver<bool>,
) -> Result<(), Error> {
let mut receiving = HashMap::new();
while !*must_exit.borrow() {
loop {
trace!("recv_loop: reading packet");
let mut header_id = [0u8; 2];
if read_exact_or_exit(&mut header_id[..], &mut read, &mut must_exit)
.await?
.is_none()
{
break;
}
read.read_exact(&mut header_id[..]).await?;
let id = RequestID::from_be_bytes(header_id);
trace!("recv_loop: got header id: {:04x}", id);
let mut header_size = [0u8; 2];
if read_exact_or_exit(&mut header_size[..], &mut read, &mut must_exit)
.await?
.is_none()
{
break;
}
read.read_exact(&mut header_size[..]).await?;
let size = RequestID::from_be_bytes(header_size);
trace!("recv_loop: got header size: {:04x}", id);
@ -216,12 +188,7 @@ pub(crate) trait RecvLoop: Sync + 'static {
let size = size & !0x8000;
let mut next_slice = vec![0; size as usize];
if read_exact_or_exit(&mut next_slice[..], &mut read, &mut must_exit)
.await?
.is_none()
{
break;
}
read.read_exact(&mut next_slice[..]).await?;
trace!("recv_loop: read {} bytes", size);
let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]);
@ -233,36 +200,5 @@ pub(crate) trait RecvLoop: Sync + 'static {
tokio::spawn(self.clone().recv_handler(id, msg_bytes));
}
}
Ok(())
}
}
async fn read_exact_or_exit(
buf: &mut [u8],
read: &mut BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
must_exit: &mut watch::Receiver<bool>,
) -> Result<Option<()>, Error> {
tokio::select!(
res = read.read_exact(buf) => Ok(Some(res?)),
_ = await_exit(must_exit) => Ok(None),
)
}
async fn write_all_or_exit(
buf: &[u8],
write: &mut BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
must_exit: &mut watch::Receiver<bool>,
) -> Result<Option<()>, Error> {
tokio::select!(
res = write.write_all(buf) => Ok(Some(res?)),
_ = await_exit(must_exit) => Ok(None),
)
}
async fn await_exit(must_exit: &mut watch::Receiver<bool>) {
loop {
if must_exit.recv().await == Some(true) {
return;
}
}
}

View file

@ -1,5 +1,7 @@
use serde::Serialize;
use tokio::sync::watch;
/// Utility function: encodes any serializable value in MessagePack binary format
/// using the RMP library.
///
@ -16,3 +18,21 @@ where
val.serialize(&mut se)?;
Ok(wr)
}
/// This async function returns only when a true signal was received
/// from a watcher that tells us when to exit.
/// Usefull in a select statement to interrupt another
/// future:
/// ```
/// select!(
/// _ = a_long_task() => Success,
/// _ = await_exit(must_exit) => Interrupted,
/// )
/// ```
pub async fn await_exit(mut must_exit: watch::Receiver<bool>) {
loop {
if must_exit.recv().await == Some(true) {
return;
}
}
}