Implement CORS in web server
This commit is contained in:
parent
a8d077cdec
commit
f270df21c0
|
@ -15,7 +15,7 @@ mod signature;
|
|||
pub mod helpers;
|
||||
mod s3_bucket;
|
||||
mod s3_copy;
|
||||
mod s3_cors;
|
||||
pub mod s3_cors;
|
||||
mod s3_delete;
|
||||
pub mod s3_get;
|
||||
mod s3_list;
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
use quick_xml::de::from_reader;
|
||||
use std::sync::Arc;
|
||||
|
||||
use http::header::{
|
||||
ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
ACCESS_CONTROL_EXPOSE_HEADERS,
|
||||
};
|
||||
use hyper::{header::HeaderName, Body, Method, Request, Response, StatusCode};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::*;
|
||||
|
@ -188,7 +193,7 @@ impl CorsRule {
|
|||
GarageCorsRule {
|
||||
id: self.id.as_ref().map(|x| x.0.to_owned()),
|
||||
max_age_seconds: self.max_age_seconds.as_ref().map(|x| x.0 as u64),
|
||||
allowed_origins: convert_vec(&self.allowed_methods),
|
||||
allowed_origins: convert_vec(&self.allowed_origins),
|
||||
allowed_methods: convert_vec(&self.allowed_methods),
|
||||
allowed_headers: convert_vec(&self.allowed_headers),
|
||||
expose_headers: convert_vec(&self.expose_headers),
|
||||
|
@ -196,10 +201,65 @@ impl CorsRule {
|
|||
}
|
||||
|
||||
pub fn from_garage_cors_rule(rule: &GarageCorsRule) -> Self {
|
||||
unimplemented!()
|
||||
let convert_vec = |vval: &[String]| {
|
||||
vval.iter()
|
||||
.map(|x| Value(x.clone()))
|
||||
.collect::<Vec<Value>>()
|
||||
};
|
||||
Self {
|
||||
id: rule.id.as_ref().map(|x| Value(x.clone())),
|
||||
max_age_seconds: rule.max_age_seconds.map(|x| IntValue(x as i64)),
|
||||
allowed_origins: convert_vec(&rule.allowed_origins),
|
||||
allowed_methods: convert_vec(&rule.allowed_methods),
|
||||
allowed_headers: convert_vec(&rule.allowed_headers),
|
||||
expose_headers: convert_vec(&rule.expose_headers),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cors_rule_matches<'a, HI, S>(
|
||||
rule: &GarageCorsRule,
|
||||
origin: &'a str,
|
||||
method: &'a str,
|
||||
mut request_headers: HI,
|
||||
) -> bool
|
||||
where
|
||||
HI: Iterator<Item = S>,
|
||||
S: AsRef<str>,
|
||||
{
|
||||
rule.allowed_origins.iter().any(|x| x == "*" || x == origin)
|
||||
&& rule.allowed_methods.iter().any(|x| x == "*" || x == method)
|
||||
&& request_headers.all(|h| {
|
||||
rule.allowed_headers
|
||||
.iter()
|
||||
.any(|x| x == "*" || x == h.as_ref())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_cors_headers(
|
||||
resp: &mut Response<Body>,
|
||||
rule: &GarageCorsRule,
|
||||
) -> Result<(), http::header::InvalidHeaderValue> {
|
||||
let h = resp.headers_mut();
|
||||
h.insert(
|
||||
ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
rule.allowed_origins.join(", ").parse()?,
|
||||
);
|
||||
h.insert(
|
||||
ACCESS_CONTROL_ALLOW_METHODS,
|
||||
rule.allowed_methods.join(", ").parse()?,
|
||||
);
|
||||
h.insert(
|
||||
ACCESS_CONTROL_ALLOW_HEADERS,
|
||||
rule.allowed_headers.join(", ").parse()?,
|
||||
);
|
||||
h.insert(
|
||||
ACCESS_CONTROL_EXPOSE_HEADERS,
|
||||
rule.expose_headers.join(", ").parse()?,
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -2,19 +2,22 @@ use std::{borrow::Cow, convert::Infallible, net::SocketAddr, sync::Arc};
|
|||
|
||||
use futures::future::Future;
|
||||
|
||||
use http::header::{ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD};
|
||||
use hyper::{
|
||||
header::{HeaderValue, HOST},
|
||||
server::conn::AddrStream,
|
||||
service::{make_service_fn, service_fn},
|
||||
Body, Method, Request, Response, Server,
|
||||
Body, Method, Request, Response, Server, StatusCode,
|
||||
};
|
||||
|
||||
use crate::error::*;
|
||||
|
||||
use garage_api::error::{Error as ApiError, OkOrBadRequest};
|
||||
use garage_api::error::{Error as ApiError, OkOrBadRequest, OkOrInternalError};
|
||||
use garage_api::helpers::{authority_to_host, host_to_bucket};
|
||||
use garage_api::s3_cors::{add_cors_headers, cors_rule_matches};
|
||||
use garage_api::s3_get::{handle_get, handle_head};
|
||||
|
||||
use garage_model::bucket_table::Bucket;
|
||||
use garage_model::garage::Garage;
|
||||
|
||||
use garage_table::*;
|
||||
|
@ -132,72 +135,137 @@ async fn serve_file(garage: Arc<Garage>, req: &Request<Body>) -> Result<Response
|
|||
);
|
||||
|
||||
let ret_doc = match *req.method() {
|
||||
Method::HEAD => handle_head(garage.clone(), req, bucket_id, &key).await,
|
||||
Method::OPTIONS => return handle_options(&bucket, req),
|
||||
Method::HEAD => {
|
||||
return handle_head(garage.clone(), req, bucket_id, &key)
|
||||
.await
|
||||
.map_err(Error::from)
|
||||
}
|
||||
Method::GET => handle_get(garage.clone(), req, bucket_id, &key).await,
|
||||
_ => Err(ApiError::BadRequest("HTTP method not supported".into())),
|
||||
}
|
||||
.map_err(Error::from);
|
||||
|
||||
if let Err(error) = ret_doc {
|
||||
if *req.method() == Method::HEAD || !error.http_status_code().is_client_error() {
|
||||
// Do not return the error document in the following cases:
|
||||
// - the error is not a 4xx error code
|
||||
// - the request is a HEAD method
|
||||
// In this case we just return the error code and the error message in the body,
|
||||
// by relying on err_to_res that is called above when we return an Err.
|
||||
return Err(error);
|
||||
match ret_doc {
|
||||
Err(error) => {
|
||||
// For a HEAD or OPTIONS method, we don't return the error document
|
||||
// as content, we return above and just return the error message
|
||||
// by relying on err_to_res that is called when we return an Err.
|
||||
assert!(*req.method() != Method::HEAD && *req.method() != Method::OPTIONS);
|
||||
|
||||
if !error.http_status_code().is_client_error() {
|
||||
// Do not return the error document if it is not a 4xx error code.
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
// If no error document is set: just return the error directly
|
||||
let error_document = match &website_config.error_document {
|
||||
Some(ed) => ed.trim_start_matches('/').to_owned(),
|
||||
None => return Err(error),
|
||||
};
|
||||
|
||||
// We want to return the error document
|
||||
// Create a fake HTTP request with path = the error document
|
||||
let req2 = Request::builder()
|
||||
.uri(format!("http://{}/{}", host, &error_document))
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
match handle_get(garage, &req2, bucket_id, &error_document).await {
|
||||
Ok(mut error_doc) => {
|
||||
// The error won't be logged back in handle_request,
|
||||
// so log it here
|
||||
info!(
|
||||
"{} {} {} {}",
|
||||
req.method(),
|
||||
req.uri(),
|
||||
error.http_status_code(),
|
||||
error
|
||||
);
|
||||
|
||||
*error_doc.status_mut() = error.http_status_code();
|
||||
error.add_headers(error_doc.headers_mut());
|
||||
|
||||
// Preserve error message in a special header
|
||||
for error_line in error.to_string().split('\n') {
|
||||
if let Ok(v) = HeaderValue::from_bytes(error_line.as_bytes()) {
|
||||
error_doc.headers_mut().append("X-Garage-Error", v);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(error_doc)
|
||||
}
|
||||
Err(error_doc_error) => {
|
||||
warn!(
|
||||
"Couldn't get error document {} for bucket {:?}: {}",
|
||||
error_document, bucket_id, error_doc_error
|
||||
);
|
||||
Err(error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Same if no error document is set: just return the error directly
|
||||
let error_document = match &website_config.error_document {
|
||||
Some(ed) => ed.trim_start_matches('/').to_owned(),
|
||||
None => return Err(error),
|
||||
};
|
||||
|
||||
// We want to return the error document
|
||||
// Create a fake HTTP request with path = the error document
|
||||
let req2 = Request::builder()
|
||||
.uri(format!("http://{}/{}", host, &error_document))
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
match handle_get(garage, &req2, bucket_id, &error_document).await {
|
||||
Ok(mut error_doc) => {
|
||||
// The error won't be logged back in handle_request,
|
||||
// so log it here
|
||||
info!(
|
||||
"{} {} {} {}",
|
||||
req.method(),
|
||||
req.uri(),
|
||||
error.http_status_code(),
|
||||
error
|
||||
);
|
||||
|
||||
*error_doc.status_mut() = error.http_status_code();
|
||||
error.add_headers(error_doc.headers_mut());
|
||||
|
||||
// Preserve error message in a special header
|
||||
for error_line in error.to_string().split('\n') {
|
||||
if let Ok(v) = HeaderValue::from_bytes(error_line.as_bytes()) {
|
||||
error_doc.headers_mut().append("X-Garage-Error", v);
|
||||
Ok(mut resp) => {
|
||||
// Maybe add CORS headers
|
||||
if let Some(cors_config) = bucket.params().unwrap().cors_config.get() {
|
||||
if let Some(origin) = req.headers().get("Origin") {
|
||||
let origin = origin.to_str()?;
|
||||
let request_headers = match req.headers().get(ACCESS_CONTROL_REQUEST_HEADERS) {
|
||||
Some(h) => h.to_str()?.split(",").map(|h| h.trim()).collect::<Vec<_>>(),
|
||||
None => vec![],
|
||||
};
|
||||
let matching_rule = cors_config.iter().find(|rule| {
|
||||
cors_rule_matches(
|
||||
rule,
|
||||
&origin,
|
||||
&req.method().to_string(),
|
||||
request_headers.iter(),
|
||||
)
|
||||
});
|
||||
if let Some(rule) = matching_rule {
|
||||
add_cors_headers(&mut resp, &rule)
|
||||
.ok_or_internal_error("Invalid CORS configuration")?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(error_doc)
|
||||
}
|
||||
Err(error_doc_error) => {
|
||||
warn!(
|
||||
"Couldn't get error document {} for bucket {:?}: {}",
|
||||
error_document, bucket_id, error_doc_error
|
||||
);
|
||||
Err(error)
|
||||
}
|
||||
Ok(resp)
|
||||
}
|
||||
} else {
|
||||
ret_doc
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_options(bucket: &Bucket, req: &Request<Body>) -> Result<Response<Body>, Error> {
|
||||
let origin = req
|
||||
.headers()
|
||||
.get("Origin")
|
||||
.ok_or_bad_request("Missing Origin header")?
|
||||
.to_str()?;
|
||||
let request_method = req
|
||||
.headers()
|
||||
.get(ACCESS_CONTROL_REQUEST_METHOD)
|
||||
.ok_or_bad_request("Missing Access-Control-Request-Method header")?
|
||||
.to_str()?;
|
||||
let request_headers = match req.headers().get(ACCESS_CONTROL_REQUEST_HEADERS) {
|
||||
Some(h) => h.to_str()?.split(",").map(|h| h.trim()).collect::<Vec<_>>(),
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
if let Some(cors_config) = bucket.params().unwrap().cors_config.get() {
|
||||
let matching_rule = cors_config
|
||||
.iter()
|
||||
.find(|rule| cors_rule_matches(rule, &origin, &request_method, request_headers.iter()));
|
||||
if let Some(rule) = matching_rule {
|
||||
let mut resp = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.body(Body::empty())
|
||||
.map_err(ApiError::from)?;
|
||||
add_cors_headers(&mut resp, &rule)
|
||||
.ok_or_internal_error("Invalid CORS configuration")?;
|
||||
return Ok(resp);
|
||||
}
|
||||
}
|
||||
|
||||
Err(ApiError::Forbidden("No matching CORS rule".into()).into())
|
||||
}
|
||||
|
||||
/// Path to key
|
||||
///
|
||||
/// Convert the provided path to the internal key
|
||||
|
|
Loading…
Reference in a new issue