From ac883c1ddffb2015e0c09b8c4376baa16d278692 Mon Sep 17 00:00:00 2001 From: Alexander WB Date: Fri, 21 Feb 2025 22:03:49 +0100 Subject: [PATCH] :ok_hand: Improve Parsing logic --- crates/librqbit/src/blocklist.rs | 133 +++++++++++++------------------ 1 file changed, 54 insertions(+), 79 deletions(-) diff --git a/crates/librqbit/src/blocklist.rs b/crates/librqbit/src/blocklist.rs index 5638497..f72f7e7 100644 --- a/crates/librqbit/src/blocklist.rs +++ b/crates/librqbit/src/blocklist.rs @@ -2,7 +2,7 @@ use anyhow::Result; use async_compression::tokio::bufread::GzipDecoder; use futures::TryStreamExt; use intervaltree::IntervalTree; -use std::net::IpAddr; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::pin::Pin; use std::str::FromStr; use tokio::io::AsyncRead; @@ -11,26 +11,19 @@ use tokio_util::io::StreamReader; use tracing::{debug, info, trace}; pub struct Blocklist { - // Separate trees for IPv4 and IPv6 since they have different numeric ranges - ipv4_ranges: IntervalTree, - ipv6_ranges: IntervalTree, + // ipv4 and ipv6 do not overlap + // see: https://www.rfc-editor.org/rfc/rfc4291#section-2.5.5 + blocked_ranges: IntervalTree, } impl Blocklist { pub fn empty() -> Self { - return Self::new( - &Vec::>::new(), - &Vec::>::new(), - ); + return Self::new(std::iter::empty()); } - pub fn new( - ipv4_ranges: &Vec>, - ipv6_ranges: &Vec>, - ) -> Self { + pub fn new(ip_ranges: impl IntoIterator>) -> Self { Self { - ipv4_ranges: IntervalTree::from_iter(ipv4_ranges.iter().map(|r| (r.clone(), ()))), - ipv6_ranges: IntervalTree::from_iter(ipv6_ranges.iter().map(|r| (r.clone(), ()))), + blocked_ranges: IntervalTree::from_iter(ip_ranges.into_iter().map(|r| (r, ()))), } } @@ -97,87 +90,75 @@ impl Blocklist { let reader = BufReader::new(reader); let mut lines = reader.lines(); - let mut ipv4_ranges: Vec> = Vec::new(); - let mut ipv6_ranges: Vec> = Vec::new(); + let mut ip_ranges: Vec> = Vec::new(); while let Some(line) = lines.next_line().await? { // Skip comments and empty lines if line.starts_with('#') || line.trim().is_empty() { continue; } - // Parse IP ranges in format: "RuleName:StartIp-EndIp" if let Some((start_ip, end_ip)) = parse_ip_range(&line) { - match (start_ip, end_ip) { - (IpAddr::V4(start), IpAddr::V4(end)) => { - let start_num = u32::from_be_bytes(start.octets()); - let end_num = u32::from_be_bytes(end.octets()); - let range = if end_num == u32::MAX { - start_num..end_num // Special case: Use inclusive range when max - } else { - start_num..(end_num + 1) // Normal case - }; - ipv4_ranges.push(range); - } - (IpAddr::V6(start), IpAddr::V6(end)) => { - let start_num = u128::from_be_bytes(start.octets()); - let end_num = u128::from_be_bytes(end.octets()); - let range = if end_num == u128::MAX { - start_num..end_num // Special case: Use inclusive range when max - } else { - start_num..(end_num + 1) // Normal case - }; - ipv6_ranges.push(range); - } - _ => { - continue; - } - } + let range = start_ip..(increment_ip(end_ip).unwrap()); + ip_ranges.push(range); } } info!( - ipv6_entry_count = ipv6_ranges.len(), - ipv4_entry_count = ipv4_ranges.len(), + ip_entry_count = ip_ranges.len(), "Finished loading blocklist" ); - let blocklist = Self::new(&ipv4_ranges, &ipv6_ranges); + let blocklist = Self::new(ip_ranges); Ok(blocklist) } pub fn is_blocked(&self, ip: &IpAddr) -> bool { - match ip { - IpAddr::V4(ipv4) => { - let num = u32::from_be_bytes(ipv4.octets()); - self.ipv4_ranges.query_point(num).next().is_some() - } - IpAddr::V6(ipv6) => { - let num = u128::from_be_bytes(ipv6.octets()); - self.ipv6_ranges.query_point(num).next().is_some() - } + self.blocked_ranges.query_point(*ip).next().is_some() + } +} + +/// Safely increments an `IpAddr`, returning `None` if it would overflow. +fn increment_ip(ip: IpAddr) -> Option { + match ip { + IpAddr::V4(ipv4) => { + let num = u32::from_be_bytes(ipv4.octets()); + num.checked_add(1).map(|n| IpAddr::V4(Ipv4Addr::from(n))) + } + IpAddr::V6(ipv6) => { + let num = u128::from_be_bytes(ipv6.octets()); + num.checked_add(1).map(|n| IpAddr::V6(Ipv6Addr::from(n))) } } } fn parse_ip_range(line: &str) -> Option<(IpAddr, IpAddr)> { // Skip comments and empty lines - if line.starts_with('#') || line.trim().is_empty() { + let line = line.trim(); + if line.starts_with('#') || line.is_empty() { return None; } - // Parse IP ranges in format: "RuleName:StartIp-EndIp" - if let Some((rule_name, ip_range)) = line.rsplit_once(':') { - if let Some((start, end)) = ip_range.split_once('-') { - if let (Ok(start_ip), Ok(end_ip)) = - (IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim())) - { - return Some((start_ip, end_ip)); - } else { - // Mismatched IP versions, skip this range - debug!(rulen_name = rule_name, "Could not be parsed"); - } + let is_ipv4 = line.matches('.').count() >= 6; + // Find the split point based on whether it's IPv4 or not + let split_point: usize = if is_ipv4 { + line.rfind(':') + } else { + line.find(':') + } + .unwrap_or(0); + + let (rule_name, ip_range) = line.split_at(split_point + 1); + if let Some((start, end)) = ip_range.split_once('-') { + if let (Ok(start_ip), Ok(end_ip)) = + (IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim())) + { + return Some((start_ip, end_ip)); + } else { + // Mismatched IP versions, skip this range + debug!(rulen_name = rule_name, "Could not be parsed"); } } + None } @@ -186,7 +167,6 @@ mod tests { use super::*; use async_compression::tokio::write::GzipEncoder; use mockito::{Server, ServerGuard}; - use std::net::{Ipv4Addr, Ipv6Addr}; use std::thread::{self, JoinHandle}; use tokio::io::AsyncWriteExt; @@ -313,21 +293,16 @@ mod tests { #[test] fn test_manual_ranges() { // Add IPv4 range - let start_v4: Ipv4Addr = "192.168.0.0".parse().unwrap(); - let end_v4: Ipv4Addr = "192.168.255.255".parse().unwrap(); - let start_num = u32::from_be_bytes(start_v4.octets()); - let end_num = u32::from_be_bytes(end_v4.octets()); - let ipv4_range = start_num..(end_num + 1); + let start_v4: IpAddr = "192.168.0.0".parse().unwrap(); + let end_v4: IpAddr = "192.168.255.255".parse().unwrap(); + let ipv4_range = start_v4..end_v4; // Add IPv6 range - let start_v6: Ipv6Addr = "2001:db8::".parse().unwrap(); - let end_v6: Ipv6Addr = "2001:db8::ffff".parse().unwrap(); - let start_num = u128::from_be_bytes(start_v6.octets()); - let end_num = u128::from_be_bytes(end_v6.octets()); - let ipv6_range = start_num..(end_num + 1); - - let blocklist = Blocklist::new(&vec![ipv4_range], &vec![ipv6_range]); + let start_v6: IpAddr = "2001:db8::".parse().unwrap(); + let end_v6: IpAddr = "2001:db8::ffff".parse().unwrap(); + let ipv6_range = start_v6..end_v6; + let blocklist = Blocklist::new(vec![ipv4_range, ipv6_range]); // Test IPv4 addresses assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); assert!(!blocklist.is_blocked(&"10.0.0.1".parse().unwrap()));