First pass to implement socks5 support

This commit is contained in:
Igor Katson 2024-08-08 00:35:32 +01:00
parent 8c16239a0e
commit 70dcb2e6cb
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
11 changed files with 195 additions and 23 deletions

14
Cargo.lock generated
View file

@ -1346,6 +1346,7 @@ dependencies = [
"size_format",
"tempfile",
"tokio",
"tokio-socks",
"tokio-stream",
"tokio-test",
"tokio-util",
@ -2145,6 +2146,7 @@ dependencies = [
"tokio",
"tokio-native-tls",
"tokio-rustls",
"tokio-socks",
"tower-service",
"url",
"wasm-bindgen",
@ -2700,6 +2702,18 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-socks"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d4770b8024672c1101b3f6733eab95b18007dbe0847a8afe341fcf79e06043f"
dependencies = [
"either",
"futures-util",
"thiserror",
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.15"

View file

@ -42,7 +42,10 @@ anyhow = "1"
itertools = "0.12"
http = "1"
regex = "1"
reqwest = { version = "0.12", default-features = false, features = ["json"] }
reqwest = { version = "0.12", default-features = false, features = [
"json",
"socks",
] }
urlencoding = "2"
byteorder = "1"
bincode = "1"
@ -75,6 +78,7 @@ async-stream = "0.3.5"
memmap2 = { version = "0.9.4" }
lru = { version = "0.12.3", optional = true }
mime_guess = { version = "2.0.5", default-features = false }
tokio-socks = "0.5.2"
[dev-dependencies]
futures = { version = "0.3" }

View file

@ -1,4 +1,4 @@
use std::{collections::HashSet, net::SocketAddr};
use std::{collections::HashSet, net::SocketAddr, sync::Arc};
use anyhow::Context;
use buffers::ByteBufOwned;
@ -8,6 +8,7 @@ use tracing::{debug, error_span, Instrument};
use crate::{
peer_connection::PeerConnectionOptions, peer_info_reader, spawn_utils::BlockingSpawner,
stream_connect::StreamConnector,
};
use librqbit_core::hash_id::Id20;
@ -30,6 +31,7 @@ pub async fn read_metainfo_from_peer_receiver<A: Stream<Item = SocketAddr> + Unp
initial_addrs: Vec<SocketAddr>,
addrs_stream: A,
peer_connection_options: Option<PeerConnectionOptions>,
connector: Arc<StreamConnector>,
) -> ReadMetainfoResult<A> {
let mut seen = HashSet::<SocketAddr>::new();
let mut addrs = addrs_stream;
@ -38,6 +40,7 @@ pub async fn read_metainfo_from_peer_receiver<A: Stream<Item = SocketAddr> + Unp
let read_info_guarded = |addr| {
let semaphore = &semaphore;
let connector = connector.clone();
async move {
let token = semaphore.acquire().await?;
let ret = peer_info_reader::read_metainfo_from_peer(
@ -46,6 +49,7 @@ pub async fn read_metainfo_from_peer_receiver<A: Stream<Item = SocketAddr> + Unp
info_hash,
peer_connection_options,
BlockingSpawner::new(true),
connector,
)
.instrument(error_span!("read_metainfo_from_peer", ?addr))
.await
@ -93,7 +97,10 @@ mod tests {
use librqbit_core::peer_id::generate_peer_id;
use super::*;
use std::{str::FromStr, sync::Once};
use std::{
str::FromStr,
sync::{Arc, Once},
};
static LOG_INIT: Once = Once::new();
@ -114,7 +121,15 @@ mod tests {
let peer_rx = dht.get_peers(info_hash, None).unwrap();
let peer_id = generate_peer_id();
match read_metainfo_from_peer_receiver(peer_id, info_hash, Vec::new(), peer_rx, None).await
match read_metainfo_from_peer_receiver(
peer_id,
info_hash,
Vec::new(),
peer_rx,
None,
Arc::new(Default::default()),
)
.await
{
ReadMetainfoResult::Found { info, .. } => dbg!(info),
ReadMetainfoResult::ChannelClosed { .. } => todo!("should not have happened"),

View file

@ -41,6 +41,7 @@ mod read_buf;
mod session;
mod spawn_utils;
pub mod storage;
mod stream_connect;
mod torrent_state;
pub mod tracing_subscriber_config_utils;
mod type_aliases;

View file

@ -1,5 +1,6 @@
use std::{
net::SocketAddr,
sync::Arc,
time::{Duration, Instant},
};
@ -21,7 +22,7 @@ use serde_with::serde_as;
use tokio::time::timeout;
use tracing::{debug, trace};
use crate::{read_buf::ReadBuf, spawn_utils::BlockingSpawner};
use crate::{read_buf::ReadBuf, spawn_utils::BlockingSpawner, stream_connect::StreamConnector};
pub trait PeerConnectionHandler {
fn on_connected(&self, _connection_time: Duration) {}
@ -65,6 +66,7 @@ pub(crate) struct PeerConnection<H> {
peer_id: Id20,
options: PeerConnectionOptions,
spawner: BlockingSpawner,
connector: Arc<StreamConnector>,
}
pub(crate) async fn with_timeout<T, E>(
@ -88,6 +90,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
handler: H,
options: Option<PeerConnectionOptions>,
spawner: BlockingSpawner,
connector: Arc<StreamConnector>,
) -> Self {
PeerConnection {
handler,
@ -96,6 +99,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
peer_id,
spawner,
options: options.unwrap_or_default(),
connector,
}
}
@ -169,7 +173,8 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
.unwrap_or_else(|| Duration::from_secs(10));
let now = Instant::now();
let mut conn = with_timeout(connect_timeout, tokio::net::TcpStream::connect(self.addr))
let conn = self.connector.connect(self.addr);
let mut conn = with_timeout(connect_timeout, conn)
.await
.context("error connecting")?;
self.handler.on_connected(now.elapsed());
@ -218,7 +223,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
handshake_supports_extended: bool,
mut read_buf: ReadBuf,
mut write_buf: Vec<u8>,
mut conn: tokio::net::TcpStream,
mut conn: impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
mut have_broadcast: tokio::sync::broadcast::Receiver<ValidPieceIndex>,
) -> anyhow::Result<()> {

View file

@ -1,4 +1,4 @@
use std::net::SocketAddr;
use std::{net::SocketAddr, sync::Arc};
use bencode::from_bytes;
use buffers::{ByteBuf, ByteBufOwned};
@ -22,6 +22,7 @@ use crate::{
PeerConnection, PeerConnectionHandler, PeerConnectionOptions, WriterRequest,
},
spawn_utils::BlockingSpawner,
stream_connect::StreamConnector,
};
pub(crate) async fn read_metainfo_from_peer(
@ -30,6 +31,7 @@ pub(crate) async fn read_metainfo_from_peer(
info_hash: Id20,
peer_connection_options: Option<PeerConnectionOptions>,
spawner: BlockingSpawner,
connector: Arc<StreamConnector>,
) -> anyhow::Result<TorrentMetaV1Info<ByteBufOwned>> {
let (result_tx, result_rx) =
tokio::sync::oneshot::channel::<anyhow::Result<TorrentMetaV1Info<ByteBufOwned>>>();
@ -48,6 +50,7 @@ pub(crate) async fn read_metainfo_from_peer(
handler,
peer_connection_options,
spawner,
connector,
);
let result_reader = async move { result_rx.await? };
@ -234,6 +237,7 @@ impl PeerConnectionHandler for Handler {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::{net::SocketAddr, str::FromStr, sync::Once};
use librqbit_core::hash_id::Id20;
@ -260,10 +264,15 @@ mod tests {
let addr = SocketAddr::from_str("127.0.0.1:27311").unwrap();
let peer_id = generate_peer_id();
let info_hash = Id20::from_str("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7").unwrap();
dbg!(
read_metainfo_from_peer(addr, peer_id, info_hash, None, BlockingSpawner::new(true))
.await
.unwrap()
);
dbg!(read_metainfo_from_peer(
addr,
peer_id,
info_hash,
None,
BlockingSpawner::new(true),
Arc::new(Default::default())
)
.await
.unwrap());
}
}

View file

@ -19,6 +19,7 @@ use crate::{
storage::{
filesystem::FilesystemStorageFactory, BoxStorageFactory, StorageFactoryExt, TorrentStorage,
},
stream_connect::{SocksProxyConfig, StreamConnector},
torrent_state::{
ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive,
},
@ -197,6 +198,7 @@ pub struct Session {
default_storage_factory: Option<BoxStorageFactory>,
reqwest_client: reqwest::Client,
connector: Arc<StreamConnector>,
// This is stored for all tasks to stop when session is dropped.
_cancellation_token_drop_guard: DropGuard,
@ -413,11 +415,6 @@ impl<'a> AddTorrent<'a> {
}
}
pub struct SocksProxyConfig {
// must start with socks5
pub url: String,
}
#[derive(Default)]
pub struct SessionOptions {
/// Turn on to disable DHT.
@ -449,7 +446,8 @@ pub struct SessionOptions {
pub default_storage_factory: Option<BoxStorageFactory>,
pub socks_proxy: Option<SocksProxyConfig>,
// socks5://[username:password@]host:port
pub socks_proxy_url: Option<String>,
}
async fn create_tcp_listener(
@ -548,9 +546,27 @@ impl Session {
})
.unwrap_or_default();
let reqwest_client = reqwest::Client::builder()
.build()
.context("error building HTTP(S) client")?;
let proxy_config = match opts.socks_proxy_url.as_ref() {
Some(pu) => Some(
SocksProxyConfig::parse(pu)
.with_context(|| format!("error parsing proxy url {}", pu))?,
),
None => None,
};
let reqwest_client = {
let builder = if let Some(proxy_url) = opts.socks_proxy_url.as_ref() {
let proxy = reqwest::Proxy::all(proxy_url)
.context("error creating socks5 proxy for HTTP")?;
reqwest::Client::builder().proxy(proxy)
} else {
reqwest::Client::builder()
};
builder.build().context("error building HTTP(S) client")?
};
let stream_connector = Arc::new(StreamConnector::from(proxy_config));
let session = Arc::new(Self {
persistence_filename,
@ -566,6 +582,7 @@ impl Session {
disk_write_tx,
default_storage_factory: opts.default_storage_factory,
reqwest_client,
connector: stream_connector,
});
if let Some(mut disk_write_rx) = disk_write_rx {
@ -919,6 +936,7 @@ impl Session {
opts.initial_peers.clone().unwrap_or_default(),
peer_rx,
Some(self.merge_peer_opts(opts.peer_opts)),
self.connector.clone(),
)
.await
{
@ -1088,6 +1106,7 @@ impl Session {
.allow_overwrite(opts.overwrite)
.spawner(self.spawner)
.trackers(trackers)
.connector(self.connector.clone())
.peer_id(self.peer_id);
if let Some(d) = self.disk_write_tx.clone() {

View file

@ -0,0 +1,82 @@
use std::net::SocketAddr;
use anyhow::Context;
#[derive(Debug, Clone)]
pub(crate) struct SocksProxyConfig {
pub host: String,
pub port: u16,
pub username_password: Option<(String, String)>,
}
impl SocksProxyConfig {
pub fn parse(url: &str) -> anyhow::Result<Self> {
let url = ::url::Url::parse(url).context("invalid proxy URL")?;
if url.scheme() != "socks5" {
anyhow::bail!("proxy URL should have socks5 scheme");
}
let host = url.host_str().context("missing host")?;
let port = url.port().context("missing port")?;
let up = url
.password()
.map(|p| (url.username().to_owned(), p.to_owned()));
Ok(Self {
host: host.to_owned(),
port,
username_password: up,
})
}
async fn connect(
&self,
addr: SocketAddr,
) -> anyhow::Result<impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> {
let proxy_addr = (self.host.as_str(), self.port);
if let Some((username, password)) = self.username_password.as_ref() {
tokio_socks::tcp::Socks5Stream::connect_with_password(
proxy_addr,
addr,
username.as_str(),
password.as_str(),
)
.await
.context("error connecting to proxy")
} else {
tokio_socks::tcp::Socks5Stream::connect(proxy_addr, addr)
.await
.context("error connecting to proxy")
}
}
}
#[derive(Debug, Default)]
pub(crate) struct StreamConnector {
proxy_config: Option<SocksProxyConfig>,
}
impl From<Option<SocksProxyConfig>> for StreamConnector {
fn from(proxy_config: Option<SocksProxyConfig>) -> Self {
Self { proxy_config }
}
}
pub(crate) trait AsyncReadWrite:
tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin
{
}
impl<T> AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
impl StreamConnector {
pub async fn connect(&self, addr: SocketAddr) -> anyhow::Result<Box<dyn AsyncReadWrite>> {
if let Some(proxy) = self.proxy_config.as_ref() {
return Ok(Box::new(proxy.connect(addr).await?));
}
Ok(Box::new(
tokio::net::TcpStream::connect(addr)
.await
.context("error connecting")?,
))
}
}

View file

@ -382,6 +382,7 @@ impl TorrentStateLive {
&handler,
Some(options),
self.meta.spawner,
self.meta.connector.clone(),
);
let requester = handler.task_peer_chunk_requester();
@ -444,6 +445,7 @@ impl TorrentStateLive {
&handler,
Some(options),
state.meta.spawner,
state.meta.connector.clone(),
);
let requester = handler.task_peer_chunk_requester();

View file

@ -37,6 +37,7 @@ use crate::chunk_tracker::ChunkTracker;
use crate::file_info::FileInfo;
use crate::spawn_utils::BlockingSpawner;
use crate::storage::BoxStorageFactory;
use crate::stream_connect::StreamConnector;
use crate::torrent_state::stats::LiveStats;
use crate::type_aliases::DiskWorkQueueSender;
use crate::type_aliases::FileInfos;
@ -106,6 +107,7 @@ pub struct ManagedTorrentInfo {
pub file_infos: FileInfos,
pub span: tracing::Span,
pub(crate) options: ManagedTorrentOptions,
pub(crate) connector: Arc<StreamConnector>,
}
pub struct ManagedTorrent {
@ -509,6 +511,7 @@ pub(crate) struct ManagedTorrentBuilder {
allow_overwrite: bool,
storage_factory: BoxStorageFactory,
disk_writer: Option<DiskWorkQueueSender>,
connector: Arc<StreamConnector>,
}
impl ManagedTorrentBuilder {
@ -532,6 +535,7 @@ impl ManagedTorrentBuilder {
output_folder,
storage_factory,
disk_writer: None,
connector: Arc::new(Default::default()),
}
}
@ -580,6 +584,11 @@ impl ManagedTorrentBuilder {
self
}
pub fn connector(&mut self, value: Arc<StreamConnector>) -> &mut Self {
self.connector = value;
self
}
pub fn build(self, span: tracing::Span) -> anyhow::Result<ManagedTorrentHandle> {
let lengths = Lengths::from_torrent(&self.info)?;
let file_infos = self
@ -612,6 +621,7 @@ impl ManagedTorrentBuilder {
output_folder: self.output_folder,
disk_write_queue: self.disk_writer,
},
connector: self.connector,
});
let initializing = Arc::new(TorrentStateInitializing::new(

View file

@ -115,6 +115,13 @@ struct Opts {
/// If you use it, you know what you are doing.
#[arg(long)]
experimental_mmap_storage: bool,
/// Provide a socks5 URL.
/// The format is socks5://[username:password]@host:port
///
/// Alternatively, set this as an environment variable RQBIT_SOCKS_PROXY_URL
#[arg(long)]
socks_url: Option<String>,
}
#[derive(Parser)]
@ -281,6 +288,10 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> {
Err(e) => warn!("failed increasing open file limit: {:#}", e),
};
let socks_url = opts
.socks_url
.or_else(|| std::env::var("RQBIT_SOCKS_PROXY_URL").ok());
let mut sopts = SessionOptions {
disable_dht: opts.disable_dht,
disable_dht_persistence: opts.disable_dht_persistence,
@ -320,7 +331,7 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> {
wrap(FilesystemStorageFactory::default()).boxed()
}
}),
socks_proxy: None,
socks_proxy_url: socks_url,
};
let stats_printer = |session: Arc<Session>| async move {