add streaming body to requests and responses #3

Merged
lx merged 64 commits from stream-body into main 2022-09-13 10:56:54 +00:00
4 changed files with 64 additions and 54 deletions
Showing only changes of commit 7909a95d3c - Show all commits

View file

@ -175,7 +175,7 @@ impl ClientConn {
"Too many inflight requests! RequestID collision. Interrupting previous request." "Too many inflight requests! RequestID collision. Interrupting previous request."
); );
let _ = old_ch.send(Box::pin(futures::stream::once(async move { let _ = old_ch.send(Box::pin(futures::stream::once(async move {
Err(Error::IdCollision.code()) Err(std::io::Error::new(std::io::ErrorKind::Other, "RequestID collision, too many inflight requests"))
}))); })));
} }

View file

@ -35,7 +35,7 @@ impl Sender {
impl Drop for Sender { impl Drop for Sender {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(inner) = self.inner.take() { if let Some(inner) = self.inner.take() {
let _ = inner.send(Err(255)); let _ = inner.send(Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Netapp connection dropped before end of stream")));
} }
} }
} }
@ -76,25 +76,26 @@ pub(crate) trait RecvLoop: Sync + 'static {
let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
let is_error = (size & ERROR_MARKER) != 0; let is_error = (size & ERROR_MARKER) != 0;
let size = (size & CHUNK_LENGTH_MASK) as usize;
let mut next_slice = vec![0; size as usize];
read.read_exact(&mut next_slice[..]).await?;
let packet = if is_error { let packet = if is_error {
trace!( let msg = String::from_utf8(next_slice).unwrap_or("<invalid utf8 error message>".into());
"recv_loop: got id {}, header_size {:04x}, error {}", debug!("recv_loop: got id {}, error: {}", id, msg);
id, Some(Err(std::io::Error::new(std::io::ErrorKind::Other, msg)))
size,
size & !ERROR_MARKER
);
Err((size & !ERROR_MARKER) as u8)
} else { } else {
let size = size & !CHUNK_HAS_CONTINUATION;
let mut next_slice = vec![0; size as usize];
read.read_exact(&mut next_slice[..]).await?;
trace!( trace!(
"recv_loop: got id {}, header_size {:04x}, {} bytes", "recv_loop: got id {}, size {}, has_cont {}",
id, id,
size, size,
next_slice.len() has_cont
); );
Ok(Bytes::from(next_slice)) if !next_slice.is_empty() {
Some(Ok(Bytes::from(next_slice)))
} else {
None
}
}; };
let mut sender = if let Some(send) = streams.remove(&(id)) { let mut sender = if let Some(send) = streams.remove(&(id)) {
@ -109,9 +110,12 @@ pub(crate) trait RecvLoop: Sync + 'static {
Sender::new(send) Sender::new(send)
}; };
// If we get an error, the receiving end is disconnected. if let Some(packet) = packet {
// We still need to reach eos before dropping this sender // If we cannot put packet in channel, it means that the
let _ = sender.send(packet); // receiving end of the channel is disconnected.
// We still need to reach eos before dropping this sender
let _ = sender.send(packet);
}
if has_cont { if has_cont {
assert!(!is_error); assert!(!is_error);

View file

@ -18,9 +18,11 @@ use crate::stream::*;
// Messages are sent by chunks // Messages are sent by chunks
// Chunk format: // Chunk format:
// - u32 BE: request id (same for request and response) // - u32 BE: request id (same for request and response)
// - u16 BE: chunk length, possibly with CHUNK_HAS_CONTINUATION flag // - u16 BE: chunk length + flags:
// when this is not the last chunk of the message // CHUNK_HAS_CONTINUATION when this is not the last chunk of the stream
// - [u8; chunk_length] chunk data // ERROR_MARKER if this chunk denotes an error
// (these two flags are exclusive, an error denotes the end of the stream)
// - [u8; chunk_length] chunk data / error message
pub(crate) type RequestID = u32; pub(crate) type RequestID = u32;
pub(crate) type ChunkLength = u16; pub(crate) type ChunkLength = u16;
@ -28,6 +30,7 @@ pub(crate) type ChunkLength = u16;
pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0;
pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; pub(crate) const ERROR_MARKER: ChunkLength = 0x4000;
pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF;
struct SendQueue { struct SendQueue {
items: Vec<(u8, VecDeque<SendQueueItem>)>, items: Vec<(u8, VecDeque<SendQueueItem>)>,
@ -92,29 +95,12 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> {
let id = item.id; let id = item.id;
let eos = item.data.eos(); let eos = item.data.eos();
let data_frame = match bytes_or_err { let packet = bytes_or_err.map_err(|e| match e {
Ok(bytes) => { ReadExactError::Stream(err) => err,
trace!( _ => unreachable!(),
"send queue poll next ready: id {} eos {:?} bytes {}", });
id,
eos, let data_frame = DataFrame::from_packet(packet, !eos);
bytes.len()
);
DataFrame::Data(bytes, !eos)
}
Err(e) => DataFrame::Error(match e {
ReadExactError::Stream(code) => {
trace!(
"send queue poll next ready: id {} eos {:?} ERROR {}",
id,
eos,
code
);
code
}
_ => unreachable!(),
}),
};
if !eos && !matches!(data_frame, DataFrame::Error(_)) { if !eos && !matches!(data_frame, DataFrame::Error(_)) {
items_at_prio.push_back(item); items_at_prio.push_back(item);
@ -139,15 +125,32 @@ enum DataFrame {
/// (albeit sub-optimal) to set it to true if there is nothing coming after /// (albeit sub-optimal) to set it to true if there is nothing coming after
Data(Bytes, bool), Data(Bytes, bool),
/// An error code automatically signals the end of the stream /// An error code automatically signals the end of the stream
Error(u8), Error(Bytes),
} }
impl DataFrame { impl DataFrame {
fn from_packet(p: Packet, has_cont: bool) -> Self {
match p {
Ok(bytes) => {
assert!(bytes.len() <= MAX_CHUNK_LENGTH as usize);
Self::Data(bytes, has_cont)
}
Err(e) => {
let msg = format!("{}", e);
let mut msg = Bytes::from(msg.into_bytes());
if msg.len() > MAX_CHUNK_LENGTH as usize {
msg = msg.slice(..MAX_CHUNK_LENGTH as usize);
}
Self::Error(msg)
}
}
}
fn header(&self) -> [u8; 2] { fn header(&self) -> [u8; 2] {
let header_u16 = match self { let header_u16 = match self {
DataFrame::Data(data, false) => data.len() as u16, DataFrame::Data(data, false) => data.len() as u16,
DataFrame::Data(data, true) => data.len() as u16 | CHUNK_HAS_CONTINUATION, DataFrame::Data(data, true) => data.len() as u16 | CHUNK_HAS_CONTINUATION,
DataFrame::Error(e) => *e as u16 | ERROR_MARKER, DataFrame::Error(msg) => msg.len() as u16 | ERROR_MARKER,
}; };
ChunkLength::to_be_bytes(header_u16) ChunkLength::to_be_bytes(header_u16)
} }
@ -155,7 +158,7 @@ impl DataFrame {
fn data(&self) -> &[u8] { fn data(&self) -> &[u8] {
match self { match self {
DataFrame::Data(ref data, _) => &data[..], DataFrame::Data(ref data, _) => &data[..],
DataFrame::Error(_) => &[], DataFrame::Error(ref msg) => &msg[..],
} }
} }
} }

View file

@ -4,7 +4,7 @@ use std::task::{Context, Poll};
use bytes::Bytes; use bytes::Bytes;
use futures::Future; use futures::Future;
use futures::{Stream, StreamExt, TryStreamExt}; use futures::{Stream, StreamExt};
use tokio::io::AsyncRead; use tokio::io::AsyncRead;
use crate::bytes_buf::BytesBuf; use crate::bytes_buf::BytesBuf;
@ -18,7 +18,7 @@ use crate::bytes_buf::BytesBuf;
/// meaning, it's up to your application to define their semantic. /// meaning, it's up to your application to define their semantic.
pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>; pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>;
pub type Packet = Result<Bytes, u8>; pub type Packet = Result<Bytes, std::io::Error>;
// ---- // ----
@ -26,7 +26,7 @@ pub struct ByteStreamReader {
stream: ByteStream, stream: ByteStream,
buf: BytesBuf, buf: BytesBuf,
eos: bool, eos: bool,
err: Option<u8>, err: Option<std::io::Error>,
} }
impl ByteStreamReader { impl ByteStreamReader {
@ -99,7 +99,7 @@ impl ByteStreamReader {
pub enum ReadExactError { pub enum ReadExactError {
UnexpectedEos, UnexpectedEos,
Stream(u8), Stream(std::io::Error),
} }
#[pin_project::pin_project] #[pin_project::pin_project]
@ -120,7 +120,8 @@ impl<'a> Future for ByteStreamReadExact<'a> {
if let Some(bytes) = this.reader.try_get(*this.read_len) { if let Some(bytes) = this.reader.try_get(*this.read_len) {
return Poll::Ready(Ok(bytes)); return Poll::Ready(Ok(bytes));
} }
if let Some(err) = this.reader.err { if let Some(err) = &this.reader.err {
let err = std::io::Error::new(err.kind(), format!("{}", err));
return Poll::Ready(Err(ReadExactError::Stream(err))); return Poll::Ready(Err(ReadExactError::Stream(err)));
} }
if this.reader.eos { if this.reader.eos {
@ -149,6 +150,7 @@ impl<'a> Future for ByteStreamReadExact<'a> {
// ---- // ----
/*
fn u8_to_io_error(v: u8) -> std::io::Error { fn u8_to_io_error(v: u8) -> std::io::Error {
use std::io::{Error, ErrorKind}; use std::io::{Error, ErrorKind};
let kind = match v { let kind = match v {
@ -183,11 +185,12 @@ fn io_error_to_u8(e: std::io::Error) -> u8 {
_ => 100, _ => 100,
} }
} }
*/
pub fn asyncread_stream<R: AsyncRead + Send + Sync + 'static>(reader: R) -> ByteStream { pub fn asyncread_stream<R: AsyncRead + Send + Sync + 'static>(reader: R) -> ByteStream {
Box::pin(tokio_util::io::ReaderStream::new(reader).map_err(io_error_to_u8)) Box::pin(tokio_util::io::ReaderStream::new(reader))
} }
pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static { pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static {
tokio_util::io::StreamReader::new(stream.map_err(u8_to_io_error)) tokio_util::io::StreamReader::new(stream)
} }