diff --git a/Cargo.toml b/Cargo.toml index 2567b4c..a939769 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,3 +20,4 @@ pnet = "0.35.0" rupnp = "2.0.0" tokio = { version = "1.41.1", features = ["rt", "rt-multi-thread", "macros"] } futures = "0.3.31" +ipnet = { version = "2.10.1", features = ["serde"] } diff --git a/src/main.rs b/src/main.rs index a2fe1cc..2f29628 100644 --- a/src/main.rs +++ b/src/main.rs @@ -29,6 +29,8 @@ const PERSIST_INTERVAL: Duration = Duration::from_secs(600); const LAN_BROADCAST_INTERVAL: Duration = Duration::from_secs(60); const IGD_INTERVAL: Duration = Duration::from_secs(60); +const IGD_LEASE_DURATION: Duration = Duration::from_secs(300); + type Pubkey = String; #[derive(Deserialize)] @@ -56,6 +58,9 @@ struct Config { /// Settings for the Wireguard interfaces (currently only necessary if you want to use igd features) #[serde(default)] interfaces: Vec, + + #[serde(default)] + forbidden_nets: Vec, } #[derive(Deserialize)] @@ -168,7 +173,11 @@ fn wg_dump(interface: &str) -> Result { }) .collect::>(); - Ok(IfInfo { our_pubkey, listen_port, peers }) + Ok(IfInfo { + our_pubkey, + listen_port, + peers, + }) } // ============ DAEMON CODE ================= @@ -210,26 +219,39 @@ impl Daemon { fn new(config: Config) -> Result { let gossip_key = kdf(config.gossip_secret.as_deref().unwrap_or_default()); - let interface_names = config.peers.iter().map(|peer| peer.interface.clone()).collect::>(); - let interfaces = interface_names.into_iter().map(|interface_name| wg_dump(&interface_name).map(|ifinfo| (interface_name, ifinfo))).collect::>>()?; + let interface_names = config + .peers + .iter() + .map(|peer| peer.interface.clone()) + .collect::>(); + let interfaces = interface_names + .into_iter() + .map(|interface_name| wg_dump(&interface_name).map(|ifinfo| (interface_name, ifinfo))) + .collect::>>()?; let socket = UdpSocket::bind(SocketAddr::new("::".parse()?, config.gossip_port))?; socket.set_broadcast(true)?; - socket.set_ttl(1)?; + //socket.set_ttl(1)?; let our_pubkey = interfaces.iter().next().unwrap().1.our_pubkey.clone(); - let peers = config.peers.iter().map(|peer_cfg| { - ( - peer_cfg.pubkey.clone(), - PeerInfo { - gossip_ip: peer_cfg.address, - gossip_prio: fasthash(format!("{}-{}", our_pubkey, peer_cfg.pubkey).as_bytes()), - endpoint: None, // Is resolved as DNS name later - last_seen: u64::MAX, - lan_endpoint: None, - } - ) - }).collect(); + let peers = config + .peers + .iter() + .map(|peer_cfg| { + ( + peer_cfg.pubkey.clone(), + PeerInfo { + gossip_ip: peer_cfg.address, + gossip_prio: fasthash( + format!("{}-{}", our_pubkey, peer_cfg.pubkey).as_bytes(), + ), + endpoint: None, // Is resolved as DNS name later + last_seen: u64::MAX, + lan_endpoint: None, + }, + ) + }) + .collect(); Ok(Daemon { config, @@ -357,9 +379,7 @@ impl Daemon { self.socket.send_to(&packet, from)?; } } - Gossip::LanBroadcast { - pubkey, - } => { + Gossip::LanBroadcast { pubkey } => { if self.config.lan_discovery { if let Some(peer) = state.peers.get_mut(&pubkey) { peer.lan_endpoint = Some((from.ip(), time())); @@ -528,7 +548,20 @@ impl State { let mut peer_vec = self .peers .iter() - .filter(|(_, info)| info.last_seen != u64::MAX && now < info.last_seen + TIMEOUT.as_secs() && info.endpoint.is_some()) + .filter(|(_, info)| { + let seen = info.last_seen != u64::MAX && now < info.last_seen + TIMEOUT.as_secs(); + let endpoint_valid = info + .endpoint + .map(|ep| { + !daemon + .config + .forbidden_nets + .iter() + .any(|net| net.contains(&ep)) + }) + .unwrap_or(false); + seen && endpoint_valid + }) .map(|(_, info)| (info.gossip_ip, info.gossip_prio)) .collect::>(); peer_vec.sort_by_key(|(_, prio)| *prio); @@ -589,7 +622,6 @@ impl State { } fn read_wg_peers(&mut self) -> Result<()> { - // Clear old known endpoints if any for (_, peer) in self.peers.iter_mut() { peer.endpoint = None; @@ -630,10 +662,22 @@ impl State { (Some((addr1, _)), Some(addr2)) => addr1 != addr2, _ => false, }; + // If the current endpoint is in a forbidden net, reconfigure the peer even if it has a connection + let forbidden_endpoint = peer + .endpoint + .map(|ep| { + daemon + .config + .forbidden_nets + .iter() + .any(|net| net.contains(&ep)) + }) + .unwrap_or(false); // if peer is connected and endpoint is the correct one, // set higher keepalive and then skip reconfiguring it - if !bad_endpoint && peer.last_seen != u64::MAX && now < peer.last_seen + TIMEOUT.as_secs() { + if !bad_endpoint && peer.last_seen != u64::MAX && !forbidden_endpoint && now < peer.last_seen + TIMEOUT.as_secs() + { Command::new("wg") .args([ "set", @@ -673,6 +717,16 @@ impl State { } } endpoints.sort(); + endpoints = endpoints + .into_iter() + .filter(|(ep, _)| { + !daemon + .config + .forbidden_nets + .iter() + .any(|net| net.contains(ep)) + }) + .collect(); endpoints } }; @@ -702,7 +756,7 @@ impl State { "persistent-keepalive", "10", "allowed-ips", - "::/0,0.0.0.0/0" + "::/0,0.0.0.0/0", ]) .output()?; let packet = daemon.make_packet(&Gossip::Ping)?; @@ -720,7 +774,7 @@ impl State { "peer", &peer_cfg.pubkey, "allowed-ips", - "::/0,0.0.0.0/0" + "::/0,0.0.0.0/0", ]) .output()?; }