[dep-upgrade-202402] refactor http listener code

This commit is contained in:
Alex 2024-02-07 14:33:07 +01:00
parent 22332e6c35
commit fe48d60d2b
Signed by untrusted user: lx
GPG key ID: 0E496D15096376BE
3 changed files with 132 additions and 103 deletions

View file

@ -5,6 +5,7 @@ use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use futures::future::Future; use futures::future::Future;
use futures::stream::{futures_unordered::FuturesUnordered, StreamExt};
use http_body_util::BodyExt; use http_body_util::BodyExt;
use hyper::header::HeaderValue; use hyper::header::HeaderValue;
@ -15,7 +16,7 @@ use hyper::{HeaderMap, StatusCode};
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, UnixListener}; use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream};
use opentelemetry::{ use opentelemetry::{
global, global,
@ -110,20 +111,12 @@ impl<A: ApiHandler> ApiServer<A> {
bind_addr bind_addr
); );
tokio::pin!(shutdown_signal);
match bind_addr { match bind_addr {
UnixOrTCPSocketAddress::TCPSocket(addr) => { UnixOrTCPSocketAddress::TCPSocket(addr) => {
let listener = TcpListener::bind(addr).await?; let listener = TcpListener::bind(addr).await?;
loop { let handler = move |request, socketaddr| self.clone().handler(request, socketaddr);
let (stream, client_addr) = tokio::select! { server_loop(listener, handler, shutdown_signal).await
acc = listener.accept() => acc?,
_ = &mut shutdown_signal => break,
};
self.launch_handler(stream, client_addr.to_string());
}
} }
UnixOrTCPSocketAddress::UnixSocket(ref path) => { UnixOrTCPSocketAddress::UnixSocket(ref path) => {
if path.exists() { if path.exists() {
@ -131,52 +124,24 @@ impl<A: ApiHandler> ApiServer<A> {
} }
let listener = UnixListener::bind(path)?; let listener = UnixListener::bind(path)?;
let listener = UnixListenerOn(listener, path.display().to_string());
fs::set_permissions( fs::set_permissions(
path, path,
Permissions::from_mode(unix_bind_addr_mode.unwrap_or(0o222)), Permissions::from_mode(unix_bind_addr_mode.unwrap_or(0o222)),
)?; )?;
loop { let handler = move |request, socketaddr| self.clone().handler(request, socketaddr);
let (stream, _) = tokio::select! { server_loop(listener, handler, shutdown_signal).await
acc = listener.accept() => acc?,
_ = &mut shutdown_signal => break,
};
self.launch_handler(stream, path.display().to_string());
}
} }
}; }
Ok(())
}
fn launch_handler<S>(self: &Arc<Self>, stream: S, client_addr: String)
where
S: AsyncRead + AsyncWrite + Send + Sync + 'static,
{
let this = self.clone();
let io = TokioIo::new(stream);
let serve =
move |req: Request<IncomingBody>| this.clone().handler(req, client_addr.to_string());
tokio::task::spawn(async move {
let io = Box::pin(io);
if let Err(e) = http1::Builder::new()
.serve_connection(io, service_fn(serve))
.await
{
debug!("Error handling HTTP connection: {}", e);
}
});
} }
async fn handler( async fn handler(
self: Arc<Self>, self: Arc<Self>,
req: Request<IncomingBody>, req: Request<IncomingBody>,
addr: String, addr: String,
) -> Result<Response<BoxBody<A::Error>>, GarageError> { ) -> Result<Response<BoxBody<A::Error>>, http::Error> {
let uri = req.uri().clone(); let uri = req.uri().clone();
if let Ok(forwarded_for_ip_addr) = if let Ok(forwarded_for_ip_addr) =
@ -278,3 +243,105 @@ impl<A: ApiHandler> ApiServer<A> {
res res
} }
} }
// ==== helper functions ====
#[async_trait]
pub trait Accept: Send + Sync + 'static {
type Stream: AsyncRead + AsyncWrite + Send + Sync + 'static;
async fn accept(&self) -> std::io::Result<(Self::Stream, String)>;
}
#[async_trait]
impl Accept for TcpListener {
type Stream = TcpStream;
async fn accept(&self) -> std::io::Result<(Self::Stream, String)> {
self.accept()
.await
.map(|(stream, addr)| (stream, addr.to_string()))
}
}
pub struct UnixListenerOn(pub UnixListener, pub String);
#[async_trait]
impl Accept for UnixListenerOn {
type Stream = UnixStream;
async fn accept(&self) -> std::io::Result<(Self::Stream, String)> {
self.0
.accept()
.await
.map(|(stream, _addr)| (stream, self.1.clone()))
}
}
pub async fn server_loop<A, H, F, E>(
listener: A,
handler: H,
shutdown_signal: impl Future<Output = ()>,
) -> Result<(), GarageError>
where
A: Accept,
H: Fn(Request<IncomingBody>, String) -> F + Send + Sync + Clone + 'static,
F: Future<Output = Result<Response<BoxBody<E>>, http::Error>> + Send + 'static,
E: Send + Sync + std::error::Error + 'static,
{
tokio::pin!(shutdown_signal);
let (conn_in, mut conn_out) = tokio::sync::mpsc::unbounded_channel();
let connection_collector = tokio::spawn(async move {
let mut collection = FuturesUnordered::new();
loop {
let collect_next = async {
if collection.is_empty() {
futures::future::pending().await
} else {
collection.next().await
}
};
tokio::select! {
result = collect_next => {
trace!("HTTP connection finished: {:?}", result);
}
new_fut = conn_out.recv() => {
match new_fut {
Some(f) => collection.push(f),
None => break,
}
}
}
}
debug!("Collecting last open HTTP connections.");
while let Some(conn_res) = collection.next().await {
trace!("HTTP connection finished: {:?}", conn_res);
}
debug!("No more HTTP connections to collect");
});
loop {
let (stream, client_addr) = tokio::select! {
acc = listener.accept() => acc?,
_ = &mut shutdown_signal => break,
};
let io = TokioIo::new(stream);
let handler = handler.clone();
let serve = move |req: Request<IncomingBody>| handler(req, client_addr.clone());
let fut = tokio::task::spawn(async move {
let io = Box::pin(io);
if let Err(e) = http1::Builder::new()
.serve_connection(io, service_fn(serve))
.await
{
debug!("Error handling HTTP connection: {}", e);
}
});
conn_in.send(fut)?;
}
connection_collector.await?;
Ok(())
}

