use crate::{ iptables::{self, Proto}, startup::Interfaces, }; use once_cell::sync::OnceCell; use sled::{Db, Tree}; use std::net::Ipv4Addr; static RULES_TREE: OnceCell = OnceCell::new(); fn rules_tree(db: &Db) -> &'static Tree { RULES_TREE.get_or_init(|| db.open_tree("rules").unwrap()) } #[derive(Debug, serde::Deserialize, serde::Serialize)] pub struct Rule { pub(crate) proto: Proto, pub(crate) port: u16, pub(crate) kind: RuleKind, } impl Rule { pub(crate) fn as_forward(&self) -> Option<(Ipv4Addr, u16)> { match &self.kind { RuleKind::Forward { dest_ip, dest_port } => Some((*dest_ip, *dest_port)), _ => None, } } } #[derive(Debug, serde::Deserialize, serde::Serialize)] #[serde(tag = "type")] pub(crate) enum RuleKind { Accept, Forward { dest_ip: Ipv4Addr, #[serde(with = "serde_with::rust::display_fromstr")] dest_port: u16, }, } pub(crate) async fn apply_all(db: &Db, interfaces: &Interfaces) -> Result<(), anyhow::Error> { for (_, rule) in read(db)? { apply(interfaces, rule).await?; } Ok(()) } pub(crate) fn read(db: &Db) -> Result, anyhow::Error> { rules_tree(db) .iter() .map(|res| { let (id, rule) = res?; let id = String::from_utf8_lossy(&id).to_string(); tide::log::debug!("id: {}", id); tide::log::debug!("rule str: {}", String::from_utf8_lossy(&rule)); let rule: Rule = serde_json::from_slice(&rule)?; tide::log::debug!("rule: {:?}", rule); Ok((id, rule)) as Result<(String, Rule), anyhow::Error> }) .collect::, anyhow::Error>>() } pub(crate) async fn delete(db: &Db, rule_id: &str) -> Result { let tree = rules_tree(db); let rule = tree .remove(rule_id.as_bytes())? .ok_or_else(|| anyhow::anyhow!("No rule with id {}", rule_id))?; tree.flush_async().await?; let rule: Rule = serde_json::from_slice(&rule)?; Ok(rule) } async fn set_rule( interfaces: &Interfaces, rule: Rule, func: impl Fn(&mut async_process::Command) -> &mut async_process::Command + Copy, ) -> anyhow::Result<()> { match rule.kind { RuleKind::Accept => { iptables::input( rule.proto, &interfaces.external.interface, interfaces.external.ip, rule.port, interfaces.external.mask, func, ) .await?; } RuleKind::Forward { dest_ip, dest_port } => { let internal_iface = interfaces .internal .iter() .chain(&interfaces.tunnel) .chain(&interfaces.vlan) .find(|info| mask_matches(info.ip, info.mask, dest_ip)); let internal_iface = if let Some(internal_iface) = internal_iface { internal_iface } else { return Ok(()); }; iptables::forward( &interfaces.external.interface, &internal_iface.interface, rule.proto, dest_port, func, ) .await?; iptables::forward_prerouting_dnat( rule.proto, interfaces.external.ip, interfaces.external.mask, rule.port, dest_ip, dest_port, func, ) .await?; iptables::forward_postrouting_snat( rule.proto, internal_iface.ip, internal_iface.mask, dest_port, interfaces.external.ip, dest_ip, func, ) .await?; for iface in interfaces .internal .iter() .chain(&interfaces.tunnel) .chain(&interfaces.vlan) .filter(|iface| *iface != internal_iface) { let has_nat_subnet = interfaces.nats.iter().any(|nat_iface| { *nat_iface == iface.interface || interfaces .internal .iter() .chain(&interfaces.tunnel) .chain(&interfaces.vlan) .any(|other_iface| { *nat_iface == other_iface.interface && other_iface.ip == iface.ip && other_iface.mask == iface.mask }) }); if !has_nat_subnet { iptables::forward( &iface.interface, &internal_iface.interface, rule.proto, dest_port, func, ) .await?; iptables::outbound_forward_established( &iface.interface, &internal_iface.interface, rule.proto, dest_port, func, ) .await?; } } } } Ok(()) } pub(crate) async fn apply(interfaces: &Interfaces, rule: Rule) -> Result<(), anyhow::Error> { set_rule(interfaces, rule, |cmd| cmd.arg("-I")).await } pub(crate) async fn unset(interfaces: &Interfaces, rule: Rule) -> Result<(), anyhow::Error> { set_rule(interfaces, rule, |cmd| cmd.arg("-D")).await } pub(crate) async fn save(db: &Db, rule: &Rule) -> Result<(), anyhow::Error> { let tree = rules_tree(db); let s = serde_json::to_string(rule)?; let id = db.generate_id()?; tree.insert(rule_id(id).as_bytes(), s.as_bytes())?; tree.flush_async().await?; Ok(()) } // 255 - 2^n where n is the index const MASK_BITS: [u8; 7] = [254, 252, 248, 240, 224, 192, 128]; fn mask_matches(mask_ip: Ipv4Addr, netmask: u8, rule_ip: Ipv4Addr) -> bool { let mut count: u8 = 8; let mut matches = true; for (mask_byte, rule_byte) in mask_ip.octets().iter().zip(&rule_ip.octets()) { if count <= netmask { matches = matches && mask_byte == rule_byte } else { let remaining = count.saturating_sub(netmask); if remaining < 8 { let mask = MASK_BITS[remaining.saturating_sub(1) as usize]; matches = matches && (mask_byte & mask) == (rule_byte & mask); } } count += 8; } matches } fn rule_id(id: u64) -> String { format!("rule-{}", id) } #[cfg(test)] mod tests { use super::mask_matches; use crate::{ iptables::Proto, rules::{Rule, RuleKind}, }; #[test] fn ips_match_mask() { let tests = [ ("192.168.6.0", 24, "192.168.6.1"), ("192.168.6.0", 24, "192.168.6.254"), ("192.168.6.0", 25, "192.168.6.1"), ("192.168.6.0", 26, "192.168.6.1"), ("192.168.6.0", 27, "192.168.6.1"), ("192.168.6.0", 31, "192.168.6.1"), ("192.168.6.0", 32, "192.168.6.0"), ("192.168.0.0", 16, "192.168.255.0"), ("192.168.0.0", 16, "192.168.0.255"), ("192.168.0.0", 16, "192.168.1.0"), ("192.168.0.0", 16, "192.168.0.1"), ("192.168.0.0", 16, "192.168.128.0"), ]; for (mask_ip, mask, rule_ip) in tests { let matches = mask_matches(mask_ip.parse().unwrap(), mask, rule_ip.parse().unwrap()); assert!(matches); } } #[test] fn ips_dont_match_mask() { let tests = [ ("192.168.6.0", 24, "192.168.5.0"), ("192.168.6.0", 25, "192.168.6.128"), ("192.168.6.0", 31, "192.168.6.2"), ("192.169.0.0", 16, "192.168.0.1"), ("192.1.0.0", 16, "192.0.0.0"), ("192.0.0.0", 16, "192.1.0.0"), ("192.255.0.0", 16, "192.0.0.0"), ("192.0.0.0", 16, "192.255.0.0"), ("192.168.0.0", 17, "192.168.128.0"), ]; for (mask_ip, mask, rule_ip) in tests { println!("comparing: {}/{} to {}", mask_ip, mask, rule_ip); let matches = mask_matches(mask_ip.parse().unwrap(), mask, rule_ip.parse().unwrap()); assert!(!matches); } } #[test] fn can_serialize() { let rule = Rule { proto: Proto::Tcp, port: 53, kind: RuleKind::Forward { dest_ip: "192.168.6.84".parse().unwrap(), dest_port: 53, }, }; let s = serde_qs::to_string(&rule).unwrap(); assert_eq!( s, "proto=Tcp&port=53&kind[type]=Forward&kind[dest_ip]=192.168.6.84&kind[dest_port]=53" ); } }