From bed7433d8e605786b7b7803c5827e4c7aaf030e2 Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Thu, 7 Dec 2023 08:10:17 +0000 Subject: [PATCH] Use tokio_util::CancellationToken everywhere --- Cargo.lock | 3 + crates/dht/Cargo.toml | 2 +- crates/dht/examples/dht.rs | 3 +- crates/dht/src/dht.rs | 42 +++++---- crates/dht/src/lib.rs | 13 +-- crates/dht/src/persistence.rs | 64 ++++++------- crates/librqbit/Cargo.toml | 3 +- crates/librqbit/src/dht_utils.rs | 3 +- crates/librqbit/src/session.rs | 71 +++++---------- crates/librqbit/src/torrent_state/live/mod.rs | 36 ++------ crates/librqbit/src/torrent_state/mod.rs | 90 +++++++++++-------- crates/librqbit_core/Cargo.toml | 1 + crates/librqbit_core/src/spawn_utils.rs | 17 ++++ crates/rqbit/src/main.rs | 1 + desktop/src-tauri/Cargo.lock | 3 + desktop/src-tauri/src/main.rs | 2 +- 16 files changed, 176 insertions(+), 178 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cd5519b..eef876b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1290,6 +1290,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-test", + "tokio-util", "tower-http", "tracing", "tracing-subscriber", @@ -1336,6 +1337,7 @@ dependencies = [ "serde", "serde_json", "tokio", + "tokio-util", "tracing", "url", "uuid", @@ -1362,6 +1364,7 @@ dependencies = [ "serde_json", "tokio", "tokio-stream", + "tokio-util", "tracing", "tracing-subscriber", ] diff --git a/crates/dht/Cargo.toml b/crates/dht/Cargo.toml index 220147c..b3eb407 100644 --- a/crates/dht/Cargo.toml +++ b/crates/dht/Cargo.toml @@ -32,10 +32,10 @@ futures = "0.3" rand = "0.8" indexmap = "2" dashmap = {version = "5.5.3", features = ["serde"]} - clone_to_owned = {path="../clone_to_owned", package="librqbit-clone-to-owned", version = "2.2.1"} librqbit-core = {path="../librqbit_core", version = "3.3.0"} chrono = {version = "0.4.31", features = ["serde"]} +tokio-util = "0.7.10" [dev-dependencies] tracing-subscriber = "0.3" diff --git a/crates/dht/examples/dht.rs b/crates/dht/examples/dht.rs index 07a3375..ec436ad 100644 --- a/crates/dht/examples/dht.rs +++ b/crates/dht/examples/dht.rs @@ -16,8 +16,7 @@ async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt::init(); - let (dht, worker) = DhtBuilder::new().await.context("error initializing DHT")?; - tokio::spawn(worker); + let dht = DhtBuilder::new().await.context("error initializing DHT")?; let mut stream = dht.get_peers(info_hash, None)?; diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index da58179..e97762e 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -23,10 +23,14 @@ use anyhow::{bail, Context}; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use bencode::ByteString; use dashmap::DashMap; -use futures::{stream::FuturesUnordered, Future, Stream, StreamExt, TryFutureExt}; +use futures::{stream::FuturesUnordered, Stream, StreamExt, TryFutureExt}; use leaky_bucket::RateLimiter; -use librqbit_core::{id20::Id20, peer_id::generate_peer_id, spawn_utils::spawn}; +use librqbit_core::{ + id20::Id20, + peer_id::generate_peer_id, + spawn_utils::{spawn, spawn_with_cancel}, +}; use parking_lot::RwLock; use serde::Serialize; @@ -35,6 +39,7 @@ use tokio::{ sync::mpsc::{channel, unbounded_channel, Sender, UnboundedReceiver, UnboundedSender}, }; +use tokio_util::sync::CancellationToken; use tracing::{debug, debug_span, error, error_span, info, trace, warn, Instrument}; #[derive(Debug, Serialize)] @@ -535,6 +540,8 @@ pub struct DhtState { // This is to send raw messages worker_sender: UnboundedSender, + cancellation_token: CancellationToken, + pub(crate) peer_store: PeerStore, } @@ -545,6 +552,7 @@ impl DhtState { routing_table: Option, listen_addr: SocketAddr, peer_store: PeerStore, + cancellation_token: CancellationToken, ) -> Self { let routing_table = routing_table.unwrap_or_else(|| RoutingTable::new(id, None)); Self { @@ -556,6 +564,7 @@ impl DhtState { listen_addr, rate_limiter: make_rate_limiter(), peer_store, + cancellation_token, } } @@ -1124,21 +1133,18 @@ pub struct DhtConfig { pub routing_table: Option, pub listen_addr: Option, pub peer_store: Option, + pub cancellation_token: Option, } impl DhtState { - pub async fn new() -> anyhow::Result<( - Arc, - impl Future> + Send + Sync + 'static, - )> { + pub async fn new() -> anyhow::Result> { Self::with_config(DhtConfig::default()).await } - pub async fn with_config( - config: DhtConfig, - ) -> anyhow::Result<( - Arc, - impl Future> + Send + Sync + 'static, - )> { + pub fn cancellation_token(&self) -> &CancellationToken { + &self.cancellation_token + } + + pub async fn with_config(mut config: DhtConfig) -> anyhow::Result> { let socket = match config.listen_addr { Some(addr) => UdpSocket::bind(addr) .await @@ -1159,6 +1165,8 @@ impl DhtState { .bootstrap_addrs .unwrap_or_else(|| crate::DHT_BOOTSTRAP.iter().map(|v| v.to_string()).collect()); + let token = config.cancellation_token.take().unwrap_or_default(); + let (in_tx, in_rx) = unbounded_channel(); let state = Arc::new(Self::new_internal( peer_id, @@ -1166,17 +1174,17 @@ impl DhtState { config.routing_table, listen_addr, config.peer_store.unwrap_or_else(|| PeerStore::new(peer_id)), + token, )); - let run_worker = { + spawn_with_cancel(error_span!("dht"), state.cancellation_token.clone(), { let state = state.clone(); async move { let worker = DhtWorker { socket, dht: state }; - worker.start(in_rx, &bootstrap_addrs).await?; - Ok(()) + worker.start(in_rx, &bootstrap_addrs).await } - }; - Ok((state, run_worker)) + }); + Ok(state) } pub fn get_peers( diff --git a/crates/dht/src/lib.rs b/crates/dht/src/lib.rs index 325c789..94188d0 100644 --- a/crates/dht/src/lib.rs +++ b/crates/dht/src/lib.rs @@ -10,7 +10,6 @@ use std::time::Duration; pub use crate::dht::DhtStats; pub use crate::dht::{DhtConfig, DhtState, RequestPeersStream}; -use futures::Future; pub use librqbit_core::id20::Id20; pub use persistence::{PersistentDht, PersistentDhtConfig}; @@ -27,19 +26,11 @@ pub struct DhtBuilder {} impl DhtBuilder { #[allow(clippy::new_ret_no_self)] - pub async fn new() -> anyhow::Result<( - Dht, - impl Future> + Send + Sync + 'static, - )> { + pub async fn new() -> anyhow::Result { DhtState::new().await } - pub async fn with_config( - config: DhtConfig, - ) -> anyhow::Result<( - Dht, - impl Future> + Send + Sync + 'static, - )> { + pub async fn with_config(config: DhtConfig) -> anyhow::Result { DhtState::with_config(config).await } } diff --git a/crates/dht/src/persistence.rs b/crates/dht/src/persistence.rs index 4b8a44a..30d3f9b 100644 --- a/crates/dht/src/persistence.rs +++ b/crates/dht/src/persistence.rs @@ -1,16 +1,17 @@ // TODO: this now stores only the routing table, but we also need AT LEAST the same socket address... -use futures::Future; use librqbit_core::directories::get_configuration_directory; +use librqbit_core::spawn_utils::spawn_with_cancel; use serde::{Deserialize, Serialize}; use std::fs::OpenOptions; use std::io::{BufReader, BufWriter}; use std::net::SocketAddr; use std::path::{Path, PathBuf}; use std::time::Duration; +use tokio_util::sync::CancellationToken; use anyhow::Context; -use tracing::{debug, error, info, trace, warn}; +use tracing::{debug, error, error_span, info, trace, warn}; use crate::peer_store::PeerStore; use crate::routing_table::RoutingTable; @@ -76,11 +77,8 @@ impl PersistentDht { pub async fn create( config: Option, - ) -> anyhow::Result<( - Dht, - impl Future> + Send + Sync + 'static, - impl Future> + Send + Sync + 'static, - )> { + cancellation_token: Option, + ) -> anyhow::Result { let mut config = config.unwrap_or_default(); let config_filename = match config.config_filename.take() { Some(config_filename) => config_filename, @@ -129,35 +127,41 @@ impl PersistentDht { routing_table, listen_addr, peer_store, + cancellation_token, ..Default::default() }; - let (dht, run_worker) = DhtState::with_config(dht_config).await?; + let dht = DhtState::with_config(dht_config).await?; + spawn_with_cancel( + error_span!("dht_persistence"), + dht.cancellation_token().clone(), + { + let dht = dht.clone(); + let dump_interval = config + .dump_interval + .unwrap_or_else(|| Duration::from_secs(3)); + async move { + let tempfile_name = { + let file_name = format!("dht.json.tmp.{}", std::process::id()); + let mut tmp = config_filename.clone(); + tmp.set_file_name(file_name); + tmp + }; - let run_persistence = { - let dht = dht.clone(); - let dump_interval = config - .dump_interval - .unwrap_or_else(|| Duration::from_secs(3)); - async move { - let tempfile_name = { - let file_name = format!("dht.json.tmp.{}", std::process::id()); - let mut tmp = config_filename.clone(); - tmp.set_file_name(file_name); - tmp - }; + loop { + trace!("sleeping for {:?}", &dump_interval); + tokio::time::sleep(dump_interval).await; - loop { - trace!("sleeping for {:?}", &dump_interval); - tokio::time::sleep(dump_interval).await; - - match dump_dht(&dht, &config_filename, &tempfile_name) { - Ok(_) => debug!("dumped DHT to {:?}", &config_filename), - Err(e) => error!("error dumping DHT to {:?}: {:#}", &config_filename, e), + match dump_dht(&dht, &config_filename, &tempfile_name) { + Ok(_) => debug!("dumped DHT to {:?}", &config_filename), + Err(e) => { + error!("error dumping DHT to {:?}: {:#}", &config_filename, e) + } + } } } - } - }; + }, + ); - Ok((dht, run_worker, run_persistence)) + Ok(dht) } } diff --git a/crates/librqbit/Cargo.toml b/crates/librqbit/Cargo.toml index 4259b90..d3a00bf 100644 --- a/crates/librqbit/Cargo.toml +++ b/crates/librqbit/Cargo.toml @@ -64,8 +64,9 @@ backoff = "0.4.0" dashmap = "5.5.3" base64 = "0.21.5" serde_with = "3.4.0" +tokio-util = "0.7.10" [dev-dependencies] futures = {version = "0.3"} tracing-subscriber = "0.3" -tokio-test = "0.4" \ No newline at end of file +tokio-test = "0.4" diff --git a/crates/librqbit/src/dht_utils.rs b/crates/librqbit/src/dht_utils.rs index ac84fd5..fb4f72c 100644 --- a/crates/librqbit/src/dht_utils.rs +++ b/crates/librqbit/src/dht_utils.rs @@ -107,8 +107,7 @@ mod tests { init_logging(); let info_hash = Id20::from_str("cab507494d02ebb1178b38f2e9d7be299c86b862").unwrap(); - let (dht, run_dht) = DhtBuilder::new().await.unwrap(); - tokio::spawn(run_dht); + let dht = DhtBuilder::new().await.unwrap(); let peer_rx = dht.get_peers(info_hash, None).unwrap(); let peer_id = generate_peer_id(); diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 814ebfc..3a93459 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -21,6 +21,7 @@ use librqbit_core::{ directories::get_configuration_directory, magnet::Magnet, peer_id::generate_peer_id, + spawn_utils::spawn_with_cancel, torrent_metainfo::{torrent_from_bytes, TorrentMetaV1Info, TorrentMetaV1Owned}, }; use parking_lot::RwLock; @@ -32,12 +33,13 @@ use tokio::{ io::AsyncReadExt, net::{TcpListener, TcpStream}, }; +use tokio_util::sync::CancellationToken; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use crate::{ dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult}, peer_connection::{with_timeout, PeerConnectionOptions}, - spawn_utils::{spawn, BlockingSpawner}, + spawn_utils::BlockingSpawner, torrent_state::{ ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive, }, @@ -150,23 +152,6 @@ struct SerializedSessionDatabase { torrents: HashMap, } -fn spawn_with_cancel_token( - mut cancel_rx: tokio::sync::watch::Receiver<()>, - name: &str, - span: tracing::Span, - fut: impl std::future::Future> + Send + 'static, -) { - spawn(name, span, async move { - tokio::select! { - r = fut => r, - _ = cancel_rx.changed() => { - debug!("task canceled"); - Ok(()) - } - } - }); -} - pub struct Session { peer_id: Id20, dht: Option, @@ -178,8 +163,7 @@ pub struct Session { tcp_listen_port: Option, - cancel_tx: tokio::sync::watch::Sender<()>, - cancel_rx: tokio::sync::watch::Receiver<()>, + cancellation_token: CancellationToken, } async fn torrent_from_url(url: &str) -> anyhow::Result { @@ -395,14 +379,17 @@ impl Session { Ok(dir.data_dir().join("session.json")) } + pub fn cancellation_token(&self) -> &CancellationToken { + &self.cancellation_token + } + /// Create a new session with options. pub async fn new_with_opts( output_folder: PathBuf, mut opts: SessionOptions, ) -> anyhow::Result> { let peer_id = opts.peer_id.unwrap_or_else(generate_peer_id); - - let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(()); + let token = CancellationToken::new(); let (tcp_listener, tcp_listen_port) = if let Some(port_range) = opts.listen_port_range { let (l, p) = create_tcp_listener(port_range) @@ -418,24 +405,17 @@ impl Session { None } else { let dht = if opts.disable_dht_persistence { - let (dht, run_worker) = DhtBuilder::with_config(DhtConfig::default()) - .await - .context("error initializing DHT")?; - spawn_with_cancel_token(cancel_rx.clone(), "dht", error_span!("dht"), run_worker); - dht + DhtBuilder::with_config(DhtConfig { + cancellation_token: Some(token.child_token()), + ..Default::default() + }) + .await + .context("error initializing DHT")? } else { let pdht_config = opts.dht_config.take().unwrap_or_default(); - let (dht, run_worker, run_persistence) = PersistentDht::create(Some(pdht_config)) + PersistentDht::create(Some(pdht_config), Some(token.clone())) .await - .context("error initializing persistent DHT")?; - spawn_with_cancel_token(cancel_rx.clone(), "dht", error_span!("dht"), run_worker); - spawn_with_cancel_token( - cancel_rx.clone(), - "dht_persistence", - error_span!("dht_persistence"), - run_persistence, - ); - dht + .context("error initializing persistent DHT")? }; Some(dht) @@ -455,14 +435,12 @@ impl Session { spawner, output_folder, db: RwLock::new(Default::default()), - cancel_rx, - cancel_tx, + cancellation_token: token, tcp_listen_port, }); if let Some(tcp_listener) = tcp_listener { session.spawn( - "tcp listener", error_span!("tcp_listen", port = tcp_listen_port), session.clone().task_tcp_listener(tcp_listener), ); @@ -471,7 +449,6 @@ impl Session { if let Some(listen_port) = tcp_listen_port { if opts.enable_upnp_port_forwarding { session.spawn( - "upnp_forward", error_span!("upnp_forward", port = listen_port), session.clone().task_upnp_port_forwarder(listen_port), ); @@ -489,11 +466,7 @@ impl Session { })?; } let persistence_task = session.clone().task_persistence(); - session.spawn( - "session persistene", - error_span!("session_persistence"), - persistence_task, - ); + session.spawn(error_span!("session_persistence"), persistence_task); } Ok(session) @@ -645,11 +618,10 @@ impl Session { /// Spawn a task in the context of the session. pub fn spawn( &self, - name: &str, span: tracing::Span, fut: impl std::future::Future> + Send + 'static, ) { - spawn_with_cancel_token(self.cancel_rx.clone(), name, span, fut); + spawn_with_cancel(span, self.cancellation_token.clone(), fut); } /// Stop the session and all managed tasks. @@ -666,7 +638,7 @@ impl Session { debug!("error pausing torrent: {e:#}"); } } - let _ = self.cancel_tx.send(()); + self.cancellation_token.cancel(); // this sucks, but hopefully will be enough tokio::time::sleep(Duration::from_secs(1)).await; } @@ -999,6 +971,7 @@ impl Session { builder .overwrite(opts.overwrite) .spawner(self.spawner) + .cancellation_token(self.cancellation_token.child_token()) .peer_id(self.peer_id); if opts.disable_trackers { diff --git a/crates/librqbit/src/torrent_state/live/mod.rs b/crates/librqbit/src/torrent_state/live/mod.rs index a7c0821..e73e91f 100644 --- a/crates/librqbit/src/torrent_state/live/mod.rs +++ b/crates/librqbit/src/torrent_state/live/mod.rs @@ -65,6 +65,7 @@ use itertools::Itertools; use librqbit_core::{ id20::Id20, lengths::{ChunkInfo, Lengths, ValidPieceIndex}, + spawn_utils::spawn_with_cancel, speed_estimator::SpeedEstimator, torrent_metainfo::TorrentMetaV1Info, }; @@ -80,6 +81,7 @@ use tokio::{ }, time::timeout, }; +use tokio_util::sync::CancellationToken; use tracing::{debug, error, error_span, info, trace, warn}; use url::Url; @@ -90,7 +92,6 @@ use crate::{ PeerConnection, PeerConnectionHandler, PeerConnectionOptions, WriterRequest, }, session::CheckedIncomingConnection, - spawn_utils::spawn, torrent_state::{peer::Peer, utils::atomic_inc}, tracker_comms::{TrackerError, TrackerRequest, TrackerRequestEvent, TrackerResponse}, type_aliases::{PeerHandle, BF}, @@ -185,17 +186,16 @@ pub struct TorrentStateLive { finished_notify: Notify, - cancel_tx: tokio::sync::watch::Sender<()>, - cancel_rx: tokio::sync::watch::Receiver<()>, - down_speed_estimator: SpeedEstimator, up_speed_estimator: SpeedEstimator, + cancellation_token: CancellationToken, } impl TorrentStateLive { pub(crate) fn new( paused: TorrentStatePaused, fatal_errors_tx: tokio::sync::oneshot::Sender, + cancellation_token: CancellationToken, ) -> Arc { let (peer_queue_tx, peer_queue_rx) = unbounded_channel(); @@ -206,8 +206,6 @@ impl TorrentStateLive { let needed_bytes = paused.info.lengths.total_length() - have_bytes; let lengths = *paused.chunk_tracker.get_lengths(); - let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(()); - let state = Arc::new(TorrentStateLive { meta: paused.info.clone(), peers: Default::default(), @@ -229,20 +227,17 @@ impl TorrentStateLive { finished_notify: Notify::new(), down_speed_estimator, up_speed_estimator, - cancel_rx, - cancel_tx, + cancellation_token, }); for tracker in state.meta.trackers.iter() { state.spawn( - "tracker_monitor", error_span!(parent: state.meta.span.clone(), "tracker_monitor", url = tracker.to_string()), state.clone().task_single_tracker_monitor(tracker.clone()), ); } state.spawn( - "speed_estimator_updater", error_span!(parent: state.meta.span.clone(), "speed_estimator_updater"), { let state = Arc::downgrade(&state); @@ -273,29 +268,18 @@ impl TorrentStateLive { ); state.spawn( - "peer_adder", error_span!(parent: state.meta.span.clone(), "peer_adder"), state.clone().task_peer_adder(peer_queue_rx), ); state } - fn spawn( + pub(crate) fn spawn( &self, - name: &str, span: tracing::Span, fut: impl std::future::Future> + Send + 'static, ) { - let mut cancel_rx = self.cancel_rx.clone(); - spawn(name, span, async move { - tokio::select! { - r = fut => r, - _ = cancel_rx.changed() => { - debug!("task canceled"); - Ok(()) - } - } - }); + spawn_with_cancel(span, self.cancellation_token.clone(), fut); } pub fn down_speed_estimator(&self) -> &SpeedEstimator { @@ -418,7 +402,6 @@ impl TorrentStateLive { atomic_inc(&counters.incoming_connections); self.spawn( - "incoming peer", error_span!( parent: self.meta.span.clone(), "manage_incoming_peer", @@ -570,7 +553,6 @@ impl TorrentStateLive { let permit = state.peer_semaphore.clone().acquire_owned().await?; state.spawn( - "manage_peer", error_span!(parent: state.meta.span.clone(), "manage_peer", peer = addr.to_string()), state.clone().task_manage_outgoing_peer(addr, permit), ); @@ -682,7 +664,6 @@ impl TorrentStateLive { // We don't want to remember this task as there may be too many. self.spawn( - "transmit_haves", error_span!( parent: self.meta.span.clone(), "transmit_haves", @@ -744,7 +725,7 @@ impl TorrentStateLive { } pub fn pause(&self) -> anyhow::Result { - let _ = self.cancel_tx.send(()); + self.cancellation_token.cancel(); let mut g = self.locked.write(); @@ -971,7 +952,6 @@ impl PeerHandler { if let Some(dur) = backoff { self.state.clone().spawn( - "wait_for_peer", error_span!( parent: self.state.meta.span.clone(), "wait_for_peer", diff --git a/crates/librqbit/src/torrent_state/mod.rs b/crates/librqbit/src/torrent_state/mod.rs index 5639e31..6df68fe 100644 --- a/crates/librqbit/src/torrent_state/mod.rs +++ b/crates/librqbit/src/torrent_state/mod.rs @@ -20,12 +20,14 @@ use librqbit_core::id20::Id20; use librqbit_core::lengths::Lengths; use librqbit_core::peer_id::generate_peer_id; +use librqbit_core::spawn_utils::spawn_with_cancel; use librqbit_core::torrent_metainfo::TorrentMetaV1Info; pub use live::*; use parking_lot::RwLock; use tokio::time::timeout; use tokio_stream::StreamExt; +use tokio_util::sync::CancellationToken; use tracing::debug; use tracing::error_span; use tracing::trace; @@ -33,7 +35,6 @@ use tracing::warn; use url::Url; use crate::chunk_tracker::ChunkTracker; -use crate::spawn_utils::spawn; use crate::spawn_utils::BlockingSpawner; use crate::torrent_state::stats::LiveStats; @@ -91,6 +92,7 @@ pub struct ManagedTorrentInfo { pub struct ManagedTorrent { pub info: Arc, + pub cancellation_token: CancellationToken, pub(crate) only_files: Option>, locked: RwLock, } @@ -179,10 +181,11 @@ impl ManagedTorrent { let spawn_fatal_errors_receiver = |state: &Arc, rx: tokio::sync::oneshot::Receiver| { let span = state.info.span.clone(); + let token = state.cancellation_token.clone(); let state = Arc::downgrade(state); - spawn( - "fatal_errors_receiver", + spawn_with_cancel( error_span!(parent: span, "fatal_errors_receiver"), + token, async move { let e = match rx.await { Ok(e) => e, @@ -191,7 +194,7 @@ impl ManagedTorrent { if let Some(state) = state.upgrade() { state.stop_with_error(e); } else { - warn!("tried to stop the torrent with error, but it's couldn't upgrade the arc"); + warn!("tried to stop the torrent with error, but couldn't upgrade the arc"); } Ok(()) }, @@ -203,40 +206,42 @@ impl ManagedTorrent { initial_peers: Vec, peer_rx: Option, ) { - let span = live.meta().span.clone(); - let live = Arc::downgrade(live); - spawn( - "external_peer_adder", - error_span!(parent: span, "external_peer_adder"), - async move { - { - let live: Arc = - live.upgrade().context("no longer live")?; + live.spawn( + error_span!(parent: live.meta().span.clone(), "external_peer_adder"), + { + let live = live.clone(); + async move { trace!("adding {} initial peers", initial_peers.len()); for peer in initial_peers { live.add_peer_if_not_seen(peer).context("torrent closed")?; } - } - let mut peer_rx = if let Some(peer_rx) = peer_rx { - peer_rx - } else { - return Ok(()); - }; + let live = { + let weak = Arc::downgrade(&live); + drop(live); + weak + }; - loop { - match timeout(Duration::from_secs(5), peer_rx.next()).await { - Ok(Some(peer)) => { - let live = match live.upgrade() { - Some(live) => live, - None => return Ok(()), - }; - live.add_peer_if_not_seen(peer).context("torrent closed")?; + let mut peer_rx = if let Some(peer_rx) = peer_rx { + peer_rx + } else { + return Ok(()); + }; + + loop { + match timeout(Duration::from_secs(5), peer_rx.next()).await { + Ok(Some(peer)) => { + let live = match live.upgrade() { + Some(live) => live, + None => return Ok(()), + }; + live.add_peer_if_not_seen(peer).context("torrent closed")?; + } + Ok(None) => return Ok(()), + // If timeout, check if the torrent is live. + Err(_) if live.strong_count() == 0 => return Ok(()), + Err(_) => continue, } - Ok(None) => return Ok(()), - // If timeout, check if the torrent is live. - Err(_) if live.strong_count() == 0 => return Ok(()), - Err(_) => continue, } } }, @@ -252,9 +257,10 @@ impl ManagedTorrent { drop(g); let t = self.clone(); let span = self.info().span.clone(); - spawn( - "initialize_and_start", + let token = self.cancellation_token.clone(); + spawn_with_cancel( error_span!(parent: span.clone(), "initialize_and_start"), + token.clone(), async move { match init.check().await { Ok(paused) => { @@ -271,7 +277,7 @@ impl ManagedTorrent { } let (tx, rx) = tokio::sync::oneshot::channel(); - let live = TorrentStateLive::new(paused, tx); + let live = TorrentStateLive::new(paused, tx, token.child_token()); g.state = ManagedTorrentState::Live(live.clone()); spawn_fatal_errors_receiver(&t, rx); @@ -292,7 +298,11 @@ impl ManagedTorrent { ManagedTorrentState::Paused(_) => { let paused = g.state.take().assert_paused(); let (tx, rx) = tokio::sync::oneshot::channel(); - let live = TorrentStateLive::new(paused, tx); + let live = TorrentStateLive::new( + paused, + tx, + self.cancellation_token.child_token().clone(), + ); g.state = ManagedTorrentState::Live(live.clone()); spawn_fatal_errors_receiver(self, rx); spawn_peer_adder(&live, initial_peers, peer_rx); @@ -409,6 +419,7 @@ pub struct ManagedTorrentBuilder { peer_id: Option, overwrite: bool, spawner: Option, + cancellation_token: Option, } impl ManagedTorrentBuilder { @@ -429,9 +440,15 @@ impl ManagedTorrentBuilder { trackers: Default::default(), peer_id: None, overwrite: false, + cancellation_token: None, } } + pub fn cancellation_token(&mut self, token: CancellationToken) -> &mut Self { + self.cancellation_token = Some(token); + self + } + pub fn only_files(&mut self, only_files: Vec) -> &mut Self { self.only_files = Some(only_files); self @@ -472,7 +489,7 @@ impl ManagedTorrentBuilder { self } - pub(crate) fn build(self, span: tracing::Span) -> anyhow::Result { + pub(crate) fn build(mut self, span: tracing::Span) -> anyhow::Result { let lengths = Lengths::from_torrent(&self.info)?; let info = Arc::new(ManagedTorrentInfo { span, @@ -499,6 +516,7 @@ impl ManagedTorrentBuilder { locked: RwLock::new(ManagedTorrentLocked { state: ManagedTorrentState::Initializing(initializing), }), + cancellation_token: self.cancellation_token.take().unwrap_or_default(), info, })) } diff --git a/crates/librqbit_core/Cargo.toml b/crates/librqbit_core/Cargo.toml index 0ddd412..ab7db32 100644 --- a/crates/librqbit_core/Cargo.toml +++ b/crates/librqbit_core/Cargo.toml @@ -30,6 +30,7 @@ bencode = {path = "../bencode", default-features=false, package="librqbit-bencod clone_to_owned = {path="../clone_to_owned", package="librqbit-clone-to-owned", version = "2.2.1"} itertools = "0.12" directories = "5" +tokio-util = "0.7.10" [dev-dependencies] serde_json = "1" diff --git a/crates/librqbit_core/src/spawn_utils.rs b/crates/librqbit_core/src/spawn_utils.rs index ac0dd65..7479ca3 100644 --- a/crates/librqbit_core/src/spawn_utils.rs +++ b/crates/librqbit_core/src/spawn_utils.rs @@ -1,3 +1,5 @@ +use anyhow::bail; +use tokio_util::sync::CancellationToken; use tracing::{error, trace, Instrument}; /// Spawns a future with tracing instrumentation. @@ -32,3 +34,18 @@ pub fn spawn( .instrument(span); tokio::task::spawn(fut) } + +pub fn spawn_with_cancel( + span: tracing::Span, + cancellation_token: CancellationToken, + fut: impl std::future::Future> + Send + 'static, +) -> tokio::task::JoinHandle<()> { + spawn(span, async move { + tokio::select! { + _ = cancellation_token.cancelled() => { + bail!("cancelled"); + }, + r = fut => r + } + }) +} diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index 95a7277..45bfe80 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -501,6 +501,7 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> { ) .await .context("error initializing rqbit session")?; + librqbit_spawn( "stats_printer", trace_span!("stats_printer"), diff --git a/desktop/src-tauri/Cargo.lock b/desktop/src-tauri/Cargo.lock index d1be380..e7c8989 100644 --- a/desktop/src-tauri/Cargo.lock +++ b/desktop/src-tauri/Cargo.lock @@ -1900,6 +1900,7 @@ dependencies = [ "size_format", "tokio", "tokio-stream", + "tokio-util", "tower-http", "tracing", "url", @@ -1944,6 +1945,7 @@ dependencies = [ "parking_lot", "serde", "tokio", + "tokio-util", "tracing", "url", "uuid", @@ -1970,6 +1972,7 @@ dependencies = [ "serde_json", "tokio", "tokio-stream", + "tokio-util", "tracing", ] diff --git a/desktop/src-tauri/src/main.rs b/desktop/src-tauri/src/main.rs index 7978cdc..3004fab 100644 --- a/desktop/src-tauri/src/main.rs +++ b/desktop/src-tauri/src/main.rs @@ -101,7 +101,7 @@ async fn api_from_config( librqbit::http_api::HttpApi::new(session.clone(), Some(rust_log_reload_tx.clone())) .make_http_api_and_run(config.http_api.listen_addr, config.http_api.read_only); - session.spawn("http api", error_span!("http_api"), http_api_task); + session.spawn(error_span!("http_api"), http_api_task); } Ok(api) }