From 187ce8c4629a7300455a343f18d53f98c782ed43 Mon Sep 17 00:00:00 2001 From: Alexander WB Date: Mon, 18 Nov 2024 23:02:52 +0100 Subject: [PATCH 1/7] :tada: Add Blocklist Add an implementation of p2p plaintext (and gz compressed) blocklists. The list can be read from an url or from a file. All the IP ranges are then stored in interval trees. --- Cargo.lock | 75 +++++++ crates/librqbit/Cargo.toml | 4 + crates/librqbit/src/blocklist.rs | 342 +++++++++++++++++++++++++++++++ crates/librqbit/src/lib.rs | 1 + 4 files changed, 422 insertions(+) create mode 100644 crates/librqbit/src/blocklist.rs diff --git a/Cargo.lock b/Cargo.lock index 5395ae9..638b2b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -135,6 +135,16 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "assert_cfg" version = "0.1.0" @@ -171,6 +181,19 @@ dependencies = [ "syn 2.0.95", ] +[[package]] +name = "async-compression" +version = "0.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df895a515f70646414f4b45c0b79082783b80552b373a68283012928df56f522" +dependencies = [ + "flate2", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -801,6 +824,16 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +[[package]] +name = "colored" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbf2150cce219b664a8a70df7a1f933836724b503f8a413af9365b4dcc4d90b8" +dependencies = [ + "lazy_static", + "windows-sys 0.48.0", +] + [[package]] name = "combine" version = "4.6.7" @@ -2338,6 +2371,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "intervaltree" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "270bc34e57047cab801a8c871c124d9dc7132f6473c6401f645524f4e6edd111" +dependencies = [ + "smallvec", +] + [[package]] name = "ipnet" version = "2.10.1" @@ -2629,6 +2671,7 @@ dependencies = [ "anyhow", "arc-swap", "async-backtrace", + "async-compression", "async-stream", "async-trait", "axum 0.8.1", @@ -2645,6 +2688,7 @@ dependencies = [ "hex 0.4.3", "home", "http", + "intervaltree", "itertools 0.14.0", "librqbit-bencode", "librqbit-buffers", @@ -2659,6 +2703,7 @@ dependencies = [ "lru", "memmap2", "mime_guess", + "mockito", "notify", "parking_lot", "rand 0.8.5", @@ -3038,6 +3083,30 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "mockito" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "652cd6d169a36eaf9d1e6bce1a221130439a966d7f27858af66a33a66e9c4ee2" +dependencies = [ + "assert-json-diff", + "bytes", + "colored", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "log", + "rand 0.8.5", + "regex", + "serde_json", + "serde_urlencoded", + "similar", + "tokio", +] + [[package]] name = "muda" version = "0.15.3" @@ -4859,6 +4928,12 @@ version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" +[[package]] +name = "similar" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1de1d4f81173b03af4c0cbed3c898f6bff5b870e4a7f5d6f4057d62a7a4b686e" + [[package]] name = "siphasher" version = "0.3.11" diff --git a/crates/librqbit/Cargo.toml b/crates/librqbit/Cargo.toml index 29da015..a280c13 100644 --- a/crates/librqbit/Cargo.toml +++ b/crates/librqbit/Cargo.toml @@ -81,6 +81,7 @@ regex = "1" reqwest = { version = "0.12", default-features = false, features = [ "json", "socks", + "stream", ] } urlencoding = "2" byteorder = "1" @@ -117,6 +118,8 @@ async-backtrace = { version = "0.2", optional = true } notify = { version = "7", optional = true } walkdir = "2.5.0" arc-swap = "1.7.1" +intervaltree = "0.2.7" +async-compression = {version="0.4.18", features= ["tokio", "gzip"] } [build-dependencies] anyhow = "1" @@ -127,3 +130,4 @@ tracing-subscriber = "0.3" tokio-test = "0.4" tempfile = "3" rand = { version = "0.8", features = ["small_rng"] } +mockito = "1.2" diff --git a/crates/librqbit/src/blocklist.rs b/crates/librqbit/src/blocklist.rs new file mode 100644 index 0000000..b09e045 --- /dev/null +++ b/crates/librqbit/src/blocklist.rs @@ -0,0 +1,342 @@ +use anyhow::Result; +use async_compression::tokio::bufread::GzipDecoder; +use futures::TryStreamExt; +use intervaltree::IntervalTree; +use std::net::IpAddr; +use std::pin::Pin; +use std::str::FromStr; +use tokio::io::AsyncRead; +use tokio::{io::AsyncBufReadExt, io::BufReader}; +use tokio_util::io::StreamReader; +use tracing::{debug, info, trace}; + +struct Blocklist { + // Separate trees for IPv4 and IPv6 since they have different numeric ranges + ipv4_ranges: IntervalTree, + ipv6_ranges: IntervalTree, +} + +impl Blocklist { + pub fn new( + ipv4_ranges: &Vec>, + ipv6_ranges: &Vec>, + ) -> Self { + Self { + ipv4_ranges: IntervalTree::from_iter(ipv4_ranges.iter().map(|r| (r.clone(), ()))), + ipv6_ranges: IntervalTree::from_iter(ipv6_ranges.iter().map(|r| (r.clone(), ()))), + } + } + + fn ip_to_num(ip: &IpAddr) -> u128 { + match ip { + IpAddr::V4(ip) => u32::from_be_bytes(ip.octets()) as u128, + IpAddr::V6(ip) => u128::from_be_bytes(ip.octets()), + } + } + + pub async fn load_from_url(url: &str) -> Result { + let response = reqwest::get(url).await.map_err(|e| anyhow::anyhow!(e))?; + if response.status() != 200 { + return Err(anyhow::anyhow!( + "Failed to fetch blocklist: HTTP {}", + response.status() + )); + } + + let content_length = response + .content_length() + .ok_or_else(|| anyhow::anyhow!("Failed to get content length"))?; + + if content_length < 2 { + return Err(anyhow::anyhow!( + "Content too short: not enough data to determine compression" + )); + } + + let reader = StreamReader::new( + response + .bytes_stream() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)), + ); + Self::create_from_stream(reader).await + } + + pub async fn load_from_file(path: &str) -> Result { + let file = tokio::fs::File::open(path).await?; + let reader = tokio::io::BufReader::new(file); + Self::create_from_stream(reader).await + } + + async fn create_from_stream(reader: R) -> Result + where + R: AsyncRead + Unpin + Send, + { + let mut peek_bytes = [0u8; 2]; + let mut reader = tokio::io::BufReader::new(reader); + + let buffer = reader.fill_buf().await?; // Get a reference to the buffer + if buffer.len() >= 2 { + peek_bytes.copy_from_slice(&buffer[0..2]); + } else { + return Err(anyhow::anyhow!( + "Content too short: not enough data to determine compression" + )); + } + + // Check for Gzip magic bytes (1F 8B) + let is_gzip = peek_bytes == [0x1F, 0x8B]; + + let reader: Pin> = if is_gzip { + trace!("Detected Gzip file, decompressing..."); + Box::pin(BufReader::new(GzipDecoder::new(reader))) + } else { + trace!("Plain text file detected."); + Box::pin(reader) + }; + + let reader = BufReader::new(reader); + let mut lines = reader.lines(); + let mut ipv4_ranges: Vec> = Vec::new(); + let mut ipv6_ranges: Vec> = Vec::new(); + while let Some(line) = lines.next_line().await? { + // Skip comments and empty lines + if line.starts_with('#') || line.trim().is_empty() { + continue; + } + + // Parse IP ranges in format: "RuleName:StartIp-EndIp" + if let Some((start_ip, end_ip)) = parse_ip_range(&line) { + match (start_ip, end_ip) { + (IpAddr::V4(start), IpAddr::V4(end)) => { + let start_num = u32::from_be_bytes(start.octets()); + let end_num = u32::from_be_bytes(end.octets()); + let range = if end_num == u32::MAX { + start_num..end_num // Special case: Use inclusive range when max + } else { + start_num..(end_num + 1) // Normal case + }; + ipv4_ranges.push(range); + } + (IpAddr::V6(start), IpAddr::V6(end)) => { + let start_num = u128::from_be_bytes(start.octets()); + let end_num = u128::from_be_bytes(end.octets()); + let range = if end_num == u128::MAX { + start_num..end_num // Special case: Use inclusive range when max + } else { + start_num..(end_num + 1) // Normal case + }; + ipv6_ranges.push(range); + } + _ => { + continue; + } + } + } + } + + info!( + ipv6_entry_count = ipv6_ranges.len(), + ipv4_entry_count = ipv4_ranges.len(), + "Finished loading blocklist" + ); + + let blocklist = Self::new(&ipv4_ranges, &ipv6_ranges); + Ok(blocklist) + } + + pub fn is_blocked(&self, ip: &IpAddr) -> bool { + match ip { + IpAddr::V4(ipv4) => { + let num = u32::from_be_bytes(ipv4.octets()); + self.ipv4_ranges.query_point(num).next().is_some() + } + IpAddr::V6(ipv6) => { + let num = u128::from_be_bytes(ipv6.octets()); + self.ipv6_ranges.query_point(num).next().is_some() + } + } + } +} + +fn parse_ip_range(line: &str) -> Option<(IpAddr, IpAddr)> { + // Skip comments and empty lines + if line.starts_with('#') || line.trim().is_empty() { + return None; + } + + // Parse IP ranges in format: "RuleName:StartIp-EndIp" + if let Some((rule_name, ip_range)) = line.rsplit_once(':') { + if let Some((start, end)) = ip_range.split_once('-') { + if let (Ok(start_ip), Ok(end_ip)) = + (IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim())) + { + return Some((start_ip, end_ip)); + } else { + // Mismatched IP versions, skip this range + debug!(rulen_name = rule_name, "Could not be parsed"); + } + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + use async_compression::tokio::write::GzipEncoder; + use mockito::{Server, ServerGuard}; + use std::net::{Ipv4Addr, Ipv6Addr}; + use std::thread::{self, JoinHandle}; + use tokio::io::AsyncWriteExt; + + struct TestServer { + server: ServerGuard, + mock: mockito::Mock, + url: String, + _thread: JoinHandle<()>, + } + + impl TestServer { + fn new(content: &[u8], headers: &[(&str, &str)]) -> Self { + let (tx, rx) = std::sync::mpsc::channel(); + let server_thread = thread::spawn(move || { + let mut server = Server::new(); + let url = server.url(); + let mock = server.mock("GET", "/").with_status(200); + + tx.send((server, mock, url)).unwrap(); + thread::park(); + }); + + let (server, mut mock, url) = rx.recv().unwrap(); + + // Add response body and headers + mock = mock.with_body(content); + for &(key, value) in headers { + mock = mock.with_header(key, value); + } + let mock = mock.create(); + + TestServer { + server, + mock, + url, + _thread: server_thread, + } + } + } + + impl Drop for TestServer { + fn drop(&mut self) { + self._thread.thread().unpark(); + } + } + + #[tokio::test] + async fn test_blocklist_gzipped() -> Result<()> { + let blocklist = r#" + # test + local:192.168.1.1-192.168.1.255 + localv6:2001:db8::1-2001:db8::ffff + "#; + let mut gzipped_blocklist = Vec::new(); + { + let mut encoder = GzipEncoder::new(&mut gzipped_blocklist); + encoder.write_all(blocklist.as_bytes()).await.unwrap(); + encoder.flush().await.unwrap(); + encoder.shutdown().await.unwrap(); + } + + let server = TestServer::new(&gzipped_blocklist, &[("Content-Encoding", "gzip")]); + + let blocklist = Blocklist::load_from_url(&server.url).await?; + assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap())); + + server.mock.assert(); + Ok(()) + } + + #[tokio::test] + async fn test_blocklist_plaintext() -> Result<()> { + let blocklist = r#" + # test + local:192.168.1.1-192.168.1.255 + localv6:2001:db8::1-2001:db8::ffff + "#; + + let server = TestServer::new(blocklist.as_bytes(), &[]); + + let blocklist = Blocklist::load_from_url(&server.url).await?; + assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap())); + + server.mock.assert(); + Ok(()) + } + + #[tokio::test] + async fn test_blocklist_from_plaintext_file() -> Result<()> { + let blocklist_content = r#" + # test + local:192.168.1.1-192.168.1.255 + localv6:2001:db8::1-2001:db8::ffff + "#; + + // Create a temporary file + let mut temp_file = tokio::fs::File::create("temp_blocklist.txt").await?; + tokio::io::AsyncWriteExt::write_all(&mut temp_file, blocklist_content.as_bytes()).await?; + drop(temp_file); // Close the file + + // Load the blocklist from the file + let blocklist = Blocklist::load_from_file("temp_blocklist.txt").await?; + + // Verify the blocklist + assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap())); + assert!(blocklist.is_blocked(&"2001:db8::1".parse().unwrap())); + assert!(!blocklist.is_blocked(&"2001:4860:4860::8888".parse().unwrap())); + + // Clean up the temporary file + tokio::fs::remove_file("temp_blocklist.txt").await?; + + Ok(()) + } + + #[test] + fn test_blocklist_empty() { + let blocklist = Blocklist::new( + &Vec::>::new(), + &Vec::>::new(), + ); + assert!(!blocklist.is_blocked(&"127.0.0.1".parse().unwrap())); + assert!(!blocklist.is_blocked(&"::1".parse().unwrap())); + } + + #[test] + fn test_manual_ranges() { + // Add IPv4 range + let start_v4: Ipv4Addr = "192.168.0.0".parse().unwrap(); + let end_v4: Ipv4Addr = "192.168.255.255".parse().unwrap(); + let start_num = u32::from_be_bytes(start_v4.octets()); + let end_num = u32::from_be_bytes(end_v4.octets()); + let ipv4_range = start_num..(end_num + 1); + + // Add IPv6 range + let start_v6: Ipv6Addr = "2001:db8::".parse().unwrap(); + let end_v6: Ipv6Addr = "2001:db8::ffff".parse().unwrap(); + let start_num = u128::from_be_bytes(start_v6.octets()); + let end_num = u128::from_be_bytes(end_v6.octets()); + let ipv6_range = start_num..(end_num + 1); + + let blocklist = Blocklist::new(&vec![ipv4_range], &vec![ipv6_range]); + + // Test IPv4 addresses + assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked(&"10.0.0.1".parse().unwrap())); + + // Test IPv6 addresses + assert!(blocklist.is_blocked(&"2001:db8::1".parse().unwrap())); + assert!(!blocklist.is_blocked(&"2001:db9::1".parse().unwrap())); + } +} diff --git a/crates/librqbit/src/lib.rs b/crates/librqbit/src/lib.rs index 7f33882..ef2e96b 100644 --- a/crates/librqbit/src/lib.rs +++ b/crates/librqbit/src/lib.rs @@ -42,6 +42,7 @@ pub mod api; mod api_error; mod bitv; mod bitv_factory; +mod blocklist; mod chunk_tracker; mod create_torrent_file; mod dht_utils; From 6e9ecf8a26a60626d1478a6765b29b31c057b215 Mon Sep 17 00:00:00 2001 From: Alexander WB Date: Thu, 20 Feb 2025 20:19:08 +0100 Subject: [PATCH 2/7] Add blocklist-url launch parameter Block incoming peers from blocked ips. --- crates/librqbit/src/blocklist.rs | 25 +++++++++++-------------- crates/librqbit/src/session.rs | 19 +++++++++++++++++++ crates/rqbit/src/main.rs | 5 +++++ 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/crates/librqbit/src/blocklist.rs b/crates/librqbit/src/blocklist.rs index b09e045..5638497 100644 --- a/crates/librqbit/src/blocklist.rs +++ b/crates/librqbit/src/blocklist.rs @@ -10,13 +10,20 @@ use tokio::{io::AsyncBufReadExt, io::BufReader}; use tokio_util::io::StreamReader; use tracing::{debug, info, trace}; -struct Blocklist { +pub struct Blocklist { // Separate trees for IPv4 and IPv6 since they have different numeric ranges ipv4_ranges: IntervalTree, ipv6_ranges: IntervalTree, } impl Blocklist { + pub fn empty() -> Self { + return Self::new( + &Vec::>::new(), + &Vec::>::new(), + ); + } + pub fn new( ipv4_ranges: &Vec>, ipv6_ranges: &Vec>, @@ -27,13 +34,6 @@ impl Blocklist { } } - fn ip_to_num(ip: &IpAddr) -> u128 { - match ip { - IpAddr::V4(ip) => u32::from_be_bytes(ip.octets()) as u128, - IpAddr::V6(ip) => u128::from_be_bytes(ip.octets()), - } - } - pub async fn load_from_url(url: &str) -> Result { let response = reqwest::get(url).await.map_err(|e| anyhow::anyhow!(e))?; if response.status() != 200 { @@ -74,7 +74,8 @@ impl Blocklist { let mut peek_bytes = [0u8; 2]; let mut reader = tokio::io::BufReader::new(reader); - let buffer = reader.fill_buf().await?; // Get a reference to the buffer + // Peek the first bytes by filling buffer + let buffer = reader.fill_buf().await?; if buffer.len() >= 2 { peek_bytes.copy_from_slice(&buffer[0..2]); } else { @@ -210,7 +211,6 @@ mod tests { let (server, mut mock, url) = rx.recv().unwrap(); - // Add response body and headers mock = mock.with_body(content); for &(key, value) in headers { mock = mock.with_header(key, value); @@ -305,10 +305,7 @@ mod tests { #[test] fn test_blocklist_empty() { - let blocklist = Blocklist::new( - &Vec::>::new(), - &Vec::>::new(), - ); + let blocklist = Blocklist::empty(); assert!(!blocklist.is_blocked(&"127.0.0.1".parse().unwrap())); assert!(!blocklist.is_blocked(&"::1".parse().unwrap())); } diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 0b0f4fb..376c419 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -11,6 +11,7 @@ use std::{ use crate::{ api::TorrentIdOrHash, bitv_factory::{BitVFactory, NonPersistentBitVFactory}, + blocklist, dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult}, limits::{Limits, LimitsConfig}, merge_streams::merge_streams, @@ -125,6 +126,8 @@ pub struct Session { pub(crate) concurrent_initialize_semaphore: Arc, pub ratelimits: Limits, + pub blocklist: blocklist::Blocklist, + // Monitoring / tracing / logging pub(crate) stats: SessionStats, root_span: Option, @@ -417,6 +420,8 @@ pub struct SessionOptions { pub ratelimits: LimitsConfig, + pub blocklist_url: Option, + #[cfg(feature = "disable-upload")] pub disable_upload: bool, } @@ -607,6 +612,14 @@ impl Session { let stream_connector = Arc::new(StreamConnector::from(proxy_config)); + let blocklist: blocklist::Blocklist = if let Some(blocklist_url) = opts.blocklist_url { + blocklist::Blocklist::load_from_url(&blocklist_url) + .await + .unwrap() + } else { + blocklist::Blocklist::empty() + }; + let session = Arc::new(Self { persistence, bitv_factory, @@ -632,6 +645,7 @@ impl Session { ratelimits: Limits::new(opts.ratelimits), #[cfg(feature = "disable-upload")] _disable_upload: opts.disable_upload, + blocklist, }); if let Some(mut disk_write_rx) = disk_write_rx { @@ -716,6 +730,11 @@ impl Session { .read_write_timeout .unwrap_or_else(|| Duration::from_secs(10)); + let incoming_ip = addr.ip(); + if self.blocklist.is_blocked(&incoming_ip) { + bail!("Incoming ip {incoming_ip} is in blocklist"); + } + let mut read_buf = ReadBuf::new(); let h = read_buf .read_handshake(&mut stream, rwtimeout) diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index 3b9cd86..23ada49 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -228,6 +228,10 @@ struct Opts { /// Limit upload to bytes-per-second. #[arg(long = "ratelimit-upload", env = "RQBIT_RATELIMIT_UPLOAD")] ratelimit_upload_bps: Option, + + /// Downloads a p2p blocklist from this url and blocks peers from it + #[arg(long, env = "RQBIT_BLOCKLIST_URL")] + blocklist_url: Option, } #[derive(Parser)] @@ -494,6 +498,7 @@ async fn async_main(opts: Opts, cancel: CancellationToken) -> anyhow::Result<()> upload_bps: opts.ratelimit_upload_bps, download_bps: opts.ratelimit_download_bps, }, + blocklist_url: opts.blocklist_url, }; let http_api_basic_auth = if let Ok(up) = std::env::var("RQBIT_HTTP_BASIC_AUTH_USERPASS") { From ac883c1ddffb2015e0c09b8c4376baa16d278692 Mon Sep 17 00:00:00 2001 From: Alexander WB Date: Fri, 21 Feb 2025 22:03:49 +0100 Subject: [PATCH 3/7] :ok_hand: Improve Parsing logic --- crates/librqbit/src/blocklist.rs | 133 +++++++++++++------------------ 1 file changed, 54 insertions(+), 79 deletions(-) diff --git a/crates/librqbit/src/blocklist.rs b/crates/librqbit/src/blocklist.rs index 5638497..f72f7e7 100644 --- a/crates/librqbit/src/blocklist.rs +++ b/crates/librqbit/src/blocklist.rs @@ -2,7 +2,7 @@ use anyhow::Result; use async_compression::tokio::bufread::GzipDecoder; use futures::TryStreamExt; use intervaltree::IntervalTree; -use std::net::IpAddr; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::pin::Pin; use std::str::FromStr; use tokio::io::AsyncRead; @@ -11,26 +11,19 @@ use tokio_util::io::StreamReader; use tracing::{debug, info, trace}; pub struct Blocklist { - // Separate trees for IPv4 and IPv6 since they have different numeric ranges - ipv4_ranges: IntervalTree, - ipv6_ranges: IntervalTree, + // ipv4 and ipv6 do not overlap + // see: https://www.rfc-editor.org/rfc/rfc4291#section-2.5.5 + blocked_ranges: IntervalTree, } impl Blocklist { pub fn empty() -> Self { - return Self::new( - &Vec::>::new(), - &Vec::>::new(), - ); + return Self::new(std::iter::empty()); } - pub fn new( - ipv4_ranges: &Vec>, - ipv6_ranges: &Vec>, - ) -> Self { + pub fn new(ip_ranges: impl IntoIterator>) -> Self { Self { - ipv4_ranges: IntervalTree::from_iter(ipv4_ranges.iter().map(|r| (r.clone(), ()))), - ipv6_ranges: IntervalTree::from_iter(ipv6_ranges.iter().map(|r| (r.clone(), ()))), + blocked_ranges: IntervalTree::from_iter(ip_ranges.into_iter().map(|r| (r, ()))), } } @@ -97,87 +90,75 @@ impl Blocklist { let reader = BufReader::new(reader); let mut lines = reader.lines(); - let mut ipv4_ranges: Vec> = Vec::new(); - let mut ipv6_ranges: Vec> = Vec::new(); + let mut ip_ranges: Vec> = Vec::new(); while let Some(line) = lines.next_line().await? { // Skip comments and empty lines if line.starts_with('#') || line.trim().is_empty() { continue; } - // Parse IP ranges in format: "RuleName:StartIp-EndIp" if let Some((start_ip, end_ip)) = parse_ip_range(&line) { - match (start_ip, end_ip) { - (IpAddr::V4(start), IpAddr::V4(end)) => { - let start_num = u32::from_be_bytes(start.octets()); - let end_num = u32::from_be_bytes(end.octets()); - let range = if end_num == u32::MAX { - start_num..end_num // Special case: Use inclusive range when max - } else { - start_num..(end_num + 1) // Normal case - }; - ipv4_ranges.push(range); - } - (IpAddr::V6(start), IpAddr::V6(end)) => { - let start_num = u128::from_be_bytes(start.octets()); - let end_num = u128::from_be_bytes(end.octets()); - let range = if end_num == u128::MAX { - start_num..end_num // Special case: Use inclusive range when max - } else { - start_num..(end_num + 1) // Normal case - }; - ipv6_ranges.push(range); - } - _ => { - continue; - } - } + let range = start_ip..(increment_ip(end_ip).unwrap()); + ip_ranges.push(range); } } info!( - ipv6_entry_count = ipv6_ranges.len(), - ipv4_entry_count = ipv4_ranges.len(), + ip_entry_count = ip_ranges.len(), "Finished loading blocklist" ); - let blocklist = Self::new(&ipv4_ranges, &ipv6_ranges); + let blocklist = Self::new(ip_ranges); Ok(blocklist) } pub fn is_blocked(&self, ip: &IpAddr) -> bool { - match ip { - IpAddr::V4(ipv4) => { - let num = u32::from_be_bytes(ipv4.octets()); - self.ipv4_ranges.query_point(num).next().is_some() - } - IpAddr::V6(ipv6) => { - let num = u128::from_be_bytes(ipv6.octets()); - self.ipv6_ranges.query_point(num).next().is_some() - } + self.blocked_ranges.query_point(*ip).next().is_some() + } +} + +/// Safely increments an `IpAddr`, returning `None` if it would overflow. +fn increment_ip(ip: IpAddr) -> Option { + match ip { + IpAddr::V4(ipv4) => { + let num = u32::from_be_bytes(ipv4.octets()); + num.checked_add(1).map(|n| IpAddr::V4(Ipv4Addr::from(n))) + } + IpAddr::V6(ipv6) => { + let num = u128::from_be_bytes(ipv6.octets()); + num.checked_add(1).map(|n| IpAddr::V6(Ipv6Addr::from(n))) } } } fn parse_ip_range(line: &str) -> Option<(IpAddr, IpAddr)> { // Skip comments and empty lines - if line.starts_with('#') || line.trim().is_empty() { + let line = line.trim(); + if line.starts_with('#') || line.is_empty() { return None; } - // Parse IP ranges in format: "RuleName:StartIp-EndIp" - if let Some((rule_name, ip_range)) = line.rsplit_once(':') { - if let Some((start, end)) = ip_range.split_once('-') { - if let (Ok(start_ip), Ok(end_ip)) = - (IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim())) - { - return Some((start_ip, end_ip)); - } else { - // Mismatched IP versions, skip this range - debug!(rulen_name = rule_name, "Could not be parsed"); - } + let is_ipv4 = line.matches('.').count() >= 6; + // Find the split point based on whether it's IPv4 or not + let split_point: usize = if is_ipv4 { + line.rfind(':') + } else { + line.find(':') + } + .unwrap_or(0); + + let (rule_name, ip_range) = line.split_at(split_point + 1); + if let Some((start, end)) = ip_range.split_once('-') { + if let (Ok(start_ip), Ok(end_ip)) = + (IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim())) + { + return Some((start_ip, end_ip)); + } else { + // Mismatched IP versions, skip this range + debug!(rulen_name = rule_name, "Could not be parsed"); } } + None } @@ -186,7 +167,6 @@ mod tests { use super::*; use async_compression::tokio::write::GzipEncoder; use mockito::{Server, ServerGuard}; - use std::net::{Ipv4Addr, Ipv6Addr}; use std::thread::{self, JoinHandle}; use tokio::io::AsyncWriteExt; @@ -313,21 +293,16 @@ mod tests { #[test] fn test_manual_ranges() { // Add IPv4 range - let start_v4: Ipv4Addr = "192.168.0.0".parse().unwrap(); - let end_v4: Ipv4Addr = "192.168.255.255".parse().unwrap(); - let start_num = u32::from_be_bytes(start_v4.octets()); - let end_num = u32::from_be_bytes(end_v4.octets()); - let ipv4_range = start_num..(end_num + 1); + let start_v4: IpAddr = "192.168.0.0".parse().unwrap(); + let end_v4: IpAddr = "192.168.255.255".parse().unwrap(); + let ipv4_range = start_v4..end_v4; // Add IPv6 range - let start_v6: Ipv6Addr = "2001:db8::".parse().unwrap(); - let end_v6: Ipv6Addr = "2001:db8::ffff".parse().unwrap(); - let start_num = u128::from_be_bytes(start_v6.octets()); - let end_num = u128::from_be_bytes(end_v6.octets()); - let ipv6_range = start_num..(end_num + 1); - - let blocklist = Blocklist::new(&vec![ipv4_range], &vec![ipv6_range]); + let start_v6: IpAddr = "2001:db8::".parse().unwrap(); + let end_v6: IpAddr = "2001:db8::ffff".parse().unwrap(); + let ipv6_range = start_v6..end_v6; + let blocklist = Blocklist::new(vec![ipv4_range, ipv6_range]); // Test IPv4 addresses assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); assert!(!blocklist.is_blocked(&"10.0.0.1".parse().unwrap())); From 8f019882d06a7c479393a6b015122ea74db37d6c Mon Sep 17 00:00:00 2001 From: Alexander WB Date: Tue, 25 Feb 2025 03:02:13 +0100 Subject: [PATCH 4/7] :ok_hand: Simplify test Remove mock http server --- Cargo.lock | 51 -------------------------- crates/librqbit/Cargo.toml | 1 - crates/librqbit/src/blocklist.rs | 63 ++++++-------------------------- 3 files changed, 11 insertions(+), 104 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 638b2b7..dc300fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -135,16 +135,6 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" -[[package]] -name = "assert-json-diff" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" -dependencies = [ - "serde", - "serde_json", -] - [[package]] name = "assert_cfg" version = "0.1.0" @@ -824,16 +814,6 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" -[[package]] -name = "colored" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbf2150cce219b664a8a70df7a1f933836724b503f8a413af9365b4dcc4d90b8" -dependencies = [ - "lazy_static", - "windows-sys 0.48.0", -] - [[package]] name = "combine" version = "4.6.7" @@ -2703,7 +2683,6 @@ dependencies = [ "lru", "memmap2", "mime_guess", - "mockito", "notify", "parking_lot", "rand 0.8.5", @@ -3083,30 +3062,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "mockito" -version = "1.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "652cd6d169a36eaf9d1e6bce1a221130439a966d7f27858af66a33a66e9c4ee2" -dependencies = [ - "assert-json-diff", - "bytes", - "colored", - "futures-util", - "http", - "http-body", - "http-body-util", - "hyper", - "hyper-util", - "log", - "rand 0.8.5", - "regex", - "serde_json", - "serde_urlencoded", - "similar", - "tokio", -] - [[package]] name = "muda" version = "0.15.3" @@ -4928,12 +4883,6 @@ version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" -[[package]] -name = "similar" -version = "2.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1de1d4f81173b03af4c0cbed3c898f6bff5b870e4a7f5d6f4057d62a7a4b686e" - [[package]] name = "siphasher" version = "0.3.11" diff --git a/crates/librqbit/Cargo.toml b/crates/librqbit/Cargo.toml index a280c13..a2a4453 100644 --- a/crates/librqbit/Cargo.toml +++ b/crates/librqbit/Cargo.toml @@ -130,4 +130,3 @@ tracing-subscriber = "0.3" tokio-test = "0.4" tempfile = "3" rand = { version = "0.8", features = ["small_rng"] } -mockito = "1.2" diff --git a/crates/librqbit/src/blocklist.rs b/crates/librqbit/src/blocklist.rs index f72f7e7..fe90e0a 100644 --- a/crates/librqbit/src/blocklist.rs +++ b/crates/librqbit/src/blocklist.rs @@ -164,54 +164,13 @@ fn parse_ip_range(line: &str) -> Option<(IpAddr, IpAddr)> { #[cfg(test)] mod tests { + use std::io::Cursor; + use super::*; use async_compression::tokio::write::GzipEncoder; - use mockito::{Server, ServerGuard}; - use std::thread::{self, JoinHandle}; + use futures::stream::once; use tokio::io::AsyncWriteExt; - struct TestServer { - server: ServerGuard, - mock: mockito::Mock, - url: String, - _thread: JoinHandle<()>, - } - - impl TestServer { - fn new(content: &[u8], headers: &[(&str, &str)]) -> Self { - let (tx, rx) = std::sync::mpsc::channel(); - let server_thread = thread::spawn(move || { - let mut server = Server::new(); - let url = server.url(); - let mock = server.mock("GET", "/").with_status(200); - - tx.send((server, mock, url)).unwrap(); - thread::park(); - }); - - let (server, mut mock, url) = rx.recv().unwrap(); - - mock = mock.with_body(content); - for &(key, value) in headers { - mock = mock.with_header(key, value); - } - let mock = mock.create(); - - TestServer { - server, - mock, - url, - _thread: server_thread, - } - } - } - - impl Drop for TestServer { - fn drop(&mut self) { - self._thread.thread().unpark(); - } - } - #[tokio::test] async fn test_blocklist_gzipped() -> Result<()> { let blocklist = r#" @@ -227,13 +186,13 @@ mod tests { encoder.shutdown().await.unwrap(); } - let server = TestServer::new(&gzipped_blocklist, &[("Content-Encoding", "gzip")]); - - let blocklist = Blocklist::load_from_url(&server.url).await?; + let stream = StreamReader::new(Box::pin(once(async { + Ok::<_, std::io::Error>(Cursor::new(gzipped_blocklist)) + }))); + let blocklist = Blocklist::create_from_stream(stream).await?; assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap())); - server.mock.assert(); Ok(()) } @@ -245,13 +204,13 @@ mod tests { localv6:2001:db8::1-2001:db8::ffff "#; - let server = TestServer::new(blocklist.as_bytes(), &[]); - - let blocklist = Blocklist::load_from_url(&server.url).await?; + let stream = StreamReader::new(Box::pin(once(async { + Ok::<_, std::io::Error>(Cursor::new(blocklist.as_bytes().to_vec())) + }))); + let blocklist = Blocklist::create_from_stream(stream).await?; assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap())); - server.mock.assert(); Ok(()) } From c19ea3979aa55e37695f46a43c382547739f020a Mon Sep 17 00:00:00 2001 From: Alexander WB Date: Tue, 25 Feb 2025 03:27:59 +0100 Subject: [PATCH 5/7] :ok_hand: Improved error handling --- crates/librqbit/src/blocklist.rs | 19 +++++++------------ crates/librqbit/src/session.rs | 1 + 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/crates/librqbit/src/blocklist.rs b/crates/librqbit/src/blocklist.rs index fe90e0a..f67f1d4 100644 --- a/crates/librqbit/src/blocklist.rs +++ b/crates/librqbit/src/blocklist.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{Context, Result}; use async_compression::tokio::bufread::GzipDecoder; use futures::TryStreamExt; use intervaltree::IntervalTree; @@ -28,12 +28,11 @@ impl Blocklist { } pub async fn load_from_url(url: &str) -> Result { - let response = reqwest::get(url).await.map_err(|e| anyhow::anyhow!(e))?; + let response = reqwest::get(url) + .await + .context("Failed to send request for blocklist")?; if response.status() != 200 { - return Err(anyhow::anyhow!( - "Failed to fetch blocklist: HTTP {}", - response.status() - )); + anyhow::bail!("Failed to fetch blocklist: HTTP {}", response.status()); } let content_length = response @@ -41,9 +40,7 @@ impl Blocklist { .ok_or_else(|| anyhow::anyhow!("Failed to get content length"))?; if content_length < 2 { - return Err(anyhow::anyhow!( - "Content too short: not enough data to determine compression" - )); + anyhow::bail!("Content too short: not enough data to determine compression"); } let reader = StreamReader::new( @@ -72,9 +69,7 @@ impl Blocklist { if buffer.len() >= 2 { peek_bytes.copy_from_slice(&buffer[0..2]); } else { - return Err(anyhow::anyhow!( - "Content too short: not enough data to determine compression" - )); + anyhow::bail!("Content too short: not enough data to determine compression"); } // Check for Gzip magic bytes (1F 8B) diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 376c419..c68de6f 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -615,6 +615,7 @@ impl Session { let blocklist: blocklist::Blocklist = if let Some(blocklist_url) = opts.blocklist_url { blocklist::Blocklist::load_from_url(&blocklist_url) .await + .inspect_err(|e| warn!("failed to read blocklist: {e}")) .unwrap() } else { blocklist::Blocklist::empty() From 4c6e19ceab7c5c9a73fa8efbe60a33233bb83c36 Mon Sep 17 00:00:00 2001 From: Alexander WB Date: Tue, 25 Feb 2025 03:42:41 +0100 Subject: [PATCH 6/7] :ok_hand: Add parse file:// urls --- crates/librqbit/src/blocklist.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/crates/librqbit/src/blocklist.rs b/crates/librqbit/src/blocklist.rs index f67f1d4..c194a86 100644 --- a/crates/librqbit/src/blocklist.rs +++ b/crates/librqbit/src/blocklist.rs @@ -9,6 +9,7 @@ use tokio::io::AsyncRead; use tokio::{io::AsyncBufReadExt, io::BufReader}; use tokio_util::io::StreamReader; use tracing::{debug, info, trace}; +use url::Url; pub struct Blocklist { // ipv4 and ipv6 do not overlap @@ -28,7 +29,16 @@ impl Blocklist { } pub async fn load_from_url(url: &str) -> Result { - let response = reqwest::get(url) + let parsed_url = Url::parse(url).context("Failed to parse URL")?; + + if parsed_url.scheme() == "file" { + let path = parsed_url + .to_file_path() + .map_err(|_| anyhow::anyhow!("Failed to convert file URL to path"))?; + return Self::load_from_file(path.to_str().unwrap()).await; + } + + let response = reqwest::get(parsed_url) .await .context("Failed to send request for blocklist")?; if response.status() != 200 { From 1a313400769c29d5ec595734239dc78d2c524d24 Mon Sep 17 00:00:00 2001 From: Alexander WB Date: Tue, 25 Feb 2025 04:25:24 +0100 Subject: [PATCH 7/7] :ok_hand: Fix remaining pr comments --- crates/librqbit/src/blocklist.rs | 55 +++++++++---------- crates/librqbit/src/session.rs | 2 +- crates/librqbit/src/torrent_state/live/mod.rs | 11 ++++ 3 files changed, 37 insertions(+), 31 deletions(-) diff --git a/crates/librqbit/src/blocklist.rs b/crates/librqbit/src/blocklist.rs index c194a86..291de10 100644 --- a/crates/librqbit/src/blocklist.rs +++ b/crates/librqbit/src/blocklist.rs @@ -5,7 +5,7 @@ use intervaltree::IntervalTree; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::pin::Pin; use std::str::FromStr; -use tokio::io::AsyncRead; +use tokio::io::{AsyncBufRead, AsyncRead}; use tokio::{io::AsyncBufReadExt, io::BufReader}; use tokio_util::io::StreamReader; use tracing::{debug, info, trace}; @@ -85,7 +85,7 @@ impl Blocklist { // Check for Gzip magic bytes (1F 8B) let is_gzip = peek_bytes == [0x1F, 0x8B]; - let reader: Pin> = if is_gzip { + let mut reader: Pin> = if is_gzip { trace!("Detected Gzip file, decompressing..."); Box::pin(BufReader::new(GzipDecoder::new(reader))) } else { @@ -93,19 +93,14 @@ impl Blocklist { Box::pin(reader) }; - let reader = BufReader::new(reader); - let mut lines = reader.lines(); + let mut line: String = Default::default(); let mut ip_ranges: Vec> = Vec::new(); - while let Some(line) = lines.next_line().await? { - // Skip comments and empty lines - if line.starts_with('#') || line.trim().is_empty() { - continue; - } - + while reader.read_line(&mut line).await? > 0 { if let Some((start_ip, end_ip)) = parse_ip_range(&line) { - let range = start_ip..(increment_ip(end_ip).unwrap()); + let range = start_ip..(increment_ip(end_ip)); ip_ranges.push(range); } + line.clear(); } info!( @@ -117,21 +112,21 @@ impl Blocklist { Ok(blocklist) } - pub fn is_blocked(&self, ip: &IpAddr) -> bool { - self.blocked_ranges.query_point(*ip).next().is_some() + pub fn is_blocked(&self, ip: IpAddr) -> bool { + self.blocked_ranges.query_point(ip).next().is_some() } } /// Safely increments an `IpAddr`, returning `None` if it would overflow. -fn increment_ip(ip: IpAddr) -> Option { +fn increment_ip(ip: IpAddr) -> IpAddr { match ip { IpAddr::V4(ipv4) => { let num = u32::from_be_bytes(ipv4.octets()); - num.checked_add(1).map(|n| IpAddr::V4(Ipv4Addr::from(n))) + std::net::IpAddr::V4(Ipv4Addr::from(num.saturating_add(1))) } IpAddr::V6(ipv6) => { let num = u128::from_be_bytes(ipv6.octets()); - num.checked_add(1).map(|n| IpAddr::V6(Ipv6Addr::from(n))) + std::net::IpAddr::V6(Ipv6Addr::from(num.saturating_add(1))) } } } @@ -195,8 +190,8 @@ mod tests { Ok::<_, std::io::Error>(Cursor::new(gzipped_blocklist)) }))); let blocklist = Blocklist::create_from_stream(stream).await?; - assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); - assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap())); + assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap())); Ok(()) } @@ -213,8 +208,8 @@ mod tests { Ok::<_, std::io::Error>(Cursor::new(blocklist.as_bytes().to_vec())) }))); let blocklist = Blocklist::create_from_stream(stream).await?; - assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); - assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap())); + assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap())); Ok(()) } @@ -236,10 +231,10 @@ mod tests { let blocklist = Blocklist::load_from_file("temp_blocklist.txt").await?; // Verify the blocklist - assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); - assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap())); - assert!(blocklist.is_blocked(&"2001:db8::1".parse().unwrap())); - assert!(!blocklist.is_blocked(&"2001:4860:4860::8888".parse().unwrap())); + assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap())); + assert!(blocklist.is_blocked("2001:db8::1".parse().unwrap())); + assert!(!blocklist.is_blocked("2001:4860:4860::8888".parse().unwrap())); // Clean up the temporary file tokio::fs::remove_file("temp_blocklist.txt").await?; @@ -250,8 +245,8 @@ mod tests { #[test] fn test_blocklist_empty() { let blocklist = Blocklist::empty(); - assert!(!blocklist.is_blocked(&"127.0.0.1".parse().unwrap())); - assert!(!blocklist.is_blocked(&"::1".parse().unwrap())); + assert!(!blocklist.is_blocked("127.0.0.1".parse().unwrap())); + assert!(!blocklist.is_blocked("::1".parse().unwrap())); } #[test] @@ -268,11 +263,11 @@ mod tests { let blocklist = Blocklist::new(vec![ipv4_range, ipv6_range]); // Test IPv4 addresses - assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); - assert!(!blocklist.is_blocked(&"10.0.0.1".parse().unwrap())); + assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked("10.0.0.1".parse().unwrap())); // Test IPv6 addresses - assert!(blocklist.is_blocked(&"2001:db8::1".parse().unwrap())); - assert!(!blocklist.is_blocked(&"2001:db9::1".parse().unwrap())); + assert!(blocklist.is_blocked("2001:db8::1".parse().unwrap())); + assert!(!blocklist.is_blocked("2001:db9::1".parse().unwrap())); } } diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index c68de6f..951e025 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -732,7 +732,7 @@ impl Session { .unwrap_or_else(|| Duration::from_secs(10)); let incoming_ip = addr.ip(); - if self.blocklist.is_blocked(&incoming_ip) { + if self.blocklist.is_blocked(incoming_ip) { bail!("Incoming ip {incoming_ip} is in blocklist"); } diff --git a/crates/librqbit/src/torrent_state/live/mod.rs b/crates/librqbit/src/torrent_state/live/mod.rs index d7f9748..3729494 100644 --- a/crates/librqbit/src/torrent_state/live/mod.rs +++ b/crates/librqbit/src/torrent_state/live/mod.rs @@ -572,6 +572,17 @@ impl TorrentStateLive { continue; } + let outgoing_ip = addr.ip(); + let is_blocked_ip = state.shared.session.upgrade().map_or_else( + || false, + |session| session.blocklist.is_blocked(outgoing_ip), + ); + + if is_blocked_ip { + info!("Outgoing ip {outgoing_ip} for peer is in blocklist skipping"); + continue; + } + let permit = state.peer_semaphore.clone().acquire_owned().await?; state.spawn( error_span!(parent: state.shared.span.clone(), "manage_peer", peer = addr.to_string()),