refactoring

This commit is contained in:
Armaël Guéneau 2024-11-19 14:18:44 +01:00
parent 797377734f
commit e0a0456402

View file

@ -8,7 +8,7 @@ use std::fs::File;
use std::io::{BufReader, BufWriter}; use std::io::{BufReader, BufWriter};
use rand::prelude::*; use rand::prelude::*;
#[derive(Debug, Hash, PartialEq, Eq)] #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct RepoId(i64); struct RepoId(i64);
@ -19,7 +19,7 @@ struct RepoData {
description: Option<String>, description: Option<String>,
} }
#[derive(Debug, Hash, PartialEq, Eq)] #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct IssueId(i64); struct IssueId(i64);
@ -30,7 +30,7 @@ struct IssueData {
body: String, body: String,
} }
#[derive(Debug, Hash, PartialEq, Eq)] #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct UserId(i64); struct UserId(i64);
@ -57,7 +57,9 @@ use Classification::*;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct Db { struct Db {
users: HashMap<UserId, Classification>, users: HashMap<UserId, UserData>,
text: HashMap<UserId, String>,
classification: HashMap<UserId, Classification>,
} }
impl UserData { impl UserData {
@ -107,6 +109,8 @@ impl Db {
fn new() -> Db { fn new() -> Db {
Db { Db {
users: HashMap::new(), users: HashMap::new(),
text: HashMap::new(),
classification: HashMap::new(),
} }
} }
} }
@ -286,14 +290,6 @@ async fn main() -> anyhow::Result<()> {
Classifier::new() Classifier::new()
}; };
let classification_path = Path::new("classification.json");
let mut classification = if classification_path.is_file() {
let file = File::open(classification_path)?;
serde_json::from_reader(BufReader::new(file))?
} else {
Db::new()
};
let api_token = let api_token =
std::fs::read_to_string(Path::new("api_token"))? std::fs::read_to_string(Path::new("api_token"))?
.trim().to_string(); .trim().to_string();
@ -302,37 +298,49 @@ async fn main() -> anyhow::Result<()> {
url::Url::parse("https://git.deuxfleurs.fr")? url::Url::parse("https://git.deuxfleurs.fr")?
)?; )?;
let data_path = Path::new("data.json"); let db_path = Path::new("db.json");
let data = if data_path.is_file() { let mut db = if db_path.is_file() {
let file = File::open(data_path)?; let file = File::open(db_path)?;
serde_json::from_reader(BufReader::new(file))? serde_json::from_reader(BufReader::new(file))?
} else { } else {
let data = get_users_data(&forge).await?; let mut db = Db::new();
let file = File::create(data_path)?;
serde_json::to_writer(BufWriter::new(file), &classification)?; db.users = get_users_data(&forge).await?;
data
for (user_id, user) in &db.users {
db.text.insert(*user_id, user.to_text());
}
let file = File::create(db_path)?;
serde_json::to_writer(BufWriter::new(file), &db)?;
db
}; };
println!("got {} users", data.len());
println!("got {} users", db.users.len());
let mut users: Vec<_> = let mut users: Vec<_> =
data.into_iter() db.users.iter()
.filter_map( .filter_map(
|(user_id, user)| |(user_id, user)|
if classification.users.contains_key(&user_id) { if db.classification.contains_key(&user_id) {
None None
} else { } else {
let text = user.to_text(); let text = db.text.get(&user_id).unwrap();
let score = classifier.score(&text); let score = classifier.score(text);
Some((user_id, user, text, score)) Some((user_id, user, text, score))
} }
) )
.collect(); .collect();
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
users.shuffle(&mut rng); users.shuffle(&mut rng);
users.sort_by_key(|(_, _, _, score)| 1000 - (score * 1000.) as u64); users.sort_by_key(|(_, _, _, score)| 1000 - (score * 1000.) as u64);
for (user_id, user, text, score) in users { for (user_id, user, text, _) in users {
println!("{:#?}", user); println!("{:#?}", user);
let score = classifier.score(text);
println!("SCORE: {}", score); println!("SCORE: {}", score);
let c = { let c = {
@ -355,13 +363,13 @@ async fn main() -> anyhow::Result<()> {
Unknown => () Unknown => ()
} }
classification.users.insert(user_id, c); db.classification.insert(*user_id, c);
{ {
classifier.save(&mut File::create(model_path)?, false)?; classifier.save(&mut File::create(model_path)?, false)?;
let file = File::create(classification_path)?; let file = File::create(db_path)?;
serde_json::to_writer(BufWriter::new(file), &classification)?; serde_json::to_writer(BufWriter::new(file), &db)?;
} }
} }