View file

@ -113,12 +113,11 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er
if let Some(web_config) = &config.s3_web { if let Some(web_config) = &config.s3_web {
info!("Initializing web server..."); info!("Initializing web server...");
let web_server = WebServer::new(garage.clone(), web_config.root_domain.clone());
servers.push(( servers.push((
"Web", "Web",
tokio::spawn(WebServer::run( tokio::spawn(web_server.run(
garage.clone(),
web_config.bind_addr.clone(), web_config.bind_addr.clone(),
web_config.root_domain.clone(),
wait_from(watch_cancel.clone()), wait_from(watch_cancel.clone()),
)), )),
)); ));

View file

@ -4,16 +4,12 @@ use std::{convert::Infallible, sync::Arc};
use futures::future::Future; use futures::future::Future;
use hyper::server::conn::http1;
use hyper::{ use hyper::{
body::Incoming as IncomingBody, body::Incoming as IncomingBody,
header::{HeaderValue, HOST}, header::{HeaderValue, HOST},
service::service_fn,
Method, Request, Response, StatusCode, Method, Request, Response, StatusCode,
}; };
use hyper_util::rt::TokioIo;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, UnixListener}; use tokio::net::{TcpListener, UnixListener};
use opentelemetry::{ use opentelemetry::{
@ -25,6 +21,7 @@ use opentelemetry::{
use crate::error::*; use crate::error::*;
use garage_api::generic_server::{server_loop, UnixListenerOn};
use garage_api::helpers::*; use garage_api::helpers::*;
use garage_api::s3::cors::{add_cors_headers, find_matching_cors_rule, handle_options_for_bucket}; use garage_api::s3::cors::{add_cors_headers, find_matching_cors_rule, handle_options_for_bucket};
use garage_api::s3::error::{ use garage_api::s3::error::{
@ -75,35 +72,29 @@ pub struct WebServer {
impl WebServer { impl WebServer {
/// Run a web server /// Run a web server
pub async fn run( pub fn new(garage: Arc<Garage>, root_domain: String) -> Arc<Self> {
garage: Arc<Garage>,
bind_addr: UnixOrTCPSocketAddress,
root_domain: String,
shutdown_signal: impl Future<Output = ()>,
) -> Result<(), GarageError> {
let metrics = Arc::new(WebMetrics::new()); let metrics = Arc::new(WebMetrics::new());
let web_server = Arc::new(WebServer { Arc::new(WebServer {
garage, garage,
metrics, metrics,
root_domain, root_domain,
}); })
}
pub async fn run(
self: Arc<Self>,
bind_addr: UnixOrTCPSocketAddress,
shutdown_signal: impl Future<Output = ()>,
) -> Result<(), GarageError> {
info!("Web server listening on {}", bind_addr); info!("Web server listening on {}", bind_addr);
tokio::pin!(shutdown_signal);
match bind_addr { match bind_addr {
UnixOrTCPSocketAddress::TCPSocket(addr) => { UnixOrTCPSocketAddress::TCPSocket(addr) => {
let listener = TcpListener::bind(addr).await?; let listener = TcpListener::bind(addr).await?;
loop { let handler =
let (stream, client_addr) = tokio::select! { move |stream, socketaddr| self.clone().handle_request(stream, socketaddr);
acc = listener.accept() => acc?, server_loop(listener, handler, shutdown_signal).await
_ = &mut shutdown_signal => break,
};
web_server.launch_handler(stream, client_addr.to_string());
}
} }
UnixOrTCPSocketAddress::UnixSocket(ref path) => { UnixOrTCPSocketAddress::UnixSocket(ref path) => {
if path.exists() { if path.exists() {
@ -111,50 +102,22 @@ impl WebServer {
} }
let listener = UnixListener::bind(path)?; let listener = UnixListener::bind(path)?;
let listener = UnixListenerOn(listener, path.display().to_string());
fs::set_permissions(path, Permissions::from_mode(0o222))?; fs::set_permissions(path, Permissions::from_mode(0o222))?;
loop { let handler =
let (stream, _) = tokio::select! { move |stream, socketaddr| self.clone().handle_request(stream, socketaddr);
acc = listener.accept() => acc?, server_loop(listener, handler, shutdown_signal).await
_ = &mut shutdown_signal => break,
};
web_server.launch_handler(stream, path.display().to_string());
}
} }
}; }
Ok(())
}
fn launch_handler<S>(self: &Arc<Self>, stream: S, client_addr: String)
where
S: AsyncRead + AsyncWrite + Send + Sync + 'static,
{
let this = self.clone();
let io = TokioIo::new(stream);
let serve = move |req: Request<IncomingBody>| {
this.clone().handle_request(req, client_addr.to_string())
};
tokio::task::spawn(async move {
let io = Box::pin(io);
if let Err(e) = http1::Builder::new()
.serve_connection(io, service_fn(serve))
.await
{
debug!("Error handling HTTP connection: {}", e);
}
});
} }
async fn handle_request( async fn handle_request(
self: Arc<Self>, self: Arc<Self>,
req: Request<IncomingBody>, req: Request<IncomingBody>,
addr: String, addr: String,
) -> Result<Response<BoxBody<Error>>, Infallible> { ) -> Result<Response<BoxBody<Error>>, http::Error> {
if let Ok(forwarded_for_ip_addr) = if let Ok(forwarded_for_ip_addr) =
forwarded_headers::handle_forwarded_for_headers(req.headers()) forwarded_headers::handle_forwarded_for_headers(req.headers())
{ {