diff --git a/src/main.rs b/src/main.rs index 634ebe4..f92fa6b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ use forgejo_api::{Auth, Forgejo}; use lazy_static::lazy_static; use rand::prelude::*; use serde::{Deserialize, Serialize}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::fs::File; use std::io::{BufReader, BufWriter}; use std::path::Path; @@ -48,18 +48,9 @@ struct UserData { issues: Vec<(IssueId, IssueData)>, } -#[derive(Debug, Serialize, Deserialize)] -enum Classification { - Spam, - Legit, - Unknown, -} -use Classification::*; - -#[derive(Debug, Serialize, Deserialize)] struct Db { users: HashMap, - classification: HashMap, + is_spam: HashMap, // caches: derived from the rest score: HashMap, tokens: HashMap>, @@ -113,14 +104,45 @@ impl Db { fn new() -> Db { Db { users: HashMap::new(), + is_spam: HashMap::new(), tokens: HashMap::new(), - classification: HashMap::new(), score: HashMap::new(), } } - fn all_users(&self) -> Vec { - self.users.iter().map(|(id, _)| *id).collect() + fn recompute_tokens(&mut self) { + for (id, user) in &self.users { + self.tokens.insert(*id, user.to_tokens()); + } + } + + fn recompute_scores(&mut self, classifier: &Classifier) { + for (id, tokens) in &self.tokens { + self.score.insert(*id, classifier.score(tokens)); + } + } + + fn from_path(path: &Path, classifier: &Classifier) -> anyhow::Result { + let file = File::open(path)?; + let (users, is_spam) = serde_json::from_reader(BufReader::new(file))?; + let mut db = Db { users, is_spam, tokens: HashMap::new(), score: HashMap::new() }; + db.recompute_tokens(); + db.recompute_scores(classifier); + Ok(db) + } + + fn from_users(users: HashMap, is_spam: HashMap, classifier: &Classifier) -> Db { + let mut db = Db { users, is_spam, tokens: HashMap::new(), score: HashMap::new() }; + db.recompute_tokens(); + db.recompute_scores(classifier); + db + } + + fn store_to_path(&self, path: &Path) -> anyhow::Result<()> { + let file = File::create(path)?; + let dat: (&HashMap, &HashMap) = (&self.users, &self.is_spam); + serde_json::to_writer(BufWriter::new(file), &dat)?; + Ok(()) } } @@ -298,7 +320,7 @@ async fn get_users_data(forge: &Forgejo) -> anyhow::Result anyhow::Result<(Db, Classifier)> { let model_path = Path::new("model.json"); - let mut classifier = if model_path.is_file() { + let classifier = if model_path.is_file() { Classifier::new_from_pre_trained(&mut File::open(model_path)?)? } else { Classifier::new() @@ -314,70 +336,46 @@ async fn load_db() -> anyhow::Result<(Db, Classifier)> { let db_path = Path::new("db.json"); let db: Db = if db_path.is_file() { - let file = File::open(db_path)?; - serde_json::from_reader(BufReader::new(file))? + Db::from_path(db_path, &classifier)? } else { - let mut db = Db::new(); - - db.users = get_users_data(&forge).await?; - - eprintln!("Scoring users..."); - for &user_id in &db.all_users() { - update_user(&mut db, &mut classifier, user_id); - } - - let file = File::create(db_path)?; - serde_json::to_writer(BufWriter::new(file), &db)?; + let db = + Db::from_users( + get_users_data(&forge).await?, + HashMap::new(), + &classifier + ); + db.store_to_path(db_path)?; db }; Ok((db, classifier)) } -fn update_user(db: &mut Db, classifier: &mut Classifier, id: UserId) { - let userdata = db.users.get(&id).unwrap(); - let tokens = - match db.tokens.get(&id) { - Some(tokens) => tokens, - None => { - let tokens = userdata.to_tokens(); - db.tokens.insert(id, tokens); - db.tokens.get(&id).unwrap() - } - }; - db.score.insert(id, classifier.score(&tokens)); -} - fn unclassified_users<'a>(db: &'a Db) -> Vec<(&'a UserId, &'a UserData)> { db.users .iter() - .filter(|(user_id, _)| !db.classification.contains_key(&user_id)) + .filter(|(user_id, _)| !db.is_spam.contains_key(&user_id)) .collect() } fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)]) { - let mut all_tokens = HashSet::new(); - eprintln!("updating classifier"); for (id, is_spam) in ids { let tokens = db.tokens.get(id).unwrap(); - if *is_spam { classifier.train_spam(tokens); } else { classifier.train_ham(tokens); } - - for tok in tokens { - all_tokens.insert(tok.clone()); - } } eprintln!("recomputing user scores"); + db.recompute_scores(&classifier); - for &user_id in &db.all_users() { - update_user(db, classifier, user_id) + eprintln!("updating db with new classification"); + for (user_id, is_spam) in ids { + db.is_spam.insert(*user_id, *is_spam); } } @@ -489,7 +487,6 @@ async fn apply( set_spam(db, classifier, &updates); - eprintln!("{:#?}", req); HttpResponse::SeeOther() .insert_header(("Location", "/")) .finish() diff --git a/templates/index.html b/templates/index.html index 886ad4e..9ebb325 100644 --- a/templates/index.html +++ b/templates/index.html @@ -38,6 +38,9 @@ flex-direction: row; gap: 10px; align-items: center; + border: 1px dotted #000; + padding: 3px 8px 3px 8px; + margin: 3px; } .user-card { @@ -87,7 +90,7 @@ .score { padding-left: 3px; padding-right: 3px; - width: 3em; + width: 2.8em; text-align: center; flex-grow: 0; flex-shrink: 0; @@ -134,8 +137,8 @@
-
{{ user.login }}
- {%- if user.full_name %}
({{ user.full_name }})
{% endif -%} +
{{ user.login }}
+ {%- if user.full_name %}
({{ user.full_name }})
{% endif -%}