implement feature

This commit is contained in:
Quentin 2023-11-30 16:53:04 +01:00
parent ca449ebff4
commit 753903ee02
Signed by: quentin
GPG key ID: E9602264D639FF68
2 changed files with 143 additions and 49 deletions

View file

@ -67,58 +67,73 @@ impl CertStore {
let mut t_last_check: HashMap<String, Instant> = HashMap::new(); let mut t_last_check: HashMap<String, Instant> = HashMap::new();
loop { // Collect data from proxy config
let mut domains: HashSet<String> = HashSet::new(); let mut static_domains: HashSet<String> = HashSet::new();
let mut on_demand_checks: Vec<(glob::Pattern, Option<String>)> = vec![];
// Collect domains that need a TLS certificate loop {
// either from the proxy configuration (eagerly) // Collect domains that need a TLS certificate
// or on reaction to a user request (lazily) // either from the proxy configuration (eagerly)
select! { // or on reaction to a user request (lazily)
let domains = select! {
res = rx_proxy_config.changed() => { res = rx_proxy_config.changed() => {
if res.is_err() { if res.is_err() {
bail!("rx_proxy_config closed"); bail!("rx_proxy_config closed");
} }
on_demand_checks.clear();
let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone(); let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();
for ent in proxy_config.entries.iter() { for ent in proxy_config.entries.iter() {
// Eagerly generate certificates for domains that // Eagerly generate certificates for domains that
// are not patterns // are not patterns
if let HostDescription::Hostname(domain) = &ent.url_prefix.host { match &ent.url_prefix.host {
HostDescription::Hostname(domain) => {
if let Some((host, _port)) = domain.split_once(':') { if let Some((host, _port)) = domain.split_once(':') {
domains.insert(host.to_string()); static_domains.insert(host.to_string());
//domains.insert(host.to_string());
} else { } else {
domains.insert(domain.clone()); static_domains.insert(domain.clone());
//domains.insert(domain.clone());
} }
} },
HostDescription::Pattern(pattern) => {
// @TODO Register a map of on_demand_checks.push((pattern.clone(), ent.on_demand_tls_ask.clone()));
// UrlPrefix -> OnDemandTlsAskCheckUrl }
}
} }
// only static_domains are refreshed
static_domains.clone()
} }
need_cert = rx_need_cert.recv() => { need_cert = rx_need_cert.recv() => {
match need_cert { match need_cert {
Some(dom) => { Some(dom) => {
domains.insert(dom); let mut candidates: HashSet<String> = HashSet::new();
// collect certificates as much as possible
candidates.insert(dom);
while let Ok(dom2) = rx_need_cert.try_recv() { while let Ok(dom2) = rx_need_cert.try_recv() {
domains.insert(dom2); candidates.insert(dom2);
} }
self.domain_validation(candidates, &static_domains, on_demand_checks.as_slice()).await
} }
None => bail!("rx_need_cert closed"), None => bail!("rx_need_cert closed"),
}; }
} }
} };
// Now that we have our list of domains to check, // Now that we have our list of domains to check,
// actually do something // actually do something
for dom in domains.iter() { for dom in domains.iter() {
// Exclude from the list domains that were checked less than 60 // Exclude from the list domains that were checked less than 60
// seconds ago // seconds ago
match t_last_check.get(dom) { match t_last_check.get(dom) {
Some(t) if Instant::now() - *t < Duration::from_secs(60) => continue, Some(t) if Instant::now() - *t < Duration::from_secs(60) => continue,
_ => t_last_check.insert(dom.to_string(), Instant::now()), _ => t_last_check.insert(dom.to_string(), Instant::now()),
}; };
// Actual Let's Encrypt calls are done here (in sister function) // Actual Let's Encrypt calls are done here (in sister function)
debug!("Checking cert for domain: {}", dom); debug!("Checking cert for domain: {}", dom);
if let Err(e) = self.check_cert(dom).await { if let Err(e) = self.check_cert(dom).await {
warn!("({}) Could not get certificate: {}", dom, e); warn!("({}) Could not get certificate: {}", dom, e);
@ -127,6 +142,73 @@ impl CertStore {
} }
} }
async fn domain_validation(
&self,
candidates: HashSet<String>,
static_domains: &HashSet<String>,
checks: &[(glob::Pattern, Option<String>)],
) -> HashSet<String> {
let mut domains: HashSet<String> = HashSet::new();
// Filter certificates...
for candidate in candidates.into_iter() {
// Disallow obvious wrong domains...
if !candidate.contains('.') || candidate.ends_with(".local") {
warn!("Probably not a publicly accessible domain, skipping (a self-signed certificate will be used)");
continue;
}
// Try to register domain as a static domain
if static_domains.contains(&candidate) {
trace!("domain {} validated as static domain", candidate);
domains.insert(candidate);
continue;
}
// It's not a static domain, maybe an on-demand domain?
for (pattern, maybe_check_url) in checks.iter() {
// check glob pattern
if pattern.matches(&candidate) {
// if no check url is set, accept domain as long as it matches the pattern
let check_url = match maybe_check_url {
None => {
trace!(
"domain {} validated on glob pattern {} only",
candidate,
pattern
);
domains.insert(candidate);
break;
}
Some(url) => url,
};
// if a check url is set, call it
match self.on_demand_tls_ask(check_url, &candidate).await {
Ok(()) => {
trace!(
"domain {} validated on glob pattern {} and on check url {}",
candidate,
pattern,
check_url
);
domains.insert(candidate);
break;
}
Err(e) => {
warn!("domain {} validation refused on glob pattern {} and on check url {} with error: {}", candidate, pattern, check_url, e);
}
}
}
}
// Avoid DDoSing a backend
tokio::time::sleep(Duration::from_secs(2)).await;
}
return domains;
}
/// This function is also in charge of the refresh of the domain names
fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> { fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
// Check if domain is authorized // Check if domain is authorized
if !self if !self
@ -199,15 +281,15 @@ impl CertStore {
Ok(()) Ok(())
} }
/// Check certificate ensure that the certificate is in the memory store /// Check certificate ensure that the certificate is in the memory store
/// and that it does not need to be renewed. /// and that it does not need to be renewed.
/// ///
/// If it's not in the memory store, it tries to load it from Consul, /// If it's not in the memory store, it tries to load it from Consul,
/// if it's not in Consul, it calls Let's Encrypt. /// if it's not in Consul, it calls Let's Encrypt.
/// ///
/// If the certificate is outdated in the memory store, it tries to load /// If the certificate is outdated in the memory store, it tries to load
/// a more recent version in Consul, if the Consul version is also outdated, /// a more recent version in Consul, if the Consul version is also outdated,
/// it tries to renew it /// it tries to renew it
pub async fn check_cert(self: &Arc<Self>, domain: &str) -> Result<()> { pub async fn check_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
// First, try locally. // First, try locally.
{ {
@ -248,16 +330,10 @@ impl CertStore {
self.renew_cert(domain).await self.renew_cert(domain).await
} }
/// This is the place where certificates are generated or renewed /// This is the place where certificates are generated or renewed
pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<()> { pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
info!("({}) Renewing certificate", domain); info!("({}) Renewing certificate", domain);
// Basic sanity check (we could add more kinds of checks here)
// This is just to help avoid getting rate-limited against ACME server
if !domain.contains('.') || domain.ends_with(".local") {
bail!("Probably not a publicly accessible domain, skipping (a self-signed certificate will be used)");
}
// ---- Acquire lock ---- // ---- Acquire lock ----
// the lock is acquired for half an hour, // the lock is acquired for half an hour,
// so that in case of an error we won't retry before // so that in case of an error we won't retry before
@ -373,6 +449,19 @@ impl CertStore {
Ok(()) Ok(())
} }
async fn on_demand_tls_ask(&self, check_url: &str, domain: &str) -> Result<()> {
let httpcli = reqwest::Client::new();
let chall_url = format!("{}?domain={}", check_url, domain);
info!("({}) On-demand TLS check", domain);
let httpresp = httpcli.get(&chall_url).send().await?;
if httpresp.status() != reqwest::StatusCode::OK {
bail!("{} is not authorized for on-demand TLS", domain);
}
Ok(())
}
async fn check_domain_accessibility(&self, domain: &str, session: &str) -> Result<()> { async fn check_domain_accessibility(&self, domain: &str, session: &str) -> Result<()> {
// Returns Ok(()) only if domain is a correct domain name that // Returns Ok(()) only if domain is a correct domain name that
// redirects to this server // redirects to this server

View file

@ -108,6 +108,10 @@ pub struct ProxyEntry {
/// when matching this rule /// when matching this rule
pub redirects: Vec<(UrlPrefix, UrlPrefix, u16)>, pub redirects: Vec<(UrlPrefix, UrlPrefix, u16)>,
/// Wether or not the domain must be validated before asking a certificate
/// to let's encrypt (only for Glob patterns)
pub on_demand_tls_ask: Option<String>,
/// Number of calls in progress, used to deprioritize slow back-ends /// Number of calls in progress, used to deprioritize slow back-ends
pub calls_in_progress: atomic::AtomicI64, pub calls_in_progress: atomic::AtomicI64,
/// Time of last call, used for round-robin selection /// Time of last call, used for round-robin selection
@ -142,14 +146,14 @@ impl ProxyEntry {
let mut add_headers = vec![]; let mut add_headers = vec![];
let mut redirects = vec![]; let mut redirects = vec![];
let mut on_demand_tls_ask: Option<String> = None;
for mid in middleware.into_iter() { for mid in middleware.into_iter() {
// LocalLb and GlobalLb are handled in the parent function
match mid { match mid {
ConfigTag::AddHeader(k, v) => add_headers.push((k.to_string(), v.clone())), ConfigTag::AddHeader(k, v) => add_headers.push((k.to_string(), v.clone())),
ConfigTag::AddRedirect(m, r, c) => redirects.push(((*m).clone(), (*r).clone(), *c)), ConfigTag::AddRedirect(m, r, c) => redirects.push(((*m).clone(), (*r).clone(), *c)),
ConfigTag::LocalLb | ConfigTag::GlobalLb => { ConfigTag::OnDemandTlsAsk(url) => on_demand_tls_ask = Some(url.to_string()),
/* handled in parent fx */ ConfigTag::LocalLb | ConfigTag::GlobalLb => (),
()
}
}; };
} }
@ -166,6 +170,7 @@ impl ProxyEntry {
flags, flags,
add_headers, add_headers,
redirects, redirects,
on_demand_tls_ask,
// internal // internal
last_call: atomic::AtomicI64::from(0), last_call: atomic::AtomicI64::from(0),
calls_in_progress: atomic::AtomicI64::from(0), calls_in_progress: atomic::AtomicI64::from(0),
@ -247,6 +252,7 @@ enum MatchTag {
enum ConfigTag<'a> { enum ConfigTag<'a> {
AddHeader(&'a str, String), AddHeader(&'a str, String),
AddRedirect(UrlPrefix, UrlPrefix, u16), AddRedirect(UrlPrefix, UrlPrefix, u16),
OnDemandTlsAsk(&'a str),
GlobalLb, GlobalLb,
LocalLb, LocalLb,
} }
@ -321,6 +327,9 @@ fn parse_tricot_tags(tag: &str) -> Option<ParsedTag> {
p_match, p_replace, http_code, p_match, p_replace, http_code,
))) )))
} }
["tricot-on-demand-tls-ask", url, ..] => {
Some(ParsedTag::Middleware(ConfigTag::OnDemandTlsAsk(url)))
}
["tricot-global-lb", ..] => Some(ParsedTag::Middleware(ConfigTag::GlobalLb)), ["tricot-global-lb", ..] => Some(ParsedTag::Middleware(ConfigTag::GlobalLb)),
["tricot-local-lb", ..] => Some(ParsedTag::Middleware(ConfigTag::LocalLb)), ["tricot-local-lb", ..] => Some(ParsedTag::Middleware(ConfigTag::LocalLb)),
_ => None, _ => None,
@ -369,13 +378,9 @@ fn parse_consul_service(
// some legacy processing that would need a refactor later // some legacy processing that would need a refactor later
for mid in collected_middleware.iter() { for mid in collected_middleware.iter() {
match mid { match mid {
ConfigTag::AddHeader(_, _) | ConfigTag::AddRedirect(_, _, _) =>
/* not handled here */
{
()
}
ConfigTag::GlobalLb => flags.global_lb = true, ConfigTag::GlobalLb => flags.global_lb = true,
ConfigTag::LocalLb => flags.site_lb = true, ConfigTag::LocalLb => flags.site_lb = true,
_ => (),
}; };
} }