diff --git a/Cargo.lock b/Cargo.lock index 723eea5..54477f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2827,10 +2827,12 @@ dependencies = [ "librqbit-bencode", "librqbit-buffers", "librqbit-core", + "parking_lot", "rand 0.8.5", "reqwest", "serde", "tokio", + "tokio-util", "tracing", "url", "urlencoding", diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index e65eded..95e1763 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -61,7 +61,7 @@ use tokio::{ }; use tokio_util::sync::{CancellationToken, DropGuard}; use tracing::{debug, error, error_span, info, trace, warn, Instrument, Span}; -use tracker_comms::TrackerComms; +use tracker_comms::{TrackerComms, UdpTrackerClient}; pub const SUPPORTED_SCHEMES: [&str; 3] = ["http:", "https:", "magnet:"]; @@ -110,6 +110,7 @@ pub struct Session { dht: Option, pub(crate) connector: Arc, reqwest_client: reqwest::Client, + udp_tracker_client: UdpTrackerClient, // Lifecycle management cancellation_token: CancellationToken, @@ -625,6 +626,10 @@ impl Session { blocklist::Blocklist::empty() }; + let udp_tracker_client = UdpTrackerClient::new(token.clone()) + .await + .context("error creating UDP tracker client")?; + let session = Arc::new(Self { persistence, bitv_factory, @@ -647,6 +652,7 @@ impl Session { concurrent_initialize_semaphore: Arc::new(tokio::sync::Semaphore::new( opts.concurrent_init_limit.unwrap_or(3), )), + udp_tracker_client, ratelimits: Limits::new(opts.ratelimits), trackers: opts.trackers, #[cfg(feature = "disable-upload")] @@ -1355,6 +1361,7 @@ impl Session { force_tracker_interval, announce_port, self.reqwest_client.clone(), + self.udp_tracker_client.clone(), ); let initial_peers_rx = if initial_peers.is_empty() { diff --git a/crates/tracker_comms/Cargo.toml b/crates/tracker_comms/Cargo.toml index 3bb7699..704894e 100644 --- a/crates/tracker_comms/Cargo.toml +++ b/crates/tracker_comms/Cargo.toml @@ -32,3 +32,5 @@ tracing = "0.1.40" reqwest = { version = "0.12", default-features = false, features = ["json"] } bencode = { path = "../bencode", default-features = false, package = "librqbit-bencode", version = "3" } url = { version = "2", default-features = false } +parking_lot = "0.12.3" +tokio-util = "0.7.13" diff --git a/crates/tracker_comms/src/lib.rs b/crates/tracker_comms/src/lib.rs index 74cc980..cdae214 100644 --- a/crates/tracker_comms/src/lib.rs +++ b/crates/tracker_comms/src/lib.rs @@ -3,3 +3,4 @@ mod tracker_comms_http; mod tracker_comms_udp; pub use tracker_comms::*; +pub use tracker_comms_udp::UdpTrackerClient; diff --git a/crates/tracker_comms/src/tracker_comms.rs b/crates/tracker_comms/src/tracker_comms.rs index 0bfc60a..e32a9da 100644 --- a/crates/tracker_comms/src/tracker_comms.rs +++ b/crates/tracker_comms/src/tracker_comms.rs @@ -18,6 +18,7 @@ use url::Url; use crate::tracker_comms_http; use crate::tracker_comms_udp; +use crate::tracker_comms_udp::UdpTrackerClient; use librqbit_core::hash_id::Id20; pub struct TrackerComms { @@ -89,6 +90,7 @@ impl std::fmt::Debug for SupportedTracker { } impl TrackerComms { + #[allow(clippy::too_many_arguments)] pub fn start( info_hash: Id20, peer_id: Id20, @@ -97,6 +99,7 @@ impl TrackerComms { force_interval: Option, tcp_listen_port: Option, reqwest_client: reqwest::Client, + udp_client: UdpTrackerClient, ) -> Option> { let trackers = trackers .into_iter() @@ -131,7 +134,7 @@ impl TrackerComms { }); let mut futures = FuturesUnordered::new(); for tracker in trackers { - futures.push(comms.add_tracker(tracker)) + futures.push(comms.add_tracker(tracker, &udp_client)) } while !(futures.is_empty()) { tokio::select! { @@ -155,6 +158,7 @@ impl TrackerComms { fn add_tracker( &self, url: SupportedTracker, + client: &UdpTrackerClient, ) -> Either< impl std::future::Future> + '_ + Send, impl std::future::Future> + '_ + Send, @@ -163,7 +167,7 @@ impl TrackerComms { match url { SupportedTracker::Udp(url) => { let span = error_span!(parent: None, "udp_tracker", tracker = %url, info_hash = ?info_hash); - self.task_single_tracker_monitor_udp(url) + self.task_single_tracker_monitor_udp(url, client.clone()) .instrument(span) .right_future() } @@ -183,7 +187,7 @@ impl TrackerComms { async fn task_single_tracker_monitor_http(&self, mut tracker_url: Url) -> anyhow::Result<()> { let mut event = Some(tracker_comms_http::TrackerRequestEvent::Started); - trace!(url=?tracker_url, "starting monitor"); + trace!(url=%tracker_url, "starting monitor"); loop { let stats = self.stats.get(); let request = tracker_comms_http::TrackerRequest { @@ -227,7 +231,7 @@ impl TrackerComms { } async fn tracker_one_request_http(&self, tracker_url: Url) -> anyhow::Result { - debug!(url = ?tracker_url, "calling tracker over http"); + debug!(url = %tracker_url, "calling tracker over http"); let response: reqwest::Response = self.reqwest_client.get(tracker_url).send().await?; if !response.status().is_success() { anyhow::bail!("tracker responded with {:?}", response.status()); @@ -247,19 +251,20 @@ impl TrackerComms { Ok(response.interval) } - async fn task_single_tracker_monitor_udp(&self, url: Url) -> anyhow::Result<()> { + async fn task_single_tracker_monitor_udp( + &self, + url: Url, + client: UdpTrackerClient, + ) -> anyhow::Result<()> { use tracker_comms_udp::*; if url.scheme() != "udp" { bail!("expected UDP scheme in {}", url); } - let hp: (&str, u16) = ( - url.host_str().context("missing host")?, + let hp: (String, u16) = ( + url.host_str().context("missing host")?.to_owned(), url.port().context("missing port")?, ); - let mut requester = UdpTrackerRequester::new(hp) - .await - .context("error creating UDP tracker requester")?; let mut sleep_interval: Option = None; loop { @@ -291,7 +296,7 @@ impl TrackerComms { port: self.tcp_listen_port.unwrap_or(0), }; - match requester.announce(request).await { + match client.announce(&hp, request).await { Ok(response) => { trace!(len = response.addrs.len(), "received announce response"); for addr in response.addrs { @@ -305,7 +310,7 @@ impl TrackerComms { sleep_interval = Some(self.force_tracker_interval.unwrap_or(new_interval)); } Err(e) => { - debug!(url = ?url, "error reading announce response: {e:#}"); + debug!(url = %url, "error reading announce response: {e:#}"); if sleep_interval.is_none() { sleep_interval = Some( self.force_tracker_interval diff --git a/crates/tracker_comms/src/tracker_comms_udp.rs b/crates/tracker_comms/src/tracker_comms_udp.rs index bbfa5e8..30fc65d 100644 --- a/crates/tracker_comms/src/tracker_comms_udp.rs +++ b/crates/tracker_comms/src/tracker_comms_udp.rs @@ -1,15 +1,22 @@ -use std::net::{Ipv4Addr, SocketAddrV4}; +use std::{ + collections::{hash_map::Entry, HashMap}, + ffi::CStr, + net::{Ipv4Addr, SocketAddrV4}, + sync::Arc, + time::{Duration, Instant}, +}; use anyhow::{bail, Context}; -use librqbit_core::hash_id::Id20; +use librqbit_core::{hash_id::Id20, spawn_utils::spawn_with_cancel}; +use parking_lot::RwLock; use rand::Rng; -use tokio::net::ToSocketAddrs; -use tracing::trace; +use tokio_util::sync::CancellationToken; +use tracing::{debug, error_span, trace, warn}; const ACTION_CONNECT: u32 = 0; const ACTION_ANNOUNCE: u32 = 1; // const ACTION_SCRAPE: u32 = 2; -// const ACTION_ERROR: u32 = 3; +const ACTION_ERROR: u32 = 3; pub const EVENT_NONE: u32 = 0; pub const EVENT_COMPLETED: u32 = 1; @@ -44,31 +51,51 @@ pub enum Request { } impl Request { - pub fn serialize(&self, transaction_id: TransactionId, buf: &mut Vec) -> usize { - let cur_len = buf.len(); - match self { - Request::Connect => { - buf.extend_from_slice(&CONNECTION_ID_MAGIC.to_be_bytes()); - buf.extend_from_slice(&ACTION_CONNECT.to_be_bytes()); - buf.extend_from_slice(&transaction_id.to_be_bytes()); - } - Request::Announce(connection_id, fields) => { - buf.extend_from_slice(&connection_id.to_be_bytes()); - buf.extend_from_slice(&ACTION_ANNOUNCE.to_be_bytes()); - buf.extend_from_slice(&transaction_id.to_be_bytes()); - buf.extend_from_slice(&fields.info_hash.0); - buf.extend_from_slice(&fields.peer_id.0); - buf.extend_from_slice(&fields.downloaded.to_be_bytes()); - buf.extend_from_slice(&fields.left.to_be_bytes()); - buf.extend_from_slice(&fields.uploaded.to_be_bytes()); - buf.extend_from_slice(&fields.event.to_be_bytes()); - buf.extend_from_slice(&0u32.to_be_bytes()); // ip address 0 - buf.extend_from_slice(&fields.key.to_be_bytes()); - buf.extend_from_slice(&(-1i32).to_be_bytes()); // num want -1 - buf.extend_from_slice(&fields.port.to_be_bytes()); + pub fn serialize( + &self, + transaction_id: TransactionId, + buf: &mut [u8], + ) -> anyhow::Result { + struct W<'a> { + buf: &'a mut [u8], + offset: usize, + } + impl W<'_> { + fn extend_from_slice(&mut self, s: &[u8]) -> anyhow::Result<()> { + if self.buf.len() < self.offset + s.len() { + bail!("not enough space in buffer") + } + self.buf[self.offset..self.offset + s.len()].copy_from_slice(s); + self.offset += s.len(); + Ok(()) } } - buf.len() - cur_len + + let mut w = W { buf, offset: 0 }; + + match self { + Request::Connect => { + w.extend_from_slice(&CONNECTION_ID_MAGIC.to_be_bytes())?; + w.extend_from_slice(&ACTION_CONNECT.to_be_bytes())?; + w.extend_from_slice(&transaction_id.to_be_bytes())?; + } + Request::Announce(connection_id, fields) => { + w.extend_from_slice(&connection_id.to_be_bytes())?; + w.extend_from_slice(&ACTION_ANNOUNCE.to_be_bytes())?; + w.extend_from_slice(&transaction_id.to_be_bytes())?; + w.extend_from_slice(&fields.info_hash.0)?; + w.extend_from_slice(&fields.peer_id.0)?; + w.extend_from_slice(&fields.downloaded.to_be_bytes())?; + w.extend_from_slice(&fields.left.to_be_bytes())?; + w.extend_from_slice(&fields.uploaded.to_be_bytes())?; + w.extend_from_slice(&fields.event.to_be_bytes())?; + w.extend_from_slice(&0u32.to_be_bytes())?; // ip address 0 + w.extend_from_slice(&fields.key.to_be_bytes())?; + w.extend_from_slice(&(-1i32).to_be_bytes())?; // num want -1 + w.extend_from_slice(&fields.port.to_be_bytes())?; + } + } + Ok(w.offset) } } @@ -86,6 +113,9 @@ pub struct AnnounceResponse { pub enum Response { Connect(ConnectionId), Announce(AnnounceResponse), + #[allow(dead_code)] + Error(String), + Unknown, } fn split_slice(s: &[u8], first_len: usize) -> Option<(&[u8], &[u8])> { @@ -128,7 +158,20 @@ parse_impl!(i16, 2); impl Response { pub fn parse(buf: &[u8]) -> anyhow::Result<(TransactionId, Self)> { let (action, buf) = u32::parse_num(buf).context("can't parse action")?; - let (tid, mut buf) = u32::parse_num(buf).context("can't parse transaction id")?; + let (tid, buf) = u32::parse_num(buf).context("can't parse transaction id")?; + + let response = match Self::parse_response(action, buf) { + Ok(r) => r, + Err(e) => { + debug!("error parsing: {e:#}"); + Response::Unknown + } + }; + + Ok((tid, response)) + } + + fn parse_response(action: u32, mut buf: &[u8]) -> anyhow::Result { let response = match action { ACTION_CONNECT => { let (connection_id, b) = @@ -158,6 +201,15 @@ impl Response { addrs, }) } + ACTION_ERROR => { + let msg = CStr::from_bytes_with_nul(buf) + .ok() + .and_then(|s| s.to_str().ok()) + .or_else(|| std::str::from_utf8(buf).ok()) + .unwrap_or("") + .to_owned(); + return Ok(Response::Error(msg)); + } _ => bail!("unsupported action {action}"), }; @@ -168,92 +220,182 @@ impl Response { ); } - Ok((tid, response)) + Ok(response) } } -pub struct UdpTrackerRequester { - sock: tokio::net::UdpSocket, - connection_id: ConnectionId, - read_buf: Vec, - write_buf: Vec, +pub type TrackerAddr = (String, u16); + +struct ConnectionIdMeta { + id: ConnectionId, + created: Instant, } -impl UdpTrackerRequester { - // Addr is "host:port" - pub async fn new(addr: impl ToSocketAddrs) -> anyhow::Result { +#[derive(Default)] +struct ClientLocked { + connections: HashMap, + transactions: HashMap>, +} + +struct ClientShared { + sock: tokio::net::UdpSocket, + locked: RwLock, +} + +#[derive(Clone)] +pub struct UdpTrackerClient { + state: Arc, +} + +struct TransactionIdGuard<'a> { + tid: TransactionId, + state: &'a ClientShared, +} + +impl Drop for TransactionIdGuard<'_> { + fn drop(&mut self) { + let mut g = self.state.locked.write(); + g.transactions.remove(&self.tid); + } +} + +impl UdpTrackerClient { + pub async fn new(cancel_token: CancellationToken) -> anyhow::Result { let sock = tokio::net::UdpSocket::bind("0.0.0.0:0") .await - .context("error binding UDP socket")?; - sock.connect(addr) - .await - .context("error connecting UDP socket")?; - - let tid = new_transaction_id(); - let mut write_buf = Vec::new(); - let mut read_buf = vec![0u8; 4096]; - - trace!("sending connect request"); - Request::Connect.serialize(tid, &mut write_buf); - - sock.send(&write_buf) - .await - .context("error sending to socket")?; - - let size = sock - .recv(&mut read_buf) - .await - .context("error receiving from socket")?; - - let (rtid, response) = - Response::parse(&read_buf[..size]).context("error parsing response")?; - if tid != rtid { - bail!("expected transaction id {} == {}", tid, rtid); - } - trace!(response=?response, "received"); - - let connection_id = match response { - Response::Connect(connection_id) => connection_id, - other => bail!("unexpected response {other:?}"), + .context("error binding UDP for tracker")?; + let client = Self { + state: Arc::new(ClientShared { + sock, + locked: RwLock::new(Default::default()), + }), }; - trace!(connection_id); + spawn_with_cancel(error_span!("udp_tracker"), cancel_token, { + let client = client.clone(); + async move { client.run().await } + }); - Ok(Self { - sock, - connection_id, - read_buf, - write_buf, - }) + Ok(client) } - pub async fn announce(&mut self, fields: AnnounceFields) -> anyhow::Result { - let request = Request::Announce(self.connection_id, fields); - let response = self.request(request).await?; + async fn run(self) -> anyhow::Result<()> { + let mut buf = [0u8; 16384]; + loop { + let (len, addr) = match self.state.sock.recv_from(&mut buf).await { + Ok(r) => r, + Err(e) => { + warn!("error in UdpSocket::recv_from: {e:#}"); + continue; + } + }; + + let (tid, response) = match Response::parse(&buf[..len]) { + Ok(r) => r, + Err(e) => { + debug!(?addr, "error parsing UDP response: {e:#}"); + continue; + } + }; + + trace!(?tid, ?response, ?addr, "received"); + + let t = self.state.locked.write().transactions.remove(&tid); + match t { + Some(tx) => match tx.send(response) { + Ok(_) => {} + Err(_) => { + debug!(tid, "reader dead"); + } + }, + None => { + debug!(tid, "nowhere to send response"); + } + }; + } + } + + async fn get_connection_id(&self, addr: &TrackerAddr) -> anyhow::Result { + if let Some(m) = self.state.locked.read().connections.get(addr) { + if m.created.elapsed() < Duration::from_secs(60) { + return Ok(m.id); + } + } + + let response = self.request(addr, Request::Connect).await?; + match response { + Response::Connect(connection_id) => { + self.state.locked.write().connections.insert( + addr.clone(), + ConnectionIdMeta { + id: connection_id, + created: Instant::now(), + }, + ); + Ok(connection_id) + } + _ => anyhow::bail!("expected connect response"), + } + } + + async fn request(&self, addr: &TrackerAddr, request: Request) -> anyhow::Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + let tid_g = self.reserve_transaction_id(tx)?; + + let mut write_buf = [0u8; 1024]; + let len = request.serialize(tid_g.tid, &mut write_buf)?; + self.state.sock.send_to(&write_buf[..len], addr).await?; + + let response = tokio::time::timeout(Duration::from_secs(10), rx) + .await + .context("timeout connecting")? + .context("sender dead")?; + match &response { + Response::Error(e) => { + anyhow::bail!("remote errored: {e}") + } + Response::Unknown => { + anyhow::bail!("remote replied with something we could not parse") + } + _ => {} + } + Ok(response) + } + + fn reserve_transaction_id( + &self, + tx: tokio::sync::oneshot::Sender, + ) -> anyhow::Result> { + let mut g = self.state.locked.write(); + for _ in 0..10 { + let t = new_transaction_id(); + match g.transactions.entry(t) { + Entry::Occupied(_) => continue, + Entry::Vacant(vac) => { + vac.insert(tx); + return Ok(TransactionIdGuard { + tid: t, + state: &self.state, + }); + } + } + } + bail!("cant generate transaction id") + } + + pub async fn announce( + &self, + tracker: &TrackerAddr, + fields: AnnounceFields, + ) -> anyhow::Result { + let connection_id = self.get_connection_id(tracker).await?; + let request = Request::Announce(connection_id, fields); + let response = self.request(tracker, request).await?; match response { Response::Announce(r) => Ok(r), other => bail!("unexpected response {other:?}, expected announce"), } } - - pub async fn request(&mut self, request: Request) -> anyhow::Result { - let tid = new_transaction_id(); - self.write_buf.clear(); - let size = request.serialize(tid, &mut self.write_buf); - trace!(request=?request, tid, "sending"); - self.sock - .send(&self.write_buf[..size]) - .await - .context("error sending")?; - let size = self.sock.recv(&mut self.read_buf).await?; - - let (rtid, response) = Response::parse(&self.read_buf[..size])?; - trace!("received response"); - if tid != rtid { - bail!("unexpected transaction id"); - } - Ok(response) - } } #[cfg(test)] @@ -280,12 +422,12 @@ mod tests { sock.connect("opentor.net:6969").await.unwrap(); let tid = new_transaction_id(); - let mut write_buf = Vec::new(); + let mut write_buf = [0u8; 16384]; let mut read_buf = vec![0u8; 4096]; - Request::Connect.serialize(tid, &mut write_buf); + let len = Request::Connect.serialize(tid, &mut write_buf).unwrap(); - sock.send(&write_buf).await.unwrap(); + sock.send(&write_buf[..len]).await.unwrap(); let size = sock.recv(&mut read_buf).await.unwrap(); @@ -314,8 +456,7 @@ mod tests { port: 24563, }, ); - write_buf.clear(); - let size = request.serialize(tid, &mut write_buf); + let size = request.serialize(tid, &mut write_buf).unwrap(); sock.send(&write_buf[..size]).await.unwrap(); let size = sock.recv(&mut read_buf).await.unwrap();