use std::collections::HashSet; use std::net::SocketAddr; use std::sync::{Arc, RwLock}; use std::time::Duration; use async_trait::async_trait; use log::{debug, info, trace, warn}; use lru::LruCache; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::hash; use tokio::sync::watch; use crate::endpoint::*; use crate::message::*; use crate::netapp::*; use crate::NodeID; // -- Protocol messages -- #[derive(Serialize, Deserialize)] struct PullMessage {} impl Message for PullMessage { type Response = PushMessage; } #[derive(Serialize, Deserialize)] struct PushMessage { peers: Vec, } impl Message for PushMessage { type Response = (); } // -- Algorithm data structures -- type Seed = [u8; 32]; #[derive(Hash, Clone, Copy, Debug, PartialOrd, PartialEq, Eq, Serialize, Deserialize)] struct Peer { id: NodeID, addr: SocketAddr, } type Cost = [u8; 40]; const MAX_COST: Cost = [0xffu8; 40]; impl Peer { fn cost(&self, seed: &Seed) -> Cost { let mut hasher = hash::State::new(); hasher.update(&seed[..]); let hasher = hasher; let mut cost = [0u8; 40]; match self.addr { SocketAddr::V4(v4addr) => { let v4ip = v4addr.ip().octets(); for i in 0..4 { let mut h = hasher; h.update(&v4ip[..i + 1]); cost[i * 8..(i + 1) * 8].copy_from_slice(&h.finalize()[..8]); } } SocketAddr::V6(v6addr) => { let v6ip = v6addr.ip().octets(); for i in 0..4 { let mut h = hasher; h.update(&v6ip[..i + 2]); cost[i * 8..(i + 1) * 8].copy_from_slice(&h.finalize()[..8]); } } } { let mut h5 = hasher; h5.update(&format!("{} {}", self.addr, hex::encode(self.id)).into_bytes()[..]); cost[32..40].copy_from_slice(&h5.finalize()[..8]); } cost } } struct BasaltSlot { seed: Seed, peer: Option, } impl BasaltSlot { fn cost(&self) -> Cost { self.peer.map(|p| p.cost(&self.seed)).unwrap_or(MAX_COST) } } struct BasaltView { i_reset: usize, slots: Vec, } impl BasaltView { fn new(size: usize) -> Self { let slots = (0..size) .map(|_| BasaltSlot { seed: rand_seed(), peer: None, }) .collect::>(); Self { i_reset: 0, slots } } fn current_peers(&self) -> HashSet { self.slots .iter() .filter_map(|s| s.peer) .collect::>() } fn current_peers_vec(&self) -> Vec { self.current_peers().drain().collect::>() } fn sample(&self, count: usize) -> Vec { let possibles = self .slots .iter() .enumerate() .filter(|(_i, s)| s.peer.is_some()) .map(|(i, _s)| i) .collect::>(); if possibles.is_empty() { vec![] } else { let mut ret = vec![]; let mut rng = thread_rng(); for _i in 0..count { let idx = rng.gen_range(0..possibles.len()); ret.push(self.slots[possibles[idx]].peer.unwrap()); } ret } } fn update_slot(&mut self, i: usize, peers: &[Peer]) { let mut slot_cost = self.slots[i].cost(); for peer in peers.iter() { let peer_cost = peer.cost(&self.slots[i].seed); if self.slots[i].peer.is_none() || peer_cost < slot_cost { trace!( "Best match for slot {}: {}@{} (cost {})", i, hex::encode(peer.id), peer.addr, hex::encode(peer_cost) ); self.slots[i].peer = Some(*peer); slot_cost = peer_cost; } } } fn update_all_slots(&mut self, peers: &[Peer]) { for i in 0..self.slots.len() { self.update_slot(i, peers); } } fn disconnected(&mut self, id: NodeID) { let mut cleared_slots = vec![]; for i in 0..self.slots.len() { if let Some(p) = self.slots[i].peer { if p.id == id { self.slots[i].peer = None; cleared_slots.push(i); } } } let remaining_peers = self.current_peers_vec(); for i in cleared_slots { self.update_slot(i, &remaining_peers[..]); } } fn should_try_list(&self, peers: &[Peer]) -> Vec { // Select peers that have lower cost than any of our slots let mut ret = HashSet::new(); for i in 0..self.slots.len() { if self.slots[i].peer.is_none() { return peers.to_vec(); } let mut min_cost = self.slots[i].cost(); let mut min_peer = None; for peer in peers.iter() { if ret.contains(peer) { continue; } let peer_cost = peer.cost(&self.slots[i].seed); if peer_cost < min_cost { min_cost = peer_cost; min_peer = Some(*peer); } } if let Some(p) = min_peer { ret.insert(p); if ret.len() == peers.len() { break; } } } ret.drain().collect::>() } fn reset_some_slots(&mut self, count: usize) { for _i in 0..count { trace!("Reset slot {}", self.i_reset); self.slots[self.i_reset].seed = rand_seed(); self.i_reset = (self.i_reset + 1) % self.slots.len(); } } } pub struct BasaltParams { pub view_size: usize, pub cache_size: usize, pub exchange_interval: Duration, pub reset_interval: Duration, pub reset_count: usize, } pub struct Basalt { netapp: Arc, pull_endpoint: Arc>, push_endpoint: Arc>, param: BasaltParams, bootstrap_peers: Vec, view: RwLock, current_attempts: RwLock>, backlog: RwLock>, } impl Basalt { pub fn new( netapp: Arc, bootstrap_list: Vec<(NodeID, SocketAddr)>, param: BasaltParams, ) -> Arc { let bootstrap_peers = bootstrap_list .iter() .map(|(id, addr)| Peer { id: *id, addr: *addr, }) .collect::>(); let view = BasaltView::new(param.view_size); let backlog = LruCache::new(param.cache_size); let basalt = Arc::new(Self { netapp: netapp.clone(), pull_endpoint: netapp.endpoint("__netapp/peering/basalt.rs/Pull".into()), push_endpoint: netapp.endpoint("__netapp/peering/basalt.rs/Push".into()), param, bootstrap_peers, view: RwLock::new(view), current_attempts: RwLock::new(HashSet::new()), backlog: RwLock::new(backlog), }); basalt.pull_endpoint.set_handler(basalt.clone()); basalt.push_endpoint.set_handler(basalt.clone()); let basalt2 = basalt.clone(); netapp.on_connected(move |id: NodeID, addr: SocketAddr, is_incoming: bool| { basalt2.on_connected(id, addr, is_incoming); }); let basalt2 = basalt.clone(); netapp.on_disconnected(move |id: NodeID, is_incoming: bool| { basalt2.on_disconnected(id, is_incoming); }); basalt } pub fn sample(&self, count: usize) -> Vec { self.view .read() .unwrap() .sample(count) .iter() .map(|p| { debug!("KYEV S {}", hex::encode(p.id)); p.id }) .collect::>() } pub async fn run(self: Arc, must_exit: watch::Receiver) { for peer in self.bootstrap_peers.iter() { tokio::spawn(self.clone().try_connect(*peer)); } tokio::join!( self.clone().run_pushpull_loop(must_exit.clone()), self.clone().run_reset_loop(must_exit.clone()), ); } async fn run_pushpull_loop(self: Arc, must_exit: watch::Receiver) { while !*must_exit.borrow() { tokio::time::sleep(self.param.exchange_interval).await; let peers = self.view.read().unwrap().sample(2); if peers.len() == 2 { tokio::spawn(self.clone().do_pull(peers[0].id)); tokio::spawn(self.clone().do_push(peers[1].id)); } } } async fn do_pull(self: Arc, peer: NodeID) { match self .pull_endpoint .call(&peer, PullMessage {}, PRIO_NORMAL) .await { Ok(resp) => { self.handle_peer_list(&resp.peers[..]); trace!("KYEV PEXi {}", hex::encode(peer)); } Err(e) => { warn!("Error during pull exchange: {}", e); } }; } async fn do_push(self: Arc, peer: NodeID) { let push_msg = self.make_push_message(); match self.push_endpoint.call(&peer, push_msg, PRIO_NORMAL).await { Ok(_) => { trace!("KYEV PEXo {}", hex::encode(peer)); } Err(e) => { warn!("Error during push exchange: {}", e); } } } fn make_push_message(&self) -> PushMessage { let current_peers = self.view.read().unwrap().current_peers_vec(); PushMessage { peers: current_peers, } } async fn run_reset_loop(self: Arc, must_exit: watch::Receiver) { while !*must_exit.borrow() { tokio::time::sleep(self.param.reset_interval).await; { debug!("KYEV R {}", self.param.reset_count); let mut view = self.view.write().unwrap(); let prev_peers = view.current_peers(); let prev_peers_vec = prev_peers.iter().cloned().collect::>(); view.reset_some_slots(self.param.reset_count); view.update_all_slots(&prev_peers_vec[..]); let new_peers = view.current_peers(); drop(view); self.close_all_diff(&prev_peers, &new_peers); } let mut to_retry_maybe = self.bootstrap_peers.clone(); for (peer, _) in self.backlog.read().unwrap().iter() { if !self.bootstrap_peers.contains(peer) { to_retry_maybe.push(*peer); } } self.handle_peer_list(&to_retry_maybe[..]); } } fn handle_peer_list(self: &Arc, peers: &[Peer]) { let to_connect = self.view.read().unwrap().should_try_list(peers); for peer in to_connect.iter() { tokio::spawn(self.clone().try_connect(*peer)); } } async fn try_connect(self: Arc, peer: Peer) { { let view = self.view.read().unwrap(); let mut attempts = self.current_attempts.write().unwrap(); if view.slots.iter().any(|x| x.peer == Some(peer)) { return; } if attempts.contains(&peer) { return; } attempts.insert(peer); } let res = self.netapp.clone().try_connect(peer.addr, peer.id).await; trace!("Connection attempt to {}: {:?}", peer.addr, res); self.current_attempts.write().unwrap().remove(&peer); if res.is_err() { self.backlog.write().unwrap().pop(&peer); } } fn on_connected(self: &Arc, id: NodeID, addr: SocketAddr, is_incoming: bool) { if is_incoming { self.handle_peer_list(&[Peer { id, addr }][..]); } else { info!("KYEV C {} {}", hex::encode(id), addr); let peer = Peer { id, addr }; let mut backlog = self.backlog.write().unwrap(); if backlog.get(&peer).is_none() { backlog.put(peer, ()); } drop(backlog); let mut view = self.view.write().unwrap(); let prev_peers = view.current_peers(); view.update_all_slots(&[peer][..]); let new_peers = view.current_peers(); drop(view); self.close_all_diff(&prev_peers, &new_peers); } } fn on_disconnected(&self, id: NodeID, is_incoming: bool) { if !is_incoming { info!("KYEV D {}", hex::encode(id)); self.view.write().unwrap().disconnected(id); } } fn close_all_diff(&self, prev_peers: &HashSet, new_peers: &HashSet) { for peer in prev_peers.iter() { if !new_peers.contains(peer) { self.netapp.disconnect(&peer.id); } } } } #[async_trait] impl EndpointHandler for Basalt { async fn handle(self: &Arc, _pullmsg: &PullMessage, _from: NodeID) -> PushMessage { self.make_push_message() } } #[async_trait] impl EndpointHandler for Basalt { async fn handle(self: &Arc, pushmsg: &PushMessage, _from: NodeID) { self.handle_peer_list(&pushmsg.peers[..]); } } fn rand_seed() -> Seed { let mut seed = [0u8; 32]; sodiumoxide::randombytes::randombytes_into(&mut seed[..]); seed }