From 5a9ae8615ee616b11460a046deaa6981b10d69ab Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 7 Dec 2020 13:35:24 +0100 Subject: [PATCH] Do not close connections immediately on close signal, await for remaining responses --- examples/basalt.rs | 41 +++++------ examples/fullmesh.rs | 17 ++--- src/conn.rs | 137 ++++++++++++++++++------------------ src/lib.rs | 4 +- src/netapp.rs | 36 ++++++---- src/peering/basalt.rs | 12 ++-- src/peering/fullmesh.rs | 10 ++- src/proto.rs | 150 ++++++++++++---------------------------- src/util.rs | 20 ++++++ 9 files changed, 192 insertions(+), 235 deletions(-) diff --git a/examples/basalt.rs b/examples/basalt.rs index 4c86cf8..eaf056b 100644 --- a/examples/basalt.rs +++ b/examples/basalt.rs @@ -1,20 +1,20 @@ +use std::io::Write; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; -use std::io::Write; use log::{debug, info, warn}; -use structopt::StructOpt; use serde::{Deserialize, Serialize}; +use structopt::StructOpt; use sodiumoxide::crypto::auth; use sodiumoxide::crypto::sign::ed25519; -use netapp::NetApp; -use netapp::peering::basalt::*; use netapp::message::*; +use netapp::peering::basalt::*; use netapp::proto::*; +use netapp::NetApp; #[derive(StructOpt, Debug)] #[structopt(name = "netapp")] @@ -52,17 +52,17 @@ async fn main() { env_logger::Builder::new() .parse_env("RUST_LOG") .format(|buf, record| { - writeln!(buf, - "{} {} {} {}", - chrono::Local::now().format("%s%.6f"), - record.module_path().unwrap_or("_"), - record.level(), - record.args() - ) + writeln!( + buf, + "{} {} {} {}", + chrono::Local::now().format("%s%.6f"), + record.module_path().unwrap_or("_"), + record.level(), + record.args() + ) }) .init(); - let opt = Opt::from_args(); let netid = match &opt.network_key { @@ -108,10 +108,12 @@ async fn main() { |_from: ed25519::PublicKey, msg: ExampleMessage| { debug!("Got example message: {:?}, sending example response", msg); async { - ExampleResponse{example_field: false} + ExampleResponse { + example_field: false, + } } - } - ); + }, + ); tokio::join!( sampling_loop(netapp.clone(), peering.clone()), @@ -120,8 +122,6 @@ async fn main() { ); } - - async fn sampling_loop(netapp: Arc, basalt: Arc) { loop { tokio::time::delay_for(Duration::from_secs(10)).await; @@ -132,9 +132,10 @@ async fn sampling_loop(netapp: Arc, basalt: Arc) { let netapp2 = netapp.clone(); tokio::spawn(async move { - match netapp2.request(&p, ExampleMessage{ - example_field: 42, - }, PRIO_NORMAL).await { + match netapp2 + .request(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL) + .await + { Ok(resp) => debug!("Got example response: {:?}", resp), Err(e) => warn!("Error with example request: {}", e), } diff --git a/examples/fullmesh.rs b/examples/fullmesh.rs index dfacb89..5addcea 100644 --- a/examples/fullmesh.rs +++ b/examples/fullmesh.rs @@ -1,5 +1,5 @@ -use std::net::SocketAddr; use std::io::Write; +use std::net::SocketAddr; use log::info; @@ -32,13 +32,14 @@ async fn main() { env_logger::Builder::new() .parse_env("RUST_LOG") .format(|buf, record| { - writeln!(buf, - "{} {} {} {}", - chrono::Local::now().format("%s%.6f"), - record.module_path().unwrap_or("_"), - record.level(), - record.args() - ) + writeln!( + buf, + "{} {} {} {}", + chrono::Local::now().format("%s%.6f"), + record.module_path().unwrap_or("_"), + record.level(), + record.args() + ) }) .init(); diff --git a/src/conn.rs b/src/conn.rs index d4362e5..89bf654 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -1,17 +1,18 @@ use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::atomic::{self, AtomicU16}; -use std::sync::Arc; +use std::sync::atomic::{self, AtomicBool, AtomicU16}; +use std::sync::{Arc, Mutex}; -use async_trait::async_trait; use bytes::Bytes; -use log::{debug, trace}; +use log::{debug, error, trace}; use sodiumoxide::crypto::sign::ed25519; use tokio::io::split; use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot, watch}; +use async_trait::async_trait; + use kuska_handshake::async_std::{ handshake_client, handshake_server, BoxStream, TokioCompatExt, TokioCompatExtRead, TokioCompatExtWrite, @@ -29,7 +30,7 @@ pub(crate) struct ServerConn { netapp: Arc, - resp_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec)>, + resp_send: mpsc::UnboundedSender)>>, close_send: watch::Sender, } @@ -78,9 +79,20 @@ impl ServerConn { let conn2 = conn.clone(); let conn3 = conn.clone(); + let close_recv2 = close_recv.clone(); tokio::try_join!( - conn2.recv_loop(box_stream_read, close_recv.clone()), - conn3.send_loop(resp_recv, box_stream_write, close_recv.clone()), + async move { + tokio::select!( + r = conn2.recv_loop(box_stream_read) => r, + _ = await_exit(close_recv) => Ok(()), + ) + }, + async move { + tokio::select!( + r = conn3.send_loop(resp_recv, box_stream_write) => r, + _ = await_exit(close_recv2) => Ok(()), + ) + }, ) .map(|_| ()) .log_err("ServerConn recv_loop/send_loop"); @@ -112,7 +124,7 @@ impl RecvLoop for ServerConn { let net_handler = &handler.net_handler; let resp = net_handler(self.peer_pk.clone(), bytes.slice(5..)).await; self.resp_send - .send((id, prio, resp)) + .send(Some((id, prio, resp))) .log_err("ServerConn recv_handler send resp"); } } @@ -121,11 +133,12 @@ pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_pk: ed25519::PublicKey, - query_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec)>, + query_send: mpsc::UnboundedSender)>>, + next_query_number: AtomicU16, - resp_send: mpsc::UnboundedSender<(RequestID, Vec)>, - resp_notify_send: mpsc::UnboundedSender<(RequestID, oneshot::Sender>)>, - close_send: watch::Sender, + inflight: Mutex>>>, + must_exit: AtomicBool, + stop_recv_loop: watch::Sender, } impl ClientConn { @@ -163,19 +176,17 @@ impl ClientConn { BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write(); let (query_send, query_recv) = mpsc::unbounded_channel(); - let (resp_send, resp_recv) = mpsc::unbounded_channel(); - let (resp_notify_send, resp_notify_recv) = mpsc::unbounded_channel(); - let (close_send, close_recv) = watch::channel(false); + let (stop_recv_loop, stop_recv_loop_recv) = watch::channel(false); let conn = Arc::new(ClientConn { remote_addr, peer_pk: remote_pk.clone(), next_query_number: AtomicU16::from(0u16), query_send, - resp_send, - resp_notify_send, - close_send, + inflight: Mutex::new(HashMap::new()), + must_exit: AtomicBool::new(false), + stop_recv_loop, }); netapp.connected_as_client(remote_pk.clone(), conn.clone()); @@ -183,11 +194,14 @@ impl ClientConn { tokio::spawn(async move { let conn2 = conn.clone(); let conn3 = conn.clone(); - let conn4 = conn.clone(); tokio::try_join!( - conn2.send_loop(query_recv, box_stream_write, close_recv.clone()), - conn3.recv_loop(box_stream_read, close_recv.clone()), - conn4.dispatch_resp(resp_recv, resp_notify_recv, close_recv.clone()), + conn2.send_loop(query_recv, box_stream_write), + async move { + tokio::select!( + r = conn3.recv_loop(box_stream_read) => r, + _ = await_exit(stop_recv_loop_recv) => Ok(()), + ) + } ) .map(|_| ()) .log_err("ClientConn send_loop/recv_loop/dispatch_loop"); @@ -199,51 +213,15 @@ impl ClientConn { } pub fn close(&self) { - self.close_send.broadcast(true).unwrap(); - } - - async fn dispatch_resp( - self: Arc, - mut resp_recv: mpsc::UnboundedReceiver<(RequestID, Vec)>, - mut resp_notify_recv: mpsc::UnboundedReceiver<(RequestID, oneshot::Sender>)>, - mut must_exit: watch::Receiver, - ) -> Result<(), Error> { - let mut resps: HashMap> = HashMap::new(); - let mut resp_notify: HashMap>> = HashMap::new(); - while !*must_exit.borrow() { - tokio::select! { - resp = resp_recv.recv() => { - if let Some((id, resp)) = resp { - trace!("dispatch_resp: got resp to {}, {} bytes", id, resp.len()); - if let Some(ch) = resp_notify.remove(&id) { - if ch.send(resp).is_err() { - debug!("Could not dispatch reply (channel probably closed, happens if request was canceled)"); - } - } else { - resps.insert(id, resp); - } - } - } - resp_ch = resp_notify_recv.recv() => { - if let Some((id, resp_ch)) = resp_ch { - trace!("dispatch_resp: got resp_ch {}", id); - if let Some(rs) = resps.remove(&id) { - if resp_ch.send(rs).is_err() { - debug!("Could not dispatch reply (channel probably closed, happens if request was canceled)"); - } - } else { - resp_notify.insert(id, resp_ch); - } - } - } - exit = must_exit.recv() => { - if exit == Some(true) { - break; - } - } - } + self.must_exit.store(true, atomic::Ordering::SeqCst); + self.query_send + .send(None) + .log_err("could not write None in query_send"); + if self.inflight.lock().unwrap().is_empty() { + self.stop_recv_loop + .broadcast(true) + .log_err("could not write true to stop_recv_loop"); } - Ok(()) } pub(crate) async fn request( @@ -262,10 +240,18 @@ impl ClientConn { bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]); let (resp_send, resp_recv) = oneshot::channel(); - self.resp_notify_send.send((id, resp_send))?; + let old = self.inflight.lock().unwrap().insert(id, resp_send); + if let Some(old_ch) = old { + error!( + "Too many inflight requests! RequestID collision. Interrupting previous request." + ); + if old_ch.send(vec![]).is_err() { + debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); + } + } trace!("request: query_send {}, {} bytes", id, bytes.len()); - self.query_send.send((id, prio, bytes))?; + self.query_send.send(Some((id, prio, bytes)))?; let resp = resp_recv.await?; @@ -279,8 +265,17 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { async fn recv_handler(self: Arc, id: RequestID, msg: Vec) { - self.resp_send - .send((id, msg)) - .log_err("ClientConn::recv_handler"); + let mut inflight = self.inflight.lock().unwrap(); + if let Some(ch) = inflight.remove(&id) { + if ch.send(msg).is_err() { + debug!("Could not send request response, probably because request was interrupted. Dropping response."); + } + } + + if inflight.is_empty() && self.must_exit.load(atomic::Ordering::SeqCst) { + self.stop_recv_loop + .broadcast(true) + .log_err("could not write true to stop_recv_loop"); + } } } diff --git a/src/lib.rs b/src/lib.rs index ba365c7..af8fbb8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ //! Netapp is a Rust library that takes care of a few common tasks in distributed software: -//! +//! //! - establishing secure connections //! - managing connection lifetime, reconnecting on failure //! - checking peer's state @@ -18,8 +18,8 @@ pub mod error; pub mod util; -pub mod proto; pub mod message; +pub mod proto; mod conn; diff --git a/src/netapp.rs b/src/netapp.rs index bf9a3f0..967105e 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -53,7 +53,7 @@ pub struct NetApp { server_conns: RwLock>>, client_conns: RwLock>>, - + pub(crate) msg_handlers: ArcSwap>>, on_connected_handler: ArcSwapOption>, @@ -133,18 +133,22 @@ impl NetApp { /// been successfully established. Do not set this if using a peering strategy, /// as the peering strategy will need to set this itself. pub fn on_connected(&self, handler: F) - where F: Fn(ed25519::PublicKey, SocketAddr, bool) + Sized + Send + Sync + 'static - { - self.on_connected_handler.store(Some(Arc::new(Box::new(handler)))); + where + F: Fn(ed25519::PublicKey, SocketAddr, bool) + Sized + Send + Sync + 'static, + { + self.on_connected_handler + .store(Some(Arc::new(Box::new(handler)))); } /// Set the handler to be called when an existing connection (incoming or outgoing) has /// been closed by either party. Do not set this if using a peering strategy, /// as the peering strategy will need to set this itself. pub fn on_disconnected(&self, handler: F) - where F: Fn(ed25519::PublicKey, bool) + Sized + Send + Sync + 'static - { - self.on_disconnected_handler.store(Some(Arc::new(Box::new(handler)))); + where + F: Fn(ed25519::PublicKey, bool) + Sized + Send + Sync + 'static, + { + self.on_disconnected_handler + .store(Some(Arc::new(Box::new(handler)))); } /// Add a handler for a certain message type. Note that only one handler @@ -240,11 +244,13 @@ impl NetApp { pub fn disconnect(self: &Arc, pk: &ed25519::PublicKey) { // If pk is ourself, we're not supposed to have a connection open if *pk != self.pubkey { - let conn = self.client_conns.read().unwrap().remove(pk); + let conn = self.client_conns.write().unwrap().remove(pk); if let Some(c) = conn { - debug!("Closing connection to {} ({})", - hex::encode(c.peer_pk), - c.remote_addr); + debug!( + "Closing connection to {} ({})", + hex::encode(c.peer_pk), + c.remote_addr + ); c.close(); } else { return; @@ -268,9 +274,11 @@ impl NetApp { pub fn server_disconnect(self: &Arc, pk: &ed25519::PublicKey) { let conn = self.server_conns.read().unwrap().get(pk).cloned(); if let Some(c) = conn { - debug!("Closing incoming connection from {} ({})", - hex::encode(c.peer_pk), - c.remote_addr); + debug!( + "Closing incoming connection from {} ({})", + hex::encode(c.peer_pk), + c.remote_addr + ); c.close(); } } diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index 615b559..4aa34f6 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -3,7 +3,7 @@ use std::net::SocketAddr; use std::sync::{Arc, RwLock}; use std::time::Duration; -use log::{trace, debug, info, warn}; +use log::{debug, info, trace, warn}; use lru::LruCache; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; @@ -267,15 +267,13 @@ impl Basalt { netapp.on_connected( move |pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool| { basalt2.on_connected(pk, addr, is_incoming); - } + }, ); let basalt2 = basalt.clone(); - netapp.on_disconnected( - move |pk: ed25519::PublicKey, is_incoming: bool| { - basalt2.on_disconnected(pk, is_incoming); - }, - ); + netapp.on_disconnected(move |pk: ed25519::PublicKey, is_incoming: bool| { + basalt2.on_disconnected(pk, is_incoming); + }); let basalt2 = basalt.clone(); netapp.add_msg_handler::( diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 1b26489..d6ca08a 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -185,12 +185,10 @@ impl FullMeshPeeringStrategy { ); let strat2 = strat.clone(); - netapp.on_disconnected( - move |pk: ed25519::PublicKey, is_incoming: bool| { - let strat2 = strat2.clone(); - tokio::spawn(strat2.on_disconnected(pk, is_incoming)); - }, - ); + netapp.on_disconnected(move |pk: ed25519::PublicKey, is_incoming: bool| { + let strat2 = strat2.clone(); + tokio::spawn(strat2.on_disconnected(pk, is_incoming)); + }); strat } diff --git a/src/proto.rs b/src/proto.rs index b044280..d90042f 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -3,14 +3,14 @@ use std::sync::Arc; use log::trace; -use async_trait::async_trait; - use async_std::io::prelude::WriteExt; use async_std::io::ReadExt; use tokio::io::{ReadHalf, WriteHalf}; use tokio::net::TcpStream; -use tokio::sync::{mpsc, watch}; +use tokio::sync::mpsc; + +use async_trait::async_trait; use crate::error::*; @@ -85,26 +85,33 @@ impl SendQueue { } } } + fn is_empty(&self) -> bool { + self.items.iter().all(|(_k, v)| v.is_empty()) + } } #[async_trait] pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec)>, + mut msg_recv: mpsc::UnboundedReceiver)>>, mut write: BoxStreamWrite>>, - mut must_exit: watch::Receiver, ) -> Result<(), Error> { let mut sending = SendQueue::new(); - while !*must_exit.borrow() { - if let Ok((id, prio, data)) = msg_recv.try_recv() { - trace!("send_loop: got {}, {} bytes", id, data.len()); - sending.push(SendQueueItem { - id, - prio, - data, - cursor: 0, - }); + let mut should_exit = false; + while !should_exit || !sending.is_empty() { + if let Ok(sth) = msg_recv.try_recv() { + if let Some((id, prio, data)) = sth { + trace!("send_loop: got {}, {} bytes", id, data.len()); + sending.push(SendQueueItem { + id, + prio, + data, + cursor: 0, + }); + } else { + should_exit = true; + } } else if let Some(mut item) = sending.pop() { trace!( "send_loop: sending bytes for {} ({} bytes, {} already sent)", @@ -113,33 +120,14 @@ pub(crate) trait SendLoop: Sync { item.cursor ); let header_id = u16::to_be_bytes(item.id); - if write_all_or_exit(&header_id[..], &mut write, &mut must_exit) - .await? - .is_none() - { - break; - } + write.write_all(&header_id[..]).await?; if item.data.len() - item.cursor > MAX_CHUNK_SIZE { let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000); - if write_all_or_exit(&header_size[..], &mut write, &mut must_exit) - .await? - .is_none() - { - break; - } + write.write_all(&header_size[..]).await?; let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize; - if write_all_or_exit( - &item.data[item.cursor..new_cursor], - &mut write, - &mut must_exit, - ) - .await? - .is_none() - { - break; - } + write.write_all(&item.data[item.cursor..new_cursor]).await?; item.cursor = new_cursor; sending.push(item); @@ -147,33 +135,27 @@ pub(crate) trait SendLoop: Sync { let send_len = (item.data.len() - item.cursor) as u16; let header_size = u16::to_be_bytes(send_len); - if write_all_or_exit(&header_size[..], &mut write, &mut must_exit) - .await? - .is_none() - { - break; - } + write.write_all(&header_size[..]).await?; - if write_all_or_exit(&item.data[item.cursor..], &mut write, &mut must_exit) - .await? - .is_none() - { - break; - } + write.write_all(&item.data[item.cursor..]).await?; } write.flush().await.log_err("Could not flush in send_loop"); } else { - let (id, prio, data) = msg_recv + let sth = msg_recv .recv() .await .ok_or(Error::Message("Connection closed.".into()))?; - trace!("send_loop: got {}, {} bytes", id, data.len()); - sending.push(SendQueueItem { - id, - prio, - data, - cursor: 0, - }); + if let Some((id, prio, data)) = sth { + trace!("send_loop: got {}, {} bytes", id, data.len()); + sending.push(SendQueueItem { + id, + prio, + data, + cursor: 0, + }); + } else { + should_exit = true; + } } } Ok(()) @@ -182,33 +164,23 @@ pub(crate) trait SendLoop: Sync { #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { + // Returns true if we should stop receiving after this async fn recv_handler(self: Arc, id: RequestID, msg: Vec); async fn recv_loop( self: Arc, mut read: BoxStreamRead>>, - mut must_exit: watch::Receiver, ) -> Result<(), Error> { let mut receiving = HashMap::new(); - while !*must_exit.borrow() { + loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; 2]; - if read_exact_or_exit(&mut header_id[..], &mut read, &mut must_exit) - .await? - .is_none() - { - break; - } + read.read_exact(&mut header_id[..]).await?; let id = RequestID::from_be_bytes(header_id); trace!("recv_loop: got header id: {:04x}", id); let mut header_size = [0u8; 2]; - if read_exact_or_exit(&mut header_size[..], &mut read, &mut must_exit) - .await? - .is_none() - { - break; - } + read.read_exact(&mut header_size[..]).await?; let size = RequestID::from_be_bytes(header_size); trace!("recv_loop: got header size: {:04x}", id); @@ -216,12 +188,7 @@ pub(crate) trait RecvLoop: Sync + 'static { let size = size & !0x8000; let mut next_slice = vec![0; size as usize]; - if read_exact_or_exit(&mut next_slice[..], &mut read, &mut must_exit) - .await? - .is_none() - { - break; - } + read.read_exact(&mut next_slice[..]).await?; trace!("recv_loop: read {} bytes", size); let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]); @@ -233,36 +200,5 @@ pub(crate) trait RecvLoop: Sync + 'static { tokio::spawn(self.clone().recv_handler(id, msg_bytes)); } } - Ok(()) - } -} - -async fn read_exact_or_exit( - buf: &mut [u8], - read: &mut BoxStreamRead>>, - must_exit: &mut watch::Receiver, -) -> Result, Error> { - tokio::select!( - res = read.read_exact(buf) => Ok(Some(res?)), - _ = await_exit(must_exit) => Ok(None), - ) -} - -async fn write_all_or_exit( - buf: &[u8], - write: &mut BoxStreamWrite>>, - must_exit: &mut watch::Receiver, -) -> Result, Error> { - tokio::select!( - res = write.write_all(buf) => Ok(Some(res?)), - _ = await_exit(must_exit) => Ok(None), - ) -} - -async fn await_exit(must_exit: &mut watch::Receiver) { - loop { - if must_exit.recv().await == Some(true) { - return; - } } } diff --git a/src/util.rs b/src/util.rs index f09a3bc..017ef00 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,5 +1,7 @@ use serde::Serialize; +use tokio::sync::watch; + /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. /// @@ -16,3 +18,21 @@ where val.serialize(&mut se)?; Ok(wr) } + +/// This async function returns only when a true signal was received +/// from a watcher that tells us when to exit. +/// Usefull in a select statement to interrupt another +/// future: +/// ``` +/// select!( +/// _ = a_long_task() => Success, +/// _ = await_exit(must_exit) => Interrupted, +/// ) +/// ``` +pub async fn await_exit(mut must_exit: watch::Receiver) { + loop { + if must_exit.recv().await == Some(true) { + return; + } + } +}