diff --git a/src/api/s3_post_object.rs b/src/api/s3_post_object.rs index 6d602b5d..52229303 100644 --- a/src/api/s3_post_object.rs +++ b/src/api/s3_post_object.rs @@ -1,7 +1,9 @@ use std::collections::HashMap; use std::convert::TryInto; +use std::ops::RangeInclusive; use std::sync::Arc; +use chrono::{DateTime, Duration, Utc}; use futures::StreamExt; use hyper::header::{self, HeaderMap, HeaderName, HeaderValue}; use hyper::{Body, Request, Response, StatusCode}; @@ -103,7 +105,55 @@ pub async fn handle_post_object( } let decoded_policy = base64::decode(&policy)?; - let _decoded_policy: Policy = serde_json::from_slice(&decoded_policy).unwrap(); + let decoded_policy: Policy = + serde_json::from_slice(&decoded_policy).ok_or_bad_request("Invalid policy")?; + + let expiration: DateTime = DateTime::parse_from_rfc3339(&decoded_policy.expiration) + .ok_or_bad_request("Invalid expiration date")? + .into(); + if Utc::now() - expiration > Duration::zero() { + return Err(Error::BadRequest( + "Expiration date is in the paste".to_string(), + )); + } + + let conditions = decoded_policy.into_conditions()?; + + for (key, value) in params.iter() { + let key = key.as_str(); + if key.eq_ignore_ascii_case("content-type") { + for cond in &conditions.content_type { + let ok = match cond { + Operation::Equal(s) => value == s, + Operation::StartsWith(s) => { + value.to_str()?.split(',').all(|v| v.starts_with(s)) + } + }; + if !ok { + return Err(Error::BadRequest(format!( + "Key '{}' has value not allowed in policy", + key + ))); + } + } + } else { + let conds = conditions.params.get(key).ok_or_else(|| { + Error::BadRequest(format!("Key '{}' is not allowed in policy", key)) + })?; + for cond in conds { + let ok = match cond { + Operation::Equal(s) => s == value, + Operation::StartsWith(s) => value.to_str()?.starts_with(s), + }; + if !ok { + return Err(Error::BadRequest(format!( + "Key '{}' has value not allowed in policy", + key + ))); + } + } + } + } // TODO validate policy against request // unsafe to merge until implemented @@ -169,15 +219,74 @@ pub async fn handle_post_object( )) } -// TODO remove allow(dead_code) when policy is verified - -#[allow(dead_code)] #[derive(Deserialize)] struct Policy { expiration: String, conditions: Vec, } +impl Policy { + fn into_conditions(self) -> Result { + let mut params = HashMap::<_, Vec<_>>::new(); + let mut content_type = Vec::new(); + + let mut length = (0, u64::MAX); + for condition in self.conditions { + match condition { + PolicyCondition::Equal(map) => { + if map.len() != 1 { + return Err(Error::BadRequest("Invalid policy item".to_owned())); + } + let (k, v) = map.into_iter().next().expect("size was verified"); + if k.eq_ignore_ascii_case("content-type") { + content_type.push(Operation::Equal(v)); + } else { + params.entry(k).or_default().push(Operation::Equal(v)); + } + } + PolicyCondition::OtherOp([cond, mut key, value]) => { + if key.remove(0) != '$' { + return Err(Error::BadRequest("Invalid policy item".to_owned())); + } + match cond.as_str() { + "eq" => { + if key.eq_ignore_ascii_case("content-type") { + content_type.push(Operation::Equal(value)); + } else { + params.entry(key).or_default().push(Operation::Equal(value)); + } + } + "starts-with" => { + if key.eq_ignore_ascii_case("content-type") { + content_type.push(Operation::StartsWith(value)); + } else { + params + .entry(key) + .or_default() + .push(Operation::StartsWith(value)); + } + } + _ => return Err(Error::BadRequest("Invalid policy item".to_owned())), + } + } + PolicyCondition::SizeRange(key, min, max) => { + if key == "content-length-range" { + length.0 = length.0.max(min); + length.1 = length.1.min(max); + } else { + return Err(Error::BadRequest("Invalid policy item".to_owned())); + } + } + } + } + Ok(Conditions { + params, + content_type, + content_length: RangeInclusive::new(length.0, length.1), + }) + } +} + /// A single condition from a policy #[derive(Deserialize)] #[serde(untagged)] @@ -188,11 +297,15 @@ enum PolicyCondition { SizeRange(String, u64, u64), } -#[allow(dead_code)] +struct Conditions { + params: HashMap>, + content_type: Vec, + #[allow(dead_code)] + content_length: RangeInclusive, +} + #[derive(PartialEq, Eq)] enum Operation { - Equal, - StartsWith, - StartsWithCT, - SizeRange, + Equal(String), + StartsWith(String), }