👌 Fix remaining pr comments
This commit is contained in:
parent
4c6e19ceab
commit
1a31340076
3 changed files with 37 additions and 31 deletions
|
|
@ -5,7 +5,7 @@ use intervaltree::IntervalTree;
|
||||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use tokio::io::AsyncRead;
|
use tokio::io::{AsyncBufRead, AsyncRead};
|
||||||
use tokio::{io::AsyncBufReadExt, io::BufReader};
|
use tokio::{io::AsyncBufReadExt, io::BufReader};
|
||||||
use tokio_util::io::StreamReader;
|
use tokio_util::io::StreamReader;
|
||||||
use tracing::{debug, info, trace};
|
use tracing::{debug, info, trace};
|
||||||
|
|
@ -85,7 +85,7 @@ impl Blocklist {
|
||||||
// Check for Gzip magic bytes (1F 8B)
|
// Check for Gzip magic bytes (1F 8B)
|
||||||
let is_gzip = peek_bytes == [0x1F, 0x8B];
|
let is_gzip = peek_bytes == [0x1F, 0x8B];
|
||||||
|
|
||||||
let reader: Pin<Box<dyn AsyncRead + Send>> = if is_gzip {
|
let mut reader: Pin<Box<dyn AsyncBufRead + Send>> = if is_gzip {
|
||||||
trace!("Detected Gzip file, decompressing...");
|
trace!("Detected Gzip file, decompressing...");
|
||||||
Box::pin(BufReader::new(GzipDecoder::new(reader)))
|
Box::pin(BufReader::new(GzipDecoder::new(reader)))
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -93,19 +93,14 @@ impl Blocklist {
|
||||||
Box::pin(reader)
|
Box::pin(reader)
|
||||||
};
|
};
|
||||||
|
|
||||||
let reader = BufReader::new(reader);
|
let mut line: String = Default::default();
|
||||||
let mut lines = reader.lines();
|
|
||||||
let mut ip_ranges: Vec<std::ops::Range<IpAddr>> = Vec::new();
|
let mut ip_ranges: Vec<std::ops::Range<IpAddr>> = Vec::new();
|
||||||
while let Some(line) = lines.next_line().await? {
|
while reader.read_line(&mut line).await? > 0 {
|
||||||
// Skip comments and empty lines
|
|
||||||
if line.starts_with('#') || line.trim().is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some((start_ip, end_ip)) = parse_ip_range(&line) {
|
if let Some((start_ip, end_ip)) = parse_ip_range(&line) {
|
||||||
let range = start_ip..(increment_ip(end_ip).unwrap());
|
let range = start_ip..(increment_ip(end_ip));
|
||||||
ip_ranges.push(range);
|
ip_ranges.push(range);
|
||||||
}
|
}
|
||||||
|
line.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
|
|
@ -117,21 +112,21 @@ impl Blocklist {
|
||||||
Ok(blocklist)
|
Ok(blocklist)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_blocked(&self, ip: &IpAddr) -> bool {
|
pub fn is_blocked(&self, ip: IpAddr) -> bool {
|
||||||
self.blocked_ranges.query_point(*ip).next().is_some()
|
self.blocked_ranges.query_point(ip).next().is_some()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Safely increments an `IpAddr`, returning `None` if it would overflow.
|
/// Safely increments an `IpAddr`, returning `None` if it would overflow.
|
||||||
fn increment_ip(ip: IpAddr) -> Option<IpAddr> {
|
fn increment_ip(ip: IpAddr) -> IpAddr {
|
||||||
match ip {
|
match ip {
|
||||||
IpAddr::V4(ipv4) => {
|
IpAddr::V4(ipv4) => {
|
||||||
let num = u32::from_be_bytes(ipv4.octets());
|
let num = u32::from_be_bytes(ipv4.octets());
|
||||||
num.checked_add(1).map(|n| IpAddr::V4(Ipv4Addr::from(n)))
|
std::net::IpAddr::V4(Ipv4Addr::from(num.saturating_add(1)))
|
||||||
}
|
}
|
||||||
IpAddr::V6(ipv6) => {
|
IpAddr::V6(ipv6) => {
|
||||||
let num = u128::from_be_bytes(ipv6.octets());
|
let num = u128::from_be_bytes(ipv6.octets());
|
||||||
num.checked_add(1).map(|n| IpAddr::V6(Ipv6Addr::from(n)))
|
std::net::IpAddr::V6(Ipv6Addr::from(num.saturating_add(1)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -195,8 +190,8 @@ mod tests {
|
||||||
Ok::<_, std::io::Error>(Cursor::new(gzipped_blocklist))
|
Ok::<_, std::io::Error>(Cursor::new(gzipped_blocklist))
|
||||||
})));
|
})));
|
||||||
let blocklist = Blocklist::create_from_stream(stream).await?;
|
let blocklist = Blocklist::create_from_stream(stream).await?;
|
||||||
assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap()));
|
assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap()));
|
||||||
assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap()));
|
assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap()));
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -213,8 +208,8 @@ mod tests {
|
||||||
Ok::<_, std::io::Error>(Cursor::new(blocklist.as_bytes().to_vec()))
|
Ok::<_, std::io::Error>(Cursor::new(blocklist.as_bytes().to_vec()))
|
||||||
})));
|
})));
|
||||||
let blocklist = Blocklist::create_from_stream(stream).await?;
|
let blocklist = Blocklist::create_from_stream(stream).await?;
|
||||||
assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap()));
|
assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap()));
|
||||||
assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap()));
|
assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap()));
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -236,10 +231,10 @@ mod tests {
|
||||||
let blocklist = Blocklist::load_from_file("temp_blocklist.txt").await?;
|
let blocklist = Blocklist::load_from_file("temp_blocklist.txt").await?;
|
||||||
|
|
||||||
// Verify the blocklist
|
// Verify the blocklist
|
||||||
assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap()));
|
assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap()));
|
||||||
assert!(!blocklist.is_blocked(&"8.8.8.8".parse().unwrap()));
|
assert!(!blocklist.is_blocked("8.8.8.8".parse().unwrap()));
|
||||||
assert!(blocklist.is_blocked(&"2001:db8::1".parse().unwrap()));
|
assert!(blocklist.is_blocked("2001:db8::1".parse().unwrap()));
|
||||||
assert!(!blocklist.is_blocked(&"2001:4860:4860::8888".parse().unwrap()));
|
assert!(!blocklist.is_blocked("2001:4860:4860::8888".parse().unwrap()));
|
||||||
|
|
||||||
// Clean up the temporary file
|
// Clean up the temporary file
|
||||||
tokio::fs::remove_file("temp_blocklist.txt").await?;
|
tokio::fs::remove_file("temp_blocklist.txt").await?;
|
||||||
|
|
@ -250,8 +245,8 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_blocklist_empty() {
|
fn test_blocklist_empty() {
|
||||||
let blocklist = Blocklist::empty();
|
let blocklist = Blocklist::empty();
|
||||||
assert!(!blocklist.is_blocked(&"127.0.0.1".parse().unwrap()));
|
assert!(!blocklist.is_blocked("127.0.0.1".parse().unwrap()));
|
||||||
assert!(!blocklist.is_blocked(&"::1".parse().unwrap()));
|
assert!(!blocklist.is_blocked("::1".parse().unwrap()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -268,11 +263,11 @@ mod tests {
|
||||||
|
|
||||||
let blocklist = Blocklist::new(vec![ipv4_range, ipv6_range]);
|
let blocklist = Blocklist::new(vec![ipv4_range, ipv6_range]);
|
||||||
// Test IPv4 addresses
|
// Test IPv4 addresses
|
||||||
assert!(blocklist.is_blocked(&"192.168.1.1".parse().unwrap()));
|
assert!(blocklist.is_blocked("192.168.1.1".parse().unwrap()));
|
||||||
assert!(!blocklist.is_blocked(&"10.0.0.1".parse().unwrap()));
|
assert!(!blocklist.is_blocked("10.0.0.1".parse().unwrap()));
|
||||||
|
|
||||||
// Test IPv6 addresses
|
// Test IPv6 addresses
|
||||||
assert!(blocklist.is_blocked(&"2001:db8::1".parse().unwrap()));
|
assert!(blocklist.is_blocked("2001:db8::1".parse().unwrap()));
|
||||||
assert!(!blocklist.is_blocked(&"2001:db9::1".parse().unwrap()));
|
assert!(!blocklist.is_blocked("2001:db9::1".parse().unwrap()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -732,7 +732,7 @@ impl Session {
|
||||||
.unwrap_or_else(|| Duration::from_secs(10));
|
.unwrap_or_else(|| Duration::from_secs(10));
|
||||||
|
|
||||||
let incoming_ip = addr.ip();
|
let incoming_ip = addr.ip();
|
||||||
if self.blocklist.is_blocked(&incoming_ip) {
|
if self.blocklist.is_blocked(incoming_ip) {
|
||||||
bail!("Incoming ip {incoming_ip} is in blocklist");
|
bail!("Incoming ip {incoming_ip} is in blocklist");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -572,6 +572,17 @@ impl TorrentStateLive {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let outgoing_ip = addr.ip();
|
||||||
|
let is_blocked_ip = state.shared.session.upgrade().map_or_else(
|
||||||
|
|| false,
|
||||||
|
|session| session.blocklist.is_blocked(outgoing_ip),
|
||||||
|
);
|
||||||
|
|
||||||
|
if is_blocked_ip {
|
||||||
|
info!("Outgoing ip {outgoing_ip} for peer is in blocklist skipping");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
let permit = state.peer_semaphore.clone().acquire_owned().await?;
|
let permit = state.peer_semaphore.clone().acquire_owned().await?;
|
||||||
state.spawn(
|
state.spawn(
|
||||||
error_span!(parent: state.shared.span.clone(), "manage_peer", peer = addr.to_string()),
|
error_span!(parent: state.shared.span.clone(), "manage_peer", peer = addr.to_string()),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue