Improved management of ACME orders and certificate pre-expiration period

This commit is contained in:
Alex 2021-12-09 12:18:23 +01:00
parent 8153bdca46
commit fdb83162ce
No known key found for this signature in database
GPG key ID: EDABF9711E244EB1
7 changed files with 116 additions and 79 deletions

View file

@ -6,7 +6,7 @@ use chrono::{Date, NaiveDate, Utc};
use rustls::sign::CertifiedKey; use rustls::sign::CertifiedKey;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CertSer { pub struct CertSer {
pub hostname: String, pub hostname: String,
pub date: NaiveDate, pub date: NaiveDate,

View file

@ -1,11 +1,13 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::Duration; use std::time::{Duration, Instant};
use anyhow::Result; use anyhow::Result;
use chrono::Utc; use chrono::Utc;
use futures::TryFutureExt;
use log::*; use log::*;
use tokio::sync::watch; use tokio::select;
use tokio::sync::{mpsc, watch};
use tokio::task::block_in_place; use tokio::task::block_in_place;
use acme_micro::create_p384_key; use acme_micro::create_p384_key;
@ -14,6 +16,7 @@ use rustls::sign::CertifiedKey;
use crate::cert::{Cert, CertSer}; use crate::cert::{Cert, CertSer};
use crate::consul::*; use crate::consul::*;
use crate::exit_on_err;
use crate::proxy_config::*; use crate::proxy_config::*;
pub struct CertStore { pub struct CertStore {
@ -22,6 +25,7 @@ pub struct CertStore {
certs: RwLock<HashMap<String, Arc<Cert>>>, certs: RwLock<HashMap<String, Arc<Cert>>>,
self_signed_certs: RwLock<HashMap<String, Arc<Cert>>>, self_signed_certs: RwLock<HashMap<String, Arc<Cert>>>,
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>, rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
tx_need_cert: mpsc::UnboundedSender<String>,
} }
impl CertStore { impl CertStore {
@ -30,21 +34,39 @@ impl CertStore {
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>, rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
letsencrypt_email: String, letsencrypt_email: String,
) -> Arc<Self> { ) -> Arc<Self> {
Arc::new(Self { let (tx, rx) = mpsc::unbounded_channel();
let cert_store = Arc::new(Self {
consul, consul,
certs: RwLock::new(HashMap::new()), certs: RwLock::new(HashMap::new()),
self_signed_certs: RwLock::new(HashMap::new()), self_signed_certs: RwLock::new(HashMap::new()),
rx_proxy_config, rx_proxy_config,
letsencrypt_email, letsencrypt_email,
}) tx_need_cert: tx,
});
tokio::spawn(cert_store.clone().certificate_loop(rx).map_err(exit_on_err));
cert_store
} }
pub async fn watch_proxy_config(self: Arc<Self>) -> Result<()> { async fn certificate_loop(
self: Arc<Self>,
mut rx_need_cert: mpsc::UnboundedReceiver<String>,
) -> Result<()> {
let mut rx_proxy_config = self.rx_proxy_config.clone(); let mut rx_proxy_config = self.rx_proxy_config.clone();
while rx_proxy_config.changed().await.is_ok() { let mut t_last_check: HashMap<String, Instant> = HashMap::new();
loop {
let mut domains: HashSet<String> = HashSet::new(); let mut domains: HashSet<String> = HashSet::new();
select! {
res = rx_proxy_config.changed() => {
if res.is_err() {
bail!("rx_proxy_config closed");
}
let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone(); let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();
for ent in proxy_config.entries.iter() { for ent in proxy_config.entries.iter() {
if let HostDescription::Hostname(domain) = &ent.host { if let HostDescription::Hostname(domain) = &ent.host {
@ -55,19 +77,35 @@ impl CertStore {
} }
} }
} }
}
need_cert = rx_need_cert.recv() => {
match need_cert {
Some(dom) => {
domains.insert(dom);
while let Ok(dom2) = rx_need_cert.try_recv() {
domains.insert(dom2);
}
}
None => bail!("rx_need_cert closed"),
};
}
}
debug!("Ensuring we have certs for domains: {:#?}", domains);
for dom in domains.iter() { for dom in domains.iter() {
if let Err(e) = self.get_cert(dom).await { match t_last_check.get(dom) {
warn!("Error get_cert {}: {}", dom, e); Some(t) if Instant::now() - *t < Duration::from_secs(3600) => continue,
_ => t_last_check.insert(dom.to_string(), Instant::now()),
};
debug!("Checking cert for domain: {}", dom);
if let Err(e) = self.check_cert(dom).await {
warn!("({}) Could not get certificate: {}", dom, e);
}
} }
} }
} }
bail!("rx_proxy_config closed"); fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
}
pub fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
// Check if domain is authorized // Check if domain is authorized
if !self if !self
.rx_proxy_config .rx_proxy_config
@ -81,35 +119,30 @@ impl CertStore {
// Check in local memory if it exists // Check in local memory if it exists
if let Some(cert) = self.certs.read().unwrap().get(domain) { if let Some(cert) = self.certs.read().unwrap().get(domain) {
if !cert.is_old() { if cert.is_old() {
return Ok(cert.clone()); self.tx_need_cert.send(domain.to_string())?;
} }
return Ok(cert.clone());
} }
// Not found in local memory, try to get it in background // Not found in local memory, try to get it in background
tokio::spawn(self.clone().get_cert_task(domain.to_string())); self.tx_need_cert.send(domain.to_string())?;
// In the meantime, use a self-signed certificate // In the meantime, use a self-signed certificate
if let Some(cert) = self.self_signed_certs.read().unwrap().get(domain) { if let Some(cert) = self.self_signed_certs.read().unwrap().get(domain) {
if !cert.is_old() {
return Ok(cert.clone()); return Ok(cert.clone());
} }
}
self.gen_self_signed_certificate(domain) self.gen_self_signed_certificate(domain)
} }
pub async fn get_cert_task(self: Arc<Self>, domain: String) -> Result<Arc<Cert>> { pub async fn check_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
self.get_cert(domain.as_str()).await
}
pub async fn get_cert(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
// First, try locally. // First, try locally.
{ {
let certs = self.certs.read().unwrap(); let certs = self.certs.read().unwrap();
if let Some(cert) = certs.get(domain) { if let Some(cert) = certs.get(domain) {
if !cert.is_old() { if !cert.is_old() {
return Ok(cert.clone()); return Ok(());
} }
} }
} }
@ -122,12 +155,12 @@ impl CertStore {
{ {
if let Ok(cert) = Cert::new(consul_cert) { if let Ok(cert) = Cert::new(consul_cert) {
let cert = Arc::new(cert); let cert = Arc::new(cert);
if !cert.is_old() {
self.certs self.certs
.write() .write()
.unwrap() .unwrap()
.insert(domain.to_string(), cert.clone()); .insert(domain.to_string(), cert.clone());
return Ok(cert); if !cert.is_old() {
return Ok(());
} }
} }
} }
@ -136,8 +169,14 @@ impl CertStore {
self.renew_cert(domain).await self.renew_cert(domain).await
} }
pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> { pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
info!("Renewing certificate for {}", domain); info!("({}) Renewing certificate", domain);
// Basic sanity check (we could add more kinds of checks here)
// This is just to help avoid getting rate-limited against ACME server
if !domain.contains('.') || domain.ends_with(".local") {
bail!("Probably not a publicly accessible domain, skipping (a self-signed certificate will be used)");
}
// ---- Acquire lock ---- // ---- Acquire lock ----
// the lock is acquired for fifteen minutes, // the lock is acquired for fifteen minutes,
@ -171,11 +210,13 @@ impl CertStore {
let dir = Directory::from_url(DirectoryUrl::LetsEncrypt)?; let dir = Directory::from_url(DirectoryUrl::LetsEncrypt)?;
let contact = vec![format!("mailto:{}", self.letsencrypt_email)]; let contact = vec![format!("mailto:{}", self.letsencrypt_email)];
let acc = // Use existing Let's encrypt account or register new one if necessary
if let Some(acc_privkey) = self.consul.kv_get("letsencrypt_account_key.pem").await? { let acc = match self.consul.kv_get("letsencrypt_account_key.pem").await? {
Some(acc_privkey) => {
info!("Using existing Let's encrypt account"); info!("Using existing Let's encrypt account");
dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)? dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)?
} else { }
None => {
info!("Creating new Let's encrypt account"); info!("Creating new Let's encrypt account");
let acc = block_in_place(|| dir.register_account(contact.clone()))?; let acc = block_in_place(|| dir.register_account(contact.clone()))?;
self.consul self.consul
@ -185,8 +226,10 @@ impl CertStore {
) )
.await?; .await?;
acc acc
}
}; };
// Order certificate and perform validation
let mut ord_new = acc.new_order(domain, &[])?; let mut ord_new = acc.new_order(domain, &[])?;
let ord_csr = loop { let ord_csr = loop {
if let Some(ord_csr) = ord_new.confirm_validations() { if let Some(ord_csr) = ord_new.confirm_validations() {
@ -195,28 +238,29 @@ impl CertStore {
let auths = ord_new.authorizations()?; let auths = ord_new.authorizations()?;
info!("Creating challenge and storing in Consul"); info!("({}) Creating challenge and storing in Consul", domain);
let chall = auths[0].http_challenge().unwrap(); let chall = auths[0].http_challenge().unwrap();
let chall_key = format!("challenge/{}", chall.http_token()); let chall_key = format!("challenge/{}", chall.http_token());
self.consul self.consul
.acquire(&chall_key, chall.http_proof()?.into(), &session) .acquire(&chall_key, chall.http_proof()?.into(), &session)
.await?; .await?;
info!("Validating challenge"); info!("({}) Validating challenge", domain);
block_in_place(|| chall.validate(Duration::from_millis(5000)))?; block_in_place(|| chall.validate(Duration::from_millis(5000)))?;
info!("Deleting challenge"); info!("({}) Deleting challenge", domain);
self.consul.kv_delete(&chall_key).await?; self.consul.kv_delete(&chall_key).await?;
block_in_place(|| ord_new.refresh())?; block_in_place(|| ord_new.refresh())?;
}; };
// Generate key and finalize certificate
let pkey_pri = create_p384_key()?; let pkey_pri = create_p384_key()?;
let ord_cert = let ord_cert =
block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?; block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?;
let cert = block_in_place(|| ord_cert.download_cert())?; let cert = block_in_place(|| ord_cert.download_cert())?;
info!("Keys and certificate obtained"); info!("({}) Keys and certificate obtained", domain);
let key_pem = cert.private_key().to_string(); let key_pem = cert.private_key().to_string();
let cert_pem = cert.certificate().to_string(); let cert_pem = cert.certificate().to_string();
@ -227,21 +271,20 @@ impl CertStore {
key_pem, key_pem,
cert_pem, cert_pem,
}; };
let cert = Arc::new(Cert::new(certser.clone())?);
// Store certificate in Consul and local store
self.certs.write().unwrap().insert(domain.to_string(), cert);
self.consul self.consul
.kv_put_json(&format!("certs/{}", domain), &certser) .kv_put_json(&format!("certs/{}", domain), &certser)
.await?; .await?;
// Release locks
self.consul.release(&lock_path, "".into(), &session).await?; self.consul.release(&lock_path, "".into(), &session).await?;
self.consul.kv_delete(&lock_path).await?; self.consul.kv_delete(&lock_path).await?;
let cert = Arc::new(Cert::new(certser)?); info!("({}) Cert successfully renewed and stored", domain);
self.certs Ok(())
.write()
.unwrap()
.insert(domain.to_string(), cert.clone());
info!("Cert successfully renewed: {}", domain);
Ok(cert)
} }
fn gen_self_signed_certificate(&self, domain: &str) -> Result<Arc<Cert>> { fn gen_self_signed_certificate(&self, domain: &str) -> Result<Arc<Cert>> {

View file

@ -114,7 +114,6 @@ async fn handle(
) )
}); });
if let Some(proxy_to) = best_match { if let Some(proxy_to) = best_match {
proxy_to.calls.fetch_add(1, Ordering::SeqCst); proxy_to.calls.fetch_add(1, Ordering::SeqCst);

View file

@ -5,7 +5,6 @@ use futures::TryFutureExt;
use std::net::SocketAddr; use std::net::SocketAddr;
use structopt::StructOpt; use structopt::StructOpt;
mod tls_util;
mod cert; mod cert;
mod cert_store; mod cert_store;
mod consul; mod consul;
@ -13,6 +12,7 @@ mod http;
mod https; mod https;
mod proxy_config; mod proxy_config;
mod reverse_proxy; mod reverse_proxy;
mod tls_util;
use log::*; use log::*;
@ -85,7 +85,6 @@ async fn main() {
rx_proxy_config.clone(), rx_proxy_config.clone(),
opt.letsencrypt_email.clone(), opt.letsencrypt_email.clone(),
); );
tokio::spawn(cert_store.clone().watch_proxy_config().map_err(exit_on_err));
tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone()).map_err(exit_on_err)); tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone()).map_err(exit_on_err));
tokio::spawn( tokio::spawn(

View file

@ -99,7 +99,8 @@ fn parse_tricot_tag(
) -> Option<ProxyEntry> { ) -> Option<ProxyEntry> {
let splits = tag.split(' ').collect::<Vec<_>>(); let splits = tag.split(' ').collect::<Vec<_>>();
if (splits.len() != 2 && splits.len() != 3) if (splits.len() != 2 && splits.len() != 3)
|| (splits[0] != "tricot" && splits[0] != "tricot-https") { || (splits[0] != "tricot" && splits[0] != "tricot-https")
{
return None; return None;
} }

View file

@ -1,11 +1,11 @@
//! Copied from https://github.com/felipenoris/hyper-reverse-proxy //! Copied from https://github.com/felipenoris/hyper-reverse-proxy
//! See there for original Copyright notice //! See there for original Copyright notice
use std::sync::Arc;
use std::convert::TryInto; use std::convert::TryInto;
use std::time::SystemTime;
use std::net::IpAddr; use std::net::IpAddr;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc;
use std::time::SystemTime;
use anyhow::Result; use anyhow::Result;
use log::*; use log::*;
@ -13,9 +13,9 @@ use log::*;
use http::header::HeaderName; use http::header::HeaderName;
use hyper::header::{HeaderMap, HeaderValue}; use hyper::header::{HeaderMap, HeaderValue};
use hyper::{Body, Client, Request, Response, Uri}; use hyper::{Body, Client, Request, Response, Uri};
use rustls::{Certificate, ServerName};
use rustls::client::{ServerCertVerifier, ServerCertVerified};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use rustls::client::{ServerCertVerified, ServerCertVerifier};
use rustls::{Certificate, ServerName};
use crate::tls_util::HttpsConnectorFixedDnsname; use crate::tls_util::HttpsConnectorFixedDnsname;
@ -181,10 +181,8 @@ impl ServerCertVerifier for DontVerifyServerCert {
_server_name: &ServerName, _server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>, _scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8], _ocsp_response: &[u8],
_now: SystemTime _now: SystemTime,
) -> Result<ServerCertVerified, rustls::Error> { ) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion()) Ok(ServerCertVerified::assertion())
} }
} }

