diff --git a/src/client.rs b/src/client.rs index 42eeaa3..d51236b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,7 +8,6 @@ use async_trait::async_trait; use bytes::Bytes; use log::{debug, error, trace}; -use futures::channel::mpsc::{unbounded, UnboundedReceiver}; use futures::io::AsyncReadExt; use kuska_handshake::async_std::{handshake_client, BoxStream}; use tokio::net::TcpStream; @@ -39,7 +38,7 @@ pub(crate) struct ClientConn { query_send: ArcSwapOption>, next_query_number: AtomicU32, - inflight: Mutex>>>, + inflight: Mutex>>, } impl ClientConn { @@ -175,7 +174,9 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - let _ = old_ch.send(unbounded().1); + let _ = old_ch.send(Box::pin(futures::stream::once(async move { + Err(Error::IdCollision.code()) + }))); } trace!( @@ -199,7 +200,7 @@ impl ClientConn { } } - let resp_enc = RespEnc::decode(Box::pin(stream)).await?; + let resp_enc = RespEnc::decode(stream).await?; trace!("request response {}", id); Resp::from_enc(resp_enc) } @@ -209,7 +210,7 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver) { + fn recv_handler(self: &Arc, id: RequestID, stream: ByteStream) { trace!("ClientConn recv_handler {}", id); let mut inflight = self.inflight.lock().unwrap(); diff --git a/src/error.rs b/src/error.rs index 665647c..f374341 100644 --- a/src/error.rs +++ b/src/error.rs @@ -28,6 +28,9 @@ pub enum Error { #[error(display = "Framing protocol error")] Framing, + #[error(display = "Request ID collision")] + IdCollision, + #[error(display = "{}", _0)] Message(String), @@ -56,6 +59,7 @@ impl Error { Self::Framing => 13, Self::NoHandler => 20, Self::ConnectionClosed => 21, + Self::IdCollision => 22, Self::Handshake(_) => 30, Self::VersionMismatch(_) => 31, Self::Remote(c, _) => *c, diff --git a/src/recv.rs b/src/recv.rs index 19288f2..b2f5530 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -5,8 +5,8 @@ use async_trait::async_trait; use bytes::Bytes; use log::trace; -use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::AsyncReadExt; +use tokio::sync::mpsc; use crate::error::*; use crate::send::*; @@ -15,33 +15,28 @@ use crate::stream::*; /// Structure to warn when the sender is dropped before end of stream was reached, like when /// connection to some remote drops while transmitting data struct Sender { - inner: UnboundedSender, - closed: bool, + inner: Option>, } impl Sender { - fn new(inner: UnboundedSender) -> Self { - Sender { - inner, - closed: false, - } + fn new(inner: mpsc::Sender) -> Self { + Sender { inner: Some(inner) } } - fn send(&self, packet: Packet) { - let _ = self.inner.unbounded_send(packet); + async fn send(&self, packet: Packet) { + let _ = self.inner.as_ref().unwrap().send(packet).await; } fn end(&mut self) { - self.closed = true; + self.inner = None; } } impl Drop for Sender { fn drop(&mut self) { - if !self.closed { - self.send(Err(255)); + if let Some(inner) = self.inner.take() { + let _ = inner.blocking_send(Err(255)); } - self.inner.close_channel(); } } @@ -54,7 +49,7 @@ impl Drop for Sender { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver); + fn recv_handler(self: &Arc, id: RequestID, stream: ByteStream); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where @@ -92,14 +87,17 @@ pub(crate) trait RecvLoop: Sync + 'static { let mut sender = if let Some(send) = streams.remove(&(id)) { send } else { - let (send, recv) = unbounded(); - self.recv_handler(id, recv); + let (send, recv) = mpsc::channel(4); + self.recv_handler( + id, + Box::pin(tokio_stream::wrappers::ReceiverStream::new(recv)), + ); Sender::new(send) }; - // if we get an error, the receiving end is disconnected. We still need to - // reach eos before dropping this sender - sender.send(packet); + // If we get an error, the receiving end is disconnected. + // We still need to reach eos before dropping this sender + let _ = sender.send(packet).await; if has_cont { streams.insert(id, sender); diff --git a/src/server.rs b/src/server.rs index ae1196c..4b232af 100644 --- a/src/server.rs +++ b/src/server.rs @@ -5,7 +5,6 @@ use arc_swap::ArcSwapOption; use async_trait::async_trait; use log::{debug, trace}; -use futures::channel::mpsc::UnboundedReceiver; use futures::io::{AsyncReadExt, AsyncWriteExt}; use kuska_handshake::async_std::{handshake_server, BoxStream}; use tokio::net::TcpStream; @@ -171,21 +170,24 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver) { + fn recv_handler(self: &Arc, id: RequestID, stream: ByteStream) { let resp_send = self.resp_send.load_full().unwrap(); let self2 = self.clone(); tokio::spawn(async move { trace!("ServerConn recv_handler {}", id); - let (prio, resp_enc) = match ReqEnc::decode(Box::pin(stream)).await { + let (prio, resp_enc) = match ReqEnc::decode(stream).await { Ok(req_enc) => { let prio = req_enc.prio; let resp = self2.recv_handler_aux(req_enc).await; - (prio, match resp { - Ok(resp_enc) => resp_enc, - Err(e) => RespEnc::from_err(e), - }) + ( + prio, + match resp { + Ok(resp_enc) => resp_enc, + Err(e) => RespEnc::from_err(e), + }, + ) } Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)), };