use std::collections::HashMap; use std::sync::Arc; use async_trait::async_trait; use bytes::Bytes; use log::trace; use futures::AsyncReadExt; use tokio::sync::mpsc; use crate::error::*; use crate::send::*; use crate::stream::*; /// Structure to warn when the sender is dropped before end of stream was reached, like when /// connection to some remote drops while transmitting data struct Sender { inner: Option>, } impl Sender { fn new(inner: mpsc::Sender) -> Self { Sender { inner: Some(inner) } } async fn send(&self, packet: Packet) { let _ = self.inner.as_ref().unwrap().send(packet).await; } fn end(&mut self) { self.inner = None; } } impl Drop for Sender { fn drop(&mut self) { if let Some(inner) = self.inner.take() { tokio::spawn(async move { let _ = inner.send(Err(255)).await; }); } } } /// The RecvLoop trait, which is implemented both by the client and the server /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// and a prototype of a handler for received messages `.recv_handler()` that /// must be filled by implementors. `.recv_loop()` receives messages in a loop /// according to the protocol defined above: chunks of message in progress of being /// received are stored in a buffer, and when the last chunk of a message is received, /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { fn recv_handler(self: &Arc, id: RequestID, stream: ByteStream); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { let mut streams: HashMap = HashMap::new(); loop { let mut header_id = [0u8; RequestID::BITS as usize / 8]; match read.read_exact(&mut header_id[..]).await { Ok(_) => (), Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, Err(e) => return Err(e.into()), }; let id = RequestID::from_be_bytes(header_id); let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; read.read_exact(&mut header_size[..]).await?; let size = ChunkLength::from_be_bytes(header_size); let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let is_error = (size & ERROR_MARKER) != 0; let packet = if is_error { trace!( "recv_loop: got id {}, header_size {:04x}, error {}", id, size, size & !ERROR_MARKER ); Err((size & !ERROR_MARKER) as u8) } else { let size = size & !CHUNK_HAS_CONTINUATION; let mut next_slice = vec![0; size as usize]; read.read_exact(&mut next_slice[..]).await?; trace!( "recv_loop: got id {}, header_size {:04x}, {} bytes", id, size, next_slice.len() ); Ok(Bytes::from(next_slice)) }; let mut sender = if let Some(send) = streams.remove(&(id)) { send } else { let (send, recv) = mpsc::channel(4); trace!("recv_loop: id {} is new channel", id); self.recv_handler( id, Box::pin(tokio_stream::wrappers::ReceiverStream::new(recv)), ); Sender::new(send) }; // If we get an error, the receiving end is disconnected. // We still need to reach eos before dropping this sender let _ = sender.send(packet).await; if has_cont { assert!(!is_error); streams.insert(id, sender); } else { trace!("recv_loop: close channel id {}", id); sender.end(); } } Ok(()) } }