diff --git a/Cargo.lock b/Cargo.lock index 5395ae9..638b2b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -135,6 +135,16 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "assert_cfg" version = "0.1.0" @@ -171,6 +181,19 @@ dependencies = [ "syn 2.0.95", ] +[[package]] +name = "async-compression" +version = "0.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df895a515f70646414f4b45c0b79082783b80552b373a68283012928df56f522" +dependencies = [ + "flate2", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -801,6 +824,16 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +[[package]] +name = "colored" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbf2150cce219b664a8a70df7a1f933836724b503f8a413af9365b4dcc4d90b8" +dependencies = [ + "lazy_static", + "windows-sys 0.48.0", +] + [[package]] name = "combine" version = "4.6.7" @@ -2338,6 +2371,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "intervaltree" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "270bc34e57047cab801a8c871c124d9dc7132f6473c6401f645524f4e6edd111" +dependencies = [ + "smallvec", +] + [[package]] name = "ipnet" version = "2.10.1" @@ -2629,6 +2671,7 @@ dependencies = [ "anyhow", "arc-swap", "async-backtrace", + "async-compression", "async-stream", "async-trait", "axum 0.8.1", @@ -2645,6 +2688,7 @@ dependencies = [ "hex 0.4.3", "home", "http", + "intervaltree", "itertools 0.14.0", "librqbit-bencode", "librqbit-buffers", @@ -2659,6 +2703,7 @@ dependencies = [ "lru", "memmap2", "mime_guess", + "mockito", "notify", "parking_lot", "rand 0.8.5", @@ -3038,6 +3083,30 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "mockito" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "652cd6d169a36eaf9d1e6bce1a221130439a966d7f27858af66a33a66e9c4ee2" +dependencies = [ + "assert-json-diff", + "bytes", + "colored", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "log", + "rand 0.8.5", + "regex", + "serde_json", + "serde_urlencoded", + "similar", + "tokio", +] + [[package]] name = "muda" version = "0.15.3" @@ -4859,6 +4928,12 @@ version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" +[[package]] +name = "similar" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1de1d4f81173b03af4c0cbed3c898f6bff5b870e4a7f5d6f4057d62a7a4b686e" + [[package]] name = "siphasher" version = "0.3.11" diff --git a/crates/librqbit/Cargo.toml b/crates/librqbit/Cargo.toml index 29da015..a280c13 100644 --- a/crates/librqbit/Cargo.toml +++ b/crates/librqbit/Cargo.toml @@ -81,6 +81,7 @@ regex = "1" reqwest = { version = "0.12", default-features = false, features = [ "json", "socks", + "stream", ] } urlencoding = "2" byteorder = "1" @@ -117,6 +118,8 @@ async-backtrace = { version = "0.2", optional = true } notify = { version = "7", optional = true } walkdir = "2.5.0" arc-swap = "1.7.1" +intervaltree = "0.2.7" +async-compression = {version="0.4.18", features= ["tokio", "gzip"] } [build-dependencies] anyhow = "1" @@ -127,3 +130,4 @@ tracing-subscriber = "0.3" tokio-test = "0.4" tempfile = "3" rand = { version = "0.8", features = ["small_rng"] } +mockito = "1.2" diff --git a/crates/librqbit/src/blocklist.rs b/crates/librqbit/src/blocklist.rs new file mode 100644 index 0000000..b09e045 --- /dev/null +++ b/crates/librqbit/src/blocklist.rs @@ -0,0 +1,342 @@ +use anyhow::Result; +use async_compression::tokio::bufread::GzipDecoder; +use futures::TryStreamExt; +use intervaltree::IntervalTree; +use std::net::IpAddr; +use std::pin::Pin; +use std::str::FromStr; +use tokio::io::AsyncRead; +use tokio::{io::AsyncBufReadExt, io::BufReader}; +use tokio_util::io::StreamReader; +use tracing::{debug, info, trace}; + +struct Blocklist { + // Separate trees for IPv4 and IPv6 since they have different numeric ranges + ipv4_ranges: IntervalTree, + ipv6_ranges: IntervalTree, +} + +impl Blocklist { + pub fn new( + ipv4_ranges: &Vec>, + ipv6_ranges: &Vec>, + ) -> Self { + Self { + ipv4_ranges: IntervalTree::from_iter(ipv4_ranges.iter().map(|r| (r.clone(), ()))), + ipv6_ranges: IntervalTree::from_iter(ipv6_ranges.iter().map(|r| (r.clone(), ()))), + } + } + + fn ip_to_num(ip: &IpAddr) -> u128 { + match ip { + IpAddr::V4(ip) => u32::from_be_bytes(ip.octets()) as u128, + IpAddr::V6(ip) => u128::from_be_bytes(ip.octets()), + } + } + + pub async fn load_from_url(url: &str) -> Result { + let response = reqwest::get(url).await.map_err(|e| anyhow::anyhow!(e))?; + if response.status() != 200 { + return Err(anyhow::anyhow!( + "Failed to fetch blocklist: HTTP {}", + response.status() + )); + } + + let content_length = response + .content_length() + .ok_or_else(|| anyhow::anyhow!("Failed to get content length"))?; + + if content_length < 2 { + return Err(anyhow::anyhow!( + "Content too short: not enough data to determine compression" + )); + } + + let reader = StreamReader::new( + response + .bytes_stream() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)), + ); + Self::create_from_stream(reader).await + } + + pub async fn load_from_file(path: &str) -> Result { + let file = tokio::fs::File::open(path).await?; + let reader = tokio::io::BufReader::new(file); + Self::create_from_stream(reader).await + } + + async fn create_from_stream(reader: R) -> Result + where + R: AsyncRead + Unpin + Send, + { + let mut peek_bytes = [0u8; 2]; + let mut reader = tokio::io::BufReader::new(reader); + + let buffer = reader.fill_buf().await?; // Get a reference to the buffer + if buffer.len() >= 2 { + peek_bytes.copy_from_slice(&buffer[0..2]); + } else { + return Err(anyhow::anyhow!( + "Content too short: not enough data to determine compression" + )); + } + + // Check for Gzip magic bytes (1F 8B) + let is_gzip = peek_bytes == [0x1F, 0x8B]; + + let reader: Pin> = if is_gzip { + trace!("Detected Gzip file, decompressing..."); + Box::pin(BufReader::new(GzipDecoder::new(reader))) + } else { + trace!("Plain text file detected."); + Box::pin(reader) + }; + + let reader = BufReader::new(reader); + let mut lines = reader.lines(); + let mut ipv4_ranges: Vec> = Vec::new(); + let mut ipv6_ranges: Vec> = Vec::new(); + while let Some(line) = lines.next_line().await? { + // Skip comments and empty lines + if line.starts_with('#') || line.trim().is_empty() { + continue; + } + + // Parse IP ranges in format: "RuleName:StartIp-EndIp" + if let Some((start_ip, end_ip)) = parse_ip_range(&line) { + match (start_ip, end_ip) { + (IpAddr::V4(start), IpAddr::V4(end)) => { + let start_num = u32::from_be_bytes(start.octets()); + let end_num = u32::from_be_bytes(end.octets()); + let range = if end_num == u32::MAX { + start_num..end_num // Special case: Use inclusive range when max + } else { + start_num..(end_num + 1) // Normal case + }; + ipv4_ranges.push(range); + } + (IpAddr::V6(start), IpAddr::V6(end)) => { + let start_num = u128::from_be_bytes(start.octets()); + let end_num = u128::from_be_bytes(end.octets()); + let range = if end_num == u128::MAX { + start_num..end_num // Special case: Use inclusive range when max + } else { + start_num..(end_num + 1) // Normal case + }; + ipv6_ranges.push(range); + } + _ => { + continue; + } + } + } + } + + info!( + ipv6_entry_count = ipv6_ranges.len(), + ipv4_entry_count = ipv4_ranges.len(), + "Finished loading blocklist" + ); + + let blocklist = Self::new(&ipv4_ranges, &ipv6_ranges); + Ok(blocklist) + } + + pub fn is_blocked(&self, ip: &IpAddr) -> bool { + match ip { + IpAddr::V4(ipv4) => { + let num = u32::from_be_bytes(ipv4.octets()); + self.ipv4_ranges.query_point(num).next().is_some() + } + IpAddr::V6(ipv6) => { + let num = u128::from_be_bytes(ipv6.octets()); + self.ipv6_ranges.query_point(num).next().is_some() + } + } + } +} + +fn parse_ip_range(line: &str) -> Option<(IpAddr, IpAddr)> { + // Skip comments and empty lines + if line.starts_with('#') || line.trim().is_empty() { + return None; + } + + // Parse IP ranges in format: "RuleName:StartIp-EndIp" + if let Some((rule_name, ip_range)) = line.rsplit_once(':') { + if let Some((start, end)) = ip_range.split_once('-') { + if let (Ok(start_ip), Ok(end_ip)) = + (IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim())) + { + return Some((start_ip, end_ip)); + } else { + // Mismatched IP versions, skip this range + debug!(rulen_name = rule_name, "Could not be parsed"); + } + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + use async_compression::tokio::write::GzipEncoder; + use mockito::{Server, ServerGuard}; + use std::net::{Ipv4Addr, Ipv6Addr}; + use std::thread::{self, JoinHandle}; + use tokio::io::AsyncWriteExt; + + struct TestServer { + server: ServerGuard, + mock: mockito::Mock, + url: String, + _thread: JoinHandle<()>, + } + + impl TestServer { + fn new(content: &[u8], headers: &[(&str, &str)]) -> Self { + let (tx, rx) = std::sync::mpsc::channel(); + let server_thread = thread::spawn(move || { + let mut server = Server::new(); + let url = server.url(); + let mock = server.mock("GET", "/").with_status(200); + + tx.send((server, mock, url)).unwrap(); + thread::park(); + }); + + let (server, mut mock, url) = rx.recv().unwrap(); + + // Add response body and headers + mock = mock.with_body(content); + for &(key, value) in headers { + mock = mock.with_header(key, value); + } + let mock = mock.create(); + + TestServer { + server, + mock, + url, + _thread: server_thread, + } + } + } + + impl Drop for TestServer { + fn drop(&mut self) { + self._thread.thread().unpark(); + } + } + + #[tokio::test] + async fn test_blocklist_gzipped() -> Result<()> { + let blocklist = r#" + # test + local:192.168.1.1-192.168.1.255 + localv6:2001:db8::1-2001:db8::ffff + "#; + let mut gzipped_blocklist = Vec::new(); + { + let mut encoder = GzipEncoder::new(&mut gzipped_blocklist); + encoder.write_all(blocklist.as_bytes()).await.unwrap(); + encoder.flush().await.unwrap(); + encoder.shutdown().await.unwrap(); + } + + let server = TestServer::new(&gzipped_blocklist, &[("Content-Encoding", "gzip")]); + + let blocklist = Blocklist::load_from_url(&server.url).await?; + assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap())); + + server.mock.assert(); + Ok(()) + } + + #[tokio::test] + async fn test_blocklist_plaintext() -> Result<()> { + let blocklist = r#" + # test + local:192.168.1.1-192.168.1.255 + localv6:2001:db8::1-2001:db8::ffff + "#; + + let server = TestServer::new(blocklist.as_bytes(), &[]); + + let blocklist = Blocklist::load_from_url(&server.url).await?; + assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap())); + + server.mock.assert(); + Ok(()) + } + + #[tokio::test] + async fn test_blocklist_from_plaintext_file() -> Result<()> { + let blocklist_content = r#" + # test + local:192.168.1.1-192.168.1.255 + localv6:2001:db8::1-2001:db8::ffff + "#; + + // Create a temporary file + let mut temp_file = tokio::fs::File::create("temp_blocklist.txt").await?; + tokio::io::AsyncWriteExt::write_all(&mut temp_file, blocklist_content.as_bytes()).await?; + drop(temp_file); // Close the file + + // Load the blocklist from the file + let blocklist = Blocklist::load_from_file("temp_blocklist.txt").await?; + + // Verify the blocklist + assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap())); + assert!(blocklist.is_blocked(&"2001:db8::1".parse().unwrap())); + assert!(!blocklist.is_blocked(&"2001:4860:4860::8888".parse().unwrap())); + + // Clean up the temporary file + tokio::fs::remove_file("temp_blocklist.txt").await?; + + Ok(()) + } + + #[test] + fn test_blocklist_empty() { + let blocklist = Blocklist::new( + &Vec::>::new(), + &Vec::>::new(), + ); + assert!(!blocklist.is_blocked(&"127.0.0.1".parse().unwrap())); + assert!(!blocklist.is_blocked(&"::1".parse().unwrap())); + } + + #[test] + fn test_manual_ranges() { + // Add IPv4 range + let start_v4: Ipv4Addr = "192.168.0.0".parse().unwrap(); + let end_v4: Ipv4Addr = "192.168.255.255".parse().unwrap(); + let start_num = u32::from_be_bytes(start_v4.octets()); + let end_num = u32::from_be_bytes(end_v4.octets()); + let ipv4_range = start_num..(end_num + 1); + + // Add IPv6 range + let start_v6: Ipv6Addr = "2001:db8::".parse().unwrap(); + let end_v6: Ipv6Addr = "2001:db8::ffff".parse().unwrap(); + let start_num = u128::from_be_bytes(start_v6.octets()); + let end_num = u128::from_be_bytes(end_v6.octets()); + let ipv6_range = start_num..(end_num + 1); + + let blocklist = Blocklist::new(&vec![ipv4_range], &vec![ipv6_range]); + + // Test IPv4 addresses + assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked(&"10.0.0.1".parse().unwrap())); + + // Test IPv6 addresses + assert!(blocklist.is_blocked(&"2001:db8::1".parse().unwrap())); + assert!(!blocklist.is_blocked(&"2001:db9::1".parse().unwrap())); + } +} diff --git a/crates/librqbit/src/lib.rs b/crates/librqbit/src/lib.rs index 7f33882..ef2e96b 100644 --- a/crates/librqbit/src/lib.rs +++ b/crates/librqbit/src/lib.rs @@ -42,6 +42,7 @@ pub mod api; mod api_error; mod bitv; mod bitv_factory; +mod blocklist; mod chunk_tracker; mod create_torrent_file; mod dht_utils;