only store the non-cache part of the Db
This commit is contained in:
parent
ff95f3807b
commit
b420e1608d
2 changed files with 56 additions and 56 deletions
103
src/main.rs
103
src/main.rs
|
@ -3,7 +3,7 @@ use forgejo_api::{Auth, Forgejo};
|
|||
use lazy_static::lazy_static;
|
||||
use rand::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter};
|
||||
use std::path::Path;
|
||||
|
@ -48,18 +48,9 @@ struct UserData {
|
|||
issues: Vec<(IssueId, IssueData)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
enum Classification {
|
||||
Spam,
|
||||
Legit,
|
||||
Unknown,
|
||||
}
|
||||
use Classification::*;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Db {
|
||||
users: HashMap<UserId, UserData>,
|
||||
classification: HashMap<UserId, Classification>,
|
||||
is_spam: HashMap<UserId, bool>,
|
||||
// caches: derived from the rest
|
||||
score: HashMap<UserId, f32>,
|
||||
tokens: HashMap<UserId, Vec<String>>,
|
||||
|
@ -113,14 +104,45 @@ impl Db {
|
|||
fn new() -> Db {
|
||||
Db {
|
||||
users: HashMap::new(),
|
||||
is_spam: HashMap::new(),
|
||||
tokens: HashMap::new(),
|
||||
classification: HashMap::new(),
|
||||
score: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn all_users(&self) -> Vec<UserId> {
|
||||
self.users.iter().map(|(id, _)| *id).collect()
|
||||
fn recompute_tokens(&mut self) {
|
||||
for (id, user) in &self.users {
|
||||
self.tokens.insert(*id, user.to_tokens());
|
||||
}
|
||||
}
|
||||
|
||||
fn recompute_scores(&mut self, classifier: &Classifier) {
|
||||
for (id, tokens) in &self.tokens {
|
||||
self.score.insert(*id, classifier.score(tokens));
|
||||
}
|
||||
}
|
||||
|
||||
fn from_path(path: &Path, classifier: &Classifier) -> anyhow::Result<Self> {
|
||||
let file = File::open(path)?;
|
||||
let (users, is_spam) = serde_json::from_reader(BufReader::new(file))?;
|
||||
let mut db = Db { users, is_spam, tokens: HashMap::new(), score: HashMap::new() };
|
||||
db.recompute_tokens();
|
||||
db.recompute_scores(classifier);
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
fn from_users(users: HashMap<UserId, UserData>, is_spam: HashMap<UserId, bool>, classifier: &Classifier) -> Db {
|
||||
let mut db = Db { users, is_spam, tokens: HashMap::new(), score: HashMap::new() };
|
||||
db.recompute_tokens();
|
||||
db.recompute_scores(classifier);
|
||||
db
|
||||
}
|
||||
|
||||
fn store_to_path(&self, path: &Path) -> anyhow::Result<()> {
|
||||
let file = File::create(path)?;
|
||||
let dat: (&HashMap<UserId, UserData>, &HashMap<UserId, bool>) = (&self.users, &self.is_spam);
|
||||
serde_json::to_writer(BufWriter::new(file), &dat)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -298,7 +320,7 @@ async fn get_users_data(forge: &Forgejo) -> anyhow::Result<HashMap<UserId, UserD
|
|||
|
||||
async fn load_db() -> anyhow::Result<(Db, Classifier)> {
|
||||
let model_path = Path::new("model.json");
|
||||
let mut classifier = if model_path.is_file() {
|
||||
let classifier = if model_path.is_file() {
|
||||
Classifier::new_from_pre_trained(&mut File::open(model_path)?)?
|
||||
} else {
|
||||
Classifier::new()
|
||||
|
@ -314,70 +336,46 @@ async fn load_db() -> anyhow::Result<(Db, Classifier)> {
|
|||
|
||||
let db_path = Path::new("db.json");
|
||||
let db: Db = if db_path.is_file() {
|
||||
let file = File::open(db_path)?;
|
||||
serde_json::from_reader(BufReader::new(file))?
|
||||
Db::from_path(db_path, &classifier)?
|
||||
} else {
|
||||
let mut db = Db::new();
|
||||
|
||||
db.users = get_users_data(&forge).await?;
|
||||
|
||||
eprintln!("Scoring users...");
|
||||
for &user_id in &db.all_users() {
|
||||
update_user(&mut db, &mut classifier, user_id);
|
||||
}
|
||||
|
||||
let file = File::create(db_path)?;
|
||||
serde_json::to_writer(BufWriter::new(file), &db)?;
|
||||
let db =
|
||||
Db::from_users(
|
||||
get_users_data(&forge).await?,
|
||||
HashMap::new(),
|
||||
&classifier
|
||||
);
|
||||
db.store_to_path(db_path)?;
|
||||
db
|
||||
};
|
||||
|
||||
Ok((db, classifier))
|
||||
}
|
||||
|
||||
fn update_user(db: &mut Db, classifier: &mut Classifier, id: UserId) {
|
||||
let userdata = db.users.get(&id).unwrap();
|
||||
let tokens =
|
||||
match db.tokens.get(&id) {
|
||||
Some(tokens) => tokens,
|
||||
None => {
|
||||
let tokens = userdata.to_tokens();
|
||||
db.tokens.insert(id, tokens);
|
||||
db.tokens.get(&id).unwrap()
|
||||
}
|
||||
};
|
||||
db.score.insert(id, classifier.score(&tokens));
|
||||
}
|
||||
|
||||
fn unclassified_users<'a>(db: &'a Db) -> Vec<(&'a UserId, &'a UserData)> {
|
||||
db.users
|
||||
.iter()
|
||||
.filter(|(user_id, _)| !db.classification.contains_key(&user_id))
|
||||
.filter(|(user_id, _)| !db.is_spam.contains_key(&user_id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)]) {
|
||||
let mut all_tokens = HashSet::new();
|
||||
|
||||
eprintln!("updating classifier");
|
||||
|
||||
for (id, is_spam) in ids {
|
||||
let tokens = db.tokens.get(id).unwrap();
|
||||
|
||||
if *is_spam {
|
||||
classifier.train_spam(tokens);
|
||||
} else {
|
||||
classifier.train_ham(tokens);
|
||||
}
|
||||
|
||||
for tok in tokens {
|
||||
all_tokens.insert(tok.clone());
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!("recomputing user scores");
|
||||
db.recompute_scores(&classifier);
|
||||
|
||||
for &user_id in &db.all_users() {
|
||||
update_user(db, classifier, user_id)
|
||||
eprintln!("updating db with new classification");
|
||||
for (user_id, is_spam) in ids {
|
||||
db.is_spam.insert(*user_id, *is_spam);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -489,7 +487,6 @@ async fn apply(
|
|||
|
||||
set_spam(db, classifier, &updates);
|
||||
|
||||
eprintln!("{:#?}", req);
|
||||
HttpResponse::SeeOther()
|
||||
.insert_header(("Location", "/"))
|
||||
.finish()
|
||||
|
|
|
@ -38,6 +38,9 @@
|
|||
flex-direction: row;
|
||||
gap: 10px;
|
||||
align-items: center;
|
||||
border: 1px dotted #000;
|
||||
padding: 3px 8px 3px 8px;
|
||||
margin: 3px;
|
||||
}
|
||||
|
||||
.user-card {
|
||||
|
@ -87,7 +90,7 @@
|
|||
.score {
|
||||
padding-left: 3px;
|
||||
padding-right: 3px;
|
||||
width: 3em;
|
||||
width: 2.8em;
|
||||
text-align: center;
|
||||
flex-grow: 0;
|
||||
flex-shrink: 0;
|
||||
|
@ -134,8 +137,8 @@
|
|||
</div>
|
||||
<div class="user-card">
|
||||
<div class="user-name">
|
||||
<div>{{ user.login }}</div>
|
||||
{%- if user.full_name %}<div>({{ user.full_name }})</div>{% endif -%}
|
||||
<div><strong>{{ user.login }}</strong></div>
|
||||
{%- if user.full_name %}<div><strong>({{ user.full_name }})</strong></div>{% endif -%}
|
||||
</div>
|
||||
<div class="user-info">
|
||||
{%- if user.location %}<div>[L] {{ user.location }}</div>{% endif -%}
|
||||
|
|
Loading…
Add table
Reference in a new issue