👌 Improve Parsing logic

This commit is contained in:
Alexander WB 2025-02-21 22:03:49 +01:00
parent 6e9ecf8a26
commit ac883c1ddf

View file

@ -2,7 +2,7 @@ use anyhow::Result;
use async_compression::tokio::bufread::GzipDecoder;
use futures::TryStreamExt;
use intervaltree::IntervalTree;
use std::net::IpAddr;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::pin::Pin;
use std::str::FromStr;
use tokio::io::AsyncRead;
@ -11,26 +11,19 @@ use tokio_util::io::StreamReader;
use tracing::{debug, info, trace};
pub struct Blocklist {
// Separate trees for IPv4 and IPv6 since they have different numeric ranges
ipv4_ranges: IntervalTree<u32, ()>,
ipv6_ranges: IntervalTree<u128, ()>,
// ipv4 and ipv6 do not overlap
// see: https://www.rfc-editor.org/rfc/rfc4291#section-2.5.5
blocked_ranges: IntervalTree<IpAddr, ()>,
}
impl Blocklist {
pub fn empty() -> Self {
return Self::new(
&Vec::<std::ops::Range<u32>>::new(),
&Vec::<std::ops::Range<u128>>::new(),
);
return Self::new(std::iter::empty());
}
pub fn new(
ipv4_ranges: &Vec<std::ops::Range<u32>>,
ipv6_ranges: &Vec<std::ops::Range<u128>>,
) -> Self {
pub fn new(ip_ranges: impl IntoIterator<Item = std::ops::Range<IpAddr>>) -> Self {
Self {
ipv4_ranges: IntervalTree::from_iter(ipv4_ranges.iter().map(|r| (r.clone(), ()))),
ipv6_ranges: IntervalTree::from_iter(ipv6_ranges.iter().map(|r| (r.clone(), ()))),
blocked_ranges: IntervalTree::from_iter(ip_ranges.into_iter().map(|r| (r, ()))),
}
}
@ -97,87 +90,75 @@ impl Blocklist {
let reader = BufReader::new(reader);
let mut lines = reader.lines();
let mut ipv4_ranges: Vec<std::ops::Range<u32>> = Vec::new();
let mut ipv6_ranges: Vec<std::ops::Range<u128>> = Vec::new();
let mut ip_ranges: Vec<std::ops::Range<IpAddr>> = Vec::new();
while let Some(line) = lines.next_line().await? {
// Skip comments and empty lines
if line.starts_with('#') || line.trim().is_empty() {
continue;
}
// Parse IP ranges in format: "RuleName:StartIp-EndIp"
if let Some((start_ip, end_ip)) = parse_ip_range(&line) {
match (start_ip, end_ip) {
(IpAddr::V4(start), IpAddr::V4(end)) => {
let start_num = u32::from_be_bytes(start.octets());
let end_num = u32::from_be_bytes(end.octets());
let range = if end_num == u32::MAX {
start_num..end_num // Special case: Use inclusive range when max
} else {
start_num..(end_num + 1) // Normal case
};
ipv4_ranges.push(range);
}
(IpAddr::V6(start), IpAddr::V6(end)) => {
let start_num = u128::from_be_bytes(start.octets());
let end_num = u128::from_be_bytes(end.octets());
let range = if end_num == u128::MAX {
start_num..end_num // Special case: Use inclusive range when max
} else {
start_num..(end_num + 1) // Normal case
};
ipv6_ranges.push(range);
}
_ => {
continue;
}
}
let range = start_ip..(increment_ip(end_ip).unwrap());
ip_ranges.push(range);
}
}
info!(
ipv6_entry_count = ipv6_ranges.len(),
ipv4_entry_count = ipv4_ranges.len(),
ip_entry_count = ip_ranges.len(),
"Finished loading blocklist"
);
let blocklist = Self::new(&ipv4_ranges, &ipv6_ranges);
let blocklist = Self::new(ip_ranges);
Ok(blocklist)
}
pub fn is_blocked(&self, ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(ipv4) => {
let num = u32::from_be_bytes(ipv4.octets());
self.ipv4_ranges.query_point(num).next().is_some()
}
IpAddr::V6(ipv6) => {
let num = u128::from_be_bytes(ipv6.octets());
self.ipv6_ranges.query_point(num).next().is_some()
}
self.blocked_ranges.query_point(*ip).next().is_some()
}
}
/// Safely increments an `IpAddr`, returning `None` if it would overflow.
fn increment_ip(ip: IpAddr) -> Option<IpAddr> {
match ip {
IpAddr::V4(ipv4) => {
let num = u32::from_be_bytes(ipv4.octets());
num.checked_add(1).map(|n| IpAddr::V4(Ipv4Addr::from(n)))
}
IpAddr::V6(ipv6) => {
let num = u128::from_be_bytes(ipv6.octets());
num.checked_add(1).map(|n| IpAddr::V6(Ipv6Addr::from(n)))
}
}
}
fn parse_ip_range(line: &str) -> Option<(IpAddr, IpAddr)> {
// Skip comments and empty lines
if line.starts_with('#') || line.trim().is_empty() {
let line = line.trim();
if line.starts_with('#') || line.is_empty() {
return None;
}
// Parse IP ranges in format: "RuleName:StartIp-EndIp"
if let Some((rule_name, ip_range)) = line.rsplit_once(':') {
if let Some((start, end)) = ip_range.split_once('-') {
if let (Ok(start_ip), Ok(end_ip)) =
(IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim()))
{
return Some((start_ip, end_ip));
} else {
// Mismatched IP versions, skip this range
debug!(rulen_name = rule_name, "Could not be parsed");
}
let is_ipv4 = line.matches('.').count() >= 6;
// Find the split point based on whether it's IPv4 or not
let split_point: usize = if is_ipv4 {
line.rfind(':')
} else {
line.find(':')
}
.unwrap_or(0);
let (rule_name, ip_range) = line.split_at(split_point + 1);
if let Some((start, end)) = ip_range.split_once('-') {
if let (Ok(start_ip), Ok(end_ip)) =
(IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim()))
{
return Some((start_ip, end_ip));
} else {
// Mismatched IP versions, skip this range
debug!(rulen_name = rule_name, "Could not be parsed");
}
}
None
}
@ -186,7 +167,6 @@ mod tests {
use super::*;
use async_compression::tokio::write::GzipEncoder;
use mockito::{Server, ServerGuard};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::thread::{self, JoinHandle};
use tokio::io::AsyncWriteExt;
@ -313,21 +293,16 @@ mod tests {
#[test]
fn test_manual_ranges() {
// Add IPv4 range
let start_v4: Ipv4Addr = "192.168.0.0".parse().unwrap();
let end_v4: Ipv4Addr = "192.168.255.255".parse().unwrap();
let start_num = u32::from_be_bytes(start_v4.octets());
let end_num = u32::from_be_bytes(end_v4.octets());
let ipv4_range = start_num..(end_num + 1);
let start_v4: IpAddr = "192.168.0.0".parse().unwrap();
let end_v4: IpAddr = "192.168.255.255".parse().unwrap();
let ipv4_range = start_v4..end_v4;
// Add IPv6 range
let start_v6: Ipv6Addr = "2001:db8::".parse().unwrap();
let end_v6: Ipv6Addr = "2001:db8::ffff".parse().unwrap();
let start_num = u128::from_be_bytes(start_v6.octets());
let end_num = u128::from_be_bytes(end_v6.octets());
let ipv6_range = start_num..(end_num + 1);
let blocklist = Blocklist::new(&vec![ipv4_range], &vec![ipv6_range]);
let start_v6: IpAddr = "2001:db8::".parse().unwrap();
let end_v6: IpAddr = "2001:db8::ffff".parse().unwrap();
let ipv6_range = start_v6..end_v6;
let blocklist = Blocklist::new(vec![ipv4_range, ipv6_range]);
// Test IPv4 addresses
assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap()));
assert!(!blocklist.is_blocked(&"10.0.0.1".parse().unwrap()));