bind streaming codec to hyper 1.x

This commit is contained in:
Quentin 2024-03-16 16:48:46 +01:00
parent 3abdafb0db
commit 902d33c434
Signed by: quentin
GPG key ID: E9602264D639FF68
8 changed files with 186 additions and 39 deletions

12
Cargo.lock generated
View file

@ -91,6 +91,7 @@ dependencies = [
"hyper-util", "hyper-util",
"imap-codec", "imap-codec",
"imap-flow", "imap-flow",
"quick-xml",
"rustls 0.22.2", "rustls 0.22.2",
"rustls-pemfile 2.1.1", "rustls-pemfile 2.1.1",
"smtp-message", "smtp-message",
@ -98,6 +99,7 @@ dependencies = [
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-rustls 0.25.0", "tokio-rustls 0.25.0",
"tokio-stream",
"tokio-util", "tokio-util",
"tracing", "tracing",
] ]
@ -1824,12 +1826,12 @@ dependencies = [
[[package]] [[package]]
name = "http-body-util" name = "http-body-util"
version = "0.1.0" version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840" checksum = "0475f8b2ac86659c21b64320d5d653f9efe42acd2a4e560073ec61a155a34f1d"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-util", "futures-core",
"http 1.1.0", "http 1.1.0",
"http-body 1.0.0", "http-body 1.0.0",
"pin-project-lite 0.2.13", "pin-project-lite 0.2.13",
@ -3381,9 +3383,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.35.1" version = "1.36.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931"
dependencies = [ dependencies = [
"backtrace", "backtrace",
"bytes", "bytes",

View file

@ -24,8 +24,9 @@ aero-proto = { version = "0.3.0", path = "aero-proto" }
aerogramme = { version = "0.3.0", path = "aerogramme" } aerogramme = { version = "0.3.0", path = "aerogramme" }
# async runtime # async runtime
tokio = { version = "1.18", default-features = false, features = ["rt", "rt-multi-thread", "io-util", "net", "time", "macros", "sync", "signal", "fs"] } tokio = { version = "1.36", default-features = false, features = ["rt", "rt-multi-thread", "io-util", "net", "time", "macros", "sync", "signal", "fs"] }
tokio-util = { version = "0.7", features = [ "compat" ] } tokio-util = { version = "0.7", features = [ "compat" ] }
tokio-stream = { version = "0.1" }
futures = "0.3" futures = "0.3"
# debug # debug
@ -57,8 +58,8 @@ imap-codec = { version = "2.0.0", features = ["bounded-static", "ext_condstore_q
imap-flow = { git = "https://github.com/duesee/imap-flow.git", branch = "main" } imap-flow = { git = "https://github.com/duesee/imap-flow.git", branch = "main" }
# http & web # http & web
http = "1.0" http = "1.1"
http-body-util = "0.1" http-body-util = "0.1.1"
hyper = "1.2" hyper = "1.2"
hyper-rustls = { version = "0.26", features = ["http2"] } hyper-rustls = { version = "0.26", features = ["http2"] }
hyper-util = { version = "0.1", features = ["full"] } hyper-util = { version = "0.1", features = ["full"] }

View file

@ -393,7 +393,14 @@ impl QWrite for CompKind {
for comp in many_comp.iter() { for comp in many_comp.iter() {
// Required: recursion in an async fn requires boxing // Required: recursion in an async fn requires boxing
// rustc --explain E0733 // rustc --explain E0733
Box::pin(comp.qwrite(xml)).await?; // Cycle detected when computing type of ...
// For more information about this error, try `rustc --explain E0391`.
// https://github.com/rust-lang/rust/issues/78649
#[inline(always)]
fn recurse<'a>(comp: &'a Comp, xml: &'a mut Writer<impl IWrite>) -> futures::future::BoxFuture<'a, Result<(), QError>> {
Box::pin(comp.qwrite(xml))
}
recurse(comp, xml).await?;
} }
Ok(()) Ok(())
} }
@ -525,7 +532,14 @@ impl QWrite for CompFilterMatch {
for comp_item in self.comp_filter.iter() { for comp_item in self.comp_filter.iter() {
// Required: recursion in an async fn requires boxing // Required: recursion in an async fn requires boxing
// rustc --explain E0733 // rustc --explain E0733
Box::pin(comp_item.qwrite(xml)).await?; // Cycle detected when computing type of ...
// For more information about this error, try `rustc --explain E0391`.
// https://github.com/rust-lang/rust/issues/78649
#[inline(always)]
fn recurse<'a>(comp: &'a CompFilter, xml: &'a mut Writer<impl IWrite>) -> futures::future::BoxFuture<'a, Result<(), QError>> {
Box::pin(comp.qwrite(xml))
}
recurse(comp_item, xml).await?;
} }
Ok(()) Ok(())
} }

View file

@ -15,6 +15,25 @@ pub enum ParsingError {
Int(std::num::ParseIntError), Int(std::num::ParseIntError),
Eof Eof
} }
impl std::fmt::Display for ParsingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Recoverable => write!(f, "Recoverable"),
Self::MissingChild => write!(f, "Missing child"),
Self::MissingAttribute => write!(f, "Missing attribute"),
Self::NamespacePrefixAlreadyUsed => write!(f, "Namespace prefix already used"),
Self::WrongToken => write!(f, "Wrong token"),
Self::TagNotFound => write!(f, "Tag not found"),
Self::InvalidValue => write!(f, "Invalid value"),
Self::Utf8Error(_) => write!(f, "Utf8 Error"),
Self::QuickXml(_) => write!(f, "Quick XML error"),
Self::Chrono(_) => write!(f, "Chrono error"),
Self::Int(_) => write!(f, "Number parsing error"),
Self::Eof => write!(f, "Found EOF while expecting data"),
}
}
}
impl std::error::Error for ParsingError {}
impl From<AttrError> for ParsingError { impl From<AttrError> for ParsingError {
fn from(value: AttrError) -> Self { fn from(value: AttrError) -> Self {
Self::QuickXml(value.into()) Self::QuickXml(value.into())

View file

@ -11,8 +11,8 @@ impl xml::QRead<Disabled> for Disabled {
} }
} }
impl xml::QWrite for Disabled { impl xml::QWrite for Disabled {
async fn qwrite(&self, _xml: &mut xml::Writer<impl xml::IWrite>) -> Result<(), quick_xml::Error> { fn qwrite(&self, _xml: &mut xml::Writer<impl xml::IWrite>) -> impl futures::Future<Output = Result<(), quick_xml::Error>> + Send {
unreachable!(); async { unreachable!(); }
} }
} }

View file

@ -12,19 +12,19 @@ pub const CAL_URN: &[u8] = b"urn:ietf:params:xml:ns:caldav";
pub const CARD_URN: &[u8] = b"urn:ietf:params:xml:ns:carddav"; pub const CARD_URN: &[u8] = b"urn:ietf:params:xml:ns:carddav";
// Async traits // Async traits
pub trait IWrite = AsyncWrite + Unpin; pub trait IWrite = AsyncWrite + Unpin + Send;
pub trait IRead = AsyncBufRead + Unpin; pub trait IRead = AsyncBufRead + Unpin;
// Serialization/Deserialization traits // Serialization/Deserialization traits
pub trait QWrite { pub trait QWrite {
fn qwrite(&self, xml: &mut Writer<impl IWrite>) -> impl Future<Output = Result<(), quick_xml::Error>>; fn qwrite(&self, xml: &mut Writer<impl IWrite>) -> impl Future<Output = Result<(), quick_xml::Error>> + Send;
} }
pub trait QRead<T> { pub trait QRead<T> {
fn qread(xml: &mut Reader<impl IRead>) -> impl Future<Output = Result<T, ParsingError>>; fn qread(xml: &mut Reader<impl IRead>) -> impl Future<Output = Result<T, ParsingError>>;
} }
// The representation of an XML node in Rust // The representation of an XML node in Rust
pub trait Node<T> = QRead<T> + QWrite + std::fmt::Debug + PartialEq; pub trait Node<T> = QRead<T> + QWrite + std::fmt::Debug + PartialEq + Sync;
// --------------- // ---------------

View file

@ -22,6 +22,7 @@ futures.workspace = true
tokio.workspace = true tokio.workspace = true
tokio-util.workspace = true tokio-util.workspace = true
tokio-rustls.workspace = true tokio-rustls.workspace = true
tokio-stream.workspace = true
rustls.workspace = true rustls.workspace = true
rustls-pemfile.workspace = true rustls-pemfile.workspace = true
imap-codec.workspace = true imap-codec.workspace = true
@ -33,3 +34,4 @@ duplexify.workspace = true
smtp-message.workspace = true smtp-message.workspace = true
smtp-server.workspace = true smtp-server.workspace = true
tracing.workspace = true tracing.workspace = true
quick-xml.workspace = true

View file

@ -6,6 +6,8 @@ use base64::Engine;
use hyper::service::service_fn; use hyper::service::service_fn;
use hyper::{Request, Response, body::Bytes}; use hyper::{Request, Response, body::Bytes};
use hyper::server::conn::http1 as http; use hyper::server::conn::http1 as http;
use hyper::rt::{Read, Write};
use hyper::body::Incoming;
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use http_body_util::Full; use http_body_util::Full;
use futures::stream::{FuturesUnordered, StreamExt}; use futures::stream::{FuturesUnordered, StreamExt};
@ -13,13 +15,16 @@ use tokio::net::TcpListener;
use tokio::sync::watch; use tokio::sync::watch;
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use hyper::rt::{Read, Write};
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::AsyncWriteExt;
use rustls_pemfile::{certs, private_key}; use rustls_pemfile::{certs, private_key};
use aero_user::config::{DavConfig, DavUnsecureConfig}; use aero_user::config::{DavConfig, DavUnsecureConfig};
use aero_user::login::ArcLoginProvider; use aero_user::login::ArcLoginProvider;
use aero_collections::user::User; use aero_collections::user::User;
use aero_dav::types::{PropFind, Multistatus, PropValue, ResponseDescription};
use aero_dav::realization::{Core, Calendar};
use aero_dav::xml as dav;
pub struct Server { pub struct Server {
bind_addr: SocketAddr, bind_addr: SocketAddr,
@ -94,13 +99,10 @@ impl Server {
//abitrarily bound //abitrarily bound
//@FIXME replace with a handler supporting http2 and TLS //@FIXME replace with a handler supporting http2 and TLS
match http::Builder::new().serve_connection(stream, service_fn(|req: Request<hyper::body::Incoming>| { match http::Builder::new().serve_connection(stream, service_fn(|req: Request<hyper::body::Incoming>| {
let login = login.clone(); let login = login.clone();
tracing::info!("{:?} {:?}", req.method(), req.uri()); tracing::info!("{:?} {:?}", req.method(), req.uri());
async move { auth(login, req)
auth(login, req).await
}
})).await { })).await {
Err(e) => tracing::warn!(err=?e, "connection failed"), Err(e) => tracing::warn!(err=?e, "connection failed"),
Ok(()) => tracing::trace!("connection terminated with success"), Ok(()) => tracing::trace!("connection terminated with success"),
@ -127,11 +129,13 @@ impl Server {
} }
} }
use http_body_util::BodyExt;
//@FIXME We should not support only BasicAuth //@FIXME We should not support only BasicAuth
async fn auth( async fn auth(
login: ArcLoginProvider, login: ArcLoginProvider,
req: Request<impl hyper::body::Body>, req: Request<Incoming>,
) -> Result<Response<Full<Bytes>>> { ) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
tracing::info!("headers: {:?}", req.headers()); tracing::info!("headers: {:?}", req.headers());
let auth_val = match req.headers().get(hyper::header::AUTHORIZATION) { let auth_val = match req.headers().get(hyper::header::AUTHORIZATION) {
@ -141,7 +145,7 @@ async fn auth(
return Ok(Response::builder() return Ok(Response::builder()
.status(401) .status(401)
.header("WWW-Authenticate", "Basic realm=\"Aerogramme\"") .header("WWW-Authenticate", "Basic realm=\"Aerogramme\"")
.body(Full::new(Bytes::from("Missing Authorization field")))?) .body(text_body("Missing Authorization field"))?)
}, },
}; };
@ -151,7 +155,7 @@ async fn auth(
tracing::info!("Unsupported authorization field"); tracing::info!("Unsupported authorization field");
return Ok(Response::builder() return Ok(Response::builder()
.status(400) .status(400)
.body(Full::new(Bytes::from("Unsupported Authorization field")))?) .body(text_body("Unsupported Authorization field"))?)
}, },
}; };
@ -176,7 +180,7 @@ async fn auth(
return Ok(Response::builder() return Ok(Response::builder()
.status(401) .status(401)
.header("WWW-Authenticate", "Basic realm=\"Aerogramme\"") .header("WWW-Authenticate", "Basic realm=\"Aerogramme\"")
.body(Full::new(Bytes::from("Wrong credentials")))?) .body(text_body("Wrong credentials"))?)
}, },
}; };
@ -187,26 +191,131 @@ async fn auth(
router(user, req).await router(user, req).await
} }
async fn router(user: std::sync::Arc<User>, req: Request<impl hyper::body::Body>) -> Result<Response<Full<Bytes>>> { async fn router(user: std::sync::Arc<User>, req: Request<Incoming>) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
let path_segments: Vec<_> = req.uri().path().split("/").filter(|s| *s != "").collect(); let path = req.uri().path().to_string();
tracing::info!("router"); let path_segments: Vec<_> = path.split("/").filter(|s| *s != "").collect();
match path_segments.as_slice() { let method = req.method().as_str().to_uppercase();
[] => tracing::info!("root"),
[ username, ..] if *username != user.username => return Ok(Response::builder() match (method.as_str(), path_segments.as_slice()) {
("PROPFIND", []) => propfind_root(user, req).await,
(_, [ username, ..]) if *username != user.username => return Ok(Response::builder()
.status(403) .status(403)
.body(Full::new(Bytes::from("Accessing other user ressources is not allowed")))?), .body(text_body("Accessing other user ressources is not allowed"))?),
[ _ ] => tracing::info!("user home"), ("PROPFIND", [ _ ]) => propfind_home(user, &req).await,
[ _, "calendar" ] => tracing::info!("user calendars"), ("PROPFIND", [ _, "calendar" ]) => propfind_all_calendars(user, &req).await,
[ _, "calendar", colname ] => tracing::info!(name=colname, "selected calendar"), ("PROPFIND", [ _, "calendar", colname ]) => propfind_this_calendar(user, &req, colname).await,
[ _, "calendar", colname, member ] => tracing::info!(name=colname, obj=member, "selected event"), ("PROPFIND", [ _, "calendar", colname, event ]) => propfind_event(user, req, colname, event).await,
_ => return Ok(Response::builder() _ => return Ok(Response::builder()
.status(404) .status(501)
.body(Full::new(Bytes::from("Resource not found")))?), .body(text_body("Not implemented"))?),
} }
Ok(Response::new(Full::new(Bytes::from("Hello World!"))))
} }
/// <D:propfind xmlns:D='DAV:' xmlns:A='http://apple.com/ns/ical/'>
/// <D:prop><D:getcontenttype/><D:resourcetype/><D:displayname/><A:calendar-color/>
/// </D:prop></D:propfind>
async fn propfind_root(user: std::sync::Arc<User>, req: Request<Incoming>) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
tracing::info!("root");
let r = deserialize::<PropFind<Core>>(req).await?;
println!("r: {:?}", r);
serialize(Multistatus::<Core, PropValue<Core>> {
responses: vec![],
responsedescription: Some(ResponseDescription("hello world".to_string())),
})
}
async fn propfind_home(user: std::sync::Arc<User>, req: &Request<impl hyper::body::Body>) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
tracing::info!("user home");
Ok(Response::new(text_body("Hello World!")))
}
async fn propfind_all_calendars(user: std::sync::Arc<User>, req: &Request<impl hyper::body::Body>) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
tracing::info!("calendar");
Ok(Response::new(text_body("Hello World!")))
}
async fn propfind_this_calendar(
user: std::sync::Arc<User>,
req: &Request<Incoming>,
colname: &str
) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
tracing::info!(name=colname, "selected calendar");
Ok(Response::new(text_body("Hello World!")))
}
async fn propfind_event(
user: std::sync::Arc<User>,
req: Request<Incoming>,
colname: &str,
event: &str,
) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
tracing::info!(name=colname, obj=event, "selected event");
Ok(Response::new(text_body("Hello World!")))
}
#[allow(dead_code)] #[allow(dead_code)]
async fn collections(_user: std::sync::Arc<User>, _req: Request<impl hyper::body::Body>) -> Result<Response<Full<Bytes>>> { async fn collections(_user: std::sync::Arc<User>, _req: Request<impl hyper::body::Body>) -> Result<Response<Full<Bytes>>> {
unimplemented!(); unimplemented!();
} }
use futures::stream::TryStreamExt;
use http_body_util::{BodyStream, Empty};
use http_body_util::StreamBody;
use http_body_util::combinators::BoxBody;
use hyper::body::Frame;
use tokio_util::sync::PollSender;
use std::io::{Error, ErrorKind};
use futures::sink::SinkExt;
use tokio_util::io::{SinkWriter, CopyToBytes};
fn text_body(txt: &'static str) -> BoxBody<Bytes, std::io::Error> {
BoxBody::new(Full::new(Bytes::from(txt)).map_err(|e| match e {}))
}
fn serialize<T: dav::QWrite + Send + 'static>(elem: T) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
let (tx, rx) = tokio::sync::mpsc::channel::<Bytes>(1);
// Build the writer
tokio::task::spawn(async move {
let sink = PollSender::new(tx).sink_map_err(|_| Error::from(ErrorKind::BrokenPipe));
let mut writer = SinkWriter::new(CopyToBytes::new(sink));
let q = quick_xml::writer::Writer::new_with_indent(&mut writer, b' ', 4);
let ns_to_apply = vec![ ("xmlns:D".into(), "DAV:".into()) ];
let mut qwriter = dav::Writer { q, ns_to_apply };
match elem.qwrite(&mut qwriter).await {
Ok(_) => tracing::debug!("fully serialized object"),
Err(e) => tracing::error!(err=?e, "failed to serialize object"),
}
});
// Build the reader
let recv = tokio_stream::wrappers::ReceiverStream::new(rx);
let stream = StreamBody::new(recv.map(|v| Ok(Frame::data(v))));
let boxed_body = BoxBody::new(stream);
let response = Response::builder()
.status(hyper::StatusCode::OK)
.body(boxed_body)?;
Ok(response)
}
/// Deserialize a request body to an XML request
async fn deserialize<T: dav::Node<T>>(req: Request<Incoming>) -> Result<T> {
let stream_of_frames = BodyStream::new(req.into_body());
let stream_of_bytes = stream_of_frames
.try_filter_map(|frame| async move { Ok(frame.into_data().ok()) })
.map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err));
let async_read = tokio_util::io::StreamReader::new(stream_of_bytes);
let async_read = std::pin::pin!(async_read);
let mut rdr = dav::Reader::new(quick_xml::reader::NsReader::from_reader(async_read)).await?;
let parsed = rdr.find::<T>().await?;
Ok(parsed)
}