add detection of premature eos

This commit is contained in:
trinity-1686a 2022-07-18 15:21:13 +02:00
parent d3d18b8e8b
commit cdff8ae1be
4 changed files with 58 additions and 19 deletions

View file

@ -7,6 +7,7 @@ use std::sync::{Arc, Mutex};
use arc_swap::ArcSwapOption; use arc_swap::ArcSwapOption;
use log::{debug, error, trace}; use log::{debug, error, trace};
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::select; use tokio::select;
use tokio::sync::{mpsc, oneshot, watch}; use tokio::sync::{mpsc, oneshot, watch};
@ -41,7 +42,7 @@ pub(crate) struct ClientConn {
ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>, ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>,
next_query_number: AtomicU32, next_query_number: AtomicU32,
inflight: Mutex<HashMap<RequestID, oneshot::Sender<AssociatedStream>>>, inflight: Mutex<HashMap<RequestID, oneshot::Sender<UnboundedReceiver<Packet>>>>,
} }
impl ClientConn { impl ClientConn {
@ -186,7 +187,7 @@ impl ClientConn {
error!( error!(
"Too many inflight requests! RequestID collision. Interrupting previous request." "Too many inflight requests! RequestID collision. Interrupting previous request."
); );
if old_ch.send(Box::pin(futures::stream::empty())).is_err() { if old_ch.send(unbounded().1).is_err() {
debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
} }
} }
@ -232,7 +233,7 @@ impl SendLoop for ClientConn {}
#[async_trait] #[async_trait]
impl RecvLoop for ClientConn { impl RecvLoop for ClientConn {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream) { fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
trace!("ClientConn recv_handler {}", id); trace!("ClientConn recv_handler {}", id);
let mut inflight = self.inflight.lock().unwrap(); let mut inflight = self.inflight.lock().unwrap();

View file

@ -5,7 +5,7 @@ use std::task::{Context, Poll};
use log::trace; use log::trace;
use futures::channel::mpsc::{unbounded, UnboundedSender}; use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use futures::{AsyncReadExt, AsyncWriteExt}; use futures::{AsyncReadExt, AsyncWriteExt};
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use kuska_handshake::async_std::BoxStreamWrite; use kuska_handshake::async_std::BoxStreamWrite;
@ -15,7 +15,7 @@ use tokio::sync::mpsc;
use async_trait::async_trait; use async_trait::async_trait;
use crate::error::*; use crate::error::*;
use crate::util::AssociatedStream; use crate::util::{AssociatedStream, Packet};
/// Priority of a request (click to read more about priorities). /// Priority of a request (click to read more about priorities).
/// ///
@ -67,7 +67,7 @@ struct SendQueueItem {
struct DataReader { struct DataReader {
#[pin] #[pin]
reader: AssociatedStream, reader: AssociatedStream,
packet: Result<Vec<u8>, u8>, packet: Packet,
pos: usize, pos: usize,
buf: Vec<u8>, buf: Vec<u8>,
eos: bool, eos: bool,
@ -370,7 +370,7 @@ impl Framing {
} }
} }
pub async fn from_stream<S: Stream<Item = Result<Vec<u8>, u8>> + Unpin + Send + 'static>( pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>(
mut stream: S, mut stream: S,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let mut packet = stream let mut packet = stream
@ -422,6 +422,39 @@ impl Framing {
} }
} }
/// Structure to warn when the sender is dropped before end of stream was reached, like when
/// connection to some remote drops while transmitting data
struct Sender {
inner: UnboundedSender<Packet>,
closed: bool,
}
impl Sender {
fn new(inner: UnboundedSender<Packet>) -> Self {
Sender {
inner,
closed: false,
}
}
fn send(&self, packet: Packet) {
let _ = self.inner.unbounded_send(packet);
}
fn end(&mut self) {
self.closed = true;
}
}
impl Drop for Sender {
fn drop(&mut self) {
if !self.closed {
self.send(Err(255));
}
self.inner.close_channel();
}
}
/// The RecvLoop trait, which is implemented both by the client and the server /// The RecvLoop trait, which is implemented both by the client and the server
/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
/// and a prototype of a handler for received messages `.recv_handler()` that /// and a prototype of a handler for received messages `.recv_handler()` that
@ -431,13 +464,13 @@ impl Framing {
/// the full message is passed to the receive handler. /// the full message is passed to the receive handler.
#[async_trait] #[async_trait]
pub(crate) trait RecvLoop: Sync + 'static { pub(crate) trait RecvLoop: Sync + 'static {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream); fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>);
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
where where
R: AsyncReadExt + Unpin + Send + Sync, R: AsyncReadExt + Unpin + Send + Sync,
{ {
let mut streams: HashMap<RequestID, UnboundedSender<Result<Vec<u8>, u8>>> = HashMap::new(); let mut streams: HashMap<RequestID, Sender> = HashMap::new();
loop { loop {
trace!("recv_loop: reading packet"); trace!("recv_loop: reading packet");
let mut header_id = [0u8; RequestID::BITS as usize / 8]; let mut header_id = [0u8; RequestID::BITS as usize / 8];
@ -466,20 +499,22 @@ pub(crate) trait RecvLoop: Sync + 'static {
Ok(next_slice) Ok(next_slice)
}; };
let sender = if let Some(send) = streams.remove(&(id)) { let mut sender = if let Some(send) = streams.remove(&(id)) {
send send
} else { } else {
let (send, recv) = unbounded(); let (send, recv) = unbounded();
self.recv_handler(id, Box::pin(recv)); self.recv_handler(id, recv);
send Sender::new(send)
}; };
// if we get an error, the receiving end is disconnected. We still need to // if we get an error, the receiving end is disconnected. We still need to
// reach eos before dropping this sender // reach eos before dropping this sender
let _ = sender.unbounded_send(packet); sender.send(packet);
if has_cont { if has_cont {
streams.insert(id, sender); streams.insert(id, sender);
} else {
sender.end();
} }
} }
Ok(()) Ok(())
@ -491,9 +526,9 @@ mod test {
use super::*; use super::*;
fn empty_data() -> DataReader { fn empty_data() -> DataReader {
type Item = Result<Vec<u8>, u8>; type Item = Packet;
let stream: Pin<Box<dyn futures::Stream<Item = Item> + Send + 'static>> = let stream: Pin<Box<dyn futures::Stream<Item = Item> + Send + 'static>> =
Box::pin(futures::stream::empty::<Result<Vec<u8>, u8>>()); Box::pin(futures::stream::empty::<Packet>());
stream.into() stream.into()
} }

View file

@ -19,6 +19,7 @@ use tokio::select;
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
use tokio_util::compat::*; use tokio_util::compat::*;
use futures::channel::mpsc::UnboundedReceiver;
use futures::io::{AsyncReadExt, AsyncWriteExt}; use futures::io::{AsyncReadExt, AsyncWriteExt};
use async_trait::async_trait; use async_trait::async_trait;
@ -176,7 +177,7 @@ impl SendLoop for ServerConn {}
#[async_trait] #[async_trait]
impl RecvLoop for ServerConn { impl RecvLoop for ServerConn {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream) { fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
let resp_send = self.resp_send.load_full().unwrap(); let resp_send = self.resp_send.load_full().unwrap();
let self2 = self.clone(); let self2 = self.clone();

View file

@ -25,9 +25,11 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key;
/// When sent through Netapp, the Vec may be split in smaller chunk in such a way /// When sent through Netapp, the Vec may be split in smaller chunk in such a way
/// consecutive Vec may get merged, but Vec and error code may not be reordered /// consecutive Vec may get merged, but Vec and error code may not be reordered
/// ///
/// The error code have no predefined meaning, it's up to you application to define their /// Error code 255 means the stream was cut before its end. Other codes have no predefined
/// semantic. /// meaning, it's up to your application to define their semantic.
pub type AssociatedStream = Pin<Box<dyn Stream<Item = Result<Vec<u8>, u8>> + Send>>; pub type AssociatedStream = Pin<Box<dyn Stream<Item = Packet> + Send>>;
pub type Packet = Result<Vec<u8>, u8>;
/// Utility function: encodes any serializable value in MessagePack binary format /// Utility function: encodes any serializable value in MessagePack binary format
/// using the RMP library. /// using the RMP library.