forked from Deuxfleurs/tricot
357 lines
9.9 KiB
Rust
357 lines
9.9 KiB
Rust
use std::convert::Infallible;
|
|
use std::net::SocketAddr;
|
|
use std::sync::{atomic::Ordering, Arc};
|
|
use std::time::Duration;
|
|
|
|
use anyhow::Result;
|
|
use log::*;
|
|
|
|
use accept_encoding_fork::Encoding;
|
|
use async_compression::tokio::bufread::*;
|
|
use futures::stream::FuturesUnordered;
|
|
use futures::{StreamExt, TryStreamExt};
|
|
use http::header::{HeaderName, HeaderValue};
|
|
use http::method::Method;
|
|
use hyper::server::conn::Http;
|
|
use hyper::service::service_fn;
|
|
use hyper::{header, Body, Request, Response, StatusCode};
|
|
use tokio::net::TcpListener;
|
|
use tokio::select;
|
|
use tokio::sync::watch;
|
|
use tokio_rustls::TlsAcceptor;
|
|
use tokio_util::io::{ReaderStream, StreamReader};
|
|
|
|
use crate::cert_store::{CertStore, StoreResolver};
|
|
use crate::proxy_config::ProxyConfig;
|
|
use crate::reverse_proxy;
|
|
|
|
const MAX_CONNECTION_LIFETIME: Duration = Duration::from_secs(24 * 3600);
|
|
|
|
pub struct HttpsConfig {
|
|
pub bind_addr: SocketAddr,
|
|
pub enable_compression: bool,
|
|
pub compress_mime_types: Vec<String>,
|
|
}
|
|
|
|
pub async fn serve_https(
|
|
config: HttpsConfig,
|
|
cert_store: Arc<CertStore>,
|
|
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
|
|
mut must_exit: watch::Receiver<bool>,
|
|
) -> Result<()> {
|
|
let config = Arc::new(config);
|
|
|
|
let mut tls_cfg = rustls::ServerConfig::builder()
|
|
.with_safe_defaults()
|
|
.with_no_client_auth()
|
|
.with_cert_resolver(Arc::new(StoreResolver(cert_store)));
|
|
|
|
tls_cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
|
let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(tls_cfg)));
|
|
|
|
info!("Starting to serve on https://{}.", config.bind_addr);
|
|
|
|
let tcp = TcpListener::bind(config.bind_addr).await?;
|
|
let mut connections = FuturesUnordered::new();
|
|
|
|
while !*must_exit.borrow() {
|
|
let wait_conn_finished = async {
|
|
if connections.is_empty() {
|
|
futures::future::pending().await
|
|
} else {
|
|
connections.next().await
|
|
}
|
|
};
|
|
let (socket, remote_addr) = select! {
|
|
a = tcp.accept() => a?,
|
|
_ = wait_conn_finished => continue,
|
|
_ = must_exit.changed() => continue,
|
|
};
|
|
|
|
let rx_proxy_config = rx_proxy_config.clone();
|
|
let tls_acceptor = tls_acceptor.clone();
|
|
let config = config.clone();
|
|
|
|
let mut must_exit_2 = must_exit.clone();
|
|
let conn = tokio::spawn(async move {
|
|
match tls_acceptor.accept(socket).await {
|
|
Ok(stream) => {
|
|
debug!("TLS handshake was successfull");
|
|
let http_conn = Http::new().serve_connection(
|
|
stream,
|
|
service_fn(move |req: Request<Body>| {
|
|
let https_config = config.clone();
|
|
let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();
|
|
handle_outer(remote_addr, req, https_config, proxy_config)
|
|
}),
|
|
);
|
|
let timeout = tokio::time::sleep(MAX_CONNECTION_LIFETIME);
|
|
tokio::pin!(http_conn, timeout);
|
|
let http_result = loop {
|
|
select! (
|
|
r = &mut http_conn => break r.map_err(Into::into),
|
|
_ = &mut timeout => break Err(anyhow!("Connection lived more than 24h, killing it.")),
|
|
_ = must_exit_2.changed() => {
|
|
if *must_exit_2.borrow() {
|
|
http_conn.as_mut().graceful_shutdown();
|
|
}
|
|
}
|
|
)
|
|
};
|
|
if let Err(http_err) = http_result {
|
|
warn!("HTTP error: {}", http_err);
|
|
}
|
|
}
|
|
Err(e) => warn!("Error in TLS connection: {}", e),
|
|
}
|
|
});
|
|
connections.push(conn);
|
|
}
|
|
|
|
drop(tcp);
|
|
|
|
info!("HTTPS server shutting down, draining remaining connections...");
|
|
while connections.next().await.is_some() {}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_outer(
|
|
remote_addr: SocketAddr,
|
|
req: Request<Body>,
|
|
https_config: Arc<HttpsConfig>,
|
|
proxy_config: Arc<ProxyConfig>,
|
|
) -> Result<Response<Body>, Infallible> {
|
|
match handle(remote_addr, req, https_config, proxy_config).await {
|
|
Err(e) => {
|
|
warn!("Handler error: {}", e);
|
|
Ok(Response::builder()
|
|
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
|
.body(Body::from(format!("{}", e)))
|
|
.unwrap())
|
|
}
|
|
Ok(r) => Ok(r),
|
|
}
|
|
}
|
|
|
|
// Custom echo service, handling two different routes and a
|
|
// catch-all 404 responder.
|
|
async fn handle(
|
|
remote_addr: SocketAddr,
|
|
req: Request<Body>,
|
|
https_config: Arc<HttpsConfig>,
|
|
proxy_config: Arc<ProxyConfig>,
|
|
) -> Result<Response<Body>, anyhow::Error> {
|
|
let method = req.method().clone();
|
|
let uri = req.uri().to_string();
|
|
|
|
let host = if let Some(auth) = req.uri().authority() {
|
|
auth.as_str()
|
|
} else {
|
|
req.headers()
|
|
.get("host")
|
|
.ok_or_else(|| anyhow!("Missing host header"))?
|
|
.to_str()?
|
|
};
|
|
let path = req.uri().path();
|
|
let accept_encoding = accept_encoding_fork::encodings(req.headers()).unwrap_or_else(|_| vec![]);
|
|
|
|
let best_match = proxy_config
|
|
.entries
|
|
.iter()
|
|
.filter(|ent| {
|
|
ent.host.matches(host)
|
|
&& ent
|
|
.path_prefix
|
|
.as_ref()
|
|
.map(|prefix| path.starts_with(prefix))
|
|
.unwrap_or(true)
|
|
})
|
|
.max_by_key(|ent| {
|
|
(
|
|
ent.priority,
|
|
ent.path_prefix
|
|
.as_ref()
|
|
.map(|x| x.len() as i32)
|
|
.unwrap_or(0),
|
|
ent.same_node,
|
|
ent.same_site,
|
|
-ent.calls.load(Ordering::SeqCst),
|
|
)
|
|
});
|
|
|
|
if let Some(proxy_to) = best_match {
|
|
proxy_to.calls.fetch_add(1, Ordering::SeqCst);
|
|
|
|
debug!("{}{} -> {}", host, path, proxy_to);
|
|
trace!("Request: {:?}", req);
|
|
|
|
let mut response = if proxy_to.https_target {
|
|
let to_addr = format!("https://{}", proxy_to.target_addr);
|
|
handle_error(reverse_proxy::call_https(remote_addr.ip(), &to_addr, req).await)
|
|
} else {
|
|
let to_addr = format!("http://{}", proxy_to.target_addr);
|
|
handle_error(reverse_proxy::call(remote_addr.ip(), &to_addr, req).await)
|
|
};
|
|
|
|
if response.status().is_success() {
|
|
// (TODO: maybe we want to add these headers even if it's not a success?)
|
|
for (header, value) in proxy_to.add_headers.iter() {
|
|
response.headers_mut().insert(
|
|
HeaderName::from_bytes(header.as_bytes())?,
|
|
HeaderValue::from_str(value)?,
|
|
);
|
|
}
|
|
}
|
|
|
|
if https_config.enable_compression {
|
|
response =
|
|
try_compress(response, method.clone(), accept_encoding, &https_config).await?
|
|
};
|
|
|
|
trace!("Final response: {:?}", response);
|
|
info!("{} {} {}", method, response.status().as_u16(), uri);
|
|
Ok(response)
|
|
} else {
|
|
debug!("{}{} -> NOT FOUND", host, path);
|
|
info!("{} 404 {}", method, uri);
|
|
|
|
Ok(Response::builder()
|
|
.status(StatusCode::NOT_FOUND)
|
|
.body(Body::from("No matching proxy entry"))?)
|
|
}
|
|
}
|
|
|
|
fn handle_error(resp: Result<Response<Body>>) -> Response<Body> {
|
|
match resp {
|
|
Ok(resp) => resp,
|
|
Err(e) => Response::builder()
|
|
.status(StatusCode::BAD_GATEWAY)
|
|
.body(Body::from(format!("Proxy error: {}", e)))
|
|
.unwrap(),
|
|
}
|
|
}
|
|
|
|
async fn try_compress(
|
|
response: Response<Body>,
|
|
method: Method,
|
|
accept_encoding: Vec<(Option<Encoding>, f32)>,
|
|
https_config: &HttpsConfig,
|
|
) -> Result<Response<Body>> {
|
|
// Don't bother compressing successfull responses for HEAD and PUT (they should have an empty body)
|
|
// Don't compress partial content as it causes issues
|
|
// Don't bother compressing non-2xx results
|
|
// Don't compress Upgrade responses (e.g. websockets)
|
|
// Don't compress responses that are already compressed
|
|
if (response.status().is_success() && (method == Method::HEAD || method == Method::PUT))
|
|
|| response.status() == StatusCode::PARTIAL_CONTENT
|
|
|| !response.status().is_success()
|
|
|| response.headers().get(header::CONNECTION) == Some(&HeaderValue::from_static("Upgrade"))
|
|
|| response.headers().get(header::CONTENT_ENCODING).is_some()
|
|
{
|
|
return Ok(response);
|
|
}
|
|
|
|
// Select preferred encoding among those proposed in accept_encoding
|
|
let max_q: f32 = accept_encoding
|
|
.iter()
|
|
.max_by_key(|(_, q)| (q * 10000f32) as i64)
|
|
.unwrap_or(&(None, 1.))
|
|
.1;
|
|
let preference = [
|
|
Encoding::Zstd,
|
|
//Encoding::Brotli,
|
|
Encoding::Deflate,
|
|
Encoding::Gzip,
|
|
];
|
|
#[allow(clippy::float_cmp)]
|
|
let encoding_opt = accept_encoding
|
|
.iter()
|
|
.filter(|(_, q)| *q == max_q)
|
|
.filter_map(|(enc, _)| *enc)
|
|
.filter(|enc| preference.contains(enc))
|
|
.min_by_key(|enc| preference.iter().position(|x| x == enc).unwrap());
|
|
|
|
// If preferred encoding is none, return as is
|
|
let encoding = match encoding_opt {
|
|
None | Some(Encoding::Identity) => return Ok(response),
|
|
Some(enc) => enc,
|
|
};
|
|
|
|
// If content type not in mime types for which to compress, return as is
|
|
match response.headers().get(header::CONTENT_TYPE) {
|
|
Some(ct) => {
|
|
let ct_str = ct.to_str()?;
|
|
let mime_type = match ct_str.split_once(';') {
|
|
Some((mime_type, _params)) => mime_type,
|
|
None => ct_str,
|
|
};
|
|
if !https_config
|
|
.compress_mime_types
|
|
.iter()
|
|
.any(|x| x == mime_type)
|
|
{
|
|
return Ok(response);
|
|
}
|
|
}
|
|
None => return Ok(response), // don't compress if unknown mime type
|
|
};
|
|
|
|
let (mut head, mut body) = response.into_parts();
|
|
|
|
// ---- If body is smaller than 1400 bytes, don't compress ----
|
|
let mut chunks = vec![];
|
|
let mut sum_lengths = 0;
|
|
while sum_lengths < 1400 {
|
|
match body.next().await {
|
|
Some(chunk) => {
|
|
let chunk = chunk?;
|
|
sum_lengths += chunk.len();
|
|
chunks.push(chunk);
|
|
}
|
|
None => {
|
|
return Ok(Response::from_parts(head, Body::from(chunks.concat())));
|
|
}
|
|
}
|
|
}
|
|
|
|
// put beginning chunks back into body
|
|
let body = futures::stream::iter(chunks.into_iter().map(Ok)).chain(body);
|
|
|
|
// make an async reader from that for compressor
|
|
let body_rd =
|
|
StreamReader::new(body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)));
|
|
|
|
trace!(
|
|
"Compressing response body as {:?} (at least {} bytes)",
|
|
encoding,
|
|
sum_lengths
|
|
);
|
|
|
|
// we don't know the compressed content-length so remove that header
|
|
head.headers.remove(header::CONTENT_LENGTH);
|
|
|
|
let (encoding, compressed_body) = match encoding {
|
|
Encoding::Gzip => (
|
|
"gzip",
|
|
Body::wrap_stream(ReaderStream::new(GzipEncoder::new(body_rd))),
|
|
),
|
|
// Encoding::Brotli => {
|
|
// head.headers.insert(header::CONTENT_ENCODING, "br".parse()?);
|
|
// Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(body_rd)))
|
|
// }
|
|
Encoding::Deflate => (
|
|
"deflate",
|
|
Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(body_rd))),
|
|
),
|
|
Encoding::Zstd => (
|
|
"zstd",
|
|
Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(body_rd))),
|
|
),
|
|
_ => unreachable!(),
|
|
};
|
|
head.headers
|
|
.insert(header::CONTENT_ENCODING, encoding.parse()?);
|
|
|
|
Ok(Response::from_parts(head, compressed_body))
|
|
}
|