Add S3 as storage backend, refactor db & storage code

This commit is contained in:
Armaël Guéneau 2024-12-23 00:50:01 +01:00
parent af38eae2c3
commit edc49a6d1d
11 changed files with 1608 additions and 535 deletions

1118
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,5 @@
[package]
name = "forgejo-antispam"
name = "forgery"
version = "0.1.0"
edition = "2021"
@ -21,6 +21,8 @@ actix-files = "0.6"
unicode-segmentation = "1"
lettre = { version = "0.11", features = ["builder", "smtp-transport", "rustls-tls"], default-features = false }
include_dir = "0.7"
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
aws-sdk-s3 = "1.66.0"
[profile.profiling]
inherits = "dev"

View file

@ -27,12 +27,28 @@ Forgery reads the following environment variables:
default) or set to `false`, no actual action is taken: spammers are only
listed in the database. The variable should be set in production, but probably
not for testing.
- `STORAGE_BACKEND`: either `local` (default) or `s3`. Chose `local` to store
the application state to local files, or `s3` to store them in S3-compatible
storage (see below for corresponding configuration variables).
Environment variables that are relevant when `ACTUALLY_BAN_USERS=true`:
Environment variables read when `ACTUALLY_BAN_USERS=true`:
- `SMTP_ADDRESS`: address of the SMTP relay used to send email notifications
- `SMTP_USERNAME`: SMTP username
- `SMTP_PASSWORD`: SMTP password
Environment variables read when `STORAGE_BACKEND=local`:
- `STORAGE_LOCAL_DIR`: path to a local directory where to store the application
data (as two files `db.json` and `model.json`). Defaults to `.` if not
defined.
Environment variables read when `STORAGE_BACKEND=s3`:
- `STORAGE_S3_BUCKET`: name of the bucket where to store the application data
(as two entries `db.json` and `model.json`).
- `AWS_DEFAULT_REGION`: S3 endpoint region
- `AWS_ENDPOINT_URL`: S3 endpoint URL
- `AWS_ACCESS_KEY_ID`: S3 key id
- `AWS_SECRET_ACCESS_KEY`: S3 key secret
## Todos
- discuss the current design choices for when locking the account/sending a
@ -40,5 +56,5 @@ Environment variables that are relevant when `ACTUALLY_BAN_USERS=true`:
(Current behavior is to periodically retry, avoid deleting if the account
could not be locked, but delete the account after the grace period even if
the email could not be sent…)
- add backend to store data on garage instead of local files
- auth: add support for connecting to the forge using oauth?
- improve error handling

View file

