From 8c6114c3d306acebca908f37861e2c03d8562032 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 6 May 2022 12:21:15 +0200 Subject: [PATCH] Try to clean up code and to fix WebSocket problems --- src/https.rs | 39 +++++++++++++++++--------------- src/proxy_config.rs | 10 ++++---- src/reverse_proxy.rs | 54 ++++++++++++++++++++++++++++---------------- 3 files changed, 60 insertions(+), 43 deletions(-) diff --git a/src/https.rs b/src/https.rs index 83eca7c..7dcf051 100644 --- a/src/https.rs +++ b/src/https.rs @@ -194,25 +194,24 @@ async fn handle( handle_error(reverse_proxy::call(remote_addr.ip(), &to_addr, req).await) }; - // Do further processing (compression, additionnal headers) only for 2xx responses - if !response.status().is_success() { - return Ok(response); + if response.status().is_success() { + // (TODO: maybe we want to add these headers even if it's not a success?) + for (header, value) in proxy_to.add_headers.iter() { + response.headers_mut().insert( + HeaderName::from_bytes(header.as_bytes())?, + HeaderValue::from_str(value)?, + ); + } } - for (header, value) in proxy_to.add_headers.iter() { - response.headers_mut().insert( - HeaderName::from_bytes(header.as_bytes())?, - HeaderValue::from_str(value)?, - ); - } - trace!("Response: {:?}", response); - info!("{} {} {}", method, response.status().as_u16(), uri); - if https_config.enable_compression { - try_compress(response, method, accept_encoding, &https_config).await - } else { - Ok(response) - } + response = + try_compress(response, method.clone(), accept_encoding, &https_config).await? + }; + + trace!("Final response: {:?}", response); + info!("{} {} {}", method, response.status().as_u16(), uri); + Ok(response) } else { debug!("{}{} -> NOT FOUND", host, path); info!("{} 404 {}", method, uri); @@ -240,10 +239,14 @@ async fn try_compress( https_config: &HttpsConfig, ) -> Result> { // Don't bother compressing successfull responses for HEAD and PUT (they should have an empty body) - // Don't compress partial content, that would be wierd - // If already compressed, return as is + // Don't compress partial content as it causes issues + // Don't bother compressing non-2xx results + // Don't compress Upgrade responses (e.g. websockets) + // Don't compress responses that are already compressed if (response.status().is_success() && (method == Method::HEAD || method == Method::PUT)) || response.status() == StatusCode::PARTIAL_CONTENT + || !response.status().is_success() + || response.headers().get(header::CONNECTION) == Some(&HeaderValue::from_static("Upgrade")) || response.headers().get(header::CONTENT_ENCODING).is_some() { return Ok(response); diff --git a/src/proxy_config.rs b/src/proxy_config.rs index bf7ae32..e45cc7b 100644 --- a/src/proxy_config.rs +++ b/src/proxy_config.rs @@ -355,16 +355,16 @@ pub fn spawn_proxy_config_task( #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn test_parse_tricot_add_header_tag() { - match parse_tricot_add_header_tag("tricot-add-header Content-Security-Policy default-src 'none'; img-src 'self'; script-src 'self'; style-src 'self'") { + #[test] + fn test_parse_tricot_add_header_tag() { + match parse_tricot_add_header_tag("tricot-add-header Content-Security-Policy default-src 'none'; img-src 'self'; script-src 'self'; style-src 'self'") { Some((name, value)) => { assert_eq!(name, "Content-Security-Policy"); assert_eq!(value, "default-src 'none'; img-src 'self'; script-src 'self'; style-src 'self'"); } _ => panic!("Passed a valid tag but the function says it is not valid") } - } + } } diff --git a/src/reverse_proxy.rs b/src/reverse_proxy.rs index b5916be..99a7c98 100644 --- a/src/reverse_proxy.rs +++ b/src/reverse_proxy.rs @@ -48,9 +48,19 @@ fn remove_hop_headers(headers: &HeaderMap) -> HeaderMap(mut response: Response) -> Response { - *response.headers_mut() = remove_hop_headers(response.headers()); - response +fn copy_upgrade_headers( + old_headers: &HeaderMap, + new_headers: &mut HeaderMap, +) -> Result<()> { + if let Some(conn) = old_headers.get(header::CONNECTION) { + if conn.to_str()?.to_lowercase() == "upgrade" { + if let Some(upgrade) = old_headers.get(header::UPGRADE) { + new_headers.insert(header::CONNECTION, "Upgrade".try_into()?); + new_headers.insert(header::UPGRADE, upgrade.clone()); + } + } + } + Ok(()) } fn forward_uri(forward_url: &str, req: &Request) -> Result { @@ -72,12 +82,13 @@ fn create_proxied_request( .uri(forward_uri(forward_url, &request)?) .version(hyper::Version::HTTP_11); - let headers = builder.headers_mut().unwrap(); + let old_headers = request.headers(); + let new_headers = builder.headers_mut().unwrap(); - *headers = remove_hop_headers(request.headers()); + *new_headers = remove_hop_headers(old_headers); // If request does not have host header, add it from original URI authority - if let header::Entry::Vacant(entry) = headers.entry(header::HOST) { + if let header::Entry::Vacant(entry) = new_headers.entry(header::HOST) { if let Some(authority) = request.uri().authority() { entry.insert(authority.as_str().parse()?); } @@ -86,7 +97,7 @@ fn create_proxied_request( // Concatenate cookie headers into single header // (HTTP/2 allows several cookie headers, but we are proxying to HTTP/1.1 that does not) let mut cookie_concat = vec![]; - for cookie in headers.get_all(header::COOKIE) { + for cookie in new_headers.get_all(header::COOKIE) { if !cookie_concat.is_empty() { cookie_concat.extend(b"; "); } @@ -94,12 +105,12 @@ fn create_proxied_request( } if !cookie_concat.is_empty() { // insert clears the old value of COOKIE and inserts the concatenated version instead - headers.insert(header::COOKIE, cookie_concat.try_into()?); + new_headers.insert(header::COOKIE, cookie_concat.try_into()?); } // Add forwarding information in the headers let x_forwarded_for_header_name = "x-forwarded-for"; - match headers.entry(x_forwarded_for_header_name) { + match new_headers.entry(x_forwarded_for_header_name) { header::Entry::Vacant(entry) => { entry.insert(client_ip.to_string().parse()?); } @@ -110,24 +121,27 @@ fn create_proxied_request( } } - headers.insert( + new_headers.insert( HeaderName::from_bytes(b"x-forwarded-proto")?, "https".try_into()?, ); // Proxy upgrade requests properly - if let Some(conn) = request.headers().get(header::CONNECTION) { - if conn.to_str()?.to_lowercase() == "upgrade" { - if let Some(upgrade) = request.headers().get(header::UPGRADE) { - headers.insert(header::CONNECTION, "Upgrade".try_into()?); - headers.insert(header::UPGRADE, upgrade.clone()); - } - } - } + copy_upgrade_headers(old_headers, new_headers)?; Ok(builder.body(request.into_body())?) } +fn create_proxied_response(mut response: Response) -> Result> { + let old_headers = response.headers(); + let mut new_headers = remove_hop_headers(old_headers); + + copy_upgrade_headers(old_headers, &mut new_headers)?; + + *response.headers_mut() = new_headers; + Ok(response) +} + pub async fn call( client_ip: IpAddr, forward_uri: &str, @@ -146,7 +160,7 @@ pub async fn call( trace!("Inner response: {:?}", response); - let proxied_response = create_proxied_response(response); + let proxied_response = create_proxied_response(response)?; Ok(proxied_response) } @@ -173,7 +187,7 @@ pub async fn call_https( trace!("Inner response (HTTPS): {:?}", response); - let proxied_response = create_proxied_response(response); + let proxied_response = create_proxied_response(response)?; Ok(proxied_response) }