👌 Improve Parsing logic
This commit is contained in:
parent
6e9ecf8a26
commit
ac883c1ddf
1 changed files with 54 additions and 79 deletions
|
|
@ -2,7 +2,7 @@ use anyhow::Result;
|
|||
use async_compression::tokio::bufread::GzipDecoder;
|
||||
use futures::TryStreamExt;
|
||||
use intervaltree::IntervalTree;
|
||||
use std::net::IpAddr;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr;
|
||||
use tokio::io::AsyncRead;
|
||||
|
|
@ -11,26 +11,19 @@ use tokio_util::io::StreamReader;
|
|||
use tracing::{debug, info, trace};
|
||||
|
||||
pub struct Blocklist {
|
||||
// Separate trees for IPv4 and IPv6 since they have different numeric ranges
|
||||
ipv4_ranges: IntervalTree<u32, ()>,
|
||||
ipv6_ranges: IntervalTree<u128, ()>,
|
||||
// ipv4 and ipv6 do not overlap
|
||||
// see: https://www.rfc-editor.org/rfc/rfc4291#section-2.5.5
|
||||
blocked_ranges: IntervalTree<IpAddr, ()>,
|
||||
}
|
||||
|
||||
impl Blocklist {
|
||||
pub fn empty() -> Self {
|
||||
return Self::new(
|
||||
&Vec::<std::ops::Range<u32>>::new(),
|
||||
&Vec::<std::ops::Range<u128>>::new(),
|
||||
);
|
||||
return Self::new(std::iter::empty());
|
||||
}
|
||||
|
||||
pub fn new(
|
||||
ipv4_ranges: &Vec<std::ops::Range<u32>>,
|
||||
ipv6_ranges: &Vec<std::ops::Range<u128>>,
|
||||
) -> Self {
|
||||
pub fn new(ip_ranges: impl IntoIterator<Item = std::ops::Range<IpAddr>>) -> Self {
|
||||
Self {
|
||||
ipv4_ranges: IntervalTree::from_iter(ipv4_ranges.iter().map(|r| (r.clone(), ()))),
|
||||
ipv6_ranges: IntervalTree::from_iter(ipv6_ranges.iter().map(|r| (r.clone(), ()))),
|
||||
blocked_ranges: IntervalTree::from_iter(ip_ranges.into_iter().map(|r| (r, ()))),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -97,87 +90,75 @@ impl Blocklist {
|
|||
|
||||
let reader = BufReader::new(reader);
|
||||
let mut lines = reader.lines();
|
||||
let mut ipv4_ranges: Vec<std::ops::Range<u32>> = Vec::new();
|
||||
let mut ipv6_ranges: Vec<std::ops::Range<u128>> = Vec::new();
|
||||
let mut ip_ranges: Vec<std::ops::Range<IpAddr>> = Vec::new();
|
||||
while let Some(line) = lines.next_line().await? {
|
||||
// Skip comments and empty lines
|
||||
if line.starts_with('#') || line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse IP ranges in format: "RuleName:StartIp-EndIp"
|
||||
if let Some((start_ip, end_ip)) = parse_ip_range(&line) {
|
||||
match (start_ip, end_ip) {
|
||||
(IpAddr::V4(start), IpAddr::V4(end)) => {
|
||||
let start_num = u32::from_be_bytes(start.octets());
|
||||
let end_num = u32::from_be_bytes(end.octets());
|
||||
let range = if end_num == u32::MAX {
|
||||
start_num..end_num // Special case: Use inclusive range when max
|
||||
} else {
|
||||
start_num..(end_num + 1) // Normal case
|
||||
};
|
||||
ipv4_ranges.push(range);
|
||||
}
|
||||
(IpAddr::V6(start), IpAddr::V6(end)) => {
|
||||
let start_num = u128::from_be_bytes(start.octets());
|
||||
let end_num = u128::from_be_bytes(end.octets());
|
||||
let range = if end_num == u128::MAX {
|
||||
start_num..end_num // Special case: Use inclusive range when max
|
||||
} else {
|
||||
start_num..(end_num + 1) // Normal case
|
||||
};
|
||||
ipv6_ranges.push(range);
|
||||
}
|
||||
_ => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
let range = start_ip..(increment_ip(end_ip).unwrap());
|
||||
ip_ranges.push(range);
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
ipv6_entry_count = ipv6_ranges.len(),
|
||||
ipv4_entry_count = ipv4_ranges.len(),
|
||||
ip_entry_count = ip_ranges.len(),
|
||||
"Finished loading blocklist"
|
||||
);
|
||||
|
||||
let blocklist = Self::new(&ipv4_ranges, &ipv6_ranges);
|
||||
let blocklist = Self::new(ip_ranges);
|
||||
Ok(blocklist)
|
||||
}
|
||||
|
||||
pub fn is_blocked(&self, ip: &IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(ipv4) => {
|
||||
let num = u32::from_be_bytes(ipv4.octets());
|
||||
self.ipv4_ranges.query_point(num).next().is_some()
|
||||
}
|
||||
IpAddr::V6(ipv6) => {
|
||||
let num = u128::from_be_bytes(ipv6.octets());
|
||||
self.ipv6_ranges.query_point(num).next().is_some()
|
||||
}
|
||||
self.blocked_ranges.query_point(*ip).next().is_some()
|
||||
}
|
||||
}
|
||||
|
||||
/// Safely increments an `IpAddr`, returning `None` if it would overflow.
|
||||
fn increment_ip(ip: IpAddr) -> Option<IpAddr> {
|
||||
match ip {
|
||||
IpAddr::V4(ipv4) => {
|
||||
let num = u32::from_be_bytes(ipv4.octets());
|
||||
num.checked_add(1).map(|n| IpAddr::V4(Ipv4Addr::from(n)))
|
||||
}
|
||||
IpAddr::V6(ipv6) => {
|
||||
let num = u128::from_be_bytes(ipv6.octets());
|
||||
num.checked_add(1).map(|n| IpAddr::V6(Ipv6Addr::from(n)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_ip_range(line: &str) -> Option<(IpAddr, IpAddr)> {
|
||||
// Skip comments and empty lines
|
||||
if line.starts_with('#') || line.trim().is_empty() {
|
||||
let line = line.trim();
|
||||
if line.starts_with('#') || line.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Parse IP ranges in format: "RuleName:StartIp-EndIp"
|
||||
if let Some((rule_name, ip_range)) = line.rsplit_once(':') {
|
||||
if let Some((start, end)) = ip_range.split_once('-') {
|
||||
if let (Ok(start_ip), Ok(end_ip)) =
|
||||
(IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim()))
|
||||
{
|
||||
return Some((start_ip, end_ip));
|
||||
} else {
|
||||
// Mismatched IP versions, skip this range
|
||||
debug!(rulen_name = rule_name, "Could not be parsed");
|
||||
}
|
||||
let is_ipv4 = line.matches('.').count() >= 6;
|
||||
// Find the split point based on whether it's IPv4 or not
|
||||
let split_point: usize = if is_ipv4 {
|
||||
line.rfind(':')
|
||||
} else {
|
||||
line.find(':')
|
||||
}
|
||||
.unwrap_or(0);
|
||||
|
||||
let (rule_name, ip_range) = line.split_at(split_point + 1);
|
||||
if let Some((start, end)) = ip_range.split_once('-') {
|
||||
if let (Ok(start_ip), Ok(end_ip)) =
|
||||
(IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim()))
|
||||
{
|
||||
return Some((start_ip, end_ip));
|
||||
} else {
|
||||
// Mismatched IP versions, skip this range
|
||||
debug!(rulen_name = rule_name, "Could not be parsed");
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
|
|
@ -186,7 +167,6 @@ mod tests {
|
|||
use super::*;
|
||||
use async_compression::tokio::write::GzipEncoder;
|
||||
use mockito::{Server, ServerGuard};
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::thread::{self, JoinHandle};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
|
|
@ -313,21 +293,16 @@ mod tests {
|
|||
#[test]
|
||||
fn test_manual_ranges() {
|
||||
// Add IPv4 range
|
||||
let start_v4: Ipv4Addr = "192.168.0.0".parse().unwrap();
|
||||
let end_v4: Ipv4Addr = "192.168.255.255".parse().unwrap();
|
||||
let start_num = u32::from_be_bytes(start_v4.octets());
|
||||
let end_num = u32::from_be_bytes(end_v4.octets());
|
||||
let ipv4_range = start_num..(end_num + 1);
|
||||
let start_v4: IpAddr = "192.168.0.0".parse().unwrap();
|
||||
let end_v4: IpAddr = "192.168.255.255".parse().unwrap();
|
||||
let ipv4_range = start_v4..end_v4;
|
||||
|
||||
// Add IPv6 range
|
||||
let start_v6: Ipv6Addr = "2001:db8::".parse().unwrap();
|
||||
let end_v6: Ipv6Addr = "2001:db8::ffff".parse().unwrap();
|
||||
let start_num = u128::from_be_bytes(start_v6.octets());
|
||||
let end_num = u128::from_be_bytes(end_v6.octets());
|
||||
let ipv6_range = start_num..(end_num + 1);
|
||||
|
||||
let blocklist = Blocklist::new(&vec![ipv4_range], &vec![ipv6_range]);
|
||||
let start_v6: IpAddr = "2001:db8::".parse().unwrap();
|
||||
let end_v6: IpAddr = "2001:db8::ffff".parse().unwrap();
|
||||
let ipv6_range = start_v6..end_v6;
|
||||
|
||||
let blocklist = Blocklist::new(vec![ipv4_range, ipv6_range]);
|
||||
// Test IPv4 addresses
|
||||
assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap()));
|
||||
assert!(!blocklist.is_blocked(&"10.0.0.1".parse().unwrap()));
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue