Try to handle termination and closing of stuff properly

This commit is contained in:
Alex 2021-10-13 17:12:13 +02:00
parent 8dede69dee
commit 70839d70d8
No known key found for this signature in database
GPG key ID: EDABF9711E244EB1
15 changed files with 264 additions and 166 deletions

23
Cargo.lock generated
View file

@ -446,6 +446,7 @@ dependencies = [
"serde", "serde",
"structopt", "structopt",
"tokio", "tokio",
"tokio-stream",
"tokio-util", "tokio-util",
] ]
@ -666,6 +667,15 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "signal-hook-registry"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.4" version = "0.4.4"
@ -779,7 +789,9 @@ dependencies = [
"memchr", "memchr",
"mio", "mio",
"num_cpus", "num_cpus",
"once_cell",
"pin-project-lite", "pin-project-lite",
"signal-hook-registry",
"tokio-macros", "tokio-macros",
"winapi", "winapi",
] ]
@ -795,6 +807,17 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "tokio-stream"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b2f3f698253f03119ac0102beaa64f67a67e08074d03a22d18784104543727f"
dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.6.8" version = "0.6.8"

View file

@ -20,8 +20,9 @@ basalt = ["lru", "rand"]
[dependencies] [dependencies]
futures = "0.3.17" futures = "0.3.17"
tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util"] } tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] }
tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] } tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] }
tokio-stream = "0.1.7"
serde = { version = "1.0", default-features = false, features = ["derive"] } serde = { version = "1.0", default-features = false, features = ["derive"] }
rmp-serde = "0.14.3" rmp-serde = "0.14.3"

View file

@ -2,6 +2,6 @@ all:
cargo build --all-features cargo build --all-features
cargo build --example fullmesh cargo build --example fullmesh
cargo build --all-features --example basalt cargo build --all-features --example basalt
RUST_LOG=netapp=debug,fullmesh=debug cargo run --example fullmesh -- -n 3242ce79e05e8b6a0e43441fbd140a906e13f335f298ae3a52f29784abbab500 -p 6c304114a0e1018bbe60502a34d33f4f439f370856c3333dda2726da01eb93a4894b7ef7249a71f11d342b69702f1beb7c93ec95fbcf122ad1eca583bb0629e7 RUST_LOG=netapp=trace,fullmesh=trace cargo run --example fullmesh -- -n 3242ce79e05e8b6a0e43441fbd140a906e13f335f298ae3a52f29784abbab500 -p 6c304114a0e1018bbe60502a34d33f4f439f370856c3333dda2726da01eb93a4894b7ef7249a71f11d342b69702f1beb7c93ec95fbcf122ad1eca583bb0629e7
#RUST_LOG=netapp=debug,fullmesh=debug cargo run --example fullmesh #RUST_LOG=netapp=debug,fullmesh=debug cargo run --example fullmesh

View file

