use std::collections::{HashMap, HashSet}; use std::sync::{Arc, RwLock}; use std::time::Duration; use anyhow::Result; use chrono::Utc; use log::*; use tokio::sync::watch; use tokio::task::block_in_place; use acme_micro::create_p384_key; use acme_micro::{Directory, DirectoryUrl}; use rustls::sign::CertifiedKey; use crate::cert::{Cert, CertSer}; use crate::consul::*; use crate::proxy_config::*; pub struct CertStore { consul: Consul, certs: RwLock>>, self_signed_certs: RwLock>>, rx_proxy_config: watch::Receiver>, } impl CertStore { pub fn new(consul: Consul, rx_proxy_config: watch::Receiver>) -> Arc { Arc::new(Self { consul, certs: RwLock::new(HashMap::new()), self_signed_certs: RwLock::new(HashMap::new()), rx_proxy_config, }) } pub async fn watch_proxy_config(self: Arc) { let mut rx_proxy_config = self.rx_proxy_config.clone(); while rx_proxy_config.changed().await.is_ok() { let mut domains: HashSet = HashSet::new(); let proxy_config: Arc = rx_proxy_config.borrow().clone(); for ent in proxy_config.entries.iter() { if let HostDescription::Hostname(domain) = &ent.host { domains.insert(domain.clone()); } } for dom in domains.iter() { info!("Ensuring we have certs for domains: {:?}", domains); if let Err(e) = self.get_cert(dom).await { warn!("Error get_cert {}: {}", dom, e); } } } } pub fn get_cert_for_https(self: &Arc, domain: &str) -> Result> { // Check if domain is authorized if !self .rx_proxy_config .borrow() .entries .iter() .any(|ent| ent.host.matches(domain)) { bail!("Domain {} should not have a TLS certificate.", domain); } // Check in local memory if it exists if let Some(cert) = self.certs.read().unwrap().get(domain) { if !cert.is_old() { return Ok(cert.clone()); } } // Not found in local memory, try to get it in background tokio::spawn(self.clone().get_cert_task(domain.to_string())); // In the meantime, use a self-signed certificate if let Some(cert) = self.self_signed_certs.read().unwrap().get(domain) { if !cert.is_old() { return Ok(cert.clone()); } } self.gen_self_signed_certificate(domain) } pub async fn get_cert_task(self: Arc, domain: String) -> Result> { self.get_cert(domain.as_str()).await } pub async fn get_cert(self: &Arc, domain: &str) -> Result> { // First, try locally. { let certs = self.certs.read().unwrap(); if let Some(cert) = certs.get(domain) { if !cert.is_old() { return Ok(cert.clone()); } } } // Second, try from Consul. if let Some(consul_cert) = self .consul .kv_get_json::(&format!("certs/{}", domain)) .await? { if let Ok(cert) = Cert::new(consul_cert) { let cert = Arc::new(cert); if !cert.is_old() { self.certs .write() .unwrap() .insert(domain.to_string(), cert.clone()); return Ok(cert); } } } // Third, ask from Let's Encrypt self.renew_cert(domain).await } pub async fn renew_cert(self: &Arc, domain: &str) -> Result> { info!("Renewing certificate for {}", domain); // ---- Acquire lock ---- // the lock is acquired for fifteen minutes, // so that in case of an error we won't retry before // that delay expires let lock_path = format!("renew_lock/{}", domain); let lock_name = format!("tricot/renew:{}@{}", domain, self.consul.local_node.clone()); let session = self .consul .create_session(&ConsulSessionRequest { name: lock_name.clone(), node: None, lock_delay: Some("15m".into()), ttl: Some("30m".into()), behavior: Some("delete".into()), }) .await?; debug!("Lock session: {}", session); if !self .consul .acquire(&lock_path, lock_name.clone().into(), &session) .await? { bail!("Lock is already taken, not renewing for now."); } // ---- Do let's encrypt stuff ---- let dir = Directory::from_url(DirectoryUrl::LetsEncrypt)?; let contact = vec!["mailto:alex@adnab.me".to_string()]; let acc = if let Some(acc_privkey) = self.consul.kv_get("letsencrypt_account_key.pem").await? { info!("Using existing Let's encrypt account"); dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)? } else { info!("Creating new Let's encrypt account"); let acc = block_in_place(|| dir.register_account(contact.clone()))?; self.consul .kv_put( "letsencrypt_account_key.pem", acc.acme_private_key_pem()?.into_bytes().into(), ) .await?; acc }; let mut ord_new = acc.new_order(domain, &[])?; let ord_csr = loop { if let Some(ord_csr) = ord_new.confirm_validations() { break ord_csr; } let auths = ord_new.authorizations()?; info!("Creating challenge and storing in Consul"); let chall = auths[0].http_challenge().unwrap(); let chall_key = format!("challenge/{}", chall.http_token()); self.consul .acquire(&chall_key, chall.http_proof()?.into(), &session) .await?; info!("Validating challenge"); block_in_place(|| chall.validate(Duration::from_millis(5000)))?; info!("Deleting challenge"); self.consul.kv_delete(&chall_key).await?; block_in_place(|| ord_new.refresh())?; }; let pkey_pri = create_p384_key()?; let ord_cert = block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?; let cert = block_in_place(|| ord_cert.download_cert())?; info!("Keys and certificate obtained"); let key_pem = cert.private_key().to_string(); let cert_pem = cert.certificate().to_string(); let certser = CertSer { hostname: domain.to_string(), date: Utc::today().naive_utc(), valid_days: cert.valid_days_left()?, key_pem, cert_pem, }; self.consul .kv_put_json(&format!("certs/{}", domain), &certser) .await?; self.consul.release(&lock_path, "".into(), &session).await?; let cert = Arc::new(Cert::new(certser)?); self.certs .write() .unwrap() .insert(domain.to_string(), cert.clone()); info!("Cert successfully renewed: {}", domain); Ok(cert) } fn gen_self_signed_certificate(&self, domain: &str) -> Result> { let subject_alt_names = vec![domain.to_string(), "localhost".to_string()]; let cert = rcgen::generate_simple_self_signed(subject_alt_names)?; let certser = CertSer { hostname: domain.to_string(), date: Utc::today().naive_utc(), valid_days: 1024, key_pem: cert.serialize_private_key_pem(), cert_pem: cert.serialize_pem()?, }; let cert = Arc::new(Cert::new(certser)?); self.self_signed_certs .write() .unwrap() .insert(domain.to_string(), cert.clone()); info!("Added self-signed certificate for {}", domain); Ok(cert) } } pub struct StoreResolver(pub Arc); impl rustls::server::ResolvesServerCert for StoreResolver { fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option> { let domain = client_hello.server_name()?; match self.0.get_cert_for_https(domain) { Ok(cert) => Some(cert.certkey.clone()), Err(e) => { warn!("Could not get certificate for {}: {}", domain, e); None } } } }