diff --git a/Cargo.toml b/Cargo.toml index 0c2eac9..1816f59 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ lazy_static = "1" actix-files = "0.6" unicode-segmentation = "1" lettre = { version = "0.11", features = ["builder", "smtp-transport", "rustls-tls"], default-features = false } -include_dir = "0.7.4" +include_dir = "0.7" [profile.profiling] inherits = "dev" diff --git a/src/classifier.rs b/src/classifier.rs index a34cd61..176845c 100644 --- a/src/classifier.rs +++ b/src/classifier.rs @@ -84,7 +84,7 @@ impl Classifier { /// Compute the probability of `tokens` to be part of a spam. fn rate_words(&self, tokens: &[String]) -> Vec { tokens - .into_iter() + .iter() .map(|word| { // If word was previously added in the model if let Some(counter) = self.token_table.get(word) { diff --git a/src/data.rs b/src/data.rs index 516e881..8420a5d 100644 --- a/src/data.rs +++ b/src/data.rs @@ -57,17 +57,17 @@ impl UserData { } match &self.location { - Some(s) => add(&s), + Some(s) => add(s), None => add("__NO_LOCATION__"), } match &self.website { - Some(s) => add(&s), + Some(s) => add(s), None => add("__NO_WEBSITE__"), } match &self.description { - Some(s) => add(&s), + Some(s) => add(s), None => add("__NO_USER_DESCRIPTION__"), } diff --git a/src/db.rs b/src/db.rs index a510eff..13bb7cb 100644 --- a/src/db.rs +++ b/src/db.rs @@ -116,20 +116,20 @@ impl Db { Ok(()) } - pub fn unclassified_users<'a>(&'a self) -> Vec<(UserId, &'a UserData)> { + pub fn unclassified_users(&self) -> Vec<(UserId, &UserData)> { self.users .iter() - .filter(|(user_id, _)| !self.is_spam.contains_key(&user_id)) + .filter(|(user_id, _)| !self.is_spam.contains_key(user_id)) .map(|(id, d)| (*id, d)) .collect() } - pub fn classified_users<'a>(&'a self) -> Vec<(UserId, &'a UserData, IsSpam)> { + pub fn classified_users(&self) -> Vec<(UserId, &UserData, IsSpam)> { self.users .iter() .filter_map(|(user_id, user_data)| { self.is_spam - .get(&user_id) + .get(user_id) .map(|is_spam| (user_id, user_data, *is_spam)) }) .map(|(id, d, s)| (*id, d, s)) diff --git a/src/main.rs b/src/main.rs index ed0be23..593075f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -60,19 +60,15 @@ impl Config { let admin_contact_email = std::env::var("ADMIN_CONTACT_EMAIL") .context("reading the ADMIN_CONTACT_EMAIL environment variable")?; - let actually_ban = match std::env::var("ACTUALLY_BAN_USERS") { - Ok(s) => { - if &s == "true" { - ActuallyBan::Yes { - smtp: SmtpConfig::from_env().await?, - } - } else if &s == "false" { - ActuallyBan::No - } else { - return Err(anyhow!( - "ACTUALLY_BAN_USERS: unknown value (expected: true/false)" - )); - } + let actually_ban = match std::env::var("ACTUALLY_BAN_USERS").as_deref() { + Ok("true") => ActuallyBan::Yes { + smtp: SmtpConfig::from_env().await?, + }, + Ok("false") => ActuallyBan::No, + Ok(_) => { + return Err(anyhow!( + "ACTUALLY_BAN_USERS: unknown value (expected: true/false)" + )); } Err(_) => ActuallyBan::No, }; @@ -125,7 +121,7 @@ async fn load_db(forge: &Forgejo) -> anyhow::Result<(Db, Classifier)> { Db::from_path(db_path, &classifier)? } else { let db = Db::from_users( - scrape::get_user_data(&forge).await?, + scrape::get_user_data(forge).await?, HashMap::new(), &classifier, ); @@ -216,7 +212,7 @@ fn set_spam( } eprintln!("recomputing user scores"); - db.recompute_scores(&classifier); + db.recompute_scores(classifier); spammers } @@ -225,11 +221,14 @@ async fn apply_classification( config: &Config, forge: &Forgejo, db: Arc>, - classifier: &mut Classifier, + classifier: Arc>, ids: &[(UserId, bool)], overwrite: bool, ) { - let spammers = set_spam(&mut db.lock().unwrap(), classifier, ids, overwrite); + let spammers = { + let classifier = &mut classifier.lock().unwrap(); + set_spam(&mut db.lock().unwrap(), classifier, ids, overwrite) + }; for user in spammers { let login = db.lock().unwrap().users.get(&user).unwrap().login.clone(); @@ -247,7 +246,6 @@ lazy_static! { pub static ref TEMPLATES: Tera = { let files: Vec<_> = TEMPLATES_DIR .files() - .into_iter() .map(|f| { ( f.path().to_str().unwrap(), @@ -310,7 +308,7 @@ async fn index( users.shuffle(&mut rng); - let sorting_req = q.sort.as_ref().map(|s| s.as_str()); + let sorting_req = q.sort.as_deref(); match &sorting_req { // sort "legit first": by increasing score Some("legit") => users.sort_by_key(|(_, _, score)| (score * 1000.) as u64), @@ -364,9 +362,6 @@ async fn post_classified( ) -> impl Responder { eprintln!("POST {}", req.uri()); - let classifier = &mut data.classifier.lock().unwrap(); - let db = data.db.clone(); - let updates: Vec<(UserId, bool)> = form .iter() .map(|(id, classification)| (UserId(*id), classification == "spam")) @@ -376,17 +371,21 @@ async fn post_classified( &data.config, &data.forge, data.db.clone(), - classifier, + data.classifier.clone(), &updates, overwrite, ) .await; - db.lock() + data.db + .lock() .unwrap() .store_to_path(Path::new("db.json")) .unwrap(); // FIXME - classifier + + data.classifier + .lock() + .unwrap() .save(&mut File::create(Path::new("model.json")).unwrap(), false) .unwrap(); // FIXME @@ -479,28 +478,30 @@ async fn main() -> anyhow::Result<()> { config: config.clone(), }); + let mut workers = tokio::task::JoinSet::new(); + let _ = { let forge = forge.clone(); let db = db.clone(); let classifier = classifier.clone(); - tokio::spawn(async move { workers::refresh_user_data(forge, db, classifier) }) + workers.spawn(async move { workers::refresh_user_data(forge, db, classifier).await }) }; let _ = { let config = config.clone(); let forge = forge.clone(); let db = db.clone(); - tokio::spawn(async move { workers::purge_spammer_accounts(config, forge, db) }) + workers.spawn(async move { workers::purge_spammer_accounts(config, forge, db).await }) }; let _ = { let config = config.clone(); let forge = forge.clone(); let db = db.clone(); - tokio::spawn(async move { workers::lock_and_notify_users(config, forge, db) }) + workers.spawn(async move { workers::lock_and_notify_users(config, forge, db).await }) }; println!("Listening on http://127.0.0.1:8080"); - HttpServer::new(move || { + let webserver = HttpServer::new(move || { App::new() .app_data(st.clone()) .service(static_) @@ -510,8 +511,18 @@ async fn main() -> anyhow::Result<()> { .service(post_classified_edit) }) .bind(("127.0.0.1", 8080))? - .run() - .await?; + .run(); + + tokio::select! { + _ = workers.join_all() => { + unreachable!() + }, + _ = tokio::signal::ctrl_c() => { + }, + res = webserver => { + res? + } + }; Ok(()) } diff --git a/src/scrape.rs b/src/scrape.rs index a8c07d0..9338ba8 100644 --- a/src/scrape.rs +++ b/src/scrape.rs @@ -6,8 +6,10 @@ use crate::data::*; async fn scrape_repos(forge: &Forgejo) -> anyhow::Result> { let mut repos = Vec::new(); - let mut query = forgejo_api::structs::RepoSearchQuery::default(); - query.limit = Some(50); + let mut query = forgejo_api::structs::RepoSearchQuery { + limit: Some(50), + ..Default::default() + }; let mut page: u32 = 1; loop { query.page = Some(page); @@ -29,8 +31,10 @@ async fn scrape_repos(forge: &Forgejo) -> anyhow::Result anyhow::Result> { let mut issues = Vec::new(); - let mut query = forgejo_api::structs::IssueSearchIssuesQuery::default(); - query.limit = Some(50); + let mut query = forgejo_api::structs::IssueSearchIssuesQuery { + limit: Some(50), + ..Default::default() + }; let mut page: u32 = 1; loop { query.page = Some(page); @@ -47,8 +51,10 @@ async fn scrape_issues(forge: &Forgejo) -> anyhow::Result anyhow::Result> { let mut users = Vec::new(); - let mut query = forgejo_api::structs::UserSearchQuery::default(); - query.limit = Some(50); + let mut query = forgejo_api::structs::UserSearchQuery { + limit: Some(50), + ..Default::default() + }; let mut page: u32 = 1; loop { query.page = Some(page); @@ -78,7 +84,7 @@ pub async fn get_user_data(forge: &Forgejo) -> anyhow::Result anyhow::Result anyhow::Result { eprintln!("Sending notification email to user {login}"); - email::send_locked_account_notice(config, &smtp, &login, &email).await?; + email::send_locked_account_notice(config, smtp, &login, &email).await?; eprintln!("Success"); } ActuallyBan::No => {