Store the tim a user was classified as spam in the database

This commit is contained in:
Armaël Guéneau 2024-12-19 11:49:08 +01:00
parent b2406dd883
commit 8f2bc7ebc6
4 changed files with 64 additions and 43 deletions

File diff suppressed because one or more lines are too long

View file

@ -2,17 +2,51 @@ use std::collections::HashMap;
use std::path::Path;
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::time::{Duration, SystemTime};
use std::time::SystemTime;
use std::fmt;
use serde::{Serialize, Deserialize};
use crate::data::*;
use crate::classifier::Classifier;
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
pub enum IsSpam {
Legit,
Spam { classified_at: SystemTime },
}
impl IsSpam {
pub fn as_bool(&self) -> bool {
match self {
IsSpam::Legit => true,
IsSpam::Spam { .. } => false,
}
}
pub fn from_bool(b: bool) -> IsSpam {
if b {
IsSpam::Spam { classified_at: SystemTime::now() }
} else {
IsSpam::Legit
}
}
}
impl fmt::Display for IsSpam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IsSpam::Legit => write!(f, "legit"),
IsSpam::Spam { .. } => write!(f, "spam"),
}
}
}
// TODO (?): make the fields private and provide an API that automatically
// recomputes the caches when necessary?
pub struct Db {
// persisted data
pub users: HashMap<UserId, UserData>,
pub is_spam: HashMap<UserId, bool>,
last_scrape: u64,
pub is_spam: HashMap<UserId, IsSpam>,
pub last_scrape: SystemTime,
// caches: computed from persisted data on load
pub score: HashMap<UserId, f32>,
pub tokens: HashMap<UserId, Vec<String>>,
@ -31,21 +65,10 @@ impl Db {
}
}
pub fn last_scrape(&self) -> SystemTime {
std::time::UNIX_EPOCH + Duration::from_secs(self.last_scrape)
}
pub fn set_last_scrape_to_now(&mut self) {
self.last_scrape =
SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d: Duration| d.as_secs())
.unwrap_or(0);
}
pub fn from_path(path: &Path, classifier: &Classifier) -> anyhow::Result<Self> {
let file = File::open(path)?;
let (users, is_spam, last_scrape) = serde_json::from_reader(BufReader::new(file))?;
let (users, is_spam, last_scrape) =
serde_json::from_reader(BufReader::new(file))?;
let mut db = Db {
users,
is_spam,
@ -60,25 +83,24 @@ impl Db {
pub fn from_users(
users: HashMap<UserId, UserData>,
is_spam: HashMap<UserId, bool>,
is_spam: HashMap<UserId, IsSpam>,
classifier: &Classifier,
) -> Db {
let mut db = Db {
users,
is_spam,
last_scrape: 0,
last_scrape: SystemTime::now(),
tokens: HashMap::new(),
score: HashMap::new(),
};
db.recompute_tokens();
db.recompute_scores(classifier);
db.set_last_scrape_to_now();
db
}
pub fn store_to_path(&self, path: &Path) -> anyhow::Result<()> {
let file = File::create(path)?;
let dat: (&HashMap<UserId, UserData>, &HashMap<UserId, bool>, u64) =
let dat: (&HashMap<UserId, UserData>, &HashMap<UserId, IsSpam>, SystemTime) =
(&self.users, &self.is_spam, self.last_scrape);
serde_json::to_writer(BufWriter::new(file), &dat)?;
Ok(())
@ -91,7 +113,7 @@ impl Db {
.collect()
}
pub fn classified_users<'a>(&'a self) -> Vec<(&'a UserId, &'a UserData, bool)> {
pub fn classified_users<'a>(&'a self) -> Vec<(&'a UserId, &'a UserData, IsSpam)> {
self.users
.iter()
.filter_map(|(user_id, user_data)|

View file

@ -18,7 +18,7 @@ mod workers;
use classifier::Classifier;
use data::*;
use db::Db;
use db::{IsSpam, Db};
// Fetch user data from forgejo from time to time
const FORGEJO_POLL_DELAY: Duration = Duration::from_secs(11 * 3600); // 11 hours
@ -64,18 +64,17 @@ async fn load_db(forge: &Forgejo) -> anyhow::Result<(Db, Classifier)> {
// XXX: This function looks like it is doing too many things at once.
fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)], overwrite: bool) {
for (user_id, is_spam) in ids {
for &(user_id, is_spam) in ids {
let mut train_classifier = false;
match db.is_spam.get(user_id) {
Some(was_spam) if overwrite && was_spam != is_spam => {
match db.is_spam.get(&user_id) {
Some(&was_spam) if overwrite && was_spam.as_bool() != is_spam => {
eprintln!(
"User {}: changing classification from {} to {}",
db.users.get(user_id).unwrap().login,
(if *was_spam { "spam" } else { "legit" }),
(if *is_spam { "spam" } else { "legit" })
db.users.get(&user_id).unwrap().login,
was_spam, is_spam
);
db.is_spam.insert(*user_id, *is_spam);
db.is_spam.insert(user_id, IsSpam::from_bool(is_spam));
// This is somewhat hackish: we already trained the classifier
// on the previous classification, possibly with the same
// tokens.
@ -86,33 +85,33 @@ fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)], ov
// used previously).
train_classifier = true;
},
Some(was_spam) if !overwrite && was_spam != is_spam => {
Some(&was_spam) if !overwrite && was_spam.as_bool() != is_spam => {
// Classification conflict between concurrent queries.
// In this case we play it safe and discard the classification
// for this user; the user will need to be manually classified again.
eprintln!(
"Classification conflict for user {}; discarding our current classification",
db.users.get(user_id).unwrap().login
db.users.get(&user_id).unwrap().login
);
db.is_spam.remove(user_id);
db.is_spam.remove(&user_id);
},
None => {
db.is_spam.insert(*user_id, *is_spam);
db.is_spam.insert(user_id, IsSpam::from_bool(is_spam));
train_classifier = true;
},
Some(was_spam) => {
assert!(was_spam == is_spam);
assert!(was_spam.as_bool() == is_spam);
// nothing to do
}
}
if train_classifier {
// Train the classifier with tokens from the user
let tokens = db.tokens.get(user_id).unwrap();
if *is_spam {
classifier.train_spam(tokens);
let tokens = db.tokens.get(&user_id).unwrap();
if is_spam {
classifier.train_spam(tokens)
} else {
classifier.train_ham(tokens);
classifier.train_ham(tokens)
}
}
}
@ -252,7 +251,7 @@ async fn classified(data: web::Data<AppState>, _q: web::Query<SortSetting>) -> i
let mut users: Vec<(&UserId, &UserData, f32, bool)> = db
.classified_users()
.into_iter()
.map(|(id, u, s)| (id, u, *db.score.get(id).unwrap(), s))
.map(|(id, u, s)| (id, u, *db.score.get(id).unwrap(), s.as_bool()))
.collect();
// sort "spam first"
users.sort_by_key(|(_, _, score, _)| 1000 - (score * 1000.) as u64);

View file

@ -12,7 +12,7 @@ use crate::{GUESS_LEGIT_THRESHOLD, GUESS_SPAM_THRESHOLD};
async fn try_refresh_user_data(forge: &Forgejo, db: Arc<Mutex<Db>>, classifier: Arc<Mutex<Classifier>>) -> anyhow::Result<()> {
{
let db = &db.lock().unwrap();
let d = db.last_scrape().elapsed()?;
let d = db.last_scrape.elapsed()?;
if d < FORGEJO_POLL_DELAY {
return Ok(());
}
@ -36,8 +36,8 @@ async fn try_refresh_user_data(forge: &Forgejo, db: Arc<Mutex<Db>>, classifier:
for (&user_id, user_data) in &newdb.users {
let &score = newdb.score.get(&user_id).unwrap();
if let Some(&user_was_spam) = db.is_spam.get(&user_id) {
if (user_was_spam && score < GUESS_SPAM_THRESHOLD) ||
(! user_was_spam && score > GUESS_LEGIT_THRESHOLD)
if (user_was_spam.as_bool() && score < GUESS_SPAM_THRESHOLD) ||
(! user_was_spam.as_bool() && score > GUESS_LEGIT_THRESHOLD)
{
eprintln!(
"Score for user {} changed past threshold; discarding our current classification",