From 90cdffb425c6222f4234db54a16c079d8c058724 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 7 Apr 2020 18:10:20 +0200 Subject: [PATCH] custom data type for hashes and identifiers --- Cargo.lock | 10 +++++++ Cargo.toml | 1 + src/data.rs | 71 +++++++++++++++++++++++++++++++++++++++++++++-- src/error.rs | 18 ++++++------ src/main.rs | 8 +++--- src/membership.rs | 33 +++++++++++----------- src/proto.rs | 2 +- src/server.rs | 8 +++--- 8 files changed, 114 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c524c6e6..bed369c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -284,6 +284,7 @@ dependencies = [ "rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)", "rmp-serde 0.14.3 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)", + "serde_bytes 0.11.3 (registry+https://github.com/rust-lang/crates.io-index)", "sha2 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)", "structopt 0.3.12 (registry+https://github.com/rust-lang/crates.io-index)", "tokio 0.2.16 (registry+https://github.com/rust-lang/crates.io-index)", @@ -759,6 +760,14 @@ dependencies = [ "serde_derive 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "serde_bytes" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "serde_derive" version = "1.0.106" @@ -1143,6 +1152,7 @@ dependencies = [ "checksum rustversion 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "b3bba175698996010c4f6dce5e7f173b6eb781fce25d2cfc45e27091ce0b79f6" "checksum scopeguard 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" "checksum serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)" = "36df6ac6412072f67cf767ebbde4133a5b2e88e76dc6187fa7104cd16f783399" +"checksum serde_bytes 0.11.3 (registry+https://github.com/rust-lang/crates.io-index)" = "325a073952621257820e7a3469f55ba4726d8b28657e7e36653d1c36dc2c84ae" "checksum serde_derive 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)" = "9e549e3abf4fb8621bd1609f11dfc9f5e50320802273b12f3811a67e6716ea6c" "checksum sha2 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "27044adfd2e1f077f649f59deb9490d3941d674002f7d062870a60ebe9bd47a0" "checksum signal-hook-registry 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "94f478ede9f64724c5d173d7bb56099ec3e2d9fc2774aac65d34b8b890405f41" diff --git a/Cargo.toml b/Cargo.toml index 73ad1d19..d8225737 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ futures-channel = "0.3" futures-util = "0.3" tokio = { version = "0.2", features = ["full"] } serde = { version = "1.0", features = ["derive"] } +serde_bytes = "0.11" bincode = "1.2.1" err-derive = "0.2.3" rmp-serde = "0.14.3" diff --git a/src/data.rs b/src/data.rs index c649b289..f54c4cc1 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,8 +1,73 @@ +use std::fmt; use std::collections::HashMap; -use serde::{Serialize, Deserialize}; +use serde::{Serializer, Deserializer, Serialize, Deserialize}; +use serde::de::{self, Visitor}; -pub type UUID = [u8; 32]; -pub type Hash = [u8; 32]; +#[derive(Default, PartialOrd, Ord, Clone, Hash, PartialEq)] +pub struct FixedBytes32([u8; 32]); + +impl From<[u8; 32]> for FixedBytes32 { + fn from(x: [u8; 32]) -> FixedBytes32 { + FixedBytes32(x) + } +} + +impl std::convert::AsRef<[u8]> for FixedBytes32 { + fn as_ref(&self) -> &[u8] { + &self.0[..] + } +} + +impl Eq for FixedBytes32 {} + +impl fmt::Debug for FixedBytes32 { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", hex::encode(self.0)) + } +} + +struct FixedBytes32Visitor; +impl<'de> Visitor<'de> for FixedBytes32Visitor { + type Value = FixedBytes32; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a byte slice of size 32") + } + + fn visit_bytes(self, value: &[u8]) -> Result { + if value.len() == 32 { + let mut res = [0u8; 32]; + res.copy_from_slice(value); + Ok(res.into()) + } else { + Err(E::custom(format!("Invalid byte string length {}, expected 32", value.len()))) + } + } +} + +impl<'de> Deserialize<'de> for FixedBytes32 { + fn deserialize>(deserializer: D) -> Result { + deserializer.deserialize_bytes(FixedBytes32Visitor) + } +} + +impl Serialize for FixedBytes32 { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_bytes(&self.0[..]) + } +} + +impl FixedBytes32 { + pub fn as_slice(&self) -> &[u8] { + &self.0[..] + } + pub fn as_slice_mut(&mut self) -> &mut [u8] { + &mut self.0[..] + } +} + +pub type UUID = FixedBytes32; +pub type Hash = FixedBytes32; // Network management diff --git a/src/error.rs b/src/error.rs index 1e611adb..fd717638 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,29 +3,29 @@ use std::io; #[derive(Debug, Error)] pub enum Error { - #[error(display = "IO error")] + #[error(display = "IO error: {}", _0)] Io(#[error(source)] io::Error), - #[error(display = "Hyper error")] + #[error(display = "Hyper error: {}", _0)] Hyper(#[error(source)] hyper::Error), - #[error(display = "HTTP error")] + #[error(display = "HTTP error: {}", _0)] HTTP(#[error(source)] http::Error), - #[error(display = "Messagepack encode error")] + #[error(display = "Messagepack encode error: {}", _0)] RMPEncode(#[error(source)] rmp_serde::encode::Error), - #[error(display = "Messagepack decode error")] + #[error(display = "Messagepack decode error: {}", _0)] RMPDecode(#[error(source)] rmp_serde::decode::Error), - #[error(display = "TOML decode error")] + #[error(display = "TOML decode error: {}", _0)] TomlDecode(#[error(source)] toml::de::Error), - #[error(display = "Timeout")] + #[error(display = "Timeout: {}", _0)] RPCTimeout(#[error(source)] tokio::time::Elapsed), - #[error(display = "RPC error")] + #[error(display = "RPC error: {}", _0)] RPCError(String), - #[error(display = "")] + #[error(display = "{}", _0)] Message(String), } diff --git a/src/main.rs b/src/main.rs index 2cb4b720..1e4107c2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -103,7 +103,7 @@ async fn cmd_status(rpc_cli: RpcClient, rpc_host: SocketAddr) -> Result<(), Erro println!("Healthy nodes:"); for adv in status.iter() { if let Some(cfg) = config.members.get(&adv.id) { - println!("{}\t{}\t{}\t{}", hex::encode(adv.id), cfg.datacenter, cfg.n_tokens, adv.addr); + println!("{}\t{}\t{}\t{}", hex::encode(&adv.id), cfg.datacenter, cfg.n_tokens, adv.addr); } } @@ -112,7 +112,7 @@ async fn cmd_status(rpc_cli: RpcClient, rpc_host: SocketAddr) -> Result<(), Erro println!("\nFailed nodes:"); for (id, cfg) in config.members.iter() { if !status.iter().any(|x| x.id == *id) { - println!("{}\t{}\t{}", hex::encode(id), cfg.datacenter, cfg.n_tokens); + println!("{}\t{}\t{}", hex::encode(&id), cfg.datacenter, cfg.n_tokens); } } } @@ -121,7 +121,7 @@ async fn cmd_status(rpc_cli: RpcClient, rpc_host: SocketAddr) -> Result<(), Erro println!("\nUnconfigured nodes:"); for adv in status.iter() { if !config.members.contains_key(&adv.id) { - println!("{}\t{}", hex::encode(adv.id), adv.addr); + println!("{}\t{}", hex::encode(&adv.id), adv.addr); } } } @@ -139,7 +139,7 @@ async fn cmd_configure(rpc_cli: RpcClient, rpc_host: SocketAddr, args: Configure let mut candidates = vec![]; for adv in status.iter() { - if hex::encode(adv.id).starts_with(&args.node_id) { + if hex::encode(&adv.id).starts_with(&args.node_id) { candidates.push(adv.id.clone()); } } diff --git a/src/membership.rs b/src/membership.rs index 1ce567a7..b7b99bb1 100644 --- a/src/membership.rs +++ b/src/membership.rs @@ -61,7 +61,7 @@ impl Members { }); match old_status { None => { - eprintln!("Newly pingable node: {}", hex::encode(info.id)); + eprintln!("Newly pingable node: {}", hex::encode(&info.id)); true } Some(x) => x.addr != addr, @@ -70,16 +70,16 @@ impl Members { fn recalculate_status_hash(&mut self) { let mut nodes = self.status.iter().collect::>(); - nodes.sort_by_key(|(id, _status)| *id); + nodes.sort_unstable_by_key(|(id, _status)| *id); let mut hasher = Sha256::new(); eprintln!("Current set of pingable nodes: --"); for (id, status) in nodes { - eprintln!("{} {}", hex::encode(id), status.addr); - hasher.input(format!("{} {}\n", hex::encode(id), status.addr)); + eprintln!("{} {}", hex::encode(&id), status.addr); + hasher.input(format!("{} {}\n", hex::encode(&id), status.addr)); } eprintln!("END --"); - self.status_hash.copy_from_slice(&hasher.result()[..]); + self.status_hash.as_slice_mut().copy_from_slice(&hasher.result()[..]); } fn rebuild_ring(&mut self) { @@ -97,19 +97,19 @@ impl Members { for i in 0..config.n_tokens { let mut location_hasher = Sha256::new(); - location_hasher.input(format!("{} {}", hex::encode(id), i)); + location_hasher.input(format!("{} {}", hex::encode(&id), i)); let mut location = [0u8; 32]; location.copy_from_slice(&location_hasher.result()[..]); new_ring.push(RingEntry{ - location, + location: location.into(), node: id.clone(), datacenter, }) } } - new_ring.sort_by_key(|x| x.location); + new_ring.sort_unstable_by(|x, y| x.location.cmp(&y.location)); self.ring = new_ring; self.n_datacenters = datacenters.len(); } @@ -119,7 +119,7 @@ impl Members { return self.config.members.keys().cloned().collect::>(); } - let start = match self.ring.binary_search_by_key(from, |x| x.location) { + let start = match self.ring.binary_search_by(|x| x.location.cmp(from)) { Ok(i) => i, Err(i) => if i == 0 { self.ring.len() - 1 @@ -178,7 +178,7 @@ impl System { }; let mut members = Members{ status: HashMap::new(), - status_hash: [0u8; 32], + status_hash: Hash::default(), config: net_config, ring: Vec::new(), n_datacenters: 0, @@ -193,7 +193,7 @@ impl System { } } - pub async fn save_network_config(&self) { + async fn save_network_config(self: Arc) { let mut path = self.config.metadata_dir.clone(); path.push("network_config"); @@ -211,7 +211,7 @@ impl System { pub async fn make_ping(&self) -> Message { let members = self.members.read().await; Message::Ping(PingMessage{ - id: self.id, + id: self.id.clone(), rpc_port: self.config.rpc_port, status_hash: members.status_hash.clone(), config_version: members.config.version, @@ -271,8 +271,8 @@ impl System { } else if let Some(id) = id_option { let remaining_attempts = members.status.get(id).map(|x| x.remaining_ping_attempts).unwrap_or(0); if remaining_attempts == 0 { - eprintln!("Removing node {} after too many failed pings", hex::encode(id)); - members.status.remove(id); + eprintln!("Removing node {} after too many failed pings", hex::encode(&id)); + members.status.remove(&id); has_changes = true; } else { if let Some(st) = members.status.get_mut(id) { @@ -376,11 +376,12 @@ impl System { { let mut members = self.members.write().await; if adv.version > members.config.version { - tokio::spawn(self.clone().broadcast(Message::AdvertiseConfig(adv.clone()), PING_TIMEOUT)); members.config = adv.clone(); - self.save_network_config().await; members.rebuild_ring(); + + tokio::spawn(self.clone().broadcast(Message::AdvertiseConfig(adv.clone()), PING_TIMEOUT)); + tokio::spawn(self.clone().save_network_config()); } Ok(Message::Ok) diff --git a/src/proto.rs b/src/proto.rs index 18bc339e..d1d4fb59 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -4,7 +4,7 @@ use serde::{Serialize, Deserialize}; use crate::data::*; -pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(2); +pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); #[derive(Debug, Serialize, Deserialize)] pub enum Message { diff --git a/src/server.rs b/src/server.rs index 1450911b..5cac1c70 100644 --- a/src/server.rs +++ b/src/server.rs @@ -47,13 +47,13 @@ fn gen_node_id(metadata_dir: &PathBuf) -> Result { let mut id = [0u8; 32]; id.copy_from_slice(&d[..]); - Ok(id) + Ok(id.into()) } else { - let id = rand::thread_rng().gen::(); + let id = rand::thread_rng().gen::<[u8; 32]>(); let mut f = std::fs::File::create(id_file.as_path())?; f.write_all(&id[..])?; - Ok(id) + Ok(id.into()) } } @@ -78,7 +78,7 @@ pub async fn run_server(config_file: PathBuf) -> Result<(), Error> { let id = gen_node_id(&config.metadata_dir) .expect("Unable to read or generate node ID"); - println!("Node ID: {}", hex::encode(id)); + println!("Node ID: {}", hex::encode(&id)); let sys = Arc::new(System::new(config, id));