classification UI, etc

problem: scores are slow to update when submitting classifications
This commit is contained in:
Armaël Guéneau 2024-11-22 16:26:17 +01:00
parent 5d22662499
commit d9251ce395
8 changed files with 325 additions and 49 deletions

13
Cargo.lock generated
View file

@ -326,17 +326,6 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "bayespam"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fc30da46312e6ef33841929c2b2b40a969631cb954561591c8a76c8ebcbd029"
dependencies = [
"serde",
"serde_json",
"unicode-segmentation",
]
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
@ -672,7 +661,6 @@ dependencies = [
"actix-files", "actix-files",
"actix-web", "actix-web",
"anyhow", "anyhow",
"bayespam",
"forgejo-api", "forgejo-api",
"lazy_static", "lazy_static",
"rand", "rand",
@ -681,6 +669,7 @@ dependencies = [
"serde_json", "serde_json",
"tera", "tera",
"tokio", "tokio",
"unicode-segmentation",
"url", "url",
] ]

View file

@ -11,11 +11,11 @@ reqwest = { version = "0.12", features = ["json"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
forgejo-api = "0.4" forgejo-api = "0.4"
url = "2" url = "2"
anyhow = "1.0.93" anyhow = "1"
bayespam = "1.1.0" serde_json = "1"
serde_json = "1.0.133" rand = "0.8"
rand = "0.8.5"
actix-web = "4" actix-web = "4"
tera = "1" tera = "1"
lazy_static = "1.5.0" lazy_static = "1"
actix-files = "0.6.6" actix-files = "0.6"
unicode-segmentation = "1"

View file

@ -1 +0,0 @@
{"users":{"2176":"Spam","1376":"Legit","1552":"Legit","2101":"Spam","5366":"Spam","946":"Legit","1863":"Legit","5400":"Spam","4827":"Spam","5968":"Spam","5620":"Spam","5571":"Spam","3879":"Spam","548":"Legit","2487":"Spam","2103":"Spam","3881":"Unknown","4640":"Spam","1905":"Spam","4357":"Spam","3299":"Spam","5611":"Spam","3859":"Spam","5184":"Spam","2934":"Unknown","2897":"Spam","4485":"Unknown","5593":"Spam","5847":"Spam","2887":"Spam","5006":"Spam","5513":"Spam","5524":"Spam","5628":"Spam","5212":"Spam","1985":"Legit","768":"Legit","4683":"Spam","4759":"Spam","4743":"Unknown","4832":"Spam","2630":"Unknown","5516":"Spam","4780":"Spam","2077":"Spam","1231":"Legit","4950":"Spam","2651":"Unknown","4248":"Spam","3489":"Spam","4940":"Spam","2655":"Unknown","12":"Legit","4629":"Spam","2209":"Spam","3626":"Legit","5335":"Unknown","400":"Legit","3590":"Spam","3760":"Spam","5637":"Spam","3077":"Spam","1790":"Spam","5695":"Spam","5235":"Spam","2850":"Legit","2117":"Spam","137":"Legit","3851":"Spam","5778":"Spam","4261":"Unknown"}}

1
model.json Normal file

File diff suppressed because one or more lines are too long

136
src/classifier.rs Normal file
View file

@ -0,0 +1,136 @@
// code based on the bayespam crate
use std::collections::HashMap;
use std::fs::File;
use std::io;
use serde::{Deserialize, Serialize};
use serde_json::{from_reader, to_writer, to_writer_pretty};
use unicode_segmentation::UnicodeSegmentation;
const INITIAL_RATING: f32 = 0.5;
#[derive(Debug, Default, Serialize, Deserialize)]
struct Counter {
ham: u32,
spam: u32,
}
/// A bayesian spam classifier.
#[derive(Default, Debug, Serialize, Deserialize)]
pub struct Classifier {
token_table: HashMap<String, Counter>,
}
impl Classifier {
/// Build a new classifier with an empty model.
pub fn new() -> Self {
Default::default()
}
/// Build a new classifier with a pre-trained model loaded from `file`.
pub fn new_from_pre_trained(file: &mut File) -> Result<Self, io::Error> {
let pre_trained_model = from_reader(file)?;
Ok(pre_trained_model)
}
/// Save the classifier to `file` as JSON.
/// The JSON will be pretty printed if `pretty` is `true`.
pub fn save(&self, file: &mut File, pretty: bool) -> Result<(), io::Error> {
if pretty {
to_writer_pretty(file, &self)?;
} else {
to_writer(file, &self)?;
}
Ok(())
}
/// Split `msg` into a list of words.
pub fn into_word_list(msg: &str) -> Vec<String> {
let word_list = msg.unicode_words().collect::<Vec<&str>>();
word_list.iter().map(|word| word.to_string()).collect()
}
/// Train the classifier with spam `tokens`.
pub fn train_spam(&mut self, tokens: &[String]) {
for word in tokens {
let counter = self.token_table.entry(word.to_string()).or_default();
counter.spam += 1;
}
}
/// Train the classifier with ham `tokens`.
pub fn train_ham(&mut self, tokens: &[String]) {
for word in tokens {
let counter = self.token_table.entry(word.to_string()).or_default();
counter.ham += 1;
}
}
/// Return the total number of spam in token table.
fn spam_total_count(&self) -> u32 {
self.token_table.values().map(|x| x.spam).sum()
}
/// Return the total number of ham in token table.
fn ham_total_count(&self) -> u32 {
self.token_table.values().map(|x| x.ham).sum()
}
/// Compute the probability of `tokens` to be part of a spam.
fn rate_words(&self, tokens: &[String]) -> Vec<f32> {
tokens
.into_iter()
.map(|word| {
// If word was previously added in the model
if let Some(counter) = self.token_table.get(word) {
// If the word has only been part of spam messages,
// assign it a probability of 0.99 to be part of a spam
if counter.spam > 0 && counter.ham == 0 {
return 0.99;
// If the word has only been part of ham messages,
// assign it a probability of 0.01 to be part of a spam
} else if counter.spam == 0 && counter.ham > 0 {
return 0.01;
// If the word has been part of both spam and ham messages,
// calculate the probability to be part of a spam
} else if self.spam_total_count() > 0 && self.ham_total_count() > 0 {
let ham_prob = (counter.ham as f32) / (self.ham_total_count() as f32);
let spam_prob = (counter.spam as f32) / (self.spam_total_count() as f32);
return (spam_prob / (ham_prob + spam_prob)).max(0.01);
}
}
// If word was never added to the model,
// assign it an initial probability to be part of a spam
INITIAL_RATING
})
.collect()
}
/// Compute the spam score of `tokens`.
/// The higher the score, the stronger the liklihood that `tokens` are spam is.
pub fn score(&self, tokens: &[String]) -> f32 {
// Compute the probability of each word to be part of a spam
let ratings = self.rate_words(tokens);
let ratings = match ratings.len() {
// If there are no ratings, return a score of 0
0 => return 0.0,
// If there are more than 20 ratings, keep only the 10 first
// and 10 last ratings to calculate a score
x if x > 20 => {
let length = ratings.len();
let mut ratings = ratings;
ratings.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
[&ratings[..10], &ratings[length - 10..]].concat()
}
// In all other cases, keep ratings to calculate a score
_ => ratings,
};
// Combine individual ratings
let product: f32 = ratings.iter().product();
let alt_product: f32 = ratings.iter().map(|x| 1.0 - x).product();
product / (product + alt_product)
}
}

View file

@ -1,10 +1,9 @@
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
use bayespam::classifier::Classifier;
use forgejo_api::{Auth, Forgejo}; use forgejo_api::{Auth, Forgejo};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use rand::prelude::*; use rand::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use std::fs::File; use std::fs::File;
use std::io::{BufReader, BufWriter}; use std::io::{BufReader, BufWriter};
use std::path::Path; use std::path::Path;
@ -12,6 +11,9 @@ use std::sync::Mutex;
use tera::Tera; use tera::Tera;
use tokio::time::{sleep, Duration}; use tokio::time::{sleep, Duration};
mod classifier;
use classifier::Classifier;
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
struct RepoId(i64); struct RepoId(i64);
@ -57,12 +59,15 @@ use Classification::*;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct Db { struct Db {
users: HashMap<UserId, UserData>, users: HashMap<UserId, UserData>,
text: HashMap<UserId, String>,
classification: HashMap<UserId, Classification>, classification: HashMap<UserId, Classification>,
// caches: derived from the rest
score: HashMap<UserId, f32>,
tokens: HashMap<UserId, Vec<String>>,
users_of_token: HashMap<String, Vec<UserId>>,
} }
impl UserData { impl UserData {
fn to_text(&self) -> String { fn to_tokens(&self) -> Vec<String> {
let mut text = String::new(); let mut text = String::new();
let mut add = |s: &str| { let mut add = |s: &str| {
text += s; text += s;
@ -101,7 +106,7 @@ impl UserData {
add(&issue.body); add(&issue.body);
} }
text Classifier::into_word_list(&text)
} }
} }
@ -109,8 +114,10 @@ impl Db {
fn new() -> Db { fn new() -> Db {
Db { Db {
users: HashMap::new(), users: HashMap::new(),
text: HashMap::new(), tokens: HashMap::new(),
classification: HashMap::new(), classification: HashMap::new(),
score: HashMap::new(),
users_of_token: HashMap::new(),
} }
} }
} }
@ -289,7 +296,7 @@ async fn get_users_data(forge: &Forgejo) -> anyhow::Result<HashMap<UserId, UserD
async fn load_db() -> anyhow::Result<(Db, Classifier)> { async fn load_db() -> anyhow::Result<(Db, Classifier)> {
let model_path = Path::new("model.json"); let model_path = Path::new("model.json");
let classifier = if model_path.is_file() { let mut classifier = if model_path.is_file() {
Classifier::new_from_pre_trained(&mut File::open(model_path)?)? Classifier::new_from_pre_trained(&mut File::open(model_path)?)?
} else { } else {
Classifier::new() Classifier::new()
@ -312,8 +319,10 @@ async fn load_db() -> anyhow::Result<(Db, Classifier)> {
db.users = get_users_data(&forge).await?; db.users = get_users_data(&forge).await?;
for (user_id, user) in &db.users { eprintln!("Scoring users...");
db.text.insert(*user_id, user.to_text()); let ids: Vec<_> = db.users.iter().map(|(id, _)| *id).collect();
for &user_id in &ids {
update_user(&mut db, &mut classifier, user_id);
} }
let file = File::create(db_path)?; let file = File::create(db_path)?;
@ -324,13 +333,66 @@ async fn load_db() -> anyhow::Result<(Db, Classifier)> {
Ok((db, classifier)) Ok((db, classifier))
} }
fn unclassified_users<'a>(db: &'a Db, classifier: &Classifier) -> Vec<(&'a UserId, &'a UserData)> { fn update_user(db: &mut Db, classifier: &mut Classifier, id: UserId) {
let tokens = db.users.get(&id).unwrap().to_tokens();
let score = classifier.score(&tokens);
for tok in &tokens {
db.users_of_token.entry(tok.to_string()).or_default().push(id)
};
db.tokens.insert(id, tokens);
db.score.insert(id, score);
}
fn unclassified_users<'a>(db: &'a Db) -> Vec<(&'a UserId, &'a UserData)> {
db.users db.users
.iter() .iter()
.filter(|(user_id, _)| !db.classification.contains_key(&user_id)) .filter(|(user_id, _)| !db.classification.contains_key(&user_id))
.collect() .collect()
} }
fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)]) {
let mut all_tokens = HashSet::new();
eprintln!("training classifier");
for (id, is_spam) in ids {
let tokens = db.tokens.get(id).unwrap();
if *is_spam {
classifier.train_spam(tokens);
} else {
classifier.train_ham(tokens);
}
for tok in tokens {
all_tokens.insert(tok.clone());
}
}
eprintln!("computing users to update");
let mut users_to_update = HashSet::new();
for token in all_tokens {
match db.users_of_token.get(&token) {
None => (),
Some(users) => {
for user in users {
users_to_update.insert(*user);
}
},
}
}
eprintln!("recomputing scores for {}/{} users", users_to_update.len(), db.users.len());
for user in users_to_update {
update_user(db, classifier, user)
}
}
/* /*
async fn main_() -> anyhow::Result<()> { async fn main_() -> anyhow::Result<()> {
println!("got {} users", db.users.len()); println!("got {} users", db.users.len());
@ -401,19 +463,20 @@ async fn index(data: web::Data<AppState>) -> impl Responder {
eprintln!("GET /"); eprintln!("GET /");
let db = &data.db.lock().unwrap(); let db = &data.db.lock().unwrap();
let classifier = &data.classifier.lock().unwrap();
eprintln!("compute unclassified users"); eprintln!("scoring users...");
let mut users: Vec<_> =
let users: Vec<&UserData> = unclassified_users(db, classifier) unclassified_users(db).into_iter()
.into_iter() .map(|(id, u)| (id, u, *db.score.get(id).unwrap()))
.map(|(_id, u)| u)
.collect(); .collect();
let mut rng = rand::thread_rng();
eprintln!("randomizing...");
users.shuffle(&mut rng);
eprintln!("sorting...");
users.sort_by_key(|(_, _, score)| 1000 - (score * 1000.) as u64);
users.truncate(50);
let mut context = tera::Context::new(); let mut context = tera::Context::new();
eprintln!("insert users into tera context");
context.insert("users", &users); context.insert("users", &users);
eprintln!("rendering template..."); eprintln!("rendering template...");
let page = TEMPLATES.render("index.html", &context).unwrap(); let page = TEMPLATES.render("index.html", &context).unwrap();
@ -424,9 +487,21 @@ async fn index(data: web::Data<AppState>) -> impl Responder {
#[post("/")] #[post("/")]
async fn apply( async fn apply(
data: web::Data<AppState>, data: web::Data<AppState>,
req: web::Form<HashMap<String, String>>, req: web::Form<HashMap<i64, String>>,
) -> impl Responder { ) -> impl Responder {
println!("{:#?}", req); eprintln!("POST /");
let db = &mut data.db.lock().unwrap();
let classifier = &mut data.classifier.lock().unwrap();
let updates: Vec<(UserId, bool)> =
req.iter()
.map(|(id, classification)| (UserId(*id), classification == "spam"))
.collect();
set_spam(db, classifier, &updates);
eprintln!("{:#?}", req);
HttpResponse::SeeOther() HttpResponse::SeeOther()
.insert_header(("Location", "/")) .insert_header(("Location", "/"))
.finish() .finish()
@ -434,16 +509,18 @@ async fn apply(
#[actix_web::main] #[actix_web::main]
async fn main() -> std::io::Result<()> { async fn main() -> std::io::Result<()> {
let (db, classifier) = load_db().await.unwrap(); // FIXME eprintln!("Eval templates");
let _ = *TEMPLATES;
println!("Done loading DB"); eprintln!("Load users and repos");
let (db, classifier) = load_db().await.unwrap(); // FIXME
let st = web::Data::new(AppState { let st = web::Data::new(AppState {
db: Mutex::new(db), db: Mutex::new(db),
classifier: Mutex::new(classifier), classifier: Mutex::new(classifier),
}); });
println!("Launching web server at http://127.0.0.1:8080..."); println!("Launch web server at http://127.0.0.1:8080");
HttpServer::new(move || { HttpServer::new(move || {
App::new() App::new()

0
src/scrape.rs Normal file
View file

View file

@ -20,16 +20,24 @@
<!-- <link rel="stylesheet" href="css/mdb.min.css" /> --> <!-- <link rel="stylesheet" href="css/mdb.min.css" /> -->
</head> </head>
<style> <style>
.flex-wrapper { .main {
display: flex;
flex-direction: column;
gap: 30px;
align-items: center;
}
.users {
display: flex; display: flex;
flex-direction: column; flex-direction: column;
gap: 15px; gap: 15px;
} }
.user { .user {
display: flex; display: flex;
flex-direction: row; flex-direction: row;
gap: 10px; gap: 10px;
align-items: center;
} }
.user-card { .user-card {
@ -50,16 +58,79 @@
flex-wrap: wrap; flex-wrap: wrap;
gap: 10px; gap: 10px;
} }
.user-classification {
display: flex;
flex-direction: column;
gap: 3px;
}
input.radio-classify {
display: none;
}
input.radio-classify + label {
border: 1px solid #000;
padding: 2px;
text-align: center;
}
input.radio-spam:checked + label {
border: 1px solid #d00400;
background: #d00400;
color: #fff;
}
input.radio-legit:checked + label {
border: 1px solid #048e02;
background: #048e02;
color: #fff;
}
.score {
padding-left: 3px;
padding-right: 3px;
width: 3em;
text-align: center;
flex-grow: 0;
flex-shrink: 0;
}
.score-high {
background: #ff696b;
}
.score-mid {
background: #ffa769;
}
.score-low {
background: #5fd770;
}
</style> </style>
<body> <body>
<form method="post"> <form method="post">
<input type="submit" value="Apply"/> <div class="main">
<div class="flex-wrapper"> <div class="users">
{% for user in users %} {% for id_user_score in users %}
{% set user_id = id_user_score[0] %}
{% set user = id_user_score[1] %}
{% set score = id_user_score[2] %}
<div class="user"> <div class="user">
<div class="user-classification"> <div class="user-classification">
<input type="checkbox" name="{{user.login}}" style="scale: 1.2"/> <input type="radio" name="{{user_id}}" id="{{user_id}}-spam" value="spam"
class="radio-classify radio-spam"
{% if score >= 0.8 %}checked{% endif %}
/>
<label for="{{user_id}}-spam">Spam</label>
<input type="radio" name="{{user_id}}" id="{{user_id}}-legit" value="legit"
class="radio-classify radio-legit"
{% if score < 0.8 %}checked{% endif %}
/>
<label for="{{user_id}}-legit">Legit</label>
</div>
<div class="score
{% if score >= 0.8 %} score-high {% endif %}
{% if score < 0.8 and score > 0.3 %} score-mid {% endif %}
{% if score <= 0.3 %} score-low {% endif %}
">
{{ score | round(precision=2) }}
</div> </div>
<div class="user-card"> <div class="user-card">
<div class="user-name"> <div class="user-name">
@ -92,6 +163,9 @@
{% endfor %} {% endfor %}
</div> </div>
<input type="submit" value="Apply" class="button" style="width: 200px; height: 30px"/>
</div>
</form> </form>
</body> </body>