use std::convert::Infallible; use std::net::SocketAddr; use std::sync::{atomic::Ordering, Arc}; use std::time::Duration; use anyhow::Result; use log::*; use accept_encoding_fork::Encoding; use async_compression::tokio::bufread::*; use futures::stream::FuturesUnordered; use futures::{Future, StreamExt, TryStreamExt}; use http::header::{HeaderName, HeaderValue}; use http::method::Method; use hyper::server::conn::Http; use hyper::service::service_fn; use hyper::{header, Body, Request, Response, StatusCode}; use tokio::net::TcpListener; use tokio::select; use tokio::sync::watch; use tokio_rustls::TlsAcceptor; use tokio_util::io::{ReaderStream, StreamReader}; use crate::cert_store::{CertStore, StoreResolver}; use crate::proxy_config::ProxyConfig; use crate::reverse_proxy; const PROXY_TIMEOUT: Duration = Duration::from_secs(60); const MAX_CONNECTION_LIFETIME: Duration = Duration::from_secs(24 * 3600); pub struct HttpsConfig { pub bind_addr: SocketAddr, pub enable_compression: bool, pub compress_mime_types: Vec, } pub async fn serve_https( config: HttpsConfig, cert_store: Arc, rx_proxy_config: watch::Receiver>, mut must_exit: watch::Receiver, ) -> Result<()> { let config = Arc::new(config); let mut tls_cfg = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_cert_resolver(Arc::new(StoreResolver(cert_store))); tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg))); info!("Starting to serve on https://{}.", config.bind_addr); let tcp = TcpListener::bind(config.bind_addr).await?; let mut connections = FuturesUnordered::new(); while !*must_exit.borrow() { let (socket, remote_addr) = select! { a = tcp.accept() => a?, _ = connections.next() => continue, _ = must_exit.changed() => continue, }; let rx_proxy_config = rx_proxy_config.clone(); let tls_acceptor = tls_acceptor.clone(); let config = config.clone(); let mut must_exit_2 = must_exit.clone(); let conn = tokio::spawn(async move { match tls_acceptor.accept(socket).await { Ok(stream) => { debug!("TLS handshake was successfull"); let http_conn = Http::new().serve_connection( stream, service_fn(move |req: Request| { let https_config = config.clone(); let proxy_config: Arc = rx_proxy_config.borrow().clone(); handle_outer(remote_addr, req, https_config, proxy_config) }), ); let timeout = tokio::time::sleep(MAX_CONNECTION_LIFETIME); tokio::pin!(http_conn, timeout); let http_result = loop { select! ( r = &mut http_conn => break r.map_err(Into::into), _ = &mut timeout => break Err(anyhow!("Connection lived more than 24h, killing it.")), _ = must_exit_2.changed() => { if *must_exit_2.borrow() { http_conn.as_mut().graceful_shutdown(); } } ) }; if let Err(http_err) = http_result { warn!("HTTP error: {}", http_err); } } Err(e) => warn!("Error in TLS connection: {}", e), } }); connections.push(conn); } drop(tcp); info!("HTTPS server shutting down, draining remaining connections..."); while !connections.is_empty() { let _ = connections.next().await; } Ok(()) } async fn handle_outer( remote_addr: SocketAddr, req: Request, https_config: Arc, proxy_config: Arc, ) -> Result, Infallible> { match handle(remote_addr, req, https_config, proxy_config).await { Err(e) => { warn!("Handler error: {}", e); Ok(Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Body::from(format!("{}", e))) .unwrap()) } Ok(r) => Ok(r), } } // Custom echo service, handling two different routes and a // catch-all 404 responder. async fn handle( remote_addr: SocketAddr, req: Request, https_config: Arc, proxy_config: Arc, ) -> Result, anyhow::Error> { let method = req.method().clone(); let uri = req.uri().to_string(); let host = if let Some(auth) = req.uri().authority() { auth.as_str() } else { req.headers() .get("host") .ok_or_else(|| anyhow!("Missing host header"))? .to_str()? }; let path = req.uri().path(); let accept_encoding = accept_encoding_fork::encodings(req.headers()).unwrap_or_else(|_| vec![]); let best_match = proxy_config .entries .iter() .filter(|ent| { ent.host.matches(host) && ent .path_prefix .as_ref() .map(|prefix| path.starts_with(prefix)) .unwrap_or(true) }) .max_by_key(|ent| { ( ent.priority, ent.path_prefix .as_ref() .map(|x| x.len() as i32) .unwrap_or(0), ent.same_node, ent.same_site, -ent.calls.load(Ordering::SeqCst), ) }); if let Some(proxy_to) = best_match { proxy_to.calls.fetch_add(1, Ordering::SeqCst); debug!("{}{} -> {}", host, path, proxy_to); trace!("Request: {:?}", req); let mut response = if proxy_to.https_target { let to_addr = format!("https://{}", proxy_to.target_addr); handle_timeout_and_error(reverse_proxy::call_https(remote_addr.ip(), &to_addr, req)) .await } else { let to_addr = format!("http://{}", proxy_to.target_addr); handle_timeout_and_error(reverse_proxy::call(remote_addr.ip(), &to_addr, req)).await }; // Do further processing (compression, additionnal headers) only for 2xx responses if !response.status().is_success() { return Ok(response); } for (header, value) in proxy_to.add_headers.iter() { response.headers_mut().insert( HeaderName::from_bytes(header.as_bytes())?, HeaderValue::from_str(value)?, ); } trace!("Response: {:?}", response); info!("{} {} {}", method, response.status().as_u16(), uri); if https_config.enable_compression { try_compress(response, method, accept_encoding, &https_config).await } else { Ok(response) } } else { debug!("{}{} -> NOT FOUND", host, path); info!("{} 404 {}", method, uri); Ok(Response::builder() .status(StatusCode::NOT_FOUND) .body(Body::from("No matching proxy entry"))?) } } async fn handle_timeout_and_error( fut: impl Future>>, ) -> Response { select!( resp = fut => { match resp { Ok(resp) => resp, Err(e) => Response::builder() .status(StatusCode::BAD_GATEWAY) .body(Body::from(format!("Proxy error: {}", e))) .unwrap(), } } _ = tokio::time::sleep(PROXY_TIMEOUT) => { Response::builder() .status(StatusCode::BAD_GATEWAY) .body(Body::from("Proxy timeout")) .unwrap() } ) } async fn try_compress( response: Response, method: Method, accept_encoding: Vec<(Option, f32)>, https_config: &HttpsConfig, ) -> Result> { // Don't bother compressing successfull responses for HEAD and PUT (they should have an empty body) // Don't compress partial content, that would be wierd // If already compressed, return as is if (response.status().is_success() && (method == Method::HEAD || method == Method::PUT)) || response.status() == StatusCode::PARTIAL_CONTENT || response.headers().get(header::CONTENT_ENCODING).is_some() { return Ok(response); } // Select preferred encoding among those proposed in accept_encoding let max_q: f32 = accept_encoding .iter() .max_by_key(|(_, q)| (q * 10000f32) as i64) .unwrap_or(&(None, 1.)) .1; let preference = [ Encoding::Zstd, //Encoding::Brotli, Encoding::Deflate, Encoding::Gzip, ]; #[allow(clippy::float_cmp)] let encoding_opt = accept_encoding .iter() .filter(|(_, q)| *q == max_q) .filter_map(|(enc, _)| *enc) .filter(|enc| preference.contains(enc)) .min_by_key(|enc| preference.iter().position(|x| x == enc).unwrap()); // If preferred encoding is none, return as is let encoding = match encoding_opt { None | Some(Encoding::Identity) => return Ok(response), Some(enc) => enc, }; // If content type not in mime types for which to compress, return as is match response.headers().get(header::CONTENT_TYPE) { Some(ct) => { let ct_str = ct.to_str()?; let mime_type = match ct_str.split_once(';') { Some((mime_type, _params)) => mime_type, None => ct_str, }; if !https_config .compress_mime_types .iter() .any(|x| x == mime_type) { return Ok(response); } } None => return Ok(response), // don't compress if unknown mime type }; let (mut head, mut body) = response.into_parts(); // ---- If body is smaller than 1400 bytes, don't compress ---- let mut chunks = vec![]; let mut sum_lengths = 0; while sum_lengths < 1400 { match body.next().await { Some(chunk) => { let chunk = chunk?; sum_lengths += chunk.len(); chunks.push(chunk); } None => { return Ok(Response::from_parts(head, Body::from(chunks.concat()))); } } } // put beginning chunks back into body let body = futures::stream::iter(chunks.into_iter().map(Ok)).chain(body); // make an async reader from that for compressor let body_rd = StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))); trace!( "Compressing response body as {:?} (at least {} bytes)", encoding, sum_lengths ); // we don't know the compressed content-length so remove that header head.headers.remove(header::CONTENT_LENGTH); let (encoding, compressed_body) = match encoding { Encoding::Gzip => ( "gzip", Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd))), ), // Encoding::Brotli => { // head.headers.insert(header::CONTENT_ENCODING, "br".parse()?); // Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd))) // } Encoding::Deflate => ( "deflate", Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd))), ), Encoding::Zstd => ( "zstd", Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd))), ), _ => unreachable!(), }; head.headers .insert(header::CONTENT_ENCODING, encoding.parse()?); Ok(Response::from_parts(head, compressed_body)) }