use std::collections::{HashMap, HashSet}; use std::sync::{Arc, RwLock}; use std::time::{Duration, Instant}; use anyhow::Result; use chrono::Utc; use futures::{FutureExt, TryFutureExt}; use tokio::select; use tokio::sync::{mpsc, watch}; use tokio::task::block_in_place; use tracing::*; use acme_micro::create_p384_key; use acme_micro::{Directory, DirectoryUrl}; use rustls::sign::CertifiedKey; use crate::cert::{Cert, CertSer}; use crate::consul::{self, Consul}; use crate::proxy_config::*; pub struct CertStore { consul: Consul, node_name: String, letsencrypt_email: String, certs: RwLock>>, self_signed_certs: RwLock>>, rx_proxy_config: watch::Receiver>, tx_need_cert: mpsc::UnboundedSender, } struct ProcessedDomains { static_domains: HashSet, on_demand_domains: Vec<(glob::Pattern, Option)>, } impl CertStore { pub fn new( consul: Consul, node_name: String, rx_proxy_config: watch::Receiver>, letsencrypt_email: String, exit_on_err: impl Fn(anyhow::Error) + Send + 'static, ) -> Arc { let (tx, rx) = mpsc::unbounded_channel(); let cert_store = Arc::new(Self { consul, node_name, letsencrypt_email, certs: RwLock::new(HashMap::new()), self_signed_certs: RwLock::new(HashMap::new()), rx_proxy_config, tx_need_cert: tx, }); tokio::spawn( cert_store .clone() .certificate_loop(rx) .map_err(exit_on_err) .then(|_| async { info!("Certificate renewal task exited") }), ); cert_store } async fn certificate_loop( self: Arc, mut rx_need_cert: mpsc::UnboundedReceiver, ) -> Result<()> { let mut rx_proxy_config = self.rx_proxy_config.clone(); let mut t_last_check: HashMap = HashMap::new(); let mut proc_domains: Option = None; loop { let domains = select! { // Refresh some internal states, schedule static_domains for renew res = rx_proxy_config.changed() => { if res.is_err() { bail!("rx_proxy_config closed"); } let mut static_domains: HashSet = HashSet::new(); let mut on_demand_domains: Vec<(glob::Pattern, Option)> = vec![]; let proxy_config: Arc = rx_proxy_config.borrow().clone(); for ent in proxy_config.entries.iter() { // Eagerly generate certificates for domains that // are not patterns match &ent.url_prefix.host { HostDescription::Hostname(domain) => { if let Some((host, _port)) = domain.split_once(':') { static_domains.insert(host.to_string()); } else { static_domains.insert(domain.clone()); } }, HostDescription::Pattern(pattern) => { on_demand_domains.push((pattern.clone(), ent.on_demand_tls_ask.clone())); }, } } // only static_domains are refreshed proc_domains = Some(ProcessedDomains { static_domains: static_domains.clone(), on_demand_domains }); self.domain_validation(static_domains, proc_domains.as_ref()).await } // renew static and on-demand domains need_cert = rx_need_cert.recv() => { match need_cert { Some(dom) => { let mut candidates: HashSet = HashSet::new(); // collect certificates as much as possible candidates.insert(dom); while let Ok(dom2) = rx_need_cert.try_recv() { candidates.insert(dom2); } self.domain_validation(candidates, proc_domains.as_ref()).await } None => bail!("rx_need_cert closed"), } } }; // Now that we have our list of domains to check, // actually do something for dom in domains.iter() { // Exclude from the list domains that were checked less than 60 // seconds ago match t_last_check.get(dom) { Some(t) if Instant::now() - *t < Duration::from_secs(60) => continue, _ => t_last_check.insert(dom.to_string(), Instant::now()), }; // Actual Let's Encrypt calls are done here (in sister function) debug!("Checking cert for domain: {}", dom); if let Err(e) = self.check_cert(dom).await { warn!("({}) Could not get certificate: {}", dom, e); } } } } async fn domain_validation( &self, candidates: HashSet, maybe_proc_domains: Option<&ProcessedDomains>, ) -> HashSet { let mut domains: HashSet = HashSet::new(); // Handle initialization let proc_domains = match maybe_proc_domains { None => { warn!("Proxy config is not yet loaded, refusing all certificate generation"); return domains; } Some(proc) => proc, }; // Filter certificates... 'outer: for candidate in candidates.into_iter() { // Disallow obvious wrong domains... if !candidate.contains('.') || candidate.ends_with(".local") { warn!("{} is probably not a publicly accessible domain, skipping (a self-signed certificate will be used)", candidate); continue; } // Try to register domain as a static domain if proc_domains.static_domains.contains(&candidate) { trace!("domain {} validated as static domain", candidate); domains.insert(candidate); continue; } // It's not a static domain, maybe an on-demand domain? for (pattern, maybe_check_url) in proc_domains.on_demand_domains.iter() { // check glob pattern if pattern.matches(&candidate) { // if no check url is set, accept domain as long as it matches the pattern let check_url = match maybe_check_url { None => { trace!( "domain {} validated on glob pattern {} only", candidate, pattern ); domains.insert(candidate); continue 'outer; } Some(url) => url, }; // if a check url is set, call it // -- avoid DDoSing a backend tokio::time::sleep(Duration::from_secs(2)).await; match self.on_demand_tls_ask(check_url, &candidate).await { Ok(()) => { trace!( "domain {} validated on glob pattern {} and on check url {}", candidate, pattern, check_url ); domains.insert(candidate); continue 'outer; } Err(e) => { warn!("domain {} validation refused on glob pattern {} and on check url {} with error: {}", candidate, pattern, check_url, e); } } } } } return domains; } /// This function is also in charge of the refresh of the domain names fn get_cert_for_https(self: &Arc, domain: &str) -> Result> { // Check in local memory if it exists if let Some(cert) = self.certs.read().unwrap().get(domain) { if cert.is_old() { self.tx_need_cert.send(domain.to_string())?; } return Ok(cert.clone()); } // Not found in local memory, try to get it in background self.tx_need_cert.send(domain.to_string())?; // In the meantime, use a self-signed certificate if let Some(cert) = self.self_signed_certs.read().unwrap().get(domain) { return Ok(cert.clone()); } self.gen_self_signed_certificate(domain) } pub async fn warmup_memory_store(self: &Arc) -> Result<()> { let consul_certs = self .consul .kv_get_prefix("certs/", None) .await? .into_inner(); trace!( "Fetched {} certificate entries from Consul", consul_certs.len() ); let mut loaded_certs: usize = 0; for (key, cert) in consul_certs { let certser: CertSer = match serde_json::from_slice(&cert) { Ok(cs) => cs, Err(e) => { warn!("Could not deserialize CertSer for {key}: {e}"); continue; } }; let domain = certser.hostname.clone(); let cert = match Cert::new(certser) { Ok(c) => c, Err(e) => { warn!("Could not create Cert from CertSer for domain {domain}: {e}"); continue; } }; self.certs .write() .unwrap() .insert(domain.to_string(), Arc::new(cert)); debug!("({domain}) Certificate loaded from Consul to the Memory Store"); loaded_certs += 1; } info!("Memory store warmed up with {loaded_certs} certificates"); Ok(()) } /// Check certificate ensure that the certificate is in the memory store /// and that it does not need to be renewed. /// /// If it's not in the memory store, it tries to load it from Consul, /// if it's not in Consul, it calls Let's Encrypt. /// /// If the certificate is outdated in the memory store, it tries to load /// a more recent version in Consul, if the Consul version is also outdated, /// it tries to renew it pub async fn check_cert(self: &Arc, domain: &str) -> Result<()> { // First, try locally. { let certs = self.certs.read().unwrap(); if let Some(cert) = certs.get(domain) { if !cert.is_old() { return Ok(()); } } } // Second, try from Consul. if let Some(consul_cert) = self .consul .kv_get_json::(&format!("certs/{}", domain)) .await? { match Cert::new(consul_cert) { Ok(cert) => { let cert = Arc::new(cert); self.certs .write() .unwrap() .insert(domain.to_string(), cert.clone()); debug!("({domain}) Certificate loaded from Consul to the Memory Store"); if !cert.is_old() { return Ok(()); } } Err(e) => { warn!("Could not create Cert from CertSer for domain {domain}: {e}"); } }; } // Third, ask from Let's Encrypt self.renew_cert(domain).await } /// This is the place where certificates are generated or renewed pub async fn renew_cert(self: &Arc, domain: &str) -> Result<()> { info!("({}) Renewing certificate", domain); // ---- Acquire lock ---- // the lock is acquired for half an hour, // so that in case of an error we won't retry before // that delay expires let lock_path = format!("renew_lock/{}", domain); let lock_name = format!("tricot/renew:{}@{}", domain, self.node_name); let session = self .consul .create_session(&consul::locking::SessionRequest { name: lock_name.clone(), node: None, lock_delay: Some("30m".into()), ttl: Some("45m".into()), behavior: Some("delete".into()), }) .await?; debug!("Lock session: {}", session); if !self .consul .acquire(&lock_path, lock_name.clone().into(), &session) .await? { bail!("Lock is already taken, not renewing for now."); } // ---- Accessibility check ---- // We don't want to ask Let's encrypt for a domain that // is not configured to point here. This can happen with wildcards: someone can send // a fake SNI to a domain that is not ours. We have to detect it here. self.check_domain_accessibility(domain, &session).await?; // ---- Do let's encrypt stuff ---- let dir = Directory::from_url(DirectoryUrl::LetsEncrypt)?; let contact = vec![format!("mailto:{}", self.letsencrypt_email)]; // Use existing Let's encrypt account or register new one if necessary let acc = match self.consul.kv_get("letsencrypt_account_key.pem").await? { Some(acc_privkey) => { info!("Using existing Let's encrypt account"); dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)? } None => { info!("Creating new Let's encrypt account"); let acc = block_in_place(|| dir.register_account(contact.clone()))?; self.consul .kv_put( "letsencrypt_account_key.pem", acc.acme_private_key_pem()?.into_bytes().into(), ) .await?; acc } }; // Order certificate and perform validation let mut ord_new = acc.new_order(domain, &[])?; let ord_csr = loop { if let Some(ord_csr) = ord_new.confirm_validations() { break ord_csr; } let auths = ord_new.authorizations()?; info!("({}) Creating challenge and storing in Consul", domain); let chall = auths[0].http_challenge().unwrap(); let chall_key = format!("challenge/{}", chall.http_token()); self.consul .acquire(&chall_key, chall.http_proof()?.into(), &session) .await?; info!("({}) Validating challenge", domain); block_in_place(|| chall.validate(Duration::from_millis(5000)))?; info!("({}) Deleting challenge", domain); self.consul.kv_delete(&chall_key).await?; block_in_place(|| ord_new.refresh())?; }; // Generate key and finalize certificate let pkey_pri = create_p384_key()?; let ord_cert = block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?; let cert = block_in_place(|| ord_cert.download_cert())?; info!("({}) Keys and certificate obtained", domain); let key_pem = cert.private_key().to_string(); let cert_pem = cert.certificate().to_string(); let certser = CertSer { hostname: domain.to_string(), date: Utc::now().date_naive(), valid_days: cert.valid_days_left()?, key_pem, cert_pem, }; let cert = Arc::new(Cert::new(certser.clone())?); // Store certificate in Consul and local store self.certs.write().unwrap().insert(domain.to_string(), cert); self.consul .kv_put_json(&format!("certs/{}", domain), &certser) .await?; // Release locks self.consul.release(&lock_path, "".into(), &session).await?; self.consul.kv_delete(&lock_path).await?; info!("({}) Cert successfully renewed and stored", domain); Ok(()) } async fn on_demand_tls_ask(&self, check_url: &str, domain: &str) -> Result<()> { let httpcli = reqwest::Client::new(); let chall_url = format!("{}?domain={}", check_url, domain); info!("({}) On-demand TLS check", domain); let httpresp = httpcli.get(&chall_url).send().await?; if httpresp.status() != reqwest::StatusCode::OK { bail!("{} is not authorized for on-demand TLS", domain); } Ok(()) } async fn check_domain_accessibility(&self, domain: &str, session: &str) -> Result<()> { // Returns Ok(()) only if domain is a correct domain name that // redirects to this server let self_challenge_id = uuid::Uuid::new_v4().to_string(); let self_challenge_key = format!("challenge/{}", self_challenge_id); let self_challenge_resp = uuid::Uuid::new_v4().to_string(); self.consul .acquire( &self_challenge_key, self_challenge_resp.as_bytes().to_vec().into(), session, ) .await?; let httpcli = reqwest::Client::new(); let chall_url = format!( "http://{}/.well-known/acme-challenge/{}", domain, self_challenge_id ); for i in 1..=4 { tokio::time::sleep(Duration::from_secs(2)).await; info!("({}) Accessibility check {}/4", domain, i); let httpresp = httpcli.get(&chall_url).send().await?; if httpresp.status() == reqwest::StatusCode::OK && httpresp.bytes().await? == self_challenge_resp.as_bytes() { // Challenge successfully validated info!("({}) Accessibility check successfull", domain); return Ok(()); } tokio::time::sleep(Duration::from_secs(2)).await; } bail!("Unable to validate self-challenge for domain accessibility check"); } fn gen_self_signed_certificate(&self, domain: &str) -> Result> { let subject_alt_names = vec![domain.to_string(), "localhost".to_string()]; let cert = rcgen::generate_simple_self_signed(subject_alt_names)?; let certser = CertSer { hostname: domain.to_string(), date: Utc::now().date_naive(), valid_days: 1024, key_pem: cert.serialize_private_key_pem(), cert_pem: cert.serialize_pem()?, }; let cert = Arc::new(Cert::new(certser)?); self.self_signed_certs .write() .unwrap() .insert(domain.to_string(), cert.clone()); info!("Added self-signed certificate for {}", domain); Ok(cert) } } pub struct StoreResolver(pub Arc); impl rustls::server::ResolvesServerCert for StoreResolver { fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option> { let domain = client_hello.server_name()?; match self.0.get_cert_for_https(domain) { Ok(cert) => Some(cert.certkey.clone()), Err(e) => { warn!("Could not get certificate for {}: {}", domain, e); None } } } }