@ -1,11 +1,8 @@
// code based on the bayespam crate
use std::collections::HashMap;
use std::fs::File;
use std::io;
use serde::{Deserialize, Serialize};
use serde_json::{from_reader, to_writer, to_writer_pretty};
use unicode_segmentation::UnicodeSegmentation;
const INITIAL_RATING: f32 = 0.5;
@ -30,23 +27,6 @@ impl Classifier {
Default::default()
}
/// Build a new classifier with a pre-trained model loaded from `file`.
pub fn new_from_pre_trained(file: &mut File) -> Result<Self, io::Error> {
let pre_trained_model = from_reader(file)?;
Ok(pre_trained_model)
}
/// Save the classifier to `file` as JSON.
/// The JSON will be pretty printed if `pretty` is `true`.
pub fn save(&self, file: &mut File, pretty: bool) -> Result<(), io::Error> {
if pretty {
to_writer_pretty(file, &self)?;
} else {
to_writer(file, &self)?;
}
Ok(())
}
/// Split `msg` into a list of words.
pub fn into_word_list(msg: &str) -> Vec<String> {
let word_list = msg.unicode_words().collect::<Vec<&str>>();

View file

@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct UserId(pub i64);
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UserData {
pub login: String,
pub email: String,
@ -20,7 +20,7 @@ pub struct UserData {
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct RepoId(pub i64);
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RepoData {
pub name: String,
pub description: Option<String>,
@ -29,7 +29,7 @@ pub struct RepoData {
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct IssueId(pub i64);
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct IssueData {
pub title: String,
pub body: String,

218
src/db.rs
View file

@ -1,138 +1,126 @@
use crate::classifier::Classifier;
use crate::data::*;
use serde::{Deserialize, Serialize};
use crate::userdb::{IsSpam, UserDb};
use std::collections::HashMap;
use std::fmt;
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::Path;
use std::time::SystemTime;
use std::sync::{Arc, Mutex};
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
pub enum IsSpam {
Legit,
Spam {
classified_at: SystemTime,
locked: bool,
notified: bool,
},
}
impl IsSpam {
pub fn as_bool(&self) -> bool {
match self {
IsSpam::Legit => true,
IsSpam::Spam { .. } => false,
}
}
pub fn from_bool(b: bool) -> IsSpam {
if b {
IsSpam::Spam {
classified_at: SystemTime::now(),
locked: false,
notified: false,
}
} else {
IsSpam::Legit
}
}
}
impl fmt::Display for IsSpam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IsSpam::Legit => write!(f, "legit"),
IsSpam::Spam { .. } => write!(f, "spam"),
}
}
}
// TODO (?): make the fields private and provide an API that automatically
// recomputes the caches when necessary?
#[derive(Clone)]
pub struct Db {
// persisted data
pub users: HashMap<UserId, UserData>,
pub is_spam: HashMap<UserId, IsSpam>,
pub last_scrape: SystemTime,
// caches: computed from persisted data on load
pub score: HashMap<UserId, f32>,
pub tokens: HashMap<UserId, Vec<String>>,
userdb: Arc<Mutex<UserDb>>,
classifier: Arc<Mutex<Classifier>>,
cache: Arc<Mutex<Cache>>,
}
struct Cache {
score: HashMap<UserId, f32>,
tokens: HashMap<UserId, Vec<String>>,
}
impl Db {
pub fn recompute_tokens(&mut self) {
for (id, user) in &self.users {
self.tokens.insert(*id, user.to_tokens());
// Creating
pub fn create(userdb: UserDb, classifier: Classifier) -> Self {
let cache = Cache::create(&userdb, &classifier);
Self {
userdb: Arc::new(Mutex::new(userdb)),
classifier: Arc::new(Mutex::new(classifier)),
cache: Arc::new(Mutex::new(cache)),
}
}
pub fn recompute_scores(&mut self, classifier: &Classifier) {
for (id, tokens) in &self.tokens {
self.score.insert(*id, classifier.score(tokens));
}
pub fn replace_userdb(&self, newdb: UserDb) {
let userdb: &mut UserDb = &mut self.userdb.lock().unwrap();
let _ = std::mem::replace(userdb, newdb);
let new_cache = Cache::create(userdb, &self.classifier.lock().unwrap());
let cache: &mut Cache = &mut self.cache.lock().unwrap();
let _ = std::mem::replace(cache, new_cache);
}
pub fn from_path(path: &Path, classifier: &Classifier) -> anyhow::Result<Self> {
let file = File::open(path)?;
let (users, is_spam, last_scrape) = serde_json::from_reader(BufReader::new(file))?;
let mut db = Db {
users,
is_spam,
last_scrape,
tokens: HashMap::new(),
score: HashMap::new(),
};
db.recompute_tokens();
db.recompute_scores(classifier);
Ok(db)
// Reading
pub fn with_userdb<F, T>(&self, f: F) -> T
where
F: FnOnce(&UserDb) -> T,
{
let lock = &self.userdb.lock().unwrap();
f(lock)
}
pub fn from_users(
users: HashMap<UserId, UserData>,
is_spam: HashMap<UserId, IsSpam>,
classifier: &Classifier,
) -> Db {
let mut db = Db {
users,
is_spam,
last_scrape: SystemTime::now(),
tokens: HashMap::new(),
score: HashMap::new(),
};
db.recompute_tokens();
db.recompute_scores(classifier);
db
pub fn login(&self, uid: UserId) -> Option<String> {
self.with_userdb(|u| u.userdata(uid).map(|d| d.login.clone()))
}
pub fn store_to_path(&self, path: &Path) -> anyhow::Result<()> {
let file = File::create(path)?;
let dat: (
&HashMap<UserId, UserData>,
&HashMap<UserId, IsSpam>,
SystemTime,
) = (&self.users, &self.is_spam, self.last_scrape);
serde_json::to_writer(BufWriter::new(file), &dat)?;
Ok(())
pub fn score(&self, uid: UserId) -> Option<f32> {
self.cache.lock().unwrap().score.get(&uid).copied()
}
pub fn unclassified_users(&self) -> Vec<(UserId, &UserData)> {
self.users
.iter()
.filter(|(user_id, _)| !self.is_spam.contains_key(user_id))
.map(|(id, d)| (*id, d))
.collect()
pub fn with_tokens<F>(&self, uid: UserId, f: F)
where
F: FnOnce(Option<&[String]>),
{
let lock = self.cache.lock().unwrap();
f(lock.tokens.get(&uid).map(|v| &**v))
}
pub fn classified_users(&self) -> Vec<(UserId, &UserData, IsSpam)> {
self.users
.iter()
.filter_map(|(user_id, user_data)| {
self.is_spam
.get(user_id)
.map(|is_spam| (user_id, user_data, *is_spam))
})
.map(|(id, d, s)| (*id, d, s))
.collect()
// Updating
// pub fn recompute_scores(&self, classifier: &Classifier) {
// let lock = &mut self.inner.lock().unwrap();
// lock.recompute_scores(classifier)
// }
pub fn set_spam(&self, uid: UserId, is_spam: Option<IsSpam>) {
let udb = &mut self.userdb.lock().unwrap();
udb.set_spam(uid, is_spam)
}
pub fn with_classifier<F, T>(&self, f: F) -> T
where
F: FnOnce(&Classifier) -> T,
{
let classifier = &self.classifier.lock().unwrap();
f(classifier)
}
pub fn with_classifier_mut<F, T>(&self, f: F) -> T
where
F: FnOnce(&mut Classifier) -> T,
{
let classifier = &mut self.classifier.lock().unwrap();
let res = f(classifier);
// recompute scores
let cache: &mut Cache = &mut self.cache.lock().unwrap();
for (id, tokens) in &cache.tokens {
cache.score.insert(*id, classifier.score(tokens));
}
res
}
pub fn remove_user(&self, uid: UserId) {
let userdb = &mut self.userdb.lock().unwrap();
userdb.remove_user(uid);
let cache = &mut self.cache.lock().unwrap();
cache.remove_user(uid);
}
}
impl Cache {
fn create(userdb: &UserDb, classifier: &Classifier) -> Self {
let mut tokens = HashMap::new();
let mut score = HashMap::new();
for (id, user, _) in userdb {
let user_tokens = user.to_tokens();
let user_score = classifier.score(&user_tokens);
tokens.insert(id, user_tokens);
score.insert(id, user_score);
}
Cache { tokens, score }
}
fn remove_user(&mut self, uid: UserId) {
self.score.remove(&uid);
self.tokens.remove(&uid);
}
}

View file

@ -50,6 +50,7 @@ pub async fn send_locked_account_notice(
let email = Message::builder()
.from(smtp.username.parse().unwrap())
.to(email.parse()?)
.reply_to(admin_contact_email.parse().unwrap())
.subject(format!(
"[Forgejo {org_name}] Your account was marked as spam and will be deleted in {} days",
grace_period_days

View file

@ -5,9 +5,7 @@ use lazy_static::lazy_static;
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::File;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use std::time::Duration;
use tera::Tera;
use url::Url;
@ -17,12 +15,15 @@ mod data;
mod db;
mod email;
mod scrape;
mod storage;
mod userdb;
mod workers;
use classifier::Classifier;
use data::*;
use db::{Db, IsSpam};
use db::Db;
use email::SmtpConfig;
use storage::Storage;
use userdb::{IsSpam, UserDb};
// Fetch user data from forgejo from time to time
const FORGEJO_POLL_DELAY: Duration = Duration::from_secs(11 * 3600); // 11 hours
@ -95,9 +96,10 @@ struct AppState {
config: Arc<Config>,
// authenticated access to the forgejo instance
forge: Arc<Forgejo>,
// runtime state (to be persisted in the storage when modified)
db: Arc<Mutex<Db>>,
classifier: Arc<Mutex<Classifier>>,
// handle to the storage backend
storage: Arc<Storage>,
// persistent state (written to the storage when modified)
db: Db,
}
fn forge(url: &Url) -> anyhow::Result<Forgejo> {
@ -108,28 +110,22 @@ fn forge(url: &Url) -> anyhow::Result<Forgejo> {
Ok(forge)
}
async fn load_db(forge: &Forgejo) -> anyhow::Result<(Db, Classifier)> {
let model_path = Path::new("model.json");
let classifier = if model_path.is_file() {
Classifier::new_from_pre_trained(&mut File::open(model_path)?)?
} else {
Classifier::new()
};
let db_path = Path::new("db.json");
let db: Db = if db_path.is_file() {
Db::from_path(db_path, &classifier)?
} else {
let db = Db::from_users(
async fn load_db(storage: &Storage, forge: &Forgejo) -> anyhow::Result<Db> {
let classifier = storage::load_classifier(storage).await?;
let userdb = match storage::load_userdb(storage).await? {
Some(db) => db,
None => {
let db = UserDb::from_users(
scrape::get_user_data(forge).await?,
HashMap::new(),
&classifier,
std::time::SystemTime::now(),
);
db.store_to_path(db_path)?;
storage::store_userdb(storage, &db).await?;
db
}
};
Ok((db, classifier))
Ok(Db::create(userdb, classifier))
}
// Register a list of decisions taken by the admin using the webpage, checking
@ -144,97 +140,89 @@ async fn load_db(forge: &Forgejo) -> anyhow::Result<(Db, Classifier)> {
// NB: some of the input decisions may be no-ops: when using the page to edit
// existing classifications, the webform sends the list of all existing and
// changed classifications.
fn set_spam(
db: &mut Db,
classifier: &mut Classifier,
ids: &[(UserId, bool)],
overwrite: bool,
) -> Vec<UserId> {
let mut spammers = Vec::new();
fn set_spam(db: &Db, ids: &[(UserId, bool)], overwrite: bool) -> Vec<UserId> {
let mut updated_spam = vec![];
for &(user_id, is_spam) in ids {
let mut update_classification = false;
for &(user_id, set_spam) in ids {
let login = db.login(user_id).unwrap();
match db.is_spam.get(&user_id) {
Some(&was_spam) if overwrite && was_spam.as_bool() != is_spam => {
eprintln!(
"User {}: changing classification from {} to {}",
db.users.get(&user_id).unwrap().login,
was_spam,
is_spam
);
// Training the classifier again is somewhat hackish in this
// case: we already trained the classifier on the previous
// classification, possibly with the same tokens.
match db.with_userdb(|u| u.is_spam(user_id)) {
Some(was_spam) if overwrite && was_spam.as_bool() != set_spam => {
eprintln!("User {login}: changing classification from {was_spam} to {set_spam}");
db.set_spam(user_id, Some(IsSpam::from_bool(set_spam)));
// We train the classifier again, which is somewhat hackish: we
// already trained it on the previous classification, possibly
// with the same tokens.
//
// Ideally we would undo the previous training and train with
// the correct classification now, but the classifier has no way
// to easily undo a previous training (we don't know whether the
// tokens that we have now are the same as the one that were
// used previously).
update_classification = true;
updated_spam.push((user_id, set_spam));
}
Some(&was_spam) if !overwrite && was_spam.as_bool() != is_spam => {
Some(was_spam) if !overwrite && was_spam.as_bool() != set_spam => {
// Classification conflict between concurrent queries.
// In this case we play it safe and discard the classification
// for this user; the user will need to be manually classified again.
eprintln!(
"Classification conflict for user {}; discarding our current classification",
db.users.get(&user_id).unwrap().login
"Classification conflict for user {login}; discarding our current classification"
);
db.is_spam.remove(&user_id);
db.set_spam(user_id, None);
}
None => {
update_classification = true;
db.set_spam(user_id, Some(IsSpam::from_bool(set_spam)));
updated_spam.push((user_id, set_spam));
}
Some(was_spam) => {
assert!(was_spam.as_bool() == is_spam);
assert!(was_spam.as_bool() == set_spam);
// nothing to do.
// In particular, keep the spam classification time as is.
}
}
}
if update_classification {
db.is_spam.insert(user_id, IsSpam::from_bool(is_spam));
// if we just classified the user as spam, add it to the list
if is_spam {
spammers.push(user_id)
let mut new_spammers = vec![];
// update the classifier
db.with_classifier_mut(|classifier| {
for &(user_id, set_spam) in &updated_spam {
// if we just classified the user as spam, add it to the list of new
// spammers
if set_spam {
new_spammers.push(user_id)
}
// Train the classifier with tokens from the user
let tokens = db.tokens.get(&user_id).unwrap();
if is_spam {
db.with_tokens(user_id, |tokens| {
let tokens = tokens.unwrap();
if set_spam {
classifier.train_spam(tokens)
} else {
classifier.train_ham(tokens)
}
})
}
}
});
eprintln!("recomputing user scores");
db.recompute_scores(classifier);
spammers
new_spammers
}
async fn apply_classification(
config: &Config,
storage: &Storage,
forge: &Forgejo,
db: Arc<Mutex<Db>>,
classifier: Arc<Mutex<Classifier>>,
db: &Db,
ids: &[(UserId, bool)],
overwrite: bool,
) {
let spammers = {
let classifier = &mut classifier.lock().unwrap();
set_spam(&mut db.lock().unwrap(), classifier, ids, overwrite)
};
let spammers = set_spam(db, ids, overwrite);
for user in spammers {
let login = db.lock().unwrap().users.get(&user).unwrap().login.clone();
let login = db.login(user).unwrap();
// It is ok for any of these calls to fail now: a worker will periodically retry
// TODO: signal the worker to wake up instead of performing a manual call here
workers::try_lock_and_notify_user(config, forge, db.clone(), user)
workers::try_lock_and_notify_user(config, storage, forge, db, user)
.await
.unwrap_or_else(|err| eprintln!("Failed to lock or notify user {login}: {err}"));
}
@ -297,13 +285,14 @@ async fn index(
) -> impl Responder {
eprintln!("GET {}", req.uri());
let db = &data.db.lock().unwrap();
let db = &data.db;
let mut users: Vec<(UserId, &UserData, f32)> = db
.unclassified_users()
let mut users: Vec<(UserId, UserData, f32)> = db.with_userdb(|udb| {
udb.unclassified_users()
.into_iter()
.map(|(id, u)| (id, u, *db.score.get(&id).unwrap()))
.collect();
.map(|(id, u)| (id, u.clone(), db.score(id).unwrap()))
.collect()
});
let mut rng = rand::thread_rng();
users.shuffle(&mut rng);
@ -324,7 +313,7 @@ async fn index(
}
// compute the rough "spam score" (low/mid/high) and spam guess (true/false)
let users: Vec<(UserId, &UserData, f32, ApproxScore, bool)> = users
let users: Vec<(UserId, UserData, f32, ApproxScore, bool)> = users
.into_iter()
.map(|(id, u, score)| {
(
@ -337,8 +326,8 @@ async fn index(
})
.collect();
let users_count = db.users.len();
let classified_count = db.is_spam.len();
let users_count = db.with_userdb(|udb| udb.nb_users());
let classified_count = db.with_userdb(|udb| udb.nb_classified());
let mut context = tera::Context::new();
context.insert("forge_url", &data.config.forge_url.to_string());
@ -369,30 +358,26 @@ async fn post_classified(
apply_classification(
&data.config,
&data.storage,
&data.forge,
data.db.clone(),
data.classifier.clone(),
&data.db,
&updates,
overwrite,
)
.await;
data.db
.lock()
.unwrap()
.store_to_path(Path::new("db.json"))
.unwrap(); // FIXME
data.classifier
.lock()
.unwrap()
.save(&mut File::create(Path::new("model.json")).unwrap(), false)
.unwrap(); // FIXME
let res = storage::store_db(&data.storage, &data.db).await;
eprintln!("done");
HttpResponse::SeeOther()
match res {
Ok(()) => HttpResponse::SeeOther()
.insert_header(("Location", req.uri().to_string()))
.finish()
.finish(),
Err(e) => {
HttpResponse::InternalServerError().body(format!("Internal server error:\n\n{e}"))
}
}
}
#[post("/")]
@ -421,13 +406,13 @@ async fn classified(
) -> impl Responder {
eprintln!("GET {}", req.uri());
let db = &data.db.lock().unwrap();
let mut users: Vec<(UserId, &UserData, f32, bool)> = db
.classified_users()
let db = &data.db;
let mut users: Vec<(UserId, UserData, f32, bool)> = db.with_userdb(|udb| {
udb.classified_users()
.into_iter()
.map(|(id, u, s)| (id, u, *db.score.get(&id).unwrap(), s.as_bool()))
.collect();
.map(|(id, u, s)| (id, u.clone(), db.score(id).unwrap(), s.as_bool()))
.collect()
});
// sort "spam first"
users.sort_by_key(|(_, _, score, _)| 1000 - (score * 1000.) as u64);
@ -465,15 +450,14 @@ async fn main() -> anyhow::Result<()> {
let config = Arc::new(Config::from_env().await?);
let forge = Arc::new(forge(&config.forge_url)?);
let storage = Arc::new(Storage::from_env().await?);
eprintln!("Load users and repos");
let (db, classifier) = load_db(&forge).await?;
let db = Arc::new(Mutex::new(db));
let classifier = Arc::new(Mutex::new(classifier));
let db = load_db(&storage, &forge).await?;
let st = web::Data::new(AppState {
db: db.clone(),
classifier: classifier.clone(),
storage: storage.clone(),
forge: forge.clone(),
config: config.clone(),
});
@ -481,22 +465,26 @@ async fn main() -> anyhow::Result<()> {
let mut workers = tokio::task::JoinSet::new();
let _ = {
let storage = storage.clone();
let forge = forge.clone();
let db = db.clone();
let classifier = classifier.clone();
workers.spawn(async move { workers::refresh_user_data(forge, db, classifier).await })
workers.spawn(async move { workers::refresh_user_data(storage, forge, db).await })
};
let _ = {
let config = config.clone();
let storage = storage.clone();
let forge = forge.clone();
let db = db.clone();
workers.spawn(async move { workers::purge_spammer_accounts(config, forge, db).await })
workers
.spawn(async move { workers::purge_spammer_accounts(config, storage, forge, db).await })
};
let _ = {
let config = config.clone();
let storage = storage.clone();
let forge = forge.clone();
let db = db.clone();
workers.spawn(async move { workers::lock_and_notify_users(config, forge, db).await })
workers
.spawn(async move { workers::lock_and_notify_users(config, storage, forge, db).await })
};
println!("Listening on http://127.0.0.1:8080");

167
src/storage.rs Normal file
View file

@ -0,0 +1,167 @@
use anyhow::Context;
use aws_sdk_s3 as s3;
use std::fs::File;
use std::io::prelude::{Read, Write};
use std::path::{Path, PathBuf};
pub enum Storage {
LocalFiles { dir: PathBuf },
S3 { client: s3::Client, bucket: String },
}
use Storage::*;
impl Storage {
pub fn from_local_dir(dir: PathBuf) -> Self {
LocalFiles { dir }
}
pub async fn from_s3(bucket: String) -> Self {
let sdk_config = aws_config::load_from_env().await;
let config = aws_sdk_s3::config::Builder::from(&sdk_config)
.force_path_style(true)
.build();
let client = aws_sdk_s3::Client::from_conf(config);
S3 { client, bucket }
}
pub async fn from_env() -> anyhow::Result<Self> {
match std::env::var("STORAGE_BACKEND")
.context("reading the STORAGE_BACKEND environment variable")?
.as_ref()
{
"local" => {
let dir = match std::env::var("STORAGE_LOCAL_DIR") {
Ok(dir) => dir,
Err(_) => ".".to_string(),
};
Ok(Self::from_local_dir(PathBuf::from(dir)))
}
"s3" => {
let bucket = std::env::var("STORAGE_S3_BUCKET")
.context("reading the STORAGE_S3_BUCKET environment variable")?;
Ok(Self::from_s3(bucket).await)
}
other => {
anyhow::bail!("STORAGE_BACKEND: unexpected value {other} (expected: local/s3)")
}
}
}
fn read_file(dir: &Path, path: &str) -> anyhow::Result<Option<Vec<u8>>> {
let path = dir.join(path);
if path.is_file() {
let mut file = File::open(path)?;
let mut data = vec![];
file.read_to_end(&mut data)?;
Ok(Some(data))
} else {
Ok(None)
}
}
fn write_file(dir: &Path, path: &str, data: Vec<u8>) -> anyhow::Result<()> {
let path = dir.join(path);
let mut file = File::create(path)?;
file.write_all(&data)?;
Ok(())
}
async fn read_s3(
client: &s3::Client,
bucket: &str,
path: &str,
) -> anyhow::Result<Option<Vec<u8>>> {
let output = client.get_object().bucket(bucket).key(path).send().await;
match output {
Ok(output) => {
let data = output
.body
.collect()
.await
.context(format!("error reading {} from bucket {}", path, bucket))?
.into_bytes()
.to_vec();
Ok(Some(data))
}
Err(e) if is_no_such_key_error(&e) => Ok(None),
Err(err) => Err(err)?,
}
}
async fn write_s3(
client: &s3::Client,
bucket: &str,
path: &str,
data: Vec<u8>,
) -> anyhow::Result<()> {
client
.put_object()
.bucket(bucket)
.key(path)
.body(s3::primitives::ByteStream::from(data))
.send()
.await?;
Ok(())
}
pub async fn write(&self, path: &str, data: Vec<u8>) -> anyhow::Result<()> {
match self {
LocalFiles { dir } => Self::write_file(dir, path, data),
S3 { client, bucket } => Self::write_s3(client, bucket, path, data).await,
}
}
pub async fn read(&self, path: &str) -> anyhow::Result<Option<Vec<u8>>> {
match self {
LocalFiles { dir } => Self::read_file(dir, path),
S3 { client, bucket } => Self::read_s3(client, bucket, path).await,
}
}
}
use s3::error::SdkError;
use s3::operation::get_object::GetObjectError;
fn is_no_such_key_error<R>(err: &SdkError<GetObjectError, R>) -> bool {
match err {
SdkError::ServiceError(e) => matches!(e.err(), GetObjectError::NoSuchKey(_)),
_ => false,
}
}
use crate::classifier::Classifier;
pub async fn load_classifier(storage: &Storage) -> anyhow::Result<Classifier> {
match storage.read("model.json").await? {
Some(data) => Ok(serde_json::from_slice(&data)?),
None => Ok(Classifier::new()),
}
}
use crate::userdb::UserDb;
pub async fn load_userdb(storage: &Storage) -> anyhow::Result<Option<UserDb>> {
if let Some(data) = storage.read("db.json").await? {
Ok(Some(serde_json::from_slice(&data)?))
} else {
Ok(None)
}
}
pub async fn store_userdb(storage: &Storage, userdb: &UserDb) -> anyhow::Result<()> {
storage
.write("db.json", serde_json::to_vec(userdb)?)
.await?;
Ok(())
}
use crate::db::Db;
pub async fn store_db(storage: &Storage, db: &Db) -> anyhow::Result<()> {
let userdb_bytes = db.with_userdb(serde_json::to_vec)?;
let classifier_bytes = db.with_classifier(serde_json::to_vec)?;
storage.write("db.json", userdb_bytes).await?;
storage.write("model.json", classifier_bytes).await?;
Ok(())
}

169
src/userdb.rs Normal file
View file

@ -0,0 +1,169 @@
use crate::data::*;
use serde::{Deserialize, Serialize};
use std::collections::{hash_map, HashMap};
use std::fmt;
use std::time::SystemTime;
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
pub enum IsSpam {
Legit,
Spam {
classified_at: SystemTime,
locked: bool,
notified: bool,
},
}
impl IsSpam {
pub fn as_bool(&self) -> bool {
match self {
IsSpam::Legit => true,
IsSpam::Spam { .. } => false,
}
}
pub fn from_bool(b: bool) -> IsSpam {
if b {
IsSpam::Spam {
classified_at: SystemTime::now(),
locked: false,
notified: false,
}
} else {
IsSpam::Legit
}
}
}
impl fmt::Display for IsSpam {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IsSpam::Legit => write!(f, "legit"),
IsSpam::Spam { .. } => write!(f, "spam"),
}
}
}
#[derive(Serialize, Deserialize)]
pub struct UserDb {
users: HashMap<UserId, UserData>,
is_spam: HashMap<UserId, IsSpam>,
last_scrape: SystemTime,
}
impl UserDb {
// Creating
pub fn from_users(
users: HashMap<UserId, UserData>,
is_spam: HashMap<UserId, IsSpam>,
last_scrape: SystemTime,
) -> Self {
Self {
users,
is_spam,
last_scrape,
}
}
// Reading
pub fn userdata(&self, uid: UserId) -> Option<&UserData> {
self.users.get(&uid)
}
pub fn is_spam(&self, uid: UserId) -> Option<IsSpam> {
self.is_spam.get(&uid).copied()
}
pub fn last_scrape(&self) -> SystemTime {
self.last_scrape
}
pub fn nb_users(&self) -> usize {
self.users.len()
}
pub fn nb_classified(&self) -> usize {
self.is_spam.len()
}
pub fn unclassified_users(&self) -> Vec<(UserId, &UserData)> {
self.users
.iter()
.filter(|(user_id, _)| !self.is_spam.contains_key(user_id))
.map(|(id, d)| (*id, d))
.collect()
}
pub fn classified_users(&self) -> Vec<(UserId, &UserData, IsSpam)> {
self.users
.iter()
.filter_map(|(user_id, user_data)| {
self.is_spam
.get(user_id)
.map(|is_spam| (user_id, user_data, *is_spam))
})
.map(|(id, d, s)| (*id, d, s))
.collect()
}
// Updating
pub fn set_spam(&mut self, uid: UserId, is_spam: Option<IsSpam>) {
match is_spam {
Some(is_spam) => self.is_spam.insert(uid, is_spam),
None => self.is_spam.remove(&uid),
};
}
pub fn remove_user(&mut self, uid: UserId) {
self.users.remove(&uid);
self.is_spam.remove(&uid);
}
// Internal helpers
// XXX remove?
// fn recompute_tokens_for(&mut self, uid: UserId) {
// self.tokens.insert(uid, self.users.get(&uid).unwrap().to_tokens());
// }
// fn recompute_tokens(&mut self) {
// for (id, user) in &self.users {
// self.tokens.insert(*id, user.to_tokens());
// }
// }
// fn recompute_scores(&mut self, classifier: &Classifier) {
// for (id, tokens) in &self.tokens {
// self.score.insert(*id, classifier.score(tokens));
// }
// }
}
pub struct Iter<'a> {
iter_users: hash_map::Iter<'a, UserId, UserData>,
is_spam: &'a HashMap<UserId, IsSpam>,
}
impl<'a> Iterator for Iter<'a> {
type Item = (UserId, &'a UserData, Option<IsSpam>);
fn next(&mut self) -> Option<(UserId, &'a UserData, Option<IsSpam>)> {
self.iter_users.next().map(|(uid, udata)| {
let is_spam = self.is_spam.get(uid).copied();
(*uid, udata, is_spam)
})
}
}
impl<'a> IntoIterator for &'a UserDb {
type Item = (UserId, &'a UserData, Option<IsSpam>);
type IntoIter = Iter<'a>;
fn into_iter(self) -> Iter<'a> {
Iter {
iter_users: self.users.iter(),
is_spam: &self.is_spam,
}
}
}

View file

@ -1,13 +1,14 @@
use crate::classifier::Classifier;
use crate::data::UserId;
use crate::db::{Db, IsSpam};
use crate::db::Db;
use crate::email;
use crate::scrape;
use crate::userdb::{IsSpam, UserDb};
use crate::{storage, storage::Storage};
use anyhow::anyhow;
use forgejo_api::Forgejo;
use std::collections::HashMap;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use std::time::SystemTime;
use crate::FORGEJO_POLL_DELAY;
use crate::GRACE_PERIOD;
@ -16,14 +17,9 @@ use crate::{GUESS_LEGIT_THRESHOLD, GUESS_SPAM_THRESHOLD};
// Worker to refresh user data by periodically polling Forgejo
async fn try_refresh_user_data(
forge: &Forgejo,
db: Arc<Mutex<Db>>,
classifier: Arc<Mutex<Classifier>>,
) -> anyhow::Result<()> {
async fn try_refresh_user_data(storage: &Storage, forge: &Forgejo, db: &Db) -> anyhow::Result<()> {
{
let db = &db.lock().unwrap();
let d = db.last_scrape.elapsed()?;
let d = db.with_userdb(|udb| udb.last_scrape().elapsed())?;
if d < FORGEJO_POLL_DELAY {
return Ok(());
}
@ -32,49 +28,51 @@ async fn try_refresh_user_data(
eprintln!("Fetching user data");
let users = scrape::get_user_data(forge).await?;
let db: &mut Db = &mut db.lock().unwrap();
let classifier = &classifier.lock().unwrap();
{
// NB: Some user accounts may have been deleted since last fetch (hopefully
// they were spammers).
// Such users will appear in the current [db] but not in the new [users].
// We don't want to keep them in the database, so we rebuild a fresh [db]
// containing only data for users who still exist.
let mut newdb = Db::from_users(users, HashMap::new(), classifier);
let mut newdb = UserDb::from_users(users, HashMap::new(), SystemTime::now());
// Import spam classification from the previous Db
for (&user_id, user_data) in &newdb.users {
let &score = newdb.score.get(&user_id).unwrap();
if let Some(&user_was_spam) = db.is_spam.get(&user_id) {
let users: Vec<(UserId, Vec<String>, String)> = newdb
.unclassified_users()
.iter()
.map(|(user_id, user_data)| (*user_id, user_data.to_tokens(), user_data.login.clone()))
.collect();
// Import spam classification from the previous Db.
// (Initially, all users are "unclassified" in newdb.)
for (user_id, tokens, login) in users.into_iter() {
let score = db.with_classifier(|c| c.score(&tokens));
if let Some(user_was_spam) = db.with_userdb(|u| u.is_spam(user_id)) {
if (user_was_spam.as_bool() && score < GUESS_SPAM_THRESHOLD)
|| (!user_was_spam.as_bool() && score > GUESS_LEGIT_THRESHOLD)
{
eprintln!(
"Score for user {} changed past threshold; discarding our current classification",
user_data.login
"Score for user {login} changed past threshold; discarding our current classification",
);
} else {
newdb.is_spam.insert(user_id, user_was_spam);
newdb.set_spam(user_id, Some(user_was_spam));
}
}
}
// switch to [newdb]
let _ = std::mem::replace(db, newdb);
db.replace_userdb(newdb);
}
db.store_to_path(Path::new("db.json")).unwrap(); // FIXME
let res = storage::store_db(storage, db).await;
res.unwrap(); // FIXME
Ok(())
}
pub async fn refresh_user_data(
forge: Arc<Forgejo>,
db: Arc<Mutex<Db>>,
classifier: Arc<Mutex<Classifier>>,
) {
pub async fn refresh_user_data(storage: Arc<Storage>, forge: Arc<Forgejo>, db: Db) {
loop {
if let Err(e) = try_refresh_user_data(&forge, db.clone(), classifier.clone()).await {
if let Err(e) = try_refresh_user_data(&storage, &forge, &db).await {
eprintln!("Error refreshing user data: {:?}", e);
}
tokio::time::sleep(FORGEJO_POLL_DELAY.mul_f32(0.1)).await;
@ -100,15 +98,19 @@ async fn try_purge_account(config: &Config, forge: &Forgejo, login: &str) -> any
Ok(())
}
pub async fn purge_spammer_accounts(config: Arc<Config>, forge: Arc<Forgejo>, db: Arc<Mutex<Db>>) {
pub async fn purge_spammer_accounts(
config: Arc<Config>,
storage: Arc<Storage>,
forge: Arc<Forgejo>,
db: Db,
) {
loop {
let mut classified_users = Vec::new();
{
let db = &db.lock().unwrap();
for (id, user, is_spam) in db.classified_users() {
classified_users.push((id, user.login.clone(), is_spam));
}
}
let classified_users: Vec<_> = db.with_userdb(|u| {
u.classified_users()
.into_iter()
.map(|(user_id, user, is_spam)| (user_id, user.login.clone(), is_spam))
.collect()
});
for (user_id, login, is_spam) in classified_users {
if let IsSpam::Spam {
@ -141,12 +143,11 @@ pub async fn purge_spammer_accounts(config: Arc<Config>, forge: Arc<Forgejo>, db
eprintln!("Error while deleting spammer account {login}: {:?}", e)
} else {
eprintln!("Deleted spammer account {login}");
let db = &mut db.lock().unwrap();
db.users.remove(&user_id);
db.is_spam.remove(&user_id);
db.score.remove(&user_id);
db.tokens.remove(&user_id);
db.store_to_path(Path::new("db.json")).unwrap(); // FIXME
{
db.remove_user(user_id);
}
let res = storage::store_db(&storage, &db).await;
res.unwrap(); // FIXME
}
}
_ => (),
@ -193,23 +194,23 @@ async fn lock_user_account(forge: &Forgejo, username: &str) -> anyhow::Result<()
pub async fn try_lock_and_notify_user(
config: &Config,
storage: &Storage,
forge: &Forgejo,
db: Arc<Mutex<Db>>,
db: &Db,
user_id: UserId,
) -> anyhow::Result<()> {
let (login, email, is_spam) = {
let db = &db.lock().unwrap();
let user = db.users.get(&user_id).unwrap();
let is_spam = match db.is_spam.get(&user_id) {
let (login, email, is_spam) = db.with_userdb(|u| {
let user = u.userdata(user_id).unwrap();
(user.login.clone(), user.email.clone(), u.is_spam(user_id))
});
let is_spam = match is_spam{
Some(IsSpam::Spam {
classified_at,
locked,
notified,
}) => Some((*classified_at, *locked, *notified)),
}) => Some((classified_at, locked, notified)),
_ => None,
};
(user.login.clone(), user.email.clone(), is_spam)
};
if let Some((classified_at, locked, notified)) = is_spam {
if !locked {
@ -222,16 +223,16 @@ pub async fn try_lock_and_notify_user(
ActuallyBan::No => eprintln!("[Simulating: lock account of user {login}]"),
}
let db = &mut db.lock().unwrap();
db.is_spam.insert(
db.set_spam(
user_id,
IsSpam::Spam {
Some(IsSpam::Spam {
classified_at,
locked: true,
notified,
},
}),
);
db.store_to_path(Path::new("db.json")).unwrap(); // FIXME
storage::store_db(storage, db).await.unwrap(); // FIXME
}
if !notified {
@ -245,16 +246,17 @@ pub async fn try_lock_and_notify_user(
eprintln!("[Simulating: send notification email to user {login}]")
}
}
let db = &mut db.lock().unwrap();
db.is_spam.insert(
db.set_spam(
user_id,
IsSpam::Spam {
Some(IsSpam::Spam {
classified_at,
locked: true,
notified: true,
},
}),
);
db.store_to_path(Path::new("db.json")).unwrap(); // FIXME
storage::store_db(storage, db).await.unwrap(); // FIXME
}
Ok(())
@ -266,19 +268,25 @@ pub async fn try_lock_and_notify_user(
}
}
pub async fn lock_and_notify_users(config: Arc<Config>, forge: Arc<Forgejo>, db: Arc<Mutex<Db>>) {
pub async fn lock_and_notify_users(
config: Arc<Config>,
storage: Arc<Storage>,
forge: Arc<Forgejo>,
db: Db,
) {
let mut spammers = Vec::new();
{
let db = &db.lock().unwrap();
for (id, user, is_spam) in db.classified_users() {
db.with_userdb(|udb| {
for (id, user, is_spam) in udb.classified_users() {
if is_spam.as_bool() {
spammers.push((id, user.login.clone()))
}
}
})
}
for (user_id, login) in spammers {
try_lock_and_notify_user(&config, &forge, db.clone(), user_id)
try_lock_and_notify_user(&config, &storage, &forge, &db, user_id)
.await
.unwrap_or_else(|err| eprintln!("Failed to lock or notify user {login}: {err}"));
}