cargo fmt

This commit is contained in:
Armaël Guéneau 2024-12-19 12:49:58 +01:00
parent ddda6cc1cf
commit 45ff1f3ea5
5 changed files with 93 additions and 48 deletions

View file

@ -1,5 +1,5 @@
use serde::{Deserialize, Serialize};
use crate::classifier::Classifier; use crate::classifier::Classifier;
use serde::{Deserialize, Serialize};
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct UserId(pub i64); pub struct UserId(pub i64);

View file

@ -1,12 +1,12 @@
use crate::classifier::Classifier;
use crate::data::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::path::Path; use std::fmt;
use std::fs::File; use std::fs::File;
use std::io::{BufReader, BufWriter}; use std::io::{BufReader, BufWriter};
use std::path::Path;
use std::time::SystemTime; use std::time::SystemTime;
use std::fmt;
use serde::{Serialize, Deserialize};
use crate::data::*;
use crate::classifier::Classifier;
#[derive(Serialize, Deserialize, Debug, Clone, Copy)] #[derive(Serialize, Deserialize, Debug, Clone, Copy)]
pub enum IsSpam { pub enum IsSpam {
@ -24,7 +24,9 @@ impl IsSpam {
pub fn from_bool(b: bool) -> IsSpam { pub fn from_bool(b: bool) -> IsSpam {
if b { if b {
IsSpam::Spam { classified_at: SystemTime::now() } IsSpam::Spam {
classified_at: SystemTime::now(),
}
} else { } else {
IsSpam::Legit IsSpam::Legit
} }
@ -67,8 +69,7 @@ impl Db {
pub fn from_path(path: &Path, classifier: &Classifier) -> anyhow::Result<Self> { pub fn from_path(path: &Path, classifier: &Classifier) -> anyhow::Result<Self> {
let file = File::open(path)?; let file = File::open(path)?;
let (users, is_spam, last_scrape) = let (users, is_spam, last_scrape) = serde_json::from_reader(BufReader::new(file))?;
serde_json::from_reader(BufReader::new(file))?;
let mut db = Db { let mut db = Db {
users, users,
is_spam, is_spam,
@ -100,8 +101,11 @@ impl Db {
pub fn store_to_path(&self, path: &Path) -> anyhow::Result<()> { pub fn store_to_path(&self, path: &Path) -> anyhow::Result<()> {
let file = File::create(path)?; let file = File::create(path)?;
let dat: (&HashMap<UserId, UserData>, &HashMap<UserId, IsSpam>, SystemTime) = let dat: (
(&self.users, &self.is_spam, self.last_scrape); &HashMap<UserId, UserData>,
&HashMap<UserId, IsSpam>,
SystemTime,
) = (&self.users, &self.is_spam, self.last_scrape);
serde_json::to_writer(BufWriter::new(file), &dat)?; serde_json::to_writer(BufWriter::new(file), &dat)?;
Ok(()) Ok(())
} }
@ -116,9 +120,11 @@ impl Db {
pub fn classified_users<'a>(&'a self) -> Vec<(&'a UserId, &'a UserData, IsSpam)> { pub fn classified_users<'a>(&'a self) -> Vec<(&'a UserId, &'a UserData, IsSpam)> {
self.users self.users
.iter() .iter()
.filter_map(|(user_id, user_data)| .filter_map(|(user_id, user_data)| {
self.is_spam.get(&user_id).map(|is_spam| (user_id, user_data, *is_spam)) self.is_spam
) .get(&user_id)
.map(|is_spam| (user_id, user_data, *is_spam))
})
.collect() .collect()
} }
} }

View file

@ -2,7 +2,7 @@ use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Resp
use forgejo_api::{Auth, Forgejo}; use forgejo_api::{Auth, Forgejo};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use rand::prelude::*; use rand::prelude::*;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::fs::File; use std::fs::File;
use std::path::Path; use std::path::Path;
@ -18,7 +18,7 @@ mod workers;
use classifier::Classifier; use classifier::Classifier;
use data::*; use data::*;
use db::{IsSpam, Db}; use db::{Db, IsSpam};
// Fetch user data from forgejo from time to time // Fetch user data from forgejo from time to time
const FORGEJO_POLL_DELAY: Duration = Duration::from_secs(11 * 3600); // 11 hours const FORGEJO_POLL_DELAY: Duration = Duration::from_secs(11 * 3600); // 11 hours
@ -54,7 +54,11 @@ async fn load_db(forge: &Forgejo) -> anyhow::Result<(Db, Classifier)> {
let db: Db = if db_path.is_file() { let db: Db = if db_path.is_file() {
Db::from_path(db_path, &classifier)? Db::from_path(db_path, &classifier)?
} else { } else {
let db = Db::from_users(scrape::get_user_data(&forge).await?, HashMap::new(), &classifier); let db = Db::from_users(
scrape::get_user_data(&forge).await?,
HashMap::new(),
&classifier,
);
db.store_to_path(db_path)?; db.store_to_path(db_path)?;
db db
}; };
@ -72,7 +76,8 @@ fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)], ov
eprintln!( eprintln!(
"User {}: changing classification from {} to {}", "User {}: changing classification from {} to {}",
db.users.get(&user_id).unwrap().login, db.users.get(&user_id).unwrap().login,
was_spam, is_spam was_spam,
is_spam
); );
db.is_spam.insert(user_id, IsSpam::from_bool(is_spam)); db.is_spam.insert(user_id, IsSpam::from_bool(is_spam));
// This is somewhat hackish: we already trained the classifier // This is somewhat hackish: we already trained the classifier
@ -84,7 +89,7 @@ fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)], ov
// tokens that we have now are the same as the one that were // tokens that we have now are the same as the one that were
// used previously). // used previously).
train_classifier = true; train_classifier = true;
}, }
Some(&was_spam) if !overwrite && was_spam.as_bool() != is_spam => { Some(&was_spam) if !overwrite && was_spam.as_bool() != is_spam => {
// Classification conflict between concurrent queries. // Classification conflict between concurrent queries.
// In this case we play it safe and discard the classification // In this case we play it safe and discard the classification
@ -94,11 +99,11 @@ fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)], ov
db.users.get(&user_id).unwrap().login db.users.get(&user_id).unwrap().login
); );
db.is_spam.remove(&user_id); db.is_spam.remove(&user_id);
}, }
None => { None => {
db.is_spam.insert(user_id, IsSpam::from_bool(is_spam)); db.is_spam.insert(user_id, IsSpam::from_bool(is_spam));
train_classifier = true; train_classifier = true;
}, }
Some(was_spam) => { Some(was_spam) => {
assert!(was_spam.as_bool() == is_spam); assert!(was_spam.as_bool() == is_spam);
// nothing to do // nothing to do
@ -143,7 +148,11 @@ struct SortSetting {
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
enum ApproxScore { Low, Mid, High } enum ApproxScore {
Low,
Mid,
High,
}
// approximated score, for feeding to the template // approximated score, for feeding to the template
fn approx_score(score: f32) -> ApproxScore { fn approx_score(score: f32) -> ApproxScore {
@ -157,7 +166,11 @@ fn approx_score(score: f32) -> ApproxScore {
} }
#[get("/")] #[get("/")]
async fn index(data: web::Data<AppState>, q: web::Query<SortSetting>, req: HttpRequest) -> impl Responder { async fn index(
data: web::Data<AppState>,
q: web::Query<SortSetting>,
req: HttpRequest,
) -> impl Responder {
eprintln!("GET {}", req.uri()); eprintln!("GET {}", req.uri());
let db = &data.db.lock().unwrap(); let db = &data.db.lock().unwrap();
@ -187,9 +200,17 @@ async fn index(data: web::Data<AppState>, q: web::Query<SortSetting>, req: HttpR
} }
// compute the rough "spam score" (low/mid/high) and spam guess (true/false) // compute the rough "spam score" (low/mid/high) and spam guess (true/false)
let users: Vec<(&UserId, &UserData, f32, ApproxScore, bool)> = let users: Vec<(&UserId, &UserData, f32, ApproxScore, bool)> = users
users.into_iter() .into_iter()
.map(|(id, u, score)| (id, u, score, approx_score(score), score >= GUESS_SPAM_THRESHOLD)) .map(|(id, u, score)| {
(
id,
u,
score,
approx_score(score),
score >= GUESS_SPAM_THRESHOLD,
)
})
.collect(); .collect();
let users_count = db.users.len(); let users_count = db.users.len();
@ -212,7 +233,7 @@ async fn post_classified(
data: web::Data<AppState>, data: web::Data<AppState>,
form: web::Form<HashMap<i64, String>>, form: web::Form<HashMap<i64, String>>,
req: HttpRequest, req: HttpRequest,
overwrite: bool overwrite: bool,
) -> impl Responder { ) -> impl Responder {
eprintln!("POST {}", req.uri()); eprintln!("POST {}", req.uri());
@ -238,17 +259,29 @@ async fn post_classified(
} }
#[post("/")] #[post("/")]
async fn post_classified_index(data: web::Data<AppState>, form: web::Form<HashMap<i64, String>>, req: HttpRequest) -> impl Responder { async fn post_classified_index(
data: web::Data<AppState>,
form: web::Form<HashMap<i64, String>>,
req: HttpRequest,
) -> impl Responder {
post_classified(data, form, req, false).await post_classified(data, form, req, false).await
} }
#[post("/classified")] #[post("/classified")]
async fn post_classified_edit(data: web::Data<AppState>, form: web::Form<HashMap<i64, String>>, req: HttpRequest) -> impl Responder { async fn post_classified_edit(
data: web::Data<AppState>,
form: web::Form<HashMap<i64, String>>,
req: HttpRequest,
) -> impl Responder {
post_classified(data, form, req, true).await post_classified(data, form, req, true).await
} }
#[get("/classified")] #[get("/classified")]
async fn classified(data: web::Data<AppState>, _q: web::Query<SortSetting>, req: HttpRequest) -> impl Responder { async fn classified(
data: web::Data<AppState>,
_q: web::Query<SortSetting>,
req: HttpRequest,
) -> impl Responder {
eprintln!("GET {}", req.uri()); eprintln!("GET {}", req.uri());
let db = &data.db.lock().unwrap(); let db = &data.db.lock().unwrap();
@ -261,8 +294,8 @@ async fn classified(data: web::Data<AppState>, _q: web::Query<SortSetting>, req:
// sort "spam first" // sort "spam first"
users.sort_by_key(|(_, _, score, _)| 1000 - (score * 1000.) as u64); users.sort_by_key(|(_, _, score, _)| 1000 - (score * 1000.) as u64);
let users: Vec<_> = let users: Vec<_> = users
users.into_iter() .into_iter()
.map(|(id, u, score, is_spam)| (id, u, score, approx_score(score), is_spam)) .map(|(id, u, score, is_spam)| (id, u, score, approx_score(score), is_spam))
.collect(); .collect();

View file

@ -1,6 +1,6 @@
use forgejo_api::Forgejo; use forgejo_api::Forgejo;
use tokio::time::{sleep, Duration};
use std::collections::HashMap; use std::collections::HashMap;
use tokio::time::{sleep, Duration};
use crate::data::*; use crate::data::*;
@ -71,12 +71,10 @@ async fn scrape_users(forge: &Forgejo) -> anyhow::Result<Vec<forgejo_api::struct
pub async fn get_user_data(forge: &Forgejo) -> anyhow::Result<HashMap<UserId, UserData>> { pub async fn get_user_data(forge: &Forgejo) -> anyhow::Result<HashMap<UserId, UserData>> {
let mut data = HashMap::new(); let mut data = HashMap::new();
let discard_empty = |o: Option<String>| { let discard_empty = |o: Option<String>| match o {
match o {
None => None, None => None,
Some(s) if s.is_empty() => None, Some(s) if s.is_empty() => None,
Some(s) => Some(s), Some(s) => Some(s),
}
}; };
eprintln!("Fetching users..."); eprintln!("Fetching users...");

View file

@ -1,15 +1,19 @@
use std::sync::{Arc, Mutex}; use crate::classifier::Classifier;
use crate::db::Db;
use crate::scrape;
use forgejo_api::Forgejo;
use std::collections::HashMap; use std::collections::HashMap;
use std::path::Path; use std::path::Path;
use forgejo_api::Forgejo; use std::sync::{Arc, Mutex};
use crate::db::Db;
use crate::classifier::Classifier;
use crate::scrape;
use crate::FORGEJO_POLL_DELAY; use crate::FORGEJO_POLL_DELAY;
use crate::{GUESS_LEGIT_THRESHOLD, GUESS_SPAM_THRESHOLD}; use crate::{GUESS_LEGIT_THRESHOLD, GUESS_SPAM_THRESHOLD};
async fn try_refresh_user_data(forge: &Forgejo, db: Arc<Mutex<Db>>, classifier: Arc<Mutex<Classifier>>) -> anyhow::Result<()> { async fn try_refresh_user_data(
forge: &Forgejo,
db: Arc<Mutex<Db>>,
classifier: Arc<Mutex<Classifier>>,
) -> anyhow::Result<()> {
{ {
let db = &db.lock().unwrap(); let db = &db.lock().unwrap();
let d = db.last_scrape.elapsed()?; let d = db.last_scrape.elapsed()?;
@ -36,8 +40,8 @@ async fn try_refresh_user_data(forge: &Forgejo, db: Arc<Mutex<Db>>, classifier:
for (&user_id, user_data) in &newdb.users { for (&user_id, user_data) in &newdb.users {
let &score = newdb.score.get(&user_id).unwrap(); let &score = newdb.score.get(&user_id).unwrap();
if let Some(&user_was_spam) = db.is_spam.get(&user_id) { if let Some(&user_was_spam) = db.is_spam.get(&user_id) {
if (user_was_spam.as_bool() && score < GUESS_SPAM_THRESHOLD) || if (user_was_spam.as_bool() && score < GUESS_SPAM_THRESHOLD)
(! user_was_spam.as_bool() && score > GUESS_LEGIT_THRESHOLD) || (!user_was_spam.as_bool() && score > GUESS_LEGIT_THRESHOLD)
{ {
eprintln!( eprintln!(
"Score for user {} changed past threshold; discarding our current classification", "Score for user {} changed past threshold; discarding our current classification",
@ -57,7 +61,11 @@ async fn try_refresh_user_data(forge: &Forgejo, db: Arc<Mutex<Db>>, classifier:
Ok(()) Ok(())
} }
pub async fn refresh_user_data(forge: Arc<Forgejo>, db: Arc<Mutex<Db>>, classifier: Arc<Mutex<Classifier>>) { pub async fn refresh_user_data(
forge: Arc<Forgejo>,
db: Arc<Mutex<Db>>,
classifier: Arc<Mutex<Classifier>>,
) {
loop { loop {
tokio::time::sleep(FORGEJO_POLL_DELAY.mul_f32(0.1)).await; tokio::time::sleep(FORGEJO_POLL_DELAY.mul_f32(0.1)).await;
if let Err(e) = try_refresh_user_data(&forge, db.clone(), classifier.clone()).await { if let Err(e) = try_refresh_user_data(&forge, db.clone(), classifier.clone()).await {