Add support for custom headers

This commit is contained in:
Alex 2021-12-07 18:19:51 +01:00
parent 0682c74e9d
commit 489d364676
No known key found for this signature in database
GPG key ID: EDABF9711E244EB1
6 changed files with 78 additions and 39 deletions

1
.gitignore vendored
View file

@ -1 +1,2 @@
/target /target
run_local.sh

View file

@ -40,7 +40,7 @@ impl CertStore {
for ent in proxy_config.entries.iter() { for ent in proxy_config.entries.iter() {
domains.insert(ent.host.clone()); domains.insert(ent.host.clone());
} }
info!("Ensuring we have certs for domains: {:#?}", domains); info!("Ensuring we have certs for domains: {:?}", domains);
for dom in domains.iter() { for dom in domains.iter() {
if let Err(e) = self.get_cert(dom).await { if let Err(e) = self.get_cert(dom).await {

View file

@ -1,4 +1,5 @@
use std::sync::Arc; use std::sync::Arc;
use std::net::SocketAddr;
use anyhow::Result; use anyhow::Result;
use log::*; use log::*;
@ -11,6 +12,34 @@ use crate::consul::Consul;
const CHALLENGE_PREFIX: &str = "/.well-known/acme-challenge/"; const CHALLENGE_PREFIX: &str = "/.well-known/acme-challenge/";
pub async fn serve_http(
bind_addr: SocketAddr,
consul: Consul,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let consul = Arc::new(consul);
// For every connection, we must make a `Service` to handle all
// incoming HTTP requests on said connection.
let make_svc = make_service_fn(|_conn| {
let consul = consul.clone();
// This is the `Service` that will handle the connection.
// `service_fn` is a helper to convert a function that
// returns a Response into a `Service`.
async move {
Ok::<_, anyhow::Error>(service_fn(move |req: Request<Body>| {
let consul = consul.clone();
handle(req, consul)
}))
}
});
info!("Listening on http://{}", bind_addr);
let server = Server::bind(&bind_addr).serve(make_svc);
server.await?;
Ok(())
}
async fn handle(req: Request<Body>, consul: Arc<Consul>) -> Result<Response<Body>> { async fn handle(req: Request<Body>, consul: Arc<Consul>) -> Result<Response<Body>> {
let path = req.uri().path(); let path = req.uri().path();
info!("HTTP request {}", path); info!("HTTP request {}", path);
@ -45,31 +74,3 @@ async fn handle(req: Request<Body>, consul: Arc<Consul>) -> Result<Response<Body
.body(Body::from(""))?) .body(Body::from(""))?)
} }
} }
pub async fn serve_http(consul: Consul) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let consul = Arc::new(consul);
// For every connection, we must make a `Service` to handle all
// incoming HTTP requests on said connection.
let make_svc = make_service_fn(|_conn| {
let consul = consul.clone();
// This is the `Service` that will handle the connection.
// `service_fn` is a helper to convert a function that
// returns a Response into a `Service`.
async move {
Ok::<_, anyhow::Error>(service_fn(move |req: Request<Body>| {
let consul = consul.clone();
handle(req, consul)
}))
}
});
let addr = ([0, 0, 0, 0], 1080).into();
let server = Server::bind(&addr).serve(make_svc);
println!("Listening on http://{}", addr);
server.await?;
Ok(())
}

View file

@ -8,6 +8,7 @@ use futures::FutureExt;
use hyper::server::conn::Http; use hyper::server::conn::Http;
use hyper::service::service_fn; use hyper::service::service_fn;
use hyper::{Body, Request, Response, StatusCode}; use hyper::{Body, Request, Response, StatusCode};
use http::header::{HeaderName, HeaderValue};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::watch; use tokio::sync::watch;
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
@ -17,11 +18,10 @@ use crate::proxy_config::ProxyConfig;
use crate::reverse_proxy; use crate::reverse_proxy;
pub async fn serve_https( pub async fn serve_https(
bind_addr: SocketAddr,
cert_store: Arc<CertStore>, cert_store: Arc<CertStore>,
proxy_config: watch::Receiver<Arc<ProxyConfig>>, proxy_config: watch::Receiver<Arc<ProxyConfig>>,
) -> Result<()> { ) -> Result<()> {
let addr = format!("0.0.0.0:1443");
let mut cfg = rustls::ServerConfig::builder() let mut cfg = rustls::ServerConfig::builder()
.with_safe_defaults() .with_safe_defaults()
.with_no_client_auth() .with_no_client_auth()
@ -31,9 +31,9 @@ pub async fn serve_https(
let tls_cfg = Arc::new(cfg); let tls_cfg = Arc::new(cfg);
let tls_acceptor = Arc::new(TlsAcceptor::from(tls_cfg)); let tls_acceptor = Arc::new(TlsAcceptor::from(tls_cfg));
println!("Starting to serve on https://{}.", addr); info!("Starting to serve on https://{}.", bind_addr);
let tcp = TcpListener::bind(&addr).await?; let tcp = TcpListener::bind(bind_addr).await?;
loop { loop {
let (socket, remote_addr) = tcp.accept().await?; let (socket, remote_addr) = tcp.accept().await?;
@ -118,7 +118,13 @@ async fn handle(
let to_addr = format!("http://{}", proxy_to.target_addr); let to_addr = format!("http://{}", proxy_to.target_addr);
info!("Proxying {} {} -> {}", host, path, to_addr); info!("Proxying {} {} -> {}", host, path, to_addr);
reverse_proxy::call(remote_addr.ip(), &to_addr, req).await let mut response = reverse_proxy::call(remote_addr.ip(), &to_addr, req).await?;
for (header, value) in proxy_to.add_headers.iter() {
response.headers_mut().insert(HeaderName::from_bytes(header.as_bytes())?, HeaderValue::from_str(value)?);
}
Ok(response)
} else { } else {
info!("Proxying {} {} -> NOT FOUND", host, path); info!("Proxying {} {} -> NOT FOUND", host, path);

View file

@ -1,6 +1,7 @@
#[macro_use] #[macro_use]
extern crate anyhow; extern crate anyhow;
use std::net::SocketAddr;
use structopt::StructOpt; use structopt::StructOpt;
mod cert; mod cert;
@ -27,6 +28,14 @@ struct Opt {
/// Node name /// Node name
#[structopt(long = "node-name", env = "TRICOT_NODE_NAME", default_value = "<none>")] #[structopt(long = "node-name", env = "TRICOT_NODE_NAME", default_value = "<none>")]
pub node_name: String, pub node_name: String,
/// Bind address for HTTP server
#[structopt(long = "http-bind-addr", env = "TRICOT_HTTP_BIND_ADDR", default_value = "0.0.0.0:80")]
pub http_bind_addr: SocketAddr,
/// Bind address for HTTPS server
#[structopt(long = "https-bind-addr", env = "TRICOT_HTTPS_BIND_ADDR", default_value = "0.0.0.0:443")]
pub https_bind_addr: SocketAddr,
} }
#[tokio::main(flavor = "multi_thread", worker_threads = 10)] #[tokio::main(flavor = "multi_thread", worker_threads = 10)]
@ -50,13 +59,17 @@ async fn main() {
.watch_proxy_config(rx_proxy_config.clone()), .watch_proxy_config(rx_proxy_config.clone()),
); );
tokio::spawn(http::serve_http(consul.clone())); tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone()));
tokio::spawn(https::serve_https( tokio::spawn(https::serve_https(
opt.https_bind_addr,
cert_store.clone(), cert_store.clone(),
rx_proxy_config.clone(), rx_proxy_config.clone(),
)); ));
while rx_proxy_config.changed().await.is_ok() { while rx_proxy_config.changed().await.is_ok() {
info!("Proxy config: {:#?}", *rx_proxy_config.borrow()); info!("Proxy config:");
for ent in rx_proxy_config.borrow().entries.iter() {
info!(" {:?}", ent);
}
} }
} }

