classification UI, etc
problem: scores are slow to update when submitting classifications
This commit is contained in:
parent
5d22662499
commit
d9251ce395
8 changed files with 325 additions and 49 deletions
13
Cargo.lock
generated
13
Cargo.lock
generated
|
@ -326,17 +326,6 @@ version = "1.6.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
|
||||
|
||||
[[package]]
|
||||
name = "bayespam"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fc30da46312e6ef33841929c2b2b40a969631cb954561591c8a76c8ebcbd029"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
|
@ -672,7 +661,6 @@ dependencies = [
|
|||
"actix-files",
|
||||
"actix-web",
|
||||
"anyhow",
|
||||
"bayespam",
|
||||
"forgejo-api",
|
||||
"lazy_static",
|
||||
"rand",
|
||||
|
@ -681,6 +669,7 @@ dependencies = [
|
|||
"serde_json",
|
||||
"tera",
|
||||
"tokio",
|
||||
"unicode-segmentation",
|
||||
"url",
|
||||
]
|
||||
|
||||
|
|
12
Cargo.toml
12
Cargo.toml
|
@ -11,11 +11,11 @@ reqwest = { version = "0.12", features = ["json"] }
|
|||
serde = { version = "1", features = ["derive"] }
|
||||
forgejo-api = "0.4"
|
||||
url = "2"
|
||||
anyhow = "1.0.93"
|
||||
bayespam = "1.1.0"
|
||||
serde_json = "1.0.133"
|
||||
rand = "0.8.5"
|
||||
anyhow = "1"
|
||||
serde_json = "1"
|
||||
rand = "0.8"
|
||||
actix-web = "4"
|
||||
tera = "1"
|
||||
lazy_static = "1.5.0"
|
||||
actix-files = "0.6.6"
|
||||
lazy_static = "1"
|
||||
actix-files = "0.6"
|
||||
unicode-segmentation = "1"
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
{"users":{"2176":"Spam","1376":"Legit","1552":"Legit","2101":"Spam","5366":"Spam","946":"Legit","1863":"Legit","5400":"Spam","4827":"Spam","5968":"Spam","5620":"Spam","5571":"Spam","3879":"Spam","548":"Legit","2487":"Spam","2103":"Spam","3881":"Unknown","4640":"Spam","1905":"Spam","4357":"Spam","3299":"Spam","5611":"Spam","3859":"Spam","5184":"Spam","2934":"Unknown","2897":"Spam","4485":"Unknown","5593":"Spam","5847":"Spam","2887":"Spam","5006":"Spam","5513":"Spam","5524":"Spam","5628":"Spam","5212":"Spam","1985":"Legit","768":"Legit","4683":"Spam","4759":"Spam","4743":"Unknown","4832":"Spam","2630":"Unknown","5516":"Spam","4780":"Spam","2077":"Spam","1231":"Legit","4950":"Spam","2651":"Unknown","4248":"Spam","3489":"Spam","4940":"Spam","2655":"Unknown","12":"Legit","4629":"Spam","2209":"Spam","3626":"Legit","5335":"Unknown","400":"Legit","3590":"Spam","3760":"Spam","5637":"Spam","3077":"Spam","1790":"Spam","5695":"Spam","5235":"Spam","2850":"Legit","2117":"Spam","137":"Legit","3851":"Spam","5778":"Spam","4261":"Unknown"}}
|
1
model.json
Normal file
1
model.json
Normal file
File diff suppressed because one or more lines are too long
136
src/classifier.rs
Normal file
136
src/classifier.rs
Normal file
|
@ -0,0 +1,136 @@
|
|||
// code based on the bayespam crate
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{from_reader, to_writer, to_writer_pretty};
|
||||
use unicode_segmentation::UnicodeSegmentation;
|
||||
|
||||
const INITIAL_RATING: f32 = 0.5;
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
struct Counter {
|
||||
ham: u32,
|
||||
spam: u32,
|
||||
}
|
||||
|
||||
/// A bayesian spam classifier.
|
||||
#[derive(Default, Debug, Serialize, Deserialize)]
|
||||
pub struct Classifier {
|
||||
token_table: HashMap<String, Counter>,
|
||||
}
|
||||
|
||||
impl Classifier {
|
||||
/// Build a new classifier with an empty model.
|
||||
pub fn new() -> Self {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
/// Build a new classifier with a pre-trained model loaded from `file`.
|
||||
pub fn new_from_pre_trained(file: &mut File) -> Result<Self, io::Error> {
|
||||
let pre_trained_model = from_reader(file)?;
|
||||
Ok(pre_trained_model)
|
||||
}
|
||||
|
||||
/// Save the classifier to `file` as JSON.
|
||||
/// The JSON will be pretty printed if `pretty` is `true`.
|
||||
pub fn save(&self, file: &mut File, pretty: bool) -> Result<(), io::Error> {
|
||||
if pretty {
|
||||
to_writer_pretty(file, &self)?;
|
||||
} else {
|
||||
to_writer(file, &self)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Split `msg` into a list of words.
|
||||
pub fn into_word_list(msg: &str) -> Vec<String> {
|
||||
let word_list = msg.unicode_words().collect::<Vec<&str>>();
|
||||
word_list.iter().map(|word| word.to_string()).collect()
|
||||
}
|
||||
|
||||
/// Train the classifier with spam `tokens`.
|
||||
pub fn train_spam(&mut self, tokens: &[String]) {
|
||||
for word in tokens {
|
||||
let counter = self.token_table.entry(word.to_string()).or_default();
|
||||
counter.spam += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Train the classifier with ham `tokens`.
|
||||
pub fn train_ham(&mut self, tokens: &[String]) {
|
||||
for word in tokens {
|
||||
let counter = self.token_table.entry(word.to_string()).or_default();
|
||||
counter.ham += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the total number of spam in token table.
|
||||
fn spam_total_count(&self) -> u32 {
|
||||
self.token_table.values().map(|x| x.spam).sum()
|
||||
}
|
||||
|
||||
/// Return the total number of ham in token table.
|
||||
fn ham_total_count(&self) -> u32 {
|
||||
self.token_table.values().map(|x| x.ham).sum()
|
||||
}
|
||||
|
||||
/// Compute the probability of `tokens` to be part of a spam.
|
||||
fn rate_words(&self, tokens: &[String]) -> Vec<f32> {
|
||||
tokens
|
||||
.into_iter()
|
||||
.map(|word| {
|
||||
// If word was previously added in the model
|
||||
if let Some(counter) = self.token_table.get(word) {
|
||||
// If the word has only been part of spam messages,
|
||||
// assign it a probability of 0.99 to be part of a spam
|
||||
if counter.spam > 0 && counter.ham == 0 {
|
||||
return 0.99;
|
||||
// If the word has only been part of ham messages,
|
||||
// assign it a probability of 0.01 to be part of a spam
|
||||
} else if counter.spam == 0 && counter.ham > 0 {
|
||||
return 0.01;
|
||||
// If the word has been part of both spam and ham messages,
|
||||
// calculate the probability to be part of a spam
|
||||
} else if self.spam_total_count() > 0 && self.ham_total_count() > 0 {
|
||||
let ham_prob = (counter.ham as f32) / (self.ham_total_count() as f32);
|
||||
let spam_prob = (counter.spam as f32) / (self.spam_total_count() as f32);
|
||||
return (spam_prob / (ham_prob + spam_prob)).max(0.01);
|
||||
}
|
||||
}
|
||||
// If word was never added to the model,
|
||||
// assign it an initial probability to be part of a spam
|
||||
INITIAL_RATING
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute the spam score of `tokens`.
|
||||
/// The higher the score, the stronger the liklihood that `tokens` are spam is.
|
||||
pub fn score(&self, tokens: &[String]) -> f32 {
|
||||
// Compute the probability of each word to be part of a spam
|
||||
let ratings = self.rate_words(tokens);
|
||||
|
||||
let ratings = match ratings.len() {
|
||||
// If there are no ratings, return a score of 0
|
||||
0 => return 0.0,
|
||||
// If there are more than 20 ratings, keep only the 10 first
|
||||
// and 10 last ratings to calculate a score
|
||||
x if x > 20 => {
|
||||
let length = ratings.len();
|
||||
let mut ratings = ratings;
|
||||
ratings.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
[&ratings[..10], &ratings[length - 10..]].concat()
|
||||
}
|
||||
// In all other cases, keep ratings to calculate a score
|
||||
_ => ratings,
|
||||
};
|
||||
|
||||
// Combine individual ratings
|
||||
let product: f32 = ratings.iter().product();
|
||||
let alt_product: f32 = ratings.iter().map(|x| 1.0 - x).product();
|
||||
product / (product + alt_product)
|
||||
}
|
||||
}
|
125
src/main.rs
125
src/main.rs
|
@ -1,10 +1,9 @@
|
|||
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
|
||||
use bayespam::classifier::Classifier;
|
||||
use forgejo_api::{Auth, Forgejo};
|
||||
use lazy_static::lazy_static;
|
||||
use rand::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter};
|
||||
use std::path::Path;
|
||||
|
@ -12,6 +11,9 @@ use std::sync::Mutex;
|
|||
use tera::Tera;
|
||||
use tokio::time::{sleep, Duration};
|
||||
|
||||
mod classifier;
|
||||
use classifier::Classifier;
|
||||
|
||||
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
struct RepoId(i64);
|
||||
|
||||
|
@ -57,12 +59,15 @@ use Classification::*;
|
|||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Db {
|
||||
users: HashMap<UserId, UserData>,
|
||||
text: HashMap<UserId, String>,
|
||||
classification: HashMap<UserId, Classification>,
|
||||
// caches: derived from the rest
|
||||
score: HashMap<UserId, f32>,
|
||||
tokens: HashMap<UserId, Vec<String>>,
|
||||
users_of_token: HashMap<String, Vec<UserId>>,
|
||||
}
|
||||
|
||||
impl UserData {
|
||||
fn to_text(&self) -> String {
|
||||
fn to_tokens(&self) -> Vec<String> {
|
||||
let mut text = String::new();
|
||||
let mut add = |s: &str| {
|
||||
text += s;
|
||||
|
@ -101,7 +106,7 @@ impl UserData {
|
|||
add(&issue.body);
|
||||
}
|
||||
|
||||
text
|
||||
Classifier::into_word_list(&text)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -109,8 +114,10 @@ impl Db {
|
|||
fn new() -> Db {
|
||||
Db {
|
||||
users: HashMap::new(),
|
||||
text: HashMap::new(),
|
||||
tokens: HashMap::new(),
|
||||
classification: HashMap::new(),
|
||||
score: HashMap::new(),
|
||||
users_of_token: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -289,7 +296,7 @@ async fn get_users_data(forge: &Forgejo) -> anyhow::Result<HashMap<UserId, UserD
|
|||
|
||||
async fn load_db() -> anyhow::Result<(Db, Classifier)> {
|
||||
let model_path = Path::new("model.json");
|
||||
let classifier = if model_path.is_file() {
|
||||
let mut classifier = if model_path.is_file() {
|
||||
Classifier::new_from_pre_trained(&mut File::open(model_path)?)?
|
||||
} else {
|
||||
Classifier::new()
|
||||
|
@ -312,8 +319,10 @@ async fn load_db() -> anyhow::Result<(Db, Classifier)> {
|
|||
|
||||
db.users = get_users_data(&forge).await?;
|
||||
|
||||
for (user_id, user) in &db.users {
|
||||
db.text.insert(*user_id, user.to_text());
|
||||
eprintln!("Scoring users...");
|
||||
let ids: Vec<_> = db.users.iter().map(|(id, _)| *id).collect();
|
||||
for &user_id in &ids {
|
||||
update_user(&mut db, &mut classifier, user_id);
|
||||
}
|
||||
|
||||
let file = File::create(db_path)?;
|
||||
|
@ -324,13 +333,66 @@ async fn load_db() -> anyhow::Result<(Db, Classifier)> {
|
|||
Ok((db, classifier))
|
||||
}
|
||||
|
||||
fn unclassified_users<'a>(db: &'a Db, classifier: &Classifier) -> Vec<(&'a UserId, &'a UserData)> {
|
||||
fn update_user(db: &mut Db, classifier: &mut Classifier, id: UserId) {
|
||||
let tokens = db.users.get(&id).unwrap().to_tokens();
|
||||
let score = classifier.score(&tokens);
|
||||
|
||||
for tok in &tokens {
|
||||
db.users_of_token.entry(tok.to_string()).or_default().push(id)
|
||||
};
|
||||
|
||||
db.tokens.insert(id, tokens);
|
||||
db.score.insert(id, score);
|
||||
}
|
||||
|
||||
fn unclassified_users<'a>(db: &'a Db) -> Vec<(&'a UserId, &'a UserData)> {
|
||||
db.users
|
||||
.iter()
|
||||
.filter(|(user_id, _)| !db.classification.contains_key(&user_id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)]) {
|
||||
let mut all_tokens = HashSet::new();
|
||||
|
||||
eprintln!("training classifier");
|
||||
|
||||
for (id, is_spam) in ids {
|
||||
let tokens = db.tokens.get(id).unwrap();
|
||||
|
||||
if *is_spam {
|
||||
classifier.train_spam(tokens);
|
||||
} else {
|
||||
classifier.train_ham(tokens);
|
||||
}
|
||||
|
||||
for tok in tokens {
|
||||
all_tokens.insert(tok.clone());
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!("computing users to update");
|
||||
|
||||
let mut users_to_update = HashSet::new();
|
||||
|
||||
for token in all_tokens {
|
||||
match db.users_of_token.get(&token) {
|
||||
None => (),
|
||||
Some(users) => {
|
||||
for user in users {
|
||||
users_to_update.insert(*user);
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!("recomputing scores for {}/{} users", users_to_update.len(), db.users.len());
|
||||
|
||||
for user in users_to_update {
|
||||
update_user(db, classifier, user)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
async fn main_() -> anyhow::Result<()> {
|
||||
println!("got {} users", db.users.len());
|
||||
|
@ -401,19 +463,20 @@ async fn index(data: web::Data<AppState>) -> impl Responder {
|
|||
eprintln!("GET /");
|
||||
|
||||
let db = &data.db.lock().unwrap();
|
||||
let classifier = &data.classifier.lock().unwrap();
|
||||
|
||||
eprintln!("compute unclassified users");
|
||||
|
||||
let users: Vec<&UserData> = unclassified_users(db, classifier)
|
||||
.into_iter()
|
||||
.map(|(_id, u)| u)
|
||||
eprintln!("scoring users...");
|
||||
let mut users: Vec<_> =
|
||||
unclassified_users(db).into_iter()
|
||||
.map(|(id, u)| (id, u, *db.score.get(id).unwrap()))
|
||||
.collect();
|
||||
let mut rng = rand::thread_rng();
|
||||
eprintln!("randomizing...");
|
||||
users.shuffle(&mut rng);
|
||||
eprintln!("sorting...");
|
||||
users.sort_by_key(|(_, _, score)| 1000 - (score * 1000.) as u64);
|
||||
users.truncate(50);
|
||||
|
||||
let mut context = tera::Context::new();
|
||||
|
||||
eprintln!("insert users into tera context");
|
||||
|
||||
context.insert("users", &users);
|
||||
eprintln!("rendering template...");
|
||||
let page = TEMPLATES.render("index.html", &context).unwrap();
|
||||
|
@ -424,9 +487,21 @@ async fn index(data: web::Data<AppState>) -> impl Responder {
|
|||
#[post("/")]
|
||||
async fn apply(
|
||||
data: web::Data<AppState>,
|
||||
req: web::Form<HashMap<String, String>>,
|
||||
req: web::Form<HashMap<i64, String>>,
|
||||
) -> impl Responder {
|
||||
println!("{:#?}", req);
|
||||
eprintln!("POST /");
|
||||
|
||||
let db = &mut data.db.lock().unwrap();
|
||||
let classifier = &mut data.classifier.lock().unwrap();
|
||||
|
||||
let updates: Vec<(UserId, bool)> =
|
||||
req.iter()
|
||||
.map(|(id, classification)| (UserId(*id), classification == "spam"))
|
||||
.collect();
|
||||
|
||||
set_spam(db, classifier, &updates);
|
||||
|
||||
eprintln!("{:#?}", req);
|
||||
HttpResponse::SeeOther()
|
||||
.insert_header(("Location", "/"))
|
||||
.finish()
|
||||
|
@ -434,16 +509,18 @@ async fn apply(
|
|||
|
||||
#[actix_web::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
let (db, classifier) = load_db().await.unwrap(); // FIXME
|
||||
eprintln!("Eval templates");
|
||||
let _ = *TEMPLATES;
|
||||
|
||||
println!("Done loading DB");
|
||||
eprintln!("Load users and repos");
|
||||
let (db, classifier) = load_db().await.unwrap(); // FIXME
|
||||
|
||||
let st = web::Data::new(AppState {
|
||||
db: Mutex::new(db),
|
||||
classifier: Mutex::new(classifier),
|
||||
});
|
||||
|
||||
println!("Launching web server at http://127.0.0.1:8080...");
|
||||
println!("Launch web server at http://127.0.0.1:8080");
|
||||
|
||||
HttpServer::new(move || {
|
||||
App::new()
|
||||
|
|
0
src/scrape.rs
Normal file
0
src/scrape.rs
Normal file
|
@ -20,16 +20,24 @@
|
|||
<!-- <link rel="stylesheet" href="css/mdb.min.css" /> -->
|
||||
</head>
|
||||
<style>
|
||||
.flex-wrapper {
|
||||
.main {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 30px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.users {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 15px;
|
||||
}
|
||||
|
||||
|
||||
.user {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
gap: 10px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.user-card {
|
||||
|
@ -50,16 +58,79 @@
|
|||
flex-wrap: wrap;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.user-classification {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 3px;
|
||||
}
|
||||
|
||||
input.radio-classify {
|
||||
display: none;
|
||||
}
|
||||
input.radio-classify + label {
|
||||
border: 1px solid #000;
|
||||
padding: 2px;
|
||||
text-align: center;
|
||||
}
|
||||
input.radio-spam:checked + label {
|
||||
border: 1px solid #d00400;
|
||||
background: #d00400;
|
||||
color: #fff;
|
||||
}
|
||||
input.radio-legit:checked + label {
|
||||
border: 1px solid #048e02;
|
||||
background: #048e02;
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
.score {
|
||||
padding-left: 3px;
|
||||
padding-right: 3px;
|
||||
width: 3em;
|
||||
text-align: center;
|
||||
flex-grow: 0;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.score-high {
|
||||
background: #ff696b;
|
||||
}
|
||||
.score-mid {
|
||||
background: #ffa769;
|
||||
}
|
||||
.score-low {
|
||||
background: #5fd770;
|
||||
}
|
||||
</style>
|
||||
<body>
|
||||
<form method="post">
|
||||
<input type="submit" value="Apply"/>
|
||||
<div class="flex-wrapper">
|
||||
<div class="main">
|
||||
<div class="users">
|
||||
|
||||
{% for user in users %}
|
||||
{% for id_user_score in users %}
|
||||
{% set user_id = id_user_score[0] %}
|
||||
{% set user = id_user_score[1] %}
|
||||
{% set score = id_user_score[2] %}
|
||||
<div class="user">
|
||||
<div class="user-classification">
|
||||
<input type="checkbox" name="{{user.login}}" style="scale: 1.2"/>
|
||||
<input type="radio" name="{{user_id}}" id="{{user_id}}-spam" value="spam"
|
||||
class="radio-classify radio-spam"
|
||||
{% if score >= 0.8 %}checked{% endif %}
|
||||
/>
|
||||
<label for="{{user_id}}-spam">Spam</label>
|
||||
<input type="radio" name="{{user_id}}" id="{{user_id}}-legit" value="legit"
|
||||
class="radio-classify radio-legit"
|
||||
{% if score < 0.8 %}checked{% endif %}
|
||||
/>
|
||||
<label for="{{user_id}}-legit">Legit</label>
|
||||
</div>
|
||||
<div class="score
|
||||
{% if score >= 0.8 %} score-high {% endif %}
|
||||
{% if score < 0.8 and score > 0.3 %} score-mid {% endif %}
|
||||
{% if score <= 0.3 %} score-low {% endif %}
|
||||
">
|
||||
{{ score | round(precision=2) }}
|
||||
</div>
|
||||
<div class="user-card">
|
||||
<div class="user-name">
|
||||
|
@ -92,6 +163,9 @@
|
|||
{% endfor %}
|
||||
|
||||
</div>
|
||||
|
||||
<input type="submit" value="Apply" class="button" style="width: 200px; height: 30px"/>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
</body>
|
||||
|
|
Loading…
Reference in a new issue