👌 Fix remaining pr comments

This commit is contained in:
Alexander WB 2025-02-25 04:25:24 +01:00
parent 4c6e19ceab
commit 1a31340076
3 changed files with 37 additions and 31 deletions

View file

@ -5,7 +5,7 @@ use intervaltree::IntervalTree;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::pin::Pin;
use std::str::FromStr;
use tokio::io::AsyncRead;
use tokio::io::{AsyncBufRead, AsyncRead};
use tokio::{io::AsyncBufReadExt, io::BufReader};
use tokio_util::io::StreamReader;
use tracing::{debug, info, trace};
@ -85,7 +85,7 @@ impl Blocklist {
// Check for Gzip magic bytes (1F 8B)
let is_gzip = peek_bytes == [0x1F, 0x8B];
let reader: Pin<Box<dyn AsyncRead + Send>> = if is_gzip {
let mut reader: Pin<Box<dyn AsyncBufRead + Send>> = if is_gzip {
trace!("Detected Gzip file, decompressing...");
Box::pin(BufReader::new(GzipDecoder::new(reader)))
} else {
@ -93,19 +93,14 @@ impl Blocklist {
Box::pin(reader)
};
let reader = BufReader::new(reader);
let mut lines = reader.lines();
let mut line: String = Default::default();
let mut ip_ranges: Vec<std::ops::Range<IpAddr>> = Vec::new();
while let Some(line) = lines.next_line().await? {
// Skip comments and empty lines
if line.starts_with('#') || line.trim().is_empty() {
continue;
}
while reader.read_line(&mut line).await? > 0 {
if let Some((start_ip, end_ip)) = parse_ip_range(&line) {
let range = start_ip..(increment_ip(end_ip).unwrap());
let range = start_ip..(increment_ip(end_ip));
ip_ranges.push(range);
}
line.clear();
}
info!(
@ -117,21 +112,21 @@ impl Blocklist {
Ok(blocklist)
}
pub fn is_blocked(&self, ip: &IpAddr) -> bool {
self.blocked_ranges.query_point(*ip).next().is_some()
pub fn is_blocked(&self, ip: IpAddr) -> bool {
self.blocked_ranges.query_point(ip).next().is_some()
}
}
/// Safely increments an `IpAddr`, returning `None` if it would overflow.
fn increment_ip(ip: IpAddr) -> Option<IpAddr> {
fn increment_ip(ip: IpAddr) -> IpAddr {
match ip {
IpAddr::V4(ipv4) => {
let num = u32::from_be_bytes(ipv4.octets());
num.checked_add(1).map(|n| IpAddr::V4(Ipv4Addr::from(n)))
std::net::IpAddr::V4(Ipv4Addr::from(num.saturating_add(1)))
}
IpAddr::V6(ipv6) => {
let num = u128::from_be_bytes(ipv6.octets());
num.checked_add(1).map(|n| IpAddr::V6(Ipv6Addr::from(n)))
std::net::IpAddr::V6(Ipv6Addr::from(num.saturating_add(1)))
}
}
}
@ -195,8 +190,8 @@ mod tests {
Ok::<_, std::io::Error>(Cursor::new(gzipped_blocklist))
})));
let blocklist = Blocklist::create_from_stream(stream).await?;
assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap()));
assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap()));
assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap()));
assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap()));
Ok(())
}
@ -213,8 +208,8 @@ mod tests {
Ok::<_, std::io::Error>(Cursor::new(blocklist.as_bytes().to_vec()))
})));
let blocklist = Blocklist::create_from_stream(stream).await?;
assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap()));
assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap()));
assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap()));
assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap()));
Ok(())
}
@ -236,10 +231,10 @@ mod tests {
let blocklist = Blocklist::load_from_file("temp_blocklist.txt").await?;
// Verify the blocklist
assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap()));
assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap()));
assert!(blocklist.is_blocked(&"2001:db8::1".parse().unwrap()));
assert!(!blocklist.is_blocked(&"2001:4860:4860::8888".parse().unwrap()));
assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap()));
assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap()));
assert!(blocklist.is_blocked("2001:db8::1".parse().unwrap()));
assert!(!blocklist.is_blocked("2001:4860:4860::8888".parse().unwrap()));
// Clean up the temporary file
tokio::fs::remove_file("temp_blocklist.txt").await?;
@ -250,8 +245,8 @@ mod tests {
#[test]
fn test_blocklist_empty() {
let blocklist = Blocklist::empty();
assert!(!blocklist.is_blocked(&"127.0.0.1".parse().unwrap()));
assert!(!blocklist.is_blocked(&"::1".parse().unwrap()));
assert!(!blocklist.is_blocked("127.0.0.1".parse().unwrap()));
assert!(!blocklist.is_blocked("::1".parse().unwrap()));
}
#[test]
@ -268,11 +263,11 @@ mod tests {
let blocklist = Blocklist::new(vec![ipv4_range, ipv6_range]);
// Test IPv4 addresses
assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap()));
assert!(!blocklist.is_blocked(&"10.0.0.1".parse().unwrap()));
assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap()));
assert!(!blocklist.is_blocked("10.0.0.1".parse().unwrap()));
// Test IPv6 addresses
assert!(blocklist.is_blocked(&"2001:db8::1".parse().unwrap()));
assert!(!blocklist.is_blocked(&"2001:db9::1".parse().unwrap()));
assert!(blocklist.is_blocked("2001:db8::1".parse().unwrap()));
assert!(!blocklist.is_blocked("2001:db9::1".parse().unwrap()));
}
}

View file

@ -732,7 +732,7 @@ impl Session {
.unwrap_or_else(|| Duration::from_secs(10));
let incoming_ip = addr.ip();
if self.blocklist.is_blocked(&incoming_ip) {
if self.blocklist.is_blocked(incoming_ip) {
bail!("Incoming ip {incoming_ip} is in blocklist");
}

View file

@ -572,6 +572,17 @@ impl TorrentStateLive {
continue;
}
let outgoing_ip = addr.ip();
let is_blocked_ip = state.shared.session.upgrade().map_or_else(
|| false,
|session| session.blocklist.is_blocked(outgoing_ip),
);
if is_blocked_ip {
info!("Outgoing ip {outgoing_ip} for peer is in blocklist skipping");
continue;
}
let permit = state.peer_semaphore.clone().acquire_owned().await?;
state.spawn(
error_span!(parent: state.shared.span.clone(), "manage_peer", peer = addr.to_string()),