use a framing protocol instead of even/odd channel

This commit is contained in:
trinity-1686a 2022-06-20 23:40:31 +02:00
parent 0fec85b47a
commit d3d18b8e8b
5 changed files with 192 additions and 233 deletions

View file

@ -37,10 +37,11 @@ pub(crate) struct ClientConn {
pub(crate) remote_addr: SocketAddr, pub(crate) remote_addr: SocketAddr,
pub(crate) peer_id: NodeID, pub(crate) peer_id: NodeID,
query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Data)>>, query_send:
ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>,
next_query_number: AtomicU32, next_query_number: AtomicU32,
inflight: Mutex<HashMap<RequestID, oneshot::Sender<(Vec<u8>, AssociatedStream)>>>, inflight: Mutex<HashMap<RequestID, oneshot::Sender<AssociatedStream>>>,
} }
impl ClientConn { impl ClientConn {
@ -148,11 +149,9 @@ impl ClientConn {
{ {
let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
// increment by 2; even are direct data; odd are associated stream
let id = self let id = self
.next_query_number .next_query_number
.fetch_add(2, atomic::Ordering::Relaxed); .fetch_add(1, atomic::Ordering::Relaxed);
let stream_id = id + 1;
cfg_if::cfg_if! { cfg_if::cfg_if! {
if #[cfg(feature = "telemetry")] { if #[cfg(feature = "telemetry")] {
@ -187,10 +186,7 @@ impl ClientConn {
error!( error!(
"Too many inflight requests! RequestID collision. Interrupting previous request." "Too many inflight requests! RequestID collision. Interrupting previous request."
); );
if old_ch if old_ch.send(Box::pin(futures::stream::empty())).is_err() {
.send((vec![], Box::pin(futures::stream::empty())))
.is_err()
{
debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
} }
} }
@ -200,22 +196,18 @@ impl ClientConn {
#[cfg(feature = "telemetry")] #[cfg(feature = "telemetry")]
span.set_attribute(KeyValue::new("len_query", bytes.len() as i64)); span.set_attribute(KeyValue::new("len_query", bytes.len() as i64));
query_send.send((id, prio, Data::Full(bytes)))?; query_send.send((id, prio, Framing::new(bytes, stream).into_stream()))?;
if let Some(stream) = stream {
query_send.send((stream_id, prio | PRIO_SECONDARY, Data::Streaming(stream)))?;
} else {
query_send.send((stream_id, prio, Data::Full(Vec::new())))?;
}
cfg_if::cfg_if! { cfg_if::cfg_if! {
if #[cfg(feature = "telemetry")] { if #[cfg(feature = "telemetry")] {
let (resp, stream) = resp_recv let stream = resp_recv
.with_context(Context::current_with_span(span)) .with_context(Context::current_with_span(span))
.await?; .await?;
} else { } else {
let (resp, stream) = resp_recv.await?; let stream = resp_recv.await?;
} }
} }
let (resp, stream) = Framing::from_stream(stream).await?.into_parts();
if resp.is_empty() { if resp.is_empty() {
return Err(Error::Message( return Err(Error::Message(
@ -240,12 +232,12 @@ impl SendLoop for ClientConn {}
#[async_trait] #[async_trait]
impl RecvLoop for ClientConn { impl RecvLoop for ClientConn {
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>, stream: AssociatedStream) { fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream) {
trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len()); trace!("ClientConn recv_handler {}", id);
let mut inflight = self.inflight.lock().unwrap(); let mut inflight = self.inflight.lock().unwrap();
if let Some(ch) = inflight.remove(&id) { if let Some(ch) = inflight.remove(&id) {
if ch.send((msg, stream)).is_err() { if ch.send(stream).is_err() {
debug!("Could not send request response, probably because request was interrupted. Dropping response."); debug!("Could not send request response, probably because request was interrupted. Dropping response.");
} }
} }

View file

@ -23,7 +23,6 @@ pub trait Message: SerializeMessage + Send + Sync {
pub trait SerializeMessage: Sized { pub trait SerializeMessage: Sized {
type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send; type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
// TODO should return Result
fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>); fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>);
// TODO should return Result // TODO should return Result

View file

@ -25,6 +25,9 @@ pub enum Error {
#[error(display = "UTF8 error: {}", _0)] #[error(display = "UTF8 error: {}", _0)]
UTF8(#[error(source)] std::string::FromUtf8Error), UTF8(#[error(source)] std::string::FromUtf8Error),
#[error(display = "Framing protocol error")]
Framing,
#[error(display = "{}", _0)] #[error(display = "{}", _0)]
Message(String), Message(String),
@ -50,6 +53,7 @@ impl Error {
Self::RMPEncode(_) => 10, Self::RMPEncode(_) => 10,
Self::RMPDecode(_) => 11, Self::RMPDecode(_) => 11,
Self::UTF8(_) => 12, Self::UTF8(_) => 12,
Self::Framing => 13,
Self::NoHandler => 20, Self::NoHandler => 20,
Self::ConnectionClosed => 21, Self::ConnectionClosed => 21,
Self::Handshake(_) => 30, Self::Handshake(_) => 30,

View file

@ -3,11 +3,11 @@ use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use log::{trace, warn}; use log::trace;
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::channel::mpsc::{unbounded, UnboundedSender};
use futures::Stream;
use futures::{AsyncReadExt, AsyncWriteExt}; use futures::{AsyncReadExt, AsyncWriteExt};
use futures::{Stream, StreamExt};
use kuska_handshake::async_std::BoxStreamWrite; use kuska_handshake::async_std::BoxStreamWrite;
use tokio::sync::mpsc; use tokio::sync::mpsc;
@ -63,39 +63,24 @@ struct SendQueueItem {
data: DataReader, data: DataReader,
} }
pub(crate) enum Data { #[pin_project::pin_project]
Full(Vec<u8>), struct DataReader {
Streaming(AssociatedStream),
}
#[pin_project::pin_project(project = DataReaderProj)]
enum DataReader {
Full {
#[pin]
data: Vec<u8>,
pos: usize,
},
Streaming {
#[pin] #[pin]
reader: AssociatedStream, reader: AssociatedStream,
packet: Result<Vec<u8>, u8>, packet: Result<Vec<u8>, u8>,
pos: usize, pos: usize,
buf: Vec<u8>, buf: Vec<u8>,
eos: bool, eos: bool,
},
} }
impl From<Data> for DataReader { impl From<AssociatedStream> for DataReader {
fn from(data: Data) -> DataReader { fn from(data: AssociatedStream) -> DataReader {
match data { DataReader {
Data::Full(data) => DataReader::Full { data, pos: 0 }, reader: data,
Data::Streaming(reader) => DataReader::Streaming {
reader,
packet: Ok(Vec::new()), packet: Ok(Vec::new()),
pos: 0, pos: 0,
buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize),
eos: false, eos: false,
},
} }
} }
} }
@ -155,83 +140,61 @@ impl Stream for DataReader {
type Item = DataReaderItem; type Item = DataReaderItem;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.project() { let mut this = self.project();
DataReaderProj::Full { data, pos } => {
let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, data.len() - *pos);
let end = *pos + len;
if len == 0 { if *this.eos {
Poll::Ready(None)
} else {
let mut body = [0; MAX_CHUNK_LENGTH as usize];
body[..len].copy_from_slice(&data[*pos..end]);
*pos = end;
Poll::Ready(Some(DataReaderItem {
data: DataFrame::Data { data: body, len },
may_have_more: end < data.len(),
}))
}
}
DataReaderProj::Streaming {
mut reader,
packet: res_packet,
pos,
buf,
eos,
} => {
if *eos {
// eos was reached at previous call to poll_next, where a partial packet // eos was reached at previous call to poll_next, where a partial packet
// was returned. Now return None // was returned. Now return None
return Poll::Ready(None); return Poll::Ready(None);
} }
loop { loop {
let packet = match res_packet { let packet = match this.packet {
Ok(v) => v, Ok(v) => v,
Err(e) => { Err(e) => {
let e = *e; let e = *e;
*res_packet = Ok(Vec::new()); *this.packet = Ok(Vec::new());
return Poll::Ready(Some(DataReaderItem { return Poll::Ready(Some(DataReaderItem {
data: DataFrame::Error(e), data: DataFrame::Error(e),
may_have_more: true, may_have_more: true,
})); }));
} }
}; };
let packet_left = packet.len() - *pos; let packet_left = packet.len() - *this.pos;
let buf_left = MAX_CHUNK_LENGTH as usize - buf.len(); let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len();
let to_read = std::cmp::min(buf_left, packet_left); let to_read = std::cmp::min(buf_left, packet_left);
buf.extend_from_slice(&packet[*pos..*pos + to_read]); this.buf
*pos += to_read; .extend_from_slice(&packet[*this.pos..*this.pos + to_read]);
if buf.len() == MAX_CHUNK_LENGTH as usize { *this.pos += to_read;
if this.buf.len() == MAX_CHUNK_LENGTH as usize {
// we have a full buf, ready to send // we have a full buf, ready to send
break; break;
} }
// we don't have a full buf, packet is empty; try receive more // we don't have a full buf, packet is empty; try receive more
if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) { if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) {
*res_packet = p; *this.packet = p;
*pos = 0; *this.pos = 0;
// if buf is empty, we will loop and return the error directly. If buf // if buf is empty, we will loop and return the error directly. If buf
// isn't empty, send it before by breaking. // isn't empty, send it before by breaking.
if res_packet.is_err() && !buf.is_empty() { if this.packet.is_err() && !this.buf.is_empty() {
break; break;
} }
} else { } else {
*eos = true; *this.eos = true;
break; break;
} }
} }
let mut body = [0; MAX_CHUNK_LENGTH as usize]; let mut body = [0; MAX_CHUNK_LENGTH as usize];
let len = buf.len(); let len = this.buf.len();
body[..len].copy_from_slice(buf); body[..len].copy_from_slice(this.buf);
buf.clear(); this.buf.clear();
Poll::Ready(Some(DataReaderItem { Poll::Ready(Some(DataReaderItem {
data: DataFrame::Data { data: body, len }, data: DataFrame::Data { data: body, len },
may_have_more: !*eos, may_have_more: !*this.eos,
})) }))
} }
}
}
} }
struct SendQueue { struct SendQueue {
@ -334,7 +297,7 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> {
pub(crate) trait SendLoop: Sync { pub(crate) trait SendLoop: Sync {
async fn send_loop<W>( async fn send_loop<W>(
self: Arc<Self>, self: Arc<Self>,
mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Data)>, mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>,
mut write: BoxStreamWrite<W>, mut write: BoxStreamWrite<W>,
) -> Result<(), Error> ) -> Result<(), Error>
where where
@ -380,38 +343,82 @@ pub(crate) trait SendLoop: Sync {
} }
} }
struct ChannelPair { pub(crate) struct Framing {
receiver: Option<UnboundedReceiver<Vec<u8>>>, direct: Vec<u8>,
sender: Option<UnboundedSender<Vec<u8>>>, stream: Option<AssociatedStream>,
} }
impl ChannelPair { impl Framing {
fn take_receiver(&mut self) -> Option<UnboundedReceiver<Vec<u8>>> { pub fn new(direct: Vec<u8>, stream: Option<AssociatedStream>) -> Self {
self.receiver.take() assert!(direct.len() <= u32::MAX as usize);
Framing { direct, stream }
} }
fn take_sender(&mut self) -> Option<UnboundedSender<Vec<u8>>> { pub fn into_stream(self) -> AssociatedStream {
self.sender.take() use futures::stream;
let len = self.direct.len() as u32;
// required because otherwise the borrow-checker complains
let Framing { direct, stream } = self;
let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) })
.chain(stream::once(async move { Ok(direct) }));
if let Some(stream) = stream {
Box::pin(res.chain(stream))
} else {
Box::pin(res)
}
} }
fn ref_sender(&mut self) -> Option<&UnboundedSender<Vec<u8>>> { pub async fn from_stream<S: Stream<Item = Result<Vec<u8>, u8>> + Unpin + Send + 'static>(
self.sender.as_ref().take() mut stream: S,
) -> Result<Self, Error> {
let mut packet = stream
.next()
.await
.ok_or(Error::Framing)?
.map_err(|_| Error::Framing)?;
if packet.len() < 4 {
return Err(Error::Framing);
} }
fn insert_into(self, map: &mut HashMap<RequestID, ChannelPair>, index: RequestID) { let mut len = [0; 4];
if self.receiver.is_some() || self.sender.is_some() { len.copy_from_slice(&packet[..4]);
map.insert(index, self); let len = u32::from_be_bytes(len);
} packet.drain(..4);
}
}
impl Default for ChannelPair { let mut buffer = Vec::new();
fn default() -> Self { let len = len as usize;
let (send, recv) = unbounded(); loop {
ChannelPair { let max_cp = std::cmp::min(len - buffer.len(), packet.len());
receiver: Some(recv),
sender: Some(send), buffer.extend_from_slice(&packet[..max_cp]);
if buffer.len() == len {
packet.drain(..max_cp);
break;
} }
packet = stream
.next()
.await
.ok_or(Error::Framing)?
.map_err(|_| Error::Framing)?;
}
let stream: AssociatedStream = if packet.is_empty() {
Box::pin(stream)
} else {
Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream))
};
Ok(Framing {
direct: buffer,
stream: Some(stream),
})
}
pub fn into_parts(self) -> (Vec<u8>, AssociatedStream) {
let Framing { direct, stream } = self;
(direct, stream.unwrap_or(Box::pin(futures::stream::empty())))
} }
} }
@ -424,14 +431,13 @@ impl Default for ChannelPair {
/// the full message is passed to the receive handler. /// the full message is passed to the receive handler.
#[async_trait] #[async_trait]
pub(crate) trait RecvLoop: Sync + 'static { pub(crate) trait RecvLoop: Sync + 'static {
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>, stream: AssociatedStream); fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream);
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
where where
R: AsyncReadExt + Unpin + Send + Sync, R: AsyncReadExt + Unpin + Send + Sync,
{ {
let mut receiving: HashMap<RequestID, Vec<u8>> = HashMap::new(); let mut streams: HashMap<RequestID, UnboundedSender<Result<Vec<u8>, u8>>> = HashMap::new();
let mut streams: HashMap<RequestID, ChannelPair> = HashMap::new();
loop { loop {
trace!("recv_loop: reading packet"); trace!("recv_loop: reading packet");
let mut header_id = [0u8; RequestID::BITS as usize / 8]; let mut header_id = [0u8; RequestID::BITS as usize / 8];
@ -450,55 +456,30 @@ pub(crate) trait RecvLoop: Sync + 'static {
let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
let is_error = (size & ERROR_MARKER) != 0; let is_error = (size & ERROR_MARKER) != 0;
let size = if !is_error { let packet = if is_error {
size & !CHUNK_HAS_CONTINUATION Err(size as u8)
} else { } else {
0 let size = size & !CHUNK_HAS_CONTINUATION;
};
// TODO propagate errors
let mut next_slice = vec![0; size as usize]; let mut next_slice = vec![0; size as usize];
read.read_exact(&mut next_slice[..]).await?; read.read_exact(&mut next_slice[..]).await?;
trace!("recv_loop: read {} bytes", next_slice.len()); trace!("recv_loop: read {} bytes", next_slice.len());
Ok(next_slice)
};
if id & 1 == 0 { let sender = if let Some(send) = streams.remove(&(id)) {
// main stream send
let mut msg_bytes = receiving.remove(&id).unwrap_or_default();
msg_bytes.extend_from_slice(&next_slice[..]);
if has_cont {
receiving.insert(id, msg_bytes);
} else { } else {
let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); let (send, recv) = unbounded();
self.recv_handler(id, Box::pin(recv));
if let Some(receiver) = channel_pair.take_receiver() { send
use futures::StreamExt; };
self.recv_handler(id, msg_bytes, Box::pin(receiver.map(|v| Ok(v))));
} else {
warn!("Couldn't take receiver part of stream")
}
channel_pair.insert_into(&mut streams, id | 1);
}
} else {
// associated stream
let mut channel_pair = streams.remove(&(id)).unwrap_or_default();
// if we get an error, the receiving end is disconnected. We still need to // if we get an error, the receiving end is disconnected. We still need to
// reach eos before dropping this sender // reach eos before dropping this sender
if !next_slice.is_empty() { let _ = sender.unbounded_send(packet);
if let Some(sender) = channel_pair.ref_sender() {
let _ = sender.unbounded_send(next_slice);
} else {
warn!("Couldn't take sending part of stream")
}
}
if !has_cont { if has_cont {
channel_pair.take_sender(); streams.insert(id, sender);
}
channel_pair.insert_into(&mut streams, id);
} }
} }
Ok(()) Ok(())
@ -509,55 +490,44 @@ pub(crate) trait RecvLoop: Sync + 'static {
mod test { mod test {
use super::*; use super::*;
fn empty_data() -> DataReader {
type Item = Result<Vec<u8>, u8>;
let stream: Pin<Box<dyn futures::Stream<Item = Item> + Send + 'static>> =
Box::pin(futures::stream::empty::<Result<Vec<u8>, u8>>());
stream.into()
}
#[test] #[test]
fn test_priority_queue() { fn test_priority_queue() {
let i1 = SendQueueItem { let i1 = SendQueueItem {
id: 1, id: 1,
prio: PRIO_NORMAL, prio: PRIO_NORMAL,
data: DataReader::Full { data: empty_data(),
data: vec![],
pos: 0,
},
}; };
let i2 = SendQueueItem { let i2 = SendQueueItem {
id: 2, id: 2,
prio: PRIO_HIGH, prio: PRIO_HIGH,
data: DataReader::Full { data: empty_data(),
data: vec![],
pos: 0,
},
}; };
let i2bis = SendQueueItem { let i2bis = SendQueueItem {
id: 20, id: 20,
prio: PRIO_HIGH, prio: PRIO_HIGH,
data: DataReader::Full { data: empty_data(),
data: vec![],
pos: 0,
},
}; };
let i3 = SendQueueItem { let i3 = SendQueueItem {
id: 3, id: 3,
prio: PRIO_HIGH | PRIO_SECONDARY, prio: PRIO_HIGH | PRIO_SECONDARY,
data: DataReader::Full { data: empty_data(),
data: vec![],
pos: 0,
},
}; };
let i4 = SendQueueItem { let i4 = SendQueueItem {
id: 4, id: 4,
prio: PRIO_BACKGROUND | PRIO_SECONDARY, prio: PRIO_BACKGROUND | PRIO_SECONDARY,
data: DataReader::Full { data: empty_data(),
data: vec![],
pos: 0,
},
}; };
let i5 = SendQueueItem { let i5 = SendQueueItem {
id: 5, id: 5,
prio: PRIO_BACKGROUND | PRIO_PRIMARY, prio: PRIO_BACKGROUND | PRIO_PRIMARY,
data: DataReader::Full { data: empty_data(),
data: vec![],
pos: 0,
},
}; };
let mut q = SendQueue::new(); let mut q = SendQueue::new();

View file

@ -2,7 +2,6 @@ use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use arc_swap::ArcSwapOption; use arc_swap::ArcSwapOption;
use bytes::Bytes;
use log::{debug, trace}; use log::{debug, trace};
#[cfg(feature = "telemetry")] #[cfg(feature = "telemetry")]
@ -55,7 +54,7 @@ pub(crate) struct ServerConn {
netapp: Arc<NetApp>, netapp: Arc<NetApp>,
resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Data)>>, resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>,
} }
impl ServerConn { impl ServerConn {
@ -177,13 +176,13 @@ impl SendLoop for ServerConn {}
#[async_trait] #[async_trait]
impl RecvLoop for ServerConn { impl RecvLoop for ServerConn {
fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>, stream: AssociatedStream) { fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream) {
let resp_send = self.resp_send.load_full().unwrap(); let resp_send = self.resp_send.load_full().unwrap();
let self2 = self.clone(); let self2 = self.clone();
tokio::spawn(async move { tokio::spawn(async move {
trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); trace!("ServerConn recv_handler {}", id);
let bytes: Bytes = bytes.into(); let (bytes, stream) = Framing::from_stream(stream).await?.into_parts();
let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
let resp = self2.recv_handler_aux(&bytes[..], stream).await; let resp = self2.recv_handler_aux(&bytes[..], stream).await;
@ -204,18 +203,13 @@ impl RecvLoop for ServerConn {
trace!("ServerConn sending response to {}: ", id); trace!("ServerConn sending response to {}: ", id);
resp_send resp_send
.send((id, prio, Data::Full(resp_bytes))) .send((
id,
prio,
Framing::new(resp_bytes, resp_stream).into_stream(),
))
.log_err("ServerConn recv_handler send resp bytes"); .log_err("ServerConn recv_handler send resp bytes");
Ok::<_, Error>(())
if let Some(resp_stream) = resp_stream {
resp_send
.send((id + 1, prio, Data::Streaming(resp_stream)))
.log_err("ServerConn recv_handler send resp stream");
} else {
resp_send
.send((id + 1, prio, Data::Full(Vec::new())))
.log_err("ServerConn recv_handler send resp stream");
}
}); });
} }
} }