From 1a313400769c29d5ec595734239dc78d2c524d24 Mon Sep 17 00:00:00 2001 From: Alexander WB Date: Tue, 25 Feb 2025 04:25:24 +0100 Subject: [PATCH] :ok_hand: Fix remaining pr comments --- crates/librqbit/src/blocklist.rs | 55 +++++++++---------- crates/librqbit/src/session.rs | 2 +- crates/librqbit/src/torrent_state/live/mod.rs | 11 ++++ 3 files changed, 37 insertions(+), 31 deletions(-) diff --git a/crates/librqbit/src/blocklist.rs b/crates/librqbit/src/blocklist.rs index c194a86..291de10 100644 --- a/crates/librqbit/src/blocklist.rs +++ b/crates/librqbit/src/blocklist.rs @@ -5,7 +5,7 @@ use intervaltree::IntervalTree; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::pin::Pin; use std::str::FromStr; -use tokio::io::AsyncRead; +use tokio::io::{AsyncBufRead, AsyncRead}; use tokio::{io::AsyncBufReadExt, io::BufReader}; use tokio_util::io::StreamReader; use tracing::{debug, info, trace}; @@ -85,7 +85,7 @@ impl Blocklist { // Check for Gzip magic bytes (1F 8B) let is_gzip = peek_bytes == [0x1F, 0x8B]; - let reader: Pin> = if is_gzip { + let mut reader: Pin> = if is_gzip { trace!("Detected Gzip file, decompressing..."); Box::pin(BufReader::new(GzipDecoder::new(reader))) } else { @@ -93,19 +93,14 @@ impl Blocklist { Box::pin(reader) }; - let reader = BufReader::new(reader); - let mut lines = reader.lines(); + let mut line: String = Default::default(); let mut ip_ranges: Vec> = Vec::new(); - while let Some(line) = lines.next_line().await? { - // Skip comments and empty lines - if line.starts_with('#') || line.trim().is_empty() { - continue; - } - + while reader.read_line(&mut line).await? > 0 { if let Some((start_ip, end_ip)) = parse_ip_range(&line) { - let range = start_ip..(increment_ip(end_ip).unwrap()); + let range = start_ip..(increment_ip(end_ip)); ip_ranges.push(range); } + line.clear(); } info!( @@ -117,21 +112,21 @@ impl Blocklist { Ok(blocklist) } - pub fn is_blocked(&self, ip: &IpAddr) -> bool { - self.blocked_ranges.query_point(*ip).next().is_some() + pub fn is_blocked(&self, ip: IpAddr) -> bool { + self.blocked_ranges.query_point(ip).next().is_some() } } /// Safely increments an `IpAddr`, returning `None` if it would overflow. -fn increment_ip(ip: IpAddr) -> Option { +fn increment_ip(ip: IpAddr) -> IpAddr { match ip { IpAddr::V4(ipv4) => { let num = u32::from_be_bytes(ipv4.octets()); - num.checked_add(1).map(|n| IpAddr::V4(Ipv4Addr::from(n))) + std::net::IpAddr::V4(Ipv4Addr::from(num.saturating_add(1))) } IpAddr::V6(ipv6) => { let num = u128::from_be_bytes(ipv6.octets()); - num.checked_add(1).map(|n| IpAddr::V6(Ipv6Addr::from(n))) + std::net::IpAddr::V6(Ipv6Addr::from(num.saturating_add(1))) } } } @@ -195,8 +190,8 @@ mod tests { Ok::<_, std::io::Error>(Cursor::new(gzipped_blocklist)) }))); let blocklist = Blocklist::create_from_stream(stream).await?; - assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); - assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap())); + assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap())); Ok(()) } @@ -213,8 +208,8 @@ mod tests { Ok::<_, std::io::Error>(Cursor::new(blocklist.as_bytes().to_vec())) }))); let blocklist = Blocklist::create_from_stream(stream).await?; - assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); - assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap())); + assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap())); Ok(()) } @@ -236,10 +231,10 @@ mod tests { let blocklist = Blocklist::load_from_file("temp_blocklist.txt").await?; // Verify the blocklist - assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); - assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap())); - assert!(blocklist.is_blocked(&"2001:db8::1".parse().unwrap())); - assert!(!blocklist.is_blocked(&"2001:4860:4860::8888".parse().unwrap())); + assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap())); + assert!(blocklist.is_blocked("2001:db8::1".parse().unwrap())); + assert!(!blocklist.is_blocked("2001:4860:4860::8888".parse().unwrap())); // Clean up the temporary file tokio::fs::remove_file("temp_blocklist.txt").await?; @@ -250,8 +245,8 @@ mod tests { #[test] fn test_blocklist_empty() { let blocklist = Blocklist::empty(); - assert!(!blocklist.is_blocked(&"127.0.0.1".parse().unwrap())); - assert!(!blocklist.is_blocked(&"::1".parse().unwrap())); + assert!(!blocklist.is_blocked("127.0.0.1".parse().unwrap())); + assert!(!blocklist.is_blocked("::1".parse().unwrap())); } #[test] @@ -268,11 +263,11 @@ mod tests { let blocklist = Blocklist::new(vec![ipv4_range, ipv6_range]); // Test IPv4 addresses - assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); - assert!(!blocklist.is_blocked(&"10.0.0.1".parse().unwrap())); + assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked("10.0.0.1".parse().unwrap())); // Test IPv6 addresses - assert!(blocklist.is_blocked(&"2001:db8::1".parse().unwrap())); - assert!(!blocklist.is_blocked(&"2001:db9::1".parse().unwrap())); + assert!(blocklist.is_blocked("2001:db8::1".parse().unwrap())); + assert!(!blocklist.is_blocked("2001:db9::1".parse().unwrap())); } } diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index c68de6f..951e025 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -732,7 +732,7 @@ impl Session { .unwrap_or_else(|| Duration::from_secs(10)); let incoming_ip = addr.ip(); - if self.blocklist.is_blocked(&incoming_ip) { + if self.blocklist.is_blocked(incoming_ip) { bail!("Incoming ip {incoming_ip} is in blocklist"); } diff --git a/crates/librqbit/src/torrent_state/live/mod.rs b/crates/librqbit/src/torrent_state/live/mod.rs index d7f9748..3729494 100644 --- a/crates/librqbit/src/torrent_state/live/mod.rs +++ b/crates/librqbit/src/torrent_state/live/mod.rs @@ -572,6 +572,17 @@ impl TorrentStateLive { continue; } + let outgoing_ip = addr.ip(); + let is_blocked_ip = state.shared.session.upgrade().map_or_else( + || false, + |session| session.blocklist.is_blocked(outgoing_ip), + ); + + if is_blocked_ip { + info!("Outgoing ip {outgoing_ip} for peer is in blocklist skipping"); + continue; + } + let permit = state.peer_semaphore.clone().acquire_owned().await?; state.spawn( error_span!(parent: state.shared.span.clone(), "manage_peer", peer = addr.to_string()),