diff --git a/src/server/accept/addr.rs b/src/server/accept/addr.rs index 4f77a94..378b4b3 100644 --- a/src/server/accept/addr.rs +++ b/src/server/accept/addr.rs @@ -62,6 +62,14 @@ impl AsyncRead for AddrStream { ) -> Poll> { self.project().stream.poll_read(cx, buf) } + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + bufs: &mut [std::io::IoSliceMut<'_>], + ) -> Poll> { + self.project().stream.poll_read_vectored(cx, bufs) + } } impl AsyncWrite for AddrStream { diff --git a/src/server/conn.rs b/src/server/conn.rs index e37f787..e3cf53e 100644 --- a/src/server/conn.rs +++ b/src/server/conn.rs @@ -4,35 +4,111 @@ use std::task::{self, Poll}; use futures::io::{AsyncRead, AsyncWrite}; use imap_codec::types::core::Tag; +use tokio_tower::pipeline::Server as PipelineServer; use tower::Service; use super::pipeline::Connection; use super::Imap; use crate::proto::{Request, Response}; -#[pin_project::pin_project] -pub struct Connecting { - pub conn: Connection, - #[pin] - pub state: ConnectingState, +pub struct Connecting +where + C: AsyncRead + AsyncWrite + Unpin, + S: Service, + S::Future: Send + 'static, +{ + pub state: Option>, pub protocol: Imap, } -#[pin_project::pin_project(project = ConnectingStateProj)] -pub enum ConnectingState { +pub enum ConnectingState +where + C: AsyncRead + AsyncWrite + Unpin, + S: Service, + S::Future: Send + 'static, +{ Waiting { - #[pin] + conn: Connection, service_fut: F, }, Ready { + conn: Connection, service: S, }, + Serving { + server: PipelineServer, PipelineService>, + }, + Finished, +} + +impl ConnectingState +where + C: AsyncRead + AsyncWrite + Unpin, + F: Future> + Unpin, + ME: std::fmt::Display, + S: Service, + S::Future: Send + 'static, + S::Error: std::fmt::Display, +{ + fn poll_new_state(self, cx: &mut task::Context) -> (Self, Option>) { + match self { + ConnectingState::Waiting { + conn, + mut service_fut, + } => { + let service = match Pin::new(&mut service_fut).poll(cx) { + Poll::Ready(Ok(service)) => service, + Poll::Ready(Err(err)) => { + tracing::error!("Connection error: {}", err); + return ( + ConnectingState::Waiting { conn, service_fut }, + Some(Poll::Ready(())), + ); + } + Poll::Pending => { + return ( + ConnectingState::Waiting { conn, service_fut }, + Some(Poll::Pending), + ) + } + }; + + let mut conn = conn; + + // TODO: Properly handle server greeting + { + use futures::SinkExt; + + let greeting = Response::ok("Hello").unwrap(); // "Hello" is a valid + // greeting + conn.start_send_unpin((None, greeting)).unwrap(); + } + + (ConnectingState::Ready { conn, service }, None) + } + ConnectingState::Ready { conn, service } => ( + ConnectingState::Serving { + server: PipelineServer::new(conn, PipelineService { inner: service }), + }, + None, + ), + ConnectingState::Serving { mut server } => match Pin::new(&mut server).poll(cx) { + Poll::Ready(Ok(_)) => (ConnectingState::Finished, Some(Poll::Ready(()))), + Poll::Ready(Err(err)) => { + tracing::debug!("Connecting error: {}", err); + (ConnectingState::Finished, Some(Poll::Ready(()))) + } + Poll::Pending => (ConnectingState::Serving { server }, Some(Poll::Pending)), + }, + ConnectingState::Finished => (self, Some(Poll::Ready(()))), + } + } } impl Future for Connecting where C: AsyncRead + AsyncWrite + Unpin, - F: Future>, + F: Future> + Unpin, ME: std::fmt::Display, S: Service, S::Future: Send + 'static, @@ -40,51 +116,28 @@ where { type Output = (); - fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - use tokio_tower::pipeline::Server as PipelineServer; - - let mut this = self.project(); - + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { loop { - let next = match this.state.as_mut().project() { - ConnectingStateProj::Waiting { service_fut } => { - let service = match futures::ready!(service_fut.poll(cx)) { - Ok(service) => service, - Err(err) => { - tracing::error!("Connection error: {}", err); - return Poll::Ready(()); - } - }; + let state = self.as_mut().state.take().unwrap(); + let (next, res) = state.poll_new_state(cx); - // TODO: Properly handle server greeting - { - use futures::SinkExt; - - let greeting = Response::ok("Hello").unwrap(); // "Hello" is a valid - // greeting - this.conn.start_send_unpin((None, greeting)).unwrap(); - } - - ConnectingState::Ready { service } - } - ConnectingStateProj::Ready { service } => { - let server = PipelineServer::new(this.conn, PipelineService { inner: service }); - futures::pin_mut!(server); - - return server.poll(cx).map(|res| { - if let Err(err) = res { - tracing::debug!("Connection error: {}", err); - } - }); - } - }; - - this.state.set(next); + self.state = Some(next); + if let Some(res) = res { + return res; + } } } } -struct PipelineService { +impl Unpin for Connecting +where + C: AsyncRead + AsyncWrite + Unpin, + S: Service, + S::Future: Send + 'static, +{ +} + +pub struct PipelineService { inner: S, } diff --git a/src/server/mod.rs b/src/server/mod.rs index 88afa4e..db02c95 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -46,8 +46,8 @@ where I::Error: Send + Sync + 'static, S: MakeServiceRef, S::MakeError: Into + std::fmt::Display, - S::Error: std::fmt::Display, - S::Future: Send + 'static, + S::Error: std::fmt::Display + Send + 'static, + S::Future: Unpin + Send + 'static, S::Service: Send + 'static, >::Future: Send + 'static, { @@ -66,8 +66,10 @@ where let service_fut = this.make_service.make_service_ref(&conn); tokio::task::spawn(conn::Connecting { - conn: Connection::new(conn), - state: conn::ConnectingState::Waiting { service_fut }, + state: Some(conn::ConnectingState::Waiting { + conn: Connection::new(conn), + service_fut, + }), protocol: this.protocol.clone(), }); } else {