Use bounded channels on receive side for backpressure
This commit is contained in:
parent
0b71ca12f9
commit
9cb28c21b4
4 changed files with 37 additions and 32 deletions
|
@ -8,7 +8,6 @@ use async_trait::async_trait;
|
|||
use bytes::Bytes;
|
||||
use log::{debug, error, trace};
|
||||
|
||||
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
|
||||
use futures::io::AsyncReadExt;
|
||||
use kuska_handshake::async_std::{handshake_client, BoxStream};
|
||||
use tokio::net::TcpStream;
|
||||
|
@ -39,7 +38,7 @@ pub(crate) struct ClientConn {
|
|||
query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, ByteStream)>>,
|
||||
|
||||
next_query_number: AtomicU32,
|
||||
inflight: Mutex<HashMap<RequestID, oneshot::Sender<UnboundedReceiver<Packet>>>>,
|
||||
inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>,
|
||||
}
|
||||
|
||||
impl ClientConn {
|
||||
|
@ -175,7 +174,9 @@ impl ClientConn {
|
|||
error!(
|
||||
"Too many inflight requests! RequestID collision. Interrupting previous request."
|
||||
);
|
||||
let _ = old_ch.send(unbounded().1);
|
||||
let _ = old_ch.send(Box::pin(futures::stream::once(async move {
|
||||
Err(Error::IdCollision.code())
|
||||
})));
|
||||
}
|
||||
|
||||
trace!(
|
||||
|
@ -199,7 +200,7 @@ impl ClientConn {
|
|||
}
|
||||
}
|
||||
|
||||
let resp_enc = RespEnc::decode(Box::pin(stream)).await?;
|
||||
let resp_enc = RespEnc::decode(stream).await?;
|
||||
trace!("request response {}", id);
|
||||
Resp::from_enc(resp_enc)
|
||||
}
|
||||
|
@ -209,7 +210,7 @@ impl SendLoop for ClientConn {}
|
|||
|
||||
#[async_trait]
|
||||
impl RecvLoop for ClientConn {
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
|
||||
trace!("ClientConn recv_handler {}", id);
|
||||
|
||||
let mut inflight = self.inflight.lock().unwrap();
|
||||
|
|
|
@ -28,6 +28,9 @@ pub enum Error {
|
|||
#[error(display = "Framing protocol error")]
|
||||
Framing,
|
||||
|
||||
#[error(display = "Request ID collision")]
|
||||
IdCollision,
|
||||
|
||||
#[error(display = "{}", _0)]
|
||||
Message(String),
|
||||
|
||||
|
@ -56,6 +59,7 @@ impl Error {
|
|||
Self::Framing => 13,
|
||||
Self::NoHandler => 20,
|
||||
Self::ConnectionClosed => 21,
|
||||
Self::IdCollision => 22,
|
||||
Self::Handshake(_) => 30,
|
||||
Self::VersionMismatch(_) => 31,
|
||||
Self::Remote(c, _) => *c,
|
||||
|
|
38
src/recv.rs
38
src/recv.rs
|
@ -5,8 +5,8 @@ use async_trait::async_trait;
|
|||
use bytes::Bytes;
|
||||
use log::trace;
|
||||
|
||||
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
|
||||
use futures::AsyncReadExt;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::error::*;
|
||||
use crate::send::*;
|
||||
|
@ -15,33 +15,28 @@ use crate::stream::*;
|
|||
/// Structure to warn when the sender is dropped before end of stream was reached, like when
|
||||
/// connection to some remote drops while transmitting data
|
||||
struct Sender {
|
||||
inner: UnboundedSender<Packet>,
|
||||
closed: bool,
|
||||
inner: Option<mpsc::Sender<Packet>>,
|
||||
}
|
||||
|
||||
impl Sender {
|
||||
fn new(inner: UnboundedSender<Packet>) -> Self {
|
||||
Sender {
|
||||
inner,
|
||||
closed: false,
|
||||
}
|
||||
fn new(inner: mpsc::Sender<Packet>) -> Self {
|
||||
Sender { inner: Some(inner) }
|
||||
}
|
||||
|
||||
fn send(&self, packet: Packet) {
|
||||
let _ = self.inner.unbounded_send(packet);
|
||||
async fn send(&self, packet: Packet) {
|
||||
let _ = self.inner.as_ref().unwrap().send(packet).await;
|
||||
}
|
||||
|
||||
fn end(&mut self) {
|
||||
self.closed = true;
|
||||
self.inner = None;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Sender {
|
||||
fn drop(&mut self) {
|
||||
if !self.closed {
|
||||
self.send(Err(255));
|
||||
if let Some(inner) = self.inner.take() {
|
||||
let _ = inner.blocking_send(Err(255));
|
||||
}
|
||||
self.inner.close_channel();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -54,7 +49,7 @@ impl Drop for Sender {
|
|||
/// the full message is passed to the receive handler.
|
||||
#[async_trait]
|
||||
pub(crate) trait RecvLoop: Sync + 'static {
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>);
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream);
|
||||
|
||||
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
|
||||
where
|
||||
|
@ -92,14 +87,17 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
|||
let mut sender = if let Some(send) = streams.remove(&(id)) {
|
||||
send
|
||||
} else {
|
||||
let (send, recv) = unbounded();
|
||||
self.recv_handler(id, recv);
|
||||
let (send, recv) = mpsc::channel(4);
|
||||
self.recv_handler(
|
||||
id,
|
||||
Box::pin(tokio_stream::wrappers::ReceiverStream::new(recv)),
|
||||
);
|
||||
Sender::new(send)
|
||||
};
|
||||
|
||||
// if we get an error, the receiving end is disconnected. We still need to
|
||||
// reach eos before dropping this sender
|
||||
sender.send(packet);
|
||||
// If we get an error, the receiving end is disconnected.
|
||||
// We still need to reach eos before dropping this sender
|
||||
let _ = sender.send(packet).await;
|
||||
|
||||
if has_cont {
|
||||
streams.insert(id, sender);
|
||||
|
|
|
@ -5,7 +5,6 @@ use arc_swap::ArcSwapOption;
|
|||
use async_trait::async_trait;
|
||||
use log::{debug, trace};
|
||||
|
||||
use futures::channel::mpsc::UnboundedReceiver;
|
||||
use futures::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use kuska_handshake::async_std::{handshake_server, BoxStream};
|
||||
use tokio::net::TcpStream;
|
||||
|
@ -171,21 +170,24 @@ impl SendLoop for ServerConn {}
|
|||
|
||||
#[async_trait]
|
||||
impl RecvLoop for ServerConn {
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
|
||||
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
|
||||
let resp_send = self.resp_send.load_full().unwrap();
|
||||
|
||||
let self2 = self.clone();
|
||||
tokio::spawn(async move {
|
||||
trace!("ServerConn recv_handler {}", id);
|
||||
let (prio, resp_enc) = match ReqEnc::decode(Box::pin(stream)).await {
|
||||
let (prio, resp_enc) = match ReqEnc::decode(stream).await {
|
||||
Ok(req_enc) => {
|
||||
let prio = req_enc.prio;
|
||||
let resp = self2.recv_handler_aux(req_enc).await;
|
||||
|
||||
(prio, match resp {
|
||||
(
|
||||
prio,
|
||||
match resp {
|
||||
Ok(resp_enc) => resp_enc,
|
||||
Err(e) => RespEnc::from_err(e),
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)),
|
||||
};
|
||||
|
|
Loading…
Reference in a new issue