From 522f420e2bf30d5ef6f50dccb88adf86882ac7c6 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 1 Sep 2022 15:54:11 +0200 Subject: [PATCH] Implement request cancellation --- src/client.rs | 68 +++++++++++++++++++++++++++++++++++++++++++++++++-- src/recv.rs | 18 +++++++++++++- src/send.rs | 57 +++++++++++++++++++++++++++++++++--------- src/server.rs | 37 +++++++++++++++++++++++----- 4 files changed, 159 insertions(+), 21 deletions(-) diff --git a/src/client.rs b/src/client.rs index 9726125..d82c91e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,9 @@ use std::collections::HashMap; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::atomic::{self, AtomicU32}; use std::sync::{Arc, Mutex}; +use std::task::Poll; use arc_swap::ArcSwapOption; use async_trait::async_trait; @@ -9,6 +11,7 @@ use bytes::Bytes; use log::{debug, error, trace}; use futures::io::AsyncReadExt; +use futures::Stream; use kuska_handshake::async_std::{handshake_client, BoxStream}; use tokio::net::TcpStream; use tokio::select; @@ -35,7 +38,7 @@ pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption>, + query_send: ArcSwapOption>, next_query_number: AtomicU32, inflight: Mutex>>, @@ -193,7 +196,9 @@ impl ClientConn { #[cfg(feature = "telemetry")] span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64)); - query_send.send((id, prio, req_order, req_stream))?; + query_send.send(SendItem::Stream(id, prio, req_order, req_stream))?; + + let canceller = CancelOnDrop::new(id, query_send.as_ref().clone()); cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { @@ -205,6 +210,8 @@ impl ClientConn { } } + let stream = Box::pin(canceller.for_stream(stream)); + let resp_enc = RespEnc::decode(stream).await?; debug!("client: got response to request {} (path {})", id, path); Resp::from_enc(resp_enc) @@ -223,6 +230,63 @@ impl RecvLoop for ClientConn { if ch.send(stream).is_err() { debug!("Could not send request response, probably because request was interrupted. Dropping response."); } + } else { + debug!("Got unexpected response to request {}, dropping it", id); } } } + +// ---- + +struct CancelOnDrop { + id: RequestID, + query_send: mpsc::UnboundedSender, +} + +impl CancelOnDrop { + fn new(id: RequestID, query_send: mpsc::UnboundedSender) -> Self { + Self { id, query_send } + } + fn for_stream(self, stream: ByteStream) -> CancelOnDropStream { + CancelOnDropStream { + cancel: Some(self), + stream: stream, + } + } +} + +impl Drop for CancelOnDrop { + fn drop(&mut self) { + trace!("cancelling request {}", self.id); + let _ = self.query_send.send(SendItem::Cancel(self.id)); + } +} + +#[pin_project::pin_project] +struct CancelOnDropStream { + cancel: Option, + #[pin] + stream: ByteStream, +} + +impl Stream for CancelOnDropStream { + type Item = Packet; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.project(); + let res = this.stream.poll_next(cx); + if matches!(res, Poll::Ready(None)) { + if let Some(c) = this.cancel.take() { + std::mem::forget(c) + } + } + res + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} diff --git a/src/recv.rs b/src/recv.rs index ac93c4b..8909190 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -53,6 +53,7 @@ impl Drop for Sender { #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { fn recv_handler(self: &Arc, id: RequestID, stream: ByteStream); + fn cancel_handler(self: &Arc, _id: RequestID) {} async fn recv_loop(self: Arc, mut read: R, debug_name: String) -> Result<(), Error> where @@ -78,6 +79,18 @@ pub(crate) trait RecvLoop: Sync + 'static { read.read_exact(&mut header_size[..]).await?; let size = ChunkLength::from_be_bytes(header_size); + if size == CANCEL_REQUEST { + if let Some(mut stream) = streams.remove(&id) { + let _ = stream.send(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "netapp: cancel requested", + ))); + stream.end(); + } + self.cancel_handler(id); + continue; + } + let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let is_error = (size & ERROR_MARKER) != 0; let size = (size & CHUNK_LENGTH_MASK) as usize; @@ -88,7 +101,10 @@ pub(crate) trait RecvLoop: Sync + 'static { let kind = u8_to_io_errorkind(next_slice[0]); let msg = std::str::from_utf8(&next_slice[1..]).unwrap_or(""); - debug!("recv_loop({}): got id {}, error {:?}: {}", debug_name, id, kind, msg); + debug!( + "recv_loop({}): got id {}, error {:?}: {}", + debug_name, id, kind, msg + ); Some(Err(std::io::Error::new(kind, msg.to_string()))) } else { trace!( diff --git a/src/send.rs b/src/send.rs index d927d98..780bbcf 100644 --- a/src/send.rs +++ b/src/send.rs @@ -22,6 +22,7 @@ use crate::stream::*; // CHUNK_HAS_CONTINUATION when this is not the last chunk of the stream // ERROR_MARKER if this chunk denotes an error // (these two flags are exclusive, an error denotes the end of the stream) +// **special value** 0xFFFF indicates a CANCEL message // - [u8; chunk_length], either // - if not error: chunk data // - if error: @@ -35,8 +36,14 @@ pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF; +pub(crate) const CANCEL_REQUEST: ChunkLength = 0xFFFF; -pub(crate) type SendStream = (RequestID, RequestPriority, Option, ByteStream); +pub(crate) enum SendItem { + Stream(RequestID, RequestPriority, Option, ByteStream), + Cancel(RequestID), +} + +// ---- struct SendQueue { items: Vec<(u8, SendQueuePriority)>, @@ -71,6 +78,11 @@ impl SendQueue { }; self.items[pos_prio].1.push(item); } + fn remove(&mut self, id: RequestID) { + for (_, prioq) in self.items.iter_mut() { + prioq.remove(id); + } + } fn is_empty(&self) -> bool { self.items.iter().all(|(_k, v)| v.is_empty()) } @@ -96,6 +108,16 @@ impl SendQueuePriority { } self.items.push_back(item); } + fn remove(&mut self, id: RequestID) { + if let Some(i) = self.items.iter().position(|x| x.id == id) { + let item = self.items.remove(i).unwrap(); + if let Some(OrderTag(stream, order)) = item.order_tag { + let order_vec = self.order.get_mut(&stream).unwrap(); + let j = order_vec.iter().position(|x| *x == order).unwrap(); + order_vec.remove(j).unwrap(); + } + } + } fn is_empty(&self) -> bool { self.items.is_empty() } @@ -229,7 +251,7 @@ impl DataFrame { pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc, - msg_recv: mpsc::UnboundedReceiver, + msg_recv: mpsc::UnboundedReceiver, mut write: BoxStreamWrite, debug_name: String, ) -> Result<(), Error> @@ -264,16 +286,27 @@ pub(crate) trait SendLoop: Sync { tokio::select! { biased; // always read incomming channel first if it has data sth = recv_fut => { - if let Some((id, prio, order_tag, data)) = sth { - trace!("send_loop({}): add stream {} to send", debug_name, id); - sending.push(SendQueueItem { - id, - prio, - order_tag, - data: ByteStreamReader::new(data), - }); - } else { - msg_recv = None; + match sth { + Some(SendItem::Stream(id, prio, order_tag, data)) => { + trace!("send_loop({}): add stream {} to send", debug_name, id); + sending.push(SendQueueItem { + id, + prio, + order_tag, + data: ByteStreamReader::new(data), + }) + } + Some(SendItem::Cancel(id)) => { + trace!("send_loop({}): cancelling {}", debug_name, id); + sending.remove(id); + let header_id = RequestID::to_be_bytes(id); + write.write_all(&header_id[..]).await?; + write.write_all(&ChunkLength::to_be_bytes(CANCEL_REQUEST)).await?; + write.flush().await?; + } + None => { + msg_recv = None; + } }; } (id, data) = send_fut => { diff --git a/src/server.rs b/src/server.rs index 2c12d9d..f9eb121 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,5 +1,6 @@ +use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; use async_trait::async_trait; @@ -53,7 +54,8 @@ pub(crate) struct ServerConn { netapp: Arc, - resp_send: ArcSwapOption>, + resp_send: ArcSwapOption>, + running_handlers: Mutex>>, } impl ServerConn { @@ -99,6 +101,7 @@ impl ServerConn { remote_addr, peer_id, resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))), + running_handlers: Mutex::new(HashMap::new()), }); netapp.connected_as_server(peer_id, conn.clone()); @@ -174,10 +177,15 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { fn recv_handler(self: &Arc, id: RequestID, stream: ByteStream) { - let resp_send = self.resp_send.load_full().unwrap(); + let resp_send = match self.resp_send.load_full() { + Some(c) => c, + None => return, + }; + + let mut rh = self.running_handlers.lock().unwrap(); let self2 = self.clone(); - tokio::spawn(async move { + let jh = tokio::spawn(async move { debug!("server: recv_handler got {}", id); let (prio, resp_enc_result) = match ReqEnc::decode(stream).await { @@ -189,9 +197,26 @@ impl RecvLoop for ServerConn { let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result); resp_send - .send((id, prio, resp_order, resp_stream)) + .send(SendItem::Stream(id, prio, resp_order, resp_stream)) .log_err("ServerConn recv_handler send resp bytes"); - Ok::<_, Error>(()) + + self2.running_handlers.lock().unwrap().remove(&id); }); + + rh.insert(id, jh); + } + + fn cancel_handler(self: &Arc, id: RequestID) { + trace!("received cancel for request {}", id); + + // If the handler is still running, abort it now + if let Some(jh) = self.running_handlers.lock().unwrap().remove(&id) { + jh.abort(); + } + + // Inform the response sender that we don't need to send the response + if let Some(resp_send) = self.resp_send.load_full() { + let _ = resp_send.send(SendItem::Cancel(id)); + } } }