From 02e98e2d100a6af96369a72bc6979580424fe7df Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 9 Feb 2024 15:34:42 +0100 Subject: [PATCH] [header-override-650] implement header overriding in GetObject (fix #650) --- src/api/s3/api_server.rs | 22 +++++++++++++-- src/api/s3/get.rs | 50 ++++++++++++++++++++++++++++++++-- src/api/s3/router.rs | 21 +++++++++++++- src/garage/tests/s3/objects.rs | 24 ++++++++++++++++ src/web/web_server.rs | 22 +++++++++++++-- 5 files changed, 131 insertions(+), 8 deletions(-) diff --git a/src/api/s3/api_server.rs b/src/api/s3/api_server.rs index 0065ca59..7fac6261 100644 --- a/src/api/s3/api_server.rs +++ b/src/api/s3/api_server.rs @@ -178,8 +178,26 @@ impl ApiHandler for S3ApiServer { key, part_number, .. } => handle_head(garage, &req, bucket_id, &key, part_number).await, Endpoint::GetObject { - key, part_number, .. - } => handle_get(garage, &req, bucket_id, &key, part_number).await, + key, + part_number, + response_cache_control, + response_content_disposition, + response_content_encoding, + response_content_language, + response_content_type, + response_expires, + .. + } => { + let overrides = GetObjectOverrides { + response_cache_control, + response_content_disposition, + response_content_encoding, + response_content_language, + response_content_type, + response_expires, + }; + handle_get(garage, &req, bucket_id, &key, part_number, overrides).await + } Endpoint::UploadPart { key, part_number, diff --git a/src/api/s3/get.rs b/src/api/s3/get.rs index f70dad7d..53f0a345 100644 --- a/src/api/s3/get.rs +++ b/src/api/s3/get.rs @@ -1,12 +1,14 @@ //! Function related to GET and HEAD requests +use std::convert::TryInto; use std::sync::Arc; use std::time::{Duration, UNIX_EPOCH}; use futures::future; use futures::stream::{self, StreamExt}; use http::header::{ - ACCEPT_RANGES, CONTENT_LENGTH, CONTENT_RANGE, CONTENT_TYPE, ETAG, IF_MODIFIED_SINCE, - IF_NONE_MATCH, LAST_MODIFIED, RANGE, + ACCEPT_RANGES, CACHE_CONTROL, CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LANGUAGE, + CONTENT_LENGTH, CONTENT_RANGE, CONTENT_TYPE, ETAG, EXPIRES, IF_MODIFIED_SINCE, IF_NONE_MATCH, + LAST_MODIFIED, RANGE, }; use hyper::{body::Body, Request, Response, StatusCode}; use tokio::sync::mpsc; @@ -27,6 +29,16 @@ use crate::s3::error::*; const X_AMZ_MP_PARTS_COUNT: &str = "x-amz-mp-parts-count"; +#[derive(Default)] +pub struct GetObjectOverrides { + pub(crate) response_cache_control: Option, + pub(crate) response_content_disposition: Option, + pub(crate) response_content_encoding: Option, + pub(crate) response_content_language: Option, + pub(crate) response_content_type: Option, + pub(crate) response_expires: Option, +} + fn object_headers( version: &ObjectVersion, version_meta: &ObjectVersionMeta, @@ -52,6 +64,32 @@ fn object_headers( resp } +/// Override headers according to specific query parameters, see +/// section "Overriding response header values through the request" in +/// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetObject.html +fn getobject_override_headers( + overrides: GetObjectOverrides, + resp: &mut http::response::Builder, +) -> Result<(), Error> { + // TODO: this only applies for signed requests, so when we support + // anonymous access in the future we will have to do a permission check here + let overrides = [ + (CACHE_CONTROL, overrides.response_cache_control), + (CONTENT_DISPOSITION, overrides.response_content_disposition), + (CONTENT_ENCODING, overrides.response_content_encoding), + (CONTENT_LANGUAGE, overrides.response_content_language), + (CONTENT_TYPE, overrides.response_content_type), + (EXPIRES, overrides.response_expires), + ]; + for (hdr, val_opt) in overrides { + if let Some(val) = val_opt { + let val = val.try_into().ok_or_bad_request("invalid header value")?; + resp.headers_mut().unwrap().insert(hdr, val); + } + } + Ok(()) +} + fn try_answer_cached( version: &ObjectVersion, version_meta: &ObjectVersionMeta, @@ -185,6 +223,7 @@ pub async fn handle_get( bucket_id: Uuid, key: &str, part_number: Option, + overrides: GetObjectOverrides, ) -> Result, Error> { let object = garage .object_table @@ -236,9 +275,10 @@ pub async fn handle_get( (None, None) => (), } - let resp_builder = object_headers(last_v, last_v_meta) + let mut resp_builder = object_headers(last_v, last_v_meta) .header(CONTENT_LENGTH, format!("{}", last_v_meta.size)) .status(StatusCode::OK); + getobject_override_headers(overrides, &mut resp_builder)?; match &last_v_data { ObjectVersionData::DeleteMarker => unreachable!(), @@ -303,6 +343,9 @@ async fn handle_get_range( begin: u64, end: u64, ) -> Result, Error> { + // Here we do not use getobject_override_headers because we don't + // want to add any overridden headers (those should not be added + // when returning PARTIAL_CONTENT) let resp_builder = object_headers(version, version_meta) .header(CONTENT_LENGTH, format!("{}", end - begin)) .header( @@ -343,6 +386,7 @@ async fn handle_get_part( version_meta: &ObjectVersionMeta, part_number: u64, ) -> Result, Error> { + // Same as for get_range, no getobject_override_headers let resp_builder = object_headers(object_version, version_meta).status(StatusCode::PARTIAL_CONTENT); diff --git a/src/api/s3/router.rs b/src/api/s3/router.rs index 05882885..e7ac1d77 100644 --- a/src/api/s3/router.rs +++ b/src/api/s3/router.rs @@ -125,6 +125,12 @@ pub enum Endpoint { key: String, part_number: Option, version_id: Option, + response_cache_control: Option, + response_content_disposition: Option, + response_content_encoding: Option, + response_content_language: Option, + response_content_type: Option, + response_expires: Option, }, GetObjectAcl { key: String, @@ -358,7 +364,14 @@ impl Endpoint { (query.keyword.take().unwrap_or_default(), key, query, None), key: [ EMPTY if upload_id => ListParts (query::upload_id, opt_parse::max_parts, opt_parse::part_number_marker), - EMPTY => GetObject (query_opt::version_id, opt_parse::part_number), + EMPTY => GetObject (query_opt::version_id, + opt_parse::part_number, + query_opt::response_cache_control, + query_opt::response_content_disposition, + query_opt::response_content_encoding, + query_opt::response_content_language, + query_opt::response_content_type, + query_opt::response_expires), ACL => GetObjectAcl (query_opt::version_id), LEGAL_HOLD => GetObjectLegalHold (query_opt::version_id), RETENTION => GetObjectRetention (query_opt::version_id), @@ -671,6 +684,12 @@ generateQueryParameters! { "partNumber" => part_number, "part-number-marker" => part_number_marker, "prefix" => prefix, + "response-cache-control" => response_cache_control, + "response-content-disposition" => response_content_disposition, + "response-content-encoding" => response_content_encoding, + "response-content-language" => response_content_language, + "response-content-type" => response_content_type, + "response-expires" => response_expires, "select-type" => select_type, "start-after" => start_after, "uploadId" => upload_id, diff --git a/src/garage/tests/s3/objects.rs b/src/garage/tests/s3/objects.rs index ca35b435..ad5f63f1 100644 --- a/src/garage/tests/s3/objects.rs +++ b/src/garage/tests/s3/objects.rs @@ -185,6 +185,30 @@ async fn test_getobject() { assert_eq!(o.content_range.unwrap().as_str(), "bytes 57-61/62"); assert_bytes_eq!(o.body, &BODY[57..]); } + { + let exp = aws_sdk_s3::primitives::DateTime::from_secs(10000000000); + let o = ctx + .client + .get_object() + .bucket(&bucket) + .key(STD_KEY) + .response_content_type("application/x-dummy-test") + .response_cache_control("ccdummy") + .response_content_disposition("cddummy") + .response_content_encoding("cedummy") + .response_content_language("cldummy") + .response_expires(exp) + .send() + .await + .unwrap(); + assert_eq!(o.content_type.unwrap().as_str(), "application/x-dummy-test"); + assert_eq!(o.cache_control.unwrap().as_str(), "ccdummy"); + assert_eq!(o.content_disposition.unwrap().as_str(), "cddummy"); + assert_eq!(o.content_encoding.unwrap().as_str(), "cedummy"); + assert_eq!(o.content_language.unwrap().as_str(), "cldummy"); + assert_eq!(o.expires.unwrap(), exp); + assert_bytes_eq!(o.body, &BODY[..]); + } } #[tokio::test] diff --git a/src/web/web_server.rs b/src/web/web_server.rs index 269f37f2..0f9b5dc8 100644 --- a/src/web/web_server.rs +++ b/src/web/web_server.rs @@ -247,7 +247,17 @@ impl WebServer { .map_err(ApiError::from) .map(|res| res.map(|_empty_body: EmptyBody| empty_body())), Method::HEAD => handle_head(self.garage.clone(), &req, bucket_id, &key, None).await, - Method::GET => handle_get(self.garage.clone(), &req, bucket_id, &key, None).await, + Method::GET => { + handle_get( + self.garage.clone(), + &req, + bucket_id, + &key, + None, + Default::default(), + ) + .await + } _ => Err(ApiError::bad_request("HTTP method not supported")), }; @@ -291,7 +301,15 @@ impl WebServer { .body(empty_body::()) .unwrap(); - match handle_get(self.garage.clone(), &req2, bucket_id, &error_document, None).await + match handle_get( + self.garage.clone(), + &req2, + bucket_id, + &error_document, + None, + Default::default(), + ) + .await { Ok(mut error_doc) => { // The error won't be logged back in handle_request,