refactoring

This commit is contained in:
Armaël Guéneau 2024-11-19 14:18:44 +01:00
parent 797377734f
commit e0a0456402

View file

@ -8,7 +8,7 @@ use std::fs::File;
use std::io::{BufReader, BufWriter};
use rand::prelude::*;
#[derive(Debug, Hash, PartialEq, Eq)]
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
#[derive(Serialize, Deserialize)]
struct RepoId(i64);
@ -19,7 +19,7 @@ struct RepoData {
description: Option<String>,
}
#[derive(Debug, Hash, PartialEq, Eq)]
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
#[derive(Serialize, Deserialize)]
struct IssueId(i64);
@ -30,7 +30,7 @@ struct IssueData {
body: String,
}
#[derive(Debug, Hash, PartialEq, Eq)]
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
#[derive(Serialize, Deserialize)]
struct UserId(i64);
@ -57,7 +57,9 @@ use Classification::*;
#[derive(Debug, Serialize, Deserialize)]
struct Db {
users: HashMap<UserId, Classification>,
users: HashMap<UserId, UserData>,
text: HashMap<UserId, String>,
classification: HashMap<UserId, Classification>,
}
impl UserData {
@ -107,6 +109,8 @@ impl Db {
fn new() -> Db {
Db {
users: HashMap::new(),
text: HashMap::new(),
classification: HashMap::new(),
}
}
}
@ -286,14 +290,6 @@ async fn main() -> anyhow::Result<()> {
Classifier::new()
};
let classification_path = Path::new("classification.json");
let mut classification = if classification_path.is_file() {
let file = File::open(classification_path)?;
serde_json::from_reader(BufReader::new(file))?
} else {
Db::new()
};
let api_token =
std::fs::read_to_string(Path::new("api_token"))?
.trim().to_string();
@ -302,37 +298,49 @@ async fn main() -> anyhow::Result<()> {
url::Url::parse("https://git.deuxfleurs.fr")?
)?;
let data_path = Path::new("data.json");
let data = if data_path.is_file() {
let file = File::open(data_path)?;
let db_path = Path::new("db.json");
let mut db = if db_path.is_file() {
let file = File::open(db_path)?;
serde_json::from_reader(BufReader::new(file))?
} else {
let data = get_users_data(&forge).await?;
let file = File::create(data_path)?;
serde_json::to_writer(BufWriter::new(file), &classification)?;
data
let mut db = Db::new();
db.users = get_users_data(&forge).await?;
for (user_id, user) in &db.users {
db.text.insert(*user_id, user.to_text());
}
let file = File::create(db_path)?;
serde_json::to_writer(BufWriter::new(file), &db)?;
db
};
println!("got {} users", data.len());
println!("got {} users", db.users.len());
let mut users: Vec<_> =
data.into_iter()
db.users.iter()
.filter_map(
|(user_id, user)|
if classification.users.contains_key(&user_id) {
if db.classification.contains_key(&user_id) {
None
} else {
let text = user.to_text();
let score = classifier.score(&text);
let text = db.text.get(&user_id).unwrap();
let score = classifier.score(text);
Some((user_id, user, text, score))
}
)
.collect();
let mut rng = rand::thread_rng();
users.shuffle(&mut rng);
users.sort_by_key(|(_, _, _, score)| 1000 - (score * 1000.) as u64);
for (user_id, user, text, score) in users {
for (user_id, user, text, _) in users {
println!("{:#?}", user);
let score = classifier.score(text);
println!("SCORE: {}", score);
let c = {
@ -355,13 +363,13 @@ async fn main() -> anyhow::Result<()> {
Unknown => ()
}
classification.users.insert(user_id, c);
db.classification.insert(*user_id, c);
{
classifier.save(&mut File::create(model_path)?, false)?;
let file = File::create(classification_path)?;
serde_json::to_writer(BufWriter::new(file), &classification)?;
let file = File::create(db_path)?;
serde_json::to_writer(BufWriter::new(file), &db)?;
}
}