@ -5,9 +5,9 @@ use std::time::Duration;
use log::{debug, info, warn}; use log::{debug, info, warn};
use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use structopt::StructOpt; use structopt::StructOpt;
use async_trait::async_trait;
use sodiumoxide::crypto::auth; use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519; use sodiumoxide::crypto::sign::ed25519;
@ -122,9 +122,15 @@ async fn main() {
let listen_addr = opt.listen_addr.parse().unwrap(); let listen_addr = opt.listen_addr.parse().unwrap();
let public_addr = opt.public_addr.map(|x| x.parse().unwrap()); let public_addr = opt.public_addr.map(|x| x.parse().unwrap());
let watch_cancel = netapp::util::watch_ctrl_c();
tokio::join!( tokio::join!(
example.clone().sampling_loop(), example.clone().sampling_loop(),
example.netapp.clone().listen(listen_addr, public_addr), example
.netapp
.clone()
.listen(listen_addr, public_addr, watch_cancel),
example.basalt.clone().run(), example.basalt.clone().run(),
); );
} }
@ -141,7 +147,8 @@ impl Example {
let self2 = self.clone(); let self2 = self.clone();
tokio::spawn(async move { tokio::spawn(async move {
match self2 match self2
.example_endpoint.call(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL) .example_endpoint
.call(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL)
.await .await
{ {
Ok(resp) => debug!("Got example response: {:?}", resp), Ok(resp) => debug!("Got example response: {:?}", resp),

View file

@ -87,6 +87,11 @@ async fn main() {
hex::encode(&privkey.public_key()), hex::encode(&privkey.public_key()),
listen_addr); listen_addr);
let watch_cancel = netapp::util::watch_ctrl_c();
let public_addr = opt.public_addr.map(|x| x.parse().unwrap()); let public_addr = opt.public_addr.map(|x| x.parse().unwrap());
tokio::join!(netapp.listen(listen_addr, public_addr), peering.run(),); tokio::join!(
netapp.listen(listen_addr, public_addr, watch_cancel.clone()),
peering.run(watch_cancel),
);
} }

View file

@ -1,11 +1,13 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::{self, AtomicBool, AtomicU32}; use std::sync::atomic::{self, AtomicU32};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use arc_swap::ArcSwapOption;
use log::{debug, error, trace}; use log::{debug, error, trace};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::select;
use tokio::sync::{mpsc, oneshot, watch}; use tokio::sync::{mpsc, oneshot, watch};
use tokio_util::compat::*; use tokio_util::compat::*;
@ -21,17 +23,14 @@ use crate::netapp::*;
use crate::proto::*; use crate::proto::*;
use crate::util::*; use crate::util::*;
pub(crate) struct ClientConn { pub(crate) struct ClientConn {
pub(crate) remote_addr: SocketAddr, pub(crate) remote_addr: SocketAddr,
pub(crate) peer_id: NodeID, pub(crate) peer_id: NodeID,
query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>, query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
next_query_number: AtomicU32, next_query_number: AtomicU32,
inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>, inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
must_exit: AtomicBool,
stop_recv_loop: watch::Sender<bool>,
} }
impl ClientConn { impl ClientConn {
@ -71,25 +70,35 @@ impl ClientConn {
remote_addr, remote_addr,
peer_id, peer_id,
next_query_number: AtomicU32::from(RequestID::default()), next_query_number: AtomicU32::from(RequestID::default()),
query_send, query_send: ArcSwapOption::new(Some(Arc::new(query_send))),
inflight: Mutex::new(HashMap::new()), inflight: Mutex::new(HashMap::new()),
must_exit: AtomicBool::new(false),
stop_recv_loop,
}); });
netapp.connected_as_client(peer_id, conn.clone()); netapp.connected_as_client(peer_id, conn.clone());
tokio::spawn(async move { tokio::spawn(async move {
let send_future = tokio::spawn(conn.clone().send_loop(query_recv, write));
let conn2 = conn.clone(); let conn2 = conn.clone();
let conn3 = conn.clone(); let recv_future = tokio::spawn(async move {
tokio::try_join!(conn2.send_loop(query_recv, write), async move { select! {
tokio::select!( r = conn2.recv_loop(read) => r,
r = conn3.recv_loop(read) => r, _ = await_exit(stop_recv_loop_recv) => Ok(())
_ = await_exit(stop_recv_loop_recv) => Ok(()), }
) });
})
.map(|_| ()) send_future.await.log_err("ClientConn send_loop");
.log_err("ClientConn send_loop/recv_loop/dispatch_loop");
// TODO here: wait for inflight requests to all have their response
stop_recv_loop
.send(true)
.log_err("ClientConn send true to stop_recv_loop");
recv_future.await.log_err("ClientConn recv_loop");
// Make sure we don't wait on any more requests that won't
// have a response
conn.inflight.lock().unwrap().clear();
netapp.disconnected_as_client(&peer_id, conn); netapp.disconnected_as_client(&peer_id, conn);
}); });
@ -98,15 +107,7 @@ impl ClientConn {
} }
pub fn close(&self) { pub fn close(&self) {
self.must_exit.store(true, atomic::Ordering::SeqCst); self.query_send.store(None);
self.query_send
.send(None)
.log_err("could not write None in query_send");
if self.inflight.lock().unwrap().is_empty() {
self.stop_recv_loop
.send(true)
.log_err("could not write true to stop_recv_loop");
}
} }
pub(crate) async fn call<T>( pub(crate) async fn call<T>(
@ -118,6 +119,8 @@ impl ClientConn {
where where
T: Message, T: Message,
{ {
let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
let id = self let id = self
.next_query_number .next_query_number
.fetch_add(1, atomic::Ordering::Relaxed); .fetch_add(1, atomic::Ordering::Relaxed);
@ -138,20 +141,23 @@ impl ClientConn {
} }
trace!("request: query_send {}, {} bytes", id, bytes.len()); trace!("request: query_send {}, {} bytes", id, bytes.len());
self.query_send.send(Some((id, prio, bytes)))?; query_send.send((id, prio, bytes))?;
let resp = resp_recv.await?; let resp = resp_recv.await?;
if resp.len() == 0 { if resp.is_empty() {
return Err(Error::Message("Response is 0 bytes, either a collision or a protocol error".into())); return Err(Error::Message(
"Response is 0 bytes, either a collision or a protocol error".into(),
));
} }
trace!("request response {}: ", id); trace!("request response {}: ", id);
let code = resp[0]; let code = resp[0];
if code == 0 { if code == 0 {
Ok(rmp_serde::decode::from_read_ref::<_, <T as Message>::Response>( Ok(rmp_serde::decode::from_read_ref::<
&resp[1..], _,
)?) <T as Message>::Response,
>(&resp[1..])?)
} else { } else {
Err(Error::Remote(format!("Remote error code {}", code))) Err(Error::Remote(format!("Remote error code {}", code)))
} }
@ -162,7 +168,7 @@ impl SendLoop for ClientConn {}
#[async_trait] #[async_trait]
impl RecvLoop for ClientConn { impl RecvLoop for ClientConn {
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>) { fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>) {
trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len()); trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len());
let mut inflight = self.inflight.lock().unwrap(); let mut inflight = self.inflight.lock().unwrap();
@ -171,11 +177,5 @@ impl RecvLoop for ClientConn {
debug!("Could not send request response, probably because request was interrupted. Dropping response."); debug!("Could not send request response, probably because request was interrupted. Dropping response.");
} }
} }
if inflight.is_empty() && self.must_exit.load(atomic::Ordering::SeqCst) {
self.stop_recv_loop
.send(true)
.log_err("could not write true to stop_recv_loop");
}
} }
} }

