add support for hot reloading
This commit is contained in:
parent
65f4ceae78
commit
1f6e64d34e
3 changed files with 50 additions and 35 deletions
|
@ -1,6 +1,8 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::path::PathBuf;
|
||||
use tokio::sync::watch;
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use async_trait::async_trait;
|
||||
|
@ -9,48 +11,59 @@ use crate::config::*;
|
|||
use crate::login::*;
|
||||
use crate::storage;
|
||||
|
||||
pub struct StaticLoginProvider {
|
||||
user_list: PathBuf,
|
||||
#[derive(Default)]
|
||||
pub struct UserDatabase {
|
||||
users: HashMap<String, Arc<UserEntry>>,
|
||||
users_by_email: HashMap<String, Arc<UserEntry>>,
|
||||
}
|
||||
|
||||
impl StaticLoginProvider {
|
||||
pub fn new(config: LoginStaticConfig) -> Result<Self> {
|
||||
let mut lp = Self {
|
||||
user_list: config.user_list.clone(),
|
||||
users: HashMap::new(),
|
||||
users_by_email: HashMap::new(),
|
||||
pub struct StaticLoginProvider {
|
||||
user_db: watch::Receiver<UserDatabase>,
|
||||
}
|
||||
|
||||
pub async fn update_user_list(config: PathBuf, up: watch::Sender<UserDatabase>) -> Result<()> {
|
||||
let mut stream = signal(SignalKind::user_defined1()).expect("failed to install SIGUSR1 signal hander for reload");
|
||||
|
||||
loop {
|
||||
let ulist: UserList = match read_config(config.clone()) {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
tracing::warn!(path=%config.as_path().to_string_lossy(), error=%e, "Unable to load config");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
lp
|
||||
.update_user_list()
|
||||
.context(
|
||||
format!(
|
||||
"failed to read {:?}, make sure it exists and it's correctly formatted",
|
||||
config.user_list))?;
|
||||
|
||||
Ok(lp)
|
||||
}
|
||||
|
||||
pub fn update_user_list(&mut self) -> Result<()> {
|
||||
let ulist: UserList = read_config(self.user_list.clone())?;
|
||||
|
||||
self.users = ulist
|
||||
let users = ulist
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, Arc::new(v)))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
self.users_by_email.clear();
|
||||
for (_, u) in self.users.iter() {
|
||||
let mut users_by_email = HashMap::new();
|
||||
for (_, u) in users.iter() {
|
||||
for m in u.email_addresses.iter() {
|
||||
if self.users_by_email.contains_key(m) {
|
||||
bail!("Several users have same email address: {}", m);
|
||||
if users_by_email.contains_key(m) {
|
||||
tracing::warn!("Several users have the same email address: {}", m);
|
||||
continue
|
||||
}
|
||||
self.users_by_email.insert(m.clone(), u.clone());
|
||||
users_by_email.insert(m.clone(), u.clone());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
||||
tracing::info!("{} users loaded", users.len());
|
||||
up.send(UserDatabase { users, users_by_email }).context("update user db config")?;
|
||||
stream.recv().await;
|
||||
tracing::info!("Received SIGUSR1, reloading");
|
||||
}
|
||||
}
|
||||
|
||||
impl StaticLoginProvider {
|
||||
pub async fn new(config: LoginStaticConfig) -> Result<Self> {
|
||||
let (tx, mut rx) = watch::channel(UserDatabase::default());
|
||||
|
||||
tokio::spawn(update_user_list(config.user_list, tx));
|
||||
rx.changed().await?;
|
||||
|
||||
Ok(Self { user_db: rx })
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,7 +71,8 @@ impl StaticLoginProvider {
|
|||
impl LoginProvider for StaticLoginProvider {
|
||||
async fn login(&self, username: &str, password: &str) -> Result<Credentials> {
|
||||
tracing::debug!(user=%username, "login");
|
||||
let user = match self.users.get(username) {
|
||||
let user_db = self.user_db.borrow();
|
||||
let user = match user_db.users.get(username) {
|
||||
None => bail!("User {} does not exist", username),
|
||||
Some(u) => u,
|
||||
};
|
||||
|
@ -89,7 +103,8 @@ impl LoginProvider for StaticLoginProvider {
|
|||
}
|
||||
|
||||
async fn public_login(&self, email: &str) -> Result<PublicCredentials> {
|
||||
let user = match self.users_by_email.get(email) {
|
||||
let user_db = self.user_db.borrow();
|
||||
let user = match user_db.users_by_email.get(email) {
|
||||
None => bail!("No user for email address {}", email),
|
||||
Some(u) => u,
|
||||
};
|
||||
|
|
|
@ -42,7 +42,7 @@ enum Command {
|
|||
Provider(ProviderCommand),
|
||||
|
||||
#[clap(subcommand)]
|
||||
/// Specific tooling, should not be part of a normal workflow, for debug & experimenting only
|
||||
/// Specific tooling, should not be part of a normal workflow, for debug & experimentation only
|
||||
Tools(ToolsCommand),
|
||||
//Test,
|
||||
}
|
||||
|
|
|
@ -18,21 +18,21 @@ pub struct Server {
|
|||
|
||||
impl Server {
|
||||
pub async fn from_companion_config(config: CompanionConfig) -> Result<Self> {
|
||||
let login = Arc::new(StaticLoginProvider::new(config.users)?);
|
||||
let login = Arc::new(StaticLoginProvider::new(config.users).await?);
|
||||
|
||||
let lmtp_server = None;
|
||||
let imap_server = Some(imap::new(config.imap, login).await?);
|
||||
let imap_server = Some(imap::new(config.imap, login.clone()).await?);
|
||||
Ok(Self { lmtp_server, imap_server })
|
||||
}
|
||||
|
||||
pub async fn from_provider_config(config: ProviderConfig) -> Result<Self> {
|
||||
let login: ArcLoginProvider = match config.users {
|
||||
UserManagement::Static(x) => Arc::new(StaticLoginProvider::new(x)?),
|
||||
UserManagement::Static(x) => Arc::new(StaticLoginProvider::new(x).await?),
|
||||
UserManagement::Ldap(x) => Arc::new(LdapLoginProvider::new(x)?),
|
||||
};
|
||||
|
||||
let lmtp_server = Some(LmtpServer::new(config.lmtp, login.clone()));
|
||||
let imap_server = Some(imap::new(config.imap, login).await?);
|
||||
let imap_server = Some(imap::new(config.imap, login.clone()).await?);
|
||||
|
||||
Ok(Self { lmtp_server, imap_server })
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue