only store the non-cache part of the Db

This commit is contained in:
Armaël Guéneau 2024-11-23 11:46:47 +01:00
parent ff95f3807b
commit b420e1608d
2 changed files with 56 additions and 56 deletions

View file

@ -3,7 +3,7 @@ use forgejo_api::{Auth, Forgejo};
use lazy_static::lazy_static;
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::Path;
@ -48,18 +48,9 @@ struct UserData {
issues: Vec<(IssueId, IssueData)>,
}
#[derive(Debug, Serialize, Deserialize)]
enum Classification {
Spam,
Legit,
Unknown,
}
use Classification::*;
#[derive(Debug, Serialize, Deserialize)]
struct Db {
users: HashMap<UserId, UserData>,
classification: HashMap<UserId, Classification>,
is_spam: HashMap<UserId, bool>,
// caches: derived from the rest
score: HashMap<UserId, f32>,
tokens: HashMap<UserId, Vec<String>>,
@ -113,14 +104,45 @@ impl Db {
fn new() -> Db {
Db {
users: HashMap::new(),
is_spam: HashMap::new(),
tokens: HashMap::new(),
classification: HashMap::new(),
score: HashMap::new(),
}
}
fn all_users(&self) -> Vec<UserId> {
self.users.iter().map(|(id, _)| *id).collect()
fn recompute_tokens(&mut self) {
for (id, user) in &self.users {
self.tokens.insert(*id, user.to_tokens());
}
}
fn recompute_scores(&mut self, classifier: &Classifier) {
for (id, tokens) in &self.tokens {
self.score.insert(*id, classifier.score(tokens));
}
}
fn from_path(path: &Path, classifier: &Classifier) -> anyhow::Result<Self> {
let file = File::open(path)?;
let (users, is_spam) = serde_json::from_reader(BufReader::new(file))?;
let mut db = Db { users, is_spam, tokens: HashMap::new(), score: HashMap::new() };
db.recompute_tokens();
db.recompute_scores(classifier);
Ok(db)
}
fn from_users(users: HashMap<UserId, UserData>, is_spam: HashMap<UserId, bool>, classifier: &Classifier) -> Db {
let mut db = Db { users, is_spam, tokens: HashMap::new(), score: HashMap::new() };
db.recompute_tokens();
db.recompute_scores(classifier);
db
}
fn store_to_path(&self, path: &Path) -> anyhow::Result<()> {
let file = File::create(path)?;
let dat: (&HashMap<UserId, UserData>, &HashMap<UserId, bool>) = (&self.users, &self.is_spam);
serde_json::to_writer(BufWriter::new(file), &dat)?;
Ok(())
}
}
@ -298,7 +320,7 @@ async fn get_users_data(forge: &Forgejo) -> anyhow::Result<HashMap<UserId, UserD
async fn load_db() -> anyhow::Result<(Db, Classifier)> {
let model_path = Path::new("model.json");
let mut classifier = if model_path.is_file() {
let classifier = if model_path.is_file() {
Classifier::new_from_pre_trained(&mut File::open(model_path)?)?
} else {
Classifier::new()
@ -314,70 +336,46 @@ async fn load_db() -> anyhow::Result<(Db, Classifier)> {
let db_path = Path::new("db.json");
let db: Db = if db_path.is_file() {
let file = File::open(db_path)?;
serde_json::from_reader(BufReader::new(file))?
Db::from_path(db_path, &classifier)?
} else {
let mut db = Db::new();
db.users = get_users_data(&forge).await?;
eprintln!("Scoring users...");
for &user_id in &db.all_users() {
update_user(&mut db, &mut classifier, user_id);
}
let file = File::create(db_path)?;
serde_json::to_writer(BufWriter::new(file), &db)?;
let db =
Db::from_users(
get_users_data(&forge).await?,
HashMap::new(),
&classifier
);
db.store_to_path(db_path)?;
db
};
Ok((db, classifier))
}
fn update_user(db: &mut Db, classifier: &mut Classifier, id: UserId) {
let userdata = db.users.get(&id).unwrap();
let tokens =
match db.tokens.get(&id) {
Some(tokens) => tokens,
None => {
let tokens = userdata.to_tokens();
db.tokens.insert(id, tokens);
db.tokens.get(&id).unwrap()
}
};
db.score.insert(id, classifier.score(&tokens));
}
fn unclassified_users<'a>(db: &'a Db) -> Vec<(&'a UserId, &'a UserData)> {
db.users
.iter()
.filter(|(user_id, _)| !db.classification.contains_key(&user_id))
.filter(|(user_id, _)| !db.is_spam.contains_key(&user_id))
.collect()
}
fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)]) {
let mut all_tokens = HashSet::new();
eprintln!("updating classifier");
for (id, is_spam) in ids {
let tokens = db.tokens.get(id).unwrap();
if *is_spam {
classifier.train_spam(tokens);
} else {
classifier.train_ham(tokens);
}
for tok in tokens {
all_tokens.insert(tok.clone());
}
}
eprintln!("recomputing user scores");
db.recompute_scores(&classifier);
for &user_id in &db.all_users() {
update_user(db, classifier, user_id)
eprintln!("updating db with new classification");
for (user_id, is_spam) in ids {
db.is_spam.insert(*user_id, *is_spam);
}
}
@ -489,7 +487,6 @@ async fn apply(
set_spam(db, classifier, &updates);
eprintln!("{:#?}", req);
HttpResponse::SeeOther()
.insert_header(("Location", "/"))
.finish()

View file

@ -38,6 +38,9 @@
flex-direction: row;
gap: 10px;
align-items: center;
border: 1px dotted #000;
padding: 3px 8px 3px 8px;
margin: 3px;
}
.user-card {
@ -87,7 +90,7 @@
.score {
padding-left: 3px;
padding-right: 3px;
width: 3em;
width: 2.8em;
text-align: center;
flex-grow: 0;
flex-shrink: 0;
@ -134,8 +137,8 @@
</div>
<div class="user-card">
<div class="user-name">
<div>{{ user.login }}</div>
{%- if user.full_name %}<div>({{ user.full_name }})</div>{% endif -%}
<div><strong>{{ user.login }}</strong></div>
{%- if user.full_name %}<div><strong>({{ user.full_name }})</strong></div>{% endif -%}
</div>
<div class="user-info">
{%- if user.location %}<div>[L] {{ user.location }}</div>{% endif -%}