add streaming body to requests and responses #3
4 changed files with 159 additions and 21 deletions
|
@ -1,7 +1,9 @@
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::pin::Pin;
|
||||||
use std::sync::atomic::{self, AtomicU32};
|
use std::sync::atomic::{self, AtomicU32};
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
use std::task::Poll;
|
||||||
|
|
||||||
use arc_swap::ArcSwapOption;
|
use arc_swap::ArcSwapOption;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
@ -9,6 +11,7 @@ use bytes::Bytes;
|
||||||
use log::{debug, error, trace};
|
use log::{debug, error, trace};
|
||||||
|
|
||||||
use futures::io::AsyncReadExt;
|
use futures::io::AsyncReadExt;
|
||||||
|
use futures::Stream;
|
||||||
use kuska_handshake::async_std::{handshake_client, BoxStream};
|
use kuska_handshake::async_std::{handshake_client, BoxStream};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
|
@ -35,7 +38,7 @@ pub(crate) struct ClientConn {
|
||||||
pub(crate) remote_addr: SocketAddr,
|
pub(crate) remote_addr: SocketAddr,
|
||||||
pub(crate) peer_id: NodeID,
|
pub(crate) peer_id: NodeID,
|
||||||
|
|
||||||
query_send: ArcSwapOption<mpsc::UnboundedSender<SendStream>>,
|
query_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>,
|
||||||
|
|
||||||
next_query_number: AtomicU32,
|
next_query_number: AtomicU32,
|
||||||
inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>,
|
inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>,
|
||||||
|
@ -193,7 +196,9 @@ impl ClientConn {
|
||||||
#[cfg(feature = "telemetry")]
|
#[cfg(feature = "telemetry")]
|
||||||
span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64));
|
span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64));
|
||||||
|
|
||||||
query_send.send((id, prio, req_order, req_stream))?;
|
query_send.send(SendItem::Stream(id, prio, req_order, req_stream))?;
|
||||||
|
|
||||||
|
let canceller = CancelOnDrop::new(id, query_send.as_ref().clone());
|
||||||
|
|
||||||
cfg_if::cfg_if! {
|
cfg_if::cfg_if! {
|
||||||
if #[cfg(feature = "telemetry")] {
|
if #[cfg(feature = "telemetry")] {
|
||||||
|
@ -205,6 +210,8 @@ impl ClientConn {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let stream = Box::pin(canceller.for_stream(stream));
|
||||||
|
|
||||||
let resp_enc = RespEnc::decode(stream).await?;
|
let resp_enc = RespEnc::decode(stream).await?;
|
||||||
debug!("client: got response to request {} (path {})", id, path);
|
debug!("client: got response to request {} (path {})", id, path);
|
||||||
Resp::from_enc(resp_enc)
|
Resp::from_enc(resp_enc)
|
||||||
|
@ -223,6 +230,63 @@ impl RecvLoop for ClientConn {
|
||||||
if ch.send(stream).is_err() {
|
if ch.send(stream).is_err() {
|
||||||
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
|
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
debug!("Got unexpected response to request {}, dropping it", id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----
|
||||||
|
|
||||||
|
struct CancelOnDrop {
|
||||||
|
id: RequestID,
|
||||||
|
query_send: mpsc::UnboundedSender<SendItem>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CancelOnDrop {
|
||||||
|
fn new(id: RequestID, query_send: mpsc::UnboundedSender<SendItem>) -> Self {
|
||||||
|
Self { id, query_send }
|
||||||
|
}
|
||||||
|
fn for_stream(self, stream: ByteStream) -> CancelOnDropStream {
|
||||||
|
CancelOnDropStream {
|
||||||
|
cancel: Some(self),
|
||||||
|
stream: stream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for CancelOnDrop {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
trace!("cancelling request {}", self.id);
|
||||||
|
let _ = self.query_send.send(SendItem::Cancel(self.id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pin_project::pin_project]
|
||||||
|
struct CancelOnDropStream {
|
||||||
|
cancel: Option<CancelOnDrop>,
|
||||||
|
#[pin]
|
||||||
|
stream: ByteStream,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Stream for CancelOnDropStream {
|
||||||
|
type Item = Packet;
|
||||||
|
|
||||||
|
fn poll_next(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<Option<Self::Item>> {
|
||||||
|
let this = self.project();
|
||||||
|
let res = this.stream.poll_next(cx);
|
||||||
|
if matches!(res, Poll::Ready(None)) {
|
||||||
|
if let Some(c) = this.cancel.take() {
|
||||||
|
std::mem::forget(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
res
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||||
|
self.stream.size_hint()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
18
src/recv.rs
18
src/recv.rs
|
@ -53,6 +53,7 @@ impl Drop for Sender {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub(crate) trait RecvLoop: Sync + 'static {
|
pub(crate) trait RecvLoop: Sync + 'static {
|
||||||
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream);
|
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream);
|
||||||
|
fn cancel_handler(self: &Arc<Self>, _id: RequestID) {}
|
||||||
|
|
||||||
async fn recv_loop<R>(self: Arc<Self>, mut read: R, debug_name: String) -> Result<(), Error>
|
async fn recv_loop<R>(self: Arc<Self>, mut read: R, debug_name: String) -> Result<(), Error>
|
||||||
where
|
where
|
||||||
|
@ -78,6 +79,18 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
||||||
read.read_exact(&mut header_size[..]).await?;
|
read.read_exact(&mut header_size[..]).await?;
|
||||||
let size = ChunkLength::from_be_bytes(header_size);
|
let size = ChunkLength::from_be_bytes(header_size);
|
||||||
|
|
||||||
|
if size == CANCEL_REQUEST {
|
||||||
|
if let Some(mut stream) = streams.remove(&id) {
|
||||||
|
let _ = stream.send(Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::Other,
|
||||||
|
"netapp: cancel requested",
|
||||||
|
)));
|
||||||
|
stream.end();
|
||||||
|
}
|
||||||
|
self.cancel_handler(id);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
|
let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
|
||||||
let is_error = (size & ERROR_MARKER) != 0;
|
let is_error = (size & ERROR_MARKER) != 0;
|
||||||
let size = (size & CHUNK_LENGTH_MASK) as usize;
|
let size = (size & CHUNK_LENGTH_MASK) as usize;
|
||||||
|
@ -88,7 +101,10 @@ pub(crate) trait RecvLoop: Sync + 'static {
|
||||||
let kind = u8_to_io_errorkind(next_slice[0]);
|
let kind = u8_to_io_errorkind(next_slice[0]);
|
||||||
let msg =
|
let msg =
|
||||||
std::str::from_utf8(&next_slice[1..]).unwrap_or("<invalid utf8 error message>");
|
std::str::from_utf8(&next_slice[1..]).unwrap_or("<invalid utf8 error message>");
|
||||||
debug!("recv_loop({}): got id {}, error {:?}: {}", debug_name, id, kind, msg);
|
debug!(
|
||||||
|
"recv_loop({}): got id {}, error {:?}: {}",
|
||||||
|
debug_name, id, kind, msg
|
||||||
|
);
|
||||||
Some(Err(std::io::Error::new(kind, msg.to_string())))
|
Some(Err(std::io::Error::new(kind, msg.to_string())))
|
||||||
} else {
|
} else {
|
||||||
trace!(
|
trace!(
|
||||||
|
|
57
src/send.rs
57
src/send.rs
|
@ -22,6 +22,7 @@ use crate::stream::*;
|
||||||
// CHUNK_HAS_CONTINUATION when this is not the last chunk of the stream
|
// CHUNK_HAS_CONTINUATION when this is not the last chunk of the stream
|
||||||
// ERROR_MARKER if this chunk denotes an error
|
// ERROR_MARKER if this chunk denotes an error
|
||||||
// (these two flags are exclusive, an error denotes the end of the stream)
|
// (these two flags are exclusive, an error denotes the end of the stream)
|
||||||
|
// **special value** 0xFFFF indicates a CANCEL message
|
||||||
// - [u8; chunk_length], either
|
// - [u8; chunk_length], either
|
||||||
// - if not error: chunk data
|
// - if not error: chunk data
|
||||||
// - if error:
|
// - if error:
|
||||||
|
@ -35,8 +36,14 @@ pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0;
|
||||||
pub(crate) const ERROR_MARKER: ChunkLength = 0x4000;
|
pub(crate) const ERROR_MARKER: ChunkLength = 0x4000;
|
||||||
pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
|
pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
|
||||||
pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF;
|
pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF;
|
||||||
|
pub(crate) const CANCEL_REQUEST: ChunkLength = 0xFFFF;
|
||||||
|
|
||||||
pub(crate) type SendStream = (RequestID, RequestPriority, Option<OrderTag>, ByteStream);
|
pub(crate) enum SendItem {
|
||||||
|
Stream(RequestID, RequestPriority, Option<OrderTag>, ByteStream),
|
||||||
|
Cancel(RequestID),
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----
|
||||||
|
|
||||||
struct SendQueue {
|
struct SendQueue {
|
||||||
items: Vec<(u8, SendQueuePriority)>,
|
items: Vec<(u8, SendQueuePriority)>,
|
||||||
|
@ -71,6 +78,11 @@ impl SendQueue {
|
||||||
};
|
};
|
||||||
self.items[pos_prio].1.push(item);
|
self.items[pos_prio].1.push(item);
|
||||||
}
|
}
|
||||||
|
fn remove(&mut self, id: RequestID) {
|
||||||
|
for (_, prioq) in self.items.iter_mut() {
|
||||||
|
prioq.remove(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
fn is_empty(&self) -> bool {
|
fn is_empty(&self) -> bool {
|
||||||
self.items.iter().all(|(_k, v)| v.is_empty())
|
self.items.iter().all(|(_k, v)| v.is_empty())
|
||||||
}
|
}
|
||||||
|
@ -96,6 +108,16 @@ impl SendQueuePriority {
|
||||||
}
|
}
|
||||||
self.items.push_back(item);
|
self.items.push_back(item);
|
||||||
}
|
}
|
||||||
|
fn remove(&mut self, id: RequestID) {
|
||||||
|
if let Some(i) = self.items.iter().position(|x| x.id == id) {
|
||||||
|
let item = self.items.remove(i).unwrap();
|
||||||
|
if let Some(OrderTag(stream, order)) = item.order_tag {
|
||||||
|
let order_vec = self.order.get_mut(&stream).unwrap();
|
||||||
|
let j = order_vec.iter().position(|x| *x == order).unwrap();
|
||||||
|
order_vec.remove(j).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
fn is_empty(&self) -> bool {
|
fn is_empty(&self) -> bool {
|
||||||
self.items.is_empty()
|
self.items.is_empty()
|
||||||
}
|
}
|
||||||
|
@ -229,7 +251,7 @@ impl DataFrame {
|
||||||
pub(crate) trait SendLoop: Sync {
|
pub(crate) trait SendLoop: Sync {
|
||||||
async fn send_loop<W>(
|
async fn send_loop<W>(
|
||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
msg_recv: mpsc::UnboundedReceiver<SendStream>,
|
msg_recv: mpsc::UnboundedReceiver<SendItem>,
|
||||||
mut write: BoxStreamWrite<W>,
|
mut write: BoxStreamWrite<W>,
|
||||||
debug_name: String,
|
debug_name: String,
|
||||||
) -> Result<(), Error>
|
) -> Result<(), Error>
|
||||||
|
@ -264,16 +286,27 @@ pub(crate) trait SendLoop: Sync {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
biased; // always read incomming channel first if it has data
|
biased; // always read incomming channel first if it has data
|
||||||
sth = recv_fut => {
|
sth = recv_fut => {
|
||||||
if let Some((id, prio, order_tag, data)) = sth {
|
match sth {
|
||||||
trace!("send_loop({}): add stream {} to send", debug_name, id);
|
Some(SendItem::Stream(id, prio, order_tag, data)) => {
|
||||||
sending.push(SendQueueItem {
|
trace!("send_loop({}): add stream {} to send", debug_name, id);
|
||||||
id,
|
sending.push(SendQueueItem {
|
||||||
prio,
|
id,
|
||||||
order_tag,
|
prio,
|
||||||
data: ByteStreamReader::new(data),
|
order_tag,
|
||||||
});
|
data: ByteStreamReader::new(data),
|
||||||
} else {
|
})
|
||||||
msg_recv = None;
|
}
|
||||||
|
Some(SendItem::Cancel(id)) => {
|
||||||
|
trace!("send_loop({}): cancelling {}", debug_name, id);
|
||||||
|
sending.remove(id);
|
||||||
|
let header_id = RequestID::to_be_bytes(id);
|
||||||
|
write.write_all(&header_id[..]).await?;
|
||||||
|
write.write_all(&ChunkLength::to_be_bytes(CANCEL_REQUEST)).await?;
|
||||||
|
write.flush().await?;
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
msg_recv = None;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
(id, data) = send_fut => {
|
(id, data) = send_fut => {
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use arc_swap::ArcSwapOption;
|
use arc_swap::ArcSwapOption;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
@ -53,7 +54,8 @@ pub(crate) struct ServerConn {
|
||||||
|
|
||||||
netapp: Arc<NetApp>,
|
netapp: Arc<NetApp>,
|
||||||
|
|
||||||
resp_send: ArcSwapOption<mpsc::UnboundedSender<SendStream>>,
|
resp_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>,
|
||||||
|
running_handlers: Mutex<HashMap<RequestID, tokio::task::JoinHandle<()>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ServerConn {
|
impl ServerConn {
|
||||||
|
@ -99,6 +101,7 @@ impl ServerConn {
|
||||||
remote_addr,
|
remote_addr,
|
||||||
peer_id,
|
peer_id,
|
||||||
resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))),
|
resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))),
|
||||||
|
running_handlers: Mutex::new(HashMap::new()),
|
||||||
});
|
});
|
||||||
|
|
||||||
netapp.connected_as_server(peer_id, conn.clone());
|
netapp.connected_as_server(peer_id, conn.clone());
|
||||||
|
@ -174,10 +177,15 @@ impl SendLoop for ServerConn {}
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RecvLoop for ServerConn {
|
impl RecvLoop for ServerConn {
|
||||||
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
|
fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
|
||||||
let resp_send = self.resp_send.load_full().unwrap();
|
let resp_send = match self.resp_send.load_full() {
|
||||||
|
Some(c) => c,
|
||||||
|
None => return,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut rh = self.running_handlers.lock().unwrap();
|
||||||
|
|
||||||
let self2 = self.clone();
|
let self2 = self.clone();
|
||||||
tokio::spawn(async move {
|
let jh = tokio::spawn(async move {
|
||||||
debug!("server: recv_handler got {}", id);
|
debug!("server: recv_handler got {}", id);
|
||||||
|
|
||||||
let (prio, resp_enc_result) = match ReqEnc::decode(stream).await {
|
let (prio, resp_enc_result) = match ReqEnc::decode(stream).await {
|
||||||
|
@ -189,9 +197,26 @@ impl RecvLoop for ServerConn {
|
||||||
|
|
||||||
let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result);
|
let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result);
|
||||||
resp_send
|
resp_send
|
||||||
.send((id, prio, resp_order, resp_stream))
|
.send(SendItem::Stream(id, prio, resp_order, resp_stream))
|
||||||
.log_err("ServerConn recv_handler send resp bytes");
|
.log_err("ServerConn recv_handler send resp bytes");
|
||||||
Ok::<_, Error>(())
|
|
||||||
|
self2.running_handlers.lock().unwrap().remove(&id);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
rh.insert(id, jh);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cancel_handler(self: &Arc<Self>, id: RequestID) {
|
||||||
|
trace!("received cancel for request {}", id);
|
||||||
|
|
||||||
|
// If the handler is still running, abort it now
|
||||||
|
if let Some(jh) = self.running_handlers.lock().unwrap().remove(&id) {
|
||||||
|
jh.abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inform the response sender that we don't need to send the response
|
||||||
|
if let Some(resp_send) = self.resp_send.load_full() {
|
||||||
|
let _ = resp_send.send(SendItem::Cancel(id));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue