From e0a0456402b2bc951ededebeae97e9df4a3a0e0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arma=C3=ABl=20Gu=C3=A9neau?= Date: Tue, 19 Nov 2024 14:18:44 +0100 Subject: [PATCH] refactoring --- src/main.rs | 64 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/src/main.rs b/src/main.rs index f4cc17e..133d32a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,7 @@ use std::fs::File; use std::io::{BufReader, BufWriter}; use rand::prelude::*; -#[derive(Debug, Hash, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] #[derive(Serialize, Deserialize)] struct RepoId(i64); @@ -19,7 +19,7 @@ struct RepoData { description: Option, } -#[derive(Debug, Hash, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] #[derive(Serialize, Deserialize)] struct IssueId(i64); @@ -30,7 +30,7 @@ struct IssueData { body: String, } -#[derive(Debug, Hash, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] #[derive(Serialize, Deserialize)] struct UserId(i64); @@ -57,7 +57,9 @@ use Classification::*; #[derive(Debug, Serialize, Deserialize)] struct Db { - users: HashMap, + users: HashMap, + text: HashMap, + classification: HashMap, } impl UserData { @@ -107,6 +109,8 @@ impl Db { fn new() -> Db { Db { users: HashMap::new(), + text: HashMap::new(), + classification: HashMap::new(), } } } @@ -286,14 +290,6 @@ async fn main() -> anyhow::Result<()> { Classifier::new() }; - let classification_path = Path::new("classification.json"); - let mut classification = if classification_path.is_file() { - let file = File::open(classification_path)?; - serde_json::from_reader(BufReader::new(file))? - } else { - Db::new() - }; - let api_token = std::fs::read_to_string(Path::new("api_token"))? .trim().to_string(); @@ -302,37 +298,49 @@ async fn main() -> anyhow::Result<()> { url::Url::parse("https://git.deuxfleurs.fr")? )?; - let data_path = Path::new("data.json"); - let data = if data_path.is_file() { - let file = File::open(data_path)?; + let db_path = Path::new("db.json"); + let mut db = if db_path.is_file() { + let file = File::open(db_path)?; serde_json::from_reader(BufReader::new(file))? } else { - let data = get_users_data(&forge).await?; - let file = File::create(data_path)?; - serde_json::to_writer(BufWriter::new(file), &classification)?; - data + let mut db = Db::new(); + + db.users = get_users_data(&forge).await?; + + for (user_id, user) in &db.users { + db.text.insert(*user_id, user.to_text()); + } + + let file = File::create(db_path)?; + serde_json::to_writer(BufWriter::new(file), &db)?; + + db }; - println!("got {} users", data.len()); + + println!("got {} users", db.users.len()); let mut users: Vec<_> = - data.into_iter() + db.users.iter() .filter_map( |(user_id, user)| - if classification.users.contains_key(&user_id) { + if db.classification.contains_key(&user_id) { None } else { - let text = user.to_text(); - let score = classifier.score(&text); + let text = db.text.get(&user_id).unwrap(); + let score = classifier.score(text); Some((user_id, user, text, score)) } ) .collect(); + let mut rng = rand::thread_rng(); users.shuffle(&mut rng); users.sort_by_key(|(_, _, _, score)| 1000 - (score * 1000.) as u64); - for (user_id, user, text, score) in users { + for (user_id, user, text, _) in users { println!("{:#?}", user); + + let score = classifier.score(text); println!("SCORE: {}", score); let c = { @@ -355,13 +363,13 @@ async fn main() -> anyhow::Result<()> { Unknown => () } - classification.users.insert(user_id, c); + db.classification.insert(*user_id, c); { classifier.save(&mut File::create(model_path)?, false)?; - let file = File::create(classification_path)?; - serde_json::to_writer(BufWriter::new(file), &classification)?; + let file = File::create(db_path)?; + serde_json::to_writer(BufWriter::new(file), &db)?; } }