Use bounded channels on receive side for backpressure
This commit is contained in:
parent
0b71ca12f9
commit
9cb28c21b4
4 changed files with 37 additions and 32 deletions
|
@ -8,7 +8,6 @@ use async_trait::async_trait;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use log::{debug, error, trace};
|
use log::{debug, error, trace};
|
||||||
|
|
||||||
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
|
|
||||||
use futures::io::AsyncReadExt;
|
use futures::io::AsyncReadExt;
|
||||||
use kuska_handshake::async_std::{handshake_client, BoxStream};
|
use kuska_handshake::async_std::{handshake_client, BoxStream};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
|
@ -39,7 +38,7 @@ pub(crate) struct ClientConn {
|
||||||
query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, ByteStream)>>,
|
query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, ByteStream)>>,
|
||||||
|
|
||||||
next_query_number: AtomicU32,
|
next_query_number: AtomicU32,
|
||||||
inflight: Mutex<HashMap<RequestID, oneshot::Sender<UnboundedReceiver<Packet>>>>,
|
inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClientConn {
|
impl ClientConn {
|
||||||
|
@ -175,7 +174,9 @@ impl ClientConn {
|
||||||
error!(
|
error!(
|
||||||
"Too many inflight requests! RequestID collision. Interrupting previous request."
|
"Too many inflight requests! RequestID collision. Interrupting previous request."
|
||||||
);
|
);
|
||||||
let _ = old_ch.send(unbounded().1);
|
let _ = old_ch.send(Box::pin(futures::stream::once(async move {
|
||||||
|
Err(Error::IdCollision.code())
|
||||||
|
})));
|
||||||
}
|
}
|
||||||
|
|
||||||
trace!(
|
trace!(
|
||||||
|
@ -199,7 +200,7 @@ impl ClientConn {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let resp_enc = RespEnc::decode(Box::pin(stream)).await?;
|
let resp_enc = RespEnc::decode(stream).await?;
|
||||||
trace!("request response {}", id);
|
trace!("request response {}", id);
|
||||||
Resp::from_enc(resp_enc)
|
Resp::from_enc(resp_enc)
|
||||||
}
|
}
|
||||||
|
@ -209,7 +210,7 @@ impl SendLoop for ClientConn {}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RecvLoop for ClientConn {
|
impl RecvLoop for ClientConn {
|
||||||
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
|
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
|
||||||
trace!("ClientConn recv_handler {}", id);
|
trace!("ClientConn recv_handler {}", id);
|
||||||
|
|
||||||
let mut inflight = self.inflight.lock().unwrap();
|
let mut inflight = self.inflight.lock().unwrap();
|
||||||
|
|
|
@ -28,6 +28,9 @@ pub enum Error {
|
||||||
#[error(display = "Framing protocol error")]
|
#[error(display = "Framing protocol error")]
|
||||||
Framing,
|
Framing,
|
||||||
|
|
||||||
|
#[error(display = "Request ID collision")]
|
||||||
|
IdCollision,
|
||||||
|
|
||||||
#[error(display = "{}", _0)]
|
#[error(display = "{}", _0)]
|
||||||
Message(String),
|
Message(String),
|
||||||
|
|
||||||
|
@ -56,6 +59,7 @@ impl Error {
|
||||||
Self::Framing => 13,
|
Self::Framing => 13,
|
||||||
Self::NoHandler => 20,
|
Self::NoHandler => 20,
|
||||||
Self::ConnectionClosed => 21,
|
Self::ConnectionClosed => 21,
|
||||||
|
Self::IdCollision => 22,
|
||||||
Self::Handshake(_) => 30,
|
Self::Handshake(_) => 30,
|
||||||
Self::VersionMismatch(_) => 31,
|
Self::VersionMismatch(_) => 31,
|
||||||
Self::Remote(c, _) => *c,
|
Self::Remote(c, _) => *c,
|
||||||
|
|
38
src/recv.rs
38
src/recv.rs
|
@ -5,8 +5,8 @@ use async_trait::async_trait;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use log::trace;
|
use log::trace;
|
||||||
|
|
||||||
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
|
|
||||||
use futures::AsyncReadExt;
|
use futures::AsyncReadExt;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
use crate::error::*;
|
use crate::error::*;
|
||||||
use crate::send::*;
|
use crate::send::*;
|
||||||
|
@ -15,33 +15,28 @@ use crate::stream::*;
|
||||||
/// Structure to warn when the sender is dropped before end of stream was reached, like when
|
/// Structure to warn when the sender is dropped before end of stream was reached, like when
|
||||||
/// connection to some remote drops while transmitting data
|
/// connection to some remote drops while transmitting data
|
||||||
struct Sender {
|
struct Sender {
|
||||||
inner: UnboundedSender<Packet>,
|
inner: Option<mpsc::Sender<Packet>>,
|
||||||
closed: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Sender {
|
impl Sender {
|
||||||
fn new(inner: UnboundedSender<Packet>) -> Self {
|
fn new(inner: mpsc::Sender<Packet>) -> Self {
|
||||||
Sender {
|
Sender { inner: Some(inner) }
|
||||||
inner,
|
|
||||||
closed: false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send(&self, packet: Packet) {
|
async fn send(&self, packet: Packet) {
|
||||||
let _ = self.inner.unbounded_send(packet);
|
let _ = self.inner.as_ref().unwrap().send(packet).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn end(&mut self) {
|
fn end(&mut self) {
|
||||||
self.closed = true;
|
self.inner = None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for Sender {
|
impl Drop for Sender {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
if !self.closed {
|
if let Some(inner) = self.inner.take() {
|
||||||
self.send(Err(255));
|
let _ = inner.blocking_send(Err(255));
|
||||||
}
|
}
|
||||||
self.inner.close_channel();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,7 +49,7 @@ impl Drop for Sender {
|
||||||
/// the full message is passed to the receive handler.
|
/// the full message is passed to the receive handler.
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub(crate) trait RecvLoop: Sync + 'static {
|
pub(crate) trait RecvLoop: Sync + 'static {
|
||||||
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>);
|
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream);
|
||||||
|
|
||||||
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
|
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
|
||||||
where
|
where
|
||||||
|
@ -92,14 +87,17 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
||||||
let mut sender = if let Some(send) = streams.remove(&(id)) {
|
let mut sender = if let Some(send) = streams.remove(&(id)) {
|
||||||
send
|
send
|
||||||
} else {
|
} else {
|
||||||
let (send, recv) = unbounded();
|
let (send, recv) = mpsc::channel(4);
|
||||||
self.recv_handler(id, recv);
|
self.recv_handler(
|
||||||
|
id,
|
||||||
|
Box::pin(tokio_stream::wrappers::ReceiverStream::new(recv)),
|
||||||
|
);
|
||||||
Sender::new(send)
|
Sender::new(send)
|
||||||
};
|
};
|
||||||
|
|
||||||
// if we get an error, the receiving end is disconnected. We still need to
|
// If we get an error, the receiving end is disconnected.
|
||||||
// reach eos before dropping this sender
|
// We still need to reach eos before dropping this sender
|
||||||
sender.send(packet);
|
let _ = sender.send(packet).await;
|
||||||
|
|
||||||
if has_cont {
|
if has_cont {
|
||||||
streams.insert(id, sender);
|
streams.insert(id, sender);
|
||||||
|
|
|
@ -5,7 +5,6 @@ use arc_swap::ArcSwapOption;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use log::{debug, trace};
|
use log::{debug, trace};
|
||||||
|
|
||||||
use futures::channel::mpsc::UnboundedReceiver;
|
|
||||||
use futures::io::{AsyncReadExt, AsyncWriteExt};
|
use futures::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
use kuska_handshake::async_std::{handshake_server, BoxStream};
|
use kuska_handshake::async_std::{handshake_server, BoxStream};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
|
@ -171,21 +170,24 @@ impl SendLoop for ServerConn {}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RecvLoop for ServerConn {
|
impl RecvLoop for ServerConn {
|
||||||
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
|
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
|
||||||
let resp_send = self.resp_send.load_full().unwrap();
|
let resp_send = self.resp_send.load_full().unwrap();
|
||||||
|
|
||||||
let self2 = self.clone();
|
let self2 = self.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
trace!("ServerConn recv_handler {}", id);
|
trace!("ServerConn recv_handler {}", id);
|
||||||
let (prio, resp_enc) = match ReqEnc::decode(Box::pin(stream)).await {
|
let (prio, resp_enc) = match ReqEnc::decode(stream).await {
|
||||||
Ok(req_enc) => {
|
Ok(req_enc) => {
|
||||||
let prio = req_enc.prio;
|
let prio = req_enc.prio;
|
||||||
let resp = self2.recv_handler_aux(req_enc).await;
|
let resp = self2.recv_handler_aux(req_enc).await;
|
||||||
|
|
||||||
(prio, match resp {
|
(
|
||||||
|
prio,
|
||||||
|
match resp {
|
||||||
Ok(resp_enc) => resp_enc,
|
Ok(resp_enc) => resp_enc,
|
||||||
Err(e) => RespEnc::from_err(e),
|
Err(e) => RespEnc::from_err(e),
|
||||||
})
|
},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)),
|
Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)),
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in a new issue