forked from lx/netapp
Do not close connections immediately on close signal, await for remaining responses
This commit is contained in:
parent
83789a3076
commit
5a9ae8615e
9 changed files with 192 additions and 235 deletions
|
@ -1,20 +1,20 @@
|
|||
use std::io::Write;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::io::Write;
|
||||
|
||||
use log::{debug, info, warn};
|
||||
|
||||
use structopt::StructOpt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use structopt::StructOpt;
|
||||
|
||||
use sodiumoxide::crypto::auth;
|
||||
use sodiumoxide::crypto::sign::ed25519;
|
||||
|
||||
use netapp::NetApp;
|
||||
use netapp::peering::basalt::*;
|
||||
use netapp::message::*;
|
||||
use netapp::peering::basalt::*;
|
||||
use netapp::proto::*;
|
||||
use netapp::NetApp;
|
||||
|
||||
#[derive(StructOpt, Debug)]
|
||||
#[structopt(name = "netapp")]
|
||||
|
@ -52,17 +52,17 @@ async fn main() {
|
|||
env_logger::Builder::new()
|
||||
.parse_env("RUST_LOG")
|
||||
.format(|buf, record| {
|
||||
writeln!(buf,
|
||||
"{} {} {} {}",
|
||||
chrono::Local::now().format("%s%.6f"),
|
||||
record.module_path().unwrap_or("_"),
|
||||
record.level(),
|
||||
record.args()
|
||||
)
|
||||
writeln!(
|
||||
buf,
|
||||
"{} {} {} {}",
|
||||
chrono::Local::now().format("%s%.6f"),
|
||||
record.module_path().unwrap_or("_"),
|
||||
record.level(),
|
||||
record.args()
|
||||
)
|
||||
})
|
||||
.init();
|
||||
|
||||
|
||||
let opt = Opt::from_args();
|
||||
|
||||
let netid = match &opt.network_key {
|
||||
|
@ -108,10 +108,12 @@ async fn main() {
|
|||
|_from: ed25519::PublicKey, msg: ExampleMessage| {
|
||||
debug!("Got example message: {:?}, sending example response", msg);
|
||||
async {
|
||||
ExampleResponse{example_field: false}
|
||||
ExampleResponse {
|
||||
example_field: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
tokio::join!(
|
||||
sampling_loop(netapp.clone(), peering.clone()),
|
||||
|
@ -120,8 +122,6 @@ async fn main() {
|
|||
);
|
||||
}
|
||||
|
||||
|
||||
|
||||
async fn sampling_loop(netapp: Arc<NetApp>, basalt: Arc<Basalt>) {
|
||||
loop {
|
||||
tokio::time::delay_for(Duration::from_secs(10)).await;
|
||||
|
@ -132,9 +132,10 @@ async fn sampling_loop(netapp: Arc<NetApp>, basalt: Arc<Basalt>) {
|
|||
|
||||
let netapp2 = netapp.clone();
|
||||
tokio::spawn(async move {
|
||||
match netapp2.request(&p, ExampleMessage{
|
||||
example_field: 42,
|
||||
}, PRIO_NORMAL).await {
|
||||
match netapp2
|
||||
.request(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => debug!("Got example response: {:?}", resp),
|
||||
Err(e) => warn!("Error with example request: {}", e),
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use std::net::SocketAddr;
|
||||
use std::io::Write;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use log::info;
|
||||
|
||||
|
@ -32,13 +32,14 @@ async fn main() {
|
|||
env_logger::Builder::new()
|
||||
.parse_env("RUST_LOG")
|
||||
.format(|buf, record| {
|
||||
writeln!(buf,
|
||||
"{} {} {} {}",
|
||||
chrono::Local::now().format("%s%.6f"),
|
||||
record.module_path().unwrap_or("_"),
|
||||
record.level(),
|
||||
record.args()
|
||||
)
|
||||
writeln!(
|
||||
buf,
|
||||
"{} {} {} {}",
|
||||
chrono::Local::now().format("%s%.6f"),
|
||||
record.module_path().unwrap_or("_"),
|
||||
record.level(),
|
||||
record.args()
|
||||
)
|
||||
})
|
||||
.init();
|
||||
|
||||
|
|
137
src/conn.rs
137
src/conn.rs
|
@ -1,17 +1,18 @@
|
|||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{self, AtomicU16};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{self, AtomicBool, AtomicU16};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use log::{debug, trace};
|
||||
use log::{debug, error, trace};
|
||||
|
||||
use sodiumoxide::crypto::sign::ed25519;
|
||||
use tokio::io::split;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, oneshot, watch};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use kuska_handshake::async_std::{
|
||||
handshake_client, handshake_server, BoxStream, TokioCompatExt, TokioCompatExtRead,
|
||||
TokioCompatExtWrite,
|
||||
|
@ -29,7 +30,7 @@ pub(crate) struct ServerConn {
|
|||
|
||||
netapp: Arc<NetApp>,
|
||||
|
||||
resp_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>,
|
||||
resp_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
|
||||
close_send: watch::Sender<bool>,
|
||||
}
|
||||
|
||||
|
@ -78,9 +79,20 @@ impl ServerConn {
|
|||
|
||||
let conn2 = conn.clone();
|
||||
let conn3 = conn.clone();
|
||||
let close_recv2 = close_recv.clone();
|
||||
tokio::try_join!(
|
||||
conn2.recv_loop(box_stream_read, close_recv.clone()),
|
||||
conn3.send_loop(resp_recv, box_stream_write, close_recv.clone()),
|
||||
async move {
|
||||
tokio::select!(
|
||||
r = conn2.recv_loop(box_stream_read) => r,
|
||||
_ = await_exit(close_recv) => Ok(()),
|
||||
)
|
||||
},
|
||||
async move {
|
||||
tokio::select!(
|
||||
r = conn3.send_loop(resp_recv, box_stream_write) => r,
|
||||
_ = await_exit(close_recv2) => Ok(()),
|
||||
)
|
||||
},
|
||||
)
|
||||
.map(|_| ())
|
||||
.log_err("ServerConn recv_loop/send_loop");
|
||||
|
@ -112,7 +124,7 @@ impl RecvLoop for ServerConn {
|
|||
let net_handler = &handler.net_handler;
|
||||
let resp = net_handler(self.peer_pk.clone(), bytes.slice(5..)).await;
|
||||
self.resp_send
|
||||
.send((id, prio, resp))
|
||||
.send(Some((id, prio, resp)))
|
||||
.log_err("ServerConn recv_handler send resp");
|
||||
}
|
||||
}
|
||||
|
@ -121,11 +133,12 @@ pub(crate) struct ClientConn {
|
|||
pub(crate) remote_addr: SocketAddr,
|
||||
pub(crate) peer_pk: ed25519::PublicKey,
|
||||
|
||||
query_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>,
|
||||
query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
|
||||
|
||||
next_query_number: AtomicU16,
|
||||
resp_send: mpsc::UnboundedSender<(RequestID, Vec<u8>)>,
|
||||
resp_notify_send: mpsc::UnboundedSender<(RequestID, oneshot::Sender<Vec<u8>>)>,
|
||||
close_send: watch::Sender<bool>,
|
||||
inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
|
||||
must_exit: AtomicBool,
|
||||
stop_recv_loop: watch::Sender<bool>,
|
||||
}
|
||||
|
||||
impl ClientConn {
|
||||
|
@ -163,19 +176,17 @@ impl ClientConn {
|
|||
BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();
|
||||
|
||||
let (query_send, query_recv) = mpsc::unbounded_channel();
|
||||
let (resp_send, resp_recv) = mpsc::unbounded_channel();
|
||||
let (resp_notify_send, resp_notify_recv) = mpsc::unbounded_channel();
|
||||
|
||||
let (close_send, close_recv) = watch::channel(false);
|
||||
let (stop_recv_loop, stop_recv_loop_recv) = watch::channel(false);
|
||||
|
||||
let conn = Arc::new(ClientConn {
|
||||
remote_addr,
|
||||
peer_pk: remote_pk.clone(),
|
||||
next_query_number: AtomicU16::from(0u16),
|
||||
query_send,
|
||||
resp_send,
|
||||
resp_notify_send,
|
||||
close_send,
|
||||
inflight: Mutex::new(HashMap::new()),
|
||||
must_exit: AtomicBool::new(false),
|
||||
stop_recv_loop,
|
||||
});
|
||||
|
||||
netapp.connected_as_client(remote_pk.clone(), conn.clone());
|
||||
|
@ -183,11 +194,14 @@ impl ClientConn {
|
|||
tokio::spawn(async move {
|
||||
let conn2 = conn.clone();
|
||||
let conn3 = conn.clone();
|
||||
let conn4 = conn.clone();
|
||||
tokio::try_join!(
|
||||
conn2.send_loop(query_recv, box_stream_write, close_recv.clone()),
|
||||
conn3.recv_loop(box_stream_read, close_recv.clone()),
|
||||
conn4.dispatch_resp(resp_recv, resp_notify_recv, close_recv.clone()),
|
||||
conn2.send_loop(query_recv, box_stream_write),
|
||||
async move {
|
||||
tokio::select!(
|
||||
r = conn3.recv_loop(box_stream_read) => r,
|
||||
_ = await_exit(stop_recv_loop_recv) => Ok(()),
|
||||
)
|
||||
}
|
||||
)
|
||||
.map(|_| ())
|
||||
.log_err("ClientConn send_loop/recv_loop/dispatch_loop");
|
||||
|
@ -199,51 +213,15 @@ impl ClientConn {
|
|||
}
|
||||
|
||||
pub fn close(&self) {
|
||||
self.close_send.broadcast(true).unwrap();
|
||||
}
|
||||
|
||||
async fn dispatch_resp(
|
||||
self: Arc<Self>,
|
||||
mut resp_recv: mpsc::UnboundedReceiver<(RequestID, Vec<u8>)>,
|
||||
mut resp_notify_recv: mpsc::UnboundedReceiver<(RequestID, oneshot::Sender<Vec<u8>>)>,
|
||||
mut must_exit: watch::Receiver<bool>,
|
||||
) -> Result<(), Error> {
|
||||
let mut resps: HashMap<RequestID, Vec<u8>> = HashMap::new();
|
||||
let mut resp_notify: HashMap<RequestID, oneshot::Sender<Vec<u8>>> = HashMap::new();
|
||||
while !*must_exit.borrow() {
|
||||
tokio::select! {
|
||||
resp = resp_recv.recv() => {
|
||||
if let Some((id, resp)) = resp {
|
||||
trace!("dispatch_resp: got resp to {}, {} bytes", id, resp.len());
|
||||
if let Some(ch) = resp_notify.remove(&id) {
|
||||
if ch.send(resp).is_err() {
|
||||
debug!("Could not dispatch reply (channel probably closed, happens if request was canceled)");
|
||||
}
|
||||
} else {
|
||||
resps.insert(id, resp);
|
||||
}
|
||||
}
|
||||
}
|
||||
resp_ch = resp_notify_recv.recv() => {
|
||||
if let Some((id, resp_ch)) = resp_ch {
|
||||
trace!("dispatch_resp: got resp_ch {}", id);
|
||||
if let Some(rs) = resps.remove(&id) {
|
||||
if resp_ch.send(rs).is_err() {
|
||||
debug!("Could not dispatch reply (channel probably closed, happens if request was canceled)");
|
||||
}
|
||||
} else {
|
||||
resp_notify.insert(id, resp_ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
exit = must_exit.recv() => {
|
||||
if exit == Some(true) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
self.must_exit.store(true, atomic::Ordering::SeqCst);
|
||||
self.query_send
|
||||
.send(None)
|
||||
.log_err("could not write None in query_send");
|
||||
if self.inflight.lock().unwrap().is_empty() {
|
||||
self.stop_recv_loop
|
||||
.broadcast(true)
|
||||
.log_err("could not write true to stop_recv_loop");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn request<T>(
|
||||
|
@ -262,10 +240,18 @@ impl ClientConn {
|
|||
bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]);
|
||||
|
||||
let (resp_send, resp_recv) = oneshot::channel();
|
||||
self.resp_notify_send.send((id, resp_send))?;
|
||||
let old = self.inflight.lock().unwrap().insert(id, resp_send);
|
||||
if let Some(old_ch) = old {
|
||||
error!(
|
||||
"Too many inflight requests! RequestID collision. Interrupting previous request."
|
||||
);
|
||||
if old_ch.send(vec![]).is_err() {
|
||||
debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
|
||||
}
|
||||
}
|
||||
|
||||
trace!("request: query_send {}, {} bytes", id, bytes.len());
|
||||
self.query_send.send((id, prio, bytes))?;
|
||||
self.query_send.send(Some((id, prio, bytes)))?;
|
||||
|
||||
let resp = resp_recv.await?;
|
||||
|
||||
|
@ -279,8 +265,17 @@ impl SendLoop for ClientConn {}
|
|||
#[async_trait]
|
||||
impl RecvLoop for ClientConn {
|
||||
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>) {
|
||||
self.resp_send
|
||||
.send((id, msg))
|
||||
.log_err("ClientConn::recv_handler");
|
||||
let mut inflight = self.inflight.lock().unwrap();
|
||||
if let Some(ch) = inflight.remove(&id) {
|
||||
if ch.send(msg).is_err() {
|
||||
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
|
||||
}
|
||||
}
|
||||
|
||||
if inflight.is_empty() && self.must_exit.load(atomic::Ordering::SeqCst) {
|
||||
self.stop_recv_loop
|
||||
.broadcast(true)
|
||||
.log_err("could not write true to stop_recv_loop");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
//! Netapp is a Rust library that takes care of a few common tasks in distributed software:
|
||||
//!
|
||||
//!
|
||||
//! - establishing secure connections
|
||||
//! - managing connection lifetime, reconnecting on failure
|
||||
//! - checking peer's state
|
||||
|
@ -18,8 +18,8 @@
|
|||
pub mod error;
|
||||
pub mod util;
|
||||
|
||||
pub mod proto;
|
||||
pub mod message;
|
||||
pub mod proto;
|
||||
|
||||
mod conn;
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ pub struct NetApp {
|
|||
|
||||
server_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ServerConn>>>,
|
||||
client_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ClientConn>>>,
|
||||
|
||||
|
||||
pub(crate) msg_handlers: ArcSwap<HashMap<MessageKind, Arc<Handler>>>,
|
||||
on_connected_handler:
|
||||
ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, SocketAddr, bool) + Send + Sync>>,
|
||||
|
@ -133,18 +133,22 @@ impl NetApp {
|
|||
/// been successfully established. Do not set this if using a peering strategy,
|
||||
/// as the peering strategy will need to set this itself.
|
||||
pub fn on_connected<F>(&self, handler: F)
|
||||
where F: Fn(ed25519::PublicKey, SocketAddr, bool) + Sized + Send + Sync + 'static
|
||||
{
|
||||
self.on_connected_handler.store(Some(Arc::new(Box::new(handler))));
|
||||
where
|
||||
F: Fn(ed25519::PublicKey, SocketAddr, bool) + Sized + Send + Sync + 'static,
|
||||
{
|
||||
self.on_connected_handler
|
||||
.store(Some(Arc::new(Box::new(handler))));
|
||||
}
|
||||
|
||||
/// Set the handler to be called when an existing connection (incoming or outgoing) has
|
||||
/// been closed by either party. Do not set this if using a peering strategy,
|
||||
/// as the peering strategy will need to set this itself.
|
||||
pub fn on_disconnected<F>(&self, handler: F)
|
||||
where F: Fn(ed25519::PublicKey, bool) + Sized + Send + Sync + 'static
|
||||
{
|
||||
self.on_disconnected_handler.store(Some(Arc::new(Box::new(handler))));
|
||||
where
|
||||
F: Fn(ed25519::PublicKey, bool) + Sized + Send + Sync + 'static,
|
||||
{
|
||||
self.on_disconnected_handler
|
||||
.store(Some(Arc::new(Box::new(handler))));
|
||||
}
|
||||
|
||||
/// Add a handler for a certain message type. Note that only one handler
|
||||
|
@ -240,11 +244,13 @@ impl NetApp {
|
|||
pub fn disconnect(self: &Arc<Self>, pk: &ed25519::PublicKey) {
|
||||
// If pk is ourself, we're not supposed to have a connection open
|
||||
if *pk != self.pubkey {
|
||||
let conn = self.client_conns.read().unwrap().remove(pk);
|
||||
let conn = self.client_conns.write().unwrap().remove(pk);
|
||||
if let Some(c) = conn {
|
||||
debug!("Closing connection to {} ({})",
|
||||
hex::encode(c.peer_pk),
|
||||
c.remote_addr);
|
||||
debug!(
|
||||
"Closing connection to {} ({})",
|
||||
hex::encode(c.peer_pk),
|
||||
c.remote_addr
|
||||
);
|
||||
c.close();
|
||||
} else {
|
||||
return;
|
||||
|
@ -268,9 +274,11 @@ impl NetApp {
|
|||
pub fn server_disconnect(self: &Arc<Self>, pk: &ed25519::PublicKey) {
|
||||
let conn = self.server_conns.read().unwrap().get(pk).cloned();
|
||||
if let Some(c) = conn {
|
||||
debug!("Closing incoming connection from {} ({})",
|
||||
hex::encode(c.peer_pk),
|
||||
c.remote_addr);
|
||||
debug!(
|
||||
"Closing incoming connection from {} ({})",
|
||||
hex::encode(c.peer_pk),
|
||||
c.remote_addr
|
||||
);
|
||||
c.close();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::net::SocketAddr;
|
|||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use log::{trace, debug, info, warn};
|
||||
use log::{debug, info, trace, warn};
|
||||
use lru::LruCache;
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -267,15 +267,13 @@ impl Basalt {
|
|||
netapp.on_connected(
|
||||
move |pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool| {
|
||||
basalt2.on_connected(pk, addr, is_incoming);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let basalt2 = basalt.clone();
|
||||
netapp.on_disconnected(
|
||||
move |pk: ed25519::PublicKey, is_incoming: bool| {
|
||||
basalt2.on_disconnected(pk, is_incoming);
|
||||
},
|
||||
);
|
||||
netapp.on_disconnected(move |pk: ed25519::PublicKey, is_incoming: bool| {
|
||||
basalt2.on_disconnected(pk, is_incoming);
|
||||
});
|
||||
|
||||
let basalt2 = basalt.clone();
|
||||
netapp.add_msg_handler::<PullMessage, _, _>(
|
||||
|
|
|
@ -185,12 +185,10 @@ impl FullMeshPeeringStrategy {
|
|||
);
|
||||
|
||||
let strat2 = strat.clone();
|
||||
netapp.on_disconnected(
|
||||
move |pk: ed25519::PublicKey, is_incoming: bool| {
|
||||
let strat2 = strat2.clone();
|
||||
tokio::spawn(strat2.on_disconnected(pk, is_incoming));
|
||||
},
|
||||
);
|
||||
netapp.on_disconnected(move |pk: ed25519::PublicKey, is_incoming: bool| {
|
||||
let strat2 = strat2.clone();
|
||||
tokio::spawn(strat2.on_disconnected(pk, is_incoming));
|
||||
});
|
||||
|
||||
strat
|
||||
}
|
||||
|
|
150
src/proto.rs
150
src/proto.rs
|
@ -3,14 +3,14 @@ use std::sync::Arc;
|
|||
|
||||
use log::trace;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use async_std::io::prelude::WriteExt;
|
||||
use async_std::io::ReadExt;
|
||||
|
||||
use tokio::io::{ReadHalf, WriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::error::*;
|
||||
|
||||
|
@ -85,26 +85,33 @@ impl SendQueue {
|
|||
}
|
||||
}
|
||||
}
|
||||
fn is_empty(&self) -> bool {
|
||||
self.items.iter().all(|(_k, v)| v.is_empty())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait SendLoop: Sync {
|
||||
async fn send_loop(
|
||||
self: Arc<Self>,
|
||||
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
|
||||
mut msg_recv: mpsc::UnboundedReceiver<Option<(RequestID, RequestPriority, Vec<u8>)>>,
|
||||
mut write: BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
|
||||
mut must_exit: watch::Receiver<bool>,
|
||||
) -> Result<(), Error> {
|
||||
let mut sending = SendQueue::new();
|
||||
while !*must_exit.borrow() {
|
||||
if let Ok((id, prio, data)) = msg_recv.try_recv() {
|
||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||
sending.push(SendQueueItem {
|
||||
id,
|
||||
prio,
|
||||
data,
|
||||
cursor: 0,
|
||||
});
|
||||
let mut should_exit = false;
|
||||
while !should_exit || !sending.is_empty() {
|
||||
if let Ok(sth) = msg_recv.try_recv() {
|
||||
if let Some((id, prio, data)) = sth {
|
||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||
sending.push(SendQueueItem {
|
||||
id,
|
||||
prio,
|
||||
data,
|
||||
cursor: 0,
|
||||
});
|
||||
} else {
|
||||
should_exit = true;
|
||||
}
|
||||
} else if let Some(mut item) = sending.pop() {
|
||||
trace!(
|
||||
"send_loop: sending bytes for {} ({} bytes, {} already sent)",
|
||||
|
@ -113,33 +120,14 @@ pub(crate) trait SendLoop: Sync {
|
|||
item.cursor
|
||||
);
|
||||
let header_id = u16::to_be_bytes(item.id);
|
||||
if write_all_or_exit(&header_id[..], &mut write, &mut must_exit)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
write.write_all(&header_id[..]).await?;
|
||||
|
||||
if item.data.len() - item.cursor > MAX_CHUNK_SIZE {
|
||||
let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000);
|
||||
if write_all_or_exit(&header_size[..], &mut write, &mut must_exit)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
write.write_all(&header_size[..]).await?;
|
||||
|
||||
let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize;
|
||||
if write_all_or_exit(
|
||||
&item.data[item.cursor..new_cursor],
|
||||
&mut write,
|
||||
&mut must_exit,
|
||||
)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
write.write_all(&item.data[item.cursor..new_cursor]).await?;
|
||||
item.cursor = new_cursor;
|
||||
|
||||
sending.push(item);
|
||||
|
@ -147,33 +135,27 @@ pub(crate) trait SendLoop: Sync {
|
|||
let send_len = (item.data.len() - item.cursor) as u16;
|
||||
|
||||
let header_size = u16::to_be_bytes(send_len);
|
||||
if write_all_or_exit(&header_size[..], &mut write, &mut must_exit)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
write.write_all(&header_size[..]).await?;
|
||||
|
||||
if write_all_or_exit(&item.data[item.cursor..], &mut write, &mut must_exit)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
write.write_all(&item.data[item.cursor..]).await?;
|
||||
}
|
||||
write.flush().await.log_err("Could not flush in send_loop");
|
||||
} else {
|
||||
let (id, prio, data) = msg_recv
|
||||
let sth = msg_recv
|
||||
.recv()
|
||||
.await
|
||||
.ok_or(Error::Message("Connection closed.".into()))?;
|
||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||
sending.push(SendQueueItem {
|
||||
id,
|
||||
prio,
|
||||
data,
|
||||
cursor: 0,
|
||||
});
|
||||
if let Some((id, prio, data)) = sth {
|
||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||
sending.push(SendQueueItem {
|
||||
id,
|
||||
prio,
|
||||
data,
|
||||
cursor: 0,
|
||||
});
|
||||
} else {
|
||||
should_exit = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
@ -182,33 +164,23 @@ pub(crate) trait SendLoop: Sync {
|
|||
|
||||
#[async_trait]
|
||||
pub(crate) trait RecvLoop: Sync + 'static {
|
||||
// Returns true if we should stop receiving after this
|
||||
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>);
|
||||
|
||||
async fn recv_loop(
|
||||
self: Arc<Self>,
|
||||
mut read: BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
|
||||
mut must_exit: watch::Receiver<bool>,
|
||||
) -> Result<(), Error> {
|
||||
let mut receiving = HashMap::new();
|
||||
while !*must_exit.borrow() {
|
||||
loop {
|
||||
trace!("recv_loop: reading packet");
|
||||
let mut header_id = [0u8; 2];
|
||||
if read_exact_or_exit(&mut header_id[..], &mut read, &mut must_exit)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
read.read_exact(&mut header_id[..]).await?;
|
||||
let id = RequestID::from_be_bytes(header_id);
|
||||
trace!("recv_loop: got header id: {:04x}", id);
|
||||
|
||||
let mut header_size = [0u8; 2];
|
||||
if read_exact_or_exit(&mut header_size[..], &mut read, &mut must_exit)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
read.read_exact(&mut header_size[..]).await?;
|
||||
let size = RequestID::from_be_bytes(header_size);
|
||||
trace!("recv_loop: got header size: {:04x}", id);
|
||||
|
||||
|
@ -216,12 +188,7 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
|||
let size = size & !0x8000;
|
||||
|
||||
let mut next_slice = vec![0; size as usize];
|
||||
if read_exact_or_exit(&mut next_slice[..], &mut read, &mut must_exit)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
break;
|
||||
}
|
||||
read.read_exact(&mut next_slice[..]).await?;
|
||||
trace!("recv_loop: read {} bytes", size);
|
||||
|
||||
let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]);
|
||||
|
@ -233,36 +200,5 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
|||
tokio::spawn(self.clone().recv_handler(id, msg_bytes));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_exact_or_exit(
|
||||
buf: &mut [u8],
|
||||
read: &mut BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
|
||||
must_exit: &mut watch::Receiver<bool>,
|
||||
) -> Result<Option<()>, Error> {
|
||||
tokio::select!(
|
||||
res = read.read_exact(buf) => Ok(Some(res?)),
|
||||
_ = await_exit(must_exit) => Ok(None),
|
||||
)
|
||||
}
|
||||
|
||||
async fn write_all_or_exit(
|
||||
buf: &[u8],
|
||||
write: &mut BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
|
||||
must_exit: &mut watch::Receiver<bool>,
|
||||
) -> Result<Option<()>, Error> {
|
||||
tokio::select!(
|
||||
res = write.write_all(buf) => Ok(Some(res?)),
|
||||
_ = await_exit(must_exit) => Ok(None),
|
||||
)
|
||||
}
|
||||
|
||||
async fn await_exit(must_exit: &mut watch::Receiver<bool>) {
|
||||
loop {
|
||||
if must_exit.recv().await == Some(true) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
20
src/util.rs
20
src/util.rs
|
@ -1,5 +1,7 @@
|
|||
use serde::Serialize;
|
||||
|
||||
use tokio::sync::watch;
|
||||
|
||||
/// Utility function: encodes any serializable value in MessagePack binary format
|
||||
/// using the RMP library.
|
||||
///
|
||||
|
@ -16,3 +18,21 @@ where
|
|||
val.serialize(&mut se)?;
|
||||
Ok(wr)
|
||||
}
|
||||
|
||||
/// This async function returns only when a true signal was received
|
||||
/// from a watcher that tells us when to exit.
|
||||
/// Usefull in a select statement to interrupt another
|
||||
/// future:
|
||||
/// ```
|
||||
/// select!(
|
||||
/// _ = a_long_task() => Success,
|
||||
/// _ = await_exit(must_exit) => Interrupted,
|
||||
/// )
|
||||
/// ```
|
||||
pub async fn await_exit(mut must_exit: watch::Receiver<bool>) {
|
||||
loop {
|
||||
if must_exit.recv().await == Some(true) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue