👌 Improve Parsing logic
This commit is contained in:
parent
6e9ecf8a26
commit
ac883c1ddf
1 changed files with 54 additions and 79 deletions
|
|
@ -2,7 +2,7 @@ use anyhow::Result;
|
||||||
use async_compression::tokio::bufread::GzipDecoder;
|
use async_compression::tokio::bufread::GzipDecoder;
|
||||||
use futures::TryStreamExt;
|
use futures::TryStreamExt;
|
||||||
use intervaltree::IntervalTree;
|
use intervaltree::IntervalTree;
|
||||||
use std::net::IpAddr;
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use tokio::io::AsyncRead;
|
use tokio::io::AsyncRead;
|
||||||
|
|
@ -11,26 +11,19 @@ use tokio_util::io::StreamReader;
|
||||||
use tracing::{debug, info, trace};
|
use tracing::{debug, info, trace};
|
||||||
|
|
||||||
pub struct Blocklist {
|
pub struct Blocklist {
|
||||||
// Separate trees for IPv4 and IPv6 since they have different numeric ranges
|
// ipv4 and ipv6 do not overlap
|
||||||
ipv4_ranges: IntervalTree<u32, ()>,
|
// see: https://www.rfc-editor.org/rfc/rfc4291#section-2.5.5
|
||||||
ipv6_ranges: IntervalTree<u128, ()>,
|
blocked_ranges: IntervalTree<IpAddr, ()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Blocklist {
|
impl Blocklist {
|
||||||
pub fn empty() -> Self {
|
pub fn empty() -> Self {
|
||||||
return Self::new(
|
return Self::new(std::iter::empty());
|
||||||
&Vec::<std::ops::Range<u32>>::new(),
|
|
||||||
&Vec::<std::ops::Range<u128>>::new(),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(
|
pub fn new(ip_ranges: impl IntoIterator<Item = std::ops::Range<IpAddr>>) -> Self {
|
||||||
ipv4_ranges: &Vec<std::ops::Range<u32>>,
|
|
||||||
ipv6_ranges: &Vec<std::ops::Range<u128>>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
ipv4_ranges: IntervalTree::from_iter(ipv4_ranges.iter().map(|r| (r.clone(), ()))),
|
blocked_ranges: IntervalTree::from_iter(ip_ranges.into_iter().map(|r| (r, ()))),
|
||||||
ipv6_ranges: IntervalTree::from_iter(ipv6_ranges.iter().map(|r| (r.clone(), ()))),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -97,87 +90,75 @@ impl Blocklist {
|
||||||
|
|
||||||
let reader = BufReader::new(reader);
|
let reader = BufReader::new(reader);
|
||||||
let mut lines = reader.lines();
|
let mut lines = reader.lines();
|
||||||
let mut ipv4_ranges: Vec<std::ops::Range<u32>> = Vec::new();
|
let mut ip_ranges: Vec<std::ops::Range<IpAddr>> = Vec::new();
|
||||||
let mut ipv6_ranges: Vec<std::ops::Range<u128>> = Vec::new();
|
|
||||||
while let Some(line) = lines.next_line().await? {
|
while let Some(line) = lines.next_line().await? {
|
||||||
// Skip comments and empty lines
|
// Skip comments and empty lines
|
||||||
if line.starts_with('#') || line.trim().is_empty() {
|
if line.starts_with('#') || line.trim().is_empty() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse IP ranges in format: "RuleName:StartIp-EndIp"
|
|
||||||
if let Some((start_ip, end_ip)) = parse_ip_range(&line) {
|
if let Some((start_ip, end_ip)) = parse_ip_range(&line) {
|
||||||
match (start_ip, end_ip) {
|
let range = start_ip..(increment_ip(end_ip).unwrap());
|
||||||
(IpAddr::V4(start), IpAddr::V4(end)) => {
|
ip_ranges.push(range);
|
||||||
let start_num = u32::from_be_bytes(start.octets());
|
|
||||||
let end_num = u32::from_be_bytes(end.octets());
|
|
||||||
let range = if end_num == u32::MAX {
|
|
||||||
start_num..end_num // Special case: Use inclusive range when max
|
|
||||||
} else {
|
|
||||||
start_num..(end_num + 1) // Normal case
|
|
||||||
};
|
|
||||||
ipv4_ranges.push(range);
|
|
||||||
}
|
|
||||||
(IpAddr::V6(start), IpAddr::V6(end)) => {
|
|
||||||
let start_num = u128::from_be_bytes(start.octets());
|
|
||||||
let end_num = u128::from_be_bytes(end.octets());
|
|
||||||
let range = if end_num == u128::MAX {
|
|
||||||
start_num..end_num // Special case: Use inclusive range when max
|
|
||||||
} else {
|
|
||||||
start_num..(end_num + 1) // Normal case
|
|
||||||
};
|
|
||||||
ipv6_ranges.push(range);
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
ipv6_entry_count = ipv6_ranges.len(),
|
ip_entry_count = ip_ranges.len(),
|
||||||
ipv4_entry_count = ipv4_ranges.len(),
|
|
||||||
"Finished loading blocklist"
|
"Finished loading blocklist"
|
||||||
);
|
);
|
||||||
|
|
||||||
let blocklist = Self::new(&ipv4_ranges, &ipv6_ranges);
|
let blocklist = Self::new(ip_ranges);
|
||||||
Ok(blocklist)
|
Ok(blocklist)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_blocked(&self, ip: &IpAddr) -> bool {
|
pub fn is_blocked(&self, ip: &IpAddr) -> bool {
|
||||||
match ip {
|
self.blocked_ranges.query_point(*ip).next().is_some()
|
||||||
IpAddr::V4(ipv4) => {
|
}
|
||||||
let num = u32::from_be_bytes(ipv4.octets());
|
}
|
||||||
self.ipv4_ranges.query_point(num).next().is_some()
|
|
||||||
}
|
/// Safely increments an `IpAddr`, returning `None` if it would overflow.
|
||||||
IpAddr::V6(ipv6) => {
|
fn increment_ip(ip: IpAddr) -> Option<IpAddr> {
|
||||||
let num = u128::from_be_bytes(ipv6.octets());
|
match ip {
|
||||||
self.ipv6_ranges.query_point(num).next().is_some()
|
IpAddr::V4(ipv4) => {
|
||||||
}
|
let num = u32::from_be_bytes(ipv4.octets());
|
||||||
|
num.checked_add(1).map(|n| IpAddr::V4(Ipv4Addr::from(n)))
|
||||||
|
}
|
||||||
|
IpAddr::V6(ipv6) => {
|
||||||
|
let num = u128::from_be_bytes(ipv6.octets());
|
||||||
|
num.checked_add(1).map(|n| IpAddr::V6(Ipv6Addr::from(n)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_ip_range(line: &str) -> Option<(IpAddr, IpAddr)> {
|
fn parse_ip_range(line: &str) -> Option<(IpAddr, IpAddr)> {
|
||||||
// Skip comments and empty lines
|
// Skip comments and empty lines
|
||||||
if line.starts_with('#') || line.trim().is_empty() {
|
let line = line.trim();
|
||||||
|
if line.starts_with('#') || line.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse IP ranges in format: "RuleName:StartIp-EndIp"
|
let is_ipv4 = line.matches('.').count() >= 6;
|
||||||
if let Some((rule_name, ip_range)) = line.rsplit_once(':') {
|
// Find the split point based on whether it's IPv4 or not
|
||||||
if let Some((start, end)) = ip_range.split_once('-') {
|
let split_point: usize = if is_ipv4 {
|
||||||
if let (Ok(start_ip), Ok(end_ip)) =
|
line.rfind(':')
|
||||||
(IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim()))
|
} else {
|
||||||
{
|
line.find(':')
|
||||||
return Some((start_ip, end_ip));
|
}
|
||||||
} else {
|
.unwrap_or(0);
|
||||||
// Mismatched IP versions, skip this range
|
|
||||||
debug!(rulen_name = rule_name, "Could not be parsed");
|
let (rule_name, ip_range) = line.split_at(split_point + 1);
|
||||||
}
|
if let Some((start, end)) = ip_range.split_once('-') {
|
||||||
|
if let (Ok(start_ip), Ok(end_ip)) =
|
||||||
|
(IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim()))
|
||||||
|
{
|
||||||
|
return Some((start_ip, end_ip));
|
||||||
|
} else {
|
||||||
|
// Mismatched IP versions, skip this range
|
||||||
|
debug!(rulen_name = rule_name, "Could not be parsed");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -186,7 +167,6 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use async_compression::tokio::write::GzipEncoder;
|
use async_compression::tokio::write::GzipEncoder;
|
||||||
use mockito::{Server, ServerGuard};
|
use mockito::{Server, ServerGuard};
|
||||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
|
||||||
use std::thread::{self, JoinHandle};
|
use std::thread::{self, JoinHandle};
|
||||||
use tokio::io::AsyncWriteExt;
|
use tokio::io::AsyncWriteExt;
|
||||||
|
|
||||||
|
|
@ -313,21 +293,16 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_manual_ranges() {
|
fn test_manual_ranges() {
|
||||||
// Add IPv4 range
|
// Add IPv4 range
|
||||||
let start_v4: Ipv4Addr = "192.168.0.0".parse().unwrap();
|
let start_v4: IpAddr = "192.168.0.0".parse().unwrap();
|
||||||
let end_v4: Ipv4Addr = "192.168.255.255".parse().unwrap();
|
let end_v4: IpAddr = "192.168.255.255".parse().unwrap();
|
||||||
let start_num = u32::from_be_bytes(start_v4.octets());
|
let ipv4_range = start_v4..end_v4;
|
||||||
let end_num = u32::from_be_bytes(end_v4.octets());
|
|
||||||
let ipv4_range = start_num..(end_num + 1);
|
|
||||||
|
|
||||||
// Add IPv6 range
|
// Add IPv6 range
|
||||||
let start_v6: Ipv6Addr = "2001:db8::".parse().unwrap();
|
let start_v6: IpAddr = "2001:db8::".parse().unwrap();
|
||||||
let end_v6: Ipv6Addr = "2001:db8::ffff".parse().unwrap();
|
let end_v6: IpAddr = "2001:db8::ffff".parse().unwrap();
|
||||||
let start_num = u128::from_be_bytes(start_v6.octets());
|
let ipv6_range = start_v6..end_v6;
|
||||||
let end_num = u128::from_be_bytes(end_v6.octets());
|
|
||||||
let ipv6_range = start_num..(end_num + 1);
|
|
||||||
|
|
||||||
let blocklist = Blocklist::new(&vec![ipv4_range], &vec![ipv6_range]);
|
|
||||||
|
|
||||||
|
let blocklist = Blocklist::new(vec![ipv4_range, ipv6_range]);
|
||||||
// Test IPv4 addresses
|
// Test IPv4 addresses
|
||||||
assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap()));
|
assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap()));
|
||||||
assert!(!blocklist.is_blocked(&"10.0.0.1".parse().unwrap()));
|
assert!(!blocklist.is_blocked(&"10.0.0.1".parse().unwrap()));
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue