Implement CORS in web server

This commit is contained in:
Alex 2022-01-07 17:03:25 +01:00
parent a8d077cdec
commit f270df21c0
No known key found for this signature in database
GPG key ID: EDABF9711E244EB1
3 changed files with 186 additions and 58 deletions

View file

@ -15,7 +15,7 @@ mod signature;
pub mod helpers;
mod s3_bucket;
mod s3_copy;
mod s3_cors;
pub mod s3_cors;
mod s3_delete;
pub mod s3_get;
mod s3_list;

View file

@ -1,7 +1,12 @@
use quick_xml::de::from_reader;
use std::sync::Arc;
use http::header::{
ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
ACCESS_CONTROL_EXPOSE_HEADERS,
};
use hyper::{header::HeaderName, Body, Method, Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use crate::error::*;
@ -188,7 +193,7 @@ impl CorsRule {
GarageCorsRule {
id: self.id.as_ref().map(|x| x.0.to_owned()),
max_age_seconds: self.max_age_seconds.as_ref().map(|x| x.0 as u64),
allowed_origins: convert_vec(&self.allowed_methods),
allowed_origins: convert_vec(&self.allowed_origins),
allowed_methods: convert_vec(&self.allowed_methods),
allowed_headers: convert_vec(&self.allowed_headers),
expose_headers: convert_vec(&self.expose_headers),
@ -196,10 +201,65 @@ impl CorsRule {
}
pub fn from_garage_cors_rule(rule: &GarageCorsRule) -> Self {
unimplemented!()
let convert_vec = |vval: &[String]| {
vval.iter()
.map(|x| Value(x.clone()))
.collect::<Vec<Value>>()
};
Self {
id: rule.id.as_ref().map(|x| Value(x.clone())),
max_age_seconds: rule.max_age_seconds.map(|x| IntValue(x as i64)),
allowed_origins: convert_vec(&rule.allowed_origins),
allowed_methods: convert_vec(&rule.allowed_methods),
allowed_headers: convert_vec(&rule.allowed_headers),
expose_headers: convert_vec(&rule.expose_headers),
}
}
}
pub fn cors_rule_matches<'a, HI, S>(
rule: &GarageCorsRule,
origin: &'a str,
method: &'a str,
mut request_headers: HI,
) -> bool
where
HI: Iterator<Item = S>,
S: AsRef<str>,
{
rule.allowed_origins.iter().any(|x| x == "*" || x == origin)
&& rule.allowed_methods.iter().any(|x| x == "*" || x == method)
&& request_headers.all(|h| {
rule.allowed_headers
.iter()
.any(|x| x == "*" || x == h.as_ref())
})
}
pub fn add_cors_headers(
resp: &mut Response<Body>,
rule: &GarageCorsRule,
) -> Result<(), http::header::InvalidHeaderValue> {
let h = resp.headers_mut();
h.insert(
ACCESS_CONTROL_ALLOW_ORIGIN,
rule.allowed_origins.join(", ").parse()?,
);
h.insert(
ACCESS_CONTROL_ALLOW_METHODS,
rule.allowed_methods.join(", ").parse()?,
);
h.insert(
ACCESS_CONTROL_ALLOW_HEADERS,
rule.allowed_headers.join(", ").parse()?,
);
h.insert(
ACCESS_CONTROL_EXPOSE_HEADERS,
rule.expose_headers.join(", ").parse()?,
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -2,19 +2,22 @@ use std::{borrow::Cow, convert::Infallible, net::SocketAddr, sync::Arc};
use futures::future::Future;
use http::header::{ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD};
use hyper::{
header::{HeaderValue, HOST},
server::conn::AddrStream,
service::{make_service_fn, service_fn},
Body, Method, Request, Response, Server,
Body, Method, Request, Response, Server, StatusCode,
};
use crate::error::*;
use garage_api::error::{Error as ApiError, OkOrBadRequest};
use garage_api::error::{Error as ApiError, OkOrBadRequest, OkOrInternalError};
use garage_api::helpers::{authority_to_host, host_to_bucket};
use garage_api::s3_cors::{add_cors_headers, cors_rule_matches};
use garage_api::s3_get::{handle_get, handle_head};
use garage_model::bucket_table::Bucket;
use garage_model::garage::Garage;
use garage_table::*;
@ -132,72 +135,137 @@ async fn serve_file(garage: Arc<Garage>, req: &Request<Body>) -> Result<Response
);
let ret_doc = match *req.method() {
Method::HEAD => handle_head(garage.clone(), req, bucket_id, &key).await,
Method::OPTIONS => return handle_options(&bucket, req),
Method::HEAD => {
return handle_head(garage.clone(), req, bucket_id, &key)
.await
.map_err(Error::from)
}
Method::GET => handle_get(garage.clone(), req, bucket_id, &key).await,
_ => Err(ApiError::BadRequest("HTTP method not supported".into())),
}
.map_err(Error::from);
if let Err(error) = ret_doc {
if *req.method() == Method::HEAD || !error.http_status_code().is_client_error() {
// Do not return the error document in the following cases:
// - the error is not a 4xx error code
// - the request is a HEAD method
// In this case we just return the error code and the error message in the body,
// by relying on err_to_res that is called above when we return an Err.
return Err(error);
match ret_doc {
Err(error) => {
// For a HEAD or OPTIONS method, we don't return the error document
// as content, we return above and just return the error message
// by relying on err_to_res that is called when we return an Err.
assert!(*req.method() != Method::HEAD && *req.method() != Method::OPTIONS);
if !error.http_status_code().is_client_error() {
// Do not return the error document if it is not a 4xx error code.
return Err(error);
}
// If no error document is set: just return the error directly
let error_document = match &website_config.error_document {
Some(ed) => ed.trim_start_matches('/').to_owned(),
None => return Err(error),
};
// We want to return the error document
// Create a fake HTTP request with path = the error document
let req2 = Request::builder()
.uri(format!("http://{}/{}", host, &error_document))
.body(Body::empty())
.unwrap();
match handle_get(garage, &req2, bucket_id, &error_document).await {
Ok(mut error_doc) => {
// The error won't be logged back in handle_request,
// so log it here
info!(
"{} {} {} {}",
req.method(),
req.uri(),
error.http_status_code(),
error
);
*error_doc.status_mut() = error.http_status_code();
error.add_headers(error_doc.headers_mut());
// Preserve error message in a special header
for error_line in error.to_string().split('\n') {
if let Ok(v) = HeaderValue::from_bytes(error_line.as_bytes()) {
error_doc.headers_mut().append("X-Garage-Error", v);
}
}
Ok(error_doc)
}
Err(error_doc_error) => {
warn!(
"Couldn't get error document {} for bucket {:?}: {}",
error_document, bucket_id, error_doc_error
);
Err(error)
}
}
}
// Same if no error document is set: just return the error directly
let error_document = match &website_config.error_document {
Some(ed) => ed.trim_start_matches('/').to_owned(),
None => return Err(error),
};
// We want to return the error document
// Create a fake HTTP request with path = the error document
let req2 = Request::builder()
.uri(format!("http://{}/{}", host, &error_document))
.body(Body::empty())
.unwrap();
match handle_get(garage, &req2, bucket_id, &error_document).await {
Ok(mut error_doc) => {
// The error won't be logged back in handle_request,
// so log it here
info!(
"{} {} {} {}",
req.method(),
req.uri(),
error.http_status_code(),
error
);
*error_doc.status_mut() = error.http_status_code();
error.add_headers(error_doc.headers_mut());
// Preserve error message in a special header
for error_line in error.to_string().split('\n') {
if let Ok(v) = HeaderValue::from_bytes(error_line.as_bytes()) {
error_doc.headers_mut().append("X-Garage-Error", v);
Ok(mut resp) => {
// Maybe add CORS headers
if let Some(cors_config) = bucket.params().unwrap().cors_config.get() {
if let Some(origin) = req.headers().get("Origin") {
let origin = origin.to_str()?;
let request_headers = match req.headers().get(ACCESS_CONTROL_REQUEST_HEADERS) {
Some(h) => h.to_str()?.split(",").map(|h| h.trim()).collect::<Vec<_>>(),
None => vec![],
};
let matching_rule = cors_config.iter().find(|rule| {
cors_rule_matches(
rule,
&origin,
&req.method().to_string(),
request_headers.iter(),
)
});
if let Some(rule) = matching_rule {
add_cors_headers(&mut resp, &rule)
.ok_or_internal_error("Invalid CORS configuration")?;
}
}
Ok(error_doc)
}
Err(error_doc_error) => {
warn!(
"Couldn't get error document {} for bucket {:?}: {}",
error_document, bucket_id, error_doc_error
);
Err(error)
}
Ok(resp)
}
} else {
ret_doc
}
}
fn handle_options(bucket: &Bucket, req: &Request<Body>) -> Result<Response<Body>, Error> {
let origin = req
.headers()
.get("Origin")
.ok_or_bad_request("Missing Origin header")?
.to_str()?;
let request_method = req
.headers()
.get(ACCESS_CONTROL_REQUEST_METHOD)
.ok_or_bad_request("Missing Access-Control-Request-Method header")?
.to_str()?;
let request_headers = match req.headers().get(ACCESS_CONTROL_REQUEST_HEADERS) {
Some(h) => h.to_str()?.split(",").map(|h| h.trim()).collect::<Vec<_>>(),
None => vec![],
};
if let Some(cors_config) = bucket.params().unwrap().cors_config.get() {
let matching_rule = cors_config
.iter()
.find(|rule| cors_rule_matches(rule, &origin, &request_method, request_headers.iter()));
if let Some(rule) = matching_rule {
let mut resp = Response::builder()
.status(StatusCode::OK)
.body(Body::empty())
.map_err(ApiError::from)?;
add_cors_headers(&mut resp, &rule)
.ok_or_internal_error("Invalid CORS configuration")?;
return Ok(resp);
}
}
Err(ApiError::Forbidden("No matching CORS rule".into()).into())
}
/// Path to key
///
/// Convert the provided path to the internal key