View file

@ -123,4 +123,3 @@ where
Box::new(Self(self.0.clone())) Box::new(Self(self.0.clone()))
} }
} }

View file

@ -31,6 +31,9 @@ pub enum Error {
#[error(display = "No handler / shutting down")] #[error(display = "No handler / shutting down")]
NoHandler, NoHandler,
#[error(display = "Connection closed")]
ConnectionClosed,
#[error(display = "Remote error: {}", _0)] #[error(display = "Remote error: {}", _0)]
Remote(String), Remote(String),
} }
@ -45,6 +48,7 @@ impl Error {
Self::RMPDecode(_) => 11, Self::RMPDecode(_) => 11,
Self::UTF8(_) => 12, Self::UTF8(_) => 12,
Self::NoHandler => 20, Self::NoHandler => 20,
Self::ConnectionClosed => 21,
Self::Handshake(_) => 30, Self::Handshake(_) => 30,
Self::Remote(_) => 40, Self::Remote(_) => 40,
Self::Message(_) => 99, Self::Message(_) => 99,
@ -80,3 +84,16 @@ where
}; };
} }
} }
impl<E, T> LogError for Result<T, E>
where
T: LogError,
E: Into<Error>,
{
fn log_err(self, msg: &'static str) {
match self {
Err(e) => error!("Error: {}: {}", msg, Into::<Error>::into(e)),
Ok(x) => x.log_err(msg),
}
}
}

View file

@ -13,16 +13,14 @@
//! about message priorization. //! about message priorization.
//! Also check out the examples to learn how to use this crate. //! Also check out the examples to learn how to use this crate.
#![feature(map_first_last)]
pub mod error; pub mod error;
pub mod util; pub mod util;
pub mod endpoint; pub mod endpoint;
pub mod proto; pub mod proto;
mod server;
mod client; mod client;
mod server;
pub mod netapp; pub mod netapp;
pub mod peering; pub mod peering;

