tricot/src/reverse_proxy.rs

189 lines
5.3 KiB
Rust
Raw Normal View History

2021-12-07 14:20:45 +00:00
//! Copied from https://github.com/felipenoris/hyper-reverse-proxy
//! See there for original Copyright notice
use std::convert::TryInto;
use std::net::IpAddr;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
2021-12-07 14:20:45 +00:00
use anyhow::Result;
2021-12-08 12:28:07 +00:00
use log::*;
2021-12-07 14:20:45 +00:00
2021-12-08 12:28:07 +00:00
use http::header::HeaderName;
2021-12-07 14:20:45 +00:00
use hyper::header::{HeaderMap, HeaderValue};
use hyper::{client::HttpConnector, header, Body, Client, Request, Response, Uri};
use rustls::client::{ServerCertVerified, ServerCertVerifier};
use rustls::{Certificate, ServerName};
use crate::tls_util::HttpsConnectorFixedDnsname;
2021-12-07 14:20:45 +00:00
pub const PROXY_TIMEOUT: Duration = Duration::from_secs(60);
2021-12-09 14:43:19 +00:00
const HOP_HEADERS: &[HeaderName] = &[
header::CONNECTION,
// header::KEEP_ALIVE, // not found in http::header
2021-12-09 14:43:19 +00:00
header::PROXY_AUTHENTICATE,
header::PROXY_AUTHORIZATION,
header::TE,
header::TRAILER,
header::TRANSFER_ENCODING,
header::UPGRADE,
];
fn is_hop_header(name: &HeaderName) -> bool {
HOP_HEADERS.iter().any(|h| h == name)
2021-12-07 14:20:45 +00:00
}
/// Returns a clone of the headers without the [hop-by-hop headers].
///
/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
let mut result = HeaderMap::new();
for (k, v) in headers.iter() {
2021-12-09 22:38:56 +00:00
if !is_hop_header(k) {
2021-12-09 10:20:04 +00:00
result.append(k.clone(), v.clone());
2021-12-07 14:20:45 +00:00
}
}
result
}
fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
*response.headers_mut() = remove_hop_headers(response.headers());
response
}
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri> {
let forward_uri = match req.uri().query() {
Some(query) => format!("{}{}?{}", forward_url, req.uri().path(), query),
None => format!("{}{}", forward_url, req.uri().path()),
};
Ok(Uri::from_str(forward_uri.as_str())?)
}
fn create_proxied_request<B>(
client_ip: IpAddr,
forward_url: &str,
request: Request<B>,
) -> Result<Request<B>> {
2021-12-08 12:28:07 +00:00
let mut builder = Request::builder()
.method(request.method())
.uri(forward_uri(forward_url, &request)?)
.version(hyper::Version::HTTP_11);
2021-12-07 14:20:45 +00:00
*builder.headers_mut().unwrap() = remove_hop_headers(request.headers());
// If request does not have host header, add it from original URI authority
2021-12-09 14:43:19 +00:00
if let header::Entry::Vacant(entry) = builder.headers_mut().unwrap().entry(header::HOST) {
2021-12-09 10:20:04 +00:00
if let Some(authority) = request.uri().authority() {
2021-12-07 14:20:45 +00:00
entry.insert(authority.as_str().parse()?);
}
}
// Add forwarding information in the headers
2021-12-08 16:27:27 +00:00
let x_forwarded_for_header_name = "x-forwarded-for";
2021-12-07 14:20:45 +00:00
match builder
.headers_mut()
.unwrap()
.entry(x_forwarded_for_header_name)
{
2021-12-09 14:43:19 +00:00
header::Entry::Vacant(entry) => {
2021-12-07 14:20:45 +00:00
entry.insert(client_ip.to_string().parse()?);
}
2021-12-09 14:43:19 +00:00
header::Entry::Occupied(mut entry) => {
2021-12-07 14:20:45 +00:00
let addr = format!("{}, {}", entry.get().to_str()?, client_ip);
entry.insert(addr.parse()?);
}
}
builder.headers_mut().unwrap().insert(
HeaderName::from_bytes(b"x-forwarded-proto")?,
"https".try_into()?,
);
2021-12-08 16:27:27 +00:00
2021-12-09 10:20:04 +00:00
// Proxy upgrade requests properly
2021-12-09 14:43:19 +00:00
if let Some(conn) = request.headers().get(header::CONNECTION) {
2021-12-08 12:28:07 +00:00
if conn.to_str()?.to_lowercase() == "upgrade" {
2021-12-09 14:43:19 +00:00
if let Some(upgrade) = request.headers().get(header::UPGRADE) {
builder
.headers_mut()
.unwrap()
.insert(header::CONNECTION, "Upgrade".try_into()?);
2021-12-08 12:28:07 +00:00
builder
.headers_mut()
.unwrap()
2021-12-09 14:43:19 +00:00
.insert(header::UPGRADE, upgrade.clone());
2021-12-08 12:28:07 +00:00
}
}
}
2021-12-07 14:20:45 +00:00
Ok(builder.body(request.into_body())?)
}
pub async fn call(
client_ip: IpAddr,
forward_uri: &str,
request: Request<Body>,
) -> Result<Response<Body>> {
2021-12-09 11:20:37 +00:00
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
2021-12-07 14:20:45 +00:00
2021-12-08 12:28:07 +00:00
trace!("Proxied request: {:?}", proxied_request);
let mut connector = HttpConnector::new();
connector.set_connect_timeout(Some(PROXY_TIMEOUT));
let client: Client<_, hyper::Body> = Client::builder().set_host(false).build(connector);
2021-12-07 14:20:45 +00:00
let response = client.request(proxied_request).await?;
2021-12-08 12:28:07 +00:00
trace!("Inner response: {:?}", response);
2021-12-07 14:20:45 +00:00
let proxied_response = create_proxied_response(response);
Ok(proxied_response)
}
2021-12-08 21:58:19 +00:00
pub async fn call_https(
client_ip: IpAddr,
forward_uri: &str,
request: Request<Body>,
) -> Result<Response<Body>> {
2021-12-09 11:20:37 +00:00
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
2021-12-08 21:58:19 +00:00
trace!("Proxied request (HTTPS): {:?}", proxied_request);
let tls_config = rustls::client::ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(DontVerifyServerCert))
.with_no_client_auth();
let mut http_connector = HttpConnector::new();
http_connector.set_connect_timeout(Some(PROXY_TIMEOUT));
let connector = HttpsConnectorFixedDnsname::new(tls_config, "dummy", http_connector);
let client: Client<_, hyper::Body> = Client::builder().set_host(false).build(connector);
2021-12-08 21:58:19 +00:00
let response = client.request(proxied_request).await?;
trace!("Inner response (HTTPS): {:?}", response);
let proxied_response = create_proxied_response(response);
Ok(proxied_response)
}
struct DontVerifyServerCert;
impl ServerCertVerifier for DontVerifyServerCert {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: SystemTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
}