forked from Deuxfleurs/tricot
Handle HTTP/1.1 SWITCHING_PROTOCOL to handle Connection: Upgrade correctly
This commit is contained in:
parent
df4a36990c
commit
cbf7a03836
1 changed files with 57 additions and 15 deletions
|
@ -10,7 +10,7 @@ use std::time::{Duration, SystemTime};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use log::*;
|
use log::*;
|
||||||
|
|
||||||
use http::header::HeaderName;
|
use http::{header::HeaderName, StatusCode};
|
||||||
use hyper::header::{HeaderMap, HeaderValue};
|
use hyper::header::{HeaderMap, HeaderValue};
|
||||||
use hyper::{client::HttpConnector, header, Body, Client, Request, Response, Uri};
|
use hyper::{client::HttpConnector, header, Body, Client, Request, Response, Uri};
|
||||||
use rustls::client::{ServerCertVerified, ServerCertVerifier};
|
use rustls::client::{ServerCertVerified, ServerCertVerifier};
|
||||||
|
@ -51,20 +51,22 @@ fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue
|
||||||
fn copy_upgrade_headers(
|
fn copy_upgrade_headers(
|
||||||
old_headers: &HeaderMap<HeaderValue>,
|
old_headers: &HeaderMap<HeaderValue>,
|
||||||
new_headers: &mut HeaderMap<HeaderValue>,
|
new_headers: &mut HeaderMap<HeaderValue>,
|
||||||
) -> Result<()> {
|
) -> Result<bool> {
|
||||||
// The Connection header is stripped as it is a hop header that we are not supposed to proxy.
|
// The Connection header is stripped as it is a hop header that we are not supposed to proxy.
|
||||||
// However, it might also contain an Upgrade directive, e.g. for Websockets:
|
// However, it might also contain an Upgrade directive, e.g. for Websockets:
|
||||||
// when that happen, we do want to preserve that directive.
|
// when that happen, we do want to preserve that directive.
|
||||||
|
let mut is_upgrade = false;
|
||||||
if let Some(conn) = old_headers.get(header::CONNECTION) {
|
if let Some(conn) = old_headers.get(header::CONNECTION) {
|
||||||
let conn_str = conn.to_str()?.to_lowercase();
|
let conn_str = conn.to_str()?.to_lowercase();
|
||||||
if conn_str.split(',').map(str::trim).any(|x| x == "upgrade") {
|
if conn_str.split(',').map(str::trim).any(|x| x == "upgrade") {
|
||||||
if let Some(upgrade) = old_headers.get(header::UPGRADE) {
|
if let Some(upgrade) = old_headers.get(header::UPGRADE) {
|
||||||
new_headers.insert(header::CONNECTION, "Upgrade".try_into()?);
|
new_headers.insert(header::CONNECTION, "Upgrade".try_into()?);
|
||||||
new_headers.insert(header::UPGRADE, upgrade.clone());
|
new_headers.insert(header::UPGRADE, upgrade.clone());
|
||||||
|
is_upgrade = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(is_upgrade)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri> {
|
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri> {
|
||||||
|
@ -76,11 +78,11 @@ fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri> {
|
||||||
Ok(Uri::from_str(forward_uri.as_str())?)
|
Ok(Uri::from_str(forward_uri.as_str())?)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_proxied_request<B>(
|
fn create_proxied_request<B: std::default::Default>(
|
||||||
client_ip: IpAddr,
|
client_ip: IpAddr,
|
||||||
forward_url: &str,
|
forward_url: &str,
|
||||||
request: Request<B>,
|
request: Request<B>,
|
||||||
) -> Result<Request<B>> {
|
) -> Result<(Request<B>, Option<Request<B>>)> {
|
||||||
let mut builder = Request::builder()
|
let mut builder = Request::builder()
|
||||||
.method(request.method())
|
.method(request.method())
|
||||||
.uri(forward_uri(forward_url, &request)?)
|
.uri(forward_uri(forward_url, &request)?)
|
||||||
|
@ -131,19 +133,57 @@ fn create_proxied_request<B>(
|
||||||
);
|
);
|
||||||
|
|
||||||
// Proxy upgrade requests properly
|
// Proxy upgrade requests properly
|
||||||
copy_upgrade_headers(old_headers, new_headers)?;
|
let is_upgrade = copy_upgrade_headers(old_headers, new_headers)?;
|
||||||
|
|
||||||
Ok(builder.body(request.into_body())?)
|
if is_upgrade {
|
||||||
|
Ok((builder.body(B::default())?, Some(request)))
|
||||||
|
} else {
|
||||||
|
Ok((builder.body(request.into_body())?, None))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_proxied_response<B>(mut response: Response<B>) -> Result<Response<B>> {
|
async fn create_proxied_response<B: std::default::Default + Send + Sync + 'static>(
|
||||||
|
mut response: Response<B>,
|
||||||
|
upgrade_request: Option<Request<B>>,
|
||||||
|
) -> Result<Response<B>> {
|
||||||
let old_headers = response.headers();
|
let old_headers = response.headers();
|
||||||
let mut new_headers = remove_hop_headers(old_headers);
|
|
||||||
|
|
||||||
|
let mut new_headers = remove_hop_headers(old_headers);
|
||||||
copy_upgrade_headers(old_headers, &mut new_headers)?;
|
copy_upgrade_headers(old_headers, &mut new_headers)?;
|
||||||
|
|
||||||
*response.headers_mut() = new_headers;
|
if response.status() == StatusCode::SWITCHING_PROTOCOLS {
|
||||||
Ok(response)
|
if let Some(mut req) = upgrade_request {
|
||||||
|
let mut res_upgraded = hyper::upgrade::on(response).await?;
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
match hyper::upgrade::on(&mut req).await {
|
||||||
|
Ok(mut req_upgraded) => {
|
||||||
|
if let Err(e) =
|
||||||
|
tokio::io::copy_bidirectional(&mut req_upgraded, &mut res_upgraded)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
warn!("Error copying data in upgraded request: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
"Could not upgrade client request when switching protocols: {}",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut new_res = Response::builder().status(StatusCode::SWITCHING_PROTOCOLS);
|
||||||
|
*new_res.headers_mut().unwrap() = new_headers;
|
||||||
|
Ok(new_res.body(B::default())?)
|
||||||
|
} else {
|
||||||
|
return Err(anyhow!("Switching protocols but not an upgrade request"));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
*response.headers_mut() = new_headers;
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn call(
|
pub async fn call(
|
||||||
|
@ -151,7 +191,8 @@ pub async fn call(
|
||||||
forward_uri: &str,
|
forward_uri: &str,
|
||||||
request: Request<Body>,
|
request: Request<Body>,
|
||||||
) -> Result<Response<Body>> {
|
) -> Result<Response<Body>> {
|
||||||
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
|
let (proxied_request, upgrade_request) =
|
||||||
|
create_proxied_request(client_ip, forward_uri, request)?;
|
||||||
|
|
||||||
trace!("Proxied request: {:?}", proxied_request);
|
trace!("Proxied request: {:?}", proxied_request);
|
||||||
|
|
||||||
|
@ -164,7 +205,7 @@ pub async fn call(
|
||||||
|
|
||||||
trace!("Inner response: {:?}", response);
|
trace!("Inner response: {:?}", response);
|
||||||
|
|
||||||
let proxied_response = create_proxied_response(response)?;
|
let proxied_response = create_proxied_response(response, upgrade_request).await?;
|
||||||
Ok(proxied_response)
|
Ok(proxied_response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -173,7 +214,8 @@ pub async fn call_https(
|
||||||
forward_uri: &str,
|
forward_uri: &str,
|
||||||
request: Request<Body>,
|
request: Request<Body>,
|
||||||
) -> Result<Response<Body>> {
|
) -> Result<Response<Body>> {
|
||||||
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
|
let (proxied_request, upgrade_request) =
|
||||||
|
create_proxied_request(client_ip, forward_uri, request)?;
|
||||||
|
|
||||||
trace!("Proxied request (HTTPS): {:?}", proxied_request);
|
trace!("Proxied request (HTTPS): {:?}", proxied_request);
|
||||||
|
|
||||||
|
@ -191,7 +233,7 @@ pub async fn call_https(
|
||||||
|
|
||||||
trace!("Inner response (HTTPS): {:?}", response);
|
trace!("Inner response (HTTPS): {:?}", response);
|
||||||
|
|
||||||
let proxied_response = create_proxied_response(response)?;
|
let proxied_response = create_proxied_response(response, upgrade_request).await?;
|
||||||
Ok(proxied_response)
|
Ok(proxied_response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue