Store the tim a user was classified as spam in the database

This commit is contained in:
Armaël Guéneau 2024-12-19 11:49:08 +01:00
parent b2406dd883
commit 8f2bc7ebc6
4 changed files with 64 additions and 43 deletions

File diff suppressed because one or more lines are too long

View file

@ -2,17 +2,51 @@ use std::collections::HashMap;
use std::path::Path; use std::path::Path;
use std::fs::File; use std::fs::File;
use std::io::{BufReader, BufWriter}; use std::io::{BufReader, BufWriter};
use std::time::{Duration, SystemTime}; use std::time::SystemTime;
use std::fmt;
use serde::{Serialize, Deserialize};
use crate::data::*; use crate::data::*;
use crate::classifier::Classifier; use crate::classifier::Classifier;
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
pub enum IsSpam {
Legit,
Spam { classified_at: SystemTime },
}
impl IsSpam {
pub fn as_bool(&self) -> bool {
match self {
IsSpam::Legit => true,
IsSpam::Spam { .. } => false,
}
}
pub fn from_bool(b: bool) -> IsSpam {
if b {
IsSpam::Spam { classified_at: SystemTime::now() }
} else {
IsSpam::Legit
}
}
}
impl fmt::Display for IsSpam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IsSpam::Legit => write!(f, "legit"),
IsSpam::Spam { .. } => write!(f, "spam"),
}
}
}
// TODO (?): make the fields private and provide an API that automatically // TODO (?): make the fields private and provide an API that automatically
// recomputes the caches when necessary? // recomputes the caches when necessary?
pub struct Db { pub struct Db {
// persisted data // persisted data
pub users: HashMap<UserId, UserData>, pub users: HashMap<UserId, UserData>,
pub is_spam: HashMap<UserId, bool>, pub is_spam: HashMap<UserId, IsSpam>,
last_scrape: u64, pub last_scrape: SystemTime,
// caches: computed from persisted data on load // caches: computed from persisted data on load
pub score: HashMap<UserId, f32>, pub score: HashMap<UserId, f32>,
pub tokens: HashMap<UserId, Vec<String>>, pub tokens: HashMap<UserId, Vec<String>>,
@ -31,21 +65,10 @@ impl Db {
} }
} }
pub fn last_scrape(&self) -> SystemTime {
std::time::UNIX_EPOCH + Duration::from_secs(self.last_scrape)
}
pub fn set_last_scrape_to_now(&mut self) {
self.last_scrape =
SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d: Duration| d.as_secs())
.unwrap_or(0);
}
pub fn from_path(path: &Path, classifier: &Classifier) -> anyhow::Result<Self> { pub fn from_path(path: &Path, classifier: &Classifier) -> anyhow::Result<Self> {
let file = File::open(path)?; let file = File::open(path)?;
let (users, is_spam, last_scrape) = serde_json::from_reader(BufReader::new(file))?; let (users, is_spam, last_scrape) =
serde_json::from_reader(BufReader::new(file))?;
let mut db = Db { let mut db = Db {
users, users,
is_spam, is_spam,
@ -60,25 +83,24 @@ impl Db {
pub fn from_users( pub fn from_users(
users: HashMap<UserId, UserData>, users: HashMap<UserId, UserData>,
is_spam: HashMap<UserId, bool>, is_spam: HashMap<UserId, IsSpam>,
classifier: &Classifier, classifier: &Classifier,
) -> Db { ) -> Db {
let mut db = Db { let mut db = Db {
users, users,
is_spam, is_spam,
last_scrape: 0, last_scrape: SystemTime::now(),
tokens: HashMap::new(), tokens: HashMap::new(),
score: HashMap::new(), score: HashMap::new(),
}; };
db.recompute_tokens(); db.recompute_tokens();
db.recompute_scores(classifier); db.recompute_scores(classifier);
db.set_last_scrape_to_now();
db db
} }
pub fn store_to_path(&self, path: &Path) -> anyhow::Result<()> { pub fn store_to_path(&self, path: &Path) -> anyhow::Result<()> {
let file = File::create(path)?; let file = File::create(path)?;
let dat: (&HashMap<UserId, UserData>, &HashMap<UserId, bool>, u64) = let dat: (&HashMap<UserId, UserData>, &HashMap<UserId, IsSpam>, SystemTime) =
(&self.users, &self.is_spam, self.last_scrape); (&self.users, &self.is_spam, self.last_scrape);
serde_json::to_writer(BufWriter::new(file), &dat)?; serde_json::to_writer(BufWriter::new(file), &dat)?;
Ok(()) Ok(())
@ -91,7 +113,7 @@ impl Db {
.collect() .collect()
} }
pub fn classified_users<'a>(&'a self) -> Vec<(&'a UserId, &'a UserData, bool)> { pub fn classified_users<'a>(&'a self) -> Vec<(&'a UserId, &'a UserData, IsSpam)> {
self.users self.users
.iter() .iter()
.filter_map(|(user_id, user_data)| .filter_map(|(user_id, user_data)|

View file

@ -18,7 +18,7 @@ mod workers;
use classifier::Classifier; use classifier::Classifier;
use data::*; use data::*;
use db::Db; use db::{IsSpam, Db};
// Fetch user data from forgejo from time to time // Fetch user data from forgejo from time to time
const FORGEJO_POLL_DELAY: Duration = Duration::from_secs(11 * 3600); // 11 hours const FORGEJO_POLL_DELAY: Duration = Duration::from_secs(11 * 3600); // 11 hours
@ -64,18 +64,17 @@ async fn load_db(forge: &Forgejo) -> anyhow::Result<(Db, Classifier)> {
// XXX: This function looks like it is doing too many things at once. // XXX: This function looks like it is doing too many things at once.
fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)], overwrite: bool) { fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)], overwrite: bool) {
for (user_id, is_spam) in ids { for &(user_id, is_spam) in ids {
let mut train_classifier = false; let mut train_classifier = false;
match db.is_spam.get(user_id) { match db.is_spam.get(&user_id) {
Some(was_spam) if overwrite && was_spam != is_spam => { Some(&was_spam) if overwrite && was_spam.as_bool() != is_spam => {
eprintln!( eprintln!(
"User {}: changing classification from {} to {}", "User {}: changing classification from {} to {}",
db.users.get(user_id).unwrap().login, db.users.get(&user_id).unwrap().login,
(if *was_spam { "spam" } else { "legit" }), was_spam, is_spam
(if *is_spam { "spam" } else { "legit" })
); );
db.is_spam.insert(*user_id, *is_spam); db.is_spam.insert(user_id, IsSpam::from_bool(is_spam));
// This is somewhat hackish: we already trained the classifier // This is somewhat hackish: we already trained the classifier
// on the previous classification, possibly with the same // on the previous classification, possibly with the same
// tokens. // tokens.
@ -86,33 +85,33 @@ fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)], ov
// used previously). // used previously).
train_classifier = true; train_classifier = true;
}, },
Some(was_spam) if !overwrite && was_spam != is_spam => { Some(&was_spam) if !overwrite && was_spam.as_bool() != is_spam => {
// Classification conflict between concurrent queries. // Classification conflict between concurrent queries.
// In this case we play it safe and discard the classification // In this case we play it safe and discard the classification
// for this user; the user will need to be manually classified again. // for this user; the user will need to be manually classified again.
eprintln!( eprintln!(
"Classification conflict for user {}; discarding our current classification", "Classification conflict for user {}; discarding our current classification",
db.users.get(user_id).unwrap().login db.users.get(&user_id).unwrap().login
); );
db.is_spam.remove(user_id); db.is_spam.remove(&user_id);
}, },
None => { None => {
db.is_spam.insert(*user_id, *is_spam); db.is_spam.insert(user_id, IsSpam::from_bool(is_spam));
train_classifier = true; train_classifier = true;
}, },
Some(was_spam) => { Some(was_spam) => {
assert!(was_spam == is_spam); assert!(was_spam.as_bool() == is_spam);
// nothing to do // nothing to do
} }
} }
if train_classifier { if train_classifier {
// Train the classifier with tokens from the user // Train the classifier with tokens from the user
let tokens = db.tokens.get(user_id).unwrap(); let tokens = db.tokens.get(&user_id).unwrap();
if *is_spam { if is_spam {
classifier.train_spam(tokens); classifier.train_spam(tokens)
} else { } else {
classifier.train_ham(tokens); classifier.train_ham(tokens)
} }
} }
} }
@ -252,7 +251,7 @@ async fn classified(data: web::Data<AppState>, _q: web::Query<SortSetting>) -> i
let mut users: Vec<(&UserId, &UserData, f32, bool)> = db let mut users: Vec<(&UserId, &UserData, f32, bool)> = db
.classified_users() .classified_users()
.into_iter() .into_iter()
.map(|(id, u, s)| (id, u, *db.score.get(id).unwrap(), s)) .map(|(id, u, s)| (id, u, *db.score.get(id).unwrap(), s.as_bool()))
.collect(); .collect();
// sort "spam first" // sort "spam first"
users.sort_by_key(|(_, _, score, _)| 1000 - (score * 1000.) as u64); users.sort_by_key(|(_, _, score, _)| 1000 - (score * 1000.) as u64);

View file

@ -12,7 +12,7 @@ use crate::{GUESS_LEGIT_THRESHOLD, GUESS_SPAM_THRESHOLD};
async fn try_refresh_user_data(forge: &Forgejo, db: Arc<Mutex<Db>>, classifier: Arc<Mutex<Classifier>>) -> anyhow::Result<()> { async fn try_refresh_user_data(forge: &Forgejo, db: Arc<Mutex<Db>>, classifier: Arc<Mutex<Classifier>>) -> anyhow::Result<()> {
{ {
let db = &db.lock().unwrap(); let db = &db.lock().unwrap();
let d = db.last_scrape().elapsed()?; let d = db.last_scrape.elapsed()?;
if d < FORGEJO_POLL_DELAY { if d < FORGEJO_POLL_DELAY {
return Ok(()); return Ok(());
} }
@ -36,8 +36,8 @@ async fn try_refresh_user_data(forge: &Forgejo, db: Arc<Mutex<Db>>, classifier:
for (&user_id, user_data) in &newdb.users { for (&user_id, user_data) in &newdb.users {
let &score = newdb.score.get(&user_id).unwrap(); let &score = newdb.score.get(&user_id).unwrap();
if let Some(&user_was_spam) = db.is_spam.get(&user_id) { if let Some(&user_was_spam) = db.is_spam.get(&user_id) {
if (user_was_spam && score < GUESS_SPAM_THRESHOLD) || if (user_was_spam.as_bool() && score < GUESS_SPAM_THRESHOLD) ||
(! user_was_spam && score > GUESS_LEGIT_THRESHOLD) (! user_was_spam.as_bool() && score > GUESS_LEGIT_THRESHOLD)
{ {
eprintln!( eprintln!(
"Score for user {} changed past threshold; discarding our current classification", "Score for user {} changed past threshold; discarding our current classification",