diff --git a/Cargo.lock b/Cargo.lock index 5395ae9..dc300fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -171,6 +171,19 @@ dependencies = [ "syn 2.0.95", ] +[[package]] +name = "async-compression" +version = "0.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df895a515f70646414f4b45c0b79082783b80552b373a68283012928df56f522" +dependencies = [ + "flate2", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -2338,6 +2351,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "intervaltree" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "270bc34e57047cab801a8c871c124d9dc7132f6473c6401f645524f4e6edd111" +dependencies = [ + "smallvec", +] + [[package]] name = "ipnet" version = "2.10.1" @@ -2629,6 +2651,7 @@ dependencies = [ "anyhow", "arc-swap", "async-backtrace", + "async-compression", "async-stream", "async-trait", "axum 0.8.1", @@ -2645,6 +2668,7 @@ dependencies = [ "hex 0.4.3", "home", "http", + "intervaltree", "itertools 0.14.0", "librqbit-bencode", "librqbit-buffers", diff --git a/crates/librqbit/Cargo.toml b/crates/librqbit/Cargo.toml index 29da015..a2a4453 100644 --- a/crates/librqbit/Cargo.toml +++ b/crates/librqbit/Cargo.toml @@ -81,6 +81,7 @@ regex = "1" reqwest = { version = "0.12", default-features = false, features = [ "json", "socks", + "stream", ] } urlencoding = "2" byteorder = "1" @@ -117,6 +118,8 @@ async-backtrace = { version = "0.2", optional = true } notify = { version = "7", optional = true } walkdir = "2.5.0" arc-swap = "1.7.1" +intervaltree = "0.2.7" +async-compression = {version="0.4.18", features= ["tokio", "gzip"] } [build-dependencies] anyhow = "1" diff --git a/crates/librqbit/src/blocklist.rs b/crates/librqbit/src/blocklist.rs new file mode 100644 index 0000000..291de10 --- /dev/null +++ b/crates/librqbit/src/blocklist.rs @@ -0,0 +1,273 @@ +use anyhow::{Context, Result}; +use async_compression::tokio::bufread::GzipDecoder; +use futures::TryStreamExt; +use intervaltree::IntervalTree; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::pin::Pin; +use std::str::FromStr; +use tokio::io::{AsyncBufRead, AsyncRead}; +use tokio::{io::AsyncBufReadExt, io::BufReader}; +use tokio_util::io::StreamReader; +use tracing::{debug, info, trace}; +use url::Url; + +pub struct Blocklist { + // ipv4 and ipv6 do not overlap + // see: https://www.rfc-editor.org/rfc/rfc4291#section-2.5.5 + blocked_ranges: IntervalTree, +} + +impl Blocklist { + pub fn empty() -> Self { + return Self::new(std::iter::empty()); + } + + pub fn new(ip_ranges: impl IntoIterator>) -> Self { + Self { + blocked_ranges: IntervalTree::from_iter(ip_ranges.into_iter().map(|r| (r, ()))), + } + } + + pub async fn load_from_url(url: &str) -> Result { + let parsed_url = Url::parse(url).context("Failed to parse URL")?; + + if parsed_url.scheme() == "file" { + let path = parsed_url + .to_file_path() + .map_err(|_| anyhow::anyhow!("Failed to convert file URL to path"))?; + return Self::load_from_file(path.to_str().unwrap()).await; + } + + let response = reqwest::get(parsed_url) + .await + .context("Failed to send request for blocklist")?; + if response.status() != 200 { + anyhow::bail!("Failed to fetch blocklist: HTTP {}", response.status()); + } + + let content_length = response + .content_length() + .ok_or_else(|| anyhow::anyhow!("Failed to get content length"))?; + + if content_length < 2 { + anyhow::bail!("Content too short: not enough data to determine compression"); + } + + let reader = StreamReader::new( + response + .bytes_stream() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)), + ); + Self::create_from_stream(reader).await + } + + pub async fn load_from_file(path: &str) -> Result { + let file = tokio::fs::File::open(path).await?; + let reader = tokio::io::BufReader::new(file); + Self::create_from_stream(reader).await + } + + async fn create_from_stream(reader: R) -> Result + where + R: AsyncRead + Unpin + Send, + { + let mut peek_bytes = [0u8; 2]; + let mut reader = tokio::io::BufReader::new(reader); + + // Peek the first bytes by filling buffer + let buffer = reader.fill_buf().await?; + if buffer.len() >= 2 { + peek_bytes.copy_from_slice(&buffer[0..2]); + } else { + anyhow::bail!("Content too short: not enough data to determine compression"); + } + + // Check for Gzip magic bytes (1F 8B) + let is_gzip = peek_bytes == [0x1F, 0x8B]; + + let mut reader: Pin> = if is_gzip { + trace!("Detected Gzip file, decompressing..."); + Box::pin(BufReader::new(GzipDecoder::new(reader))) + } else { + trace!("Plain text file detected."); + Box::pin(reader) + }; + + let mut line: String = Default::default(); + let mut ip_ranges: Vec> = Vec::new(); + while reader.read_line(&mut line).await? > 0 { + if let Some((start_ip, end_ip)) = parse_ip_range(&line) { + let range = start_ip..(increment_ip(end_ip)); + ip_ranges.push(range); + } + line.clear(); + } + + info!( + ip_entry_count = ip_ranges.len(), + "Finished loading blocklist" + ); + + let blocklist = Self::new(ip_ranges); + Ok(blocklist) + } + + pub fn is_blocked(&self, ip: IpAddr) -> bool { + self.blocked_ranges.query_point(ip).next().is_some() + } +} + +/// Safely increments an `IpAddr`, returning `None` if it would overflow. +fn increment_ip(ip: IpAddr) -> IpAddr { + match ip { + IpAddr::V4(ipv4) => { + let num = u32::from_be_bytes(ipv4.octets()); + std::net::IpAddr::V4(Ipv4Addr::from(num.saturating_add(1))) + } + IpAddr::V6(ipv6) => { + let num = u128::from_be_bytes(ipv6.octets()); + std::net::IpAddr::V6(Ipv6Addr::from(num.saturating_add(1))) + } + } +} + +fn parse_ip_range(line: &str) -> Option<(IpAddr, IpAddr)> { + // Skip comments and empty lines + let line = line.trim(); + if line.starts_with('#') || line.is_empty() { + return None; + } + + let is_ipv4 = line.matches('.').count() >= 6; + // Find the split point based on whether it's IPv4 or not + let split_point: usize = if is_ipv4 { + line.rfind(':') + } else { + line.find(':') + } + .unwrap_or(0); + + let (rule_name, ip_range) = line.split_at(split_point + 1); + if let Some((start, end)) = ip_range.split_once('-') { + if let (Ok(start_ip), Ok(end_ip)) = + (IpAddr::from_str(start.trim()), IpAddr::from_str(end.trim())) + { + return Some((start_ip, end_ip)); + } else { + // Mismatched IP versions, skip this range + debug!(rulen_name = rule_name, "Could not be parsed"); + } + } + + None +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use super::*; + use async_compression::tokio::write::GzipEncoder; + use futures::stream::once; + use tokio::io::AsyncWriteExt; + + #[tokio::test] + async fn test_blocklist_gzipped() -> Result<()> { + let blocklist = r#" + # test + local:192.168.1.1-192.168.1.255 + localv6:2001:db8::1-2001:db8::ffff + "#; + let mut gzipped_blocklist = Vec::new(); + { + let mut encoder = GzipEncoder::new(&mut gzipped_blocklist); + encoder.write_all(blocklist.as_bytes()).await.unwrap(); + encoder.flush().await.unwrap(); + encoder.shutdown().await.unwrap(); + } + + let stream = StreamReader::new(Box::pin(once(async { + Ok::<_, std::io::Error>(Cursor::new(gzipped_blocklist)) + }))); + let blocklist = Blocklist::create_from_stream(stream).await?; + assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap())); + + Ok(()) + } + + #[tokio::test] + async fn test_blocklist_plaintext() -> Result<()> { + let blocklist = r#" + # test + local:192.168.1.1-192.168.1.255 + localv6:2001:db8::1-2001:db8::ffff + "#; + + let stream = StreamReader::new(Box::pin(once(async { + Ok::<_, std::io::Error>(Cursor::new(blocklist.as_bytes().to_vec())) + }))); + let blocklist = Blocklist::create_from_stream(stream).await?; + assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap())); + + Ok(()) + } + + #[tokio::test] + async fn test_blocklist_from_plaintext_file() -> Result<()> { + let blocklist_content = r#" + # test + local:192.168.1.1-192.168.1.255 + localv6:2001:db8::1-2001:db8::ffff + "#; + + // Create a temporary file + let mut temp_file = tokio::fs::File::create("temp_blocklist.txt").await?; + tokio::io::AsyncWriteExt::write_all(&mut temp_file, blocklist_content.as_bytes()).await?; + drop(temp_file); // Close the file + + // Load the blocklist from the file + let blocklist = Blocklist::load_from_file("temp_blocklist.txt").await?; + + // Verify the blocklist + assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap())); + assert!(blocklist.is_blocked("2001:db8::1".parse().unwrap())); + assert!(!blocklist.is_blocked("2001:4860:4860::8888".parse().unwrap())); + + // Clean up the temporary file + tokio::fs::remove_file("temp_blocklist.txt").await?; + + Ok(()) + } + + #[test] + fn test_blocklist_empty() { + let blocklist = Blocklist::empty(); + assert!(!blocklist.is_blocked("127.0.0.1".parse().unwrap())); + assert!(!blocklist.is_blocked("::1".parse().unwrap())); + } + + #[test] + fn test_manual_ranges() { + // Add IPv4 range + let start_v4: IpAddr = "192.168.0.0".parse().unwrap(); + let end_v4: IpAddr = "192.168.255.255".parse().unwrap(); + let ipv4_range = start_v4..end_v4; + + // Add IPv6 range + let start_v6: IpAddr = "2001:db8::".parse().unwrap(); + let end_v6: IpAddr = "2001:db8::ffff".parse().unwrap(); + let ipv6_range = start_v6..end_v6; + + let blocklist = Blocklist::new(vec![ipv4_range, ipv6_range]); + // Test IPv4 addresses + assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap())); + assert!(!blocklist.is_blocked("10.0.0.1".parse().unwrap())); + + // Test IPv6 addresses + assert!(blocklist.is_blocked("2001:db8::1".parse().unwrap())); + assert!(!blocklist.is_blocked("2001:db9::1".parse().unwrap())); + } +} diff --git a/crates/librqbit/src/lib.rs b/crates/librqbit/src/lib.rs index 7f33882..ef2e96b 100644 --- a/crates/librqbit/src/lib.rs +++ b/crates/librqbit/src/lib.rs @@ -42,6 +42,7 @@ pub mod api; mod api_error; mod bitv; mod bitv_factory; +mod blocklist; mod chunk_tracker; mod create_torrent_file; mod dht_utils; diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 0b0f4fb..951e025 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -11,6 +11,7 @@ use std::{ use crate::{ api::TorrentIdOrHash, bitv_factory::{BitVFactory, NonPersistentBitVFactory}, + blocklist, dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult}, limits::{Limits, LimitsConfig}, merge_streams::merge_streams, @@ -125,6 +126,8 @@ pub struct Session { pub(crate) concurrent_initialize_semaphore: Arc, pub ratelimits: Limits, + pub blocklist: blocklist::Blocklist, + // Monitoring / tracing / logging pub(crate) stats: SessionStats, root_span: Option, @@ -417,6 +420,8 @@ pub struct SessionOptions { pub ratelimits: LimitsConfig, + pub blocklist_url: Option, + #[cfg(feature = "disable-upload")] pub disable_upload: bool, } @@ -607,6 +612,15 @@ impl Session { let stream_connector = Arc::new(StreamConnector::from(proxy_config)); + let blocklist: blocklist::Blocklist = if let Some(blocklist_url) = opts.blocklist_url { + blocklist::Blocklist::load_from_url(&blocklist_url) + .await + .inspect_err(|e| warn!("failed to read blocklist: {e}")) + .unwrap() + } else { + blocklist::Blocklist::empty() + }; + let session = Arc::new(Self { persistence, bitv_factory, @@ -632,6 +646,7 @@ impl Session { ratelimits: Limits::new(opts.ratelimits), #[cfg(feature = "disable-upload")] _disable_upload: opts.disable_upload, + blocklist, }); if let Some(mut disk_write_rx) = disk_write_rx { @@ -716,6 +731,11 @@ impl Session { .read_write_timeout .unwrap_or_else(|| Duration::from_secs(10)); + let incoming_ip = addr.ip(); + if self.blocklist.is_blocked(incoming_ip) { + bail!("Incoming ip {incoming_ip} is in blocklist"); + } + let mut read_buf = ReadBuf::new(); let h = read_buf .read_handshake(&mut stream, rwtimeout) diff --git a/crates/librqbit/src/torrent_state/live/mod.rs b/crates/librqbit/src/torrent_state/live/mod.rs index d7f9748..3729494 100644 --- a/crates/librqbit/src/torrent_state/live/mod.rs +++ b/crates/librqbit/src/torrent_state/live/mod.rs @@ -572,6 +572,17 @@ impl TorrentStateLive { continue; } + let outgoing_ip = addr.ip(); + let is_blocked_ip = state.shared.session.upgrade().map_or_else( + || false, + |session| session.blocklist.is_blocked(outgoing_ip), + ); + + if is_blocked_ip { + info!("Outgoing ip {outgoing_ip} for peer is in blocklist skipping"); + continue; + } + let permit = state.peer_semaphore.clone().acquire_owned().await?; state.spawn( error_span!(parent: state.shared.span.clone(), "manage_peer", peer = addr.to_string()), diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index 3b9cd86..23ada49 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -228,6 +228,10 @@ struct Opts { /// Limit upload to bytes-per-second. #[arg(long = "ratelimit-upload", env = "RQBIT_RATELIMIT_UPLOAD")] ratelimit_upload_bps: Option, + + /// Downloads a p2p blocklist from this url and blocks peers from it + #[arg(long, env = "RQBIT_BLOCKLIST_URL")] + blocklist_url: Option, } #[derive(Parser)] @@ -494,6 +498,7 @@ async fn async_main(opts: Opts, cancel: CancellationToken) -> anyhow::Result<()> upload_bps: opts.ratelimit_upload_bps, download_bps: opts.ratelimit_download_bps, }, + blocklist_url: opts.blocklist_url, }; let http_api_basic_auth = if let Ok(up) = std::env::var("RQBIT_HTTP_BASIC_AUTH_USERPASS") {