use std::collections::HashSet; use std::net::SocketAddr; use std::sync::{Arc, RwLock}; use std::time::Duration; use log::{debug, warn}; use lru::LruCache; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::hash; use sodiumoxide::crypto::sign::ed25519; use crate::conn::*; use crate::message::*; use crate::netapp::*; use crate::proto::*; // -- Protocol messages -- #[derive(Serialize, Deserialize)] struct PullMessage {} impl Message for PullMessage { const KIND: MessageKind = 0x42001100; type Response = PushMessage; } #[derive(Serialize, Deserialize)] struct PushMessage { peers: Vec, } impl Message for PushMessage { const KIND: MessageKind = 0x42001101; type Response = (); } // -- Algorithm data structures -- type Seed = [u8; 32]; #[derive(Hash, Clone, Copy, Debug, PartialOrd, PartialEq, Eq, Serialize, Deserialize)] struct Peer { id: ed25519::PublicKey, addr: SocketAddr, } type Cost = [u8; 40]; const MAX_COST: Cost = [0xffu8; 40]; impl Peer { fn cost(&self, seed: &Seed) -> Cost { let mut hasher = hash::State::new(); hasher.update(&seed[..]); let mut cost = [0u8; 40]; match self.addr { SocketAddr::V4(v4addr) => { let v4ip = v4addr.ip().octets(); for i in 0..4 { let mut h = hasher.clone(); h.update(&v4ip[..i + 1]); cost[i * 8..(i + 1) * 8].copy_from_slice(&h.finalize()[..8]); } } SocketAddr::V6(v6addr) => { let v6ip = v6addr.ip().octets(); for i in 0..4 { let mut h = hasher.clone(); h.update(&v6ip[..i + 2]); cost[i * 8..(i + 1) * 8].copy_from_slice(&h.finalize()[..8]); } } } { let mut h5 = hasher.clone(); h5.update(&format!("{}", self.addr).into_bytes()[..]); cost[32..40].copy_from_slice(&h5.finalize()[..8]); } cost } } struct BasaltSlot { seed: Seed, peer: Option, } impl BasaltSlot { fn cost(&self) -> Cost { self.peer.map(|p| p.cost(&self.seed)).unwrap_or(MAX_COST) } } struct BasaltView { i_reset: usize, slots: Vec, } impl BasaltView { fn new(size: usize) -> Self { let slots = (0..size) .map(|_| BasaltSlot { seed: rand_seed(), peer: None, }) .collect::>(); Self { i_reset: 0, slots } } fn current_peers(&self) -> HashSet { self.slots .iter() .filter(|s| s.peer.is_some()) .map(|s| s.peer.unwrap().clone()) .collect::>() } fn current_peers_vec(&self) -> Vec { self.current_peers().drain().collect::>() } fn sample(&self, count: usize) -> Vec { let possibles = self .slots .iter() .enumerate() .filter(|(_i, s)| s.peer.is_some()) .map(|(i, _s)| i) .collect::>(); if possibles.len() == 0 { vec![] } else { let mut ret = vec![]; let mut rng = thread_rng(); for _i in 0..count { let idx = rng.gen_range(0, possibles.len()); ret.push(self.slots[possibles[idx]].peer.unwrap()); } ret } } fn update_slot(&mut self, i: usize, peers: &[Peer]) { let mut slot_cost = self.slots[i].cost(); for peer in peers.iter() { let peer_cost = peer.cost(&self.slots[i].seed); if self.slots[i].peer.is_none() || peer_cost < slot_cost { self.slots[i].peer = Some(*peer); slot_cost = peer_cost; } } } fn update_all_slots(&mut self, peers: &[Peer]) { for i in 0..self.slots.len() { self.update_slot(i, peers); } } fn disconnected(&mut self, id: ed25519::PublicKey) { let mut cleared_slots = vec![]; for i in 0..self.slots.len() { if let Some(p) = self.slots[i].peer { if p.id == id { self.slots[i].peer = None; cleared_slots.push(i); } } } let remaining_peers = self.current_peers_vec(); for i in cleared_slots { self.update_slot(i, &remaining_peers[..]); } } fn should_try_list(&self, peers: &[Peer]) -> Vec { // Select peers that have lower cost than any of our slots let mut ret = HashSet::new(); for i in 0..self.slots.len() { if self.slots[i].peer.is_none() { return peers.to_vec(); } let mut min_cost = self.slots[i].cost(); let mut min_peer = None; for peer in peers.iter() { if ret.contains(peer) { continue; } let peer_cost = peer.cost(&self.slots[i].seed); if peer_cost < min_cost { min_cost = peer_cost; min_peer = Some(*peer); } } if let Some(p) = min_peer { ret.insert(p); if ret.len() == peers.len() { break; } } } ret.drain().collect::>() } fn reset_some_slots(&mut self, count: usize) { for _i in 0..count { self.slots[self.i_reset].seed = rand_seed(); self.i_reset = (self.i_reset + 1) % self.slots.len(); } } } pub struct BasaltParams { pub view_size: usize, pub cache_size: usize, pub exchange_interval: Duration, pub reset_interval: Duration, pub reset_count: usize, } pub struct Basalt { netapp: Arc, param: BasaltParams, bootstrap_peers: Vec, view: RwLock, current_attempts: RwLock>, backlog: RwLock>, } impl Basalt { pub fn new( netapp: Arc, bootstrap_list: Vec<(ed25519::PublicKey, SocketAddr)>, param: BasaltParams, ) -> Arc { let bootstrap_peers = bootstrap_list .iter() .map(|(id, addr)| Peer { id: *id, addr: *addr, }) .collect::>(); let view = BasaltView::new(param.view_size); let backlog = LruCache::new(param.cache_size); let basalt = Arc::new(Self { netapp: netapp.clone(), param, bootstrap_peers, view: RwLock::new(view), current_attempts: RwLock::new(HashSet::new()), backlog: RwLock::new(backlog), }); let basalt2 = basalt.clone(); netapp.on_connected.store(Some(Arc::new(Box::new( move |pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool| { basalt2.on_connected(pk, addr, is_incoming); }, )))); let basalt2 = basalt.clone(); netapp.on_disconnected.store(Some(Arc::new(Box::new( move |pk: ed25519::PublicKey, is_incoming: bool| { basalt2.on_disconnected(pk, is_incoming); }, )))); let basalt2 = basalt.clone(); netapp.add_msg_handler::( move |_from: ed25519::PublicKey, _pullmsg: PullMessage| { let push_msg = basalt2.make_push_message(); async move { Ok(push_msg) } }, ); let basalt2 = basalt.clone(); netapp.add_msg_handler::( move |_from: ed25519::PublicKey, push_msg: PushMessage| { basalt2.handle_peer_list(&push_msg.peers[..]); async move { Ok(()) } }, ); basalt } pub fn sample(&self, count: usize) -> Vec { self.view .read() .unwrap() .sample(count) .iter() .map(|p| p.id) .collect::>() } pub async fn run(self: Arc) { for peer in self.bootstrap_peers.iter() { tokio::spawn(self.clone().try_connect(*peer)); } let pushpull_loop = self.clone().run_pushpull_loop(); let reset_loop = self.run_reset_loop(); tokio::join!(pushpull_loop, reset_loop); } async fn run_pushpull_loop(self: Arc) { loop { tokio::time::delay_for(self.param.exchange_interval).await; let peers = self.view.read().unwrap().sample(2); if peers.len() == 2 { let (c1, c2) = { let client_conns = self.netapp.client_conns.read().unwrap(); ( client_conns.get(&peers[0].id).cloned(), client_conns.get(&peers[1].id).cloned(), ) }; if let Some(c) = c1 { tokio::spawn(self.clone().do_pull(c)); } if let Some(c) = c2 { tokio::spawn(self.clone().do_push(c)); } } } } async fn do_pull(self: Arc, peer: Arc) { match peer.request(PullMessage {}, prio::NORMAL).await { Ok(resp) => { self.handle_peer_list(&resp.peers[..]); } Err(e) => { warn!("Error during pull exchange: {}", e); } }; } async fn do_push(self: Arc, peer: Arc) { let push_msg = self.make_push_message(); if let Err(e) = peer.request(push_msg, prio::NORMAL).await { warn!("Error during push exchange: {}", e); } } fn make_push_message(&self) -> PushMessage { let current_peers = self.view.read().unwrap().current_peers_vec(); PushMessage { peers: current_peers, } } async fn run_reset_loop(self: Arc) { loop { tokio::time::delay_for(self.param.reset_interval).await; { let mut view = self.view.write().unwrap(); let prev_peers = view.current_peers(); let prev_peers_vec = prev_peers.iter().cloned().collect::>(); view.reset_some_slots(self.param.reset_count); view.update_all_slots(&prev_peers_vec[..]); let new_peers = view.current_peers(); drop(view); self.close_all_diff(&prev_peers, &new_peers); } let mut to_retry_maybe = self.bootstrap_peers.clone(); for (peer, _) in self.backlog.read().unwrap().iter() { if !self.bootstrap_peers.contains(peer) { to_retry_maybe.push(*peer); } } self.handle_peer_list(&to_retry_maybe[..]); } } fn handle_peer_list(self: &Arc, peers: &[Peer]) { let to_connect = self.view.read().unwrap().should_try_list(peers); for peer in to_connect.iter() { tokio::spawn(self.clone().try_connect(*peer)); } } async fn try_connect(self: Arc, peer: Peer) { { let view = self.view.read().unwrap(); let mut attempts = self.current_attempts.write().unwrap(); if view.slots.iter().any(|x| x.peer == Some(peer)) { return; } if attempts.contains(&peer) { return; } attempts.insert(peer); } let res = self.netapp.clone().try_connect(peer.addr, peer.id).await; debug!("Connection attempt to {}: {:?}", peer.addr, res); self.current_attempts.write().unwrap().remove(&peer); if res.is_err() { self.backlog.write().unwrap().pop(&peer); } } fn on_connected(self: &Arc, pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool) { if is_incoming { self.handle_peer_list(&[Peer{id: pk, addr}][..]); } else { let peer = Peer { id: pk, addr }; let mut backlog = self.backlog.write().unwrap(); if backlog.get(&peer).is_none() { backlog.put(peer, ()); } drop(backlog); let mut view = self.view.write().unwrap(); let prev_peers = view.current_peers(); view.update_all_slots(&[peer][..]); let new_peers = view.current_peers(); drop(view); self.close_all_diff(&prev_peers, &new_peers); } } fn on_disconnected(&self, pk: ed25519::PublicKey, is_incoming: bool) { if !is_incoming { self.view.write().unwrap().disconnected(pk); } } fn close_all_diff(&self, prev_peers: &HashSet, new_peers: &HashSet) { let client_conns = self.netapp.client_conns.read().unwrap(); for peer in prev_peers.iter() { if !new_peers.contains(peer) { if let Some(c) = client_conns.get(&peer.id) { debug!("Closing connection to {} ({})", hex::encode(peer.id), peer.addr); c.close(); } } } } } fn rand_seed() -> Seed { let mut seed = [0u8; 32]; sodiumoxide::randombytes::randombytes_into(&mut seed[..]); seed }