use std::collections::HashSet; use anyhow::Result; use iptables; use log::*; use tokio::{ select, sync::watch, time::{self, Duration}, }; use crate::{fw, messages}; pub struct FirewallActor { pub ipt: iptables::IPTables, rx_ports: watch::Receiver, last_ports: messages::PublicExposedPorts, refresh: Duration, } impl FirewallActor { pub async fn new( _refresh: Duration, rxp: &watch::Receiver, ) -> Result { let ctx = Self { ipt: iptables::new(false)?, rx_ports: rxp.clone(), last_ports: messages::PublicExposedPorts::new(), refresh: _refresh, }; fw::setup(&ctx.ipt)?; return Ok(ctx); } pub async fn listen(&mut self) -> Result<()> { let mut interval = time::interval(self.refresh); loop { // 1. Wait for an event let new_ports = select! { _ = self.rx_ports.changed() => Some(self.rx_ports.borrow().clone()), _ = interval.tick() => None, else => return Ok(()) // Sender dropped, terminate loop. }; // 2. Update last ports if needed if let Some(p) = new_ports { self.last_ports = p; } // 3. Update firewall rules match self.do_fw_update().await { Ok(()) => debug!("Successfully updated firewall rules"), Err(e) => error!("An error occured while updating firewall rules. {}", e), } } } pub async fn do_fw_update(&self) -> Result<()> { let curr_opened_ports = fw::get_opened_ports(&self.ipt)?; let diff_tcp = self .last_ports .tcp_ports .difference(&curr_opened_ports.tcp_ports) .copied() .collect::>(); let diff_udp = self .last_ports .udp_ports .difference(&curr_opened_ports.udp_ports) .copied() .collect::>(); let ports_to_open = messages::PublicExposedPorts { tcp_ports: diff_tcp, udp_ports: diff_udp, }; fw::open_ports(&self.ipt, ports_to_open)?; return Ok(()); } }