diff --git a/src/cert_store.rs b/src/cert_store.rs index a2f67ec..f1b7d2b 100644 --- a/src/cert_store.rs +++ b/src/cert_store.rs @@ -19,20 +19,21 @@ use crate::proxy_config::ProxyConfig; pub struct CertStore { consul: Consul, certs: RwLock>>, + rx_proxy_config: watch::Receiver>, } impl CertStore { - pub fn new(consul: Consul) -> Arc { + pub fn new(consul: Consul, rx_proxy_config: watch::Receiver>) -> Arc { Arc::new(Self { consul, certs: RwLock::new(HashMap::new()), + rx_proxy_config, }) } - pub async fn watch_proxy_config( - self: Arc, - mut rx_proxy_config: watch::Receiver>, - ) { + pub async fn watch_proxy_config(self: Arc) { + let mut rx_proxy_config = self.rx_proxy_config.clone(); + while rx_proxy_config.changed().await.is_ok() { let mut domains: HashSet = HashSet::new(); @@ -50,6 +51,35 @@ impl CertStore { } } + pub fn get_cert_for_https(self: &Arc, domain: &str) -> Result> { + // Check if domain is authorized + if !self + .rx_proxy_config + .borrow() + .entries + .iter() + .any(|ent| ent.host == domain) + { + bail!("Domain {} should not have a TLS certificate.", domain); + } + + // Check in local memory if it exists + let certs = self.certs.read().unwrap(); + if let Some(cert) = certs.get(domain) { + if !cert.is_old() { + return Ok(cert.clone()); + } + } + + // Not found in local memory + tokio::spawn(self.clone().get_cert_task(domain.to_string())); + bail!("Certificate not found (will try to get it in background)"); + } + + pub async fn get_cert_task(self: Arc, domain: String) -> Result> { + self.get_cert(domain.as_str()).await + } + pub async fn get_cert(self: &Arc, domain: &str) -> Result> { // First, try locally. { @@ -196,7 +226,7 @@ pub struct StoreResolver(pub Arc); impl rustls::server::ResolvesServerCert for StoreResolver { fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option> { let domain = client_hello.server_name()?; - let cert = futures::executor::block_on(self.0.get_cert(domain)).ok()?; + let cert = self.0.get_cert_for_https(domain).ok()?; Some(cert.certkey.clone()) } } diff --git a/src/consul.rs b/src/consul.rs index eb7aafd..1b94dd0 100644 --- a/src/consul.rs +++ b/src/consul.rs @@ -95,7 +95,11 @@ impl Consul { Ok(resp.into_iter().map(|n| n.node).collect::>()) } - pub async fn watch_node(&self, host: &str, idx: Option) -> Result<(ConsulNodeCatalog, usize)> { + pub async fn watch_node( + &self, + host: &str, + idx: Option, + ) -> Result<(ConsulNodeCatalog, usize)> { debug!("watch_node {} {:?}", host, idx); let url = match idx { diff --git a/src/http.rs b/src/http.rs index 4731645..2b26e6d 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,5 +1,5 @@ -use std::sync::Arc; use std::net::SocketAddr; +use std::sync::Arc; use anyhow::Result; use log::*; diff --git a/src/https.rs b/src/https.rs index 43a93e2..3621e4f 100644 --- a/src/https.rs +++ b/src/https.rs @@ -5,10 +5,10 @@ use anyhow::Result; use log::*; use futures::FutureExt; +use http::header::{HeaderName, HeaderValue}; use hyper::server::conn::Http; use hyper::service::service_fn; use hyper::{Body, Request, Response, StatusCode}; -use http::header::{HeaderName, HeaderValue}; use tokio::net::TcpListener; use tokio::sync::watch; use tokio_rustls::TlsAcceptor; @@ -121,7 +121,10 @@ async fn handle( let mut response = reverse_proxy::call(remote_addr.ip(), &to_addr, req).await?; for (header, value) in proxy_to.add_headers.iter() { - response.headers_mut().insert(HeaderName::from_bytes(header.as_bytes())?, HeaderValue::from_str(value)?); + response.headers_mut().insert( + HeaderName::from_bytes(header.as_bytes())?, + HeaderValue::from_str(value)?, + ); } Ok(response) diff --git a/src/main.rs b/src/main.rs index f38767e..c6fd1d1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,11 +18,19 @@ use log::*; #[structopt(name = "tricot")] struct Opt { /// Address of consul server - #[structopt(long = "consul-addr", env = "TRICOT_CONSUL_HOST", default_value = "http://127.0.0.1:8500/")] + #[structopt( + long = "consul-addr", + env = "TRICOT_CONSUL_HOST", + default_value = "http://127.0.0.1:8500/" + )] pub consul_addr: String, /// Prefix of Tricot's entries in Consul KV space - #[structopt(long = "consul-kv-prefix", env = "TRICOT_CONSUL_KV_PREFIX", default_value = "tricot/")] + #[structopt( + long = "consul-kv-prefix", + env = "TRICOT_CONSUL_KV_PREFIX", + default_value = "tricot/" + )] pub consul_kv_prefix: String, /// Node name @@ -30,15 +38,22 @@ struct Opt { pub node_name: String, /// Bind address for HTTP server - #[structopt(long = "http-bind-addr", env = "TRICOT_HTTP_BIND_ADDR", default_value = "0.0.0.0:80")] + #[structopt( + long = "http-bind-addr", + env = "TRICOT_HTTP_BIND_ADDR", + default_value = "0.0.0.0:80" + )] pub http_bind_addr: SocketAddr, /// Bind address for HTTPS server - #[structopt(long = "https-bind-addr", env = "TRICOT_HTTPS_BIND_ADDR", default_value = "0.0.0.0:443")] + #[structopt( + long = "https-bind-addr", + env = "TRICOT_HTTPS_BIND_ADDR", + default_value = "0.0.0.0:443" + )] pub https_bind_addr: SocketAddr, } - #[tokio::main(flavor = "multi_thread", worker_threads = 10)] async fn main() { if std::env::var("RUST_LOG").is_err() { @@ -53,16 +68,12 @@ async fn main() { let consul = consul::Consul::new(&opt.consul_addr, &opt.consul_kv_prefix, &opt.node_name); let mut rx_proxy_config = proxy_config::spawn_proxy_config_task(consul.clone()); - let cert_store = cert_store::CertStore::new(consul.clone()); - tokio::spawn( - cert_store - .clone() - .watch_proxy_config(rx_proxy_config.clone()), - ); + let cert_store = cert_store::CertStore::new(consul.clone(), rx_proxy_config.clone()); + tokio::spawn(cert_store.clone().watch_proxy_config()); tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone())); tokio::spawn(https::serve_https( - opt.https_bind_addr, + opt.https_bind_addr, cert_store.clone(), rx_proxy_config.clone(), )); diff --git a/src/proxy_config.rs b/src/proxy_config.rs index 3e3e62f..d4fe039 100644 --- a/src/proxy_config.rs +++ b/src/proxy_config.rs @@ -1,12 +1,12 @@ +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::{atomic, Arc}; -use std::collections::HashMap; use std::{cmp, time::Duration}; use anyhow::Result; +use futures::future::BoxFuture; use futures::stream::{FuturesUnordered, StreamExt}; -use futures::future::{BoxFuture}; use log::*; use tokio::{sync::watch, time::sleep}; @@ -45,7 +45,11 @@ fn retry_to_time(retries: u32, max_time: Duration) -> Duration { )); } -fn parse_tricot_tag(tag: &str, target_addr: SocketAddr, add_headers: &[(String, String)]) -> Option { +fn parse_tricot_tag( + tag: &str, + target_addr: SocketAddr, + add_headers: &[(String, String)], +) -> Option { let splits = tag.split(' ').collect::>(); if (splits.len() != 2 && splits.len() != 3) || splits[0] != "tricot" { return None; @@ -89,10 +93,13 @@ fn parse_consul_catalog(catalog: &ConsulNodeCatalog) -> Vec { _ => match catalog.node.address.parse() { Ok(ip) => ip, _ => { - warn!("Could not get address for service {} at node {}", svc.service, catalog.node.node); + warn!( + "Could not get address for service {} at node {}", + svc.service, catalog.node.node + ); continue; } - } + }, }; let addr = SocketAddr::new(ip_addr, svc.port); @@ -124,7 +131,7 @@ pub fn spawn_proxy_config_task(consul: Consul) -> watch::Receiver watch::Receiver