Exit more agressively on certain errors

This commit is contained in:
Alex 2021-12-08 17:50:40 +01:00
parent 0e6e60d35a
commit 3bdb417bfb
No known key found for this signature in database
GPG key ID: EDABF9711E244EB1
5 changed files with 31 additions and 15 deletions

View file

@ -39,7 +39,7 @@ impl CertStore {
}) })
} }
pub async fn watch_proxy_config(self: Arc<Self>) { pub async fn watch_proxy_config(self: Arc<Self>) -> Result<()> {
let mut rx_proxy_config = self.rx_proxy_config.clone(); let mut rx_proxy_config = self.rx_proxy_config.clone();
while rx_proxy_config.changed().await.is_ok() { while rx_proxy_config.changed().await.is_ok() {
@ -59,6 +59,8 @@ impl CertStore {
} }
} }
} }
bail!("rx_proxy_config closed");
} }
pub fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> { pub fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {

View file

@ -12,10 +12,7 @@ use crate::consul::Consul;
const CHALLENGE_PREFIX: &str = "/.well-known/acme-challenge/"; const CHALLENGE_PREFIX: &str = "/.well-known/acme-challenge/";
pub async fn serve_http( pub async fn serve_http(bind_addr: SocketAddr, consul: Consul) -> Result<()> {
bind_addr: SocketAddr,
consul: Consul,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let consul = Arc::new(consul); let consul = Arc::new(consul);
// For every connection, we must make a `Service` to handle all // For every connection, we must make a `Service` to handle all
// incoming HTTP requests on said connection. // incoming HTTP requests on said connection.

View file

@ -1,6 +1,7 @@
#[macro_use] #[macro_use]
extern crate anyhow; extern crate anyhow;
use futures::TryFutureExt;
use std::net::SocketAddr; use std::net::SocketAddr;
use structopt::StructOpt; use structopt::StructOpt;
@ -65,6 +66,12 @@ async fn main() {
} }
pretty_env_logger::init(); pretty_env_logger::init();
// Abort on panic (same behavior as in Go)
std::panic::set_hook(Box::new(|panic_info| {
error!("{}", panic_info.to_string());
std::process::abort();
}));
let opt = Opt::from_args(); let opt = Opt::from_args();
info!("Starting Tricot"); info!("Starting Tricot");
@ -77,14 +84,17 @@ async fn main() {
rx_proxy_config.clone(), rx_proxy_config.clone(),
opt.letsencrypt_email.clone(), opt.letsencrypt_email.clone(),
); );
tokio::spawn(cert_store.clone().watch_proxy_config()); tokio::spawn(cert_store.clone().watch_proxy_config().map_err(exit_on_err));
tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone())); tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone()).map_err(exit_on_err));
tokio::spawn(https::serve_https( tokio::spawn(
https::serve_https(
opt.https_bind_addr, opt.https_bind_addr,
cert_store.clone(), cert_store.clone(),
rx_proxy_config.clone(), rx_proxy_config.clone(),
)); )
.map_err(exit_on_err),
);
while rx_proxy_config.changed().await.is_ok() { while rx_proxy_config.changed().await.is_ok() {
info!("Proxy config:"); info!("Proxy config:");
@ -93,3 +103,8 @@ async fn main() {
} }
} }
} }
fn exit_on_err(e: anyhow::Error) -> () {
error!("{}", e);
std::process::exit(1);
}

View file

@ -102,7 +102,7 @@ fn parse_tricot_tag(
Some(i) => { Some(i) => {
let (host, pp) = splits[1].split_at(i); let (host, pp) = splits[1].split_at(i);
(host, Some(pp.to_string())) (host, Some(pp.to_string()))
}, }
None => (splits[1], None), None => (splits[1], None),
}; };

View file

@ -72,7 +72,6 @@ fn create_proxied_request<B>(
*builder.headers_mut().unwrap() = remove_hop_headers(request.headers()); *builder.headers_mut().unwrap() = remove_hop_headers(request.headers());
// If request does not have host header, add it from original URI authority // If request does not have host header, add it from original URI authority
let host_header_name = "host"; let host_header_name = "host";
if let Some(authority) = request.uri().authority() { if let Some(authority) = request.uri().authority() {
@ -100,7 +99,10 @@ fn create_proxied_request<B>(
} }
} }
builder.headers_mut().unwrap().insert(HeaderName::from_bytes(b"x-forwarded-proto")?, "https".try_into()?); builder.headers_mut().unwrap().insert(
HeaderName::from_bytes(b"x-forwarded-proto")?,
"https".try_into()?,
);
if let Some(conn) = request.headers().get("connection") { if let Some(conn) = request.headers().get("connection") {
if conn.to_str()?.to_lowercase() == "upgrade" { if conn.to_str()?.to_lowercase() == "upgrade" {