diff --git a/Cargo.lock b/Cargo.lock index 571545d..08ff622 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1257,6 +1257,7 @@ name = "librqbit" version = "5.5.0" dependencies = [ "anyhow", + "async-stream", "axum 0.7.4", "backoff", "base64", diff --git a/crates/librqbit/Cargo.toml b/crates/librqbit/Cargo.toml index 9b54625..b840cf8 100644 --- a/crates/librqbit/Cargo.toml +++ b/crates/librqbit/Cargo.toml @@ -68,6 +68,7 @@ serde_with = "3.4.0" tokio-util = "0.7.10" bytes = "1.5.0" rlimit = "0.10.1" +async-stream = "0.3.5" [dev-dependencies] futures = {version = "0.3"} diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 95d65ea..529be59 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -756,15 +756,13 @@ impl Session { self.tcp_listen_port }; - let cancellation_token = self.cancellation_token.child_token(); - let cancellation_token_drop_guard = cancellation_token.clone().drop_guard(); let paused = opts.list_only || opts.paused; // The main difference between magnet link and torrent file, is that we need to resolve the magnet link // into a torrent file by connecting to peers that support extended handshakes. // So we must discover at least one peer and connect to it to be able to proceed further. - let (info_hash, info, trackers, peer_rx, initial_peers, cancellation_token) = match add { + let (info_hash, info, trackers, peer_rx, initial_peers) = match add { AddTorrent::Url(magnet) if magnet.starts_with("magnet:") => { let magnet = Magnet::parse(&magnet).context("provided path is not a valid magnet URL")?; @@ -772,11 +770,9 @@ impl Session { .as_id20() .context("magnet link didn't contain a BTv1 infohash")?; - let peer_token = cancellation_token.child_token(); let peer_rx = self.make_peer_rx( info_hash, magnet.trackers.clone(), - peer_token.clone(), announce_port, opts.force_tracker_interval, )?; @@ -800,9 +796,6 @@ impl Session { anyhow::bail!("DHT died, no way to discover torrent metainfo") } }; - if paused { - peer_token.cancel(); - } debug!(?info, "received result from DHT"); ( info_hash, @@ -810,7 +803,6 @@ impl Session { magnet.trackers, Some(peer_rx), initial_peers, - cancellation_token, ) } other => { @@ -849,7 +841,6 @@ impl Session { self.make_peer_rx( torrent.info_hash, trackers.clone(), - cancellation_token.clone(), announce_port, opts.force_tracker_interval, )? @@ -865,13 +856,10 @@ impl Session { .unwrap_or_default() .into_iter() .collect(), - cancellation_token, ) } }; - cancellation_token_drop_guard.disarm(); - self.main_torrent_info( info_hash, info, @@ -879,7 +867,6 @@ impl Session { peer_rx, initial_peers.into_iter().collect(), opts, - cancellation_token, ) .await } @@ -893,12 +880,9 @@ impl Session { peer_rx: Option, initial_peers: Vec, opts: AddTorrentOptions, - cancellation_token: CancellationToken, ) -> anyhow::Result { debug!("Torrent info: {:#?}", &info); - let drop_guard = cancellation_token.clone().drop_guard(); - let get_only_files = |only_files: Option>, only_files_regex: Option, list_only: bool| { match (only_files, only_files_regex) { @@ -1016,20 +1000,16 @@ impl Session { let span = managed_torrent.info.span.clone(); let _ = span.enter(); - // Just in case, cancel all tasks started for this torrent so far. - // This is defensive, and not proven necessary. - let token = if opts.paused { - cancellation_token.cancel(); - self.cancellation_token.child_token() - } else { - cancellation_token - }; managed_torrent - .start(initial_peers, peer_rx, opts.paused, token) + .start( + initial_peers, + peer_rx, + opts.paused, + self.cancellation_token.child_token(), + ) .context("error starting torrent")?; } - drop_guard.disarm(); Ok(AddTorrentResponse::Added(id, managed_torrent)) } @@ -1080,7 +1060,6 @@ impl Session { &self, info_hash: Id20, trackers: Vec, - cancel: CancellationToken, announce_port: Option, force_tracker_interval: Option, ) -> anyhow::Result> { @@ -1097,7 +1076,6 @@ impl Session { // TODO: report actual bytes, not zeroes. Box::new(()), force_tracker_interval, - cancel, announce_port, ); @@ -1111,15 +1089,18 @@ impl Session { } pub fn unpause(&self, handle: &ManagedTorrentHandle) -> anyhow::Result<()> { - let token = self.cancellation_token.child_token(); let peer_rx = self.make_peer_rx( handle.info_hash(), handle.info().trackers.clone().into_iter().collect(), - token.clone(), self.tcp_listen_port, handle.info().options.force_tracker_interval, )?; - handle.start(Default::default(), peer_rx, false, token)?; + handle.start( + Default::default(), + peer_rx, + false, + self.cancellation_token.child_token(), + )?; Ok(()) } } diff --git a/crates/librqbit/src/tracker_comms.rs b/crates/librqbit/src/tracker_comms.rs index d1a7813..dae75ef 100644 --- a/crates/librqbit/src/tracker_comms.rs +++ b/crates/librqbit/src/tracker_comms.rs @@ -4,13 +4,15 @@ use std::time::Duration; use anyhow::bail; use anyhow::Context; +use futures::future::Either; +use futures::stream::FuturesUnordered; +use futures::FutureExt; use futures::Stream; -use librqbit_core::spawn_utils::spawn_with_cancel; -use tokio_util::sync::CancellationToken; +use futures::StreamExt; use tracing::debug; use tracing::error_span; -use tracing::info; use tracing::trace; +use tracing::Instrument; use url::Url; use crate::tracker_comms_http; @@ -22,7 +24,6 @@ pub struct TrackerComms { peer_id: Id20, stats: Box, force_tracker_interval: Option, - cancellation_token: CancellationToken, tx: Sender, tcp_listen_port: Option, } @@ -64,69 +65,98 @@ impl TrackerComms { trackers: Vec, stats: Box, force_interval: Option, - cancellation_token: CancellationToken, tcp_listen_port: Option, - ) -> Option + Send + Sync + Unpin + 'static> { - let (tx, rx) = tokio::sync::mpsc::channel::(16); - let comms = Arc::new(Self { - info_hash, - peer_id, - stats, - force_tracker_interval: force_interval, - cancellation_token, - tx, - tcp_listen_port, - }); - let mut added = false; - for tracker in trackers { - if let Err(e) = comms.clone().add_tracker(&tracker) { - info!(tracker = tracker, "error adding tracker: {:#}", e) - } else { - added = true; - } - } - if !added { + ) -> Option + Unpin + Send + 'static> { + let trackers = trackers + .into_iter() + .filter_map(|t| match Url::parse(&t) { + Ok(parsed) => Some(parsed), + Err(e) => { + debug!("error parsing tracker URL: {}", e); + None + } + }) + .collect::>(); + if trackers.is_empty() { return None; } - Some(tokio_stream::wrappers::ReceiverStream::new(rx)) + + let (tx, mut rx) = tokio::sync::mpsc::channel::(16); + + let s = async_stream::stream! { + use futures::StreamExt; + let mut rx_done = false; + let comms = Arc::new(Self { + info_hash, + peer_id, + stats, + force_tracker_interval: force_interval, + tx, + tcp_listen_port, + }); + let mut futures = FuturesUnordered::new(); + for tracker in trackers { + if let Ok(fut) = comms.add_tracker(tracker) { + futures.push(fut); + } + } + if futures.is_empty() { + return; + } + while !(futures.is_empty() && rx_done) { + tokio::select! { + addr = rx.recv(), if !rx_done => { + match addr { + Some(addr) => yield addr, + None => rx_done = true + } + } + e = futures.next(), if !futures.is_empty() => { + if let Some(Err(e)) = e { + debug!("error: {e}"); + } + } + } + } + }; + + Some(s.boxed()) } - fn add_tracker(self: Arc, tracker: &str) -> anyhow::Result<()> { - if tracker.starts_with("http://") || tracker.starts_with("https://") { - spawn_with_cancel( - error_span!( - parent: None, - "http_tracker", - tracker = tracker, - info_hash = ?self.info_hash - ), - self.cancellation_token.clone(), - { - let comms = self; - let url = Url::parse(tracker).context("can't parse URL")?; - async move { comms.task_single_tracker_monitor_http(url).await } - }, - ); - } else if tracker.starts_with("udp://") { - spawn_with_cancel( - error_span!(parent: None, "udp_tracker", tracker = tracker, info_hash = ?self.info_hash), - self.cancellation_token.clone(), - { - let comms = self; - let url = Url::parse(tracker).context("can't parse URL")?; - async move { comms.task_single_tracker_monitor_udp(url).await } - }, + fn add_tracker( + &self, + url: Url, + ) -> anyhow::Result< + Either< + impl std::future::Future> + '_ + Send, + impl std::future::Future> + '_ + Send, + >, + > { + let info_hash = self.info_hash; + if url.scheme() == "http" || url.scheme() == "https" { + let span = error_span!( + parent: None, + "http_tracker", + tracker = %url, + info_hash = ?info_hash ); + Ok(self + .task_single_tracker_monitor_http(url) + .instrument(span) + .left_future()) + } else if url.scheme() == "udp" { + let span = + error_span!(parent: None, "udp_tracker", tracker = %url, info_hash = ?info_hash); + Ok(self + .task_single_tracker_monitor_udp(url) + .instrument(span) + .right_future()) } else { - bail!("unsupported tracker url {}", tracker) + bail!("unsupported tracker url {}", url) } - Ok(()) } - async fn task_single_tracker_monitor_http( - self: Arc, - mut tracker_url: Url, - ) -> anyhow::Result<()> { + async fn task_single_tracker_monitor_http(&self, mut tracker_url: Url) -> anyhow::Result<()> { let mut event = Some(tracker_comms_http::TrackerRequestEvent::Started); loop { let stats = self.stats.get(); diff --git a/crates/librqbit/src/type_aliases.rs b/crates/librqbit/src/type_aliases.rs index 2b6efa9..d68f5bc 100644 --- a/crates/librqbit/src/type_aliases.rs +++ b/crates/librqbit/src/type_aliases.rs @@ -5,4 +5,4 @@ use futures::Stream; pub type BF = bitvec::vec::BitVec; pub type PeerHandle = SocketAddr; -pub type PeerStream = Box + Unpin + Send + Sync + 'static>; +pub type PeerStream = Box + Unpin + Send + 'static>;