Quentin Dufour
d31212e56b
All checks were successful
continuous-integration/drone/push Build is passing
543 lines
16 KiB
Rust
543 lines
16 KiB
Rust
use std::collections::{HashMap, HashSet};
|
|
use std::sync::{Arc, RwLock};
|
|
use std::time::{Duration, Instant};
|
|
|
|
use anyhow::Result;
|
|
use chrono::Utc;
|
|
use futures::{FutureExt, TryFutureExt};
|
|
use tokio::select;
|
|
use tokio::sync::{mpsc, watch};
|
|
use tokio::task::block_in_place;
|
|
use tracing::*;
|
|
|
|
use acme_micro::create_p384_key;
|
|
use acme_micro::{Directory, DirectoryUrl};
|
|
use rustls::sign::CertifiedKey;
|
|
|
|
use crate::cert::{Cert, CertSer};
|
|
use crate::consul::{self, Consul};
|
|
use crate::proxy_config::*;
|
|
|
|
pub struct CertStore {
|
|
consul: Consul,
|
|
node_name: String,
|
|
letsencrypt_email: String,
|
|
|
|
certs: RwLock<HashMap<String, Arc<Cert>>>,
|
|
self_signed_certs: RwLock<HashMap<String, Arc<Cert>>>,
|
|
|
|
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
|
|
tx_need_cert: mpsc::UnboundedSender<String>,
|
|
}
|
|
|
|
struct ProcessedDomains {
|
|
static_domains: HashSet<String>,
|
|
on_demand_domains: Vec<(glob::Pattern, Option<String>)>,
|
|
}
|
|
|
|
impl CertStore {
|
|
pub fn new(
|
|
consul: Consul,
|
|
node_name: String,
|
|
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
|
|
letsencrypt_email: String,
|
|
exit_on_err: impl Fn(anyhow::Error) + Send + 'static,
|
|
) -> Arc<Self> {
|
|
let (tx, rx) = mpsc::unbounded_channel();
|
|
|
|
let cert_store = Arc::new(Self {
|
|
consul,
|
|
node_name,
|
|
letsencrypt_email,
|
|
certs: RwLock::new(HashMap::new()),
|
|
self_signed_certs: RwLock::new(HashMap::new()),
|
|
rx_proxy_config,
|
|
tx_need_cert: tx,
|
|
});
|
|
|
|
tokio::spawn(
|
|
cert_store
|
|
.clone()
|
|
.certificate_loop(rx)
|
|
.map_err(exit_on_err)
|
|
.then(|_| async { info!("Certificate renewal task exited") }),
|
|
);
|
|
|
|
cert_store
|
|
}
|
|
|
|
async fn certificate_loop(
|
|
self: Arc<Self>,
|
|
mut rx_need_cert: mpsc::UnboundedReceiver<String>,
|
|
) -> Result<()> {
|
|
let mut rx_proxy_config = self.rx_proxy_config.clone();
|
|
|
|
let mut t_last_check: HashMap<String, Instant> = HashMap::new();
|
|
let mut proc_domains: Option<ProcessedDomains> = None;
|
|
|
|
loop {
|
|
let domains = select! {
|
|
// Refresh some internal states, schedule static_domains for renew
|
|
res = rx_proxy_config.changed() => {
|
|
if res.is_err() {
|
|
bail!("rx_proxy_config closed");
|
|
}
|
|
|
|
let mut static_domains: HashSet<String> = HashSet::new();
|
|
let mut on_demand_domains: Vec<(glob::Pattern, Option<String>)> = vec![];
|
|
|
|
let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();
|
|
|
|
for ent in proxy_config.entries.iter() {
|
|
// Eagerly generate certificates for domains that
|
|
// are not patterns
|
|
match &ent.url_prefix.host {
|
|
HostDescription::Hostname(domain) => {
|
|
if let Some((host, _port)) = domain.split_once(':') {
|
|
static_domains.insert(host.to_string());
|
|
} else {
|
|
static_domains.insert(domain.clone());
|
|
}
|
|
},
|
|
HostDescription::Pattern(pattern) => {
|
|
on_demand_domains.push((pattern.clone(), ent.on_demand_tls_ask.clone()));
|
|
},
|
|
}
|
|
}
|
|
|
|
// only static_domains are refreshed
|
|
proc_domains = Some(ProcessedDomains { static_domains: static_domains.clone(), on_demand_domains });
|
|
self.domain_validation(static_domains, proc_domains.as_ref()).await
|
|
}
|
|
// renew static and on-demand domains
|
|
need_cert = rx_need_cert.recv() => {
|
|
match need_cert {
|
|
Some(dom) => {
|
|
let mut candidates: HashSet<String> = HashSet::new();
|
|
|
|
// collect certificates as much as possible
|
|
candidates.insert(dom);
|
|
while let Ok(dom2) = rx_need_cert.try_recv() {
|
|
candidates.insert(dom2);
|
|
}
|
|
|
|
self.domain_validation(candidates, proc_domains.as_ref()).await
|
|
}
|
|
None => bail!("rx_need_cert closed"),
|
|
}
|
|
}
|
|
};
|
|
|
|
// Now that we have our list of domains to check,
|
|
// actually do something
|
|
for dom in domains.iter() {
|
|
// Exclude from the list domains that were checked less than 60
|
|
// seconds ago
|
|
match t_last_check.get(dom) {
|
|
Some(t) if Instant::now() - *t < Duration::from_secs(60) => continue,
|
|
_ => t_last_check.insert(dom.to_string(), Instant::now()),
|
|
};
|
|
|
|
// Actual Let's Encrypt calls are done here (in sister function)
|
|
debug!("Checking cert for domain: {}", dom);
|
|
if let Err(e) = self.check_cert(dom).await {
|
|
warn!("({}) Could not get certificate: {}", dom, e);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn domain_validation(
|
|
&self,
|
|
candidates: HashSet<String>,
|
|
maybe_proc_domains: Option<&ProcessedDomains>,
|
|
) -> HashSet<String> {
|
|
let mut domains: HashSet<String> = HashSet::new();
|
|
|
|
// Handle initialization
|
|
let proc_domains = match maybe_proc_domains {
|
|
None => {
|
|
warn!("Proxy config is not yet loaded, refusing all certificate generation");
|
|
return domains;
|
|
}
|
|
Some(proc) => proc,
|
|
};
|
|
|
|
// Filter certificates...
|
|
'outer: for candidate in candidates.into_iter() {
|
|
// Disallow obvious wrong domains...
|
|
if !candidate.contains('.') || candidate.ends_with(".local") {
|
|
warn!("{} is probably not a publicly accessible domain, skipping (a self-signed certificate will be used)", candidate);
|
|
continue;
|
|
}
|
|
|
|
// Try to register domain as a static domain
|
|
if proc_domains.static_domains.contains(&candidate) {
|
|
trace!("domain {} validated as static domain", candidate);
|
|
domains.insert(candidate);
|
|
continue;
|
|
}
|
|
|
|
// It's not a static domain, maybe an on-demand domain?
|
|
for (pattern, maybe_check_url) in proc_domains.on_demand_domains.iter() {
|
|
// check glob pattern
|
|
if pattern.matches(&candidate) {
|
|
// if no check url is set, accept domain as long as it matches the pattern
|
|
let check_url = match maybe_check_url {
|
|
None => {
|
|
trace!(
|
|
"domain {} validated on glob pattern {} only",
|
|
candidate,
|
|
pattern
|
|
);
|
|
domains.insert(candidate);
|
|
continue 'outer;
|
|
}
|
|
Some(url) => url,
|
|
};
|
|
|
|
// if a check url is set, call it
|
|
// -- avoid DDoSing a backend
|
|
tokio::time::sleep(Duration::from_secs(2)).await;
|
|
match self.on_demand_tls_ask(check_url, &candidate).await {
|
|
Ok(()) => {
|
|
trace!(
|
|
"domain {} validated on glob pattern {} and on check url {}",
|
|
candidate,
|
|
pattern,
|
|
check_url
|
|
);
|
|
domains.insert(candidate);
|
|
continue 'outer;
|
|
}
|
|
Err(e) => {
|
|
warn!("domain {} validation refused on glob pattern {} and on check url {} with error: {}", candidate, pattern, check_url, e);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return domains;
|
|
}
|
|
|
|
/// This function is also in charge of the refresh of the domain names
|
|
fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
|
|
// Check in local memory if it exists
|
|
if let Some(cert) = self.certs.read().unwrap().get(domain) {
|
|
if cert.is_old() {
|
|
self.tx_need_cert.send(domain.to_string())?;
|
|
}
|
|
return Ok(cert.clone());
|
|
}
|
|
|
|
// Not found in local memory, try to get it in background
|
|
self.tx_need_cert.send(domain.to_string())?;
|
|
|
|
// In the meantime, use a self-signed certificate
|
|
if let Some(cert) = self.self_signed_certs.read().unwrap().get(domain) {
|
|
return Ok(cert.clone());
|
|
}
|
|
|
|
self.gen_self_signed_certificate(domain)
|
|
}
|
|
|
|
pub async fn warmup_memory_store(self: &Arc<Self>) -> Result<()> {
|
|
let consul_certs = self
|
|
.consul
|
|
.kv_get_prefix("certs/", None)
|
|
.await?
|
|
.into_inner();
|
|
|
|
trace!(
|
|
"Fetched {} certificate entries from Consul",
|
|
consul_certs.len()
|
|
);
|
|
let mut loaded_certs: usize = 0;
|
|
for (key, cert) in consul_certs {
|
|
let certser: CertSer = match serde_json::from_slice(&cert) {
|
|
Ok(cs) => cs,
|
|
Err(e) => {
|
|
warn!("Could not deserialize CertSer for {key}: {e}");
|
|
continue;
|
|
}
|
|
};
|
|
let domain = certser.hostname.clone();
|
|
|
|
let cert = match Cert::new(certser) {
|
|
Ok(c) => c,
|
|
Err(e) => {
|
|
warn!("Could not create Cert from CertSer for domain {domain}: {e}");
|
|
continue;
|
|
}
|
|
};
|
|
|
|
self.certs
|
|
.write()
|
|
.unwrap()
|
|
.insert(domain.to_string(), Arc::new(cert));
|
|
|
|
debug!("({domain}) Certificate loaded from Consul to the Memory Store");
|
|
loaded_certs += 1;
|
|
}
|
|
info!("Memory store warmed up with {loaded_certs} certificates");
|
|
Ok(())
|
|
}
|
|
|
|
/// Check certificate ensure that the certificate is in the memory store
|
|
/// and that it does not need to be renewed.
|
|
///
|
|
/// If it's not in the memory store, it tries to load it from Consul,
|
|
/// if it's not in Consul, it calls Let's Encrypt.
|
|
///
|
|
/// If the certificate is outdated in the memory store, it tries to load
|
|
/// a more recent version in Consul, if the Consul version is also outdated,
|
|
/// it tries to renew it
|
|
pub async fn check_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
|
|
// First, try locally.
|
|
{
|
|
let certs = self.certs.read().unwrap();
|
|
if let Some(cert) = certs.get(domain) {
|
|
if !cert.is_old() {
|
|
return Ok(());
|
|
}
|
|
}
|
|
}
|
|
|
|
// Second, try from Consul.
|
|
if let Some(consul_cert) = self
|
|
.consul
|
|
.kv_get_json::<CertSer>(&format!("certs/{}", domain))
|
|
.await?
|
|
{
|
|
match Cert::new(consul_cert) {
|
|
Ok(cert) => {
|
|
let cert = Arc::new(cert);
|
|
self.certs
|
|
.write()
|
|
.unwrap()
|
|
.insert(domain.to_string(), cert.clone());
|
|
debug!("({domain}) Certificate loaded from Consul to the Memory Store");
|
|
|
|
if !cert.is_old() {
|
|
return Ok(());
|
|
}
|
|
}
|
|
Err(e) => {
|
|
warn!("Could not create Cert from CertSer for domain {domain}: {e}");
|
|
}
|
|
};
|
|
}
|
|
|
|
// Third, ask from Let's Encrypt
|
|
self.renew_cert(domain).await
|
|
}
|
|
|
|
/// This is the place where certificates are generated or renewed
|
|
pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
|
|
info!("({}) Renewing certificate", domain);
|
|
|
|
// ---- Acquire lock ----
|
|
// the lock is acquired for half an hour,
|
|
// so that in case of an error we won't retry before
|
|
// that delay expires
|
|
|
|
let lock_path = format!("renew_lock/{}", domain);
|
|
let lock_name = format!("tricot/renew:{}@{}", domain, self.node_name);
|
|
let session = self
|
|
.consul
|
|
.create_session(&consul::locking::SessionRequest {
|
|
name: lock_name.clone(),
|
|
node: None,
|
|
lock_delay: Some("30m".into()),
|
|
ttl: Some("45m".into()),
|
|
behavior: Some("delete".into()),
|
|
})
|
|
.await?;
|
|
debug!("Lock session: {}", session);
|
|
|
|
if !self
|
|
.consul
|
|
.acquire(&lock_path, lock_name.clone().into(), &session)
|
|
.await?
|
|
{
|
|
bail!("Lock is already taken, not renewing for now.");
|
|
}
|
|
|
|
// ---- Accessibility check ----
|
|
// We don't want to ask Let's encrypt for a domain that
|
|
// is not configured to point here. This can happen with wildcards: someone can send
|
|
// a fake SNI to a domain that is not ours. We have to detect it here.
|
|
self.check_domain_accessibility(domain, &session).await?;
|
|
|
|
// ---- Do let's encrypt stuff ----
|
|
|
|
let dir = Directory::from_url(DirectoryUrl::LetsEncrypt)?;
|
|
let contact = vec![format!("mailto:{}", self.letsencrypt_email)];
|
|
|
|
// Use existing Let's encrypt account or register new one if necessary
|
|
let acc = match self.consul.kv_get("letsencrypt_account_key.pem").await? {
|
|
Some(acc_privkey) => {
|
|
info!("Using existing Let's encrypt account");
|
|
dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)?
|
|
}
|
|
None => {
|
|
info!("Creating new Let's encrypt account");
|
|
let acc = block_in_place(|| dir.register_account(contact.clone()))?;
|
|
self.consul
|
|
.kv_put(
|
|
"letsencrypt_account_key.pem",
|
|
acc.acme_private_key_pem()?.into_bytes().into(),
|
|
)
|
|
.await?;
|
|
acc
|
|
}
|
|
};
|
|
|
|
// Order certificate and perform validation
|
|
let mut ord_new = acc.new_order(domain, &[])?;
|
|
let ord_csr = loop {
|
|
if let Some(ord_csr) = ord_new.confirm_validations() {
|
|
break ord_csr;
|
|
}
|
|
|
|
let auths = ord_new.authorizations()?;
|
|
|
|
info!("({}) Creating challenge and storing in Consul", domain);
|
|
let chall = auths[0].http_challenge().unwrap();
|
|
let chall_key = format!("challenge/{}", chall.http_token());
|
|
self.consul
|
|
.acquire(&chall_key, chall.http_proof()?.into(), &session)
|
|
.await?;
|
|
|
|
info!("({}) Validating challenge", domain);
|
|
block_in_place(|| chall.validate(Duration::from_millis(5000)))?;
|
|
|
|
info!("({}) Deleting challenge", domain);
|
|
self.consul.kv_delete(&chall_key).await?;
|
|
|
|
block_in_place(|| ord_new.refresh())?;
|
|
};
|
|
|
|
// Generate key and finalize certificate
|
|
let pkey_pri = create_p384_key()?;
|
|
let ord_cert =
|
|
block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?;
|
|
let cert = block_in_place(|| ord_cert.download_cert())?;
|
|
|
|
info!("({}) Keys and certificate obtained", domain);
|
|
let key_pem = cert.private_key().to_string();
|
|
let cert_pem = cert.certificate().to_string();
|
|
|
|
let certser = CertSer {
|
|
hostname: domain.to_string(),
|
|
date: Utc::now().date_naive(),
|
|
valid_days: cert.valid_days_left()?,
|
|
key_pem,
|
|
cert_pem,
|
|
};
|
|
let cert = Arc::new(Cert::new(certser.clone())?);
|
|
|
|
// Store certificate in Consul and local store
|
|
self.certs.write().unwrap().insert(domain.to_string(), cert);
|
|
self.consul
|
|
.kv_put_json(&format!("certs/{}", domain), &certser)
|
|
.await?;
|
|
|
|
// Release locks
|
|
self.consul.release(&lock_path, "".into(), &session).await?;
|
|
self.consul.kv_delete(&lock_path).await?;
|
|
|
|
info!("({}) Cert successfully renewed and stored", domain);
|
|
Ok(())
|
|
}
|
|
|
|
async fn on_demand_tls_ask(&self, check_url: &str, domain: &str) -> Result<()> {
|
|
let httpcli = reqwest::Client::new();
|
|
let chall_url = format!("{}?domain={}", check_url, domain);
|
|
info!("({}) On-demand TLS check", domain);
|
|
|
|
let httpresp = httpcli.get(&chall_url).send().await?;
|
|
if httpresp.status() != reqwest::StatusCode::OK {
|
|
bail!("{} is not authorized for on-demand TLS", domain);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn check_domain_accessibility(&self, domain: &str, session: &str) -> Result<()> {
|
|
// Returns Ok(()) only if domain is a correct domain name that
|
|
// redirects to this server
|
|
let self_challenge_id = uuid::Uuid::new_v4().to_string();
|
|
let self_challenge_key = format!("challenge/{}", self_challenge_id);
|
|
let self_challenge_resp = uuid::Uuid::new_v4().to_string();
|
|
|
|
self.consul
|
|
.acquire(
|
|
&self_challenge_key,
|
|
self_challenge_resp.as_bytes().to_vec().into(),
|
|
session,
|
|
)
|
|
.await?;
|
|
|
|
let httpcli = reqwest::Client::new();
|
|
let chall_url = format!(
|
|
"http://{}/.well-known/acme-challenge/{}",
|
|
domain, self_challenge_id
|
|
);
|
|
|
|
for i in 1..=4 {
|
|
tokio::time::sleep(Duration::from_secs(2)).await;
|
|
info!("({}) Accessibility check {}/4", domain, i);
|
|
|
|
let httpresp = httpcli.get(&chall_url).send().await?;
|
|
if httpresp.status() == reqwest::StatusCode::OK
|
|
&& httpresp.bytes().await? == self_challenge_resp.as_bytes()
|
|
{
|
|
// Challenge successfully validated
|
|
info!("({}) Accessibility check successfull", domain);
|
|
return Ok(());
|
|
}
|
|
|
|
tokio::time::sleep(Duration::from_secs(2)).await;
|
|
}
|
|
|
|
bail!("Unable to validate self-challenge for domain accessibility check");
|
|
}
|
|
|
|
fn gen_self_signed_certificate(&self, domain: &str) -> Result<Arc<Cert>> {
|
|
let subject_alt_names = vec![domain.to_string(), "localhost".to_string()];
|
|
let cert = rcgen::generate_simple_self_signed(subject_alt_names)?;
|
|
|
|
let certser = CertSer {
|
|
hostname: domain.to_string(),
|
|
date: Utc::now().date_naive(),
|
|
valid_days: 1024,
|
|
key_pem: cert.serialize_private_key_pem(),
|
|
cert_pem: cert.serialize_pem()?,
|
|
};
|
|
let cert = Arc::new(Cert::new(certser)?);
|
|
self.self_signed_certs
|
|
.write()
|
|
.unwrap()
|
|
.insert(domain.to_string(), cert.clone());
|
|
info!("Added self-signed certificate for {}", domain);
|
|
|
|
Ok(cert)
|
|
}
|
|
}
|
|
|
|
pub struct StoreResolver(pub Arc<CertStore>);
|
|
|
|
impl rustls::server::ResolvesServerCert for StoreResolver {
|
|
fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
|
|
let domain = client_hello.server_name()?;
|
|
match self.0.get_cert_for_https(domain) {
|
|
Ok(cert) => Some(cert.certkey.clone()),
|
|
Err(e) => {
|
|
warn!("Could not get certificate for {}: {}", domain, e);
|
|
None
|
|
}
|
|
}
|
|
}
|
|
}
|