use std::collections::{BTreeMap, HashMap, VecDeque}; use std::sync::Arc; use log::trace; use async_trait::async_trait; use async_std::io::prelude::WriteExt; use async_std::io::ReadExt; use tokio::io::{ReadHalf, WriteHalf}; use tokio::net::TcpStream; use tokio::sync::{mpsc, watch}; use crate::error::*; use kuska_handshake::async_std::{BoxStreamRead, BoxStreamWrite, TokioCompat}; /// Priority of a request (click to read more about priorities). /// /// This priority value is used to priorize messages /// in the send queue of the client, and their responses in the send queue of the /// server. Lower values mean higher priority. /// /// This mechanism is usefull for messages bigger than the maximum chunk size /// (set at `0x4000` bytes), such as large file transfers. /// In such case, all of the messages in the send queue with the highest priority /// will take turns to send individual chunks, in a round-robin fashion. /// Once all highest priority messages are sent successfully, the messages with /// the next highest priority will begin being sent in the same way. /// /// The same priority value is given to a request and to its associated response. pub type RequestPriority = u8; /// Priority class: high pub const PRIO_HIGH: RequestPriority = 0x20; /// Priority class: normal pub const PRIO_NORMAL: RequestPriority = 0x40; /// Priority class: background pub const PRIO_BACKGROUND: RequestPriority = 0x80; /// Priority: primary among given class pub const PRIO_PRIMARY: RequestPriority = 0x00; /// Priority: secondary among given class (ex: `PRIO_HIGH || PRIO_SECONDARY`) pub const PRIO_SECONDARY: RequestPriority = 0x01; const MAX_CHUNK_SIZE: usize = 0x4000; pub(crate) type RequestID = u16; struct SendQueueItem { id: RequestID, prio: RequestPriority, data: Vec, cursor: usize, } struct SendQueue { items: BTreeMap>, } impl SendQueue { fn new() -> Self { Self { items: BTreeMap::new(), } } fn push(&mut self, item: SendQueueItem) { let prio = item.prio; let mut items_at_prio = self .items .remove(&prio) .unwrap_or(VecDeque::with_capacity(4)); items_at_prio.push_back(item); self.items.insert(prio, items_at_prio); } fn pop(&mut self) -> Option { match self.items.pop_first() { None => None, Some((prio, mut items_at_prio)) => { let ret = items_at_prio.pop_front(); if !items_at_prio.is_empty() { self.items.insert(prio, items_at_prio); } ret } } } } #[async_trait] pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec)>, mut write: BoxStreamWrite>>, mut must_exit: watch::Receiver, ) -> Result<(), Error> { let mut sending = SendQueue::new(); while !*must_exit.borrow() { if let Ok((id, prio, data)) = msg_recv.try_recv() { trace!("send_loop: got {}, {} bytes", id, data.len()); sending.push(SendQueueItem { id, prio, data, cursor: 0, }); } else if let Some(mut item) = sending.pop() { trace!( "send_loop: sending bytes for {} ({} bytes, {} already sent)", item.id, item.data.len(), item.cursor ); let header_id = u16::to_be_bytes(item.id); if write_all_or_exit(&header_id[..], &mut write, &mut must_exit) .await? .is_none() { break; } if item.data.len() - item.cursor > MAX_CHUNK_SIZE { let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000); if write_all_or_exit(&header_size[..], &mut write, &mut must_exit) .await? .is_none() { break; } let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize; if write_all_or_exit( &item.data[item.cursor..new_cursor], &mut write, &mut must_exit, ) .await? .is_none() { break; } item.cursor = new_cursor; sending.push(item); } else { let send_len = (item.data.len() - item.cursor) as u16; let header_size = u16::to_be_bytes(send_len); if write_all_or_exit(&header_size[..], &mut write, &mut must_exit) .await? .is_none() { break; } if write_all_or_exit(&item.data[item.cursor..], &mut write, &mut must_exit) .await? .is_none() { break; } } write.flush().await.log_err("Could not flush in send_loop"); } else { let (id, prio, data) = msg_recv .recv() .await .ok_or(Error::Message("Connection closed.".into()))?; trace!("send_loop: got {}, {} bytes", id, data.len()); sending.push(SendQueueItem { id, prio, data, cursor: 0, }); } } Ok(()) } } #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { async fn recv_handler(self: Arc, id: RequestID, msg: Vec); async fn recv_loop( self: Arc, mut read: BoxStreamRead>>, mut must_exit: watch::Receiver, ) -> Result<(), Error> { let mut receiving = HashMap::new(); while !*must_exit.borrow() { trace!("recv_loop: reading packet"); let mut header_id = [0u8; 2]; if read_exact_or_exit(&mut header_id[..], &mut read, &mut must_exit) .await? .is_none() { break; } let id = RequestID::from_be_bytes(header_id); trace!("recv_loop: got header id: {:04x}", id); let mut header_size = [0u8; 2]; if read_exact_or_exit(&mut header_size[..], &mut read, &mut must_exit) .await? .is_none() { break; } let size = RequestID::from_be_bytes(header_size); trace!("recv_loop: got header size: {:04x}", id); let has_cont = (size & 0x8000) != 0; let size = size & !0x8000; let mut next_slice = vec![0; size as usize]; if read_exact_or_exit(&mut next_slice[..], &mut read, &mut must_exit) .await? .is_none() { break; } trace!("recv_loop: read {} bytes", size); let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]); msg_bytes.extend_from_slice(&next_slice[..]); if has_cont { receiving.insert(id, msg_bytes); } else { tokio::spawn(self.clone().recv_handler(id, msg_bytes)); } } Ok(()) } } async fn read_exact_or_exit( buf: &mut [u8], read: &mut BoxStreamRead>>, must_exit: &mut watch::Receiver, ) -> Result, Error> { tokio::select!( res = read.read_exact(buf) => Ok(Some(res?)), _ = await_exit(must_exit) => Ok(None), ) } async fn write_all_or_exit( buf: &[u8], write: &mut BoxStreamWrite>>, must_exit: &mut watch::Receiver, ) -> Result, Error> { tokio::select!( res = write.write_all(buf) => Ok(Some(res?)), _ = await_exit(must_exit) => Ok(None), ) } async fn await_exit(must_exit: &mut watch::Receiver) { loop { if must_exit.recv().await == Some(true) { return; } } }