diff --git a/src/proto.rs b/src/proto.rs index 417b508..e3f9be8 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -78,7 +78,7 @@ enum DataReader { Streaming { #[pin] reader: AssociatedStream, - packet: Vec, + packet: Result, u8>, pos: usize, buf: Vec, eos: bool, @@ -91,7 +91,7 @@ impl From for DataReader { Data::Full(data) => DataReader::Full { data, pos: 0 }, Data::Streaming(reader) => DataReader::Streaming { reader, - packet: Vec::new(), + packet: Ok(Vec::new()), pos: 0, buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), eos: false, @@ -100,11 +100,18 @@ impl From for DataReader { } } +enum DataFrame { + Data { + /// a fixed size buffer containing some data, possibly padded with 0s + data: [u8; MAX_CHUNK_LENGTH as usize], + /// actual lenght of data + len: usize, + }, + Error(u8), +} + struct DataReaderItem { - /// a fixed size buffer containing some data, possibly padded with 0s - data: [u8; MAX_CHUNK_LENGTH as usize], - /// actuall lenght of data - len: usize, + data: DataFrame, /// whethere there may be more data comming from this stream. Can be used for some /// optimization. It's an error to set it to false if there is more data, but it is correct /// (albeit sub-optimal) to set it to true if there is nothing coming after @@ -114,11 +121,34 @@ struct DataReaderItem { impl DataReaderItem { fn empty_last() -> Self { DataReaderItem { - data: [0; MAX_CHUNK_LENGTH as usize], - len: 0, + data: DataFrame::Data { + data: [0; MAX_CHUNK_LENGTH as usize], + len: 0, + }, may_have_more: false, } } + + fn header(&self) -> [u8; 2] { + let continuation = if self.may_have_more { + CHUNK_HAS_CONTINUATION + } else { + 0 + }; + let len = match self.data { + DataFrame::Data { len, .. } => len as u16, + DataFrame::Error(e) => e as u16 | ERROR_MARKER, + }; + + ChunkLength::to_be_bytes(len | continuation) + } + + fn data(&self) -> &[u8] { + match self.data { + DataFrame::Data { ref data, len } => &data[..len], + DataFrame::Error(_) => &[], + } + } } impl Stream for DataReader { @@ -137,15 +167,14 @@ impl Stream for DataReader { body[..len].copy_from_slice(&data[*pos..end]); *pos = end; Poll::Ready(Some(DataReaderItem { - data: body, - len, + data: DataFrame::Data { data: body, len }, may_have_more: end < data.len(), })) } } DataReaderProj::Streaming { mut reader, - packet, + packet: res_packet, pos, buf, eos, @@ -156,6 +185,17 @@ impl Stream for DataReader { return Poll::Ready(None); } loop { + let packet = match res_packet { + Ok(v) => v, + Err(e) => { + let e = *e; + *res_packet = Ok(Vec::new()); + return Poll::Ready(Some(DataReaderItem { + data: DataFrame::Error(e), + may_have_more: true, + })); + } + }; let packet_left = packet.len() - *pos; let buf_left = MAX_CHUNK_LENGTH as usize - buf.len(); let to_read = std::cmp::min(buf_left, packet_left); @@ -168,8 +208,13 @@ impl Stream for DataReader { // we don't have a full buf, packet is empty; try receive more if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) { - *packet = p; + *res_packet = p; *pos = 0; + // if buf is empty, we will loop and return the error directly. If buf + // isn't empty, send it before by breaking. + if res_packet.is_err() && !buf.is_empty() { + break; + } } else { *eos = true; break; @@ -181,8 +226,7 @@ impl Stream for DataReader { body[..len].copy_from_slice(buf); buf.clear(); Poll::Ready(Some(DataReaderItem { - data: body, - len, + data: DataFrame::Data { data: body, len }, may_have_more: !*eos, })) } @@ -211,8 +255,8 @@ impl SendQueue { }; self.items[pos_prio].1.push_back(item); } - // used only in tests. They should probably be rewriten - #[allow(dead_code)] + // used only in tests. They should probably be rewriten + #[allow(dead_code)] fn pop(&mut self) -> Option { match self.items.pop_front() { None => None, @@ -324,16 +368,8 @@ pub(crate) trait SendLoop: Sync { let header_id = RequestID::to_be_bytes(id); write.write_all(&header_id[..]).await?; - let body = &data.data[..data.len]; - - let size_header = if data.may_have_more { - ChunkLength::to_be_bytes(data.len as u16 | CHUNK_HAS_CONTINUATION) - } else { - ChunkLength::to_be_bytes(data.len as u16) - }; - - write.write_all(&size_header[..]).await?; - write.write_all(body).await?; + write.write_all(&data.header()).await?; + write.write_all(data.data()).await?; write.flush().await?; } } @@ -413,7 +449,13 @@ pub(crate) trait RecvLoop: Sync + 'static { trace!("recv_loop: got header size: {:04x}", size); let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; - let size = size & !CHUNK_HAS_CONTINUATION; + let is_error = (size & ERROR_MARKER) != 0; + let size = if !is_error { + size & !CHUNK_HAS_CONTINUATION + } else { + 0 + }; + // TODO propagate errors let mut next_slice = vec![0; size as usize]; read.read_exact(&mut next_slice[..]).await?; @@ -430,7 +472,8 @@ pub(crate) trait RecvLoop: Sync + 'static { let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); if let Some(receiver) = channel_pair.take_receiver() { - self.recv_handler(id, msg_bytes, Box::pin(receiver)); + use futures::StreamExt; + self.recv_handler(id, msg_bytes, Box::pin(receiver.map(|v| Ok(v)))); } else { warn!("Couldn't take receiver part of stream") } diff --git a/src/util.rs b/src/util.rs index 3ee0cb9..76d7ecf 100644 --- a/src/util.rs +++ b/src/util.rs @@ -27,7 +27,7 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// /// The error code have no predefined meaning, it's up to you application to define their /// semantic. -pub type AssociatedStream = Pin> + Send>>; +pub type AssociatedStream = Pin, u8>> + Send>>; /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library.