use std::collections::{BTreeMap, HashMap, VecDeque}; use std::sync::Arc; use log::trace; use futures::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::mpsc; use async_trait::async_trait; use crate::error::*; /// Priority of a request (click to read more about priorities). /// /// This priority value is used to priorize messages /// in the send queue of the client, and their responses in the send queue of the /// server. Lower values mean higher priority. /// /// This mechanism is usefull for messages bigger than the maximum chunk size /// (set at `0x4000` bytes), such as large file transfers. /// In such case, all of the messages in the send queue with the highest priority /// will take turns to send individual chunks, in a round-robin fashion. /// Once all highest priority messages are sent successfully, the messages with /// the next highest priority will begin being sent in the same way. /// /// The same priority value is given to a request and to its associated response. pub type RequestPriority = u8; /// Priority class: high pub const PRIO_HIGH: RequestPriority = 0x20; /// Priority class: normal pub const PRIO_NORMAL: RequestPriority = 0x40; /// Priority class: background pub const PRIO_BACKGROUND: RequestPriority = 0x80; /// Priority: primary among given class pub const PRIO_PRIMARY: RequestPriority = 0x00; /// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) pub const PRIO_SECONDARY: RequestPriority = 0x01; const MAX_CHUNK_SIZE: usize = 0x4000; pub(crate) type RequestID = u16; struct SendQueueItem { id: RequestID, prio: RequestPriority, data: Vec, cursor: usize, } struct SendQueue { items: BTreeMap>, } impl SendQueue { fn new() -> Self { Self { items: BTreeMap::new(), } } fn push(&mut self, item: SendQueueItem) { let prio = item.prio; let mut items_at_prio = self .items .remove(&prio) .unwrap_or_else(|| VecDeque::with_capacity(4)); items_at_prio.push_back(item); self.items.insert(prio, items_at_prio); } fn pop(&mut self) -> Option { match self.items.pop_first() { None => None, Some((prio, mut items_at_prio)) => { let ret = items_at_prio.pop_front(); if !items_at_prio.is_empty() { self.items.insert(prio, items_at_prio); } ret.or_else(|| self.pop()) } } } fn is_empty(&self) -> bool { self.items.iter().all(|(_k, v)| v.is_empty()) } } #[async_trait] pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, mut msg_recv: mpsc::UnboundedReceiver)>>, mut write: W, ) -> Result<(), Error> where W: AsyncWriteExt + Unpin + Send + Sync, { let mut sending = SendQueue::new(); let mut should_exit = false; while !should_exit || !sending.is_empty() { if let Ok(sth) = msg_recv.try_recv() { if let Some((id, prio, data)) = sth { trace!("send_loop: got {}, {} bytes", id, data.len()); sending.push(SendQueueItem { id, prio, data, cursor: 0, }); } else { should_exit = true; } } else if let Some(mut item) = sending.pop() { trace!( "send_loop: sending bytes for {} ({} bytes, {} already sent)", item.id, item.data.len(), item.cursor ); let header_id = u16::to_be_bytes(item.id); write.write_all(&header_id[..]).await?; if item.data.len() - item.cursor > MAX_CHUNK_SIZE { let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000); write.write_all(&header_size[..]).await?; let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize; write.write_all(&item.data[item.cursor..new_cursor]).await?; item.cursor = new_cursor; sending.push(item); } else { let send_len = (item.data.len() - item.cursor) as u16; let header_size = u16::to_be_bytes(send_len); write.write_all(&header_size[..]).await?; write.write_all(&item.data[item.cursor..]).await?; } write.flush().await?; } else { let sth = msg_recv .recv() .await .ok_or_else(|| Error::Message("Connection closed.".into()))?; if let Some((id, prio, data)) = sth { trace!("send_loop: got {}, {} bytes", id, data.len()); sending.push(SendQueueItem { id, prio, data, cursor: 0, }); } else { should_exit = true; } } } Ok(()) } } #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { // Returns true if we should stop receiving after this async fn recv_handler(self: Arc, id: RequestID, msg: Vec); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { let mut receiving = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; 2]; read.read_exact(&mut header_id[..]).await?; let id = RequestID::from_be_bytes(header_id); trace!("recv_loop: got header id: {:04x}", id); let mut header_size = [0u8; 2]; read.read_exact(&mut header_size[..]).await?; let size = RequestID::from_be_bytes(header_size); trace!("recv_loop: got header size: {:04x}", size); let has_cont = (size & 0x8000) != 0; let size = size & !0x8000; let mut next_slice = vec![0; size as usize]; read.read_exact(&mut next_slice[..]).await?; trace!("recv_loop: read {} bytes", next_slice.len()); let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default(); msg_bytes.extend_from_slice(&next_slice[..]); if has_cont { receiving.insert(id, msg_bytes); } else { tokio::spawn(self.clone().recv_handler(id, msg_bytes)); } } } }