router/src/rules.rs
2022-01-31 20:28:57 -06:00

310 lines
8.9 KiB
Rust

use crate::{
iptables::{self, Proto},
startup::Interfaces,
};
use once_cell::sync::OnceCell;
use sled::{Db, Tree};
use std::net::Ipv4Addr;
static RULES_TREE: OnceCell<Tree> = OnceCell::new();
fn rules_tree(db: &Db) -> &'static Tree {
RULES_TREE.get_or_init(|| db.open_tree("rules").unwrap())
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
pub struct Rule {
pub(crate) proto: Proto,
pub(crate) port: u16,
pub(crate) kind: RuleKind,
}
impl Rule {
pub(crate) fn as_forward(&self) -> Option<(Ipv4Addr, u16)> {
match &self.kind {
RuleKind::Forward { dest_ip, dest_port } => Some((*dest_ip, *dest_port)),
_ => None,
}
}
}
#[derive(Debug, serde::Deserialize, serde::Serialize)]
#[serde(tag = "type")]
pub(crate) enum RuleKind {
Accept,
Forward {
dest_ip: Ipv4Addr,
#[serde(with = "serde_with::rust::display_fromstr")]
dest_port: u16,
},
}
pub(crate) async fn apply_all(db: &Db, interfaces: &Interfaces) -> Result<(), anyhow::Error> {
for (_, rule) in read(db)? {
apply(interfaces, rule).await?;
}
Ok(())
}
pub(crate) fn read(db: &Db) -> Result<Vec<(String, Rule)>, anyhow::Error> {
rules_tree(db)
.iter()
.map(|res| {
let (id, rule) = res?;
let id = String::from_utf8_lossy(&id).to_string();
tide::log::debug!("id: {}", id);
tide::log::debug!("rule str: {}", String::from_utf8_lossy(&rule));
let rule: Rule = serde_json::from_slice(&rule)?;
tide::log::debug!("rule: {:?}", rule);
Ok((id, rule)) as Result<(String, Rule), anyhow::Error>
})
.collect::<Result<Vec<_>, anyhow::Error>>()
}
pub(crate) async fn delete(db: &Db, rule_id: &str) -> Result<Rule, anyhow::Error> {
let tree = rules_tree(db);
let rule = tree
.remove(rule_id.as_bytes())?
.ok_or_else(|| anyhow::anyhow!("No rule with id {}", rule_id))?;
tree.flush_async().await?;
let rule: Rule = serde_json::from_slice(&rule)?;
Ok(rule)
}
async fn set_rule(
interfaces: &Interfaces,
rule: Rule,
func: impl Fn(&mut async_process::Command) -> &mut async_process::Command + Copy,
) -> anyhow::Result<()> {
match rule.kind {
RuleKind::Accept => {
iptables::input(
rule.proto,
&interfaces.external.interface,
interfaces.external.ip,
rule.port,
interfaces.external.mask,
func,
)
.await?;
}
RuleKind::Forward { dest_ip, dest_port } => {
let internal_iface = interfaces
.internal
.iter()
.chain(&interfaces.tunnel)
.chain(&interfaces.vlan)
.find(|info| mask_matches(info.ip, info.mask, dest_ip));
let internal_iface = if let Some(internal_iface) = internal_iface {
internal_iface
} else {
return Ok(());
};
iptables::forward(
&interfaces.external.interface,
&internal_iface.interface,
rule.proto,
dest_port,
func,
)
.await?;
iptables::forward_prerouting_dnat(
rule.proto,
interfaces.external.ip,
interfaces.external.mask,
rule.port,
dest_ip,
dest_port,
func,
)
.await?;
iptables::forward_postrouting_snat(
rule.proto,
internal_iface.ip,
internal_iface.mask,
dest_port,
interfaces.external.ip,
dest_ip,
func,
)
.await?;
for iface in interfaces
.internal
.iter()
.chain(&interfaces.tunnel)
.chain(&interfaces.vlan)
.filter(|iface| *iface != internal_iface)
{
let has_nat_subnet = interfaces.nats.iter().any(|nat_iface| {
*nat_iface == iface.interface
|| interfaces
.internal
.iter()
.chain(&interfaces.tunnel)
.chain(&interfaces.vlan)
.any(|other_iface| {
*nat_iface == other_iface.interface
&& other_iface.ip == iface.ip
&& other_iface.mask == iface.mask
})
});
if !has_nat_subnet {
iptables::forward(
&iface.interface,
&internal_iface.interface,
rule.proto,
dest_port,
func,
)
.await?;
iptables::outbound_forward_established(
&iface.interface,
&internal_iface.interface,
rule.proto,
dest_port,
func,
)
.await?;
}
}
}
}
Ok(())
}
pub(crate) async fn apply(interfaces: &Interfaces, rule: Rule) -> Result<(), anyhow::Error> {
set_rule(interfaces, rule, |cmd| cmd.arg("-I")).await
}
pub(crate) async fn unset(interfaces: &Interfaces, rule: Rule) -> Result<(), anyhow::Error> {
set_rule(interfaces, rule, |cmd| cmd.arg("-D")).await
}
pub(crate) async fn save(db: &Db, rule: &Rule) -> Result<(), anyhow::Error> {
let tree = rules_tree(db);
let s = serde_json::to_string(rule)?;
let id = db.generate_id()?;
tree.insert(rule_id(id).as_bytes(), s.as_bytes())?;
tree.flush_async().await?;
Ok(())
}
// 255 - 2^n where n is the index
const MASK_BITS: [u8; 7] = [254, 252, 248, 240, 224, 192, 128];
fn mask_matches(mask_ip: Ipv4Addr, netmask: u8, rule_ip: Ipv4Addr) -> bool {
let mut count: u8 = 8;
let mut matches = true;
for (mask_byte, rule_byte) in mask_ip.octets().iter().zip(&rule_ip.octets()) {
if count <= netmask {
matches = matches && mask_byte == rule_byte
} else {
let remaining = count.saturating_sub(netmask);
if remaining < 8 {
let mask = MASK_BITS[remaining.saturating_sub(1) as usize];
matches = matches && (mask_byte & mask) == (rule_byte & mask);
}
}
count += 8;
}
matches
}
fn rule_id(id: u64) -> String {
format!("rule-{}", id)
}
#[cfg(test)]
mod tests {
use super::mask_matches;
use crate::{
iptables::Proto,
rules::{Rule, RuleKind},
};
#[test]
fn ips_match_mask() {
let tests = [
("192.168.6.0", 24, "192.168.6.1"),
("192.168.6.0", 24, "192.168.6.254"),
("192.168.6.0", 25, "192.168.6.1"),
("192.168.6.0", 26, "192.168.6.1"),
("192.168.6.0", 27, "192.168.6.1"),
("192.168.6.0", 31, "192.168.6.1"),
("192.168.6.0", 32, "192.168.6.0"),
("192.168.0.0", 16, "192.168.255.0"),
("192.168.0.0", 16, "192.168.0.255"),
("192.168.0.0", 16, "192.168.1.0"),
("192.168.0.0", 16, "192.168.0.1"),
("192.168.0.0", 16, "192.168.128.0"),
];
for (mask_ip, mask, rule_ip) in tests {
let matches = mask_matches(mask_ip.parse().unwrap(), mask, rule_ip.parse().unwrap());
assert!(matches);
}
}
#[test]
fn ips_dont_match_mask() {
let tests = [
("192.168.6.0", 24, "192.168.5.0"),
("192.168.6.0", 25, "192.168.6.128"),
("192.168.6.0", 31, "192.168.6.2"),
("192.169.0.0", 16, "192.168.0.1"),
("192.1.0.0", 16, "192.0.0.0"),
("192.0.0.0", 16, "192.1.0.0"),
("192.255.0.0", 16, "192.0.0.0"),
("192.0.0.0", 16, "192.255.0.0"),
("192.168.0.0", 17, "192.168.128.0"),
];
for (mask_ip, mask, rule_ip) in tests {
println!("comparing: {}/{} to {}", mask_ip, mask, rule_ip);
let matches = mask_matches(mask_ip.parse().unwrap(), mask, rule_ip.parse().unwrap());
assert!(!matches);
}
}
#[test]
fn can_serialize() {
let rule = Rule {
proto: Proto::Tcp,
port: 53,
kind: RuleKind::Forward {
dest_ip: "192.168.6.84".parse().unwrap(),
dest_port: 53,
},
};
let s = serde_qs::to_string(&rule).unwrap();
assert_eq!(
s,
"proto=Tcp&port=53&kind[type]=Forward&kind[dest_ip]=192.168.6.84&kind[dest_port]=53"
);
}
}