forked from lx/netapp
Try to handle termination and closing of stuff properly
This commit is contained in:
parent
8dede69dee
commit
70839d70d8
15 changed files with 264 additions and 166 deletions
23
Cargo.lock
generated
23
Cargo.lock
generated
|
@ -446,6 +446,7 @@ dependencies = [
|
|||
"serde",
|
||||
"structopt",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
|
@ -666,6 +667,15 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.4"
|
||||
|
@ -779,7 +789,9 @@ dependencies = [
|
|||
"memchr",
|
||||
"mio",
|
||||
"num_cpus",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
"tokio-macros",
|
||||
"winapi",
|
||||
]
|
||||
|
@ -795,6 +807,17 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-stream"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b2f3f698253f03119ac0102beaa64f67a67e08074d03a22d18784104543727f"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.6.8"
|
||||
|
|
|
@ -20,8 +20,9 @@ basalt = ["lru", "rand"]
|
|||
|
||||
[dependencies]
|
||||
futures = "0.3.17"
|
||||
tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util"] }
|
||||
tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] }
|
||||
tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] }
|
||||
tokio-stream = "0.1.7"
|
||||
|
||||
serde = { version = "1.0", default-features = false, features = ["derive"] }
|
||||
rmp-serde = "0.14.3"
|
||||
|
|
2
Makefile
2
Makefile
|
@ -2,6 +2,6 @@ all:
|
|||
cargo build --all-features
|
||||
cargo build --example fullmesh
|
||||
cargo build --all-features --example basalt
|
||||
RUST_LOG=netapp=debug,fullmesh=debug cargo run --example fullmesh -- -n 3242ce79e05e8b6a0e43441fbd140a906e13f335f298ae3a52f29784abbab500 -p 6c304114a0e1018bbe60502a34d33f4f439f370856c3333dda2726da01eb93a4894b7ef7249a71f11d342b69702f1beb7c93ec95fbcf122ad1eca583bb0629e7
|
||||
RUST_LOG=netapp=trace,fullmesh=trace cargo run --example fullmesh -- -n 3242ce79e05e8b6a0e43441fbd140a906e13f335f298ae3a52f29784abbab500 -p 6c304114a0e1018bbe60502a34d33f4f439f370856c3333dda2726da01eb93a4894b7ef7249a71f11d342b69702f1beb7c93ec95fbcf122ad1eca583bb0629e7
|
||||
#RUST_LOG=netapp=debug,fullmesh=debug cargo run --example fullmesh
|
||||
|
||||
|
|
|
@ -5,9 +5,9 @@ use std::time::Duration;
|
|||
|
||||
use log::{debug, info, warn};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use structopt::StructOpt;
|
||||
use async_trait::async_trait;
|
||||
|
||||
use sodiumoxide::crypto::auth;
|
||||
use sodiumoxide::crypto::sign::ed25519;
|
||||
|
@ -122,9 +122,15 @@ async fn main() {
|
|||
|
||||
let listen_addr = opt.listen_addr.parse().unwrap();
|
||||
let public_addr = opt.public_addr.map(|x| x.parse().unwrap());
|
||||
|
||||
let watch_cancel = netapp::util::watch_ctrl_c();
|
||||
|
||||
tokio::join!(
|
||||
example.clone().sampling_loop(),
|
||||
example.netapp.clone().listen(listen_addr, public_addr),
|
||||
example
|
||||
.netapp
|
||||
.clone()
|
||||
.listen(listen_addr, public_addr, watch_cancel),
|
||||
example.basalt.clone().run(),
|
||||
);
|
||||
}
|
||||
|
@ -141,7 +147,8 @@ impl Example {
|
|||
let self2 = self.clone();
|
||||
tokio::spawn(async move {
|
||||
match self2
|
||||
.example_endpoint.call(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL)
|
||||
.example_endpoint
|
||||
.call(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => debug!("Got example response: {:?}", resp),
|
||||
|
|
|
@ -87,6 +87,11 @@ async fn main() {
|
|||
hex::encode(&privkey.public_key()),
|
||||
listen_addr);
|
||||
|
||||
let watch_cancel = netapp::util::watch_ctrl_c();
|
||||
|
||||
let public_addr = opt.public_addr.map(|x| x.parse().unwrap());
|
||||
tokio::join!(netapp.listen(listen_addr, public_addr), peering.run(),);
|
||||
tokio::join!(
|
||||
netapp.listen(listen_addr, public_addr, watch_cancel.clone()),
|
||||
peering.run(watch_cancel),
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{self, AtomicBool, AtomicU32};
|
||||
use std::sync::atomic::{self, AtomicU32};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use log::{debug, error, trace};
|
||||
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::select;
|
||||
use tokio::sync::{mpsc, oneshot, watch};
|
||||
use tokio_util::compat::*;
|
||||
|
||||
|
@ -21,17 +23,14 @@ use crate::netapp::*;
|
|||
use crate::proto::*;
|
||||
use crate::util::*;
|
||||
|
||||
|
||||
pub(crate) struct ClientConn {
|
||||
pub(crate) remote_addr: SocketAddr,
|
||||
pub(crate) peer_id: NodeID,
|
||||
|
||||
query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
|
||||
query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
|
||||
|
||||
next_query_number: AtomicU32,
|
||||
inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
|
||||
must_exit: AtomicBool,
|
||||
stop_recv_loop: watch::Sender<bool>,
|
||||
}
|
||||
|
||||
impl ClientConn {
|
||||
|
@ -71,25 +70,35 @@ impl ClientConn {
|
|||
remote_addr,
|
||||
peer_id,
|
||||
next_query_number: AtomicU32::from(RequestID::default()),
|
||||
query_send,
|
||||
query_send: ArcSwapOption::new(Some(Arc::new(query_send))),
|
||||
inflight: Mutex::new(HashMap::new()),
|
||||
must_exit: AtomicBool::new(false),
|
||||
stop_recv_loop,
|
||||
});
|
||||
|
||||
netapp.connected_as_client(peer_id, conn.clone());
|
||||
|
||||
tokio::spawn(async move {
|
||||
let send_future = tokio::spawn(conn.clone().send_loop(query_recv, write));
|
||||
|
||||
let conn2 = conn.clone();
|
||||
let conn3 = conn.clone();
|
||||
tokio::try_join!(conn2.send_loop(query_recv, write), async move {
|
||||
tokio::select!(
|
||||
r = conn3.recv_loop(read) => r,
|
||||
_ = await_exit(stop_recv_loop_recv) => Ok(()),
|
||||
)
|
||||
})
|
||||
.map(|_| ())
|
||||
.log_err("ClientConn send_loop/recv_loop/dispatch_loop");
|
||||
let recv_future = tokio::spawn(async move {
|
||||
select! {
|
||||
r = conn2.recv_loop(read) => r,
|
||||
_ = await_exit(stop_recv_loop_recv) => Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
send_future.await.log_err("ClientConn send_loop");
|
||||
|
||||
// TODO here: wait for inflight requests to all have their response
|
||||
stop_recv_loop
|
||||
.send(true)
|
||||
.log_err("ClientConn send true to stop_recv_loop");
|
||||
|
||||
recv_future.await.log_err("ClientConn recv_loop");
|
||||
|
||||
// Make sure we don't wait on any more requests that won't
|
||||
// have a response
|
||||
conn.inflight.lock().unwrap().clear();
|
||||
|
||||
netapp.disconnected_as_client(&peer_id, conn);
|
||||
});
|
||||
|
@ -98,15 +107,7 @@ impl ClientConn {
|
|||
}
|
||||
|
||||
pub fn close(&self) {
|
||||
self.must_exit.store(true, atomic::Ordering::SeqCst);
|
||||
self.query_send
|
||||
.send(None)
|
||||
.log_err("could not write None in query_send");
|
||||
if self.inflight.lock().unwrap().is_empty() {
|
||||
self.stop_recv_loop
|
||||
.send(true)
|
||||
.log_err("could not write true to stop_recv_loop");
|
||||
}
|
||||
self.query_send.store(None);
|
||||
}
|
||||
|
||||
pub(crate) async fn call<T>(
|
||||
|
@ -118,6 +119,8 @@ impl ClientConn {
|
|||
where
|
||||
T: Message,
|
||||
{
|
||||
let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
|
||||
|
||||
let id = self
|
||||
.next_query_number
|
||||
.fetch_add(1, atomic::Ordering::Relaxed);
|
||||
|
@ -138,20 +141,23 @@ impl ClientConn {
|
|||
}
|
||||
|
||||
trace!("request: query_send {}, {} bytes", id, bytes.len());
|
||||
self.query_send.send(Some((id, prio, bytes)))?;
|
||||
query_send.send((id, prio, bytes))?;
|
||||
|
||||
let resp = resp_recv.await?;
|
||||
if resp.len() == 0 {
|
||||
return Err(Error::Message("Response is 0 bytes, either a collision or a protocol error".into()));
|
||||
if resp.is_empty() {
|
||||
return Err(Error::Message(
|
||||
"Response is 0 bytes, either a collision or a protocol error".into(),
|
||||
));
|
||||
}
|
||||
|
||||
trace!("request response {}: ", id);
|
||||
|
||||
let code = resp[0];
|
||||
if code == 0 {
|
||||
Ok(rmp_serde::decode::from_read_ref::<_, <T as Message>::Response>(
|
||||
&resp[1..],
|
||||
)?)
|
||||
Ok(rmp_serde::decode::from_read_ref::<
|
||||
_,
|
||||
<T as Message>::Response,
|
||||
>(&resp[1..])?)
|
||||
} else {
|
||||
Err(Error::Remote(format!("Remote error code {}", code)))
|
||||
}
|
||||
|
@ -162,7 +168,7 @@ impl SendLoop for ClientConn {}
|
|||
|
||||
#[async_trait]
|
||||
impl RecvLoop for ClientConn {
|
||||
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>) {
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>) {
|
||||
trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len());
|
||||
|
||||
let mut inflight = self.inflight.lock().unwrap();
|
||||
|
@ -171,11 +177,5 @@ impl RecvLoop for ClientConn {
|
|||
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
|
||||
}
|
||||
}
|
||||
|
||||
if inflight.is_empty() && self.must_exit.load(atomic::Ordering::SeqCst) {
|
||||
self.stop_recv_loop
|
||||
.send(true)
|
||||
.log_err("could not write true to stop_recv_loop");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -123,4 +123,3 @@ where
|
|||
Box::new(Self(self.0.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
17
src/error.rs
17
src/error.rs
|
@ -31,6 +31,9 @@ pub enum Error {
|
|||
#[error(display = "No handler / shutting down")]
|
||||
NoHandler,
|
||||
|
||||
#[error(display = "Connection closed")]
|
||||
ConnectionClosed,
|
||||
|
||||
#[error(display = "Remote error: {}", _0)]
|
||||
Remote(String),
|
||||
}
|
||||
|
@ -45,6 +48,7 @@ impl Error {
|
|||
Self::RMPDecode(_) => 11,
|
||||
Self::UTF8(_) => 12,
|
||||
Self::NoHandler => 20,
|
||||
Self::ConnectionClosed => 21,
|
||||
Self::Handshake(_) => 30,
|
||||
Self::Remote(_) => 40,
|
||||
Self::Message(_) => 99,
|
||||
|
@ -80,3 +84,16 @@ where
|
|||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl<E, T> LogError for Result<T, E>
|
||||
where
|
||||
T: LogError,
|
||||
E: Into<Error>,
|
||||
{
|
||||
fn log_err(self, msg: &'static str) {
|
||||
match self {
|
||||
Err(e) => error!("Error: {}: {}", msg, Into::<Error>::into(e)),
|
||||
Ok(x) => x.log_err(msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,16 +13,14 @@
|
|||
//! about message priorization.
|
||||
//! Also check out the examples to learn how to use this crate.
|
||||
|
||||
#![feature(map_first_last)]
|
||||
|
||||
pub mod error;
|
||||
pub mod util;
|
||||
|
||||
pub mod endpoint;
|
||||
pub mod proto;
|
||||
|
||||
mod server;
|
||||
mod client;
|
||||
mod server;
|
||||
|
||||
pub mod netapp;
|
||||
pub mod peering;
|
||||
|
|
107
src/netapp.rs
107
src/netapp.rs
|
@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use log::{debug, info, error};
|
||||
use log::{debug, error, info, trace, warn};
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use async_trait::async_trait;
|
||||
|
@ -10,13 +10,18 @@ use async_trait::async_trait;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use sodiumoxide::crypto::auth;
|
||||
use sodiumoxide::crypto::sign::ed25519;
|
||||
|
||||
use futures::stream::futures_unordered::FuturesUnordered;
|
||||
use futures::stream::StreamExt;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::select;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
|
||||
use crate::client::*;
|
||||
use crate::server::*;
|
||||
use crate::endpoint::*;
|
||||
use crate::error::*;
|
||||
use crate::proto::*;
|
||||
use crate::server::*;
|
||||
use crate::util::*;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
|
@ -142,35 +147,91 @@ impl NetApp {
|
|||
/// Main listening process for our app. This future runs during the whole
|
||||
/// run time of our application.
|
||||
/// If this is not called, the NetApp instance remains a passive client.
|
||||
pub async fn listen(self: Arc<Self>, listen_addr: SocketAddr, public_addr: Option<IpAddr>) {
|
||||
pub async fn listen(
|
||||
self: Arc<Self>,
|
||||
listen_addr: SocketAddr,
|
||||
public_addr: Option<IpAddr>,
|
||||
mut must_exit: watch::Receiver<bool>,
|
||||
) {
|
||||
let listen_params = ListenParams {
|
||||
listen_addr,
|
||||
public_addr,
|
||||
};
|
||||
if self.listen_params.swap(Some(Arc::new(listen_params))).is_some() {
|
||||
if self
|
||||
.listen_params
|
||||
.swap(Some(Arc::new(listen_params)))
|
||||
.is_some()
|
||||
{
|
||||
error!("Trying to listen on NetApp but we're already listening!");
|
||||
}
|
||||
|
||||
let listener = TcpListener::bind(listen_addr).await.unwrap();
|
||||
info!("Listening on {}", listen_addr);
|
||||
|
||||
loop {
|
||||
// The second item contains the IP and port of the new connection.
|
||||
let (socket, _) = listener.accept().await.unwrap();
|
||||
let (conn_in, mut conn_out) = mpsc::unbounded_channel();
|
||||
let connection_collector = tokio::spawn(async move {
|
||||
let mut collection = FuturesUnordered::new();
|
||||
loop {
|
||||
if collection.is_empty() {
|
||||
match conn_out.recv().await {
|
||||
Some(f) => collection.push(f),
|
||||
None => break,
|
||||
}
|
||||
} else {
|
||||
select! {
|
||||
new_fut = conn_out.recv() => {
|
||||
match new_fut {
|
||||
Some(f) => collection.push(f),
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
result = collection.next() => {
|
||||
trace!("Collected connection: {:?}", result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!("Collecting last open server connections.");
|
||||
while let Some(conn_res) = collection.next().await {
|
||||
trace!("Collected connection: {:?}", conn_res);
|
||||
}
|
||||
debug!("No more server connections to collect");
|
||||
});
|
||||
|
||||
while !*must_exit.borrow_and_update() {
|
||||
let (socket, peer_addr) = select! {
|
||||
sockres = listener.accept() => {
|
||||
match sockres {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
warn!("Error in listener.accept: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
},
|
||||
_ = must_exit.changed() => continue,
|
||||
};
|
||||
|
||||
info!(
|
||||
"Incoming connection from {}, negotiating handshake...",
|
||||
match socket.peer_addr() {
|
||||
Ok(x) => format!("{}", x),
|
||||
Err(e) => format!("<invalid addr: {}>", e),
|
||||
}
|
||||
peer_addr
|
||||
);
|
||||
let self2 = self.clone();
|
||||
tokio::spawn(async move {
|
||||
ServerConn::run(self2, socket)
|
||||
.await
|
||||
.log_err("ServerConn::run");
|
||||
});
|
||||
let must_exit2 = must_exit.clone();
|
||||
conn_in
|
||||
.send(tokio::spawn(async move {
|
||||
ServerConn::run(self2, socket, must_exit2)
|
||||
.await
|
||||
.log_err("ServerConn::run");
|
||||
}))
|
||||
.log_err("Failed to send connection to connection collector");
|
||||
}
|
||||
|
||||
drop(conn_in);
|
||||
|
||||
connection_collector
|
||||
.await
|
||||
.log_err("Failed to await for connection collector");
|
||||
}
|
||||
|
||||
/// Attempt to connect to a peer, given by its ip:port and its public key.
|
||||
|
@ -231,20 +292,6 @@ impl NetApp {
|
|||
});
|
||||
}
|
||||
|
||||
/// Close the incoming connection from a certain client to us,
|
||||
/// if such a connection is currently open.
|
||||
pub fn server_disconnect(self: &Arc<Self>, id: &NodeID) {
|
||||
let conn = self.server_conns.read().unwrap().get(id).cloned();
|
||||
if let Some(c) = conn {
|
||||
debug!(
|
||||
"Closing incoming connection from {} ({})",
|
||||
hex::encode(c.peer_id),
|
||||
c.remote_addr
|
||||
);
|
||||
c.close();
|
||||
}
|
||||
}
|
||||
|
||||
// Called from conn.rs when an incoming connection is successfully established
|
||||
// Registers the connection in our list of connections
|
||||
// Do not yet call the on_connected handler, because we don't know if the remote
|
||||
|
|
|
@ -3,11 +3,11 @@ use std::net::SocketAddr;
|
|||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use log::{debug, info, trace, warn};
|
||||
use lru::LruCache;
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use async_trait::async_trait;
|
||||
|
||||
use sodiumoxide::crypto::hash;
|
||||
|
||||
|
|
|
@ -8,6 +8,8 @@ use async_trait::async_trait;
|
|||
use log::{debug, info, trace, warn};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use tokio::sync::watch;
|
||||
|
||||
use sodiumoxide::crypto::hash;
|
||||
|
||||
use crate::endpoint::*;
|
||||
|
@ -171,8 +173,8 @@ impl FullMeshPeeringStrategy {
|
|||
strat
|
||||
}
|
||||
|
||||
pub async fn run(self: Arc<Self>) {
|
||||
loop {
|
||||
pub async fn run(self: Arc<Self>, must_exit: watch::Receiver<bool>) {
|
||||
while !*must_exit.borrow() {
|
||||
// 1. Read current state: get list of connected peers (ping them)
|
||||
let (to_ping, to_retry) = {
|
||||
let known_hosts = self.known_hosts.read().unwrap();
|
||||
|
|
56
src/proto.rs
56
src/proto.rs
|
@ -1,4 +1,4 @@
|
|||
use std::collections::{BTreeMap, HashMap, VecDeque};
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::sync::Arc;
|
||||
|
||||
use log::trace;
|
||||
|
@ -50,7 +50,6 @@ type ChunkLength = u16;
|
|||
const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
|
||||
const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
|
||||
|
||||
|
||||
struct SendQueueItem {
|
||||
id: RequestID,
|
||||
prio: RequestPriority,
|
||||
|
@ -59,31 +58,33 @@ struct SendQueueItem {
|
|||
}
|
||||
|
||||
struct SendQueue {
|
||||
items: BTreeMap<u8, VecDeque<SendQueueItem>>,
|
||||
items: VecDeque<(u8, VecDeque<SendQueueItem>)>,
|
||||
}
|
||||
|
||||
impl SendQueue {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
items: BTreeMap::new(),
|
||||
items: VecDeque::with_capacity(64),
|
||||
}
|
||||
}
|
||||
fn push(&mut self, item: SendQueueItem) {
|
||||
let prio = item.prio;
|
||||
let mut items_at_prio = self
|
||||
.items
|
||||
.remove(&prio)
|
||||
.unwrap_or_else(|| VecDeque::with_capacity(4));
|
||||
items_at_prio.push_back(item);
|
||||
self.items.insert(prio, items_at_prio);
|
||||
let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) {
|
||||
Ok(i) => i,
|
||||
Err(i) => {
|
||||
self.items.insert(i, (prio, VecDeque::new()));
|
||||
i
|
||||
}
|
||||
};
|
||||
self.items[pos_prio].1.push_back(item);
|
||||
}
|
||||
fn pop(&mut self) -> Option<SendQueueItem> {
|
||||
match self.items.pop_first() {
|
||||
match self.items.pop_front() {
|
||||
None => None,
|
||||
Some((prio, mut items_at_prio)) => {
|
||||
let ret = items_at_prio.pop_front();
|
||||
if !items_at_prio.is_empty() {
|
||||
self.items.insert(prio, items_at_prio);
|
||||
self.items.push_front((prio, items_at_prio));
|
||||
}
|
||||
ret.or_else(|| self.pop())
|
||||
}
|
||||
|
@ -98,7 +99,7 @@ impl SendQueue {
|
|||
pub(crate) trait SendLoop: Sync {
|
||||
async fn send_loop<W>(
|
||||
self: Arc<Self>,
|
||||
mut msg_recv: mpsc::UnboundedReceiver<Option<(RequestID, RequestPriority, Vec<u8>)>>,
|
||||
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
|
||||
mut write: W,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
|
@ -107,18 +108,14 @@ pub(crate) trait SendLoop: Sync {
|
|||
let mut sending = SendQueue::new();
|
||||
let mut should_exit = false;
|
||||
while !should_exit || !sending.is_empty() {
|
||||
if let Ok(sth) = msg_recv.try_recv() {
|
||||
if let Some((id, prio, data)) = sth {
|
||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||
sending.push(SendQueueItem {
|
||||
id,
|
||||
prio,
|
||||
data,
|
||||
cursor: 0,
|
||||
});
|
||||
} else {
|
||||
should_exit = true;
|
||||
}
|
||||
if let Ok((id, prio, data)) = msg_recv.try_recv() {
|
||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||
sending.push(SendQueueItem {
|
||||
id,
|
||||
prio,
|
||||
data,
|
||||
cursor: 0,
|
||||
});
|
||||
} else if let Some(mut item) = sending.pop() {
|
||||
trace!(
|
||||
"send_loop: sending bytes for {} ({} bytes, {} already sent)",
|
||||
|
@ -149,10 +146,7 @@ pub(crate) trait SendLoop: Sync {
|
|||
}
|
||||
write.flush().await?;
|
||||
} else {
|
||||
let sth = msg_recv
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| Error::Message("Connection closed.".into()))?;
|
||||
let sth = msg_recv.recv().await;
|
||||
if let Some((id, prio, data)) = sth {
|
||||
trace!("send_loop: got {}, {} bytes", id, data.len());
|
||||
sending.push(SendQueueItem {
|
||||
|
@ -173,7 +167,7 @@ pub(crate) trait SendLoop: Sync {
|
|||
#[async_trait]
|
||||
pub(crate) trait RecvLoop: Sync + 'static {
|
||||
// Returns true if we should stop receiving after this
|
||||
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>);
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>);
|
||||
|
||||
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
|
||||
where
|
||||
|
@ -205,7 +199,7 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
|||
if has_cont {
|
||||
receiving.insert(id, msg_bytes);
|
||||
} else {
|
||||
tokio::spawn(self.clone().recv_handler(id, msg_bytes));
|
||||
self.recv_handler(id, msg_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
use std::net::SocketAddr;
|
||||
use std::sync::{Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use bytes::Bytes;
|
||||
use log::{debug, trace};
|
||||
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::select;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tokio_util::compat::*;
|
||||
|
||||
|
@ -42,12 +44,15 @@ pub(crate) struct ServerConn {
|
|||
|
||||
netapp: Arc<NetApp>,
|
||||
|
||||
resp_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
|
||||
close_send: watch::Sender<bool>,
|
||||
resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
|
||||
}
|
||||
|
||||
impl ServerConn {
|
||||
pub(crate) async fn run(netapp: Arc<NetApp>, socket: TcpStream) -> Result<(), Error> {
|
||||
pub(crate) async fn run(
|
||||
netapp: Arc<NetApp>,
|
||||
socket: TcpStream,
|
||||
must_exit: watch::Receiver<bool>,
|
||||
) -> Result<(), Error> {
|
||||
let remote_addr = socket.peer_addr()?;
|
||||
let mut socket = socket.compat();
|
||||
|
||||
|
@ -73,47 +78,33 @@ impl ServerConn {
|
|||
|
||||
let (resp_send, resp_recv) = mpsc::unbounded_channel();
|
||||
|
||||
let (close_send, close_recv) = watch::channel(false);
|
||||
|
||||
let conn = Arc::new(ServerConn {
|
||||
netapp: netapp.clone(),
|
||||
remote_addr,
|
||||
peer_id,
|
||||
resp_send,
|
||||
close_send,
|
||||
resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))),
|
||||
});
|
||||
|
||||
netapp.connected_as_server(peer_id, conn.clone());
|
||||
|
||||
let conn2 = conn.clone();
|
||||
let conn3 = conn.clone();
|
||||
let close_recv2 = close_recv.clone();
|
||||
tokio::try_join!(
|
||||
async move {
|
||||
tokio::select!(
|
||||
r = conn2.recv_loop(read) => r,
|
||||
_ = await_exit(close_recv) => Ok(()),
|
||||
)
|
||||
},
|
||||
async move {
|
||||
tokio::select!(
|
||||
r = conn3.send_loop(resp_recv, write) => r,
|
||||
_ = await_exit(close_recv2) => Ok(()),
|
||||
)
|
||||
},
|
||||
)
|
||||
.map(|_| ())
|
||||
.log_err("ServerConn recv_loop/send_loop");
|
||||
let recv_future = tokio::spawn(async move {
|
||||
select! {
|
||||
r = conn2.recv_loop(read) => r,
|
||||
_ = await_exit(must_exit) => Ok(())
|
||||
}
|
||||
});
|
||||
let send_future = tokio::spawn(conn.clone().send_loop(resp_recv, write));
|
||||
|
||||
recv_future.await.log_err("ServerConn recv_loop");
|
||||
conn.resp_send.store(None);
|
||||
send_future.await.log_err("ServerConn send_loop");
|
||||
|
||||
netapp.disconnected_as_server(&peer_id, conn);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn close(&self) {
|
||||
self.close_send.send(true).unwrap();
|
||||
}
|
||||
|
||||
async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
|
||||
if bytes.len() < 2 {
|
||||
return Err(Error::Message("Invalid protocol message".into()));
|
||||
|
@ -146,33 +137,33 @@ impl SendLoop for ServerConn {}
|
|||
|
||||
#[async_trait]
|
||||
impl RecvLoop for ServerConn {
|
||||
async fn recv_handler(self: Arc<Self>, id: RequestID, bytes: Vec<u8>) {
|
||||
trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len());
|
||||
let bytes: Bytes = bytes.into();
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) {
|
||||
let resp_send = self.resp_send.load_full().unwrap();
|
||||
|
||||
let prio = if !bytes.is_empty() {
|
||||
bytes[0]
|
||||
} else {
|
||||
0u8
|
||||
};
|
||||
let resp = self.recv_handler_aux(&bytes[..]).await;
|
||||
let self2 = self.clone();
|
||||
tokio::spawn(async move {
|
||||
trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len());
|
||||
let bytes: Bytes = bytes.into();
|
||||
|
||||
let mut resp_bytes = vec![];
|
||||
match resp {
|
||||
Ok(rb) => {
|
||||
resp_bytes.push(0u8);
|
||||
resp_bytes.extend(&rb[..]);
|
||||
let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
|
||||
let resp = self2.recv_handler_aux(&bytes[..]).await;
|
||||
|
||||
let mut resp_bytes = vec![];
|
||||
match resp {
|
||||
Ok(rb) => {
|
||||
resp_bytes.push(0u8);
|
||||
resp_bytes.extend(&rb[..]);
|
||||
}
|
||||
Err(e) => {
|
||||
resp_bytes.push(e.code());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
resp_bytes.push(e.code());
|
||||
}
|
||||
}
|
||||
|
||||
trace!("ServerConn sending response to {}: ", id);
|
||||
trace!("ServerConn sending response to {}: ", id);
|
||||
|
||||
self.resp_send
|
||||
.send(Some((id, prio, resp_bytes)))
|
||||
.log_err("ServerConn recv_handler send resp");
|
||||
resp_send
|
||||
.send((id, prio, resp_bytes))
|
||||
.log_err("ServerConn recv_handler send resp");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
14
src/util.rs
14
src/util.rs
|
@ -1,5 +1,7 @@
|
|||
use serde::Serialize;
|
||||
|
||||
use log::info;
|
||||
|
||||
use tokio::sync::watch;
|
||||
|
||||
pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey;
|
||||
|
@ -38,3 +40,15 @@ pub async fn await_exit(mut must_exit: watch::Receiver<bool>) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn watch_ctrl_c() -> watch::Receiver<bool> {
|
||||
let (send_cancel, watch_cancel) = watch::channel(false);
|
||||
tokio::spawn(async move {
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install CTRL+C signal handler");
|
||||
info!("Received CTRL+C, shutting down.");
|
||||
send_cancel.send(true).unwrap();
|
||||
});
|
||||
watch_cancel
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue