Add S3 as storage backend, refactor db & storage code

This commit is contained in:
Armaël Guéneau 2024-12-23 00:50:01 +01:00
parent af38eae2c3
commit edc49a6d1d
11 changed files with 1608 additions and 535 deletions

1118
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,5 @@
[package] [package]
name = "forgejo-antispam" name = "forgery"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
@ -21,6 +21,8 @@ actix-files = "0.6"
unicode-segmentation = "1" unicode-segmentation = "1"
lettre = { version = "0.11", features = ["builder", "smtp-transport", "rustls-tls"], default-features = false } lettre = { version = "0.11", features = ["builder", "smtp-transport", "rustls-tls"], default-features = false }
include_dir = "0.7" include_dir = "0.7"
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
aws-sdk-s3 = "1.66.0"
[profile.profiling] [profile.profiling]
inherits = "dev" inherits = "dev"

View file

@ -27,12 +27,28 @@ Forgery reads the following environment variables:
default) or set to `false`, no actual action is taken: spammers are only default) or set to `false`, no actual action is taken: spammers are only
listed in the database. The variable should be set in production, but probably listed in the database. The variable should be set in production, but probably
not for testing. not for testing.
- `STORAGE_BACKEND`: either `local` (default) or `s3`. Chose `local` to store
the application state to local files, or `s3` to store them in S3-compatible
storage (see below for corresponding configuration variables).
Environment variables that are relevant when `ACTUALLY_BAN_USERS=true`: Environment variables read when `ACTUALLY_BAN_USERS=true`:
- `SMTP_ADDRESS`: address of the SMTP relay used to send email notifications - `SMTP_ADDRESS`: address of the SMTP relay used to send email notifications
- `SMTP_USERNAME`: SMTP username - `SMTP_USERNAME`: SMTP username
- `SMTP_PASSWORD`: SMTP password - `SMTP_PASSWORD`: SMTP password
Environment variables read when `STORAGE_BACKEND=local`:
- `STORAGE_LOCAL_DIR`: path to a local directory where to store the application
data (as two files `db.json` and `model.json`). Defaults to `.` if not
defined.
Environment variables read when `STORAGE_BACKEND=s3`:
- `STORAGE_S3_BUCKET`: name of the bucket where to store the application data
(as two entries `db.json` and `model.json`).
- `AWS_DEFAULT_REGION`: S3 endpoint region
- `AWS_ENDPOINT_URL`: S3 endpoint URL
- `AWS_ACCESS_KEY_ID`: S3 key id
- `AWS_SECRET_ACCESS_KEY`: S3 key secret
## Todos ## Todos
- discuss the current design choices for when locking the account/sending a - discuss the current design choices for when locking the account/sending a
@ -40,5 +56,5 @@ Environment variables that are relevant when `ACTUALLY_BAN_USERS=true`:
(Current behavior is to periodically retry, avoid deleting if the account (Current behavior is to periodically retry, avoid deleting if the account
could not be locked, but delete the account after the grace period even if could not be locked, but delete the account after the grace period even if
the email could not be sent…) the email could not be sent…)
- add backend to store data on garage instead of local files - auth: add support for connecting to the forge using oauth?
- improve error handling - improve error handling

View file

