Use bounded channels on receive side for backpressure

This commit is contained in:
Alex 2022-07-22 13:01:52 +02:00
parent 0b71ca12f9
commit 9cb28c21b4
Signed by untrusted user: lx
GPG key ID: 0E496D15096376BE
4 changed files with 37 additions and 32 deletions

View file

@ -8,7 +8,6 @@ use async_trait::async_trait;
use bytes::Bytes;
use log::{debug, error, trace};
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
use futures::io::AsyncReadExt;
use kuska_handshake::async_std::{handshake_client, BoxStream};
use tokio::net::TcpStream;
@ -39,7 +38,7 @@ pub(crate) struct ClientConn {
query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, ByteStream)>>,
next_query_number: AtomicU32,
inflight: Mutex<HashMap<RequestID, oneshot::Sender<UnboundedReceiver<Packet>>>>,
inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>,
}
impl ClientConn {
@ -175,7 +174,9 @@ impl ClientConn {
error!(
"Too many inflight requests! RequestID collision. Interrupting previous request."
);
let _ = old_ch.send(unbounded().1);
let _ = old_ch.send(Box::pin(futures::stream::once(async move {
Err(Error::IdCollision.code())
})));
}
trace!(
@ -199,7 +200,7 @@ impl ClientConn {
}
}
let resp_enc = RespEnc::decode(Box::pin(stream)).await?;
let resp_enc = RespEnc::decode(stream).await?;
trace!("request response {}", id);
Resp::from_enc(resp_enc)
}
@ -209,7 +210,7 @@ impl SendLoop for ClientConn {}
#[async_trait]
impl RecvLoop for ClientConn {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
trace!("ClientConn recv_handler {}", id);
let mut inflight = self.inflight.lock().unwrap();

View file

@ -28,6 +28,9 @@ pub enum Error {
#[error(display = "Framing protocol error")]
Framing,
#[error(display = "Request ID collision")]
IdCollision,
#[error(display = "{}", _0)]
Message(String),
@ -56,6 +59,7 @@ impl Error {
Self::Framing => 13,
Self::NoHandler => 20,
Self::ConnectionClosed => 21,
Self::IdCollision => 22,
Self::Handshake(_) => 30,
Self::VersionMismatch(_) => 31,
Self::Remote(c, _) => *c,

View file

@ -5,8 +5,8 @@ use async_trait::async_trait;
use bytes::Bytes;
use log::trace;
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use futures::AsyncReadExt;
use tokio::sync::mpsc;
use crate::error::*;
use crate::send::*;
@ -15,33 +15,28 @@ use crate::stream::*;
/// Structure to warn when the sender is dropped before end of stream was reached, like when
/// connection to some remote drops while transmitting data
struct Sender {
inner: UnboundedSender<Packet>,
closed: bool,
inner: Option<mpsc::Sender<Packet>>,
}
impl Sender {
fn new(inner: UnboundedSender<Packet>) -> Self {
Sender {
inner,
closed: false,
}
fn new(inner: mpsc::Sender<Packet>) -> Self {
Sender { inner: Some(inner) }
}
fn send(&self, packet: Packet) {
let _ = self.inner.unbounded_send(packet);
async fn send(&self, packet: Packet) {
let _ = self.inner.as_ref().unwrap().send(packet).await;
}
fn end(&mut self) {
self.closed = true;
self.inner = None;
}
}
impl Drop for Sender {
fn drop(&mut self) {
if !self.closed {
self.send(Err(255));
if let Some(inner) = self.inner.take() {
let _ = inner.blocking_send(Err(255));
}
self.inner.close_channel();
}
}
@ -54,7 +49,7 @@ impl Drop for Sender {
/// the full message is passed to the receive handler.
#[async_trait]
pub(crate) trait RecvLoop: Sync + 'static {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>);
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream);
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
where
@ -92,14 +87,17 @@ pub(crate) trait RecvLoop: Sync + 'static {
let mut sender = if let Some(send) = streams.remove(&(id)) {
send
} else {
let (send, recv) = unbounded();
self.recv_handler(id, recv);
let (send, recv) = mpsc::channel(4);
self.recv_handler(
id,
Box::pin(tokio_stream::wrappers::ReceiverStream::new(recv)),
);
Sender::new(send)
};
// if we get an error, the receiving end is disconnected. We still need to
// reach eos before dropping this sender
sender.send(packet);
// If we get an error, the receiving end is disconnected.
// We still need to reach eos before dropping this sender
let _ = sender.send(packet).await;
if has_cont {
streams.insert(id, sender);

View file

@ -5,7 +5,6 @@ use arc_swap::ArcSwapOption;
use async_trait::async_trait;
use log::{debug, trace};
use futures::channel::mpsc::UnboundedReceiver;
use futures::io::{AsyncReadExt, AsyncWriteExt};
use kuska_handshake::async_std::{handshake_server, BoxStream};
use tokio::net::TcpStream;
@ -171,21 +170,24 @@ impl SendLoop for ServerConn {}
#[async_trait]
impl RecvLoop for ServerConn {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
let resp_send = self.resp_send.load_full().unwrap();
let self2 = self.clone();
tokio::spawn(async move {
trace!("ServerConn recv_handler {}", id);
let (prio, resp_enc) = match ReqEnc::decode(Box::pin(stream)).await {
let (prio, resp_enc) = match ReqEnc::decode(stream).await {
Ok(req_enc) => {
let prio = req_enc.prio;
let resp = self2.recv_handler_aux(req_enc).await;
(prio, match resp {
(
prio,
match resp {
Ok(resp_enc) => resp_enc,
Err(e) => RespEnc::from_err(e),
})
},
)
}
Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)),
};