use std::collections::VecDeque; use std::pin::Pin; use std::task::{Context, Poll}; use bytes::Bytes; use futures::Future; use futures::{Stream, StreamExt, TryStreamExt}; use tokio::io::AsyncRead; /// A stream of associated data. /// /// When sent through Netapp, the Vec may be split in smaller chunk in such a way /// consecutive Vec may get merged, but Vec and error code may not be reordered /// /// Error code 255 means the stream was cut before its end. Other codes have no predefined /// meaning, it's up to your application to define their semantic. pub type ByteStream = Pin + Send + Sync>>; pub type Packet = Result; // ---- pub struct ByteStreamReader { stream: ByteStream, buf: VecDeque, buf_len: usize, eos: bool, err: Option, } impl ByteStreamReader { pub fn new(stream: ByteStream) -> Self { ByteStreamReader { stream, buf: VecDeque::with_capacity(8), buf_len: 0, eos: false, err: None, } } pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { ByteStreamReadExact { reader: self, read_len, fail_on_eos: true, } } pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { ByteStreamReadExact { reader: self, read_len, fail_on_eos: false, } } pub async fn read_u8(&mut self) -> Result { Ok(self.read_exact(1).await?[0]) } pub async fn read_u16(&mut self) -> Result { let bytes = self.read_exact(2).await?; let mut b = [0u8; 2]; b.copy_from_slice(&bytes[..]); Ok(u16::from_be_bytes(b)) } pub async fn read_u32(&mut self) -> Result { let bytes = self.read_exact(4).await?; let mut b = [0u8; 4]; b.copy_from_slice(&bytes[..]); Ok(u32::from_be_bytes(b)) } pub fn into_stream(self) -> ByteStream { let buf_stream = futures::stream::iter(self.buf.into_iter().map(Ok)); if let Some(err) = self.err { Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) }))) } else if self.eos { Box::pin(buf_stream) } else { Box::pin(buf_stream.chain(self.stream)) } } pub fn take_buffer(&mut self) -> Bytes { let bytes = Bytes::from(self.buf.iter().map(|x| &x[..]).collect::>().concat()); self.buf.clear(); self.buf_len = 0; bytes } pub fn eos(&self) -> bool { self.buf.is_empty() && self.eos } fn try_get(&mut self, read_len: usize) -> Option { if self.buf_len >= read_len { let mut slices = Vec::with_capacity(self.buf.len()); let mut taken = 0; while taken < read_len { let front = self.buf.pop_front().unwrap(); if taken + front.len() <= read_len { taken += front.len(); self.buf_len -= front.len(); slices.push(front); } else { let front_take = read_len - taken; slices.push(front.slice(..front_take)); self.buf.push_front(front.slice(front_take..)); self.buf_len -= front_take; break; } } Some( slices .iter() .map(|x| &x[..]) .collect::>() .concat() .into(), ) } else { None } } } pub enum ReadExactError { UnexpectedEos, Stream(u8), } #[pin_project::pin_project] pub struct ByteStreamReadExact<'a> { #[pin] reader: &'a mut ByteStreamReader, read_len: usize, fail_on_eos: bool, } impl<'a> Future for ByteStreamReadExact<'a> { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); loop { if let Some(bytes) = this.reader.try_get(*this.read_len) { return Poll::Ready(Ok(bytes)); } if let Some(err) = this.reader.err { return Poll::Ready(Err(ReadExactError::Stream(err))); } if this.reader.eos { if *this.fail_on_eos { return Poll::Ready(Err(ReadExactError::UnexpectedEos)); } else { return Poll::Ready(Ok(this.reader.take_buffer())); } } match futures::ready!(this.reader.stream.as_mut().poll_next(cx)) { Some(Ok(slice)) => { this.reader.buf_len += slice.len(); this.reader.buf.push_back(slice); } Some(Err(e)) => { this.reader.err = Some(e); this.reader.eos = true; } None => { this.reader.eos = true; } } } } } // ---- fn u8_to_io_error(v: u8) -> std::io::Error { use std::io::{Error, ErrorKind}; let kind = match v { 101 => ErrorKind::ConnectionAborted, 102 => ErrorKind::BrokenPipe, 103 => ErrorKind::WouldBlock, 104 => ErrorKind::InvalidInput, 105 => ErrorKind::InvalidData, 106 => ErrorKind::TimedOut, 107 => ErrorKind::Interrupted, 108 => ErrorKind::UnexpectedEof, 109 => ErrorKind::OutOfMemory, 110 => ErrorKind::ConnectionReset, _ => ErrorKind::Other, }; Error::new(kind, "(in netapp stream)") } fn io_error_to_u8(e: std::io::Error) -> u8 { use std::io::ErrorKind; match e.kind() { ErrorKind::ConnectionAborted => 101, ErrorKind::BrokenPipe => 102, ErrorKind::WouldBlock => 103, ErrorKind::InvalidInput => 104, ErrorKind::InvalidData => 105, ErrorKind::TimedOut => 106, ErrorKind::Interrupted => 107, ErrorKind::UnexpectedEof => 108, ErrorKind::OutOfMemory => 109, ErrorKind::ConnectionReset => 110, _ => 100, } } pub fn asyncread_stream(reader: R) -> ByteStream { Box::pin(tokio_util::io::ReaderStream::new(reader).map_err(io_error_to_u8)) } pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static { tokio_util::io::StreamReader::new(stream.map_err(u8_to_io_error)) }