@ -1,11 +1,8 @@
// code based on the bayespam crate // code based on the bayespam crate
use std::collections::HashMap; use std::collections::HashMap;
use std::fs::File;
use std::io;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{from_reader, to_writer, to_writer_pretty};
use unicode_segmentation::UnicodeSegmentation; use unicode_segmentation::UnicodeSegmentation;
const INITIAL_RATING: f32 = 0.5; const INITIAL_RATING: f32 = 0.5;
@ -30,23 +27,6 @@ impl Classifier {
Default::default() Default::default()
} }
/// Build a new classifier with a pre-trained model loaded from `file`.
pub fn new_from_pre_trained(file: &mut File) -> Result<Self, io::Error> {
let pre_trained_model = from_reader(file)?;
Ok(pre_trained_model)
}
/// Save the classifier to `file` as JSON.
/// The JSON will be pretty printed if `pretty` is `true`.
pub fn save(&self, file: &mut File, pretty: bool) -> Result<(), io::Error> {
if pretty {
to_writer_pretty(file, &self)?;
} else {
to_writer(file, &self)?;
}
Ok(())
}
/// Split `msg` into a list of words. /// Split `msg` into a list of words.
pub fn into_word_list(msg: &str) -> Vec<String> { pub fn into_word_list(msg: &str) -> Vec<String> {
let word_list = msg.unicode_words().collect::<Vec<&str>>(); let word_list = msg.unicode_words().collect::<Vec<&str>>();

View file

@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct UserId(pub i64); pub struct UserId(pub i64);
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UserData { pub struct UserData {
pub login: String, pub login: String,
pub email: String, pub email: String,
@ -20,7 +20,7 @@ pub struct UserData {
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct RepoId(pub i64); pub struct RepoId(pub i64);
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RepoData { pub struct RepoData {
pub name: String, pub name: String,
pub description: Option<String>, pub description: Option<String>,
@ -29,7 +29,7 @@ pub struct RepoData {
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct IssueId(pub i64); pub struct IssueId(pub i64);
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct IssueData { pub struct IssueData {
pub title: String, pub title: String,
pub body: String, pub body: String,

218
src/db.rs
View file

@ -1,138 +1,126 @@
use crate::classifier::Classifier; use crate::classifier::Classifier;
use crate::data::*; use crate::data::*;
use serde::{Deserialize, Serialize}; use crate::userdb::{IsSpam, UserDb};
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt; use std::sync::{Arc, Mutex};
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::Path;
use std::time::SystemTime;
#[derive(Serialize, Deserialize, Debug, Clone, Copy)] #[derive(Clone)]
pub enum IsSpam {
Legit,
Spam {
classified_at: SystemTime,
locked: bool,
notified: bool,
},
}
impl IsSpam {
pub fn as_bool(&self) -> bool {
match self {
IsSpam::Legit => true,
IsSpam::Spam { .. } => false,
}
}
pub fn from_bool(b: bool) -> IsSpam {
if b {
IsSpam::Spam {
classified_at: SystemTime::now(),
locked: false,
notified: false,
}
} else {
IsSpam::Legit
}
}
}
impl fmt::Display for IsSpam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IsSpam::Legit => write!(f, "legit"),
IsSpam::Spam { .. } => write!(f, "spam"),
}
}
}
// TODO (?): make the fields private and provide an API that automatically
// recomputes the caches when necessary?
pub struct Db { pub struct Db {
// persisted data userdb: Arc<Mutex<UserDb>>,
pub users: HashMap<UserId, UserData>, classifier: Arc<Mutex<Classifier>>,
pub is_spam: HashMap<UserId, IsSpam>, cache: Arc<Mutex<Cache>>,
pub last_scrape: SystemTime, }
// caches: computed from persisted data on load
pub score: HashMap<UserId, f32>, struct Cache {
pub tokens: HashMap<UserId, Vec<String>>, score: HashMap<UserId, f32>,
tokens: HashMap<UserId, Vec<String>>,
} }
impl Db { impl Db {
pub fn recompute_tokens(&mut self) { // Creating
for (id, user) in &self.users {
self.tokens.insert(*id, user.to_tokens()); pub fn create(userdb: UserDb, classifier: Classifier) -> Self {
let cache = Cache::create(&userdb, &classifier);
Self {
userdb: Arc::new(Mutex::new(userdb)),
classifier: Arc::new(Mutex::new(classifier)),
cache: Arc::new(Mutex::new(cache)),
} }
} }
pub fn recompute_scores(&mut self, classifier: &Classifier) { pub fn replace_userdb(&self, newdb: UserDb) {
for (id, tokens) in &self.tokens { let userdb: &mut UserDb = &mut self.userdb.lock().unwrap();
self.score.insert(*id, classifier.score(tokens)); let _ = std::mem::replace(userdb, newdb);
} let new_cache = Cache::create(userdb, &self.classifier.lock().unwrap());
let cache: &mut Cache = &mut self.cache.lock().unwrap();
let _ = std::mem::replace(cache, new_cache);
} }
pub fn from_path(path: &Path, classifier: &Classifier) -> anyhow::Result<Self> { // Reading
let file = File::open(path)?;
let (users, is_spam, last_scrape) = serde_json::from_reader(BufReader::new(file))?; pub fn with_userdb<F, T>(&self, f: F) -> T
let mut db = Db { where
users, F: FnOnce(&UserDb) -> T,
is_spam, {
last_scrape, let lock = &self.userdb.lock().unwrap();
tokens: HashMap::new(), f(lock)
score: HashMap::new(),
};
db.recompute_tokens();
db.recompute_scores(classifier);
Ok(db)
} }
pub fn from_users( pub fn login(&self, uid: UserId) -> Option<String> {
users: HashMap<UserId, UserData>, self.with_userdb(|u| u.userdata(uid).map(|d| d.login.clone()))
is_spam: HashMap<UserId, IsSpam>,
classifier: &Classifier,
) -> Db {
let mut db = Db {
users,
is_spam,
last_scrape: SystemTime::now(),
tokens: HashMap::new(),
score: HashMap::new(),
};
db.recompute_tokens();
db.recompute_scores(classifier);
db
} }
pub fn store_to_path(&self, path: &Path) -> anyhow::Result<()> { pub fn score(&self, uid: UserId) -> Option<f32> {
let file = File::create(path)?; self.cache.lock().unwrap().score.get(&uid).copied()
let dat: (
&HashMap<UserId, UserData>,
&HashMap<UserId, IsSpam>,
SystemTime,
) = (&self.users, &self.is_spam, self.last_scrape);
serde_json::to_writer(BufWriter::new(file), &dat)?;
Ok(())
} }
pub fn unclassified_users(&self) -> Vec<(UserId, &UserData)> { pub fn with_tokens<F>(&self, uid: UserId, f: F)
self.users where
.iter() F: FnOnce(Option<&[String]>),
.filter(|(user_id, _)| !self.is_spam.contains_key(user_id)) {
.map(|(id, d)| (*id, d)) let lock = self.cache.lock().unwrap();
.collect() f(lock.tokens.get(&uid).map(|v| &**v))
} }
pub fn classified_users(&self) -> Vec<(UserId, &UserData, IsSpam)> { // Updating
self.users
.iter() // pub fn recompute_scores(&self, classifier: &Classifier) {
.filter_map(|(user_id, user_data)| { // let lock = &mut self.inner.lock().unwrap();
self.is_spam // lock.recompute_scores(classifier)
.get(user_id) // }
.map(|is_spam| (user_id, user_data, *is_spam))
}) pub fn set_spam(&self, uid: UserId, is_spam: Option<IsSpam>) {
.map(|(id, d, s)| (*id, d, s)) let udb = &mut self.userdb.lock().unwrap();
.collect() udb.set_spam(uid, is_spam)
}
pub fn with_classifier<F, T>(&self, f: F) -> T
where
F: FnOnce(&Classifier) -> T,
{
let classifier = &self.classifier.lock().unwrap();
f(classifier)
}
pub fn with_classifier_mut<F, T>(&self, f: F) -> T
where
F: FnOnce(&mut Classifier) -> T,
{
let classifier = &mut self.classifier.lock().unwrap();
let res = f(classifier);
// recompute scores
let cache: &mut Cache = &mut self.cache.lock().unwrap();
for (id, tokens) in &cache.tokens {
cache.score.insert(*id, classifier.score(tokens));
}
res
}
pub fn remove_user(&self, uid: UserId) {
let userdb = &mut self.userdb.lock().unwrap();
userdb.remove_user(uid);
let cache = &mut self.cache.lock().unwrap();
cache.remove_user(uid);
}
}
impl Cache {
fn create(userdb: &UserDb, classifier: &Classifier) -> Self {
let mut tokens = HashMap::new();
let mut score = HashMap::new();
for (id, user, _) in userdb {
let user_tokens = user.to_tokens();
let user_score = classifier.score(&user_tokens);
tokens.insert(id, user_tokens);
score.insert(id, user_score);
}
Cache { tokens, score }
}
fn remove_user(&mut self, uid: UserId) {
self.score.remove(&uid);
self.tokens.remove(&uid);
} }
} }

View file

@ -50,6 +50,7 @@ pub async fn send_locked_account_notice(
let email = Message::builder() let email = Message::builder()
.from(smtp.username.parse().unwrap()) .from(smtp.username.parse().unwrap())
.to(email.parse()?) .to(email.parse()?)
.reply_to(admin_contact_email.parse().unwrap())
.subject(format!( .subject(format!(
"[Forgejo {org_name}] Your account was marked as spam and will be deleted in {} days", "[Forgejo {org_name}] Your account was marked as spam and will be deleted in {} days",
grace_period_days grace_period_days

View file

@ -5,9 +5,7 @@ use lazy_static::lazy_static;
use rand::prelude::*; use rand::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::fs::File; use std::sync::Arc;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::time::Duration; use std::time::Duration;
use tera::Tera; use tera::Tera;
use url::Url; use url::Url;
@ -17,12 +15,15 @@ mod data;
mod db; mod db;
mod email; mod email;
mod scrape; mod scrape;
mod storage;
mod userdb;
mod workers; mod workers;
use classifier::Classifier;
use data::*; use data::*;
use db::{Db, IsSpam}; use db::Db;
use email::SmtpConfig; use email::SmtpConfig;
use storage::Storage;
use userdb::{IsSpam, UserDb};
// Fetch user data from forgejo from time to time // Fetch user data from forgejo from time to time
const FORGEJO_POLL_DELAY: Duration = Duration::from_secs(11 * 3600); // 11 hours const FORGEJO_POLL_DELAY: Duration = Duration::from_secs(11 * 3600); // 11 hours
@ -95,9 +96,10 @@ struct AppState {
config: Arc<Config>, config: Arc<Config>,
// authenticated access to the forgejo instance // authenticated access to the forgejo instance
forge: Arc<Forgejo>, forge: Arc<Forgejo>,
// runtime state (to be persisted in the storage when modified) // handle to the storage backend
db: Arc<Mutex<Db>>, storage: Arc<Storage>,
classifier: Arc<Mutex<Classifier>>, // persistent state (written to the storage when modified)
db: Db,
} }
fn forge(url: &Url) -> anyhow::Result<Forgejo> { fn forge(url: &Url) -> anyhow::Result<Forgejo> {
@ -108,28 +110,22 @@ fn forge(url: &Url) -> anyhow::Result<Forgejo> {
Ok(forge) Ok(forge)
} }
async fn load_db(forge: &Forgejo) -> anyhow::Result<(Db, Classifier)> { async fn load_db(storage: &Storage, forge: &Forgejo) -> anyhow::Result<Db> {
let model_path = Path::new("model.json"); let classifier = storage::load_classifier(storage).await?;
let classifier = if model_path.is_file() { let userdb = match storage::load_userdb(storage).await? {
Classifier::new_from_pre_trained(&mut File::open(model_path)?)? Some(db) => db,
} else { None => {
Classifier::new() let db = UserDb::from_users(
};
let db_path = Path::new("db.json");
let db: Db = if db_path.is_file() {
Db::from_path(db_path, &classifier)?
} else {
let db = Db::from_users(
scrape::get_user_data(forge).await?, scrape::get_user_data(forge).await?,
HashMap::new(), HashMap::new(),
&classifier, std::time::SystemTime::now(),
); );
db.store_to_path(db_path)?; storage::store_userdb(storage, &db).await?;
db db
}
}; };
Ok((db, classifier)) Ok(Db::create(userdb, classifier))
} }
// Register a list of decisions taken by the admin using the webpage, checking // Register a list of decisions taken by the admin using the webpage, checking
@ -144,97 +140,89 @@ async fn load_db(forge: &Forgejo) -> anyhow::Result<(Db, Classifier)> {
// NB: some of the input decisions may be no-ops: when using the page to edit // NB: some of the input decisions may be no-ops: when using the page to edit
// existing classifications, the webform sends the list of all existing and // existing classifications, the webform sends the list of all existing and
// changed classifications. // changed classifications.
fn set_spam( fn set_spam(db: &Db, ids: &[(UserId, bool)], overwrite: bool) -> Vec<UserId> {
db: &mut Db, let mut updated_spam = vec![];
classifier: &mut Classifier,
ids: &[(UserId, bool)],
overwrite: bool,
) -> Vec<UserId> {
let mut spammers = Vec::new();
for &(user_id, is_spam) in ids { for &(user_id, set_spam) in ids {
let mut update_classification = false; let login = db.login(user_id).unwrap();
match db.is_spam.get(&user_id) { match db.with_userdb(|u| u.is_spam(user_id)) {
Some(&was_spam) if overwrite && was_spam.as_bool() != is_spam => { Some(was_spam) if overwrite && was_spam.as_bool() != set_spam => {
eprintln!( eprintln!("User {login}: changing classification from {was_spam} to {set_spam}");
"User {}: changing classification from {} to {}", db.set_spam(user_id, Some(IsSpam::from_bool(set_spam)));
db.users.get(&user_id).unwrap().login, // We train the classifier again, which is somewhat hackish: we
was_spam, // already trained it on the previous classification, possibly
is_spam // with the same tokens.
);
// Training the classifier again is somewhat hackish in this
// case: we already trained the classifier on the previous
// classification, possibly with the same tokens.
// //
// Ideally we would undo the previous training and train with // Ideally we would undo the previous training and train with
// the correct classification now, but the classifier has no way // the correct classification now, but the classifier has no way
// to easily undo a previous training (we don't know whether the // to easily undo a previous training (we don't know whether the
// tokens that we have now are the same as the one that were // tokens that we have now are the same as the one that were
// used previously). // used previously).
update_classification = true; updated_spam.push((user_id, set_spam));
} }
Some(&was_spam) if !overwrite && was_spam.as_bool() != is_spam => { Some(was_spam) if !overwrite && was_spam.as_bool() != set_spam => {
// Classification conflict between concurrent queries. // Classification conflict between concurrent queries.
// In this case we play it safe and discard the classification // In this case we play it safe and discard the classification
// for this user; the user will need to be manually classified again. // for this user; the user will need to be manually classified again.
eprintln!( eprintln!(
"Classification conflict for user {}; discarding our current classification", "Classification conflict for user {login}; discarding our current classification"
db.users.get(&user_id).unwrap().login
); );
db.is_spam.remove(&user_id); db.set_spam(user_id, None);
} }
None => { None => {
update_classification = true; db.set_spam(user_id, Some(IsSpam::from_bool(set_spam)));
updated_spam.push((user_id, set_spam));
} }
Some(was_spam) => { Some(was_spam) => {
assert!(was_spam.as_bool() == is_spam); assert!(was_spam.as_bool() == set_spam);
// nothing to do. // nothing to do.
// In particular, keep the spam classification time as is. // In particular, keep the spam classification time as is.
} }
} }
}
if update_classification { let mut new_spammers = vec![];
db.is_spam.insert(user_id, IsSpam::from_bool(is_spam));
// if we just classified the user as spam, add it to the list // update the classifier
if is_spam { db.with_classifier_mut(|classifier| {
spammers.push(user_id) for &(user_id, set_spam) in &updated_spam {
// if we just classified the user as spam, add it to the list of new
// spammers
if set_spam {
new_spammers.push(user_id)
} }
// Train the classifier with tokens from the user // Train the classifier with tokens from the user
let tokens = db.tokens.get(&user_id).unwrap(); db.with_tokens(user_id, |tokens| {
if is_spam { let tokens = tokens.unwrap();
if set_spam {
classifier.train_spam(tokens) classifier.train_spam(tokens)
} else { } else {
classifier.train_ham(tokens) classifier.train_ham(tokens)
} }
})
} }
} });
eprintln!("recomputing user scores"); new_spammers
db.recompute_scores(classifier);
spammers
} }
async fn apply_classification( async fn apply_classification(
config: &Config, config: &Config,
storage: &Storage,
forge: &Forgejo, forge: &Forgejo,
db: Arc<Mutex<Db>>, db: &Db,
classifier: Arc<Mutex<Classifier>>,
ids: &[(UserId, bool)], ids: &[(UserId, bool)],
overwrite: bool, overwrite: bool,
) { ) {
let spammers = { let spammers = set_spam(db, ids, overwrite);
let classifier = &mut classifier.lock().unwrap();
set_spam(&mut db.lock().unwrap(), classifier, ids, overwrite)
};
for user in spammers { for user in spammers {
let login = db.lock().unwrap().users.get(&user).unwrap().login.clone(); let login = db.login(user).unwrap();
// It is ok for any of these calls to fail now: a worker will periodically retry // It is ok for any of these calls to fail now: a worker will periodically retry
// TODO: signal the worker to wake up instead of performing a manual call here // TODO: signal the worker to wake up instead of performing a manual call here
workers::try_lock_and_notify_user(config, forge, db.clone(), user) workers::try_lock_and_notify_user(config, storage, forge, db, user)
.await .await
.unwrap_or_else(|err| eprintln!("Failed to lock or notify user {login}: {err}")); .unwrap_or_else(|err| eprintln!("Failed to lock or notify user {login}: {err}"));
} }
@ -297,13 +285,14 @@ async fn index(
) -> impl Responder { ) -> impl Responder {
eprintln!("GET {}", req.uri()); eprintln!("GET {}", req.uri());
let db = &data.db.lock().unwrap(); let db = &data.db;
let mut users: Vec<(UserId, &UserData, f32)> = db let mut users: Vec<(UserId, UserData, f32)> = db.with_userdb(|udb| {
.unclassified_users() udb.unclassified_users()
.into_iter() .into_iter()
.map(|(id, u)| (id, u, *db.score.get(&id).unwrap())) .map(|(id, u)| (id, u.clone(), db.score(id).unwrap()))
.collect(); .collect()
});
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
users.shuffle(&mut rng); users.shuffle(&mut rng);
@ -324,7 +313,7 @@ async fn index(
} }
// compute the rough "spam score" (low/mid/high) and spam guess (true/false) // compute the rough "spam score" (low/mid/high) and spam guess (true/false)
let users: Vec<(UserId, &UserData, f32, ApproxScore, bool)> = users let users: Vec<(UserId, UserData, f32, ApproxScore, bool)> = users
.into_iter() .into_iter()
.map(|(id, u, score)| { .map(|(id, u, score)| {
( (
@ -337,8 +326,8 @@ async fn index(
}) })
.collect(); .collect();
let users_count = db.users.len(); let users_count = db.with_userdb(|udb| udb.nb_users());
let classified_count = db.is_spam.len(); let classified_count = db.with_userdb(|udb| udb.nb_classified());
let mut context = tera::Context::new(); let mut context = tera::Context::new();
context.insert("forge_url", &data.config.forge_url.to_string()); context.insert("forge_url", &data.config.forge_url.to_string());
@ -369,30 +358,26 @@ async fn post_classified(
apply_classification( apply_classification(
&data.config, &data.config,
&data.storage,
&data.forge, &data.forge,
data.db.clone(), &data.db,
data.classifier.clone(),
&updates, &updates,
overwrite, overwrite,
) )
.await; .await;
data.db let res = storage::store_db(&data.storage, &data.db).await;
.lock()
.unwrap()
.store_to_path(Path::new("db.json"))
.unwrap(); // FIXME
data.classifier
.lock()
.unwrap()
.save(&mut File::create(Path::new("model.json")).unwrap(), false)
.unwrap(); // FIXME
eprintln!("done"); eprintln!("done");
HttpResponse::SeeOther()
match res {
Ok(()) => HttpResponse::SeeOther()
.insert_header(("Location", req.uri().to_string())) .insert_header(("Location", req.uri().to_string()))
.finish() .finish(),
Err(e) => {
HttpResponse::InternalServerError().body(format!("Internal server error:\n\n{e}"))
}
}
} }
#[post("/")] #[post("/")]
@ -421,13 +406,13 @@ async fn classified(
) -> impl Responder { ) -> impl Responder {
eprintln!("GET {}", req.uri()); eprintln!("GET {}", req.uri());
let db = &data.db.lock().unwrap(); let db = &data.db;
let mut users: Vec<(UserId, UserData, f32, bool)> = db.with_userdb(|udb| {
let mut users: Vec<(UserId, &UserData, f32, bool)> = db udb.classified_users()
.classified_users()
.into_iter() .into_iter()
.map(|(id, u, s)| (id, u, *db.score.get(&id).unwrap(), s.as_bool())) .map(|(id, u, s)| (id, u.clone(), db.score(id).unwrap(), s.as_bool()))
.collect(); .collect()
});
// sort "spam first" // sort "spam first"
users.sort_by_key(|(_, _, score, _)| 1000 - (score * 1000.) as u64); users.sort_by_key(|(_, _, score, _)| 1000 - (score * 1000.) as u64);
@ -465,15 +450,14 @@ async fn main() -> anyhow::Result<()> {
let config = Arc::new(Config::from_env().await?); let config = Arc::new(Config::from_env().await?);
let forge = Arc::new(forge(&config.forge_url)?); let forge = Arc::new(forge(&config.forge_url)?);
let storage = Arc::new(Storage::from_env().await?);
eprintln!("Load users and repos"); eprintln!("Load users and repos");
let (db, classifier) = load_db(&forge).await?; let db = load_db(&storage, &forge).await?;
let db = Arc::new(Mutex::new(db));
let classifier = Arc::new(Mutex::new(classifier));
let st = web::Data::new(AppState { let st = web::Data::new(AppState {
db: db.clone(), db: db.clone(),
classifier: classifier.clone(), storage: storage.clone(),
forge: forge.clone(), forge: forge.clone(),
config: config.clone(), config: config.clone(),
}); });
@ -481,22 +465,26 @@ async fn main() -> anyhow::Result<()> {
let mut workers = tokio::task::JoinSet::new(); let mut workers = tokio::task::JoinSet::new();
let _ = { let _ = {
let storage = storage.clone();
let forge = forge.clone(); let forge = forge.clone();
let db = db.clone(); let db = db.clone();
let classifier = classifier.clone(); workers.spawn(async move { workers::refresh_user_data(storage, forge, db).await })
workers.spawn(async move { workers::refresh_user_data(forge, db, classifier).await })
}; };
let _ = { let _ = {
let config = config.clone(); let config = config.clone();
let storage = storage.clone();
let forge = forge.clone(); let forge = forge.clone();
let db = db.clone(); let db = db.clone();
workers.spawn(async move { workers::purge_spammer_accounts(config, forge, db).await }) workers
.spawn(async move { workers::purge_spammer_accounts(config, storage, forge, db).await })
}; };
let _ = { let _ = {
let config = config.clone(); let config = config.clone();
let storage = storage.clone();
let forge = forge.clone(); let forge = forge.clone();
let db = db.clone(); let db = db.clone();
workers.spawn(async move { workers::lock_and_notify_users(config, forge, db).await }) workers
.spawn(async move { workers::lock_and_notify_users(config, storage, forge, db).await })
}; };
println!("Listening on http://127.0.0.1:8080"); println!("Listening on http://127.0.0.1:8080");

167
src/storage.rs Normal file
View file

@ -0,0 +1,167 @@
use anyhow::Context;
use aws_sdk_s3 as s3;
use std::fs::File;
use std::io::prelude::{Read, Write};
use std::path::{Path, PathBuf};
pub enum Storage {
LocalFiles { dir: PathBuf },
S3 { client: s3::Client, bucket: String },
}
use Storage::*;
impl Storage {
pub fn from_local_dir(dir: PathBuf) -> Self {
LocalFiles { dir }
}
pub async fn from_s3(bucket: String) -> Self {
let sdk_config = aws_config::load_from_env().await;
let config = aws_sdk_s3::config::Builder::from(&sdk_config)
.force_path_style(true)
.build();
let client = aws_sdk_s3::Client::from_conf(config);
S3 { client, bucket }
}
pub async fn from_env() -> anyhow::Result<Self> {
match std::env::var("STORAGE_BACKEND")
.context("reading the STORAGE_BACKEND environment variable")?
.as_ref()
{
"local" => {
let dir = match std::env::var("STORAGE_LOCAL_DIR") {
Ok(dir) => dir,
Err(_) => ".".to_string(),
};
Ok(Self::from_local_dir(PathBuf::from(dir)))
}
"s3" => {
let bucket = std::env::var("STORAGE_S3_BUCKET")
.context("reading the STORAGE_S3_BUCKET environment variable")?;
Ok(Self::from_s3(bucket).await)
}
other => {
anyhow::bail!("STORAGE_BACKEND: unexpected value {other} (expected: local/s3)")
}
}
}
fn read_file(dir: &Path, path: &str) -> anyhow::Result<Option<Vec<u8>>> {
let path = dir.join(path);
if path.is_file() {
let mut file = File::open(path)?;
let mut data = vec![];
file.read_to_end(&mut data)?;
Ok(Some(data))
} else {
Ok(None)
}
}
fn write_file(dir: &Path, path: &str, data: Vec<u8>) -> anyhow::Result<()> {
let path = dir.join(path);
let mut file = File::create(path)?;
file.write_all(&data)?;
Ok(())
}
async fn read_s3(
client: &s3::Client,
bucket: &str,
path: &str,
) -> anyhow::Result<Option<Vec<u8>>> {
let output = client.get_object().bucket(bucket).key(path).send().await;
match output {
Ok(output) => {
let data = output
.body
.collect()
.await
.context(format!("error reading {} from bucket {}", path, bucket))?
.into_bytes()
.to_vec();
Ok(Some(data))
}
Err(e) if is_no_such_key_error(&e) => Ok(None),
Err(err) => Err(err)?,
}
}
async fn write_s3(
client: &s3::Client,
bucket: &str,
path: &str,
data: Vec<u8>,
) -> anyhow::Result<()> {
client
.put_object()
.bucket(bucket)
.key(path)
.body(s3::primitives::ByteStream::from(data))
.send()
.await?;
Ok(())
}
pub async fn write(&self, path: &str, data: Vec<u8>) -> anyhow::Result<()> {
match self {
LocalFiles { dir } => Self::write_file(dir, path, data),
S3 { client, bucket } => Self::write_s3(client, bucket, path, data).await,
}
}
pub async fn read(&self, path: &str) -> anyhow::Result<Option<Vec<u8>>> {
match self {
LocalFiles { dir } => Self::read_file(dir, path),
S3 { client, bucket } => Self::read_s3(client, bucket, path).await,
}
}
}
use s3::error::SdkError;
use s3::operation::get_object::GetObjectError;
fn is_no_such_key_error<R>(err: &SdkError<GetObjectError, R>) -> bool {
match err {
SdkError::ServiceError(e) => matches!(e.err(), GetObjectError::NoSuchKey(_)),
_ => false,
}
}
use crate::classifier::Classifier;
pub async fn load_classifier(storage: &Storage) -> anyhow::Result<Classifier> {
match storage.read("model.json").await? {
Some(data) => Ok(serde_json::from_slice(&data)?),
None => Ok(Classifier::new()),
}
}
use crate::userdb::UserDb;
pub async fn load_userdb(storage: &Storage) -> anyhow::Result<Option<UserDb>> {
if let Some(data) = storage.read("db.json").await? {
Ok(Some(serde_json::from_slice(&data)?))
} else {
Ok(None)
}
}
pub async fn store_userdb(storage: &Storage, userdb: &UserDb) -> anyhow::Result<()> {
storage
.write("db.json", serde_json::to_vec(userdb)?)
.await?;
Ok(())
}
use crate::db::Db;
pub async fn store_db(storage: &Storage, db: &Db) -> anyhow::Result<()> {
let userdb_bytes = db.with_userdb(serde_json::to_vec)?;
let classifier_bytes = db.with_classifier(serde_json::to_vec)?;
storage.write("db.json", userdb_bytes).await?;
storage.write("model.json", classifier_bytes).await?;
Ok(())
}

169
src/userdb.rs Normal file
View file

@ -0,0 +1,169 @@
use crate::data::*;
use serde::{Deserialize, Serialize};
use std::collections::{hash_map, HashMap};
use std::fmt;
use std::time::SystemTime;
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
pub enum IsSpam {
Legit,
Spam {
classified_at: SystemTime,
locked: bool,
notified: bool,
},
}
impl IsSpam {
pub fn as_bool(&self) -> bool {
match self {
IsSpam::Legit => true,
IsSpam::Spam { .. } => false,
}
}
pub fn from_bool(b: bool) -> IsSpam {
if b {
IsSpam::Spam {
classified_at: SystemTime::now(),
locked: false,
notified: false,
}
} else {
IsSpam::Legit
}
}
}
impl fmt::Display for IsSpam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IsSpam::Legit => write!(f, "legit"),
IsSpam::Spam { .. } => write!(f, "spam"),
}
}
}
#[derive(Serialize, Deserialize)]
pub struct UserDb {
users: HashMap<UserId, UserData>,
is_spam: HashMap<UserId, IsSpam>,
last_scrape: SystemTime,
}
impl UserDb {
// Creating
pub fn from_users(
users: HashMap<UserId, UserData>,
is_spam: HashMap<UserId, IsSpam>,
last_scrape: SystemTime,
) -> Self {
Self {
users,
is_spam,
last_scrape,
}
}
// Reading
pub fn userdata(&self, uid: UserId) -> Option<&UserData> {
self.users.get(&uid)
}
pub fn is_spam(&self, uid: UserId) -> Option<IsSpam> {
self.is_spam.get(&uid).copied()
}
pub fn last_scrape(&self) -> SystemTime {
self.last_scrape
}
pub fn nb_users(&self) -> usize {
self.users.len()
}
pub fn nb_classified(&self) -> usize {
self.is_spam.len()
}
pub fn unclassified_users(&self) -> Vec<(UserId, &UserData)> {
self.users
.iter()
.filter(|(user_id, _)| !self.is_spam.contains_key(user_id))
.map(|(id, d)| (*id, d))
.collect()
}
pub fn classified_users(&self) -> Vec<(UserId, &UserData, IsSpam)> {
self.users
.iter()
.filter_map(|(user_id, user_data)| {
self.is_spam
.get(user_id)
.map(|is_spam| (user_id, user_data, *is_spam))
})
.map(|(id, d, s)| (*id, d, s))
.collect()
}
// Updating
pub fn set_spam(&mut self, uid: UserId, is_spam: Option<IsSpam>) {
match is_spam {
Some(is_spam) => self.is_spam.insert(uid, is_spam),
None => self.is_spam.remove(&uid),
};
}
pub fn remove_user(&mut self, uid: UserId) {
self.users.remove(&uid);
self.is_spam.remove(&uid);
}
// Internal helpers
// XXX remove?
// fn recompute_tokens_for(&mut self, uid: UserId) {
// self.tokens.insert(uid, self.users.get(&uid).unwrap().to_tokens());
// }
// fn recompute_tokens(&mut self) {
// for (id, user) in &self.users {
// self.tokens.insert(*id, user.to_tokens());
// }
// }
// fn recompute_scores(&mut self, classifier: &Classifier) {
// for (id, tokens) in &self.tokens {
// self.score.insert(*id, classifier.score(tokens));
// }
// }
}
pub struct Iter<'a> {
iter_users: hash_map::Iter<'a, UserId, UserData>,
is_spam: &'a HashMap<UserId, IsSpam>,
}
impl<'a> Iterator for Iter<'a> {
type Item = (UserId, &'a UserData, Option<IsSpam>);
fn next(&mut self) -> Option<(UserId, &'a UserData, Option<IsSpam>)> {
self.iter_users.next().map(|(uid, udata)| {
let is_spam = self.is_spam.get(uid).copied();
(*uid, udata, is_spam)
})
}
}
impl<'a> IntoIterator for &'a UserDb {
type Item = (UserId, &'a UserData, Option<IsSpam>);
type IntoIter = Iter<'a>;
fn into_iter(self) -> Iter<'a> {
Iter {
iter_users: self.users.iter(),
is_spam: &self.is_spam,
}
}
}

View file

@ -1,13 +1,14 @@
use crate::classifier::Classifier;
use crate::data::UserId; use crate::data::UserId;
use crate::db::{Db, IsSpam}; use crate::db::Db;
use crate::email; use crate::email;
use crate::scrape; use crate::scrape;
use crate::userdb::{IsSpam, UserDb};
use crate::{storage, storage::Storage};
use anyhow::anyhow; use anyhow::anyhow;
use forgejo_api::Forgejo; use forgejo_api::Forgejo;
use std::collections::HashMap; use std::collections::HashMap;
use std::path::Path; use std::sync::Arc;
use std::sync::{Arc, Mutex}; use std::time::SystemTime;
use crate::FORGEJO_POLL_DELAY; use crate::FORGEJO_POLL_DELAY;
use crate::GRACE_PERIOD; use crate::GRACE_PERIOD;
@ -16,14 +17,9 @@ use crate::{GUESS_LEGIT_THRESHOLD, GUESS_SPAM_THRESHOLD};
// Worker to refresh user data by periodically polling Forgejo // Worker to refresh user data by periodically polling Forgejo
async fn try_refresh_user_data( async fn try_refresh_user_data(storage: &Storage, forge: &Forgejo, db: &Db) -> anyhow::Result<()> {
forge: &Forgejo,
db: Arc<Mutex<Db>>,
classifier: Arc<Mutex<Classifier>>,
) -> anyhow::Result<()> {
{ {
let db = &db.lock().unwrap(); let d = db.with_userdb(|udb| udb.last_scrape().elapsed())?;
let d = db.last_scrape.elapsed()?;
if d < FORGEJO_POLL_DELAY { if d < FORGEJO_POLL_DELAY {
return Ok(()); return Ok(());
} }
@ -32,49 +28,51 @@ async fn try_refresh_user_data(
eprintln!("Fetching user data"); eprintln!("Fetching user data");
let users = scrape::get_user_data(forge).await?; let users = scrape::get_user_data(forge).await?;
let db: &mut Db = &mut db.lock().unwrap(); {
let classifier = &classifier.lock().unwrap();
// NB: Some user accounts may have been deleted since last fetch (hopefully // NB: Some user accounts may have been deleted since last fetch (hopefully
// they were spammers). // they were spammers).
// Such users will appear in the current [db] but not in the new [users]. // Such users will appear in the current [db] but not in the new [users].
// We don't want to keep them in the database, so we rebuild a fresh [db] // We don't want to keep them in the database, so we rebuild a fresh [db]
// containing only data for users who still exist. // containing only data for users who still exist.
let mut newdb = Db::from_users(users, HashMap::new(), classifier); let mut newdb = UserDb::from_users(users, HashMap::new(), SystemTime::now());
// Import spam classification from the previous Db let users: Vec<(UserId, Vec<String>, String)> = newdb
for (&user_id, user_data) in &newdb.users { .unclassified_users()
let &score = newdb.score.get(&user_id).unwrap(); .iter()
if let Some(&user_was_spam) = db.is_spam.get(&user_id) { .map(|(user_id, user_data)| (*user_id, user_data.to_tokens(), user_data.login.clone()))
.collect();
// Import spam classification from the previous Db.
// (Initially, all users are "unclassified" in newdb.)
for (user_id, tokens, login) in users.into_iter() {
let score = db.with_classifier(|c| c.score(&tokens));
if let Some(user_was_spam) = db.with_userdb(|u| u.is_spam(user_id)) {
if (user_was_spam.as_bool() && score < GUESS_SPAM_THRESHOLD) if (user_was_spam.as_bool() && score < GUESS_SPAM_THRESHOLD)
|| (!user_was_spam.as_bool() && score > GUESS_LEGIT_THRESHOLD) || (!user_was_spam.as_bool() && score > GUESS_LEGIT_THRESHOLD)
{ {
eprintln!( eprintln!(
"Score for user {} changed past threshold; discarding our current classification", "Score for user {login} changed past threshold; discarding our current classification",
user_data.login
); );
} else { } else {
newdb.is_spam.insert(user_id, user_was_spam); newdb.set_spam(user_id, Some(user_was_spam));
} }
} }
} }
// switch to [newdb] // switch to [newdb]
let _ = std::mem::replace(db, newdb); db.replace_userdb(newdb);
}
db.store_to_path(Path::new("db.json")).unwrap(); // FIXME let res = storage::store_db(storage, db).await;
res.unwrap(); // FIXME
Ok(()) Ok(())
} }
pub async fn refresh_user_data( pub async fn refresh_user_data(storage: Arc<Storage>, forge: Arc<Forgejo>, db: Db) {
forge: Arc<Forgejo>,
db: Arc<Mutex<Db>>,
classifier: Arc<Mutex<Classifier>>,
) {
loop { loop {
if let Err(e) = try_refresh_user_data(&forge, db.clone(), classifier.clone()).await { if let Err(e) = try_refresh_user_data(&storage, &forge, &db).await {
eprintln!("Error refreshing user data: {:?}", e); eprintln!("Error refreshing user data: {:?}", e);
} }
tokio::time::sleep(FORGEJO_POLL_DELAY.mul_f32(0.1)).await; tokio::time::sleep(FORGEJO_POLL_DELAY.mul_f32(0.1)).await;
@ -100,15 +98,19 @@ async fn try_purge_account(config: &Config, forge: &Forgejo, login: &str) -> any
Ok(()) Ok(())
} }
pub async fn purge_spammer_accounts(config: Arc<Config>, forge: Arc<Forgejo>, db: Arc<Mutex<Db>>) { pub async fn purge_spammer_accounts(
config: Arc<Config>,
storage: Arc<Storage>,
forge: Arc<Forgejo>,
db: Db,
) {
loop { loop {
let mut classified_users = Vec::new(); let classified_users: Vec<_> = db.with_userdb(|u| {
{ u.classified_users()
let db = &db.lock().unwrap(); .into_iter()
for (id, user, is_spam) in db.classified_users() { .map(|(user_id, user, is_spam)| (user_id, user.login.clone(), is_spam))
classified_users.push((id, user.login.clone(), is_spam)); .collect()
} });
}
for (user_id, login, is_spam) in classified_users { for (user_id, login, is_spam) in classified_users {
if let IsSpam::Spam { if let IsSpam::Spam {
@ -141,12 +143,11 @@ pub async fn purge_spammer_accounts(config: Arc<Config>, forge: Arc<Forgejo>, db
eprintln!("Error while deleting spammer account {login}: {:?}", e) eprintln!("Error while deleting spammer account {login}: {:?}", e)
} else { } else {
eprintln!("Deleted spammer account {login}"); eprintln!("Deleted spammer account {login}");
let db = &mut db.lock().unwrap(); {
db.users.remove(&user_id); db.remove_user(user_id);
db.is_spam.remove(&user_id); }
db.score.remove(&user_id); let res = storage::store_db(&storage, &db).await;
db.tokens.remove(&user_id); res.unwrap(); // FIXME
db.store_to_path(Path::new("db.json")).unwrap(); // FIXME
} }
} }
_ => (), _ => (),
@ -193,23 +194,23 @@ async fn lock_user_account(forge: &Forgejo, username: &str) -> anyhow::Result<()
pub async fn try_lock_and_notify_user( pub async fn try_lock_and_notify_user(
config: &Config, config: &Config,
storage: &Storage,
forge: &Forgejo, forge: &Forgejo,
db: Arc<Mutex<Db>>, db: &Db,
user_id: UserId, user_id: UserId,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let (login, email, is_spam) = { let (login, email, is_spam) = db.with_userdb(|u| {
let db = &db.lock().unwrap(); let user = u.userdata(user_id).unwrap();
let user = db.users.get(&user_id).unwrap(); (user.login.clone(), user.email.clone(), u.is_spam(user_id))
let is_spam = match db.is_spam.get(&user_id) { });
let is_spam = match is_spam{
Some(IsSpam::Spam { Some(IsSpam::Spam {
classified_at, classified_at,
locked, locked,
notified, notified,
}) => Some((*classified_at, *locked, *notified)), }) => Some((classified_at, locked, notified)),
_ => None, _ => None,
}; };
(user.login.clone(), user.email.clone(), is_spam)
};
if let Some((classified_at, locked, notified)) = is_spam { if let Some((classified_at, locked, notified)) = is_spam {
if !locked { if !locked {
@ -222,16 +223,16 @@ pub async fn try_lock_and_notify_user(
ActuallyBan::No => eprintln!("[Simulating: lock account of user {login}]"), ActuallyBan::No => eprintln!("[Simulating: lock account of user {login}]"),
} }
let db = &mut db.lock().unwrap(); db.set_spam(
db.is_spam.insert(
user_id, user_id,
IsSpam::Spam { Some(IsSpam::Spam {
classified_at, classified_at,
locked: true, locked: true,
notified, notified,
}, }),
); );
db.store_to_path(Path::new("db.json")).unwrap(); // FIXME
storage::store_db(storage, db).await.unwrap(); // FIXME
} }
if !notified { if !notified {
@ -245,16 +246,17 @@ pub async fn try_lock_and_notify_user(
eprintln!("[Simulating: send notification email to user {login}]") eprintln!("[Simulating: send notification email to user {login}]")
} }
} }
let db = &mut db.lock().unwrap();
db.is_spam.insert( db.set_spam(
user_id, user_id,
IsSpam::Spam { Some(IsSpam::Spam {
classified_at, classified_at,
locked: true, locked: true,
notified: true, notified: true,
}, }),
); );
db.store_to_path(Path::new("db.json")).unwrap(); // FIXME
storage::store_db(storage, db).await.unwrap(); // FIXME
} }
Ok(()) Ok(())
@ -266,19 +268,25 @@ pub async fn try_lock_and_notify_user(
} }
} }
pub async fn lock_and_notify_users(config: Arc<Config>, forge: Arc<Forgejo>, db: Arc<Mutex<Db>>) { pub async fn lock_and_notify_users(
config: Arc<Config>,
storage: Arc<Storage>,
forge: Arc<Forgejo>,
db: Db,
) {
let mut spammers = Vec::new(); let mut spammers = Vec::new();
{ {
let db = &db.lock().unwrap(); db.with_userdb(|udb| {
for (id, user, is_spam) in db.classified_users() { for (id, user, is_spam) in udb.classified_users() {
if is_spam.as_bool() { if is_spam.as_bool() {
spammers.push((id, user.login.clone())) spammers.push((id, user.login.clone()))
} }
} }
})
} }
for (user_id, login) in spammers { for (user_id, login) in spammers {
try_lock_and_notify_user(&config, &forge, db.clone(), user_id) try_lock_and_notify_user(&config, &storage, &forge, &db, user_id)
.await .await
.unwrap_or_else(|err| eprintln!("Failed to lock or notify user {login}: {err}")); .unwrap_or_else(|err| eprintln!("Failed to lock or notify user {login}: {err}"));
} }