diff --git a/src/api/s3_post_object.rs b/src/api/s3_post_object.rs index 52229303..20b2e13c 100644 --- a/src/api/s3_post_object.rs +++ b/src/api/s3_post_object.rs @@ -2,9 +2,11 @@ use std::collections::HashMap; use std::convert::TryInto; use std::ops::RangeInclusive; use std::sync::Arc; +use std::task::{Context, Poll}; +use bytes::Bytes; use chrono::{DateTime, Duration, Utc}; -use futures::StreamExt; +use futures::{Stream, StreamExt}; use hyper::header::{self, HeaderMap, HeaderName, HeaderValue}; use hyper::{Body, Request, Response, StatusCode}; use multer::{Constraints, Multipart, SizeLimit}; @@ -29,7 +31,7 @@ pub async fn handle_post_object( .and_then(|ct| multer::parse_boundary(ct).ok()) .ok_or_bad_request("Counld not get multipart boundary")?; - // 16k seems plenty for a header. 5G is the max size of a single part, so it seemrs reasonable + // 16k seems plenty for a header. 5G is the max size of a single part, so it seems reasonable // for a PostObject let constraints = Constraints::new().size_limit( SizeLimit::new() @@ -119,9 +121,9 @@ pub async fn handle_post_object( let conditions = decoded_policy.into_conditions()?; - for (key, value) in params.iter() { - let key = key.as_str(); - if key.eq_ignore_ascii_case("content-type") { + for (param_key, value) in params.iter() { + let param_key = param_key.as_str(); + if param_key.eq_ignore_ascii_case("content-type") { for cond in &conditions.content_type { let ok = match cond { Operation::Equal(s) => value == s, @@ -132,13 +134,29 @@ pub async fn handle_post_object( if !ok { return Err(Error::BadRequest(format!( "Key '{}' has value not allowed in policy", - key + param_key + ))); + } + } + } else if param_key == "key" { + let conds = conditions.params.get("key").ok_or_else(|| { + Error::BadRequest(format!("Key '{}' is not allowed in policy", param_key)) + })?; + for cond in conds { + let ok = match cond { + Operation::Equal(s) => s == &key, + Operation::StartsWith(s) => key.starts_with(s), + }; + if !ok { + return Err(Error::BadRequest(format!( + "Key '{}' has value not allowed in policy", + param_key ))); } } } else { - let conds = conditions.params.get(key).ok_or_else(|| { - Error::BadRequest(format!("Key '{}' is not allowed in policy", key)) + let conds = conditions.params.get(param_key).ok_or_else(|| { + Error::BadRequest(format!("Key '{}' is not allowed in policy", param_key)) })?; for cond in conds { let ok = match cond { @@ -148,16 +166,13 @@ pub async fn handle_post_object( if !ok { return Err(Error::BadRequest(format!( "Key '{}' has value not allowed in policy", - key + param_key ))); } } } } - // TODO validate policy against request - // unsafe to merge until implemented - let content_type = field .content_type() .map(AsRef::as_ref) @@ -169,10 +184,11 @@ pub async fn handle_post_object( params.append(header::CONTENT_TYPE, content_type); let headers = get_headers(¶ms)?; + let stream = field.map(|r| r.map_err(Into::into)); let res = save_stream( garage, headers, - field.map(|r| r.map_err(Into::into)), + StreamLimiter::new(stream, conditions.content_length), bucket_id, &key, None, @@ -183,9 +199,10 @@ pub async fn handle_post_object( let resp = if let Some(target) = params .get("success_action_redirect") .and_then(|h| h.to_str().ok()) + .and_then(|u| url::Url::parse(u).ok()) + .filter(|u| u.scheme() == "https" || u.scheme() == "http") { - // TODO should validate it's a valid url - let target = target.to_owned(); + let target = target.to_string(); Response::builder() .status(StatusCode::SEE_OTHER) .header(header::LOCATION, target.clone()) @@ -309,3 +326,52 @@ enum Operation { Equal(String), StartsWith(String), } + +struct StreamLimiter { + inner: T, + length: RangeInclusive, + read: u64, +} + +impl StreamLimiter { + fn new(stream: T, length: RangeInclusive) -> Self { + StreamLimiter { + inner: stream, + length, + read: 0, + } + } +} + +impl Stream for StreamLimiter +where + T: Stream> + Unpin, +{ + type Item = Result; + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + ctx: &mut Context<'_>, + ) -> Poll> { + let res = std::pin::Pin::new(&mut self.inner).poll_next(ctx); + match &res { + Poll::Ready(Some(Ok(bytes))) => { + self.read += bytes.len() as u64; + // optimization to fail early when we know before the end it's too long + if self.length.end() < &self.read { + return Poll::Ready(Some(Err(Error::BadRequest( + "File size does not match policy".to_owned(), + )))); + } + } + Poll::Ready(None) => { + if !self.length.contains(&self.read) { + return Poll::Ready(Some(Err(Error::BadRequest( + "File size does not match policy".to_owned(), + )))); + } + } + _ => {} + } + res + } +}