View file

@ -1,21 +1,20 @@
use core::future::Future; use core::future::Future;
use core::task::{Context, Poll}; use core::task::{Context, Poll};
use std::convert::TryFrom; use std::convert::TryFrom;
use std::io;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::io;
use futures_util::future::*; use futures_util::future::*;
use rustls::ServerName;
use hyper::client::connect::Connection; use hyper::client::connect::Connection;
use hyper::client::HttpConnector; use hyper::client::HttpConnector;
use hyper::service::Service; use hyper::service::Service;
use hyper::Uri; use hyper::Uri;
use hyper_rustls::MaybeHttpsStream; use hyper_rustls::MaybeHttpsStream;
use rustls::ServerName;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsConnector; use tokio_rustls::TlsConnector;
#[derive(Clone)] #[derive(Clone)]
pub struct HttpsConnectorFixedDnsname<T> { pub struct HttpsConnectorFixedDnsname<T> {
http: T, http: T,
@ -62,8 +61,7 @@ where
let cfg = self.tls_config.clone(); let cfg = self.tls_config.clone();
let connecting_future = self.http.call(dst); let connecting_future = self.http.call(dst);
let dnsname = let dnsname = ServerName::try_from(self.fixed_dnsname).expect("Invalid fixed dnsname");
ServerName::try_from(self.fixed_dnsname).expect("Invalid fixed dnsname");
let f = async move { let f = async move {
let tcp = connecting_future.await.map_err(Into::into)?; let tcp = connecting_future.await.map_err(Into::into)?;
let connector = TlsConnector::from(cfg); let connector = TlsConnector::from(cfg);
@ -76,4 +74,3 @@ where
f.boxed() f.boxed()
} }
} }