diff --git a/src/cert.rs b/src/cert.rs index 0be43f3..12b9218 100644 --- a/src/cert.rs +++ b/src/cert.rs @@ -6,7 +6,7 @@ use chrono::{Date, NaiveDate, Utc}; use rustls::sign::CertifiedKey; use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct CertSer { pub hostname: String, pub date: NaiveDate, diff --git a/src/cert_store.rs b/src/cert_store.rs index fe2f8b0..2095660 100644 --- a/src/cert_store.rs +++ b/src/cert_store.rs @@ -1,11 +1,13 @@ use std::collections::{HashMap, HashSet}; use std::sync::{Arc, RwLock}; -use std::time::Duration; +use std::time::{Duration, Instant}; use anyhow::Result; use chrono::Utc; +use futures::TryFutureExt; use log::*; -use tokio::sync::watch; +use tokio::select; +use tokio::sync::{mpsc, watch}; use tokio::task::block_in_place; use acme_micro::create_p384_key; @@ -14,6 +16,7 @@ use rustls::sign::CertifiedKey; use crate::cert::{Cert, CertSer}; use crate::consul::*; +use crate::exit_on_err; use crate::proxy_config::*; pub struct CertStore { @@ -22,6 +25,7 @@ pub struct CertStore { certs: RwLock>>, self_signed_certs: RwLock>>, rx_proxy_config: watch::Receiver>, + tx_need_cert: mpsc::UnboundedSender, } impl CertStore { @@ -30,44 +34,78 @@ impl CertStore { rx_proxy_config: watch::Receiver>, letsencrypt_email: String, ) -> Arc { - Arc::new(Self { + let (tx, rx) = mpsc::unbounded_channel(); + + let cert_store = Arc::new(Self { consul, certs: RwLock::new(HashMap::new()), self_signed_certs: RwLock::new(HashMap::new()), rx_proxy_config, letsencrypt_email, - }) + tx_need_cert: tx, + }); + + tokio::spawn(cert_store.clone().certificate_loop(rx).map_err(exit_on_err)); + + cert_store } - pub async fn watch_proxy_config(self: Arc) -> Result<()> { + async fn certificate_loop( + self: Arc, + mut rx_need_cert: mpsc::UnboundedReceiver, + ) -> Result<()> { let mut rx_proxy_config = self.rx_proxy_config.clone(); - while rx_proxy_config.changed().await.is_ok() { + let mut t_last_check: HashMap = HashMap::new(); + + loop { let mut domains: HashSet = HashSet::new(); - let proxy_config: Arc = rx_proxy_config.borrow().clone(); - for ent in proxy_config.entries.iter() { - if let HostDescription::Hostname(domain) = &ent.host { - if let Some((host, _port)) = domain.split_once(':') { - domains.insert(host.to_string()); - } else { - domains.insert(domain.clone()); + select! { + res = rx_proxy_config.changed() => { + if res.is_err() { + bail!("rx_proxy_config closed"); } + + let proxy_config: Arc = rx_proxy_config.borrow().clone(); + for ent in proxy_config.entries.iter() { + if let HostDescription::Hostname(domain) = &ent.host { + if let Some((host, _port)) = domain.split_once(':') { + domains.insert(host.to_string()); + } else { + domains.insert(domain.clone()); + } + } + } + } + need_cert = rx_need_cert.recv() => { + match need_cert { + Some(dom) => { + domains.insert(dom); + while let Ok(dom2) = rx_need_cert.try_recv() { + domains.insert(dom2); + } + } + None => bail!("rx_need_cert closed"), + }; } } - debug!("Ensuring we have certs for domains: {:#?}", domains); for dom in domains.iter() { - if let Err(e) = self.get_cert(dom).await { - warn!("Error get_cert {}: {}", dom, e); + match t_last_check.get(dom) { + Some(t) if Instant::now() - *t < Duration::from_secs(3600) => continue, + _ => t_last_check.insert(dom.to_string(), Instant::now()), + }; + + debug!("Checking cert for domain: {}", dom); + if let Err(e) = self.check_cert(dom).await { + warn!("({}) Could not get certificate: {}", dom, e); } } } - - bail!("rx_proxy_config closed"); } - pub fn get_cert_for_https(self: &Arc, domain: &str) -> Result> { + fn get_cert_for_https(self: &Arc, domain: &str) -> Result> { // Check if domain is authorized if !self .rx_proxy_config @@ -81,35 +119,30 @@ impl CertStore { // Check in local memory if it exists if let Some(cert) = self.certs.read().unwrap().get(domain) { - if !cert.is_old() { - return Ok(cert.clone()); + if cert.is_old() { + self.tx_need_cert.send(domain.to_string())?; } + return Ok(cert.clone()); } // Not found in local memory, try to get it in background - tokio::spawn(self.clone().get_cert_task(domain.to_string())); + self.tx_need_cert.send(domain.to_string())?; // In the meantime, use a self-signed certificate if let Some(cert) = self.self_signed_certs.read().unwrap().get(domain) { - if !cert.is_old() { - return Ok(cert.clone()); - } + return Ok(cert.clone()); } self.gen_self_signed_certificate(domain) } - pub async fn get_cert_task(self: Arc, domain: String) -> Result> { - self.get_cert(domain.as_str()).await - } - - pub async fn get_cert(self: &Arc, domain: &str) -> Result> { + pub async fn check_cert(self: &Arc, domain: &str) -> Result<()> { // First, try locally. { let certs = self.certs.read().unwrap(); if let Some(cert) = certs.get(domain) { if !cert.is_old() { - return Ok(cert.clone()); + return Ok(()); } } } @@ -122,12 +155,12 @@ impl CertStore { { if let Ok(cert) = Cert::new(consul_cert) { let cert = Arc::new(cert); + self.certs + .write() + .unwrap() + .insert(domain.to_string(), cert.clone()); if !cert.is_old() { - self.certs - .write() - .unwrap() - .insert(domain.to_string(), cert.clone()); - return Ok(cert); + return Ok(()); } } } @@ -136,8 +169,14 @@ impl CertStore { self.renew_cert(domain).await } - pub async fn renew_cert(self: &Arc, domain: &str) -> Result> { - info!("Renewing certificate for {}", domain); + pub async fn renew_cert(self: &Arc, domain: &str) -> Result<()> { + info!("({}) Renewing certificate", domain); + + // Basic sanity check (we could add more kinds of checks here) + // This is just to help avoid getting rate-limited against ACME server + if !domain.contains('.') || domain.ends_with(".local") { + bail!("Probably not a publicly accessible domain, skipping (a self-signed certificate will be used)"); + } // ---- Acquire lock ---- // the lock is acquired for fifteen minutes, @@ -171,11 +210,13 @@ impl CertStore { let dir = Directory::from_url(DirectoryUrl::LetsEncrypt)?; let contact = vec![format!("mailto:{}", self.letsencrypt_email)]; - let acc = - if let Some(acc_privkey) = self.consul.kv_get("letsencrypt_account_key.pem").await? { + // Use existing Let's encrypt account or register new one if necessary + let acc = match self.consul.kv_get("letsencrypt_account_key.pem").await? { + Some(acc_privkey) => { info!("Using existing Let's encrypt account"); dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)? - } else { + } + None => { info!("Creating new Let's encrypt account"); let acc = block_in_place(|| dir.register_account(contact.clone()))?; self.consul @@ -185,8 +226,10 @@ impl CertStore { ) .await?; acc - }; + } + }; + // Order certificate and perform validation let mut ord_new = acc.new_order(domain, &[])?; let ord_csr = loop { if let Some(ord_csr) = ord_new.confirm_validations() { @@ -195,28 +238,29 @@ impl CertStore { let auths = ord_new.authorizations()?; - info!("Creating challenge and storing in Consul"); + info!("({}) Creating challenge and storing in Consul", domain); let chall = auths[0].http_challenge().unwrap(); let chall_key = format!("challenge/{}", chall.http_token()); self.consul .acquire(&chall_key, chall.http_proof()?.into(), &session) .await?; - info!("Validating challenge"); + info!("({}) Validating challenge", domain); block_in_place(|| chall.validate(Duration::from_millis(5000)))?; - info!("Deleting challenge"); + info!("({}) Deleting challenge", domain); self.consul.kv_delete(&chall_key).await?; block_in_place(|| ord_new.refresh())?; }; + // Generate key and finalize certificate let pkey_pri = create_p384_key()?; let ord_cert = block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?; let cert = block_in_place(|| ord_cert.download_cert())?; - info!("Keys and certificate obtained"); + info!("({}) Keys and certificate obtained", domain); let key_pem = cert.private_key().to_string(); let cert_pem = cert.certificate().to_string(); @@ -227,21 +271,20 @@ impl CertStore { key_pem, cert_pem, }; + let cert = Arc::new(Cert::new(certser.clone())?); + // Store certificate in Consul and local store + self.certs.write().unwrap().insert(domain.to_string(), cert); self.consul .kv_put_json(&format!("certs/{}", domain), &certser) .await?; + + // Release locks self.consul.release(&lock_path, "".into(), &session).await?; self.consul.kv_delete(&lock_path).await?; - let cert = Arc::new(Cert::new(certser)?); - self.certs - .write() - .unwrap() - .insert(domain.to_string(), cert.clone()); - - info!("Cert successfully renewed: {}", domain); - Ok(cert) + info!("({}) Cert successfully renewed and stored", domain); + Ok(()) } fn gen_self_signed_certificate(&self, domain: &str) -> Result> { diff --git a/src/https.rs b/src/https.rs index b0d452b..a389e72 100644 --- a/src/https.rs +++ b/src/https.rs @@ -114,7 +114,6 @@ async fn handle( ) }); - if let Some(proxy_to) = best_match { proxy_to.calls.fetch_add(1, Ordering::SeqCst); diff --git a/src/main.rs b/src/main.rs index 1fffcbc..faffac6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,6 @@ use futures::TryFutureExt; use std::net::SocketAddr; use structopt::StructOpt; -mod tls_util; mod cert; mod cert_store; mod consul; @@ -13,6 +12,7 @@ mod http; mod https; mod proxy_config; mod reverse_proxy; +mod tls_util; use log::*; @@ -85,7 +85,6 @@ async fn main() { rx_proxy_config.clone(), opt.letsencrypt_email.clone(), ); - tokio::spawn(cert_store.clone().watch_proxy_config().map_err(exit_on_err)); tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone()).map_err(exit_on_err)); tokio::spawn( diff --git a/src/proxy_config.rs b/src/proxy_config.rs index 2c55eb5..820d40a 100644 --- a/src/proxy_config.rs +++ b/src/proxy_config.rs @@ -99,7 +99,8 @@ fn parse_tricot_tag( ) -> Option { let splits = tag.split(' ').collect::>(); if (splits.len() != 2 && splits.len() != 3) - || (splits[0] != "tricot" && splits[0] != "tricot-https") { + || (splits[0] != "tricot" && splits[0] != "tricot-https") + { return None; } diff --git a/src/reverse_proxy.rs b/src/reverse_proxy.rs index 10f463c..7b0f261 100644 --- a/src/reverse_proxy.rs +++ b/src/reverse_proxy.rs @@ -1,11 +1,11 @@ //! Copied from https://github.com/felipenoris/hyper-reverse-proxy //! See there for original Copyright notice -use std::sync::Arc; use std::convert::TryInto; -use std::time::SystemTime; use std::net::IpAddr; use std::str::FromStr; +use std::sync::Arc; +use std::time::SystemTime; use anyhow::Result; use log::*; @@ -13,9 +13,9 @@ use log::*; use http::header::HeaderName; use hyper::header::{HeaderMap, HeaderValue}; use hyper::{Body, Client, Request, Response, Uri}; -use rustls::{Certificate, ServerName}; -use rustls::client::{ServerCertVerifier, ServerCertVerified}; use lazy_static::lazy_static; +use rustls::client::{ServerCertVerified, ServerCertVerifier}; +use rustls::{Certificate, ServerName}; use crate::tls_util::HttpsConnectorFixedDnsname; @@ -175,16 +175,14 @@ struct DontVerifyServerCert; impl ServerCertVerifier for DontVerifyServerCert { fn verify_server_cert( - &self, - _end_entity: &Certificate, - _intermediates: &[Certificate], - _server_name: &ServerName, - _scts: &mut dyn Iterator, - _ocsp_response: &[u8], - _now: SystemTime - ) -> Result { + &self, + _end_entity: &Certificate, + _intermediates: &[Certificate], + _server_name: &ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: SystemTime, + ) -> Result { Ok(ServerCertVerified::assertion()) } } - - diff --git a/src/tls_util.rs b/src/tls_util.rs index 054c35a..91ad31c 100644 --- a/src/tls_util.rs +++ b/src/tls_util.rs @@ -1,21 +1,20 @@ use core::future::Future; use core::task::{Context, Poll}; use std::convert::TryFrom; +use std::io; use std::pin::Pin; use std::sync::Arc; -use std::io; use futures_util::future::*; -use rustls::ServerName; use hyper::client::connect::Connection; use hyper::client::HttpConnector; use hyper::service::Service; use hyper::Uri; use hyper_rustls::MaybeHttpsStream; +use rustls::ServerName; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsConnector; - #[derive(Clone)] pub struct HttpsConnectorFixedDnsname { http: T, @@ -62,8 +61,7 @@ where let cfg = self.tls_config.clone(); let connecting_future = self.http.call(dst); - let dnsname = - ServerName::try_from(self.fixed_dnsname).expect("Invalid fixed dnsname"); + let dnsname = ServerName::try_from(self.fixed_dnsname).expect("Invalid fixed dnsname"); let f = async move { let tcp = connecting_future.await.map_err(Into::into)?; let connector = TlsConnector::from(cfg); @@ -76,4 +74,3 @@ where f.boxed() } } -