diff --git a/src/api/admin/api_server.rs b/src/api/admin/api_server.rs index 50c79120d..43497badd 100644 --- a/src/api/admin/api_server.rs +++ b/src/api/admin/api_server.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::net::SocketAddr; use std::sync::Arc; use async_trait::async_trait; @@ -18,6 +17,7 @@ use prometheus::{Encoder, TextEncoder}; use garage_model::garage::Garage; use garage_rpc::system::ClusterHealthStatus; use garage_util::error::Error as GarageError; +use garage_util::socket_address::UnixOrTCPSocketAddress; use crate::generic_server::*; @@ -61,7 +61,7 @@ impl AdminApiServer { pub async fn run( self, - bind_addr: SocketAddr, + bind_addr: UnixOrTCPSocketAddress, shutdown_signal: impl Future, ) -> Result<(), GarageError> { let region = self.garage.config.s3_api.s3_region.clone(); diff --git a/src/api/generic_server.rs b/src/api/generic_server.rs index 757b85ec5..42c65ab7c 100644 --- a/src/api/generic_server.rs +++ b/src/api/generic_server.rs @@ -1,4 +1,5 @@ -use std::net::SocketAddr; +use std::fs::{self, Permissions}; +use std::os::unix::fs::PermissionsExt; use std::sync::Arc; use async_trait::async_trait; @@ -11,6 +12,10 @@ use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Request, Response, Server}; use hyper::{HeaderMap, StatusCode}; +use hyperlocal::UnixServerExt; + +use tokio::net::UnixStream; + use opentelemetry::{ global, metrics::{Counter, ValueRecorder}, @@ -21,6 +26,7 @@ use opentelemetry::{ use garage_util::error::Error as GarageError; use garage_util::forwarded_headers; use garage_util::metrics::{gen_trace_id, RecordDuration}; +use garage_util::socket_address::UnixOrTCPSocketAddress; pub(crate) trait ApiEndpoint: Send + Sync + 'static { fn name(&self) -> &'static str; @@ -91,10 +97,10 @@ impl ApiServer { pub async fn run_server( self: Arc, - bind_addr: SocketAddr, + bind_addr: UnixOrTCPSocketAddress, shutdown_signal: impl Future, ) -> Result<(), GarageError> { - let service = make_service_fn(|conn: &AddrStream| { + let tcp_service = make_service_fn(|conn: &AddrStream| { let this = self.clone(); let client_addr = conn.remote_addr(); @@ -102,28 +108,60 @@ impl ApiServer { Ok::<_, GarageError>(service_fn(move |req: Request| { let this = this.clone(); - this.handler(req, client_addr) + this.handler(req, client_addr.to_string()) })) } }); - let server = Server::bind(&bind_addr).serve(service); + let unix_service = make_service_fn(|_: &UnixStream| { + let this = self.clone(); + + let path = bind_addr.to_string(); + async move { + Ok::<_, GarageError>(service_fn(move |req: Request| { + let this = this.clone(); + + this.handler(req, path.clone()) + })) + } + }); - let graceful = server.with_graceful_shutdown(shutdown_signal); info!( - "{} API server listening on http://{}", + "{} API server listening on {}", A::API_NAME_DISPLAY, bind_addr ); - graceful.await?; + match bind_addr { + UnixOrTCPSocketAddress::TCPSocket(addr) => { + Server::bind(&addr) + .serve(tcp_service) + .with_graceful_shutdown(shutdown_signal) + .await? + } + UnixOrTCPSocketAddress::UnixSocket(ref path) => { + if path.exists() { + fs::remove_file(path)? + } + + let bound = Server::bind_unix(path)?; + + fs::set_permissions(path, Permissions::from_mode(0o222))?; + + bound + .serve(unix_service) + .with_graceful_shutdown(shutdown_signal) + .await?; + } + }; + Ok(()) } async fn handler( self: Arc, req: Request, - addr: SocketAddr, + addr: String, ) -> Result, GarageError> { let uri = req.uri().clone(); diff --git a/src/api/k2v/api_server.rs b/src/api/k2v/api_server.rs index bb85b2e7d..413fd61c8 100644 --- a/src/api/k2v/api_server.rs +++ b/src/api/k2v/api_server.rs @@ -1,4 +1,3 @@ -use std::net::SocketAddr; use std::sync::Arc; use async_trait::async_trait; @@ -9,6 +8,7 @@ use hyper::{Body, Method, Request, Response}; use opentelemetry::{trace::SpanRef, KeyValue}; use garage_util::error::Error as GarageError; +use garage_util::socket_address::UnixOrTCPSocketAddress; use garage_model::garage::Garage; @@ -37,7 +37,7 @@ pub(crate) struct K2VApiEndpoint { impl K2VApiServer { pub async fn run( garage: Arc, - bind_addr: SocketAddr, + bind_addr: UnixOrTCPSocketAddress, s3_region: String, shutdown_signal: impl Future, ) -> Result<(), GarageError> { diff --git a/src/api/s3/api_server.rs b/src/api/s3/api_server.rs index 278372975..cc2091b81 100644 --- a/src/api/s3/api_server.rs +++ b/src/api/s3/api_server.rs @@ -1,4 +1,3 @@ -use std::net::SocketAddr; use std::sync::Arc; use async_trait::async_trait; @@ -10,6 +9,7 @@ use hyper::{Body, Request, Response}; use opentelemetry::{trace::SpanRef, KeyValue}; use garage_util::error::Error as GarageError; +use garage_util::socket_address::UnixOrTCPSocketAddress; use garage_model::garage::Garage; use garage_model::key_table::Key; @@ -44,7 +44,7 @@ pub(crate) struct S3ApiEndpoint { impl S3ApiServer { pub async fn run( garage: Arc, - addr: SocketAddr, + addr: UnixOrTCPSocketAddress, s3_region: String, shutdown_signal: impl Future, ) -> Result<(), GarageError> { diff --git a/src/garage/server.rs b/src/garage/server.rs index 472616c74..3ad10b720 100644 --- a/src/garage/server.rs +++ b/src/garage/server.rs @@ -79,7 +79,7 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er "S3 API", tokio::spawn(S3ApiServer::run( garage.clone(), - *s3_bind_addr, + s3_bind_addr.clone(), config.s3_api.s3_region.clone(), wait_from(watch_cancel.clone()), )), @@ -94,7 +94,7 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er "K2V API", tokio::spawn(K2VApiServer::run( garage.clone(), - config.k2v_api.as_ref().unwrap().api_bind_addr, + config.k2v_api.as_ref().unwrap().api_bind_addr.clone(), config.s3_api.s3_region.clone(), wait_from(watch_cancel.clone()), )), @@ -110,7 +110,7 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er "Web", tokio::spawn(WebServer::run( garage.clone(), - web_config.bind_addr, + web_config.bind_addr.clone(), web_config.root_domain.clone(), wait_from(watch_cancel.clone()), )), @@ -121,7 +121,9 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er info!("Launching Admin API server..."); servers.push(( "Admin", - tokio::spawn(admin_server.run(*admin_bind_addr, wait_from(watch_cancel.clone()))), + tokio::spawn( + admin_server.run(admin_bind_addr.clone(), wait_from(watch_cancel.clone())), + ), )); } diff --git a/src/web/web_server.rs b/src/web/web_server.rs index 287aef1aa..73780efb4 100644 --- a/src/web/web_server.rs +++ b/src/web/web_server.rs @@ -1,4 +1,6 @@ -use std::{convert::Infallible, net::SocketAddr, sync::Arc}; +use std::fs::{self, Permissions}; +use std::os::unix::prelude::PermissionsExt; +use std::{convert::Infallible, sync::Arc}; use futures::future::Future; @@ -9,6 +11,10 @@ use hyper::{ Body, Method, Request, Response, Server, StatusCode, }; +use hyperlocal::UnixServerExt; + +use tokio::net::UnixStream; + use opentelemetry::{ global, metrics::{Counter, ValueRecorder}, @@ -32,6 +38,7 @@ use garage_util::data::Uuid; use garage_util::error::Error as GarageError; use garage_util::forwarded_headers; use garage_util::metrics::{gen_trace_id, RecordDuration}; +use garage_util::socket_address::UnixOrTCPSocketAddress; struct WebMetrics { request_counter: Counter, @@ -69,7 +76,7 @@ impl WebServer { /// Run a web server pub async fn run( garage: Arc, - addr: SocketAddr, + addr: UnixOrTCPSocketAddress, root_domain: String, shutdown_signal: impl Future, ) -> Result<(), GarageError> { @@ -80,7 +87,7 @@ impl WebServer { root_domain, }); - let service = make_service_fn(|conn: &AddrStream| { + let tcp_service = make_service_fn(|conn: &AddrStream| { let web_server = web_server.clone(); let client_addr = conn.remote_addr(); @@ -88,23 +95,56 @@ impl WebServer { Ok::<_, Error>(service_fn(move |req: Request| { let web_server = web_server.clone(); - web_server.handle_request(req, client_addr) + web_server.handle_request(req, client_addr.to_string()) })) } }); - let server = Server::bind(&addr).serve(service); - let graceful = server.with_graceful_shutdown(shutdown_signal); - info!("Web server listening on http://{}", addr); + let unix_service = make_service_fn(|_: &UnixStream| { + let web_server = web_server.clone(); + + let path = addr.to_string(); + async move { + Ok::<_, Error>(service_fn(move |req: Request| { + let web_server = web_server.clone(); + + web_server.handle_request(req, path.clone()) + })) + } + }); + + info!("Web server listening on {}", addr); + + match addr { + UnixOrTCPSocketAddress::TCPSocket(addr) => { + Server::bind(&addr) + .serve(tcp_service) + .with_graceful_shutdown(shutdown_signal) + .await? + } + UnixOrTCPSocketAddress::UnixSocket(ref path) => { + if path.exists() { + fs::remove_file(path)? + } + + let bound = Server::bind_unix(path)?; + + fs::set_permissions(path, Permissions::from_mode(0o222))?; + + bound + .serve(unix_service) + .with_graceful_shutdown(shutdown_signal) + .await?; + } + }; - graceful.await?; Ok(()) } async fn handle_request( self: Arc, req: Request, - addr: SocketAddr, + addr: String, ) -> Result, Infallible> { if let Ok(forwarded_for_ip_addr) = forwarded_headers::handle_forwarded_for_headers(req.headers())