View file

@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use log::{debug, info, error}; use log::{debug, error, info, trace, warn};
use arc_swap::ArcSwapOption; use arc_swap::ArcSwapOption;
use async_trait::async_trait; use async_trait::async_trait;
@ -10,13 +10,18 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sodiumoxide::crypto::auth; use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519; use sodiumoxide::crypto::sign::ed25519;
use futures::stream::futures_unordered::FuturesUnordered;
use futures::stream::StreamExt;
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::select;
use tokio::sync::{mpsc, watch};
use crate::client::*; use crate::client::*;
use crate::server::*;
use crate::endpoint::*; use crate::endpoint::*;
use crate::error::*; use crate::error::*;
use crate::proto::*; use crate::proto::*;
use crate::server::*;
use crate::util::*; use crate::util::*;
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -142,35 +147,91 @@ impl NetApp {
/// Main listening process for our app. This future runs during the whole /// Main listening process for our app. This future runs during the whole
/// run time of our application. /// run time of our application.
/// If this is not called, the NetApp instance remains a passive client. /// If this is not called, the NetApp instance remains a passive client.
pub async fn listen(self: Arc<Self>, listen_addr: SocketAddr, public_addr: Option<IpAddr>) { pub async fn listen(
self: Arc<Self>,
listen_addr: SocketAddr,
public_addr: Option<IpAddr>,
mut must_exit: watch::Receiver<bool>,
) {
let listen_params = ListenParams { let listen_params = ListenParams {
listen_addr, listen_addr,
public_addr, public_addr,
}; };
if self.listen_params.swap(Some(Arc::new(listen_params))).is_some() { if self
.listen_params
.swap(Some(Arc::new(listen_params)))
.is_some()
{
error!("Trying to listen on NetApp but we're already listening!"); error!("Trying to listen on NetApp but we're already listening!");
} }
let listener = TcpListener::bind(listen_addr).await.unwrap(); let listener = TcpListener::bind(listen_addr).await.unwrap();
info!("Listening on {}", listen_addr); info!("Listening on {}", listen_addr);
loop { let (conn_in, mut conn_out) = mpsc::unbounded_channel();
// The second item contains the IP and port of the new connection. let connection_collector = tokio::spawn(async move {
let (socket, _) = listener.accept().await.unwrap(); let mut collection = FuturesUnordered::new();
loop {
if collection.is_empty() {
match conn_out.recv().await {
Some(f) => collection.push(f),
None => break,
}
} else {
select! {
new_fut = conn_out.recv() => {
match new_fut {
Some(f) => collection.push(f),
None => break,
}
}
result = collection.next() => {
trace!("Collected connection: {:?}", result);
}
}
}
}
debug!("Collecting last open server connections.");
while let Some(conn_res) = collection.next().await {
trace!("Collected connection: {:?}", conn_res);
}
debug!("No more server connections to collect");
});
while !*must_exit.borrow_and_update() {
let (socket, peer_addr) = select! {
sockres = listener.accept() => {
match sockres {
Ok(x) => x,
Err(e) => {
warn!("Error in listener.accept: {}", e);
continue;
}
}
},
_ = must_exit.changed() => continue,
};
info!( info!(
"Incoming connection from {}, negotiating handshake...", "Incoming connection from {}, negotiating handshake...",
match socket.peer_addr() { peer_addr
Ok(x) => format!("{}", x),
Err(e) => format!("<invalid addr: {}>", e),
}
); );
let self2 = self.clone(); let self2 = self.clone();
tokio::spawn(async move { let must_exit2 = must_exit.clone();
ServerConn::run(self2, socket) conn_in
.await .send(tokio::spawn(async move {
.log_err("ServerConn::run"); ServerConn::run(self2, socket, must_exit2)
}); .await
.log_err("ServerConn::run");
}))
.log_err("Failed to send connection to connection collector");
} }
drop(conn_in);
connection_collector
.await
.log_err("Failed to await for connection collector");
} }
/// Attempt to connect to a peer, given by its ip:port and its public key. /// Attempt to connect to a peer, given by its ip:port and its public key.
@ -231,20 +292,6 @@ impl NetApp {
}); });
} }
/// Close the incoming connection from a certain client to us,
/// if such a connection is currently open.
pub fn server_disconnect(self: &Arc<Self>, id: &NodeID) {
let conn = self.server_conns.read().unwrap().get(id).cloned();
if let Some(c) = conn {
debug!(
"Closing incoming connection from {} ({})",
hex::encode(c.peer_id),
c.remote_addr
);
c.close();
}
}
// Called from conn.rs when an incoming connection is successfully established // Called from conn.rs when an incoming connection is successfully established
// Registers the connection in our list of connections // Registers the connection in our list of connections
// Do not yet call the on_connected handler, because we don't know if the remote // Do not yet call the on_connected handler, because we don't know if the remote

