From f270df21c02d4c33078b08ce99b47326c96a67c7 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 7 Jan 2022 17:03:25 +0100 Subject: [PATCH] Implement CORS in web server --- src/api/lib.rs | 2 +- src/api/s3_cors.rs | 64 ++++++++++++++- src/web/web_server.rs | 178 +++++++++++++++++++++++++++++------------- 3 files changed, 186 insertions(+), 58 deletions(-) diff --git a/src/api/lib.rs b/src/api/lib.rs index 9e19e731..bb5a8265 100644 --- a/src/api/lib.rs +++ b/src/api/lib.rs @@ -15,7 +15,7 @@ mod signature; pub mod helpers; mod s3_bucket; mod s3_copy; -mod s3_cors; +pub mod s3_cors; mod s3_delete; pub mod s3_get; mod s3_list; diff --git a/src/api/s3_cors.rs b/src/api/s3_cors.rs index 4a539ab4..def7fbc3 100644 --- a/src/api/s3_cors.rs +++ b/src/api/s3_cors.rs @@ -1,7 +1,12 @@ use quick_xml::de::from_reader; use std::sync::Arc; +use http::header::{ + ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, + ACCESS_CONTROL_EXPOSE_HEADERS, +}; use hyper::{header::HeaderName, Body, Method, Request, Response, StatusCode}; + use serde::{Deserialize, Serialize}; use crate::error::*; @@ -188,7 +193,7 @@ impl CorsRule { GarageCorsRule { id: self.id.as_ref().map(|x| x.0.to_owned()), max_age_seconds: self.max_age_seconds.as_ref().map(|x| x.0 as u64), - allowed_origins: convert_vec(&self.allowed_methods), + allowed_origins: convert_vec(&self.allowed_origins), allowed_methods: convert_vec(&self.allowed_methods), allowed_headers: convert_vec(&self.allowed_headers), expose_headers: convert_vec(&self.expose_headers), @@ -196,10 +201,65 @@ impl CorsRule { } pub fn from_garage_cors_rule(rule: &GarageCorsRule) -> Self { - unimplemented!() + let convert_vec = |vval: &[String]| { + vval.iter() + .map(|x| Value(x.clone())) + .collect::>() + }; + Self { + id: rule.id.as_ref().map(|x| Value(x.clone())), + max_age_seconds: rule.max_age_seconds.map(|x| IntValue(x as i64)), + allowed_origins: convert_vec(&rule.allowed_origins), + allowed_methods: convert_vec(&rule.allowed_methods), + allowed_headers: convert_vec(&rule.allowed_headers), + expose_headers: convert_vec(&rule.expose_headers), + } } } +pub fn cors_rule_matches<'a, HI, S>( + rule: &GarageCorsRule, + origin: &'a str, + method: &'a str, + mut request_headers: HI, +) -> bool +where + HI: Iterator, + S: AsRef, +{ + rule.allowed_origins.iter().any(|x| x == "*" || x == origin) + && rule.allowed_methods.iter().any(|x| x == "*" || x == method) + && request_headers.all(|h| { + rule.allowed_headers + .iter() + .any(|x| x == "*" || x == h.as_ref()) + }) +} + +pub fn add_cors_headers( + resp: &mut Response, + rule: &GarageCorsRule, +) -> Result<(), http::header::InvalidHeaderValue> { + let h = resp.headers_mut(); + h.insert( + ACCESS_CONTROL_ALLOW_ORIGIN, + rule.allowed_origins.join(", ").parse()?, + ); + h.insert( + ACCESS_CONTROL_ALLOW_METHODS, + rule.allowed_methods.join(", ").parse()?, + ); + h.insert( + ACCESS_CONTROL_ALLOW_HEADERS, + rule.allowed_headers.join(", ").parse()?, + ); + h.insert( + ACCESS_CONTROL_EXPOSE_HEADERS, + rule.expose_headers.join(", ").parse()?, + ); + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/web/web_server.rs b/src/web/web_server.rs index 834f31e8..aefee7d1 100644 --- a/src/web/web_server.rs +++ b/src/web/web_server.rs @@ -2,19 +2,22 @@ use std::{borrow::Cow, convert::Infallible, net::SocketAddr, sync::Arc}; use futures::future::Future; +use http::header::{ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD}; use hyper::{ header::{HeaderValue, HOST}, server::conn::AddrStream, service::{make_service_fn, service_fn}, - Body, Method, Request, Response, Server, + Body, Method, Request, Response, Server, StatusCode, }; use crate::error::*; -use garage_api::error::{Error as ApiError, OkOrBadRequest}; +use garage_api::error::{Error as ApiError, OkOrBadRequest, OkOrInternalError}; use garage_api::helpers::{authority_to_host, host_to_bucket}; +use garage_api::s3_cors::{add_cors_headers, cors_rule_matches}; use garage_api::s3_get::{handle_get, handle_head}; +use garage_model::bucket_table::Bucket; use garage_model::garage::Garage; use garage_table::*; @@ -132,72 +135,137 @@ async fn serve_file(garage: Arc, req: &Request) -> Result handle_head(garage.clone(), req, bucket_id, &key).await, + Method::OPTIONS => return handle_options(&bucket, req), + Method::HEAD => { + return handle_head(garage.clone(), req, bucket_id, &key) + .await + .map_err(Error::from) + } Method::GET => handle_get(garage.clone(), req, bucket_id, &key).await, _ => Err(ApiError::BadRequest("HTTP method not supported".into())), } .map_err(Error::from); - if let Err(error) = ret_doc { - if *req.method() == Method::HEAD || !error.http_status_code().is_client_error() { - // Do not return the error document in the following cases: - // - the error is not a 4xx error code - // - the request is a HEAD method - // In this case we just return the error code and the error message in the body, - // by relying on err_to_res that is called above when we return an Err. - return Err(error); + match ret_doc { + Err(error) => { + // For a HEAD or OPTIONS method, we don't return the error document + // as content, we return above and just return the error message + // by relying on err_to_res that is called when we return an Err. + assert!(*req.method() != Method::HEAD && *req.method() != Method::OPTIONS); + + if !error.http_status_code().is_client_error() { + // Do not return the error document if it is not a 4xx error code. + return Err(error); + } + + // If no error document is set: just return the error directly + let error_document = match &website_config.error_document { + Some(ed) => ed.trim_start_matches('/').to_owned(), + None => return Err(error), + }; + + // We want to return the error document + // Create a fake HTTP request with path = the error document + let req2 = Request::builder() + .uri(format!("http://{}/{}", host, &error_document)) + .body(Body::empty()) + .unwrap(); + + match handle_get(garage, &req2, bucket_id, &error_document).await { + Ok(mut error_doc) => { + // The error won't be logged back in handle_request, + // so log it here + info!( + "{} {} {} {}", + req.method(), + req.uri(), + error.http_status_code(), + error + ); + + *error_doc.status_mut() = error.http_status_code(); + error.add_headers(error_doc.headers_mut()); + + // Preserve error message in a special header + for error_line in error.to_string().split('\n') { + if let Ok(v) = HeaderValue::from_bytes(error_line.as_bytes()) { + error_doc.headers_mut().append("X-Garage-Error", v); + } + } + + Ok(error_doc) + } + Err(error_doc_error) => { + warn!( + "Couldn't get error document {} for bucket {:?}: {}", + error_document, bucket_id, error_doc_error + ); + Err(error) + } + } } - - // Same if no error document is set: just return the error directly - let error_document = match &website_config.error_document { - Some(ed) => ed.trim_start_matches('/').to_owned(), - None => return Err(error), - }; - - // We want to return the error document - // Create a fake HTTP request with path = the error document - let req2 = Request::builder() - .uri(format!("http://{}/{}", host, &error_document)) - .body(Body::empty()) - .unwrap(); - - match handle_get(garage, &req2, bucket_id, &error_document).await { - Ok(mut error_doc) => { - // The error won't be logged back in handle_request, - // so log it here - info!( - "{} {} {} {}", - req.method(), - req.uri(), - error.http_status_code(), - error - ); - - *error_doc.status_mut() = error.http_status_code(); - error.add_headers(error_doc.headers_mut()); - - // Preserve error message in a special header - for error_line in error.to_string().split('\n') { - if let Ok(v) = HeaderValue::from_bytes(error_line.as_bytes()) { - error_doc.headers_mut().append("X-Garage-Error", v); + Ok(mut resp) => { + // Maybe add CORS headers + if let Some(cors_config) = bucket.params().unwrap().cors_config.get() { + if let Some(origin) = req.headers().get("Origin") { + let origin = origin.to_str()?; + let request_headers = match req.headers().get(ACCESS_CONTROL_REQUEST_HEADERS) { + Some(h) => h.to_str()?.split(",").map(|h| h.trim()).collect::>(), + None => vec![], + }; + let matching_rule = cors_config.iter().find(|rule| { + cors_rule_matches( + rule, + &origin, + &req.method().to_string(), + request_headers.iter(), + ) + }); + if let Some(rule) = matching_rule { + add_cors_headers(&mut resp, &rule) + .ok_or_internal_error("Invalid CORS configuration")?; } } - - Ok(error_doc) - } - Err(error_doc_error) => { - warn!( - "Couldn't get error document {} for bucket {:?}: {}", - error_document, bucket_id, error_doc_error - ); - Err(error) } + Ok(resp) } - } else { - ret_doc } } +fn handle_options(bucket: &Bucket, req: &Request) -> Result, Error> { + let origin = req + .headers() + .get("Origin") + .ok_or_bad_request("Missing Origin header")? + .to_str()?; + let request_method = req + .headers() + .get(ACCESS_CONTROL_REQUEST_METHOD) + .ok_or_bad_request("Missing Access-Control-Request-Method header")? + .to_str()?; + let request_headers = match req.headers().get(ACCESS_CONTROL_REQUEST_HEADERS) { + Some(h) => h.to_str()?.split(",").map(|h| h.trim()).collect::>(), + None => vec![], + }; + + if let Some(cors_config) = bucket.params().unwrap().cors_config.get() { + let matching_rule = cors_config + .iter() + .find(|rule| cors_rule_matches(rule, &origin, &request_method, request_headers.iter())); + if let Some(rule) = matching_rule { + let mut resp = Response::builder() + .status(StatusCode::OK) + .body(Body::empty()) + .map_err(ApiError::from)?; + add_cors_headers(&mut resp, &rule) + .ok_or_internal_error("Invalid CORS configuration")?; + return Ok(resp); + } + } + + Err(ApiError::Forbidden("No matching CORS rule".into()).into()) +} + /// Path to key /// /// Convert the provided path to the internal key