View file

@ -22,6 +22,7 @@ pub struct ProxyEntry {
pub host: String, pub host: String,
pub path_prefix: Option<String>, pub path_prefix: Option<String>,
pub priority: u32, pub priority: u32,
pub add_headers: Vec<(String, String)>,
// Counts the number of times this proxy server has been called to // Counts the number of times this proxy server has been called to
// This implements a round-robin load balancer if there are multiple // This implements a round-robin load balancer if there are multiple
@ -44,7 +45,7 @@ fn retry_to_time(retries: u32, max_time: Duration) -> Duration {
)); ));
} }
fn parse_tricot_tag(target_addr: SocketAddr, tag: &str) -> Option<ProxyEntry> { fn parse_tricot_tag(tag: &str, target_addr: SocketAddr, add_headers: &[(String, String)]) -> Option<ProxyEntry> {
let splits = tag.split(' ').collect::<Vec<_>>(); let splits = tag.split(' ').collect::<Vec<_>>();
if (splits.len() != 2 && splits.len() != 3) || splits[0] != "tricot" { if (splits.len() != 2 && splits.len() != 3) || splits[0] != "tricot" {
return None; return None;
@ -65,10 +66,20 @@ fn parse_tricot_tag(target_addr: SocketAddr, tag: &str) -> Option<ProxyEntry> {
host: host.to_string(), host: host.to_string(),
path_prefix, path_prefix,
priority, priority,
add_headers: add_headers.to_vec(),
calls: atomic::AtomicU64::from(0), calls: atomic::AtomicU64::from(0),
}) })
} }
fn parse_tricot_add_header_tag(tag: &str) -> Option<(String, String)> {
let splits = tag.split(' ').collect::<Vec<_>>();
if splits.len() == 3 && splits[0] == "tricot-add-header" {
Some((splits[1].to_string(), splits[2].to_string()))
} else {
None
}
}
fn parse_consul_catalog(catalog: &ConsulNodeCatalog) -> Vec<ProxyEntry> { fn parse_consul_catalog(catalog: &ConsulNodeCatalog) -> Vec<ProxyEntry> {
let mut entries = vec![]; let mut entries = vec![];
@ -78,8 +89,16 @@ fn parse_consul_catalog(catalog: &ConsulNodeCatalog) -> Vec<ProxyEntry> {
_ => continue, _ => continue,
}; };
let addr = SocketAddr::new(ip_addr, svc.port); let addr = SocketAddr::new(ip_addr, svc.port);
let mut add_headers = vec![];
for tag in svc.tags.iter() { for tag in svc.tags.iter() {
if let Some(ent) = parse_tricot_tag(addr, tag) { if let Some(pair) = parse_tricot_add_header_tag(tag) {
add_headers.push(pair);
}
}
for tag in svc.tags.iter() {
if let Some(ent) = parse_tricot_tag(tag, addr, &add_headers[..]) {
entries.push(ent); entries.push(ent);
} }
} }
@ -181,7 +200,6 @@ pub fn spawn_proxy_config_task(consul: Consul) -> watch::Receiver<Arc<ProxyConfi
} }
} }
let config = ProxyConfig { entries }; let config = ProxyConfig { entries };
debug!("Extracted configuration: {:#?}", config);
tx.send(Arc::new(config)).expect("Internal error"); tx.send(Arc::new(config)).expect("Internal error");
} }