diff --git a/src/proto.rs b/src/proto.rs index bf82e47..47480a9 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use log::trace; use futures::{AsyncReadExt, AsyncWriteExt}; +use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -100,7 +101,7 @@ pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec)>, - mut write: W, + mut write: BoxStreamWrite, ) -> Result<(), Error> where W: AsyncWriteExt + Unpin + Send + Sync, @@ -160,6 +161,7 @@ pub(crate) trait SendLoop: Sync { } } } + write.goodbye().await?; Ok(()) } } @@ -177,7 +179,11 @@ pub(crate) trait RecvLoop: Sync + 'static { loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; - read.read_exact(&mut header_id[..]).await?; + match read.read_exact(&mut header_id[..]).await { + Ok(_) => (), + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e.into()), + }; let id = RequestID::from_be_bytes(header_id); trace!("recv_loop: got header id: {:04x}", id); @@ -202,6 +208,7 @@ pub(crate) trait RecvLoop: Sync + 'static { self.recv_handler(id, msg_bytes); } } + Ok(()) } }