From 667b487427c1c16df3dd608c0c91b3b3b517b483 Mon Sep 17 00:00:00 2001 From: KokaKiwi Date: Mon, 23 May 2022 03:20:06 +0200 Subject: [PATCH] feat: Add proper Response type --- examples/simple.rs | 14 +++--- src/errors.rs | 33 +++++++++++++ src/lib.rs | 1 + src/proto/body.rs | 32 +++++++++++++ src/proto/mod.rs | 1 + src/proto/res.rs | 105 ++++++++++++++++++++++++++++++++++++++++- src/server/conn.rs | 48 +++++++++++++------ src/server/mod.rs | 19 +++----- src/server/pipeline.rs | 34 +++++++------ 9 files changed, 236 insertions(+), 51 deletions(-) create mode 100644 src/errors.rs create mode 100644 src/proto/body.rs diff --git a/examples/simple.rs b/examples/simple.rs index 7e4d440..c23f2a2 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -1,16 +1,18 @@ use miette::{IntoDiagnostic, Result}; + use boitalettres::proto::{Request, Response}; use boitalettres::server::accept::addr::{AddrIncoming, AddrStream}; use boitalettres::server::Server; async fn handle_req(req: Request) -> Result { - use imap_codec::types::response::Status; + use imap_codec::types::response::{Capability, Data}; tracing::debug!("Got request: {:#?}", req); - Ok(Response::Status( - Status::ok(Some(req.tag), None, "Ok").map_err(|e| eyre::eyre!(e))?, - )) + let capabilities = vec![Capability::Imap4Rev1, Capability::Idle]; + let body = vec![Data::Capability(capabilities)]; + + Ok(Response::ok("Done")?.with_body(body)) } #[tokio::main] @@ -24,9 +26,7 @@ async fn main() -> Result<()> { let make_service = tower::service_fn(|addr: &AddrStream| { tracing::debug!(remote_addr = %addr.remote_addr, local_addr = %addr.local_addr, "accept"); - let service = tower::ServiceBuilder::new() - .buffer(16) - .service_fn(handle_req); + let service = tower::ServiceBuilder::new().service_fn(handle_req); futures::future::ok::<_, std::convert::Infallible>(service) }); diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 0000000..c264799 --- /dev/null +++ b/src/errors.rs @@ -0,0 +1,33 @@ +use miette::Diagnostic; + +type BoxedError = tower::BoxError; + +#[derive(Debug, thiserror::Error, Diagnostic)] +pub enum Error { + #[error("Error occured when accepting new connections")] + #[diagnostic(code(boitalettres::accept))] + Accept(#[source] BoxedError), + + #[error("Error occured on service creation")] + #[diagnostic(code(boitalettres::make_service))] + MakeService(#[source] BoxedError), + + #[error("{0}")] + Text(String), +} + +impl Error { + pub(crate) fn accept>(err: E) -> Error { + Error::Accept(err.into()) + } + + pub(crate) fn make_service>(err: E) -> Error { + Error::MakeService(err.into()) + } + + pub(crate) fn text>(err: E) -> Error { + Error::Text(err.into()) + } +} + +pub type Result = std::result::Result; diff --git a/src/lib.rs b/src/lib.rs index c6eebbb..d8c4fb1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,2 +1,3 @@ +pub mod errors; pub mod proto; pub mod server; diff --git a/src/proto/body.rs b/src/proto/body.rs new file mode 100644 index 0000000..a4e9b58 --- /dev/null +++ b/src/proto/body.rs @@ -0,0 +1,32 @@ +use imap_codec::types::response::Data as ImapData; + +#[derive(Debug)] +pub enum Body { + Once(Vec), +} + +impl Body { + pub(crate) fn into_data(self) -> Vec { + match self { + Body::Once(data) => data, + } + } +} + +impl FromIterator for Body { + fn from_iter>(iter: T) -> Self { + Body::Once(Vec::from_iter(iter)) + } +} + +impl From> for Body { + fn from(data: Vec) -> Self { + Body::from_iter(data) + } +} + +impl From for Body { + fn from(data: ImapData) -> Self { + Body::from_iter([data]) + } +} diff --git a/src/proto/mod.rs b/src/proto/mod.rs index 852952c..c59663b 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -1,5 +1,6 @@ pub use self::req::Request; pub use self::res::Response; +pub mod body; pub mod req; pub mod res; diff --git a/src/proto/res.rs b/src/proto/res.rs index 1519ca5..c704c48 100644 --- a/src/proto/res.rs +++ b/src/proto/res.rs @@ -1 +1,104 @@ -pub type Response = imap_codec::types::response::Response; +use imap_codec::types::{ + core::{Tag, Text}, + response::{Code as ImapCode, Status as ImapStatus}, +}; + +use super::body::Body; +use crate::errors::{Error, Result}; + +#[derive(Debug)] +pub struct Response { + pub(crate) status: Status, + pub(crate) body: Option, +} + +impl Response { + pub fn status(code: StatusCode, msg: &str) -> Result { + Ok(Response { + status: Status::new(code, msg)?, + body: None, + }) + } + + pub fn ok(msg: &str) -> Result { + Self::status(StatusCode::Ok, msg) + } + + pub fn no(msg: &str) -> Result { + Self::status(StatusCode::No, msg) + } + + pub fn bad(msg: &str) -> Result { + Self::status(StatusCode::Bad, msg) + } + + pub fn bye(msg: &str) -> Result { + Self::status(StatusCode::Bye, msg) + } +} + +impl Response { + pub fn with_extra_code(mut self, extra: ImapCode) -> Self { + self.status.extra = Some(extra); + self + } + + pub fn with_body(mut self, body: impl Into) -> Self { + self.body = Some(body.into()); + self + } +} + +#[derive(Debug, Clone)] +pub struct Status { + pub(crate) code: StatusCode, + pub(crate) extra: Option, + pub(crate) text: Text, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StatusCode { + Ok, + No, + Bad, + PreAuth, + Bye, +} + +impl Status { + fn new(code: StatusCode, msg: &str) -> Result { + Ok(Status { + code, + extra: None, + text: msg.try_into().map_err(Error::text)?, + }) + } + + pub(crate) fn into_imap(self, tag: Option) -> ImapStatus { + match self.code { + StatusCode::Ok => ImapStatus::Ok { + tag, + code: self.extra, + text: self.text, + }, + StatusCode::No => ImapStatus::No { + tag, + code: self.extra, + text: self.text, + }, + StatusCode::Bad => ImapStatus::Bad { + tag, + code: self.extra, + text: self.text, + }, + StatusCode::PreAuth => ImapStatus::PreAuth { + code: self.extra, + text: self.text, + }, + StatusCode::Bye => ImapStatus::Bye { + code: self.extra, + text: self.text, + }, + } + } +} diff --git a/src/server/conn.rs b/src/server/conn.rs index 669321d..e37f787 100644 --- a/src/server/conn.rs +++ b/src/server/conn.rs @@ -3,6 +3,7 @@ use std::pin::Pin; use std::task::{self, Poll}; use futures::io::{AsyncRead, AsyncWrite}; +use imap_codec::types::core::Tag; use tower::Service; use super::pipeline::Connection; @@ -34,6 +35,7 @@ where F: Future>, ME: std::fmt::Display, S: Service, + S::Future: Send + 'static, S::Error: std::fmt::Display, { type Output = (); @@ -56,27 +58,17 @@ where // TODO: Properly handle server greeting { - use imap_codec::types::response::{Response, Status}; + use futures::SinkExt; - let status = match Status::ok(None, None, "Hello") { - Ok(status) => status, - Err(err) => { - tracing::error!("Connection error: {}", err); - return Poll::Ready(()); - } - }; - let res = Response::Status(status); - - if let Err(err) = this.conn.send(res) { - tracing::error!("Connection error: {}", err); - return Poll::Ready(()); - }; + let greeting = Response::ok("Hello").unwrap(); // "Hello" is a valid + // greeting + this.conn.start_send_unpin((None, greeting)).unwrap(); } ConnectingState::Ready { service } } ConnectingStateProj::Ready { service } => { - let server = PipelineServer::new(this.conn, service); + let server = PipelineServer::new(this.conn, PipelineService { inner: service }); futures::pin_mut!(server); return server.poll(cx).map(|res| { @@ -91,3 +83,29 @@ where } } } + +struct PipelineService { + inner: S, +} + +impl Service for PipelineService +where + S: Service, + S::Future: Send + 'static, +{ + type Response = (Option, S::Response); + type Error = S::Error; + type Future = futures::future::BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + use futures::{FutureExt, TryFutureExt}; + + let tag = req.tag.clone(); + + self.inner.call(req).map_ok(|res| (Some(tag), res)).boxed() + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index f749b14..88afa4e 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -5,6 +5,7 @@ use std::task::{self, Poll}; use futures::io::{AsyncRead, AsyncWrite}; use imap_codec::types::response::Capability; +use crate::errors::{Error, Result}; use crate::proto::{Request, Response}; use accept::Accept; use pipeline::Connection; @@ -15,14 +16,6 @@ mod conn; mod pipeline; mod service; -#[derive(Debug, thiserror::Error, miette::Diagnostic)] -pub enum Error { - #[error("Error occured when accepting new connections")] - Accept(#[source] A), - #[error("Error occured on service creation")] - MakeService(#[source] tower::BoxError), -} - #[derive(Debug, Default, Clone)] pub struct Imap { pub capabilities: Vec, @@ -50,25 +43,25 @@ impl Future for Server where I: Accept, I::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static, - I::Error: 'static, + I::Error: Send + Sync + 'static, S: MakeServiceRef, S::MakeError: Into + std::fmt::Display, S::Error: std::fmt::Display, S::Future: Send + 'static, S::Service: Send + 'static, + >::Future: Send + 'static, { - type Output = Result<(), Error>; + type Output = Result<()>; fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { loop { let this = self.as_mut().project(); if let Some(conn) = futures::ready!(this.incoming.poll_accept(cx)) { - let conn = conn.map_err(Error::Accept)?; + let conn = conn.map_err(Error::accept)?; futures::ready!(this.make_service.poll_ready_ref(cx)) - .map_err(Into::into) - .map_err(Error::MakeService)?; + .map_err(Error::make_service)?; let service_fut = this.make_service.make_service_ref(&conn); diff --git a/src/server/pipeline.rs b/src/server/pipeline.rs index 477487c..149b346 100644 --- a/src/server/pipeline.rs +++ b/src/server/pipeline.rs @@ -5,6 +5,7 @@ use bytes::BytesMut; use futures::io::{AsyncRead, AsyncWrite}; use futures::sink::Sink; use futures::stream::Stream; +use imap_codec::types::core::Tag; use crate::proto::{Request, Response}; @@ -89,21 +90,9 @@ where Poll::Ready(Ok(())) } - - pub(crate) fn send(&mut self, item: Response) -> Result<()> { - use bytes::BufMut; - use imap_codec::codec::Encode; - - let mut writer = BufMut::writer(&mut self.write_buf); - - tracing::debug!(item = ?item, "transport.send"); - item.encode(&mut writer)?; - - Ok(()) - } } -impl Sink for Connection +impl Sink<(Option, Response)> for Connection where C: AsyncWrite + Unpin, { @@ -117,9 +106,24 @@ where Poll::Ready(Ok(())) } - fn start_send(self: Pin<&mut Self>, item: Response) -> Result<(), Self::Error> { + fn start_send( + self: Pin<&mut Self>, + (tag, res): (Option, Response), + ) -> Result<(), Self::Error> { + use bytes::BufMut; + use imap_codec::codec::Encode; + debug_assert!(self.write_buf.is_empty()); - self.get_mut().send(item)?; + + let write_buf = &mut self.get_mut().write_buf; + let mut writer = write_buf.writer(); + + let body = res.body.into_iter().flat_map(|body| body.into_data()); + for data in body { + data.encode(&mut writer)?; + } + + res.status.into_imap(tag).encode(&mut writer)?; Ok(()) }