centralize all the checks in the same place

This commit is contained in:
Quentin 2023-11-30 17:34:07 +01:00
parent 753903ee02
commit b9b035034f
Signed by untrusted user: quentin
GPG key ID: E9602264D639FF68

View file

@ -22,12 +22,19 @@ pub struct CertStore {
consul: Consul, consul: Consul,
node_name: String, node_name: String,
letsencrypt_email: String, letsencrypt_email: String,
certs: RwLock<HashMap<String, Arc<Cert>>>, certs: RwLock<HashMap<String, Arc<Cert>>>,
self_signed_certs: RwLock<HashMap<String, Arc<Cert>>>, self_signed_certs: RwLock<HashMap<String, Arc<Cert>>>,
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>, rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
tx_need_cert: mpsc::UnboundedSender<String>, tx_need_cert: mpsc::UnboundedSender<String>,
} }
struct ProcessedDomains {
static_domains: HashSet<String>,
on_demand_domains: Vec<(glob::Pattern, Option<String>)>,
}
impl CertStore { impl CertStore {
pub fn new( pub fn new(
consul: Consul, consul: Consul,
@ -41,10 +48,10 @@ impl CertStore {
let cert_store = Arc::new(Self { let cert_store = Arc::new(Self {
consul, consul,
node_name, node_name,
letsencrypt_email,
certs: RwLock::new(HashMap::new()), certs: RwLock::new(HashMap::new()),
self_signed_certs: RwLock::new(HashMap::new()), self_signed_certs: RwLock::new(HashMap::new()),
rx_proxy_config, rx_proxy_config,
letsencrypt_email,
tx_need_cert: tx, tx_need_cert: tx,
}); });
@ -66,23 +73,21 @@ impl CertStore {
let mut rx_proxy_config = self.rx_proxy_config.clone(); let mut rx_proxy_config = self.rx_proxy_config.clone();
let mut t_last_check: HashMap<String, Instant> = HashMap::new(); let mut t_last_check: HashMap<String, Instant> = HashMap::new();
let mut proc_domains: Option<ProcessedDomains> = None;
// Collect data from proxy config
let mut static_domains: HashSet<String> = HashSet::new();
let mut on_demand_checks: Vec<(glob::Pattern, Option<String>)> = vec![];
loop { loop {
// Collect domains that need a TLS certificate
// either from the proxy configuration (eagerly)
// or on reaction to a user request (lazily)
let domains = select! { let domains = select! {
// Refresh some internal states, schedule static_domains for renew
res = rx_proxy_config.changed() => { res = rx_proxy_config.changed() => {
if res.is_err() { if res.is_err() {
bail!("rx_proxy_config closed"); bail!("rx_proxy_config closed");
} }
on_demand_checks.clear(); let mut static_domains: HashSet<String> = HashSet::new();
let mut on_demand_domains: Vec<(glob::Pattern, Option<String>)> = vec![];
let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone(); let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();
for ent in proxy_config.entries.iter() { for ent in proxy_config.entries.iter() {
// Eagerly generate certificates for domains that // Eagerly generate certificates for domains that
// are not patterns // are not patterns
@ -90,21 +95,21 @@ impl CertStore {
HostDescription::Hostname(domain) => { HostDescription::Hostname(domain) => {
if let Some((host, _port)) = domain.split_once(':') { if let Some((host, _port)) = domain.split_once(':') {
static_domains.insert(host.to_string()); static_domains.insert(host.to_string());
//domains.insert(host.to_string());
} else { } else {
static_domains.insert(domain.clone()); static_domains.insert(domain.clone());
//domains.insert(domain.clone());
} }
}, },
HostDescription::Pattern(pattern) => { HostDescription::Pattern(pattern) => {
on_demand_checks.push((pattern.clone(), ent.on_demand_tls_ask.clone())); on_demand_domains.push((pattern.clone(), ent.on_demand_tls_ask.clone()));
} },
} }
} }
// only static_domains are refreshed // only static_domains are refreshed
static_domains.clone() proc_domains = Some(ProcessedDomains { static_domains: static_domains.clone(), on_demand_domains });
self.domain_validation(static_domains, proc_domains.as_ref()).await
} }
// renew static and on-demand domains
need_cert = rx_need_cert.recv() => { need_cert = rx_need_cert.recv() => {
match need_cert { match need_cert {
Some(dom) => { Some(dom) => {
@ -116,7 +121,7 @@ impl CertStore {
candidates.insert(dom2); candidates.insert(dom2);
} }
self.domain_validation(candidates, &static_domains, on_demand_checks.as_slice()).await self.domain_validation(candidates, proc_domains.as_ref()).await
} }
None => bail!("rx_need_cert closed"), None => bail!("rx_need_cert closed"),
} }
@ -145,28 +150,36 @@ impl CertStore {
async fn domain_validation( async fn domain_validation(
&self, &self,
candidates: HashSet<String>, candidates: HashSet<String>,
static_domains: &HashSet<String>, maybe_proc_domains: Option<&ProcessedDomains>,
checks: &[(glob::Pattern, Option<String>)],
) -> HashSet<String> { ) -> HashSet<String> {
let mut domains: HashSet<String> = HashSet::new(); let mut domains: HashSet<String> = HashSet::new();
// Handle initialization
let proc_domains = match maybe_proc_domains {
None => {
warn!("Proxy config is not yet loaded, refusing all certificate generation");
return domains;
}
Some(proc) => proc,
};
// Filter certificates... // Filter certificates...
for candidate in candidates.into_iter() { 'outer: for candidate in candidates.into_iter() {
// Disallow obvious wrong domains... // Disallow obvious wrong domains...
if !candidate.contains('.') || candidate.ends_with(".local") { if !candidate.contains('.') || candidate.ends_with(".local") {
warn!("Probably not a publicly accessible domain, skipping (a self-signed certificate will be used)"); warn!("{} is probably not a publicly accessible domain, skipping (a self-signed certificate will be used)", candidate);
continue; continue;
} }
// Try to register domain as a static domain // Try to register domain as a static domain
if static_domains.contains(&candidate) { if proc_domains.static_domains.contains(&candidate) {
trace!("domain {} validated as static domain", candidate); trace!("domain {} validated as static domain", candidate);
domains.insert(candidate); domains.insert(candidate);
continue; continue;
} }
// It's not a static domain, maybe an on-demand domain? // It's not a static domain, maybe an on-demand domain?
for (pattern, maybe_check_url) in checks.iter() { for (pattern, maybe_check_url) in proc_domains.on_demand_domains.iter() {
// check glob pattern // check glob pattern
if pattern.matches(&candidate) { if pattern.matches(&candidate) {
// if no check url is set, accept domain as long as it matches the pattern // if no check url is set, accept domain as long as it matches the pattern
@ -178,12 +191,14 @@ impl CertStore {
pattern pattern
); );
domains.insert(candidate); domains.insert(candidate);
break; continue 'outer;
} }
Some(url) => url, Some(url) => url,
}; };
// if a check url is set, call it // if a check url is set, call it
// -- avoid DDoSing a backend
tokio::time::sleep(Duration::from_secs(2)).await;
match self.on_demand_tls_ask(check_url, &candidate).await { match self.on_demand_tls_ask(check_url, &candidate).await {
Ok(()) => { Ok(()) => {
trace!( trace!(
@ -193,7 +208,7 @@ impl CertStore {
check_url check_url
); );
domains.insert(candidate); domains.insert(candidate);
break; continue 'outer;
} }
Err(e) => { Err(e) => {
warn!("domain {} validation refused on glob pattern {} and on check url {} with error: {}", candidate, pattern, check_url, e); warn!("domain {} validation refused on glob pattern {} and on check url {} with error: {}", candidate, pattern, check_url, e);
@ -201,8 +216,6 @@ impl CertStore {
} }
} }
} }
// Avoid DDoSing a backend
tokio::time::sleep(Duration::from_secs(2)).await;
} }
return domains; return domains;
@ -210,17 +223,6 @@ impl CertStore {
/// This function is also in charge of the refresh of the domain names /// This function is also in charge of the refresh of the domain names
fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> { fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
// Check if domain is authorized
if !self
.rx_proxy_config
.borrow()
.entries
.iter()
.any(|ent| ent.url_prefix.host.matches(domain))
{
bail!("Domain {} should not have a TLS certificate.", domain);
}
// Check in local memory if it exists // Check in local memory if it exists
if let Some(cert) = self.certs.read().unwrap().get(domain) { if let Some(cert) = self.certs.read().unwrap().get(domain) {
if cert.is_old() { if cert.is_old() {