forked from Deuxfleurs/tricot
The host tag is not currently included, but could be added later with some refactorings if we need it.
539 lines
15 KiB
Rust
539 lines
15 KiB
Rust
use std::convert::Infallible;
|
|
use std::net::SocketAddr;
|
|
use std::sync::{atomic::Ordering, Arc};
|
|
use std::time::{Duration, Instant};
|
|
|
|
use anyhow::Result;
|
|
use tracing::*;
|
|
|
|
use accept_encoding_fork::Encoding;
|
|
use async_compression::tokio::bufread::*;
|
|
use futures::stream::FuturesUnordered;
|
|
use futures::{StreamExt, TryStreamExt};
|
|
use http::header::{HeaderName, HeaderValue};
|
|
use http::method::Method;
|
|
use hyper::server::conn::Http;
|
|
use hyper::service::service_fn;
|
|
use hyper::{header, Body, Request, Response, StatusCode};
|
|
use tokio::net::TcpListener;
|
|
use tokio::select;
|
|
use tokio::sync::watch;
|
|
use tokio_rustls::TlsAcceptor;
|
|
use tokio_util::io::{ReaderStream, StreamReader};
|
|
|
|
use opentelemetry::{metrics, KeyValue};
|
|
|
|
use crate::cert_store::{CertStore, StoreResolver};
|
|
use crate::proxy_config::{HostDescription, ProxyConfig, ProxyEntry};
|
|
use crate::reverse_proxy;
|
|
|
|
const MAX_CONNECTION_LIFETIME: Duration = Duration::from_secs(24 * 3600);
|
|
|
|
pub struct HttpsConfig {
|
|
pub bind_addr: SocketAddr,
|
|
pub enable_compression: bool,
|
|
pub compress_mime_types: Vec<String>,
|
|
|
|
// used internally to convert Instants to u64
|
|
pub time_origin: Instant,
|
|
}
|
|
|
|
struct HttpsMetrics {
|
|
requests_received: metrics::Counter<u64>,
|
|
requests_served: metrics::Counter<u64>,
|
|
requests_in_flight: metrics::UpDownCounter<i64>,
|
|
request_proxy_duration: metrics::Histogram<f64>,
|
|
}
|
|
|
|
struct InFlightGuard<'a, 'b> {
|
|
metrics: &'a HttpsMetrics,
|
|
tags: &'b [KeyValue],
|
|
}
|
|
|
|
impl<'a, 'b> Drop for InFlightGuard<'a, 'b> {
|
|
fn drop(&mut self) {
|
|
self.metrics.requests_in_flight.add(-1, self.tags)
|
|
}
|
|
}
|
|
|
|
pub async fn serve_https(
|
|
config: HttpsConfig,
|
|
cert_store: Arc<CertStore>,
|
|
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
|
|
mut must_exit: watch::Receiver<bool>,
|
|
) -> Result<()> {
|
|
let config = Arc::new(config);
|
|
|
|
let meter = opentelemetry::global::meter("tricot");
|
|
let metrics = Arc::new(HttpsMetrics {
|
|
requests_received: meter
|
|
.u64_counter("https_requests_received")
|
|
.with_description("Total number of requests received over HTTPS")
|
|
.init(),
|
|
requests_served: meter
|
|
.u64_counter("https_requests_served")
|
|
.with_description("Total number of requests served over HTTPS")
|
|
.init(),
|
|
requests_in_flight: meter
|
|
.i64_up_down_counter("https_requests_in_flight")
|
|
.with_description("Current number of requests handled over HTTPS")
|
|
.init(),
|
|
request_proxy_duration: meter
|
|
.f64_histogram("https_request_proxy_duration")
|
|
.with_description("Duration between time when request was received, and time when backend returned status code and headers")
|
|
.init(),
|
|
});
|
|
|
|
let mut tls_cfg = rustls::ServerConfig::builder()
|
|
.with_safe_defaults()
|
|
.with_no_client_auth()
|
|
.with_cert_resolver(Arc::new(StoreResolver(cert_store)));
|
|
|
|
tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
|
let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg)));
|
|
|
|
info!("Starting to serve on https://{}.", config.bind_addr);
|
|
|
|
let tcp = TcpListener::bind(config.bind_addr).await?;
|
|
let mut connections = FuturesUnordered::new();
|
|
|
|
while !*must_exit.borrow() {
|
|
let wait_conn_finished = async {
|
|
if connections.is_empty() {
|
|
futures::future::pending().await
|
|
} else {
|
|
connections.next().await
|
|
}
|
|
};
|
|
let (socket, remote_addr) = select! {
|
|
a = tcp.accept() => a?,
|
|
_ = wait_conn_finished => continue,
|
|
_ = must_exit.changed() => continue,
|
|
};
|
|
|
|
let rx_proxy_config = rx_proxy_config.clone();
|
|
let tls_acceptor = tls_acceptor.clone();
|
|
let config = config.clone();
|
|
let metrics = metrics.clone();
|
|
|
|
let mut must_exit_2 = must_exit.clone();
|
|
let conn = tokio::spawn(async move {
|
|
match tls_acceptor.accept(socket).await {
|
|
Ok(stream) => {
|
|
debug!("TLS handshake was successfull");
|
|
let http_conn = Http::new()
|
|
.serve_connection(
|
|
stream,
|
|
service_fn(move |req: Request<Body>| {
|
|
let https_config = config.clone();
|
|
let proxy_config: Arc<ProxyConfig> =
|
|
rx_proxy_config.borrow().clone();
|
|
let metrics = metrics.clone();
|
|
handle_request(
|
|
remote_addr,
|
|
req,
|
|
https_config,
|
|
proxy_config,
|
|
metrics,
|
|
)
|
|
}),
|
|
)
|
|
.with_upgrades();
|
|
let timeout = tokio::time::sleep(MAX_CONNECTION_LIFETIME);
|
|
tokio::pin!(http_conn, timeout);
|
|
let http_result = loop {
|
|
select! (
|
|
r = &mut http_conn => break r.map_err(Into::into),
|
|
_ = &mut timeout => break Err(anyhow!("Connection lived more than 24h, killing it.")),
|
|
_ = must_exit_2.changed() => {
|
|
if *must_exit_2.borrow() {
|
|
http_conn.as_mut().graceful_shutdown();
|
|
}
|
|
}
|
|
)
|
|
};
|
|
if let Err(http_err) = http_result {
|
|
warn!("HTTP error: {}", http_err);
|
|
}
|
|
}
|
|
Err(e) => warn!("Error in TLS connection: {}", e),
|
|
}
|
|
});
|
|
connections.push(conn);
|
|
}
|
|
|
|
drop(tcp);
|
|
|
|
info!("HTTPS server shutting down, draining remaining connections...");
|
|
while connections.next().await.is_some() {}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_request(
|
|
remote_addr: SocketAddr,
|
|
req: Request<Body>,
|
|
https_config: Arc<HttpsConfig>,
|
|
proxy_config: Arc<ProxyConfig>,
|
|
metrics: Arc<HttpsMetrics>,
|
|
) -> Result<Response<Body>, Infallible> {
|
|
let method_tag = KeyValue::new("method", req.method().to_string());
|
|
|
|
// The host tag is only included in the requests_received metric,
|
|
// as for other metrics it can easily lead to cardinality explosions.
|
|
let host_tag = KeyValue::new(
|
|
"host",
|
|
req.uri()
|
|
.authority()
|
|
.map(|auth| auth.to_string())
|
|
.or_else(|| {
|
|
req.headers()
|
|
.get("host")
|
|
.map(|host| host.to_str().unwrap_or_default().to_string())
|
|
})
|
|
.unwrap_or_default(),
|
|
);
|
|
|
|
metrics.requests_received.add(1, &[host_tag, method_tag.clone()]);
|
|
|
|
let mut tags = vec![method_tag];
|
|
let resp = select_target_and_proxy(
|
|
&https_config,
|
|
&proxy_config,
|
|
&metrics,
|
|
remote_addr,
|
|
req,
|
|
&mut tags,
|
|
)
|
|
.await;
|
|
|
|
tags.push(KeyValue::new("status_code", resp.status().as_u16() as i64));
|
|
metrics.requests_served.add(1, &tags);
|
|
|
|
Ok(resp)
|
|
}
|
|
|
|
// Custom echo service, handling two different routes and a
|
|
// catch-all 404 responder.
|
|
async fn select_target_and_proxy(
|
|
https_config: &HttpsConfig,
|
|
proxy_config: &ProxyConfig,
|
|
metrics: &HttpsMetrics,
|
|
remote_addr: SocketAddr,
|
|
req: Request<Body>,
|
|
tags: &mut Vec<KeyValue>,
|
|
) -> Response<Body> {
|
|
let received_time = Instant::now();
|
|
|
|
let method = req.method().clone();
|
|
let uri = req.uri().to_string();
|
|
|
|
let host = if let Some(auth) = req.uri().authority() {
|
|
auth.as_str()
|
|
} else {
|
|
match req.headers().get("host").and_then(|x| x.to_str().ok()) {
|
|
Some(host) => host,
|
|
None => {
|
|
return Response::builder()
|
|
.status(StatusCode::BAD_REQUEST)
|
|
.body(Body::from("Missing Host header"))
|
|
.unwrap();
|
|
}
|
|
}
|
|
};
|
|
let path = req.uri().path();
|
|
|
|
let best_match = proxy_config
|
|
.entries
|
|
.iter()
|
|
.filter(|ent| {
|
|
ent.flags.healthy
|
|
&& ent.url_prefix.host.matches(host)
|
|
&& ent
|
|
.url_prefix
|
|
.path_prefix
|
|
.as_ref()
|
|
.map(|prefix| path.starts_with(prefix))
|
|
.unwrap_or(true)
|
|
})
|
|
.max_by_key(|ent| {
|
|
(
|
|
ent.priority,
|
|
ent.url_prefix
|
|
.path_prefix
|
|
.as_ref()
|
|
.map(|x| x.len() as i32)
|
|
.unwrap_or(0),
|
|
(ent.flags.same_node || ent.flags.site_lb || ent.flags.global_lb),
|
|
(ent.flags.same_site || ent.flags.global_lb),
|
|
-ent.calls_in_progress.load(Ordering::SeqCst),
|
|
-ent.last_call.load(Ordering::SeqCst),
|
|
)
|
|
});
|
|
|
|
if let Some(proxy_to) = best_match {
|
|
tags.push(KeyValue::new("service", proxy_to.service_name.clone()));
|
|
tags.push(KeyValue::new(
|
|
"target_addr",
|
|
proxy_to.target_addr.to_string(),
|
|
));
|
|
tags.push(KeyValue::new("same_node", proxy_to.flags.same_node));
|
|
tags.push(KeyValue::new("same_site", proxy_to.flags.same_site));
|
|
|
|
proxy_to.last_call.fetch_max(
|
|
(received_time - https_config.time_origin).as_millis() as i64,
|
|
Ordering::Relaxed,
|
|
);
|
|
proxy_to.calls_in_progress.fetch_add(1, Ordering::SeqCst);
|
|
|
|
let tags_in_flight = &tags.clone();
|
|
metrics.requests_in_flight.add(1, &tags_in_flight);
|
|
// The guard ensures that we decrement requests_in_flight in all cases where
|
|
// the current tasks ends, including the case where it gets canceled and
|
|
// doesn't run to completion (which may happen e.g. if it timeouts).
|
|
// (Crucially we create the guard before the first .await in this function.)
|
|
let _guard = InFlightGuard {
|
|
metrics: &metrics,
|
|
tags: &tags_in_flight,
|
|
};
|
|
|
|
// Forward to backend
|
|
debug!("{}{} -> {}", host, path, proxy_to);
|
|
trace!("Request: {:?}", req);
|
|
|
|
let response = if let Some(http_res) = try_redirect(host, path, proxy_to) {
|
|
// redirection middleware
|
|
http_res
|
|
} else {
|
|
// proxying to backend
|
|
match do_proxy(https_config, remote_addr, req, proxy_to).await {
|
|
Ok(resp) => resp,
|
|
Err(e) => Response::builder()
|
|
.status(StatusCode::BAD_GATEWAY)
|
|
.body(Body::from(format!("Proxy error: {}", e)))
|
|
.unwrap(),
|
|
}
|
|
};
|
|
|
|
proxy_to.calls_in_progress.fetch_sub(1, Ordering::SeqCst);
|
|
metrics
|
|
.request_proxy_duration
|
|
.record(received_time.elapsed().as_secs_f64(), tags);
|
|
|
|
trace!("Final response: {:?}", response);
|
|
info!("{} {} {}", method, response.status().as_u16(), uri);
|
|
response
|
|
} else {
|
|
debug!("{}{} -> NOT FOUND", host, path);
|
|
info!("{} 404 {}", method, uri);
|
|
|
|
Response::builder()
|
|
.status(StatusCode::NOT_FOUND)
|
|
.body(Body::from("No matching proxy entry"))
|
|
.unwrap()
|
|
}
|
|
}
|
|
|
|
fn try_redirect(req_host: &str, req_path: &str, proxy_to: &ProxyEntry) -> Option<Response<Body>> {
|
|
let maybe_redirect = proxy_to.redirects.iter().find(|(src, _, _)| {
|
|
let mut matched: bool = src.host.matches(req_host);
|
|
|
|
if let Some(path) = &src.path_prefix {
|
|
matched &= req_path.starts_with(path);
|
|
}
|
|
|
|
matched
|
|
});
|
|
|
|
let (src_prefix, dst_prefix, code) = match maybe_redirect {
|
|
None => return None,
|
|
Some(redirect) => redirect,
|
|
};
|
|
|
|
let new_host = match &dst_prefix.host {
|
|
HostDescription::Hostname(h) => h,
|
|
_ => unreachable!(), // checked when ProxyEntry is created
|
|
};
|
|
|
|
let new_prefix = dst_prefix.path_prefix.as_deref().unwrap_or("");
|
|
let original_prefix = src_prefix.path_prefix.as_deref().unwrap_or("");
|
|
let suffix = &req_path[original_prefix.len()..];
|
|
|
|
let uri = format!("https://{}{}{}", new_host, new_prefix, suffix);
|
|
|
|
let status = match StatusCode::from_u16(*code) {
|
|
Err(e) => {
|
|
warn!(
|
|
"Couldn't redirect {}{} to {} as code {} in invalid: {}",
|
|
req_host, req_path, uri, code, e
|
|
);
|
|
return None;
|
|
}
|
|
Ok(sc) => sc,
|
|
};
|
|
|
|
Response::builder()
|
|
.header("Location", uri.clone())
|
|
.status(status)
|
|
.body(Body::from(uri))
|
|
.ok()
|
|
}
|
|
|
|
async fn do_proxy(
|
|
https_config: &HttpsConfig,
|
|
remote_addr: SocketAddr,
|
|
req: Request<Body>,
|
|
proxy_to: &ProxyEntry,
|
|
) -> Result<Response<Body>> {
|
|
let method = req.method().clone();
|
|
let accept_encoding = accept_encoding_fork::encodings(req.headers()).unwrap_or_else(|_| vec![]);
|
|
|
|
let mut response = if proxy_to.https_target {
|
|
let to_addr = format!("https://{}", proxy_to.target_addr);
|
|
reverse_proxy::call_https(remote_addr.ip(), &to_addr, req).await?
|
|
} else {
|
|
let to_addr = format!("http://{}", proxy_to.target_addr);
|
|
reverse_proxy::call(remote_addr.ip(), &to_addr, req).await?
|
|
};
|
|
|
|
if response.status().is_success() {
|
|
// (TODO: maybe we want to add these headers even if it's not a success?)
|
|
for (header, value) in proxy_to.add_headers.iter() {
|
|
response.headers_mut().insert(
|
|
HeaderName::from_bytes(header.as_bytes())?,
|
|
HeaderValue::from_str(value)?,
|
|
);
|
|
}
|
|
}
|
|
|
|
if https_config.enable_compression {
|
|
response = try_compress(response, method, accept_encoding, https_config).await?
|
|
};
|
|
|
|
Ok(response)
|
|
}
|
|
|
|
async fn try_compress(
|
|
response: Response<Body>,
|
|
method: Method,
|
|
accept_encoding: Vec<(Option<Encoding>, f32)>,
|
|
https_config: &HttpsConfig,
|
|
) -> Result<Response<Body>> {
|
|
// Don't bother compressing successfull responses for HEAD and PUT (they should have an empty body)
|
|
// Don't compress partial content as it causes issues
|
|
// Don't bother compressing non-2xx results
|
|
// Don't compress Upgrade responses (e.g. websockets)
|
|
// Don't compress responses that are already compressed
|
|
if (response.status().is_success() && (method == Method::HEAD || method == Method::PUT))
|
|
|| response.status() == StatusCode::PARTIAL_CONTENT
|
|
|| !response.status().is_success()
|
|
|| response.headers().get(header::CONNECTION) == Some(&HeaderValue::from_static("Upgrade"))
|
|
|| response.headers().get(header::CONTENT_ENCODING).is_some()
|
|
{
|
|
return Ok(response);
|
|
}
|
|
|
|
// Select preferred encoding among those proposed in accept_encoding
|
|
let max_q: f32 = accept_encoding
|
|
.iter()
|
|
.max_by_key(|(_, q)| (q * 10000f32) as i64)
|
|
.unwrap_or(&(None, 1.))
|
|
.1;
|
|
let preference = [
|
|
Encoding::Zstd,
|
|
//Encoding::Brotli,
|
|
Encoding::Deflate,
|
|
Encoding::Gzip,
|
|
];
|
|
#[allow(clippy::float_cmp)]
|
|
let encoding_opt = accept_encoding
|
|
.iter()
|
|
.filter(|(_, q)| *q == max_q)
|
|
.filter_map(|(enc, _)| *enc)
|
|
.filter(|enc| preference.contains(enc))
|
|
.min_by_key(|enc| preference.iter().position(|x| x == enc).unwrap());
|
|
|
|
// If preferred encoding is none, return as is
|
|
let encoding = match encoding_opt {
|
|
None | Some(Encoding::Identity) => return Ok(response),
|
|
Some(enc) => enc,
|
|
};
|
|
|
|
// If content type not in mime types for which to compress, return as is
|
|
match response.headers().get(header::CONTENT_TYPE) {
|
|
Some(ct) => {
|
|
let ct_str = ct.to_str()?;
|
|
let mime_type = match ct_str.split_once(';') {
|
|
Some((mime_type, _params)) => mime_type,
|
|
None => ct_str,
|
|
};
|
|
if !https_config
|
|
.compress_mime_types
|
|
.iter()
|
|
.any(|x| x == mime_type)
|
|
{
|
|
return Ok(response);
|
|
}
|
|
}
|
|
None => return Ok(response), // don't compress if unknown mime type
|
|
};
|
|
|
|
let (mut head, mut body) = response.into_parts();
|
|
|
|
// ---- If body is smaller than 1400 bytes, don't compress ----
|
|
let mut chunks = vec![];
|
|
let mut sum_lengths = 0;
|
|
while sum_lengths < 1400 {
|
|
match body.next().await {
|
|
Some(chunk) => {
|
|
let chunk = chunk?;
|
|
sum_lengths += chunk.len();
|
|
chunks.push(chunk);
|
|
}
|
|
None => {
|
|
return Ok(Response::from_parts(head, Body::from(chunks.concat())));
|
|
}
|
|
}
|
|
}
|
|
|
|
// put beginning chunks back into body
|
|
let body = futures::stream::iter(chunks.into_iter().map(Ok)).chain(body);
|
|
|
|
// make an async reader from that for compressor
|
|
let body_rd =
|
|
StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)));
|
|
|
|
trace!(
|
|
"Compressing response body as {:?} (at least {} bytes)",
|
|
encoding,
|
|
sum_lengths
|
|
);
|
|
|
|
// we don't know the compressed content-length so remove that header
|
|
head.headers.remove(header::CONTENT_LENGTH);
|
|
|
|
let (encoding, compressed_body) = match encoding {
|
|
Encoding::Gzip => (
|
|
"gzip",
|
|
Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd))),
|
|
),
|
|
// Encoding::Brotli => {
|
|
// head.headers.insert(header::CONTENT_ENCODING, "br".parse()?);
|
|
// Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd)))
|
|
// }
|
|
Encoding::Deflate => (
|
|
"deflate",
|
|
Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd))),
|
|
),
|
|
Encoding::Zstd => (
|
|
"zstd",
|
|
Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd))),
|
|
),
|
|
_ => unreachable!(),
|
|
};
|
|
head.headers
|
|
.insert(header::CONTENT_ENCODING, encoding.parse()?);
|
|
|
|
Ok(Response::from_parts(head, compressed_body))
|
|
}
|