View file

@ -3,11 +3,11 @@ use std::net::SocketAddr;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::Duration; use std::time::Duration;
use async_trait::async_trait;
use log::{debug, info, trace, warn}; use log::{debug, info, trace, warn};
use lru::LruCache; use lru::LruCache;
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use async_trait::async_trait;
use sodiumoxide::crypto::hash; use sodiumoxide::crypto::hash;

View file

@ -8,6 +8,8 @@ use async_trait::async_trait;
use log::{debug, info, trace, warn}; use log::{debug, info, trace, warn};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::watch;
use sodiumoxide::crypto::hash; use sodiumoxide::crypto::hash;
use crate::endpoint::*; use crate::endpoint::*;
@ -171,8 +173,8 @@ impl FullMeshPeeringStrategy {
strat strat
} }
pub async fn run(self: Arc<Self>) { pub async fn run(self: Arc<Self>, must_exit: watch::Receiver<bool>) {
loop { while !*must_exit.borrow() {
// 1. Read current state: get list of connected peers (ping them) // 1. Read current state: get list of connected peers (ping them)
let (to_ping, to_retry) = { let (to_ping, to_retry) = {
let known_hosts = self.known_hosts.read().unwrap(); let known_hosts = self.known_hosts.read().unwrap();

View file

@ -1,4 +1,4 @@
use std::collections::{BTreeMap, HashMap, VecDeque}; use std::collections::{HashMap, VecDeque};
use std::sync::Arc; use std::sync::Arc;
use log::trace; use log::trace;
@ -50,7 +50,6 @@ type ChunkLength = u16;
const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
struct SendQueueItem { struct SendQueueItem {
id: RequestID, id: RequestID,
prio: RequestPriority, prio: RequestPriority,
@ -59,31 +58,33 @@ struct SendQueueItem {
} }
struct SendQueue { struct SendQueue {
items: BTreeMap<u8, VecDeque<SendQueueItem>>, items: VecDeque<(u8, VecDeque<SendQueueItem>)>,
} }
impl SendQueue { impl SendQueue {
fn new() -> Self { fn new() -> Self {
Self { Self {
items: BTreeMap::new(), items: VecDeque::with_capacity(64),
} }
} }
fn push(&mut self, item: SendQueueItem) { fn push(&mut self, item: SendQueueItem) {
let prio = item.prio; let prio = item.prio;
let mut items_at_prio = self let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) {
.items Ok(i) => i,
.remove(&prio) Err(i) => {
.unwrap_or_else(|| VecDeque::with_capacity(4)); self.items.insert(i, (prio, VecDeque::new()));
items_at_prio.push_back(item); i
self.items.insert(prio, items_at_prio); }
};
self.items[pos_prio].1.push_back(item);
} }
fn pop(&mut self) -> Option<SendQueueItem> { fn pop(&mut self) -> Option<SendQueueItem> {
match self.items.pop_first() { match self.items.pop_front() {
None => None, None => None,
Some((prio, mut items_at_prio)) => { Some((prio, mut items_at_prio)) => {
let ret = items_at_prio.pop_front(); let ret = items_at_prio.pop_front();
if !items_at_prio.is_empty() { if !items_at_prio.is_empty() {
self.items.insert(prio, items_at_prio); self.items.push_front((prio, items_at_prio));
} }
ret.or_else(|| self.pop()) ret.or_else(|| self.pop())
} }
@ -98,7 +99,7 @@ impl SendQueue {
pub(crate) trait SendLoop: Sync { pub(crate) trait SendLoop: Sync {
async fn send_loop<W>( async fn send_loop<W>(
self: Arc<Self>, self: Arc<Self>,
mut msg_recv: mpsc::UnboundedReceiver<Option<(RequestID, RequestPriority, Vec<u8>)>>, mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
mut write: W, mut write: W,
) -> Result<(), Error> ) -> Result<(), Error>
where where
@ -107,18 +108,14 @@ pub(crate) trait SendLoop: Sync {
let mut sending = SendQueue::new(); let mut sending = SendQueue::new();
let mut should_exit = false; let mut should_exit = false;
while !should_exit || !sending.is_empty() { while !should_exit || !sending.is_empty() {
if let Ok(sth) = msg_recv.try_recv() { if let Ok((id, prio, data)) = msg_recv.try_recv() {
if let Some((id, prio, data)) = sth { trace!("send_loop: got {}, {} bytes", id, data.len());
trace!("send_loop: got {}, {} bytes", id, data.len()); sending.push(SendQueueItem {
sending.push(SendQueueItem { id,
id, prio,
prio, data,
data, cursor: 0,
cursor: 0, });
});
} else {
should_exit = true;
}
} else if let Some(mut item) = sending.pop() { } else if let Some(mut item) = sending.pop() {
trace!( trace!(
"send_loop: sending bytes for {} ({} bytes, {} already sent)", "send_loop: sending bytes for {} ({} bytes, {} already sent)",
@ -149,10 +146,7 @@ pub(crate) trait SendLoop: Sync {
} }
write.flush().await?; write.flush().await?;
} else { } else {
let sth = msg_recv let sth = msg_recv.recv().await;
.recv()
.await
.ok_or_else(|| Error::Message("Connection closed.".into()))?;
if let Some((id, prio, data)) = sth { if let Some((id, prio, data)) = sth {
trace!("send_loop: got {}, {} bytes", id, data.len()); trace!("send_loop: got {}, {} bytes", id, data.len());
sending.push(SendQueueItem { sending.push(SendQueueItem {
@ -173,7 +167,7 @@ pub(crate) trait SendLoop: Sync {
#[async_trait] #[async_trait]
pub(crate) trait RecvLoop: Sync + 'static { pub(crate) trait RecvLoop: Sync + 'static {
// Returns true if we should stop receiving after this // Returns true if we should stop receiving after this
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>); fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>);
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
where where
@ -205,7 +199,7 @@ pub(crate) trait RecvLoop: Sync + 'static {
if has_cont { if has_cont {
receiving.insert(id, msg_bytes); receiving.insert(id, msg_bytes);
} else { } else {
tokio::spawn(self.clone().recv_handler(id, msg_bytes)); self.recv_handler(id, msg_bytes);
} }
} }
} }

View file

@ -1,10 +1,12 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::{Arc}; use std::sync::Arc;
use arc_swap::ArcSwapOption;
use bytes::Bytes; use bytes::Bytes;
use log::{debug, trace}; use log::{debug, trace};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::select;
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
use tokio_util::compat::*; use tokio_util::compat::*;
@ -42,12 +44,15 @@ pub(crate) struct ServerConn {
netapp: Arc<NetApp>, netapp: Arc<NetApp>,
resp_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>, resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
close_send: watch::Sender<bool>,
} }
impl ServerConn { impl ServerConn {
pub(crate) async fn run(netapp: Arc<NetApp>, socket: TcpStream) -> Result<(), Error> { pub(crate) async fn run(
netapp: Arc<NetApp>,
socket: TcpStream,
must_exit: watch::Receiver<bool>,
) -> Result<(), Error> {
let remote_addr = socket.peer_addr()?; let remote_addr = socket.peer_addr()?;
let mut socket = socket.compat(); let mut socket = socket.compat();
@ -73,47 +78,33 @@ impl ServerConn {
let (resp_send, resp_recv) = mpsc::unbounded_channel(); let (resp_send, resp_recv) = mpsc::unbounded_channel();
let (close_send, close_recv) = watch::channel(false);
let conn = Arc::new(ServerConn { let conn = Arc::new(ServerConn {
netapp: netapp.clone(), netapp: netapp.clone(),
remote_addr, remote_addr,
peer_id, peer_id,
resp_send, resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))),
close_send,
}); });
netapp.connected_as_server(peer_id, conn.clone()); netapp.connected_as_server(peer_id, conn.clone());
let conn2 = conn.clone(); let conn2 = conn.clone();
let conn3 = conn.clone(); let recv_future = tokio::spawn(async move {
let close_recv2 = close_recv.clone(); select! {
tokio::try_join!( r = conn2.recv_loop(read) => r,
async move { _ = await_exit(must_exit) => Ok(())
tokio::select!( }
r = conn2.recv_loop(read) => r, });
_ = await_exit(close_recv) => Ok(()), let send_future = tokio::spawn(conn.clone().send_loop(resp_recv, write));
)
}, recv_future.await.log_err("ServerConn recv_loop");
async move { conn.resp_send.store(None);
tokio::select!( send_future.await.log_err("ServerConn send_loop");
r = conn3.send_loop(resp_recv, write) => r,
_ = await_exit(close_recv2) => Ok(()),
)
},
)
.map(|_| ())
.log_err("ServerConn recv_loop/send_loop");
netapp.disconnected_as_server(&peer_id, conn); netapp.disconnected_as_server(&peer_id, conn);
Ok(()) Ok(())
} }
pub fn close(&self) {
self.close_send.send(true).unwrap();
}
async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> { async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
if bytes.len() < 2 { if bytes.len() < 2 {
return Err(Error::Message("Invalid protocol message".into())); return Err(Error::Message("Invalid protocol message".into()));
@ -146,33 +137,33 @@ impl SendLoop for ServerConn {}
#[async_trait] #[async_trait]
impl RecvLoop for ServerConn { impl RecvLoop for ServerConn {
async fn recv_handler(self: Arc<Self>, id: RequestID, bytes: Vec<u8>) { fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) {
trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); let resp_send = self.resp_send.load_full().unwrap();
let bytes: Bytes = bytes.into();
let prio = if !bytes.is_empty() { let self2 = self.clone();
bytes[0] tokio::spawn(async move {
} else { trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len());
0u8 let bytes: Bytes = bytes.into();
};
let resp = self.recv_handler_aux(&bytes[..]).await;
let mut resp_bytes = vec![]; let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
match resp { let resp = self2.recv_handler_aux(&bytes[..]).await;
Ok(rb) => {
resp_bytes.push(0u8); let mut resp_bytes = vec![];
resp_bytes.extend(&rb[..]); match resp {
Ok(rb) => {
resp_bytes.push(0u8);
resp_bytes.extend(&rb[..]);
}
Err(e) => {
resp_bytes.push(e.code());
}
} }
Err(e) => {
resp_bytes.push(e.code());
}
}
trace!("ServerConn sending response to {}: ", id); trace!("ServerConn sending response to {}: ", id);
self.resp_send resp_send
.send(Some((id, prio, resp_bytes))) .send((id, prio, resp_bytes))
.log_err("ServerConn recv_handler send resp"); .log_err("ServerConn recv_handler send resp");
});
} }
} }

View file

@ -1,5 +1,7 @@
use serde::Serialize; use serde::Serialize;
use log::info;
use tokio::sync::watch; use tokio::sync::watch;
pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey; pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey;
@ -38,3 +40,15 @@ pub async fn await_exit(mut must_exit: watch::Receiver<bool>) {
} }
} }
} }
pub fn watch_ctrl_c() -> watch::Receiver<bool> {
let (send_cancel, watch_cancel) = watch::channel(false);
tokio::spawn(async move {
tokio::signal::ctrl_c()
.await
.expect("failed to install CTRL+C signal handler");
info!("Received CTRL+C, shutting down.");
send_cancel.send(true).unwrap();
});
watch_cancel
}