Use bounded channels on receive side for backpressure
Some checks failed
continuous-integration/drone/push Build is failing
continuous-integration/drone/pr Build is failing

This commit is contained in:
Alex 2022-07-22 13:01:52 +02:00
parent 0b71ca12f9
commit 9cb28c21b4
Signed by: lx
GPG key ID: 0E496D15096376BE
4 changed files with 37 additions and 32 deletions

View file

@ -8,7 +8,6 @@ use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use log::{debug, error, trace}; use log::{debug, error, trace};
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
use futures::io::AsyncReadExt; use futures::io::AsyncReadExt;
use kuska_handshake::async_std::{handshake_client, BoxStream}; use kuska_handshake::async_std::{handshake_client, BoxStream};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@ -39,7 +38,7 @@ pub(crate) struct ClientConn {
query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, ByteStream)>>, query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, ByteStream)>>,
next_query_number: AtomicU32, next_query_number: AtomicU32,
inflight: Mutex<HashMap<RequestID, oneshot::Sender<UnboundedReceiver<Packet>>>>, inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>,
} }
impl ClientConn { impl ClientConn {
@ -175,7 +174,9 @@ impl ClientConn {
error!( error!(
"Too many inflight requests! RequestID collision. Interrupting previous request." "Too many inflight requests! RequestID collision. Interrupting previous request."
); );
let _ = old_ch.send(unbounded().1); let _ = old_ch.send(Box::pin(futures::stream::once(async move {
Err(Error::IdCollision.code())
})));
} }
trace!( trace!(
@ -199,7 +200,7 @@ impl ClientConn {
} }
} }
let resp_enc = RespEnc::decode(Box::pin(stream)).await?; let resp_enc = RespEnc::decode(stream).await?;
trace!("request response {}", id); trace!("request response {}", id);
Resp::from_enc(resp_enc) Resp::from_enc(resp_enc)
} }
@ -209,7 +210,7 @@ impl SendLoop for ClientConn {}
#[async_trait] #[async_trait]
impl RecvLoop for ClientConn { impl RecvLoop for ClientConn {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) { fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
trace!("ClientConn recv_handler {}", id); trace!("ClientConn recv_handler {}", id);
let mut inflight = self.inflight.lock().unwrap(); let mut inflight = self.inflight.lock().unwrap();

View file

@ -28,6 +28,9 @@ pub enum Error {
#[error(display = "Framing protocol error")] #[error(display = "Framing protocol error")]
Framing, Framing,
#[error(display = "Request ID collision")]
IdCollision,
#[error(display = "{}", _0)] #[error(display = "{}", _0)]
Message(String), Message(String),
@ -56,6 +59,7 @@ impl Error {
Self::Framing => 13, Self::Framing => 13,
Self::NoHandler => 20, Self::NoHandler => 20,
Self::ConnectionClosed => 21, Self::ConnectionClosed => 21,
Self::IdCollision => 22,
Self::Handshake(_) => 30, Self::Handshake(_) => 30,
Self::VersionMismatch(_) => 31, Self::VersionMismatch(_) => 31,
Self::Remote(c, _) => *c, Self::Remote(c, _) => *c,

View file

@ -5,8 +5,8 @@ use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use log::trace; use log::trace;
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use futures::AsyncReadExt; use futures::AsyncReadExt;
use tokio::sync::mpsc;
use crate::error::*; use crate::error::*;
use crate::send::*; use crate::send::*;
@ -15,33 +15,28 @@ use crate::stream::*;
/// Structure to warn when the sender is dropped before end of stream was reached, like when /// Structure to warn when the sender is dropped before end of stream was reached, like when
/// connection to some remote drops while transmitting data /// connection to some remote drops while transmitting data
struct Sender { struct Sender {
inner: UnboundedSender<Packet>, inner: Option<mpsc::Sender<Packet>>,
closed: bool,
} }
impl Sender { impl Sender {
fn new(inner: UnboundedSender<Packet>) -> Self { fn new(inner: mpsc::Sender<Packet>) -> Self {
Sender { Sender { inner: Some(inner) }
inner,
closed: false,
}
} }
fn send(&self, packet: Packet) { async fn send(&self, packet: Packet) {
let _ = self.inner.unbounded_send(packet); let _ = self.inner.as_ref().unwrap().send(packet).await;
} }
fn end(&mut self) { fn end(&mut self) {
self.closed = true; self.inner = None;
} }
} }
impl Drop for Sender { impl Drop for Sender {
fn drop(&mut self) { fn drop(&mut self) {
if !self.closed { if let Some(inner) = self.inner.take() {
self.send(Err(255)); let _ = inner.blocking_send(Err(255));
} }
self.inner.close_channel();
} }
} }
@ -54,7 +49,7 @@ impl Drop for Sender {
/// the full message is passed to the receive handler. /// the full message is passed to the receive handler.
#[async_trait] #[async_trait]
pub(crate) trait RecvLoop: Sync + 'static { pub(crate) trait RecvLoop: Sync + 'static {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>); fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream);
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
where where
@ -92,14 +87,17 @@ pub(crate) trait RecvLoop: Sync + 'static {
let mut sender = if let Some(send) = streams.remove(&(id)) { let mut sender = if let Some(send) = streams.remove(&(id)) {
send send
} else { } else {
let (send, recv) = unbounded(); let (send, recv) = mpsc::channel(4);
self.recv_handler(id, recv); self.recv_handler(
id,
Box::pin(tokio_stream::wrappers::ReceiverStream::new(recv)),
);
Sender::new(send) Sender::new(send)
}; };
// if we get an error, the receiving end is disconnected. We still need to // If we get an error, the receiving end is disconnected.
// reach eos before dropping this sender // We still need to reach eos before dropping this sender
sender.send(packet); let _ = sender.send(packet).await;
if has_cont { if has_cont {
streams.insert(id, sender); streams.insert(id, sender);

View file

@ -5,7 +5,6 @@ use arc_swap::ArcSwapOption;
use async_trait::async_trait; use async_trait::async_trait;
use log::{debug, trace}; use log::{debug, trace};
use futures::channel::mpsc::UnboundedReceiver;
use futures::io::{AsyncReadExt, AsyncWriteExt}; use futures::io::{AsyncReadExt, AsyncWriteExt};
use kuska_handshake::async_std::{handshake_server, BoxStream}; use kuska_handshake::async_std::{handshake_server, BoxStream};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@ -171,21 +170,24 @@ impl SendLoop for ServerConn {}
#[async_trait] #[async_trait]
impl RecvLoop for ServerConn { impl RecvLoop for ServerConn {
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) { fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
let resp_send = self.resp_send.load_full().unwrap(); let resp_send = self.resp_send.load_full().unwrap();
let self2 = self.clone(); let self2 = self.clone();
tokio::spawn(async move { tokio::spawn(async move {
trace!("ServerConn recv_handler {}", id); trace!("ServerConn recv_handler {}", id);
let (prio, resp_enc) = match ReqEnc::decode(Box::pin(stream)).await { let (prio, resp_enc) = match ReqEnc::decode(stream).await {
Ok(req_enc) => { Ok(req_enc) => {
let prio = req_enc.prio; let prio = req_enc.prio;
let resp = self2.recv_handler_aux(req_enc).await; let resp = self2.recv_handler_aux(req_enc).await;
(prio, match resp { (
prio,
match resp {
Ok(resp_enc) => resp_enc, Ok(resp_enc) => resp_enc,
Err(e) => RespEnc::from_err(e), Err(e) => RespEnc::from_err(e),
}) },
)
} }
Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)), Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)),
}; };