Improved management of ACME orders and certificate pre-expiration period

This commit is contained in:
Alex 2021-12-09 12:18:23 +01:00
parent 8153bdca46
commit fdb83162ce
No known key found for this signature in database
GPG Key ID: EDABF9711E244EB1
7 changed files with 116 additions and 79 deletions

View File

@ -6,7 +6,7 @@ use chrono::{Date, NaiveDate, Utc};
use rustls::sign::CertifiedKey;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CertSer {
pub hostname: String,
pub date: NaiveDate,

View File

@ -1,11 +1,13 @@
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use std::time::{Duration, Instant};
use anyhow::Result;
use chrono::Utc;
use futures::TryFutureExt;
use log::*;
use tokio::sync::watch;
use tokio::select;
use tokio::sync::{mpsc, watch};
use tokio::task::block_in_place;
use acme_micro::create_p384_key;
@ -14,6 +16,7 @@ use rustls::sign::CertifiedKey;
use crate::cert::{Cert, CertSer};
use crate::consul::*;
use crate::exit_on_err;
use crate::proxy_config::*;
pub struct CertStore {
@ -22,6 +25,7 @@ pub struct CertStore {
certs: RwLock<HashMap<String, Arc<Cert>>>,
self_signed_certs: RwLock<HashMap<String, Arc<Cert>>>,
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
tx_need_cert: mpsc::UnboundedSender<String>,
}
impl CertStore {
@ -30,21 +34,39 @@ impl CertStore {
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
letsencrypt_email: String,
) -> Arc<Self> {
Arc::new(Self {
let (tx, rx) = mpsc::unbounded_channel();
let cert_store = Arc::new(Self {
consul,
certs: RwLock::new(HashMap::new()),
self_signed_certs: RwLock::new(HashMap::new()),
rx_proxy_config,
letsencrypt_email,
})
tx_need_cert: tx,
});
tokio::spawn(cert_store.clone().certificate_loop(rx).map_err(exit_on_err));
cert_store
}
pub async fn watch_proxy_config(self: Arc<Self>) -> Result<()> {
async fn certificate_loop(
self: Arc<Self>,
mut rx_need_cert: mpsc::UnboundedReceiver<String>,
) -> Result<()> {
let mut rx_proxy_config = self.rx_proxy_config.clone();
while rx_proxy_config.changed().await.is_ok() {
let mut t_last_check: HashMap<String, Instant> = HashMap::new();
loop {
let mut domains: HashSet<String> = HashSet::new();
select! {
res = rx_proxy_config.changed() => {
if res.is_err() {
bail!("rx_proxy_config closed");
}
let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();
for ent in proxy_config.entries.iter() {
if let HostDescription::Hostname(domain) = &ent.host {
@ -55,19 +77,35 @@ impl CertStore {
}
}
}
}
need_cert = rx_need_cert.recv() => {
match need_cert {
Some(dom) => {
domains.insert(dom);
while let Ok(dom2) = rx_need_cert.try_recv() {
domains.insert(dom2);
}
}
None => bail!("rx_need_cert closed"),
};
}
}
debug!("Ensuring we have certs for domains: {:#?}", domains);
for dom in domains.iter() {
if let Err(e) = self.get_cert(dom).await {
warn!("Error get_cert {}: {}", dom, e);
match t_last_check.get(dom) {
Some(t) if Instant::now() - *t < Duration::from_secs(3600) => continue,
_ => t_last_check.insert(dom.to_string(), Instant::now()),
};
debug!("Checking cert for domain: {}", dom);
if let Err(e) = self.check_cert(dom).await {
warn!("({}) Could not get certificate: {}", dom, e);
}
}
}
}
bail!("rx_proxy_config closed");
}
pub fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
// Check if domain is authorized
if !self
.rx_proxy_config
@ -81,35 +119,30 @@ impl CertStore {
// Check in local memory if it exists
if let Some(cert) = self.certs.read().unwrap().get(domain) {
if !cert.is_old() {
return Ok(cert.clone());
if cert.is_old() {
self.tx_need_cert.send(domain.to_string())?;
}
return Ok(cert.clone());
}
// Not found in local memory, try to get it in background
tokio::spawn(self.clone().get_cert_task(domain.to_string()));
self.tx_need_cert.send(domain.to_string())?;
// In the meantime, use a self-signed certificate
if let Some(cert) = self.self_signed_certs.read().unwrap().get(domain) {
if !cert.is_old() {
return Ok(cert.clone());
}
}
self.gen_self_signed_certificate(domain)
}
pub async fn get_cert_task(self: Arc<Self>, domain: String) -> Result<Arc<Cert>> {
self.get_cert(domain.as_str()).await
}
pub async fn get_cert(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
pub async fn check_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
// First, try locally.
{
let certs = self.certs.read().unwrap();
if let Some(cert) = certs.get(domain) {
if !cert.is_old() {
return Ok(cert.clone());
return Ok(());
}
}
}
@ -122,12 +155,12 @@ impl CertStore {
{
if let Ok(cert) = Cert::new(consul_cert) {
let cert = Arc::new(cert);
if !cert.is_old() {
self.certs
.write()
.unwrap()
.insert(domain.to_string(), cert.clone());
return Ok(cert);
if !cert.is_old() {
return Ok(());
}
}
}
@ -136,8 +169,14 @@ impl CertStore {
self.renew_cert(domain).await
}
pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
info!("Renewing certificate for {}", domain);
pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
info!("({}) Renewing certificate", domain);
// Basic sanity check (we could add more kinds of checks here)
// This is just to help avoid getting rate-limited against ACME server
if !domain.contains('.') || domain.ends_with(".local") {
bail!("Probably not a publicly accessible domain, skipping (a self-signed certificate will be used)");
}
// ---- Acquire lock ----
// the lock is acquired for fifteen minutes,
@ -171,11 +210,13 @@ impl CertStore {
let dir = Directory::from_url(DirectoryUrl::LetsEncrypt)?;
let contact = vec![format!("mailto:{}", self.letsencrypt_email)];
let acc =
if let Some(acc_privkey) = self.consul.kv_get("letsencrypt_account_key.pem").await? {
// Use existing Let's encrypt account or register new one if necessary
let acc = match self.consul.kv_get("letsencrypt_account_key.pem").await? {
Some(acc_privkey) => {
info!("Using existing Let's encrypt account");
dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)?
} else {
}
None => {
info!("Creating new Let's encrypt account");
let acc = block_in_place(|| dir.register_account(contact.clone()))?;
self.consul
@ -185,8 +226,10 @@ impl CertStore {
)
.await?;
acc
}
};
// Order certificate and perform validation
let mut ord_new = acc.new_order(domain, &[])?;
let ord_csr = loop {
if let Some(ord_csr) = ord_new.confirm_validations() {
@ -195,28 +238,29 @@ impl CertStore {
let auths = ord_new.authorizations()?;
info!("Creating challenge and storing in Consul");
info!("({}) Creating challenge and storing in Consul", domain);
let chall = auths[0].http_challenge().unwrap();
let chall_key = format!("challenge/{}", chall.http_token());
self.consul
.acquire(&chall_key, chall.http_proof()?.into(), &session)
.await?;
info!("Validating challenge");
info!("({}) Validating challenge", domain);
block_in_place(|| chall.validate(Duration::from_millis(5000)))?;
info!("Deleting challenge");
info!("({}) Deleting challenge", domain);
self.consul.kv_delete(&chall_key).await?;
block_in_place(|| ord_new.refresh())?;
};
// Generate key and finalize certificate
let pkey_pri = create_p384_key()?;
let ord_cert =
block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?;
let cert = block_in_place(|| ord_cert.download_cert())?;
info!("Keys and certificate obtained");
info!("({}) Keys and certificate obtained", domain);
let key_pem = cert.private_key().to_string();
let cert_pem = cert.certificate().to_string();
@ -227,21 +271,20 @@ impl CertStore {
key_pem,
cert_pem,
};
let cert = Arc::new(Cert::new(certser.clone())?);
// Store certificate in Consul and local store
self.certs.write().unwrap().insert(domain.to_string(), cert);
self.consul
.kv_put_json(&format!("certs/{}", domain), &certser)
.await?;
// Release locks
self.consul.release(&lock_path, "".into(), &session).await?;
self.consul.kv_delete(&lock_path).await?;
let cert = Arc::new(Cert::new(certser)?);
self.certs
.write()
.unwrap()
.insert(domain.to_string(), cert.clone());
info!("Cert successfully renewed: {}", domain);
Ok(cert)
info!("({}) Cert successfully renewed and stored", domain);
Ok(())
}
fn gen_self_signed_certificate(&self, domain: &str) -> Result<Arc<Cert>> {

View File

@ -114,7 +114,6 @@ async fn handle(
)
});
if let Some(proxy_to) = best_match {
proxy_to.calls.fetch_add(1, Ordering::SeqCst);

View File

@ -5,7 +5,6 @@ use futures::TryFutureExt;
use std::net::SocketAddr;
use structopt::StructOpt;
mod tls_util;
mod cert;
mod cert_store;
mod consul;
@ -13,6 +12,7 @@ mod http;
mod https;
mod proxy_config;
mod reverse_proxy;
mod tls_util;
use log::*;
@ -85,7 +85,6 @@ async fn main() {
rx_proxy_config.clone(),
opt.letsencrypt_email.clone(),
);
tokio::spawn(cert_store.clone().watch_proxy_config().map_err(exit_on_err));
tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone()).map_err(exit_on_err));
tokio::spawn(

View File

@ -99,7 +99,8 @@ fn parse_tricot_tag(
) -> Option<ProxyEntry> {
let splits = tag.split(' ').collect::<Vec<_>>();
if (splits.len() != 2 && splits.len() != 3)
|| (splits[0] != "tricot" && splits[0] != "tricot-https") {
|| (splits[0] != "tricot" && splits[0] != "tricot-https")
{
return None;
}

View File

@ -1,11 +1,11 @@
//! Copied from https://github.com/felipenoris/hyper-reverse-proxy
//! See there for original Copyright notice
use std::sync::Arc;
use std::convert::TryInto;
use std::time::SystemTime;
use std::net::IpAddr;
use std::str::FromStr;
use std::sync::Arc;
use std::time::SystemTime;
use anyhow::Result;
use log::*;
@ -13,9 +13,9 @@ use log::*;
use http::header::HeaderName;
use hyper::header::{HeaderMap, HeaderValue};
use hyper::{Body, Client, Request, Response, Uri};
use rustls::{Certificate, ServerName};
use rustls::client::{ServerCertVerifier, ServerCertVerified};
use lazy_static::lazy_static;
use rustls::client::{ServerCertVerified, ServerCertVerifier};
use rustls::{Certificate, ServerName};
use crate::tls_util::HttpsConnectorFixedDnsname;
@ -181,10 +181,8 @@ impl ServerCertVerifier for DontVerifyServerCert {
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: SystemTime
_now: SystemTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
}

View File

@ -1,21 +1,20 @@
use core::future::Future;
use core::task::{Context, Poll};
use std::convert::TryFrom;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::io;
use futures_util::future::*;
use rustls::ServerName;
use hyper::client::connect::Connection;
use hyper::client::HttpConnector;
use hyper::service::Service;
use hyper::Uri;
use hyper_rustls::MaybeHttpsStream;
use rustls::ServerName;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsConnector;
#[derive(Clone)]
pub struct HttpsConnectorFixedDnsname<T> {
http: T,
@ -62,8 +61,7 @@ where
let cfg = self.tls_config.clone();
let connecting_future = self.http.call(dst);
let dnsname =
ServerName::try_from(self.fixed_dnsname).expect("Invalid fixed dnsname");
let dnsname = ServerName::try_from(self.fixed_dnsname).expect("Invalid fixed dnsname");
let f = async move {
let tcp = connecting_future.await.map_err(Into::into)?;
let connector = TlsConnector::from(cfg);
@ -76,4 +74,3 @@ where
f.boxed()
}
}