use std::convert::Infallible; use std::net::SocketAddr; use std::sync::{atomic::Ordering, Arc}; use std::time::Duration; use anyhow::Result; use log::*; use accept_encoding_fork::Encoding; use async_compression::tokio::bufread::*; use futures::stream::FuturesUnordered; use futures::{StreamExt, TryStreamExt}; use http::header::{HeaderName, HeaderValue}; use http::method::Method; use hyper::server::conn::Http; use hyper::service::service_fn; use hyper::{header, Body, Request, Response, StatusCode}; use tokio::net::TcpListener; use tokio::select; use tokio::sync::watch; use tokio_rustls::TlsAcceptor; use tokio_util::io::{ReaderStream, StreamReader}; use opentelemetry::{metrics, KeyValue}; use crate::cert_store::{CertStore, StoreResolver}; use crate::proxy_config::ProxyConfig; use crate::reverse_proxy; const MAX_CONNECTION_LIFETIME: Duration = Duration::from_secs(24 * 3600); pub struct HttpsConfig { pub bind_addr: SocketAddr, pub enable_compression: bool, pub compress_mime_types: Vec, } struct HttpsMetrics { requests_received: metrics::Counter, requests_served: metrics::Counter, } pub async fn serve_https( config: HttpsConfig, cert_store: Arc, rx_proxy_config: watch::Receiver>, mut must_exit: watch::Receiver, ) -> Result<()> { let config = Arc::new(config); let meter = opentelemetry::global::meter("tricot"); let metrics = Arc::new(HttpsMetrics { requests_received: meter .u64_counter("https_requests_received") .with_description("Total number of requests received over HTTPS") .init(), requests_served: meter .u64_counter("https_requests_served") .with_description("Total number of requests served over HTTPS") .init(), }); let mut tls_cfg = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_cert_resolver(Arc::new(StoreResolver(cert_store))); tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg))); info!("Starting to serve on https://{}.", config.bind_addr); let tcp = TcpListener::bind(config.bind_addr).await?; let mut connections = FuturesUnordered::new(); while !*must_exit.borrow() { let wait_conn_finished = async { if connections.is_empty() { futures::future::pending().await } else { connections.next().await } }; let (socket, remote_addr) = select! { a = tcp.accept() => a?, _ = wait_conn_finished => continue, _ = must_exit.changed() => continue, }; let rx_proxy_config = rx_proxy_config.clone(); let tls_acceptor = tls_acceptor.clone(); let config = config.clone(); let metrics = metrics.clone(); let mut must_exit_2 = must_exit.clone(); let conn = tokio::spawn(async move { match tls_acceptor.accept(socket).await { Ok(stream) => { debug!("TLS handshake was successfull"); let http_conn = Http::new() .serve_connection( stream, service_fn(move |req: Request| { let https_config = config.clone(); let proxy_config: Arc = rx_proxy_config.borrow().clone(); let metrics = metrics.clone(); handle_outer(remote_addr, req, https_config, proxy_config, metrics) }), ) .with_upgrades(); let timeout = tokio::time::sleep(MAX_CONNECTION_LIFETIME); tokio::pin!(http_conn, timeout); let http_result = loop { select! ( r = &mut http_conn => break r.map_err(Into::into), _ = &mut timeout => break Err(anyhow!("Connection lived more than 24h, killing it.")), _ = must_exit_2.changed() => { if *must_exit_2.borrow() { http_conn.as_mut().graceful_shutdown(); } } ) }; if let Err(http_err) = http_result { warn!("HTTP error: {}", http_err); } } Err(e) => warn!("Error in TLS connection: {}", e), } }); connections.push(conn); } drop(tcp); info!("HTTPS server shutting down, draining remaining connections..."); while connections.next().await.is_some() {} Ok(()) } async fn handle_outer( remote_addr: SocketAddr, req: Request, https_config: Arc, proxy_config: Arc, metrics: Arc, ) -> Result, Infallible> { let mut tags = vec![ KeyValue::new("method", req.method().to_string()), KeyValue::new( "host", req.uri() .authority() .map(|auth| auth.to_string()) .or_else(|| { req.headers() .get("host") .map(|host| host.to_str().unwrap_or_default().to_string()) }) .unwrap_or_default(), ), ]; metrics.requests_received.add(1, &tags); let resp = match handle(remote_addr, req, https_config, proxy_config, &mut tags).await { Err(e) => { warn!("Handler error: {}", e); Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Body::from(format!("{}", e))) .unwrap() } Ok(r) => r, }; tags.push(KeyValue::new( "response_code", resp.status().as_u16().to_string(), )); metrics.requests_served.add(1, &tags); Ok(resp) } // Custom echo service, handling two different routes and a // catch-all 404 responder. async fn handle( remote_addr: SocketAddr, req: Request, https_config: Arc, proxy_config: Arc, tags: &mut Vec, ) -> Result, anyhow::Error> { let method = req.method().clone(); let uri = req.uri().to_string(); let host = if let Some(auth) = req.uri().authority() { auth.as_str() } else { req.headers() .get("host") .ok_or_else(|| anyhow!("Missing host header"))? .to_str()? }; let path = req.uri().path(); let accept_encoding = accept_encoding_fork::encodings(req.headers()).unwrap_or_else(|_| vec![]); let best_match = proxy_config .entries .iter() .filter(|ent| { ent.host.matches(host) && ent .path_prefix .as_ref() .map(|prefix| path.starts_with(prefix)) .unwrap_or(true) }) .max_by_key(|ent| { ( ent.priority, ent.path_prefix .as_ref() .map(|x| x.len() as i32) .unwrap_or(0), ent.same_node, ent.same_site, -ent.calls.load(Ordering::SeqCst), ) }); if let Some(proxy_to) = best_match { tags.push(KeyValue::new("service_name", proxy_to.service_name.clone())); tags.push(KeyValue::new( "target_addr", proxy_to.target_addr.to_string(), )); tags.push(KeyValue::new( "https_target", proxy_to.https_target.to_string(), )); tags.push(KeyValue::new("same_node", proxy_to.same_node.to_string())); tags.push(KeyValue::new("same_site", proxy_to.same_site.to_string())); proxy_to.calls.fetch_add(1, Ordering::SeqCst); debug!("{}{} -> {}", host, path, proxy_to); trace!("Request: {:?}", req); let mut response = if proxy_to.https_target { let to_addr = format!("https://{}", proxy_to.target_addr); handle_error(reverse_proxy::call_https(remote_addr.ip(), &to_addr, req).await) } else { let to_addr = format!("http://{}", proxy_to.target_addr); handle_error(reverse_proxy::call(remote_addr.ip(), &to_addr, req).await) }; if response.status().is_success() { // (TODO: maybe we want to add these headers even if it's not a success?) for (header, value) in proxy_to.add_headers.iter() { response.headers_mut().insert( HeaderName::from_bytes(header.as_bytes())?, HeaderValue::from_str(value)?, ); } } if https_config.enable_compression { response = try_compress(response, method.clone(), accept_encoding, &https_config).await? }; trace!("Final response: {:?}", response); info!("{} {} {}", method, response.status().as_u16(), uri); Ok(response) } else { debug!("{}{} -> NOT FOUND", host, path); info!("{} 404 {}", method, uri); Ok(Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("No matching proxy entry"))?) } } fn handle_error(resp: Result>) -> Response { match resp { Ok(resp) => resp, Err(e) => Response::builder() .status(StatusCode::BAD_GATEWAY) .body(Body::from(format!("Proxy error: {}", e))) .unwrap(), } } async fn try_compress( response: Response, method: Method, accept_encoding: Vec<(Option, f32)>, https_config: &HttpsConfig, ) -> Result> { // Don't bother compressing successfull responses for HEAD and PUT (they should have an empty body) // Don't compress partial content as it causes issues // Don't bother compressing non-2xx results // Don't compress Upgrade responses (e.g. websockets) // Don't compress responses that are already compressed if (response.status().is_success() && (method == Method::HEAD || method == Method::PUT)) || response.status() == StatusCode::PARTIAL_CONTENT || !response.status().is_success() || response.headers().get(header::CONNECTION) == Some(&HeaderValue::from_static("Upgrade")) || response.headers().get(header::CONTENT_ENCODING).is_some() { return Ok(response); } // Select preferred encoding among those proposed in accept_encoding let max_q: f32 = accept_encoding .iter() .max_by_key(|(_, q)| (q * 10000f32) as i64) .unwrap_or(&(None, 1.)) .1; let preference = [ Encoding::Zstd, //Encoding::Brotli, Encoding::Deflate, Encoding::Gzip, ]; #[allow(clippy::float_cmp)] let encoding_opt = accept_encoding .iter() .filter(|(_, q)| *q == max_q) .filter_map(|(enc, _)| *enc) .filter(|enc| preference.contains(enc)) .min_by_key(|enc| preference.iter().position(|x| x == enc).unwrap()); // If preferred encoding is none, return as is let encoding = match encoding_opt { None | Some(Encoding::Identity) => return Ok(response), Some(enc) => enc, }; // If content type not in mime types for which to compress, return as is match response.headers().get(header::CONTENT_TYPE) { Some(ct) => { let ct_str = ct.to_str()?; let mime_type = match ct_str.split_once(';') { Some((mime_type, _params)) => mime_type, None => ct_str, }; if !https_config .compress_mime_types .iter() .any(|x| x == mime_type) { return Ok(response); } } None => return Ok(response), // don't compress if unknown mime type }; let (mut head, mut body) = response.into_parts(); // ---- If body is smaller than 1400 bytes, don't compress ---- let mut chunks = vec![]; let mut sum_lengths = 0; while sum_lengths < 1400 { match body.next().await { Some(chunk) => { let chunk = chunk?; sum_lengths += chunk.len(); chunks.push(chunk); } None => { return Ok(Response::from_parts(head, Body::from(chunks.concat()))); } } } // put beginning chunks back into body let body = futures::stream::iter(chunks.into_iter().map(Ok)).chain(body); // make an async reader from that for compressor let body_rd = StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))); trace!( "Compressing response body as {:?} (at least {} bytes)", encoding, sum_lengths ); // we don't know the compressed content-length so remove that header head.headers.remove(header::CONTENT_LENGTH); let (encoding, compressed_body) = match encoding { Encoding::Gzip => ( "gzip", Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd))), ), // Encoding::Brotli => { // head.headers.insert(header::CONTENT_ENCODING, "br".parse()?); // Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd))) // } Encoding::Deflate => ( "deflate", Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd))), ), Encoding::Zstd => ( "zstd", Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd))), ), _ => unreachable!(), }; head.headers .insert(header::CONTENT_ENCODING, encoding.parse()?); Ok(Response::from_parts(head, compressed_body)) }