From 5d7541e13a4c3640f0dc8aead595b51775fc0ac8 Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Sun, 19 Jun 2022 17:44:07 +0200 Subject: [PATCH] wait for any ready stream instead of the highest priority one --- src/endpoint.rs | 2 +- src/proto.rs | 185 ++++++++++++++++++++++++++++++------------------ src/util.rs | 8 +++ 3 files changed, 124 insertions(+), 71 deletions(-) diff --git a/src/endpoint.rs b/src/endpoint.rs index c25365a..c430d4e 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -42,7 +42,7 @@ where (self.clone(), None) } - async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self { + async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: AssociatedStream) -> Self { // TODO verify no stream ser_self } diff --git a/src/proto.rs b/src/proto.rs index 073a317..417b508 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -7,7 +7,7 @@ use log::{trace, warn}; use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::Stream; -use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, StreamExt}; +use futures::{AsyncReadExt, AsyncWriteExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -53,7 +53,8 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; pub(crate) type RequestID = u32; type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; +const ERROR_MARKER: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { @@ -99,8 +100,29 @@ impl From for DataReader { } } +struct DataReaderItem { + /// a fixed size buffer containing some data, possibly padded with 0s + data: [u8; MAX_CHUNK_LENGTH as usize], + /// actuall lenght of data + len: usize, + /// whethere there may be more data comming from this stream. Can be used for some + /// optimization. It's an error to set it to false if there is more data, but it is correct + /// (albeit sub-optimal) to set it to true if there is nothing coming after + may_have_more: bool, +} + +impl DataReaderItem { + fn empty_last() -> Self { + DataReaderItem { + data: [0; MAX_CHUNK_LENGTH as usize], + len: 0, + may_have_more: false, + } + } +} + impl Stream for DataReader { - type Item = ([u8; MAX_CHUNK_LENGTH as usize], usize); + type Item = DataReaderItem; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { @@ -114,7 +136,11 @@ impl Stream for DataReader { let mut body = [0; MAX_CHUNK_LENGTH as usize]; body[..len].copy_from_slice(&data[*pos..end]); *pos = end; - Poll::Ready(Some((body, len))) + Poll::Ready(Some(DataReaderItem { + data: body, + len, + may_have_more: end < data.len(), + })) } } DataReaderProj::Streaming { @@ -154,7 +180,11 @@ impl Stream for DataReader { let len = buf.len(); body[..len].copy_from_slice(buf); buf.clear(); - Poll::Ready(Some((body, len))) + Poll::Ready(Some(DataReaderItem { + data: body, + len, + may_have_more: !*eos, + })) } } } @@ -181,6 +211,8 @@ impl SendQueue { }; self.items[pos_prio].1.push_back(item); } + // used only in tests. They should probably be rewriten + #[allow(dead_code)] fn pop(&mut self) -> Option { match self.items.pop_front() { None => None, @@ -196,6 +228,54 @@ impl SendQueue { fn is_empty(&self) -> bool { self.items.iter().all(|(_k, v)| v.is_empty()) } + + // this is like an async fn, but hand implemented + fn next_ready(&mut self) -> SendQueuePollNextReady<'_> { + SendQueuePollNextReady { queue: self } + } +} + +struct SendQueuePollNextReady<'a> { + queue: &'a mut SendQueue, +} + +impl<'a> futures::Future for SendQueuePollNextReady<'a> { + type Output = (RequestID, DataReaderItem); + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + for i in 0..self.queue.items.len() { + let (_prio, items_at_prio) = &mut self.queue.items[i]; + + for _ in 0..items_at_prio.len() { + let mut item = items_at_prio.pop_front().unwrap(); + + match Pin::new(&mut item.data).poll_next(ctx) { + Poll::Pending => items_at_prio.push_back(item), + Poll::Ready(Some(data)) => { + let id = item.id; + if data.may_have_more { + self.queue.push(item); + } else { + if items_at_prio.is_empty() { + // this priority level is empty, remove it + self.queue.items.remove(i); + } + } + return Poll::Ready((id, data)); + } + Poll::Ready(None) => { + if items_at_prio.is_empty() { + // this priority level is empty, remove it + self.queue.items.remove(i); + } + return Poll::Ready((item.id, DataReaderItem::empty_last())); + } + } + } + } + // TODO what do we do if self.queue is empty? We won't get scheduled again. + Poll::Pending + } } /// The SendLoop trait, which is implemented both by the client and the server @@ -219,77 +299,42 @@ pub(crate) trait SendLoop: Sync { let mut sending = SendQueue::new(); let mut should_exit = false; while !should_exit || !sending.is_empty() { - if let Ok((id, prio, data)) = msg_recv.try_recv() { - match &data { - Data::Full(data) => { - trace!("send_loop: got {}, {} bytes", id, data.len()); - } - Data::Streaming(_) => { - trace!("send_loop: got {}, unknown size", id); - } + let recv_fut = msg_recv.recv(); + futures::pin_mut!(recv_fut); + let send_fut = sending.next_ready(); + + // recv_fut is cancellation-safe according to tokio doc, + // send_fut is cancellation-safe as implemented above? + use futures::future::Either; + match futures::future::select(recv_fut, send_fut).await { + Either::Left((sth, _send_fut)) => { + if let Some((id, prio, data)) = sth { + sending.push(SendQueueItem { + id, + prio, + data: data.into(), + }); + } else { + should_exit = true; + }; } - sending.push(SendQueueItem { - id, - prio, - data: data.into(), - }); - } else if let Some(mut item) = sending.pop() { - trace!("send_loop: sending bytes for {}", item.id,); + Either::Right(((id, data), _recv_fut)) => { + trace!("send_loop: sending bytes for {}", id); - let data = futures::select! { - data = item.data.next().fuse() => data, - default => { - // nothing to send yet; re-schedule and find something else to do - sending.push(item); - continue; + let header_id = RequestID::to_be_bytes(id); + write.write_all(&header_id[..]).await?; - // TODO if every SendQueueItem is waiting on data, use select_all to await - // something to do - } - }; + let body = &data.data[..data.len]; - let header_id = RequestID::to_be_bytes(item.id); - write.write_all(&header_id[..]).await?; + let size_header = if data.may_have_more { + ChunkLength::to_be_bytes(data.len as u16 | CHUNK_HAS_CONTINUATION) + } else { + ChunkLength::to_be_bytes(data.len as u16) + }; - let data = match data.as_ref() { - Some((data, len)) => &data[..*len], - None => &[], - }; - - if data.len() == MAX_CHUNK_LENGTH as usize { - let size_header = - ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION); write.write_all(&size_header[..]).await?; - - write.write_all(data).await?; - - sending.push(item); - } else { - let size_header = ChunkLength::to_be_bytes(data.len() as u16); - write.write_all(&size_header[..]).await?; - - write.write_all(data).await?; - } - - write.flush().await?; - } else { - let sth = msg_recv.recv().await; - if let Some((id, prio, data)) = sth { - match &data { - Data::Full(data) => { - trace!("send_loop: got {}, {} bytes", id, data.len()); - } - Data::Streaming(_) => { - trace!("send_loop: got {}, unknown size", id); - } - } - sending.push(SendQueueItem { - id, - prio, - data: data.into(), - }); - } else { - should_exit = true; + write.write_all(body).await?; + write.flush().await?; } } } diff --git a/src/util.rs b/src/util.rs index 02b4e7d..3ee0cb9 100644 --- a/src/util.rs +++ b/src/util.rs @@ -19,6 +19,14 @@ pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; /// A network key pub type NetworkKey = sodiumoxide::crypto::auth::Key; +/// A stream of associated data. +/// +/// The Stream can continue after receiving an error. +/// When sent through Netapp, the Vec may be split in smaller chunk in such a way +/// consecutive Vec may get merged, but Vec and error code may not be reordered +/// +/// The error code have no predefined meaning, it's up to you application to define their +/// semantic. pub type AssociatedStream = Pin> + Send>>; /// Utility function: encodes any serializable value in MessagePack binary format