cargo fmt

This commit is contained in:
Armaël Guéneau 2024-12-19 12:49:58 +01:00
parent ddda6cc1cf
commit 45ff1f3ea5
5 changed files with 93 additions and 48 deletions

View file

@ -1,5 +1,5 @@
use serde::{Deserialize, Serialize};
use crate::classifier::Classifier;
use serde::{Deserialize, Serialize};
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct UserId(pub i64);

View file

@ -1,12 +1,12 @@
use crate::classifier::Classifier;
use crate::data::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::fmt;
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::Path;
use std::time::SystemTime;
use std::fmt;
use serde::{Serialize, Deserialize};
use crate::data::*;
use crate::classifier::Classifier;
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
pub enum IsSpam {
@ -24,7 +24,9 @@ impl IsSpam {
pub fn from_bool(b: bool) -> IsSpam {
if b {
IsSpam::Spam { classified_at: SystemTime::now() }
IsSpam::Spam {
classified_at: SystemTime::now(),
}
} else {
IsSpam::Legit
}
@ -67,8 +69,7 @@ impl Db {
pub fn from_path(path: &Path, classifier: &Classifier) -> anyhow::Result<Self> {
let file = File::open(path)?;
let (users, is_spam, last_scrape) =
serde_json::from_reader(BufReader::new(file))?;
let (users, is_spam, last_scrape) = serde_json::from_reader(BufReader::new(file))?;
let mut db = Db {
users,
is_spam,
@ -100,8 +101,11 @@ impl Db {
pub fn store_to_path(&self, path: &Path) -> anyhow::Result<()> {
let file = File::create(path)?;
let dat: (&HashMap<UserId, UserData>, &HashMap<UserId, IsSpam>, SystemTime) =
(&self.users, &self.is_spam, self.last_scrape);
let dat: (
&HashMap<UserId, UserData>,
&HashMap<UserId, IsSpam>,
SystemTime,
) = (&self.users, &self.is_spam, self.last_scrape);
serde_json::to_writer(BufWriter::new(file), &dat)?;
Ok(())
}
@ -116,9 +120,11 @@ impl Db {
pub fn classified_users<'a>(&'a self) -> Vec<(&'a UserId, &'a UserData, IsSpam)> {
self.users
.iter()
.filter_map(|(user_id, user_data)|
self.is_spam.get(&user_id).map(|is_spam| (user_id, user_data, *is_spam))
)
.filter_map(|(user_id, user_data)| {
self.is_spam
.get(&user_id)
.map(|is_spam| (user_id, user_data, *is_spam))
})
.collect()
}
}

View file

@ -2,7 +2,7 @@ use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Resp
use forgejo_api::{Auth, Forgejo};
use lazy_static::lazy_static;
use rand::prelude::*;
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::File;
use std::path::Path;
@ -18,7 +18,7 @@ mod workers;
use classifier::Classifier;
use data::*;
use db::{IsSpam, Db};
use db::{Db, IsSpam};
// Fetch user data from forgejo from time to time
const FORGEJO_POLL_DELAY: Duration = Duration::from_secs(11 * 3600); // 11 hours
@ -54,7 +54,11 @@ async fn load_db(forge: &Forgejo) -> anyhow::Result<(Db, Classifier)> {
let db: Db = if db_path.is_file() {
Db::from_path(db_path, &classifier)?
} else {
let db = Db::from_users(scrape::get_user_data(&forge).await?, HashMap::new(), &classifier);
let db = Db::from_users(
scrape::get_user_data(&forge).await?,
HashMap::new(),
&classifier,
);
db.store_to_path(db_path)?;
db
};
@ -72,7 +76,8 @@ fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)], ov
eprintln!(
"User {}: changing classification from {} to {}",
db.users.get(&user_id).unwrap().login,
was_spam, is_spam
was_spam,
is_spam
);
db.is_spam.insert(user_id, IsSpam::from_bool(is_spam));
// This is somewhat hackish: we already trained the classifier
@ -84,7 +89,7 @@ fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)], ov
// tokens that we have now are the same as the one that were
// used previously).
train_classifier = true;
},
}
Some(&was_spam) if !overwrite && was_spam.as_bool() != is_spam => {
// Classification conflict between concurrent queries.
// In this case we play it safe and discard the classification
@ -94,11 +99,11 @@ fn set_spam(db: &mut Db, classifier: &mut Classifier, ids: &[(UserId, bool)], ov
db.users.get(&user_id).unwrap().login
);
db.is_spam.remove(&user_id);
},
}
None => {
db.is_spam.insert(user_id, IsSpam::from_bool(is_spam));
train_classifier = true;
},
}
Some(was_spam) => {
assert!(was_spam.as_bool() == is_spam);
// nothing to do
@ -143,7 +148,11 @@ struct SortSetting {
}
#[derive(Serialize, Deserialize)]
enum ApproxScore { Low, Mid, High }
enum ApproxScore {
Low,
Mid,
High,
}
// approximated score, for feeding to the template
fn approx_score(score: f32) -> ApproxScore {
@ -157,7 +166,11 @@ fn approx_score(score: f32) -> ApproxScore {
}
#[get("/")]
async fn index(data: web::Data<AppState>, q: web::Query<SortSetting>, req: HttpRequest) -> impl Responder {
async fn index(
data: web::Data<AppState>,
q: web::Query<SortSetting>,
req: HttpRequest,
) -> impl Responder {
eprintln!("GET {}", req.uri());
let db = &data.db.lock().unwrap();
@ -187,9 +200,17 @@ async fn index(data: web::Data<AppState>, q: web::Query<SortSetting>, req: HttpR
}
// compute the rough "spam score" (low/mid/high) and spam guess (true/false)
let users: Vec<(&UserId, &UserData, f32, ApproxScore, bool)> =
users.into_iter()
.map(|(id, u, score)| (id, u, score, approx_score(score), score >= GUESS_SPAM_THRESHOLD))
let users: Vec<(&UserId, &UserData, f32, ApproxScore, bool)> = users
.into_iter()
.map(|(id, u, score)| {
(
id,
u,
score,
approx_score(score),
score >= GUESS_SPAM_THRESHOLD,
)
})
.collect();
let users_count = db.users.len();
@ -212,7 +233,7 @@ async fn post_classified(
data: web::Data<AppState>,
form: web::Form<HashMap<i64, String>>,
req: HttpRequest,
overwrite: bool
overwrite: bool,
) -> impl Responder {
eprintln!("POST {}", req.uri());
@ -238,17 +259,29 @@ async fn post_classified(
}
#[post("/")]
async fn post_classified_index(data: web::Data<AppState>, form: web::Form<HashMap<i64, String>>, req: HttpRequest) -> impl Responder {
async fn post_classified_index(
data: web::Data<AppState>,
form: web::Form<HashMap<i64, String>>,
req: HttpRequest,
) -> impl Responder {
post_classified(data, form, req, false).await
}
#[post("/classified")]
async fn post_classified_edit(data: web::Data<AppState>, form: web::Form<HashMap<i64, String>>, req: HttpRequest) -> impl Responder {
async fn post_classified_edit(
data: web::Data<AppState>,
form: web::Form<HashMap<i64, String>>,
req: HttpRequest,
) -> impl Responder {
post_classified(data, form, req, true).await
}
#[get("/classified")]
async fn classified(data: web::Data<AppState>, _q: web::Query<SortSetting>, req: HttpRequest) -> impl Responder {
async fn classified(
data: web::Data<AppState>,
_q: web::Query<SortSetting>,
req: HttpRequest,
) -> impl Responder {
eprintln!("GET {}", req.uri());
let db = &data.db.lock().unwrap();
@ -261,8 +294,8 @@ async fn classified(data: web::Data<AppState>, _q: web::Query<SortSetting>, req:
// sort "spam first"
users.sort_by_key(|(_, _, score, _)| 1000 - (score * 1000.) as u64);
let users: Vec<_> =
users.into_iter()
let users: Vec<_> = users
.into_iter()
.map(|(id, u, score, is_spam)| (id, u, score, approx_score(score), is_spam))
.collect();

View file

@ -1,6 +1,6 @@
use forgejo_api::Forgejo;
use tokio::time::{sleep, Duration};
use std::collections::HashMap;
use tokio::time::{sleep, Duration};
use crate::data::*;
@ -71,12 +71,10 @@ async fn scrape_users(forge: &Forgejo) -> anyhow::Result<Vec<forgejo_api::struct
pub async fn get_user_data(forge: &Forgejo) -> anyhow::Result<HashMap<UserId, UserData>> {
let mut data = HashMap::new();
let discard_empty = |o: Option<String>| {
match o {
let discard_empty = |o: Option<String>| match o {
None => None,
Some(s) if s.is_empty() => None,
Some(s) => Some(s),
}
};
eprintln!("Fetching users...");

View file

@ -1,15 +1,19 @@
use std::sync::{Arc, Mutex};
use crate::classifier::Classifier;
use crate::db::Db;
use crate::scrape;
use forgejo_api::Forgejo;
use std::collections::HashMap;
use std::path::Path;
use forgejo_api::Forgejo;
use crate::db::Db;
use crate::classifier::Classifier;
use crate::scrape;
use std::sync::{Arc, Mutex};
use crate::FORGEJO_POLL_DELAY;
use crate::{GUESS_LEGIT_THRESHOLD, GUESS_SPAM_THRESHOLD};
async fn try_refresh_user_data(forge: &Forgejo, db: Arc<Mutex<Db>>, classifier: Arc<Mutex<Classifier>>) -> anyhow::Result<()> {
async fn try_refresh_user_data(
forge: &Forgejo,
db: Arc<Mutex<Db>>,
classifier: Arc<Mutex<Classifier>>,
) -> anyhow::Result<()> {
{
let db = &db.lock().unwrap();
let d = db.last_scrape.elapsed()?;
@ -36,8 +40,8 @@ async fn try_refresh_user_data(forge: &Forgejo, db: Arc<Mutex<Db>>, classifier:
for (&user_id, user_data) in &newdb.users {
let &score = newdb.score.get(&user_id).unwrap();
if let Some(&user_was_spam) = db.is_spam.get(&user_id) {
if (user_was_spam.as_bool() && score < GUESS_SPAM_THRESHOLD) ||
(! user_was_spam.as_bool() && score > GUESS_LEGIT_THRESHOLD)
if (user_was_spam.as_bool() && score < GUESS_SPAM_THRESHOLD)
|| (!user_was_spam.as_bool() && score > GUESS_LEGIT_THRESHOLD)
{
eprintln!(
"Score for user {} changed past threshold; discarding our current classification",
@ -57,7 +61,11 @@ async fn try_refresh_user_data(forge: &Forgejo, db: Arc<Mutex<Db>>, classifier:
Ok(())
}
pub async fn refresh_user_data(forge: Arc<Forgejo>, db: Arc<Mutex<Db>>, classifier: Arc<Mutex<Classifier>>) {
pub async fn refresh_user_data(
forge: Arc<Forgejo>,
db: Arc<Mutex<Db>>,
classifier: Arc<Mutex<Classifier>>,
) {
loop {
tokio::time::sleep(FORGEJO_POLL_DELAY.mul_f32(0.1)).await;
if let Err(e) = try_refresh_user_data(&forge, db.clone(), classifier.clone()).await {