diff --git a/Cargo.lock b/Cargo.lock index 9054c450..52e3fa0f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1271,6 +1271,7 @@ dependencies = [ "http-range", "httpdate", "hyper", + "hyperlocal", "idna", "md-5", "multer", @@ -1464,8 +1465,10 @@ dependencies = [ "garage_util", "http", "hyper", + "hyperlocal", "opentelemetry", "percent-encoding", + "tokio", "tracing", ] @@ -1775,6 +1778,19 @@ dependencies = [ "tokio-io-timeout", ] +[[package]] +name = "hyperlocal" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fafdf7b2b2de7c9784f76e02c0935e65a8117ec3b768644379983ab333ac98c" +dependencies = [ + "futures-util", + "hex", + "hyper", + "pin-project", + "tokio", +] + [[package]] name = "iana-time-zone" version = "0.1.57" diff --git a/Cargo.nix b/Cargo.nix index 763f6268..5cd347fb 100644 --- a/Cargo.nix +++ b/Cargo.nix @@ -33,7 +33,7 @@ args@{ ignoreLockHash, }: let - nixifiedLockHash = "8ff415a3cc93dd7330ffcc18ee0b3a76c2863e1108be1c88d8e37f29182651f2"; + nixifiedLockHash = "3f325a8a549c43a788ff702e65f6de2d42ad19a46067248e29108e90212ca2f5"; workspaceSrc = if args.workspaceSrc == null then ./. else args.workspaceSrc; currentLockHash = builtins.hashFile "sha256" (workspaceSrc + /Cargo.lock); lockHashIgnored = if ignoreLockHash @@ -1809,6 +1809,7 @@ in http_range = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".http-range."0.1.5" { inherit profileName; }).out; httpdate = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".httpdate."1.0.3" { inherit profileName; }).out; hyper = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".hyper."0.14.27" { inherit profileName; }).out; + hyperlocal = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".hyperlocal."0.8.0" { inherit profileName; }).out; idna = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".idna."0.4.0" { inherit profileName; }).out; md5 = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".md-5."0.10.5" { inherit profileName; }).out; multer = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".multer."2.1.0" { inherit profileName; }).out; @@ -2064,8 +2065,10 @@ in garage_util = (rustPackages."unknown".garage_util."0.8.4" { inherit profileName; }).out; http = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".http."0.2.9" { inherit profileName; }).out; hyper = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".hyper."0.14.27" { inherit profileName; }).out; + hyperlocal = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".hyperlocal."0.8.0" { inherit profileName; }).out; opentelemetry = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".opentelemetry."0.17.0" { inherit profileName; }).out; percent_encoding = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".percent-encoding."2.3.0" { inherit profileName; }).out; + tokio = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".tokio."1.32.0" { inherit profileName; }).out; tracing = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".tracing."0.1.37" { inherit profileName; }).out; }; }); @@ -2491,6 +2494,23 @@ in }; }); + "registry+https://github.com/rust-lang/crates.io-index".hyperlocal."0.8.0" = overridableMkRustCrate (profileName: rec { + name = "hyperlocal"; + version = "0.8.0"; + registry = "registry+https://github.com/rust-lang/crates.io-index"; + src = fetchCratesIo { inherit name version; sha256 = "0fafdf7b2b2de7c9784f76e02c0935e65a8117ec3b768644379983ab333ac98c"; }; + features = builtins.concatLists [ + [ "server" ] + ]; + dependencies = { + futures_util = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".futures-util."0.3.28" { inherit profileName; }).out; + hex = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".hex."0.4.3" { inherit profileName; }).out; + hyper = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".hyper."0.14.27" { inherit profileName; }).out; + pin_project = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".pin-project."1.1.3" { inherit profileName; }).out; + tokio = (rustPackages."registry+https://github.com/rust-lang/crates.io-index".tokio."1.32.0" { inherit profileName; }).out; + }; + }); + "registry+https://github.com/rust-lang/crates.io-index".iana-time-zone."0.1.57" = overridableMkRustCrate (profileName: rec { name = "iana-time-zone"; version = "0.1.57"; diff --git a/doc/book/connect/apps/index.md b/doc/book/connect/apps/index.md index 7bad9d09..f67a29c9 100644 --- a/doc/book/connect/apps/index.md +++ b/doc/book/connect/apps/index.md @@ -421,7 +421,7 @@ Now we can write a simple script (eg `~/.local/bin/matrix-cache-gc`): ## CONFIGURATION ## AWS_ACCESS_KEY_ID=GKxxx AWS_SECRET_ACCESS_KEY=xxxx -S3_ENDPOINT=http://localhost:3900 +AWS_ENDPOINT_URL=http://localhost:3900 S3_BUCKET=matrix MEDIA_STORE=/var/lib/matrix-synapse/media PG_USER=matrix @@ -442,7 +442,7 @@ EOF s3_media_upload update-db 1d s3_media_upload --no-progress check-deleted $MEDIA_STORE -s3_media_upload --no-progress upload $MEDIA_STORE $S3_BUCKET --delete --endpoint-url $S3_ENDPOINT +s3_media_upload --no-progress upload $MEDIA_STORE $S3_BUCKET --delete --endpoint-url $AWS_ENDPOINT_URL ``` This script will list all the medias that were not accessed in the 24 hours according to your database. diff --git a/doc/book/connect/cli.md b/doc/book/connect/cli.md index 591ac151..c9ffd4f4 100644 --- a/doc/book/connect/cli.md +++ b/doc/book/connect/cli.md @@ -70,16 +70,17 @@ Then a file named `~/.aws/config` and put: ```toml [default] region=garage +endpoint_url=http://127.0.0.1:3900 ``` Now, supposing Garage is listening on `http://127.0.0.1:3900`, you can list your buckets with: ```bash -aws --endpoint-url http://127.0.0.1:3900 s3 ls +aws s3 ls ``` -Passing the `--endpoint-url` parameter to each command is annoying but AWS developers do not provide a corresponding configuration entry. -As a workaround, you can redefine the aws command by editing the file `~/.bashrc`: +If you're using awscli `<1.29.0` or `<2.13.0`, you need to pass `--endpoint-url` to each CLI invocation explicitly. +As a workaround, you can redefine the aws command by editing the file `~/.bashrc` in this case: ``` function aws { command aws --endpoint-url http://127.0.0.1:3900 $@ ; } diff --git a/doc/book/quick-start/_index.md b/doc/book/quick-start/_index.md index bd64e3eb..8ed36b7d 100644 --- a/doc/book/quick-start/_index.md +++ b/doc/book/quick-start/_index.md @@ -269,12 +269,14 @@ named `~/.awsrc` with this content: export AWS_ACCESS_KEY_ID=xxxx # put your Key ID here export AWS_SECRET_ACCESS_KEY=xxxx # put your Secret key here export AWS_DEFAULT_REGION='garage' -export AWS_ENDPOINT='http://localhost:3900' +export AWS_ENDPOINT_URL='http://localhost:3900' -function aws { command aws --endpoint-url $AWS_ENDPOINT $@ ; } aws --version ``` +Note you need to have at least `awscli` `>=1.29.0` or `>=2.13.0`, otherwise you +need to specify `--endpoint-url` explicitly on each `awscli` invocation. + Now, each time you want to use `awscli` on this target, run: ```bash diff --git a/doc/book/reference-manual/configuration.md b/doc/book/reference-manual/configuration.md index f07fb1e0..1ac681cf 100644 --- a/doc/book/reference-manual/configuration.md +++ b/doc/book/reference-manual/configuration.md @@ -468,6 +468,8 @@ manually. The IP and port on which to bind for accepting S3 API calls. This endpoint does not suport TLS: a reverse proxy should be used to provide it. +Alternatively, since `v0.8.5`, a path can be used to create a unix socket with 0222 mode. + ### `s3_region` Garage will accept S3 API calls that are targetted to the S3 region defined here. @@ -497,6 +499,8 @@ The IP and port on which to bind for accepting HTTP requests to buckets configur for website access. This endpoint does not suport TLS: a reverse proxy should be used to provide it. +Alternatively, since `v0.8.5`, a path can be used to create a unix socket with 0222 mode. + ### `root_domain` The optional suffix appended to bucket names for the corresponding HTTP Host. @@ -516,6 +520,9 @@ If specified, Garage will bind an HTTP server to this port and address, on which it will listen to requests for administration features. See [administration API reference](@/documentation/reference-manual/admin-api.md) to learn more about these features. +Alternatively, since `v0.8.5`, a path can be used to create a unix socket. Note that for security reasons, +the socket will have 0220 mode. Make sure to set user and group permissions accordingly. + ### `metrics_token`, `metrics_token_file` or `GARAGE_METRICS_TOKEN` (env) The token for accessing the Metrics endpoint. If this token is not set, the diff --git a/script/dev-env-aws.sh b/script/dev-env-aws.sh index 9436c2c7..808f9cf1 100644 --- a/script/dev-env-aws.sh +++ b/script/dev-env-aws.sh @@ -1,7 +1,7 @@ export AWS_ACCESS_KEY_ID=`cat /tmp/garage.s3 |cut -d' ' -f1` export AWS_SECRET_ACCESS_KEY=`cat /tmp/garage.s3 |cut -d' ' -f2` export AWS_DEFAULT_REGION='garage' - +# FUTUREWORK: set AWS_ENDPOINT_URL instead, once nixpkgs bumps awscli to >=2.13.0. function aws { command aws --endpoint-url http://127.0.0.1:3911 $@ ; } aws --version diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index 6425591f..cb9e2e55 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -45,6 +45,7 @@ http = "0.2" httpdate = "1.0" http-range = "0.1" hyper = { version = "0.14", features = ["server", "http1", "runtime", "tcp", "stream"] } +hyperlocal = { version = "0.8.0", default-features = false, features = ["server"] } multer = "2.0" percent-encoding = "2.1.0" roxmltree = "0.18" diff --git a/src/api/admin/api_server.rs b/src/api/admin/api_server.rs index cc04d81f..53503220 100644 --- a/src/api/admin/api_server.rs +++ b/src/api/admin/api_server.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::net::SocketAddr; use std::sync::Arc; use async_trait::async_trait; @@ -18,6 +17,7 @@ use prometheus::{Encoder, TextEncoder}; use garage_model::garage::Garage; use garage_rpc::system::ClusterHealthStatus; use garage_util::error::Error as GarageError; +use garage_util::socket_address::UnixOrTCPSocketAddress; use crate::generic_server::*; @@ -61,12 +61,12 @@ impl AdminApiServer { pub async fn run( self, - bind_addr: SocketAddr, + bind_addr: UnixOrTCPSocketAddress, shutdown_signal: impl Future, ) -> Result<(), GarageError> { let region = self.garage.config.s3_api.s3_region.clone(); ApiServer::new(region, self) - .run_server(bind_addr, shutdown_signal) + .run_server(bind_addr, Some(0o220), shutdown_signal) .await } diff --git a/src/api/generic_server.rs b/src/api/generic_server.rs index 757b85ec..fa346f48 100644 --- a/src/api/generic_server.rs +++ b/src/api/generic_server.rs @@ -1,4 +1,5 @@ -use std::net::SocketAddr; +use std::fs::{self, Permissions}; +use std::os::unix::fs::PermissionsExt; use std::sync::Arc; use async_trait::async_trait; @@ -11,6 +12,10 @@ use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Request, Response, Server}; use hyper::{HeaderMap, StatusCode}; +use hyperlocal::UnixServerExt; + +use tokio::net::UnixStream; + use opentelemetry::{ global, metrics::{Counter, ValueRecorder}, @@ -21,6 +26,7 @@ use opentelemetry::{ use garage_util::error::Error as GarageError; use garage_util::forwarded_headers; use garage_util::metrics::{gen_trace_id, RecordDuration}; +use garage_util::socket_address::UnixOrTCPSocketAddress; pub(crate) trait ApiEndpoint: Send + Sync + 'static { fn name(&self) -> &'static str; @@ -91,10 +97,11 @@ impl ApiServer { pub async fn run_server( self: Arc, - bind_addr: SocketAddr, + bind_addr: UnixOrTCPSocketAddress, + unix_bind_addr_mode: Option, shutdown_signal: impl Future, ) -> Result<(), GarageError> { - let service = make_service_fn(|conn: &AddrStream| { + let tcp_service = make_service_fn(|conn: &AddrStream| { let this = self.clone(); let client_addr = conn.remote_addr(); @@ -102,28 +109,63 @@ impl ApiServer { Ok::<_, GarageError>(service_fn(move |req: Request| { let this = this.clone(); - this.handler(req, client_addr) + this.handler(req, client_addr.to_string()) })) } }); - let server = Server::bind(&bind_addr).serve(service); + let unix_service = make_service_fn(|_: &UnixStream| { + let this = self.clone(); + + let path = bind_addr.to_string(); + async move { + Ok::<_, GarageError>(service_fn(move |req: Request| { + let this = this.clone(); + + this.handler(req, path.clone()) + })) + } + }); - let graceful = server.with_graceful_shutdown(shutdown_signal); info!( - "{} API server listening on http://{}", + "{} API server listening on {}", A::API_NAME_DISPLAY, bind_addr ); - graceful.await?; + match bind_addr { + UnixOrTCPSocketAddress::TCPSocket(addr) => { + Server::bind(&addr) + .serve(tcp_service) + .with_graceful_shutdown(shutdown_signal) + .await? + } + UnixOrTCPSocketAddress::UnixSocket(ref path) => { + if path.exists() { + fs::remove_file(path)? + } + + let bound = Server::bind_unix(path)?; + + fs::set_permissions( + path, + Permissions::from_mode(unix_bind_addr_mode.unwrap_or(0o222)), + )?; + + bound + .serve(unix_service) + .with_graceful_shutdown(shutdown_signal) + .await?; + } + }; + Ok(()) } async fn handler( self: Arc, req: Request, - addr: SocketAddr, + addr: String, ) -> Result, GarageError> { let uri = req.uri().clone(); diff --git a/src/api/k2v/api_server.rs b/src/api/k2v/api_server.rs index bb85b2e7..3a032aba 100644 --- a/src/api/k2v/api_server.rs +++ b/src/api/k2v/api_server.rs @@ -1,4 +1,3 @@ -use std::net::SocketAddr; use std::sync::Arc; use async_trait::async_trait; @@ -9,6 +8,7 @@ use hyper::{Body, Method, Request, Response}; use opentelemetry::{trace::SpanRef, KeyValue}; use garage_util::error::Error as GarageError; +use garage_util::socket_address::UnixOrTCPSocketAddress; use garage_model::garage::Garage; @@ -37,12 +37,12 @@ pub(crate) struct K2VApiEndpoint { impl K2VApiServer { pub async fn run( garage: Arc, - bind_addr: SocketAddr, + bind_addr: UnixOrTCPSocketAddress, s3_region: String, shutdown_signal: impl Future, ) -> Result<(), GarageError> { ApiServer::new(s3_region, K2VApiServer { garage }) - .run_server(bind_addr, shutdown_signal) + .run_server(bind_addr, None, shutdown_signal) .await } } diff --git a/src/api/s3/api_server.rs b/src/api/s3/api_server.rs index 3f995d34..d675ab61 100644 --- a/src/api/s3/api_server.rs +++ b/src/api/s3/api_server.rs @@ -1,4 +1,3 @@ -use std::net::SocketAddr; use std::sync::Arc; use async_trait::async_trait; @@ -10,6 +9,7 @@ use hyper::{Body, Request, Response}; use opentelemetry::{trace::SpanRef, KeyValue}; use garage_util::error::Error as GarageError; +use garage_util::socket_address::UnixOrTCPSocketAddress; use garage_model::garage::Garage; use garage_model::key_table::Key; @@ -46,12 +46,12 @@ pub(crate) struct S3ApiEndpoint { impl S3ApiServer { pub async fn run( garage: Arc, - addr: SocketAddr, + addr: UnixOrTCPSocketAddress, s3_region: String, shutdown_signal: impl Future, ) -> Result<(), GarageError> { ApiServer::new(s3_region, S3ApiServer { garage }) - .run_server(addr, shutdown_signal) + .run_server(addr, None, shutdown_signal) .await } diff --git a/src/garage/server.rs b/src/garage/server.rs index 472616c7..3ad10b72 100644 --- a/src/garage/server.rs +++ b/src/garage/server.rs @@ -79,7 +79,7 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er "S3 API", tokio::spawn(S3ApiServer::run( garage.clone(), - *s3_bind_addr, + s3_bind_addr.clone(), config.s3_api.s3_region.clone(), wait_from(watch_cancel.clone()), )), @@ -94,7 +94,7 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er "K2V API", tokio::spawn(K2VApiServer::run( garage.clone(), - config.k2v_api.as_ref().unwrap().api_bind_addr, + config.k2v_api.as_ref().unwrap().api_bind_addr.clone(), config.s3_api.s3_region.clone(), wait_from(watch_cancel.clone()), )), @@ -110,7 +110,7 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er "Web", tokio::spawn(WebServer::run( garage.clone(), - web_config.bind_addr, + web_config.bind_addr.clone(), web_config.root_domain.clone(), wait_from(watch_cancel.clone()), )), @@ -121,7 +121,9 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er info!("Launching Admin API server..."); servers.push(( "Admin", - tokio::spawn(admin_server.run(*admin_bind_addr, wait_from(watch_cancel.clone()))), + tokio::spawn( + admin_server.run(admin_bind_addr.clone(), wait_from(watch_cancel.clone())), + ), )); } diff --git a/src/util/config.rs b/src/util/config.rs index cf31c87c..ad5c8e1f 100644 --- a/src/util/config.rs +++ b/src/util/config.rs @@ -7,6 +7,7 @@ use std::path::PathBuf; use serde::{de, Deserialize}; use crate::error::Error; +use crate::socket_address::UnixOrTCPSocketAddress; /// Represent the whole configuration #[derive(Deserialize, Debug, Clone)] @@ -129,7 +130,7 @@ pub struct DataDir { #[derive(Deserialize, Debug, Clone)] pub struct S3ApiConfig { /// Address and port to bind for api serving - pub api_bind_addr: Option, + pub api_bind_addr: Option, /// S3 region to use pub s3_region: String, /// Suffix to remove from domain name to find bucket. If None, @@ -141,14 +142,14 @@ pub struct S3ApiConfig { #[derive(Deserialize, Debug, Clone)] pub struct K2VApiConfig { /// Address and port to bind for api serving - pub api_bind_addr: SocketAddr, + pub api_bind_addr: UnixOrTCPSocketAddress, } /// Configuration for serving files as normal web server #[derive(Deserialize, Debug, Clone)] pub struct WebConfig { /// Address and port to bind for web serving - pub bind_addr: SocketAddr, + pub bind_addr: UnixOrTCPSocketAddress, /// Suffix to remove from domain name to find bucket pub root_domain: String, } @@ -157,7 +158,7 @@ pub struct WebConfig { #[derive(Deserialize, Debug, Clone, Default)] pub struct AdminConfig { /// Address and port to bind for admin API serving - pub api_bind_addr: Option, + pub api_bind_addr: Option, /// Bearer token to use to scrape metrics pub metrics_token: Option, diff --git a/src/util/lib.rs b/src/util/lib.rs index 15f0f829..7df77959 100644 --- a/src/util/lib.rs +++ b/src/util/lib.rs @@ -14,6 +14,7 @@ pub mod forwarded_headers; pub mod metrics; pub mod migrate; pub mod persister; +pub mod socket_address; pub mod time; pub mod tranquilizer; pub mod version; diff --git a/src/util/socket_address.rs b/src/util/socket_address.rs new file mode 100644 index 00000000..f01225f6 --- /dev/null +++ b/src/util/socket_address.rs @@ -0,0 +1,44 @@ +use std::fmt::{Debug, Display, Formatter}; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::str::FromStr; + +use serde::de::Error; +use serde::{Deserialize, Deserializer}; + +#[derive(Debug, Clone)] +pub enum UnixOrTCPSocketAddress { + TCPSocket(SocketAddr), + UnixSocket(PathBuf), +} + +impl Display for UnixOrTCPSocketAddress { + fn fmt(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result { + match self { + UnixOrTCPSocketAddress::TCPSocket(address) => write!(formatter, "http://{}", address), + UnixOrTCPSocketAddress::UnixSocket(path) => { + write!(formatter, "http+unix://{}", path.to_string_lossy()) + } + } + } +} + +impl<'de> Deserialize<'de> for UnixOrTCPSocketAddress { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let string = String::deserialize(deserializer)?; + let string = string.as_str(); + + if string.starts_with("/") { + Ok(UnixOrTCPSocketAddress::UnixSocket( + PathBuf::from_str(string).map_err(Error::custom)?, + )) + } else { + Ok(UnixOrTCPSocketAddress::TCPSocket( + SocketAddr::from_str(string).map_err(Error::custom)?, + )) + } + } +} diff --git a/src/web/Cargo.toml b/src/web/Cargo.toml index 6d0eba3a..eec47bcd 100644 --- a/src/web/Cargo.toml +++ b/src/web/Cargo.toml @@ -27,5 +27,8 @@ futures = "0.3" http = "0.2" hyper = { version = "0.14", features = ["server", "http1", "runtime", "tcp", "stream"] } +hyperlocal = { version = "0.8.0", default-features = false, features = ["server"] } + +tokio = { version = "1.0", default-features = false, features = ["net"] } opentelemetry = "0.17" diff --git a/src/web/web_server.rs b/src/web/web_server.rs index 287aef1a..73780efb 100644 --- a/src/web/web_server.rs +++ b/src/web/web_server.rs @@ -1,4 +1,6 @@ -use std::{convert::Infallible, net::SocketAddr, sync::Arc}; +use std::fs::{self, Permissions}; +use std::os::unix::prelude::PermissionsExt; +use std::{convert::Infallible, sync::Arc}; use futures::future::Future; @@ -9,6 +11,10 @@ use hyper::{ Body, Method, Request, Response, Server, StatusCode, }; +use hyperlocal::UnixServerExt; + +use tokio::net::UnixStream; + use opentelemetry::{ global, metrics::{Counter, ValueRecorder}, @@ -32,6 +38,7 @@ use garage_util::data::Uuid; use garage_util::error::Error as GarageError; use garage_util::forwarded_headers; use garage_util::metrics::{gen_trace_id, RecordDuration}; +use garage_util::socket_address::UnixOrTCPSocketAddress; struct WebMetrics { request_counter: Counter, @@ -69,7 +76,7 @@ impl WebServer { /// Run a web server pub async fn run( garage: Arc, - addr: SocketAddr, + addr: UnixOrTCPSocketAddress, root_domain: String, shutdown_signal: impl Future, ) -> Result<(), GarageError> { @@ -80,7 +87,7 @@ impl WebServer { root_domain, }); - let service = make_service_fn(|conn: &AddrStream| { + let tcp_service = make_service_fn(|conn: &AddrStream| { let web_server = web_server.clone(); let client_addr = conn.remote_addr(); @@ -88,23 +95,56 @@ impl WebServer { Ok::<_, Error>(service_fn(move |req: Request| { let web_server = web_server.clone(); - web_server.handle_request(req, client_addr) + web_server.handle_request(req, client_addr.to_string()) })) } }); - let server = Server::bind(&addr).serve(service); - let graceful = server.with_graceful_shutdown(shutdown_signal); - info!("Web server listening on http://{}", addr); + let unix_service = make_service_fn(|_: &UnixStream| { + let web_server = web_server.clone(); + + let path = addr.to_string(); + async move { + Ok::<_, Error>(service_fn(move |req: Request| { + let web_server = web_server.clone(); + + web_server.handle_request(req, path.clone()) + })) + } + }); + + info!("Web server listening on {}", addr); + + match addr { + UnixOrTCPSocketAddress::TCPSocket(addr) => { + Server::bind(&addr) + .serve(tcp_service) + .with_graceful_shutdown(shutdown_signal) + .await? + } + UnixOrTCPSocketAddress::UnixSocket(ref path) => { + if path.exists() { + fs::remove_file(path)? + } + + let bound = Server::bind_unix(path)?; + + fs::set_permissions(path, Permissions::from_mode(0o222))?; + + bound + .serve(unix_service) + .with_graceful_shutdown(shutdown_signal) + .await?; + } + }; - graceful.await?; Ok(()) } async fn handle_request( self: Arc, req: Request, - addr: SocketAddr, + addr: String, ) -> Result, Infallible> { if let Ok(forwarded_for_ip_addr) = forwarded_headers::handle_forwarded_for_headers(req.headers())