use std::collections::{BTreeMap, HashMap, VecDeque}; use std::sync::Arc; use log::trace; use async_trait::async_trait; use async_std::io::prelude::WriteExt; use async_std::io::ReadExt; use tokio::io::{ReadHalf, WriteHalf}; use tokio::net::TcpStream; use tokio::sync::{mpsc, watch}; use crate::error::*; use kuska_handshake::async_std::{BoxStreamRead, BoxStreamWrite, TokioCompat}; const MAX_CHUNK_SIZE: usize = 0x4000; pub mod prio { pub const HIGH: u8 = 0x20; pub const NORMAL: u8 = 0x40; pub const BACKGROUND: u8 = 0x80; pub const PRIMARY: u8 = 0x00; pub const SECONDARY: u8 = 0x01; } pub type RequestID = u16; pub type RequestPriority = u8; struct SendQueueItem { id: RequestID, prio: RequestPriority, data: Vec, cursor: usize, } struct SendQueue { items: BTreeMap>, } impl SendQueue { fn new() -> Self { Self { items: BTreeMap::new(), } } fn push(&mut self, item: SendQueueItem) { let prio = item.prio; let mut items_at_prio = self .items .remove(&prio) .unwrap_or(VecDeque::with_capacity(4)); items_at_prio.push_back(item); self.items.insert(prio, items_at_prio); } fn pop(&mut self) -> Option { match self.items.pop_first() { None => None, Some((prio, mut items_at_prio)) => { let ret = items_at_prio.pop_front(); if !items_at_prio.is_empty() { self.items.insert(prio, items_at_prio); } ret } } } } #[async_trait] pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec)>, mut write: BoxStreamWrite>>, mut must_exit: watch::Receiver, ) -> Result<(), Error> { let mut sending = SendQueue::new(); while !*must_exit.borrow() { if let Ok((id, prio, data)) = msg_recv.try_recv() { trace!("send_loop: got {}, {} bytes", id, data.len()); sending.push(SendQueueItem { id, prio, data, cursor: 0, }); } else if let Some(mut item) = sending.pop() { trace!( "send_loop: sending bytes for {} ({} bytes, {} already sent)", item.id, item.data.len(), item.cursor ); let header_id = u16::to_be_bytes(item.id); if write_all_or_exit(&header_id[..], &mut write, &mut must_exit) .await? .is_none() { break; } if item.data.len() - item.cursor > MAX_CHUNK_SIZE { let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000); if write_all_or_exit(&header_size[..], &mut write, &mut must_exit) .await? .is_none() { break; } let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize; if write_all_or_exit( &item.data[item.cursor..new_cursor], &mut write, &mut must_exit, ) .await? .is_none() { break; } item.cursor = new_cursor; sending.push(item); } else { let send_len = (item.data.len() - item.cursor) as u16; let header_size = u16::to_be_bytes(send_len); if write_all_or_exit(&header_size[..], &mut write, &mut must_exit) .await? .is_none() { break; } if write_all_or_exit(&item.data[item.cursor..], &mut write, &mut must_exit) .await? .is_none() { break; } } write.flush().await.log_err("Could not flush in send_loop"); } else { let (id, prio, data) = msg_recv .recv() .await .ok_or(Error::Message("Connection closed.".into()))?; trace!("send_loop: got {}, {} bytes", id, data.len()); sending.push(SendQueueItem { id, prio, data, cursor: 0, }); } } Ok(()) } } #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { async fn recv_handler(self: Arc, id: RequestID, msg: Vec); async fn recv_loop( self: Arc, mut read: BoxStreamRead>>, mut must_exit: watch::Receiver, ) -> Result<(), Error> { let mut receiving = HashMap::new(); while !*must_exit.borrow() { trace!("recv_loop: reading packet"); let mut header_id = [0u8; 2]; if read_exact_or_exit(&mut header_id[..], &mut read, &mut must_exit) .await? .is_none() { break; } let id = RequestID::from_be_bytes(header_id); trace!("recv_loop: got header id: {:04x}", id); let mut header_size = [0u8; 2]; if read_exact_or_exit(&mut header_size[..], &mut read, &mut must_exit) .await? .is_none() { break; } let size = RequestID::from_be_bytes(header_size); trace!("recv_loop: got header size: {:04x}", id); let has_cont = (size & 0x8000) != 0; let size = size & !0x8000; let mut next_slice = vec![0; size as usize]; if read_exact_or_exit(&mut next_slice[..], &mut read, &mut must_exit) .await? .is_none() { break; } trace!("recv_loop: read {} bytes", size); let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]); msg_bytes.extend_from_slice(&next_slice[..]); if has_cont { receiving.insert(id, msg_bytes); } else { tokio::spawn(self.clone().recv_handler(id, msg_bytes)); } } Ok(()) } } async fn read_exact_or_exit( buf: &mut [u8], read: &mut BoxStreamRead>>, must_exit: &mut watch::Receiver, ) -> Result, Error> { tokio::select!( res = read.read_exact(buf) => Ok(Some(res?)), _ = await_exit(must_exit) => Ok(None), ) } async fn write_all_or_exit( buf: &[u8], write: &mut BoxStreamWrite>>, must_exit: &mut watch::Receiver, ) -> Result, Error> { tokio::select!( res = write.write_all(buf) => Ok(Some(res?)), _ = await_exit(must_exit) => Ok(None), ) } async fn await_exit(must_exit: &mut watch::Receiver) { loop { if must_exit.recv().await == Some(true) { return; } } }