Do not close connections immediately on close signal, await for remaining responses

This commit is contained in:
Alex 2020-12-07 13:35:24 +01:00
parent 83789a3076
commit 5a9ae8615e
9 changed files with 192 additions and 235 deletions

View file

@ -1,20 +1,20 @@
use std::io::Write;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::io::Write;
use log::{debug, info, warn}; use log::{debug, info, warn};
use structopt::StructOpt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use structopt::StructOpt;
use sodiumoxide::crypto::auth; use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519; use sodiumoxide::crypto::sign::ed25519;
use netapp::NetApp;
use netapp::peering::basalt::*;
use netapp::message::*; use netapp::message::*;
use netapp::peering::basalt::*;
use netapp::proto::*; use netapp::proto::*;
use netapp::NetApp;
#[derive(StructOpt, Debug)] #[derive(StructOpt, Debug)]
#[structopt(name = "netapp")] #[structopt(name = "netapp")]
@ -52,17 +52,17 @@ async fn main() {
env_logger::Builder::new() env_logger::Builder::new()
.parse_env("RUST_LOG") .parse_env("RUST_LOG")
.format(|buf, record| { .format(|buf, record| {
writeln!(buf, writeln!(
"{} {} {} {}", buf,
chrono::Local::now().format("%s%.6f"), "{} {} {} {}",
record.module_path().unwrap_or("_"), chrono::Local::now().format("%s%.6f"),
record.level(), record.module_path().unwrap_or("_"),
record.args() record.level(),
) record.args()
)
}) })
.init(); .init();
let opt = Opt::from_args(); let opt = Opt::from_args();
let netid = match &opt.network_key { let netid = match &opt.network_key {
@ -108,10 +108,12 @@ async fn main() {
|_from: ed25519::PublicKey, msg: ExampleMessage| { |_from: ed25519::PublicKey, msg: ExampleMessage| {
debug!("Got example message: {:?}, sending example response", msg); debug!("Got example message: {:?}, sending example response", msg);
async { async {
ExampleResponse{example_field: false} ExampleResponse {
example_field: false,
}
} }
} },
); );
tokio::join!( tokio::join!(
sampling_loop(netapp.clone(), peering.clone()), sampling_loop(netapp.clone(), peering.clone()),
@ -120,8 +122,6 @@ async fn main() {
); );
} }
async fn sampling_loop(netapp: Arc<NetApp>, basalt: Arc<Basalt>) { async fn sampling_loop(netapp: Arc<NetApp>, basalt: Arc<Basalt>) {
loop { loop {
tokio::time::delay_for(Duration::from_secs(10)).await; tokio::time::delay_for(Duration::from_secs(10)).await;
@ -132,9 +132,10 @@ async fn sampling_loop(netapp: Arc<NetApp>, basalt: Arc<Basalt>) {
let netapp2 = netapp.clone(); let netapp2 = netapp.clone();
tokio::spawn(async move { tokio::spawn(async move {
match netapp2.request(&p, ExampleMessage{ match netapp2
example_field: 42, .request(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL)
}, PRIO_NORMAL).await { .await
{
Ok(resp) => debug!("Got example response: {:?}", resp), Ok(resp) => debug!("Got example response: {:?}", resp),
Err(e) => warn!("Error with example request: {}", e), Err(e) => warn!("Error with example request: {}", e),
} }

View file

@ -1,5 +1,5 @@
use std::net::SocketAddr;
use std::io::Write; use std::io::Write;
use std::net::SocketAddr;
use log::info; use log::info;
@ -32,13 +32,14 @@ async fn main() {
env_logger::Builder::new() env_logger::Builder::new()
.parse_env("RUST_LOG") .parse_env("RUST_LOG")
.format(|buf, record| { .format(|buf, record| {
writeln!(buf, writeln!(
"{} {} {} {}", buf,
chrono::Local::now().format("%s%.6f"), "{} {} {} {}",
record.module_path().unwrap_or("_"), chrono::Local::now().format("%s%.6f"),
record.level(), record.module_path().unwrap_or("_"),
record.args() record.level(),
) record.args()
)
}) })
.init(); .init();

View file

@ -1,17 +1,18 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::{self, AtomicU16}; use std::sync::atomic::{self, AtomicBool, AtomicU16};
use std::sync::Arc; use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use log::{debug, trace}; use log::{debug, error, trace};
use sodiumoxide::crypto::sign::ed25519; use sodiumoxide::crypto::sign::ed25519;
use tokio::io::split; use tokio::io::split;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot, watch}; use tokio::sync::{mpsc, oneshot, watch};
use async_trait::async_trait;
use kuska_handshake::async_std::{ use kuska_handshake::async_std::{
handshake_client, handshake_server, BoxStream, TokioCompatExt, TokioCompatExtRead, handshake_client, handshake_server, BoxStream, TokioCompatExt, TokioCompatExtRead,
TokioCompatExtWrite, TokioCompatExtWrite,
@ -29,7 +30,7 @@ pub(crate) struct ServerConn {
netapp: Arc<NetApp>, netapp: Arc<NetApp>,
resp_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>, resp_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
close_send: watch::Sender<bool>, close_send: watch::Sender<bool>,
} }
@ -78,9 +79,20 @@ impl ServerConn {
let conn2 = conn.clone(); let conn2 = conn.clone();
let conn3 = conn.clone(); let conn3 = conn.clone();
let close_recv2 = close_recv.clone();
tokio::try_join!( tokio::try_join!(
conn2.recv_loop(box_stream_read, close_recv.clone()), async move {
conn3.send_loop(resp_recv, box_stream_write, close_recv.clone()), tokio::select!(
r = conn2.recv_loop(box_stream_read) => r,
_ = await_exit(close_recv) => Ok(()),
)
},
async move {
tokio::select!(
r = conn3.send_loop(resp_recv, box_stream_write) => r,
_ = await_exit(close_recv2) => Ok(()),
)
},
) )
.map(|_| ()) .map(|_| ())
.log_err("ServerConn recv_loop/send_loop"); .log_err("ServerConn recv_loop/send_loop");
@ -112,7 +124,7 @@ impl RecvLoop for ServerConn {
let net_handler = &handler.net_handler; let net_handler = &handler.net_handler;
let resp = net_handler(self.peer_pk.clone(), bytes.slice(5..)).await; let resp = net_handler(self.peer_pk.clone(), bytes.slice(5..)).await;
self.resp_send self.resp_send
.send((id, prio, resp)) .send(Some((id, prio, resp)))
.log_err("ServerConn recv_handler send resp"); .log_err("ServerConn recv_handler send resp");
} }
} }
@ -121,11 +133,12 @@ pub(crate) struct ClientConn {
pub(crate) remote_addr: SocketAddr, pub(crate) remote_addr: SocketAddr,
pub(crate) peer_pk: ed25519::PublicKey, pub(crate) peer_pk: ed25519::PublicKey,
query_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>, query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
next_query_number: AtomicU16, next_query_number: AtomicU16,
resp_send: mpsc::UnboundedSender<(RequestID, Vec<u8>)>, inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
resp_notify_send: mpsc::UnboundedSender<(RequestID, oneshot::Sender<Vec<u8>>)>, must_exit: AtomicBool,
close_send: watch::Sender<bool>, stop_recv_loop: watch::Sender<bool>,
} }
impl ClientConn { impl ClientConn {
@ -163,19 +176,17 @@ impl ClientConn {
BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write(); BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();
let (query_send, query_recv) = mpsc::unbounded_channel(); let (query_send, query_recv) = mpsc::unbounded_channel();
let (resp_send, resp_recv) = mpsc::unbounded_channel();
let (resp_notify_send, resp_notify_recv) = mpsc::unbounded_channel();
let (close_send, close_recv) = watch::channel(false); let (stop_recv_loop, stop_recv_loop_recv) = watch::channel(false);
let conn = Arc::new(ClientConn { let conn = Arc::new(ClientConn {
remote_addr, remote_addr,
peer_pk: remote_pk.clone(), peer_pk: remote_pk.clone(),
next_query_number: AtomicU16::from(0u16), next_query_number: AtomicU16::from(0u16),
query_send, query_send,
resp_send, inflight: Mutex::new(HashMap::new()),
resp_notify_send, must_exit: AtomicBool::new(false),
close_send, stop_recv_loop,
}); });
netapp.connected_as_client(remote_pk.clone(), conn.clone()); netapp.connected_as_client(remote_pk.clone(), conn.clone());
@ -183,11 +194,14 @@ impl ClientConn {
tokio::spawn(async move { tokio::spawn(async move {
let conn2 = conn.clone(); let conn2 = conn.clone();
let conn3 = conn.clone(); let conn3 = conn.clone();
let conn4 = conn.clone();
tokio::try_join!( tokio::try_join!(
conn2.send_loop(query_recv, box_stream_write, close_recv.clone()), conn2.send_loop(query_recv, box_stream_write),
conn3.recv_loop(box_stream_read, close_recv.clone()), async move {
conn4.dispatch_resp(resp_recv, resp_notify_recv, close_recv.clone()), tokio::select!(
r = conn3.recv_loop(box_stream_read) => r,
_ = await_exit(stop_recv_loop_recv) => Ok(()),
)
}
) )
.map(|_| ()) .map(|_| ())
.log_err("ClientConn send_loop/recv_loop/dispatch_loop"); .log_err("ClientConn send_loop/recv_loop/dispatch_loop");
@ -199,51 +213,15 @@ impl ClientConn {
} }
pub fn close(&self) { pub fn close(&self) {
self.close_send.broadcast(true).unwrap(); self.must_exit.store(true, atomic::Ordering::SeqCst);
} self.query_send
.send(None)
async fn dispatch_resp( .log_err("could not write None in query_send");
self: Arc<Self>, if self.inflight.lock().unwrap().is_empty() {
mut resp_recv: mpsc::UnboundedReceiver<(RequestID, Vec<u8>)>, self.stop_recv_loop
mut resp_notify_recv: mpsc::UnboundedReceiver<(RequestID, oneshot::Sender<Vec<u8>>)>, .broadcast(true)
mut must_exit: watch::Receiver<bool>, .log_err("could not write true to stop_recv_loop");
) -> Result<(), Error> {
let mut resps: HashMap<RequestID, Vec<u8>> = HashMap::new();
let mut resp_notify: HashMap<RequestID, oneshot::Sender<Vec<u8>>> = HashMap::new();
while !*must_exit.borrow() {
tokio::select! {
resp = resp_recv.recv() => {
if let Some((id, resp)) = resp {
trace!("dispatch_resp: got resp to {}, {} bytes", id, resp.len());
if let Some(ch) = resp_notify.remove(&id) {
if ch.send(resp).is_err() {
debug!("Could not dispatch reply (channel probably closed, happens if request was canceled)");
}
} else {
resps.insert(id, resp);
}
}
}
resp_ch = resp_notify_recv.recv() => {
if let Some((id, resp_ch)) = resp_ch {
trace!("dispatch_resp: got resp_ch {}", id);
if let Some(rs) = resps.remove(&id) {
if resp_ch.send(rs).is_err() {
debug!("Could not dispatch reply (channel probably closed, happens if request was canceled)");
}
} else {
resp_notify.insert(id, resp_ch);
}
}
}
exit = must_exit.recv() => {
if exit == Some(true) {
break;
}
}
}
} }
Ok(())
} }
pub(crate) async fn request<T>( pub(crate) async fn request<T>(
@ -262,10 +240,18 @@ impl ClientConn {
bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]); bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]);
let (resp_send, resp_recv) = oneshot::channel(); let (resp_send, resp_recv) = oneshot::channel();
self.resp_notify_send.send((id, resp_send))?; let old = self.inflight.lock().unwrap().insert(id, resp_send);
if let Some(old_ch) = old {
error!(
"Too many inflight requests! RequestID collision. Interrupting previous request."
);
if old_ch.send(vec![]).is_err() {
debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
}
}
trace!("request: query_send {}, {} bytes", id, bytes.len()); trace!("request: query_send {}, {} bytes", id, bytes.len());
self.query_send.send((id, prio, bytes))?; self.query_send.send(Some((id, prio, bytes)))?;
let resp = resp_recv.await?; let resp = resp_recv.await?;
@ -279,8 +265,17 @@ impl SendLoop for ClientConn {}
#[async_trait] #[async_trait]
impl RecvLoop for ClientConn { impl RecvLoop for ClientConn {
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>) { async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>) {
self.resp_send let mut inflight = self.inflight.lock().unwrap();
.send((id, msg)) if let Some(ch) = inflight.remove(&id) {
.log_err("ClientConn::recv_handler"); if ch.send(msg).is_err() {
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
}
}
if inflight.is_empty() && self.must_exit.load(atomic::Ordering::SeqCst) {
self.stop_recv_loop
.broadcast(true)
.log_err("could not write true to stop_recv_loop");
}
} }
} }

View file

@ -18,8 +18,8 @@
pub mod error; pub mod error;
pub mod util; pub mod util;
pub mod proto;
pub mod message; pub mod message;
pub mod proto;
mod conn; mod conn;

View file

@ -133,18 +133,22 @@ impl NetApp {
/// been successfully established. Do not set this if using a peering strategy, /// been successfully established. Do not set this if using a peering strategy,
/// as the peering strategy will need to set this itself. /// as the peering strategy will need to set this itself.
pub fn on_connected<F>(&self, handler: F) pub fn on_connected<F>(&self, handler: F)
where F: Fn(ed25519::PublicKey, SocketAddr, bool) + Sized + Send + Sync + 'static where
{ F: Fn(ed25519::PublicKey, SocketAddr, bool) + Sized + Send + Sync + 'static,
self.on_connected_handler.store(Some(Arc::new(Box::new(handler)))); {
self.on_connected_handler
.store(Some(Arc::new(Box::new(handler))));
} }
/// Set the handler to be called when an existing connection (incoming or outgoing) has /// Set the handler to be called when an existing connection (incoming or outgoing) has
/// been closed by either party. Do not set this if using a peering strategy, /// been closed by either party. Do not set this if using a peering strategy,
/// as the peering strategy will need to set this itself. /// as the peering strategy will need to set this itself.
pub fn on_disconnected<F>(&self, handler: F) pub fn on_disconnected<F>(&self, handler: F)
where F: Fn(ed25519::PublicKey, bool) + Sized + Send + Sync + 'static where
{ F: Fn(ed25519::PublicKey, bool) + Sized + Send + Sync + 'static,
self.on_disconnected_handler.store(Some(Arc::new(Box::new(handler)))); {
self.on_disconnected_handler
.store(Some(Arc::new(Box::new(handler))));
} }
/// Add a handler for a certain message type. Note that only one handler /// Add a handler for a certain message type. Note that only one handler
@ -240,11 +244,13 @@ impl NetApp {
pub fn disconnect(self: &Arc<Self>, pk: &ed25519::PublicKey) { pub fn disconnect(self: &Arc<Self>, pk: &ed25519::PublicKey) {
// If pk is ourself, we're not supposed to have a connection open // If pk is ourself, we're not supposed to have a connection open
if *pk != self.pubkey { if *pk != self.pubkey {
let conn = self.client_conns.read().unwrap().remove(pk); let conn = self.client_conns.write().unwrap().remove(pk);
if let Some(c) = conn { if let Some(c) = conn {
debug!("Closing connection to {} ({})", debug!(
hex::encode(c.peer_pk), "Closing connection to {} ({})",
c.remote_addr); hex::encode(c.peer_pk),
c.remote_addr
);
c.close(); c.close();
} else { } else {
return; return;
@ -268,9 +274,11 @@ impl NetApp {
pub fn server_disconnect(self: &Arc<Self>, pk: &ed25519::PublicKey) { pub fn server_disconnect(self: &Arc<Self>, pk: &ed25519::PublicKey) {
let conn = self.server_conns.read().unwrap().get(pk).cloned(); let conn = self.server_conns.read().unwrap().get(pk).cloned();
if let Some(c) = conn { if let Some(c) = conn {
debug!("Closing incoming connection from {} ({})", debug!(
hex::encode(c.peer_pk), "Closing incoming connection from {} ({})",
c.remote_addr); hex::encode(c.peer_pk),
c.remote_addr
);
c.close(); c.close();
} }
} }

View file

@ -3,7 +3,7 @@ use std::net::SocketAddr;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::Duration; use std::time::Duration;
use log::{trace, debug, info, warn}; use log::{debug, info, trace, warn};
use lru::LruCache; use lru::LruCache;
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -267,15 +267,13 @@ impl Basalt {
netapp.on_connected( netapp.on_connected(
move |pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool| { move |pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool| {
basalt2.on_connected(pk, addr, is_incoming); basalt2.on_connected(pk, addr, is_incoming);
} },
); );
let basalt2 = basalt.clone(); let basalt2 = basalt.clone();
netapp.on_disconnected( netapp.on_disconnected(move |pk: ed25519::PublicKey, is_incoming: bool| {
move |pk: ed25519::PublicKey, is_incoming: bool| { basalt2.on_disconnected(pk, is_incoming);
basalt2.on_disconnected(pk, is_incoming); });
},
);
let basalt2 = basalt.clone(); let basalt2 = basalt.clone();
netapp.add_msg_handler::<PullMessage, _, _>( netapp.add_msg_handler::<PullMessage, _, _>(

View file

@ -185,12 +185,10 @@ impl FullMeshPeeringStrategy {
); );
let strat2 = strat.clone(); let strat2 = strat.clone();
netapp.on_disconnected( netapp.on_disconnected(move |pk: ed25519::PublicKey, is_incoming: bool| {
move |pk: ed25519::PublicKey, is_incoming: bool| { let strat2 = strat2.clone();
let strat2 = strat2.clone(); tokio::spawn(strat2.on_disconnected(pk, is_incoming));
tokio::spawn(strat2.on_disconnected(pk, is_incoming)); });
},
);
strat strat
} }

View file

@ -3,14 +3,14 @@ use std::sync::Arc;
use log::trace; use log::trace;
use async_trait::async_trait;
use async_std::io::prelude::WriteExt; use async_std::io::prelude::WriteExt;
use async_std::io::ReadExt; use async_std::io::ReadExt;
use tokio::io::{ReadHalf, WriteHalf}; use tokio::io::{ReadHalf, WriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::{mpsc, watch}; use tokio::sync::mpsc;
use async_trait::async_trait;
use crate::error::*; use crate::error::*;
@ -85,26 +85,33 @@ impl SendQueue {
} }
} }
} }
fn is_empty(&self) -> bool {
self.items.iter().all(|(_k, v)| v.is_empty())
}
} }
#[async_trait] #[async_trait]
pub(crate) trait SendLoop: Sync { pub(crate) trait SendLoop: Sync {
async fn send_loop( async fn send_loop(
self: Arc<Self>, self: Arc<Self>,
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>, mut msg_recv: mpsc::UnboundedReceiver<Option<(RequestID, RequestPriority, Vec<u8>)>>,
mut write: BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>, mut write: BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
mut must_exit: watch::Receiver<bool>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut sending = SendQueue::new(); let mut sending = SendQueue::new();
while !*must_exit.borrow() { let mut should_exit = false;
if let Ok((id, prio, data)) = msg_recv.try_recv() { while !should_exit || !sending.is_empty() {
trace!("send_loop: got {}, {} bytes", id, data.len()); if let Ok(sth) = msg_recv.try_recv() {
sending.push(SendQueueItem { if let Some((id, prio, data)) = sth {
id, trace!("send_loop: got {}, {} bytes", id, data.len());
prio, sending.push(SendQueueItem {
data, id,
cursor: 0, prio,
}); data,
cursor: 0,
});
} else {
should_exit = true;
}
} else if let Some(mut item) = sending.pop() { } else if let Some(mut item) = sending.pop() {
trace!( trace!(
"send_loop: sending bytes for {} ({} bytes, {} already sent)", "send_loop: sending bytes for {} ({} bytes, {} already sent)",
@ -113,33 +120,14 @@ pub(crate) trait SendLoop: Sync {
item.cursor item.cursor
); );
let header_id = u16::to_be_bytes(item.id); let header_id = u16::to_be_bytes(item.id);
if write_all_or_exit(&header_id[..], &mut write, &mut must_exit) write.write_all(&header_id[..]).await?;
.await?
.is_none()
{
break;
}
if item.data.len() - item.cursor > MAX_CHUNK_SIZE { if item.data.len() - item.cursor > MAX_CHUNK_SIZE {
let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000); let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000);
if write_all_or_exit(&header_size[..], &mut write, &mut must_exit) write.write_all(&header_size[..]).await?;
.await?
.is_none()
{
break;
}
let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize; let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize;
if write_all_or_exit( write.write_all(&item.data[item.cursor..new_cursor]).await?;
&item.data[item.cursor..new_cursor],
&mut write,
&mut must_exit,
)
.await?
.is_none()
{
break;
}
item.cursor = new_cursor; item.cursor = new_cursor;
sending.push(item); sending.push(item);
@ -147,33 +135,27 @@ pub(crate) trait SendLoop: Sync {
let send_len = (item.data.len() - item.cursor) as u16; let send_len = (item.data.len() - item.cursor) as u16;
let header_size = u16::to_be_bytes(send_len); let header_size = u16::to_be_bytes(send_len);
if write_all_or_exit(&header_size[..], &mut write, &mut must_exit) write.write_all(&header_size[..]).await?;
.await?
.is_none()
{
break;
}
if write_all_or_exit(&item.data[item.cursor..], &mut write, &mut must_exit) write.write_all(&item.data[item.cursor..]).await?;
.await?
.is_none()
{
break;
}
} }
write.flush().await.log_err("Could not flush in send_loop"); write.flush().await.log_err("Could not flush in send_loop");
} else { } else {
let (id, prio, data) = msg_recv let sth = msg_recv
.recv() .recv()
.await .await
.ok_or(Error::Message("Connection closed.".into()))?; .ok_or(Error::Message("Connection closed.".into()))?;
trace!("send_loop: got {}, {} bytes", id, data.len()); if let Some((id, prio, data)) = sth {
sending.push(SendQueueItem { trace!("send_loop: got {}, {} bytes", id, data.len());
id, sending.push(SendQueueItem {
prio, id,
data, prio,
cursor: 0, data,
}); cursor: 0,
});
} else {
should_exit = true;
}
} }
} }
Ok(()) Ok(())
@ -182,33 +164,23 @@ pub(crate) trait SendLoop: Sync {
#[async_trait] #[async_trait]
pub(crate) trait RecvLoop: Sync + 'static { pub(crate) trait RecvLoop: Sync + 'static {
// Returns true if we should stop receiving after this
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>); async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>);
async fn recv_loop( async fn recv_loop(
self: Arc<Self>, self: Arc<Self>,
mut read: BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>, mut read: BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
mut must_exit: watch::Receiver<bool>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut receiving = HashMap::new(); let mut receiving = HashMap::new();
while !*must_exit.borrow() { loop {
trace!("recv_loop: reading packet"); trace!("recv_loop: reading packet");
let mut header_id = [0u8; 2]; let mut header_id = [0u8; 2];
if read_exact_or_exit(&mut header_id[..], &mut read, &mut must_exit) read.read_exact(&mut header_id[..]).await?;
.await?
.is_none()
{
break;
}
let id = RequestID::from_be_bytes(header_id); let id = RequestID::from_be_bytes(header_id);
trace!("recv_loop: got header id: {:04x}", id); trace!("recv_loop: got header id: {:04x}", id);
let mut header_size = [0u8; 2]; let mut header_size = [0u8; 2];
if read_exact_or_exit(&mut header_size[..], &mut read, &mut must_exit) read.read_exact(&mut header_size[..]).await?;
.await?
.is_none()
{
break;
}
let size = RequestID::from_be_bytes(header_size); let size = RequestID::from_be_bytes(header_size);
trace!("recv_loop: got header size: {:04x}", id); trace!("recv_loop: got header size: {:04x}", id);
@ -216,12 +188,7 @@ pub(crate) trait RecvLoop: Sync + 'static {
let size = size & !0x8000; let size = size & !0x8000;
let mut next_slice = vec![0; size as usize]; let mut next_slice = vec![0; size as usize];
if read_exact_or_exit(&mut next_slice[..], &mut read, &mut must_exit) read.read_exact(&mut next_slice[..]).await?;
.await?
.is_none()
{
break;
}
trace!("recv_loop: read {} bytes", size); trace!("recv_loop: read {} bytes", size);
let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]); let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]);
@ -233,36 +200,5 @@ pub(crate) trait RecvLoop: Sync + 'static {
tokio::spawn(self.clone().recv_handler(id, msg_bytes)); tokio::spawn(self.clone().recv_handler(id, msg_bytes));
} }
} }
Ok(())
}
}
async fn read_exact_or_exit(
buf: &mut [u8],
read: &mut BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
must_exit: &mut watch::Receiver<bool>,
) -> Result<Option<()>, Error> {
tokio::select!(
res = read.read_exact(buf) => Ok(Some(res?)),
_ = await_exit(must_exit) => Ok(None),
)
}
async fn write_all_or_exit(
buf: &[u8],
write: &mut BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
must_exit: &mut watch::Receiver<bool>,
) -> Result<Option<()>, Error> {
tokio::select!(
res = write.write_all(buf) => Ok(Some(res?)),
_ = await_exit(must_exit) => Ok(None),
)
}
async fn await_exit(must_exit: &mut watch::Receiver<bool>) {
loop {
if must_exit.recv().await == Some(true) {
return;
}
} }
} }

View file

@ -1,5 +1,7 @@
use serde::Serialize; use serde::Serialize;
use tokio::sync::watch;
/// Utility function: encodes any serializable value in MessagePack binary format /// Utility function: encodes any serializable value in MessagePack binary format
/// using the RMP library. /// using the RMP library.
/// ///
@ -16,3 +18,21 @@ where
val.serialize(&mut se)?; val.serialize(&mut se)?;
Ok(wr) Ok(wr)
} }
/// This async function returns only when a true signal was received
/// from a watcher that tells us when to exit.
/// Usefull in a select statement to interrupt another
/// future:
/// ```
/// select!(
/// _ = a_long_task() => Success,
/// _ = await_exit(must_exit) => Interrupted,
/// )
/// ```
pub async fn await_exit(mut must_exit: watch::Receiver<bool>) {
loop {
if must_exit.recv().await == Some(true) {
return;
}
}
}