From 1f6e64d34e44b8b7bc7247af38bccf3ade86cf0b Mon Sep 17 00:00:00 2001 From: Quentin Dufour Date: Thu, 14 Dec 2023 13:03:04 +0100 Subject: [PATCH] add support for hot reloading --- src/login/static_provider.rs | 75 +++++++++++++++++++++--------------- src/main.rs | 2 +- src/server.rs | 8 ++-- 3 files changed, 50 insertions(+), 35 deletions(-) diff --git a/src/login/static_provider.rs b/src/login/static_provider.rs index 85d55ef..4a8d484 100644 --- a/src/login/static_provider.rs +++ b/src/login/static_provider.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; use std::sync::Arc; use std::path::PathBuf; +use tokio::sync::watch; +use tokio::signal::unix::{signal, SignalKind}; use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; @@ -9,48 +11,59 @@ use crate::config::*; use crate::login::*; use crate::storage; -pub struct StaticLoginProvider { - user_list: PathBuf, +#[derive(Default)] +pub struct UserDatabase { users: HashMap>, users_by_email: HashMap>, } -impl StaticLoginProvider { - pub fn new(config: LoginStaticConfig) -> Result { - let mut lp = Self { - user_list: config.user_list.clone(), - users: HashMap::new(), - users_by_email: HashMap::new(), +pub struct StaticLoginProvider { + user_db: watch::Receiver, +} + +pub async fn update_user_list(config: PathBuf, up: watch::Sender) -> Result<()> { + let mut stream = signal(SignalKind::user_defined1()).expect("failed to install SIGUSR1 signal hander for reload"); + + loop { + let ulist: UserList = match read_config(config.clone()) { + Ok(x) => x, + Err(e) => { + tracing::warn!(path=%config.as_path().to_string_lossy(), error=%e, "Unable to load config"); + continue; + } }; - lp - .update_user_list() - .context( - format!( - "failed to read {:?}, make sure it exists and it's correctly formatted", - config.user_list))?; - - Ok(lp) - } - - pub fn update_user_list(&mut self) -> Result<()> { - let ulist: UserList = read_config(self.user_list.clone())?; - - self.users = ulist + let users = ulist .into_iter() .map(|(k, v)| (k, Arc::new(v))) .collect::>(); - self.users_by_email.clear(); - for (_, u) in self.users.iter() { + let mut users_by_email = HashMap::new(); + for (_, u) in users.iter() { for m in u.email_addresses.iter() { - if self.users_by_email.contains_key(m) { - bail!("Several users have same email address: {}", m); + if users_by_email.contains_key(m) { + tracing::warn!("Several users have the same email address: {}", m); + continue } - self.users_by_email.insert(m.clone(), u.clone()); + users_by_email.insert(m.clone(), u.clone()); } } - Ok(()) + + tracing::info!("{} users loaded", users.len()); + up.send(UserDatabase { users, users_by_email }).context("update user db config")?; + stream.recv().await; + tracing::info!("Received SIGUSR1, reloading"); + } +} + +impl StaticLoginProvider { + pub async fn new(config: LoginStaticConfig) -> Result { + let (tx, mut rx) = watch::channel(UserDatabase::default()); + + tokio::spawn(update_user_list(config.user_list, tx)); + rx.changed().await?; + + Ok(Self { user_db: rx }) } } @@ -58,7 +71,8 @@ impl StaticLoginProvider { impl LoginProvider for StaticLoginProvider { async fn login(&self, username: &str, password: &str) -> Result { tracing::debug!(user=%username, "login"); - let user = match self.users.get(username) { + let user_db = self.user_db.borrow(); + let user = match user_db.users.get(username) { None => bail!("User {} does not exist", username), Some(u) => u, }; @@ -89,7 +103,8 @@ impl LoginProvider for StaticLoginProvider { } async fn public_login(&self, email: &str) -> Result { - let user = match self.users_by_email.get(email) { + let user_db = self.user_db.borrow(); + let user = match user_db.users_by_email.get(email) { None => bail!("No user for email address {}", email), Some(u) => u, }; diff --git a/src/main.rs b/src/main.rs index 3d87d11..02ba5e4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -42,7 +42,7 @@ enum Command { Provider(ProviderCommand), #[clap(subcommand)] - /// Specific tooling, should not be part of a normal workflow, for debug & experimenting only + /// Specific tooling, should not be part of a normal workflow, for debug & experimentation only Tools(ToolsCommand), //Test, } diff --git a/src/server.rs b/src/server.rs index 2321da8..8abdb86 100644 --- a/src/server.rs +++ b/src/server.rs @@ -18,21 +18,21 @@ pub struct Server { impl Server { pub async fn from_companion_config(config: CompanionConfig) -> Result { - let login = Arc::new(StaticLoginProvider::new(config.users)?); + let login = Arc::new(StaticLoginProvider::new(config.users).await?); let lmtp_server = None; - let imap_server = Some(imap::new(config.imap, login).await?); + let imap_server = Some(imap::new(config.imap, login.clone()).await?); Ok(Self { lmtp_server, imap_server }) } pub async fn from_provider_config(config: ProviderConfig) -> Result { let login: ArcLoginProvider = match config.users { - UserManagement::Static(x) => Arc::new(StaticLoginProvider::new(x)?), + UserManagement::Static(x) => Arc::new(StaticLoginProvider::new(x).await?), UserManagement::Ldap(x) => Arc::new(LdapLoginProvider::new(x)?), }; let lmtp_server = Some(LmtpServer::new(config.lmtp, login.clone())); - let imap_server = Some(imap::new(config.imap, login).await?); + let imap_server = Some(imap::new(config.imap, login.clone()).await?); Ok(Self { lmtp_server, imap_server }) }