👌 Improve Parsing logic

This commit is contained in:
Alexander WB 2025-02-21 22:03:49 +01:00
parent 6e9ecf8a26
commit ac883c1ddf

View file

@ -2,7 +2,7 @@ use anyhow::Result;
use async_compression::tokio::bufread::GzipDecoder; use async_compression::tokio::bufread::GzipDecoder;
use futures::TryStreamExt; use futures::TryStreamExt;
use intervaltree::IntervalTree; use intervaltree::IntervalTree;
use std::net::IpAddr; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::pin::Pin; use std::pin::Pin;
use std::str::FromStr; use std::str::FromStr;
use tokio::io::AsyncRead; use tokio::io::AsyncRead;
@ -11,26 +11,19 @@ use tokio_util::io::StreamReader;
use tracing::{debug, info, trace}; use tracing::{debug, info, trace};
pub struct Blocklist { pub struct Blocklist {
// Separate trees for IPv4 and IPv6 since they have different numeric ranges // ipv4 and ipv6 do not overlap
ipv4_ranges: IntervalTree<u32, ()>, // see: https://www.rfc-editor.org/rfc/rfc4291#section-2.5.5
ipv6_ranges: IntervalTree<u128, ()>, blocked_ranges: IntervalTree<IpAddr, ()>,
} }
impl Blocklist { impl Blocklist {
pub fn empty() -> Self { pub fn empty() -> Self {
return Self::new( return Self::new(std::iter::empty());
&Vec::<std::ops::Range<u32>>::new(),
&Vec::<std::ops::Range<u128>>::new(),
);
} }
pub fn new( pub fn new(ip_ranges: impl IntoIterator<Item = std::ops::Range<IpAddr>>) -> Self {
ipv4_ranges: &Vec<std::ops::Range<u32>>,
ipv6_ranges: &Vec<std::ops::Range<u128>>,
) -> Self {
Self { Self {
ipv4_ranges: IntervalTree::from_iter(ipv4_ranges.iter().map(|r| (r.clone(), ()))), blocked_ranges: IntervalTree::from_iter(ip_ranges.into_iter().map(|r| (r, ()))),
ipv6_ranges: IntervalTree::from_iter(ipv6_ranges.iter().map(|r| (r.clone(), ()))),
} }
} }
@ -97,87 +90,75 @@ impl Blocklist {
let reader = BufReader::new(reader); let reader = BufReader::new(reader);
let mut lines = reader.lines(); let mut lines = reader.lines();
let mut ipv4_ranges: Vec<std::ops::Range<u32>> = Vec::new(); let mut ip_ranges: Vec<std::ops::Range<IpAddr>> = Vec::new();
let mut ipv6_ranges: Vec<std::ops::Range<u128>> = Vec::new();
while let Some(line) = lines.next_line().await? { while let Some(line) = lines.next_line().await? {
// Skip comments and empty lines // Skip comments and empty lines
if line.starts_with('#') || line.trim().is_empty() { if line.starts_with('#') || line.trim().is_empty() {
continue; continue;
} }
// Parse IP ranges in format: "RuleName:StartIp-EndIp"
if let Some((start_ip, end_ip)) = parse_ip_range(&line) { if let Some((start_ip, end_ip)) = parse_ip_range(&line) {
match (start_ip, end_ip) { let range = start_ip..(increment_ip(end_ip).unwrap());
(IpAddr::V4(start), IpAddr::V4(end)) => { ip_ranges.push(range);
let start_num = u32::from_be_bytes(start.octets());
let end_num = u32::from_be_bytes(end.octets());
let range = if end_num == u32::MAX {
start_num..end_num // Special case: Use inclusive range when max
} else {
start_num..(end_num + 1) // Normal case
};
ipv4_ranges.push(range);
}
(IpAddr::V6(start), IpAddr::V6(end)) => {
let start_num = u128::from_be_bytes(start.octets());
let end_num = u128::from_be_bytes(end.octets());
let range = if end_num == u128::MAX {
start_num..end_num // Special case: Use inclusive range when max
} else {
start_num..(end_num + 1) // Normal case
};
ipv6_ranges.push(range);
}
_ => {
continue;
}
}
} }
} }
info!( info!(
ipv6_entry_count = ipv6_ranges.len(), ip_entry_count = ip_ranges.len(),
ipv4_entry_count = ipv4_ranges.len(),
"Finished loading blocklist" "Finished loading blocklist"
); );
let blocklist = Self::new(&ipv4_ranges, &ipv6_ranges); let blocklist = Self::new(ip_ranges);
Ok(blocklist) Ok(blocklist)
} }
pub fn is_blocked(&self, ip: &IpAddr) -> bool { pub fn is_blocked(&self, ip: &IpAddr) -> bool {
match ip { self.blocked_ranges.query_point(*ip).next().is_some()
IpAddr::V4(ipv4) => { }
let num = u32::from_be_bytes(ipv4.octets()); }
self.ipv4_ranges.query_point(num).next().is_some()
} /// Safely increments an `IpAddr`, returning `None` if it would overflow.
IpAddr::V6(ipv6) => { fn increment_ip(ip: IpAddr) -> Option<IpAddr> {
let num = u128::from_be_bytes(ipv6.octets()); match ip {
self.ipv6_ranges.query_point(num).next().is_some() IpAddr::V4(ipv4) => {
} let num = u32::from_be_bytes(ipv4.octets());
num.checked_add(1).map(|n| IpAddr::V4(Ipv4Addr::from(n)))
}
IpAddr::V6(ipv6) => {
let num = u128::from_be_bytes(ipv6.octets());
num.checked_add(1).map(|n| IpAddr::V6(Ipv6Addr::from(n)))
} }
} }
} }
fn parse_ip_range(line: &str) -> Option<(IpAddr, IpAddr)> { fn parse_ip_range(line: &str) -> Option<(IpAddr, IpAddr)> {
// Skip comments and empty lines // Skip comments and empty lines
if line.starts_with('#') || line.trim().is_empty() { let line = line.trim();
if line.starts_with('#') || line.is_empty() {
return None; return None;
} }
// Parse IP ranges in format: "RuleName:StartIp-EndIp" let is_ipv4 = line.matches('.').count() >= 6;
if let Some((rule_name, ip_range)) = line.rsplit_once(':') { // Find the split point based on whether it's IPv4 or not
if let Some((start, end)) = ip_range.split_once('-') { let split_point: usize = if is_ipv4 {
if let (Ok(start_ip), Ok(end_ip)) = line.rfind(':')
(IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim())) } else {
{ line.find(':')
return Some((start_ip, end_ip)); }
} else { .unwrap_or(0);
// Mismatched IP versions, skip this range
debug!(rulen_name = rule_name, "Could not be parsed"); let (rule_name, ip_range) = line.split_at(split_point + 1);
} if let Some((start, end)) = ip_range.split_once('-') {
if let (Ok(start_ip), Ok(end_ip)) =
(IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim()))
{
return Some((start_ip, end_ip));
} else {
// Mismatched IP versions, skip this range
debug!(rulen_name = rule_name, "Could not be parsed");
} }
} }
None None
} }
@ -186,7 +167,6 @@ mod tests {
use super::*; use super::*;
use async_compression::tokio::write::GzipEncoder; use async_compression::tokio::write::GzipEncoder;
use mockito::{Server, ServerGuard}; use mockito::{Server, ServerGuard};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::thread::{self, JoinHandle}; use std::thread::{self, JoinHandle};
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
@ -313,21 +293,16 @@ mod tests {
#[test] #[test]
fn test_manual_ranges() { fn test_manual_ranges() {
// Add IPv4 range // Add IPv4 range
let start_v4: Ipv4Addr = "192.168.0.0".parse().unwrap(); let start_v4: IpAddr = "192.168.0.0".parse().unwrap();
let end_v4: Ipv4Addr = "192.168.255.255".parse().unwrap(); let end_v4: IpAddr = "192.168.255.255".parse().unwrap();
let start_num = u32::from_be_bytes(start_v4.octets()); let ipv4_range = start_v4..end_v4;
let end_num = u32::from_be_bytes(end_v4.octets());
let ipv4_range = start_num..(end_num + 1);
// Add IPv6 range // Add IPv6 range
let start_v6: Ipv6Addr = "2001:db8::".parse().unwrap(); let start_v6: IpAddr = "2001:db8::".parse().unwrap();
let end_v6: Ipv6Addr = "2001:db8::ffff".parse().unwrap(); let end_v6: IpAddr = "2001:db8::ffff".parse().unwrap();
let start_num = u128::from_be_bytes(start_v6.octets()); let ipv6_range = start_v6..end_v6;
let end_num = u128::from_be_bytes(end_v6.octets());
let ipv6_range = start_num..(end_num + 1);
let blocklist = Blocklist::new(&vec![ipv4_range], &vec![ipv6_range]);
let blocklist = Blocklist::new(vec![ipv4_range, ipv6_range]);
// Test IPv4 addresses // Test IPv4 addresses
assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap()));
assert!(!blocklist.is_blocked(&"10.0.0.1".parse().unwrap())); assert!(!blocklist.is_blocked(&"10.0.0.1".parse().unwrap()));