Add blocklist-url launch parameter

Block incoming peers from blocked ips.
This commit is contained in:
Alexander WB 2025-02-20 20:19:08 +01:00
parent 187ce8c462
commit 6e9ecf8a26
3 changed files with 35 additions and 14 deletions

View file

@ -10,13 +10,20 @@ use tokio::{io::AsyncBufReadExt, io::BufReader};
use tokio_util::io::StreamReader;
use tracing::{debug, info, trace};
struct Blocklist {
pub struct Blocklist {
// Separate trees for IPv4 and IPv6 since they have different numeric ranges
ipv4_ranges: IntervalTree<u32, ()>,
ipv6_ranges: IntervalTree<u128, ()>,
}
impl Blocklist {
pub fn empty() -> Self {
return Self::new(
&Vec::<std::ops::Range<u32>>::new(),
&Vec::<std::ops::Range<u128>>::new(),
);
}
pub fn new(
ipv4_ranges: &Vec<std::ops::Range<u32>>,
ipv6_ranges: &Vec<std::ops::Range<u128>>,
@ -27,13 +34,6 @@ impl Blocklist {
}
}
fn ip_to_num(ip: &IpAddr) -> u128 {
match ip {
IpAddr::V4(ip) => u32::from_be_bytes(ip.octets()) as u128,
IpAddr::V6(ip) => u128::from_be_bytes(ip.octets()),
}
}
pub async fn load_from_url(url: &str) -> Result<Self> {
let response = reqwest::get(url).await.map_err(|e| anyhow::anyhow!(e))?;
if response.status() != 200 {
@ -74,7 +74,8 @@ impl Blocklist {
let mut peek_bytes = [0u8; 2];
let mut reader = tokio::io::BufReader::new(reader);
let buffer = reader.fill_buf().await?; // Get a reference to the buffer
// Peek the first bytes by filling buffer
let buffer = reader.fill_buf().await?;
if buffer.len() >= 2 {
peek_bytes.copy_from_slice(&buffer[0..2]);
} else {
@ -210,7 +211,6 @@ mod tests {
let (server, mut mock, url) = rx.recv().unwrap();
// Add response body and headers
mock = mock.with_body(content);
for &(key, value) in headers {
mock = mock.with_header(key, value);
@ -305,10 +305,7 @@ mod tests {
#[test]
fn test_blocklist_empty() {
let blocklist = Blocklist::new(
&Vec::<std::ops::Range<u32>>::new(),
&Vec::<std::ops::Range<u128>>::new(),
);
let blocklist = Blocklist::empty();
assert!(!blocklist.is_blocked(&"127.0.0.1".parse().unwrap()));
assert!(!blocklist.is_blocked(&"::1".parse().unwrap()));
}

View file

@ -11,6 +11,7 @@ use std::{
use crate::{
api::TorrentIdOrHash,
bitv_factory::{BitVFactory, NonPersistentBitVFactory},
blocklist,
dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult},
limits::{Limits, LimitsConfig},
merge_streams::merge_streams,
@ -125,6 +126,8 @@ pub struct Session {
pub(crate) concurrent_initialize_semaphore: Arc<tokio::sync::Semaphore>,
pub ratelimits: Limits,
pub blocklist: blocklist::Blocklist,
// Monitoring / tracing / logging
pub(crate) stats: SessionStats,
root_span: Option<Span>,
@ -417,6 +420,8 @@ pub struct SessionOptions {
pub ratelimits: LimitsConfig,
pub blocklist_url: Option<String>,
#[cfg(feature = "disable-upload")]
pub disable_upload: bool,
}
@ -607,6 +612,14 @@ impl Session {
let stream_connector = Arc::new(StreamConnector::from(proxy_config));
let blocklist: blocklist::Blocklist = if let Some(blocklist_url) = opts.blocklist_url {
blocklist::Blocklist::load_from_url(&blocklist_url)
.await
.unwrap()
} else {
blocklist::Blocklist::empty()
};
let session = Arc::new(Self {
persistence,
bitv_factory,
@ -632,6 +645,7 @@ impl Session {
ratelimits: Limits::new(opts.ratelimits),
#[cfg(feature = "disable-upload")]
_disable_upload: opts.disable_upload,
blocklist,
});
if let Some(mut disk_write_rx) = disk_write_rx {
@ -716,6 +730,11 @@ impl Session {
.read_write_timeout
.unwrap_or_else(|| Duration::from_secs(10));
let incoming_ip = addr.ip();
if self.blocklist.is_blocked(&incoming_ip) {
bail!("Incoming ip {incoming_ip} is in blocklist");
}
let mut read_buf = ReadBuf::new();
let h = read_buf
.read_handshake(&mut stream, rwtimeout)

View file

@ -228,6 +228,10 @@ struct Opts {
/// Limit upload to bytes-per-second.
#[arg(long = "ratelimit-upload", env = "RQBIT_RATELIMIT_UPLOAD")]
ratelimit_upload_bps: Option<NonZeroU32>,
/// Downloads a p2p blocklist from this url and blocks peers from it
#[arg(long, env = "RQBIT_BLOCKLIST_URL")]
blocklist_url: Option<String>,
}
#[derive(Parser)]
@ -494,6 +498,7 @@ async fn async_main(opts: Opts, cancel: CancellationToken) -> anyhow::Result<()>
upload_bps: opts.ratelimit_upload_bps,
download_bps: opts.ratelimit_download_bps,
},
blocklist_url: opts.blocklist_url,
};
let http_api_basic_auth = if let Ok(up) = std::env::var("RQBIT_HTTP_BASIC_AUTH_USERPASS") {