First pass to implement socks5 support

This commit is contained in:
Igor Katson 2024-08-08 00:35:32 +01:00
parent 8c16239a0e
commit 70dcb2e6cb
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
11 changed files with 195 additions and 23 deletions

14
Cargo.lock generated
View file

@ -1346,6 +1346,7 @@ dependencies = [
"size_format", "size_format",
"tempfile", "tempfile",
"tokio", "tokio",
"tokio-socks",
"tokio-stream", "tokio-stream",
"tokio-test", "tokio-test",
"tokio-util", "tokio-util",
@ -2145,6 +2146,7 @@ dependencies = [
"tokio", "tokio",
"tokio-native-tls", "tokio-native-tls",
"tokio-rustls", "tokio-rustls",
"tokio-socks",
"tower-service", "tower-service",
"url", "url",
"wasm-bindgen", "wasm-bindgen",
@ -2700,6 +2702,18 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-socks"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d4770b8024672c1101b3f6733eab95b18007dbe0847a8afe341fcf79e06043f"
dependencies = [
"either",
"futures-util",
"thiserror",
"tokio",
]
[[package]] [[package]]
name = "tokio-stream" name = "tokio-stream"
version = "0.1.15" version = "0.1.15"

View file

@ -42,7 +42,10 @@ anyhow = "1"
itertools = "0.12" itertools = "0.12"
http = "1" http = "1"
regex = "1" regex = "1"
reqwest = { version = "0.12", default-features = false, features = ["json"] } reqwest = { version = "0.12", default-features = false, features = [
"json",
"socks",
] }
urlencoding = "2" urlencoding = "2"
byteorder = "1" byteorder = "1"
bincode = "1" bincode = "1"
@ -75,6 +78,7 @@ async-stream = "0.3.5"
memmap2 = { version = "0.9.4" } memmap2 = { version = "0.9.4" }
lru = { version = "0.12.3", optional = true } lru = { version = "0.12.3", optional = true }
mime_guess = { version = "2.0.5", default-features = false } mime_guess = { version = "2.0.5", default-features = false }
tokio-socks = "0.5.2"
[dev-dependencies] [dev-dependencies]
futures = { version = "0.3" } futures = { version = "0.3" }

View file

@ -1,4 +1,4 @@
use std::{collections::HashSet, net::SocketAddr}; use std::{collections::HashSet, net::SocketAddr, sync::Arc};
use anyhow::Context; use anyhow::Context;
use buffers::ByteBufOwned; use buffers::ByteBufOwned;
@ -8,6 +8,7 @@ use tracing::{debug, error_span, Instrument};
use crate::{ use crate::{
peer_connection::PeerConnectionOptions, peer_info_reader, spawn_utils::BlockingSpawner, peer_connection::PeerConnectionOptions, peer_info_reader, spawn_utils::BlockingSpawner,
stream_connect::StreamConnector,
}; };
use librqbit_core::hash_id::Id20; use librqbit_core::hash_id::Id20;
@ -30,6 +31,7 @@ pub async fn read_metainfo_from_peer_receiver<A: Stream<Item = SocketAddr> + Unp
initial_addrs: Vec<SocketAddr>, initial_addrs: Vec<SocketAddr>,
addrs_stream: A, addrs_stream: A,
peer_connection_options: Option<PeerConnectionOptions>, peer_connection_options: Option<PeerConnectionOptions>,
connector: Arc<StreamConnector>,
) -> ReadMetainfoResult<A> { ) -> ReadMetainfoResult<A> {
let mut seen = HashSet::<SocketAddr>::new(); let mut seen = HashSet::<SocketAddr>::new();
let mut addrs = addrs_stream; let mut addrs = addrs_stream;
@ -38,6 +40,7 @@ pub async fn read_metainfo_from_peer_receiver<A: Stream<Item = SocketAddr> + Unp
let read_info_guarded = |addr| { let read_info_guarded = |addr| {
let semaphore = &semaphore; let semaphore = &semaphore;
let connector = connector.clone();
async move { async move {
let token = semaphore.acquire().await?; let token = semaphore.acquire().await?;
let ret = peer_info_reader::read_metainfo_from_peer( let ret = peer_info_reader::read_metainfo_from_peer(
@ -46,6 +49,7 @@ pub async fn read_metainfo_from_peer_receiver<A: Stream<Item = SocketAddr> + Unp
info_hash, info_hash,
peer_connection_options, peer_connection_options,
BlockingSpawner::new(true), BlockingSpawner::new(true),
connector,
) )
.instrument(error_span!("read_metainfo_from_peer", ?addr)) .instrument(error_span!("read_metainfo_from_peer", ?addr))
.await .await
@ -93,7 +97,10 @@ mod tests {
use librqbit_core::peer_id::generate_peer_id; use librqbit_core::peer_id::generate_peer_id;
use super::*; use super::*;
use std::{str::FromStr, sync::Once}; use std::{
str::FromStr,
sync::{Arc, Once},
};
static LOG_INIT: Once = Once::new(); static LOG_INIT: Once = Once::new();
@ -114,7 +121,15 @@ mod tests {
let peer_rx = dht.get_peers(info_hash, None).unwrap(); let peer_rx = dht.get_peers(info_hash, None).unwrap();
let peer_id = generate_peer_id(); let peer_id = generate_peer_id();
match read_metainfo_from_peer_receiver(peer_id, info_hash, Vec::new(), peer_rx, None).await match read_metainfo_from_peer_receiver(
peer_id,
info_hash,
Vec::new(),
peer_rx,
None,
Arc::new(Default::default()),
)
.await
{ {
ReadMetainfoResult::Found { info, .. } => dbg!(info), ReadMetainfoResult::Found { info, .. } => dbg!(info),
ReadMetainfoResult::ChannelClosed { .. } => todo!("should not have happened"), ReadMetainfoResult::ChannelClosed { .. } => todo!("should not have happened"),

View file

@ -41,6 +41,7 @@ mod read_buf;
mod session; mod session;
mod spawn_utils; mod spawn_utils;
pub mod storage; pub mod storage;
mod stream_connect;
mod torrent_state; mod torrent_state;
pub mod tracing_subscriber_config_utils; pub mod tracing_subscriber_config_utils;
mod type_aliases; mod type_aliases;

View file

@ -1,5 +1,6 @@
use std::{ use std::{
net::SocketAddr, net::SocketAddr,
sync::Arc,
time::{Duration, Instant}, time::{Duration, Instant},
}; };
@ -21,7 +22,7 @@ use serde_with::serde_as;
use tokio::time::timeout; use tokio::time::timeout;
use tracing::{debug, trace}; use tracing::{debug, trace};
use crate::{read_buf::ReadBuf, spawn_utils::BlockingSpawner}; use crate::{read_buf::ReadBuf, spawn_utils::BlockingSpawner, stream_connect::StreamConnector};
pub trait PeerConnectionHandler { pub trait PeerConnectionHandler {
fn on_connected(&self, _connection_time: Duration) {} fn on_connected(&self, _connection_time: Duration) {}
@ -65,6 +66,7 @@ pub(crate) struct PeerConnection<H> {
peer_id: Id20, peer_id: Id20,
options: PeerConnectionOptions, options: PeerConnectionOptions,
spawner: BlockingSpawner, spawner: BlockingSpawner,
connector: Arc<StreamConnector>,
} }
pub(crate) async fn with_timeout<T, E>( pub(crate) async fn with_timeout<T, E>(
@ -88,6 +90,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
handler: H, handler: H,
options: Option<PeerConnectionOptions>, options: Option<PeerConnectionOptions>,
spawner: BlockingSpawner, spawner: BlockingSpawner,
connector: Arc<StreamConnector>,
) -> Self { ) -> Self {
PeerConnection { PeerConnection {
handler, handler,
@ -96,6 +99,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
peer_id, peer_id,
spawner, spawner,
options: options.unwrap_or_default(), options: options.unwrap_or_default(),
connector,
} }
} }
@ -169,7 +173,8 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
.unwrap_or_else(|| Duration::from_secs(10)); .unwrap_or_else(|| Duration::from_secs(10));
let now = Instant::now(); let now = Instant::now();
let mut conn = with_timeout(connect_timeout, tokio::net::TcpStream::connect(self.addr)) let conn = self.connector.connect(self.addr);
let mut conn = with_timeout(connect_timeout, conn)
.await .await
.context("error connecting")?; .context("error connecting")?;
self.handler.on_connected(now.elapsed()); self.handler.on_connected(now.elapsed());
@ -218,7 +223,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
handshake_supports_extended: bool, handshake_supports_extended: bool,
mut read_buf: ReadBuf, mut read_buf: ReadBuf,
mut write_buf: Vec<u8>, mut write_buf: Vec<u8>,
mut conn: tokio::net::TcpStream, mut conn: impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>, mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
mut have_broadcast: tokio::sync::broadcast::Receiver<ValidPieceIndex>, mut have_broadcast: tokio::sync::broadcast::Receiver<ValidPieceIndex>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {

View file

@ -1,4 +1,4 @@
use std::net::SocketAddr; use std::{net::SocketAddr, sync::Arc};
use bencode::from_bytes; use bencode::from_bytes;
use buffers::{ByteBuf, ByteBufOwned}; use buffers::{ByteBuf, ByteBufOwned};
@ -22,6 +22,7 @@ use crate::{
PeerConnection, PeerConnectionHandler, PeerConnectionOptions, WriterRequest, PeerConnection, PeerConnectionHandler, PeerConnectionOptions, WriterRequest,
}, },
spawn_utils::BlockingSpawner, spawn_utils::BlockingSpawner,
stream_connect::StreamConnector,
}; };
pub(crate) async fn read_metainfo_from_peer( pub(crate) async fn read_metainfo_from_peer(
@ -30,6 +31,7 @@ pub(crate) async fn read_metainfo_from_peer(
info_hash: Id20, info_hash: Id20,
peer_connection_options: Option<PeerConnectionOptions>, peer_connection_options: Option<PeerConnectionOptions>,
spawner: BlockingSpawner, spawner: BlockingSpawner,
connector: Arc<StreamConnector>,
) -> anyhow::Result<TorrentMetaV1Info<ByteBufOwned>> { ) -> anyhow::Result<TorrentMetaV1Info<ByteBufOwned>> {
let (result_tx, result_rx) = let (result_tx, result_rx) =
tokio::sync::oneshot::channel::<anyhow::Result<TorrentMetaV1Info<ByteBufOwned>>>(); tokio::sync::oneshot::channel::<anyhow::Result<TorrentMetaV1Info<ByteBufOwned>>>();
@ -48,6 +50,7 @@ pub(crate) async fn read_metainfo_from_peer(
handler, handler,
peer_connection_options, peer_connection_options,
spawner, spawner,
connector,
); );
let result_reader = async move { result_rx.await? }; let result_reader = async move { result_rx.await? };
@ -234,6 +237,7 @@ impl PeerConnectionHandler for Handler {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Arc;
use std::{net::SocketAddr, str::FromStr, sync::Once}; use std::{net::SocketAddr, str::FromStr, sync::Once};
use librqbit_core::hash_id::Id20; use librqbit_core::hash_id::Id20;
@ -260,10 +264,15 @@ mod tests {
let addr = SocketAddr::from_str("127.0.0.1:27311").unwrap(); let addr = SocketAddr::from_str("127.0.0.1:27311").unwrap();
let peer_id = generate_peer_id(); let peer_id = generate_peer_id();
let info_hash = Id20::from_str("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7").unwrap(); let info_hash = Id20::from_str("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7").unwrap();
dbg!( dbg!(read_metainfo_from_peer(
read_metainfo_from_peer(addr, peer_id, info_hash, None, BlockingSpawner::new(true)) addr,
.await peer_id,
.unwrap() info_hash,
); None,
BlockingSpawner::new(true),
Arc::new(Default::default())
)
.await
.unwrap());
} }
} }

View file

@ -19,6 +19,7 @@ use crate::{
storage::{ storage::{
filesystem::FilesystemStorageFactory, BoxStorageFactory, StorageFactoryExt, TorrentStorage, filesystem::FilesystemStorageFactory, BoxStorageFactory, StorageFactoryExt, TorrentStorage,
}, },
stream_connect::{SocksProxyConfig, StreamConnector},
torrent_state::{ torrent_state::{
ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive, ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive,
}, },
@ -197,6 +198,7 @@ pub struct Session {
default_storage_factory: Option<BoxStorageFactory>, default_storage_factory: Option<BoxStorageFactory>,
reqwest_client: reqwest::Client, reqwest_client: reqwest::Client,
connector: Arc<StreamConnector>,
// This is stored for all tasks to stop when session is dropped. // This is stored for all tasks to stop when session is dropped.
_cancellation_token_drop_guard: DropGuard, _cancellation_token_drop_guard: DropGuard,
@ -413,11 +415,6 @@ impl<'a> AddTorrent<'a> {
} }
} }
pub struct SocksProxyConfig {
// must start with socks5
pub url: String,
}
#[derive(Default)] #[derive(Default)]
pub struct SessionOptions { pub struct SessionOptions {
/// Turn on to disable DHT. /// Turn on to disable DHT.
@ -449,7 +446,8 @@ pub struct SessionOptions {
pub default_storage_factory: Option<BoxStorageFactory>, pub default_storage_factory: Option<BoxStorageFactory>,
pub socks_proxy: Option<SocksProxyConfig>, // socks5://[username:password@]host:port
pub socks_proxy_url: Option<String>,
} }
async fn create_tcp_listener( async fn create_tcp_listener(
@ -548,9 +546,27 @@ impl Session {
}) })
.unwrap_or_default(); .unwrap_or_default();
let reqwest_client = reqwest::Client::builder() let proxy_config = match opts.socks_proxy_url.as_ref() {
.build() Some(pu) => Some(
.context("error building HTTP(S) client")?; SocksProxyConfig::parse(pu)
.with_context(|| format!("error parsing proxy url {}", pu))?,
),
None => None,
};
let reqwest_client = {
let builder = if let Some(proxy_url) = opts.socks_proxy_url.as_ref() {
let proxy = reqwest::Proxy::all(proxy_url)
.context("error creating socks5 proxy for HTTP")?;
reqwest::Client::builder().proxy(proxy)
} else {
reqwest::Client::builder()
};
builder.build().context("error building HTTP(S) client")?
};
let stream_connector = Arc::new(StreamConnector::from(proxy_config));
let session = Arc::new(Self { let session = Arc::new(Self {
persistence_filename, persistence_filename,
@ -566,6 +582,7 @@ impl Session {
disk_write_tx, disk_write_tx,
default_storage_factory: opts.default_storage_factory, default_storage_factory: opts.default_storage_factory,
reqwest_client, reqwest_client,
connector: stream_connector,
}); });
if let Some(mut disk_write_rx) = disk_write_rx { if let Some(mut disk_write_rx) = disk_write_rx {
@ -919,6 +936,7 @@ impl Session {
opts.initial_peers.clone().unwrap_or_default(), opts.initial_peers.clone().unwrap_or_default(),
peer_rx, peer_rx,
Some(self.merge_peer_opts(opts.peer_opts)), Some(self.merge_peer_opts(opts.peer_opts)),
self.connector.clone(),
) )
.await .await
{ {
@ -1088,6 +1106,7 @@ impl Session {
.allow_overwrite(opts.overwrite) .allow_overwrite(opts.overwrite)
.spawner(self.spawner) .spawner(self.spawner)
.trackers(trackers) .trackers(trackers)
.connector(self.connector.clone())
.peer_id(self.peer_id); .peer_id(self.peer_id);
if let Some(d) = self.disk_write_tx.clone() { if let Some(d) = self.disk_write_tx.clone() {

View file

@ -0,0 +1,82 @@
use std::net::SocketAddr;
use anyhow::Context;
#[derive(Debug, Clone)]
pub(crate) struct SocksProxyConfig {
pub host: String,
pub port: u16,
pub username_password: Option<(String, String)>,
}
impl SocksProxyConfig {
pub fn parse(url: &str) -> anyhow::Result<Self> {
let url = ::url::Url::parse(url).context("invalid proxy URL")?;
if url.scheme() != "socks5" {
anyhow::bail!("proxy URL should have socks5 scheme");
}
let host = url.host_str().context("missing host")?;
let port = url.port().context("missing port")?;
let up = url
.password()
.map(|p| (url.username().to_owned(), p.to_owned()));
Ok(Self {
host: host.to_owned(),
port,
username_password: up,
})
}
async fn connect(
&self,
addr: SocketAddr,
) -> anyhow::Result<impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> {
let proxy_addr = (self.host.as_str(), self.port);
if let Some((username, password)) = self.username_password.as_ref() {
tokio_socks::tcp::Socks5Stream::connect_with_password(
proxy_addr,
addr,
username.as_str(),
password.as_str(),
)
.await
.context("error connecting to proxy")
} else {
tokio_socks::tcp::Socks5Stream::connect(proxy_addr, addr)
.await
.context("error connecting to proxy")
}
}
}
#[derive(Debug, Default)]
pub(crate) struct StreamConnector {
proxy_config: Option<SocksProxyConfig>,
}
impl From<Option<SocksProxyConfig>> for StreamConnector {
fn from(proxy_config: Option<SocksProxyConfig>) -> Self {
Self { proxy_config }
}
}
pub(crate) trait AsyncReadWrite:
tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin
{
}
impl<T> AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
impl StreamConnector {
pub async fn connect(&self, addr: SocketAddr) -> anyhow::Result<Box<dyn AsyncReadWrite>> {
if let Some(proxy) = self.proxy_config.as_ref() {
return Ok(Box::new(proxy.connect(addr).await?));
}
Ok(Box::new(
tokio::net::TcpStream::connect(addr)
.await
.context("error connecting")?,
))
}
}

View file

@ -382,6 +382,7 @@ impl TorrentStateLive {
&handler, &handler,
Some(options), Some(options),
self.meta.spawner, self.meta.spawner,
self.meta.connector.clone(),
); );
let requester = handler.task_peer_chunk_requester(); let requester = handler.task_peer_chunk_requester();
@ -444,6 +445,7 @@ impl TorrentStateLive {
&handler, &handler,
Some(options), Some(options),
state.meta.spawner, state.meta.spawner,
state.meta.connector.clone(),
); );
let requester = handler.task_peer_chunk_requester(); let requester = handler.task_peer_chunk_requester();

View file

@ -37,6 +37,7 @@ use crate::chunk_tracker::ChunkTracker;
use crate::file_info::FileInfo; use crate::file_info::FileInfo;
use crate::spawn_utils::BlockingSpawner; use crate::spawn_utils::BlockingSpawner;
use crate::storage::BoxStorageFactory; use crate::storage::BoxStorageFactory;
use crate::stream_connect::StreamConnector;
use crate::torrent_state::stats::LiveStats; use crate::torrent_state::stats::LiveStats;
use crate::type_aliases::DiskWorkQueueSender; use crate::type_aliases::DiskWorkQueueSender;
use crate::type_aliases::FileInfos; use crate::type_aliases::FileInfos;
@ -106,6 +107,7 @@ pub struct ManagedTorrentInfo {
pub file_infos: FileInfos, pub file_infos: FileInfos,
pub span: tracing::Span, pub span: tracing::Span,
pub(crate) options: ManagedTorrentOptions, pub(crate) options: ManagedTorrentOptions,
pub(crate) connector: Arc<StreamConnector>,
} }
pub struct ManagedTorrent { pub struct ManagedTorrent {
@ -509,6 +511,7 @@ pub(crate) struct ManagedTorrentBuilder {
allow_overwrite: bool, allow_overwrite: bool,
storage_factory: BoxStorageFactory, storage_factory: BoxStorageFactory,
disk_writer: Option<DiskWorkQueueSender>, disk_writer: Option<DiskWorkQueueSender>,
connector: Arc<StreamConnector>,
} }
impl ManagedTorrentBuilder { impl ManagedTorrentBuilder {
@ -532,6 +535,7 @@ impl ManagedTorrentBuilder {
output_folder, output_folder,
storage_factory, storage_factory,
disk_writer: None, disk_writer: None,
connector: Arc::new(Default::default()),
} }
} }
@ -580,6 +584,11 @@ impl ManagedTorrentBuilder {
self self
} }
pub fn connector(&mut self, value: Arc<StreamConnector>) -> &mut Self {
self.connector = value;
self
}
pub fn build(self, span: tracing::Span) -> anyhow::Result<ManagedTorrentHandle> { pub fn build(self, span: tracing::Span) -> anyhow::Result<ManagedTorrentHandle> {
let lengths = Lengths::from_torrent(&self.info)?; let lengths = Lengths::from_torrent(&self.info)?;
let file_infos = self let file_infos = self
@ -612,6 +621,7 @@ impl ManagedTorrentBuilder {
output_folder: self.output_folder, output_folder: self.output_folder,
disk_write_queue: self.disk_writer, disk_write_queue: self.disk_writer,
}, },
connector: self.connector,
}); });
let initializing = Arc::new(TorrentStateInitializing::new( let initializing = Arc::new(TorrentStateInitializing::new(

View file

@ -115,6 +115,13 @@ struct Opts {
/// If you use it, you know what you are doing. /// If you use it, you know what you are doing.
#[arg(long)] #[arg(long)]
experimental_mmap_storage: bool, experimental_mmap_storage: bool,
/// Provide a socks5 URL.
/// The format is socks5://[username:password]@host:port
///
/// Alternatively, set this as an environment variable RQBIT_SOCKS_PROXY_URL
#[arg(long)]
socks_url: Option<String>,
} }
#[derive(Parser)] #[derive(Parser)]
@ -281,6 +288,10 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> {
Err(e) => warn!("failed increasing open file limit: {:#}", e), Err(e) => warn!("failed increasing open file limit: {:#}", e),
}; };
let socks_url = opts
.socks_url
.or_else(|| std::env::var("RQBIT_SOCKS_PROXY_URL").ok());
let mut sopts = SessionOptions { let mut sopts = SessionOptions {
disable_dht: opts.disable_dht, disable_dht: opts.disable_dht,
disable_dht_persistence: opts.disable_dht_persistence, disable_dht_persistence: opts.disable_dht_persistence,
@ -320,7 +331,7 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> {
wrap(FilesystemStorageFactory::default()).boxed() wrap(FilesystemStorageFactory::default()).boxed()
} }
}), }),
socks_proxy: None, socks_proxy_url: socks_url,
}; };
let stats_printer = |session: Arc<Session>| async move { let stats_printer = |session: Arc<Session>| async move {