only store the non-cache part of the Db

This commit is contained in:
Armaël Guéneau 2024-11-23 11:46:47 +01:00
parent ff95f3807b
commit b420e1608d
2 changed files with 56 additions and 56 deletions

View file

@ -3,7 +3,7 @@ use forgejo_api::{Auth, Forgejo};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use rand::prelude::*; use rand::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet}; use std::collections::HashMap;
use std::fs::File; use std::fs::File;
use std::io::{BufReader, BufWriter}; use std::io::{BufReader, BufWriter};
use std::path::Path; use std::path::Path;
@ -48,18 +48,9 @@ struct UserData {
issues: Vec<(IssueId, IssueData)>, issues: Vec<(IssueId, IssueData)>,
} }
#[derive(Debug, Serialize, Deserialize)]
enum Classification {
Spam,
Legit,
Unknown,
}
use Classification::*;
#[derive(Debug, Serialize, Deserialize)]
struct Db { struct Db {
users: HashMap<UserId, UserData>, users: HashMap<UserId, UserData>,
classification: HashMap<UserId, Classification>, is_spam: HashMap<UserId, bool>,
// caches: derived from the rest // caches: derived from the rest
score: HashMap<UserId, f32>, score: HashMap<UserId, f32>,
tokens: HashMap<UserId, Vec<String>>, tokens: HashMap<UserId, Vec<String>>,
@ -113,14 +104,45 @@ impl Db {
fn new() -> Db { fn new() -> Db {
Db { Db {
users: HashMap::new(), users: HashMap::new(),
is_spam: HashMap::new(),
tokens: HashMap::new(), tokens: HashMap::new(),
classification: HashMap::new(),
score: HashMap::new(), score: HashMap::new(),
} }
} }
fn all_users(&self) -> Vec<UserId> { fn recompute_tokens(&mut self) {
self.users.iter().map(|(id, _)| *id).collect() for (id, user) in &self.users {
self.tokens.insert(*id, user.to_tokens());
}
}
fn recompute_scores(&mut self, classifier: &Classifier) {
for (id, tokens) in &self.tokens {
self.score.insert(*id, classifier.score(tokens));
}
}
fn from_path(path: &Path, classifier: &Classifier) -> anyhow::Result<Self> {
let file = File::open(path)?;
let (users, is_spam) = serde_json::from_reader(BufReader::new(file))?;
let mut db = Db { users, is_spam, tokens: HashMap::new(), score: HashMap::new() };
db.recompute_tokens();
db.recompute_scores(classifier);
Ok(db)
}
fn from_users(users: HashMap<UserId, UserData>, is_spam: HashMap<UserId, bool>, classifier: &Classifier) -> Db {
let mut db = Db { users, is_spam, tokens: HashMap::new(), score: HashMap::new() };
db.recompute_tokens();
db.recompute_scores(classifier);
db
}
fn store_to_path(&self, path: &Path) -> anyhow::Result<()> {
let file = File::create(path)?;
let dat: (&HashMap<UserId, UserData>, &HashMap<UserId, bool>) = (&self.users, &self.is_spam);
serde_json::to_writer(BufWriter::new(file), &dat)?;
Ok(())
} }
} }
@ -298,7 +320,7 @@ async fn get_users_data(forge: &Forgejo) -> anyhow::Result<HashMap<UserId, UserD
async fn load_db() -> anyhow::Result<(Db, Classifier)> { async fn load_db() -> anyhow::Result<(Db, Classifier)> {
let model_path = Path::new("model.json"); let model_path = Path::new("model.json");
let mut classifier = if model_path.is_file() { let classifier = if model_path.is_file() {
Classifier::new_from_pre_trained(&mut File::open(model_path)?)? Classifier::new_from_pre_trained(&mut File::open(model_path)?)?
} else { } else {
Classifier::new() Classifier::new()
@ -314,70 +336,46 @@ async fn load_db() -> anyhow::Result<(Db, Classifier)> {
let db_path = Path::new("db.json"); let db_path = Path::new("db.json");
let db: Db = if db_path.is_file() { let db: Db = if db_path.is_file() {
let file = File::open(db_path)?; Db::from_path(db_path, &classifier)?
serde_json::from_reader(BufReader::new(file))?
} else { } else {
let mut db = Db::new(); let db =
Db::from_users(
db.users = get_users_data(&forge).await?; get_users_data(&forge).await?,
HashMap::new(),
eprintln!("Scoring users..."); &classifier
for &user_id in &db.all_users() { );
update_user(&mut db, &mut classifier, user_id); db.store_to_path(db_path)?;
}
let file = File::create(db_path)?;
serde_json::to_writer(BufWriter::new(file), &db)?;
db db
}; };
Ok((db, classifier)) Ok((db, classifier))
} }
fn update_user(db: &mut Db, classifier: &mut Classifier, id: UserId) {
let userdata = db.users.get(&id).unwrap();
let tokens =
match db.tokens.get(&id) {
Some(tokens) => tokens,
None => {
let tokens = userdata.to_tokens();
db.tokens.insert(id, tokens);
db.tokens.get(&id).unwrap()
}
};
db.score.insert(id, classifier.score(&tokens));
}
fn unclassified_users<'a>(db: &'a Db) -> Vec<(&'a UserId, &'a UserData)> { fn unclassified_users<'a>(db: &'a Db) -> Vec<(&'a UserId, &'a UserData)> {
db.users db.users
.iter() .iter()
.filter(|(user_id, _)| !db.classification.contains_key(&user_id)) .filter(|(user_id, _)| !db.is_spam.contains_key(&user_id))
.collect() .collect()
} }
fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)]) { fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)]) {
let mut all_tokens = HashSet::new();
eprintln!("updating classifier"); eprintln!("updating classifier");
for (id, is_spam) in ids { for (id, is_spam) in ids {
let tokens = db.tokens.get(id).unwrap(); let tokens = db.tokens.get(id).unwrap();
if *is_spam { if *is_spam {
classifier.train_spam(tokens); classifier.train_spam(tokens);
} else { } else {
classifier.train_ham(tokens); classifier.train_ham(tokens);
} }
for tok in tokens {
all_tokens.insert(tok.clone());
}
} }
eprintln!("recomputing user scores"); eprintln!("recomputing user scores");
db.recompute_scores(&classifier);
for &user_id in &db.all_users() { eprintln!("updating db with new classification");
update_user(db, classifier, user_id) for (user_id, is_spam) in ids {
db.is_spam.insert(*user_id, *is_spam);
} }
} }
@ -489,7 +487,6 @@ async fn apply(
set_spam(db, classifier, &updates); set_spam(db, classifier, &updates);
eprintln!("{:#?}", req);
HttpResponse::SeeOther() HttpResponse::SeeOther()
.insert_header(("Location", "/")) .insert_header(("Location", "/"))
.finish() .finish()

View file

@ -38,6 +38,9 @@
flex-direction: row; flex-direction: row;
gap: 10px; gap: 10px;
align-items: center; align-items: center;
border: 1px dotted #000;
padding: 3px 8px 3px 8px;
margin: 3px;
} }
.user-card { .user-card {
@ -87,7 +90,7 @@
.score { .score {
padding-left: 3px; padding-left: 3px;
padding-right: 3px; padding-right: 3px;
width: 3em; width: 2.8em;
text-align: center; text-align: center;
flex-grow: 0; flex-grow: 0;
flex-shrink: 0; flex-shrink: 0;
@ -134,8 +137,8 @@
</div> </div>
<div class="user-card"> <div class="user-card">
<div class="user-name"> <div class="user-name">
<div>{{ user.login }}</div> <div><strong>{{ user.login }}</strong></div>
{%- if user.full_name %}<div>({{ user.full_name }})</div>{% endif -%} {%- if user.full_name %}<div><strong>({{ user.full_name }})</strong></div>{% endif -%}
</div> </div>
<div class="user-info"> <div class="user-info">
{%- if user.location %}<div>[L] {{ user.location }}</div>{% endif -%} {%- if user.location %}<div>[L] {{ user.location }}</div>{% endif -%}