Use tokio_util::CancellationToken everywhere

This commit is contained in:
Igor Katson 2023-12-07 08:10:17 +00:00
parent 53868ad45e
commit bed7433d8e
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
16 changed files with 176 additions and 178 deletions

3
Cargo.lock generated
View file

@ -1290,6 +1290,7 @@ dependencies = [
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-test", "tokio-test",
"tokio-util",
"tower-http", "tower-http",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@ -1336,6 +1337,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"tokio", "tokio",
"tokio-util",
"tracing", "tracing",
"url", "url",
"uuid", "uuid",
@ -1362,6 +1364,7 @@ dependencies = [
"serde_json", "serde_json",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-util",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
] ]

View file

@ -32,10 +32,10 @@ futures = "0.3"
rand = "0.8" rand = "0.8"
indexmap = "2" indexmap = "2"
dashmap = {version = "5.5.3", features = ["serde"]} dashmap = {version = "5.5.3", features = ["serde"]}
clone_to_owned = {path="../clone_to_owned", package="librqbit-clone-to-owned", version = "2.2.1"} clone_to_owned = {path="../clone_to_owned", package="librqbit-clone-to-owned", version = "2.2.1"}
librqbit-core = {path="../librqbit_core", version = "3.3.0"} librqbit-core = {path="../librqbit_core", version = "3.3.0"}
chrono = {version = "0.4.31", features = ["serde"]} chrono = {version = "0.4.31", features = ["serde"]}
tokio-util = "0.7.10"
[dev-dependencies] [dev-dependencies]
tracing-subscriber = "0.3" tracing-subscriber = "0.3"

View file

@ -16,8 +16,7 @@ async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let (dht, worker) = DhtBuilder::new().await.context("error initializing DHT")?; let dht = DhtBuilder::new().await.context("error initializing DHT")?;
tokio::spawn(worker);
let mut stream = dht.get_peers(info_hash, None)?; let mut stream = dht.get_peers(info_hash, None)?;

View file

@ -23,10 +23,14 @@ use anyhow::{bail, Context};
use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
use bencode::ByteString; use bencode::ByteString;
use dashmap::DashMap; use dashmap::DashMap;
use futures::{stream::FuturesUnordered, Future, Stream, StreamExt, TryFutureExt}; use futures::{stream::FuturesUnordered, Stream, StreamExt, TryFutureExt};
use leaky_bucket::RateLimiter; use leaky_bucket::RateLimiter;
use librqbit_core::{id20::Id20, peer_id::generate_peer_id, spawn_utils::spawn}; use librqbit_core::{
id20::Id20,
peer_id::generate_peer_id,
spawn_utils::{spawn, spawn_with_cancel},
};
use parking_lot::RwLock; use parking_lot::RwLock;
use serde::Serialize; use serde::Serialize;
@ -35,6 +39,7 @@ use tokio::{
sync::mpsc::{channel, unbounded_channel, Sender, UnboundedReceiver, UnboundedSender}, sync::mpsc::{channel, unbounded_channel, Sender, UnboundedReceiver, UnboundedSender},
}; };
use tokio_util::sync::CancellationToken;
use tracing::{debug, debug_span, error, error_span, info, trace, warn, Instrument}; use tracing::{debug, debug_span, error, error_span, info, trace, warn, Instrument};
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -535,6 +540,8 @@ pub struct DhtState {
// This is to send raw messages // This is to send raw messages
worker_sender: UnboundedSender<WorkerSendRequest>, worker_sender: UnboundedSender<WorkerSendRequest>,
cancellation_token: CancellationToken,
pub(crate) peer_store: PeerStore, pub(crate) peer_store: PeerStore,
} }
@ -545,6 +552,7 @@ impl DhtState {
routing_table: Option<RoutingTable>, routing_table: Option<RoutingTable>,
listen_addr: SocketAddr, listen_addr: SocketAddr,
peer_store: PeerStore, peer_store: PeerStore,
cancellation_token: CancellationToken,
) -> Self { ) -> Self {
let routing_table = routing_table.unwrap_or_else(|| RoutingTable::new(id, None)); let routing_table = routing_table.unwrap_or_else(|| RoutingTable::new(id, None));
Self { Self {
@ -556,6 +564,7 @@ impl DhtState {
listen_addr, listen_addr,
rate_limiter: make_rate_limiter(), rate_limiter: make_rate_limiter(),
peer_store, peer_store,
cancellation_token,
} }
} }
@ -1124,21 +1133,18 @@ pub struct DhtConfig {
pub routing_table: Option<RoutingTable>, pub routing_table: Option<RoutingTable>,
pub listen_addr: Option<SocketAddr>, pub listen_addr: Option<SocketAddr>,
pub peer_store: Option<PeerStore>, pub peer_store: Option<PeerStore>,
pub cancellation_token: Option<CancellationToken>,
} }
impl DhtState { impl DhtState {
pub async fn new() -> anyhow::Result<( pub async fn new() -> anyhow::Result<Arc<Self>> {
Arc<Self>,
impl Future<Output = anyhow::Result<()>> + Send + Sync + 'static,
)> {
Self::with_config(DhtConfig::default()).await Self::with_config(DhtConfig::default()).await
} }
pub async fn with_config( pub fn cancellation_token(&self) -> &CancellationToken {
config: DhtConfig, &self.cancellation_token
) -> anyhow::Result<( }
Arc<Self>,
impl Future<Output = anyhow::Result<()>> + Send + Sync + 'static, pub async fn with_config(mut config: DhtConfig) -> anyhow::Result<Arc<Self>> {
)> {
let socket = match config.listen_addr { let socket = match config.listen_addr {
Some(addr) => UdpSocket::bind(addr) Some(addr) => UdpSocket::bind(addr)
.await .await
@ -1159,6 +1165,8 @@ impl DhtState {
.bootstrap_addrs .bootstrap_addrs
.unwrap_or_else(|| crate::DHT_BOOTSTRAP.iter().map(|v| v.to_string()).collect()); .unwrap_or_else(|| crate::DHT_BOOTSTRAP.iter().map(|v| v.to_string()).collect());
let token = config.cancellation_token.take().unwrap_or_default();
let (in_tx, in_rx) = unbounded_channel(); let (in_tx, in_rx) = unbounded_channel();
let state = Arc::new(Self::new_internal( let state = Arc::new(Self::new_internal(
peer_id, peer_id,
@ -1166,17 +1174,17 @@ impl DhtState {
config.routing_table, config.routing_table,
listen_addr, listen_addr,
config.peer_store.unwrap_or_else(|| PeerStore::new(peer_id)), config.peer_store.unwrap_or_else(|| PeerStore::new(peer_id)),
token,
)); ));
let run_worker = { spawn_with_cancel(error_span!("dht"), state.cancellation_token.clone(), {
let state = state.clone(); let state = state.clone();
async move { async move {
let worker = DhtWorker { socket, dht: state }; let worker = DhtWorker { socket, dht: state };
worker.start(in_rx, &bootstrap_addrs).await?; worker.start(in_rx, &bootstrap_addrs).await
Ok(())
} }
}; });
Ok((state, run_worker)) Ok(state)
} }
pub fn get_peers( pub fn get_peers(

View file

@ -10,7 +10,6 @@ use std::time::Duration;
pub use crate::dht::DhtStats; pub use crate::dht::DhtStats;
pub use crate::dht::{DhtConfig, DhtState, RequestPeersStream}; pub use crate::dht::{DhtConfig, DhtState, RequestPeersStream};
use futures::Future;
pub use librqbit_core::id20::Id20; pub use librqbit_core::id20::Id20;
pub use persistence::{PersistentDht, PersistentDhtConfig}; pub use persistence::{PersistentDht, PersistentDhtConfig};
@ -27,19 +26,11 @@ pub struct DhtBuilder {}
impl DhtBuilder { impl DhtBuilder {
#[allow(clippy::new_ret_no_self)] #[allow(clippy::new_ret_no_self)]
pub async fn new() -> anyhow::Result<( pub async fn new() -> anyhow::Result<Dht> {
Dht,
impl Future<Output = anyhow::Result<()>> + Send + Sync + 'static,
)> {
DhtState::new().await DhtState::new().await
} }
pub async fn with_config( pub async fn with_config(config: DhtConfig) -> anyhow::Result<Dht> {
config: DhtConfig,
) -> anyhow::Result<(
Dht,
impl Future<Output = anyhow::Result<()>> + Send + Sync + 'static,
)> {
DhtState::with_config(config).await DhtState::with_config(config).await
} }
} }

View file

@ -1,16 +1,17 @@
// TODO: this now stores only the routing table, but we also need AT LEAST the same socket address... // TODO: this now stores only the routing table, but we also need AT LEAST the same socket address...
use futures::Future;
use librqbit_core::directories::get_configuration_directory; use librqbit_core::directories::get_configuration_directory;
use librqbit_core::spawn_utils::spawn_with_cancel;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fs::OpenOptions; use std::fs::OpenOptions;
use std::io::{BufReader, BufWriter}; use std::io::{BufReader, BufWriter};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::time::Duration; use std::time::Duration;
use tokio_util::sync::CancellationToken;
use anyhow::Context; use anyhow::Context;
use tracing::{debug, error, info, trace, warn}; use tracing::{debug, error, error_span, info, trace, warn};
use crate::peer_store::PeerStore; use crate::peer_store::PeerStore;
use crate::routing_table::RoutingTable; use crate::routing_table::RoutingTable;
@ -76,11 +77,8 @@ impl PersistentDht {
pub async fn create( pub async fn create(
config: Option<PersistentDhtConfig>, config: Option<PersistentDhtConfig>,
) -> anyhow::Result<( cancellation_token: Option<CancellationToken>,
Dht, ) -> anyhow::Result<Dht> {
impl Future<Output = anyhow::Result<()>> + Send + Sync + 'static,
impl Future<Output = anyhow::Result<()>> + Send + Sync + 'static,
)> {
let mut config = config.unwrap_or_default(); let mut config = config.unwrap_or_default();
let config_filename = match config.config_filename.take() { let config_filename = match config.config_filename.take() {
Some(config_filename) => config_filename, Some(config_filename) => config_filename,
@ -129,35 +127,41 @@ impl PersistentDht {
routing_table, routing_table,
listen_addr, listen_addr,
peer_store, peer_store,
cancellation_token,
..Default::default() ..Default::default()
}; };
let (dht, run_worker) = DhtState::with_config(dht_config).await?; let dht = DhtState::with_config(dht_config).await?;
spawn_with_cancel(
error_span!("dht_persistence"),
dht.cancellation_token().clone(),
{
let dht = dht.clone();
let dump_interval = config
.dump_interval
.unwrap_or_else(|| Duration::from_secs(3));
async move {
let tempfile_name = {
let file_name = format!("dht.json.tmp.{}", std::process::id());
let mut tmp = config_filename.clone();
tmp.set_file_name(file_name);
tmp
};
let run_persistence = { loop {
let dht = dht.clone(); trace!("sleeping for {:?}", &dump_interval);
let dump_interval = config tokio::time::sleep(dump_interval).await;
.dump_interval
.unwrap_or_else(|| Duration::from_secs(3));
async move {
let tempfile_name = {
let file_name = format!("dht.json.tmp.{}", std::process::id());
let mut tmp = config_filename.clone();
tmp.set_file_name(file_name);
tmp
};
loop { match dump_dht(&dht, &config_filename, &tempfile_name) {
trace!("sleeping for {:?}", &dump_interval); Ok(_) => debug!("dumped DHT to {:?}", &config_filename),
tokio::time::sleep(dump_interval).await; Err(e) => {
error!("error dumping DHT to {:?}: {:#}", &config_filename, e)
match dump_dht(&dht, &config_filename, &tempfile_name) { }
Ok(_) => debug!("dumped DHT to {:?}", &config_filename), }
Err(e) => error!("error dumping DHT to {:?}: {:#}", &config_filename, e),
} }
} }
} },
}; );
Ok((dht, run_worker, run_persistence)) Ok(dht)
} }
} }

View file

@ -64,6 +64,7 @@ backoff = "0.4.0"
dashmap = "5.5.3" dashmap = "5.5.3"
base64 = "0.21.5" base64 = "0.21.5"
serde_with = "3.4.0" serde_with = "3.4.0"
tokio-util = "0.7.10"
[dev-dependencies] [dev-dependencies]
futures = {version = "0.3"} futures = {version = "0.3"}

View file

@ -107,8 +107,7 @@ mod tests {
init_logging(); init_logging();
let info_hash = Id20::from_str("cab507494d02ebb1178b38f2e9d7be299c86b862").unwrap(); let info_hash = Id20::from_str("cab507494d02ebb1178b38f2e9d7be299c86b862").unwrap();
let (dht, run_dht) = DhtBuilder::new().await.unwrap(); let dht = DhtBuilder::new().await.unwrap();
tokio::spawn(run_dht);
let peer_rx = dht.get_peers(info_hash, None).unwrap(); let peer_rx = dht.get_peers(info_hash, None).unwrap();
let peer_id = generate_peer_id(); let peer_id = generate_peer_id();

View file

@ -21,6 +21,7 @@ use librqbit_core::{
directories::get_configuration_directory, directories::get_configuration_directory,
magnet::Magnet, magnet::Magnet,
peer_id::generate_peer_id, peer_id::generate_peer_id,
spawn_utils::spawn_with_cancel,
torrent_metainfo::{torrent_from_bytes, TorrentMetaV1Info, TorrentMetaV1Owned}, torrent_metainfo::{torrent_from_bytes, TorrentMetaV1Info, TorrentMetaV1Owned},
}; };
use parking_lot::RwLock; use parking_lot::RwLock;
@ -32,12 +33,13 @@ use tokio::{
io::AsyncReadExt, io::AsyncReadExt,
net::{TcpListener, TcpStream}, net::{TcpListener, TcpStream},
}; };
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use tracing::{debug, error, error_span, info, trace, warn, Instrument};
use crate::{ use crate::{
dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult}, dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult},
peer_connection::{with_timeout, PeerConnectionOptions}, peer_connection::{with_timeout, PeerConnectionOptions},
spawn_utils::{spawn, BlockingSpawner}, spawn_utils::BlockingSpawner,
torrent_state::{ torrent_state::{
ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive, ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive,
}, },
@ -150,23 +152,6 @@ struct SerializedSessionDatabase {
torrents: HashMap<usize, SerializedTorrent>, torrents: HashMap<usize, SerializedTorrent>,
} }
fn spawn_with_cancel_token(
mut cancel_rx: tokio::sync::watch::Receiver<()>,
name: &str,
span: tracing::Span,
fut: impl std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
) {
spawn(name, span, async move {
tokio::select! {
r = fut => r,
_ = cancel_rx.changed() => {
debug!("task canceled");
Ok(())
}
}
});
}
pub struct Session { pub struct Session {
peer_id: Id20, peer_id: Id20,
dht: Option<Dht>, dht: Option<Dht>,
@ -178,8 +163,7 @@ pub struct Session {
tcp_listen_port: Option<u16>, tcp_listen_port: Option<u16>,
cancel_tx: tokio::sync::watch::Sender<()>, cancellation_token: CancellationToken,
cancel_rx: tokio::sync::watch::Receiver<()>,
} }
async fn torrent_from_url(url: &str) -> anyhow::Result<TorrentMetaV1Owned> { async fn torrent_from_url(url: &str) -> anyhow::Result<TorrentMetaV1Owned> {
@ -395,14 +379,17 @@ impl Session {
Ok(dir.data_dir().join("session.json")) Ok(dir.data_dir().join("session.json"))
} }
pub fn cancellation_token(&self) -> &CancellationToken {
&self.cancellation_token
}
/// Create a new session with options. /// Create a new session with options.
pub async fn new_with_opts( pub async fn new_with_opts(
output_folder: PathBuf, output_folder: PathBuf,
mut opts: SessionOptions, mut opts: SessionOptions,
) -> anyhow::Result<Arc<Self>> { ) -> anyhow::Result<Arc<Self>> {
let peer_id = opts.peer_id.unwrap_or_else(generate_peer_id); let peer_id = opts.peer_id.unwrap_or_else(generate_peer_id);
let token = CancellationToken::new();
let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(());
let (tcp_listener, tcp_listen_port) = if let Some(port_range) = opts.listen_port_range { let (tcp_listener, tcp_listen_port) = if let Some(port_range) = opts.listen_port_range {
let (l, p) = create_tcp_listener(port_range) let (l, p) = create_tcp_listener(port_range)
@ -418,24 +405,17 @@ impl Session {
None None
} else { } else {
let dht = if opts.disable_dht_persistence { let dht = if opts.disable_dht_persistence {
let (dht, run_worker) = DhtBuilder::with_config(DhtConfig::default()) DhtBuilder::with_config(DhtConfig {
.await cancellation_token: Some(token.child_token()),
.context("error initializing DHT")?; ..Default::default()
spawn_with_cancel_token(cancel_rx.clone(), "dht", error_span!("dht"), run_worker); })
dht .await
.context("error initializing DHT")?
} else { } else {
let pdht_config = opts.dht_config.take().unwrap_or_default(); let pdht_config = opts.dht_config.take().unwrap_or_default();
let (dht, run_worker, run_persistence) = PersistentDht::create(Some(pdht_config)) PersistentDht::create(Some(pdht_config), Some(token.clone()))
.await .await
.context("error initializing persistent DHT")?; .context("error initializing persistent DHT")?
spawn_with_cancel_token(cancel_rx.clone(), "dht", error_span!("dht"), run_worker);
spawn_with_cancel_token(
cancel_rx.clone(),
"dht_persistence",
error_span!("dht_persistence"),
run_persistence,
);
dht
}; };
Some(dht) Some(dht)
@ -455,14 +435,12 @@ impl Session {
spawner, spawner,
output_folder, output_folder,
db: RwLock::new(Default::default()), db: RwLock::new(Default::default()),
cancel_rx, cancellation_token: token,
cancel_tx,
tcp_listen_port, tcp_listen_port,
}); });
if let Some(tcp_listener) = tcp_listener { if let Some(tcp_listener) = tcp_listener {
session.spawn( session.spawn(
"tcp listener",
error_span!("tcp_listen", port = tcp_listen_port), error_span!("tcp_listen", port = tcp_listen_port),
session.clone().task_tcp_listener(tcp_listener), session.clone().task_tcp_listener(tcp_listener),
); );
@ -471,7 +449,6 @@ impl Session {
if let Some(listen_port) = tcp_listen_port { if let Some(listen_port) = tcp_listen_port {
if opts.enable_upnp_port_forwarding { if opts.enable_upnp_port_forwarding {
session.spawn( session.spawn(
"upnp_forward",
error_span!("upnp_forward", port = listen_port), error_span!("upnp_forward", port = listen_port),
session.clone().task_upnp_port_forwarder(listen_port), session.clone().task_upnp_port_forwarder(listen_port),
); );
@ -489,11 +466,7 @@ impl Session {
})?; })?;
} }
let persistence_task = session.clone().task_persistence(); let persistence_task = session.clone().task_persistence();
session.spawn( session.spawn(error_span!("session_persistence"), persistence_task);
"session persistene",
error_span!("session_persistence"),
persistence_task,
);
} }
Ok(session) Ok(session)
@ -645,11 +618,10 @@ impl Session {
/// Spawn a task in the context of the session. /// Spawn a task in the context of the session.
pub fn spawn( pub fn spawn(
&self, &self,
name: &str,
span: tracing::Span, span: tracing::Span,
fut: impl std::future::Future<Output = anyhow::Result<()>> + Send + 'static, fut: impl std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
) { ) {
spawn_with_cancel_token(self.cancel_rx.clone(), name, span, fut); spawn_with_cancel(span, self.cancellation_token.clone(), fut);
} }
/// Stop the session and all managed tasks. /// Stop the session and all managed tasks.
@ -666,7 +638,7 @@ impl Session {
debug!("error pausing torrent: {e:#}"); debug!("error pausing torrent: {e:#}");
} }
} }
let _ = self.cancel_tx.send(()); self.cancellation_token.cancel();
// this sucks, but hopefully will be enough // this sucks, but hopefully will be enough
tokio::time::sleep(Duration::from_secs(1)).await; tokio::time::sleep(Duration::from_secs(1)).await;
} }
@ -999,6 +971,7 @@ impl Session {
builder builder
.overwrite(opts.overwrite) .overwrite(opts.overwrite)
.spawner(self.spawner) .spawner(self.spawner)
.cancellation_token(self.cancellation_token.child_token())
.peer_id(self.peer_id); .peer_id(self.peer_id);
if opts.disable_trackers { if opts.disable_trackers {

View file

@ -65,6 +65,7 @@ use itertools::Itertools;
use librqbit_core::{ use librqbit_core::{
id20::Id20, id20::Id20,
lengths::{ChunkInfo, Lengths, ValidPieceIndex}, lengths::{ChunkInfo, Lengths, ValidPieceIndex},
spawn_utils::spawn_with_cancel,
speed_estimator::SpeedEstimator, speed_estimator::SpeedEstimator,
torrent_metainfo::TorrentMetaV1Info, torrent_metainfo::TorrentMetaV1Info,
}; };
@ -80,6 +81,7 @@ use tokio::{
}, },
time::timeout, time::timeout,
}; };
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, error_span, info, trace, warn}; use tracing::{debug, error, error_span, info, trace, warn};
use url::Url; use url::Url;
@ -90,7 +92,6 @@ use crate::{
PeerConnection, PeerConnectionHandler, PeerConnectionOptions, WriterRequest, PeerConnection, PeerConnectionHandler, PeerConnectionOptions, WriterRequest,
}, },
session::CheckedIncomingConnection, session::CheckedIncomingConnection,
spawn_utils::spawn,
torrent_state::{peer::Peer, utils::atomic_inc}, torrent_state::{peer::Peer, utils::atomic_inc},
tracker_comms::{TrackerError, TrackerRequest, TrackerRequestEvent, TrackerResponse}, tracker_comms::{TrackerError, TrackerRequest, TrackerRequestEvent, TrackerResponse},
type_aliases::{PeerHandle, BF}, type_aliases::{PeerHandle, BF},
@ -185,17 +186,16 @@ pub struct TorrentStateLive {
finished_notify: Notify, finished_notify: Notify,
cancel_tx: tokio::sync::watch::Sender<()>,
cancel_rx: tokio::sync::watch::Receiver<()>,
down_speed_estimator: SpeedEstimator, down_speed_estimator: SpeedEstimator,
up_speed_estimator: SpeedEstimator, up_speed_estimator: SpeedEstimator,
cancellation_token: CancellationToken,
} }
impl TorrentStateLive { impl TorrentStateLive {
pub(crate) fn new( pub(crate) fn new(
paused: TorrentStatePaused, paused: TorrentStatePaused,
fatal_errors_tx: tokio::sync::oneshot::Sender<anyhow::Error>, fatal_errors_tx: tokio::sync::oneshot::Sender<anyhow::Error>,
cancellation_token: CancellationToken,
) -> Arc<Self> { ) -> Arc<Self> {
let (peer_queue_tx, peer_queue_rx) = unbounded_channel(); let (peer_queue_tx, peer_queue_rx) = unbounded_channel();
@ -206,8 +206,6 @@ impl TorrentStateLive {
let needed_bytes = paused.info.lengths.total_length() - have_bytes; let needed_bytes = paused.info.lengths.total_length() - have_bytes;
let lengths = *paused.chunk_tracker.get_lengths(); let lengths = *paused.chunk_tracker.get_lengths();
let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(());
let state = Arc::new(TorrentStateLive { let state = Arc::new(TorrentStateLive {
meta: paused.info.clone(), meta: paused.info.clone(),
peers: Default::default(), peers: Default::default(),
@ -229,20 +227,17 @@ impl TorrentStateLive {
finished_notify: Notify::new(), finished_notify: Notify::new(),
down_speed_estimator, down_speed_estimator,
up_speed_estimator, up_speed_estimator,
cancel_rx, cancellation_token,
cancel_tx,
}); });
for tracker in state.meta.trackers.iter() { for tracker in state.meta.trackers.iter() {
state.spawn( state.spawn(
"tracker_monitor",
error_span!(parent: state.meta.span.clone(), "tracker_monitor", url = tracker.to_string()), error_span!(parent: state.meta.span.clone(), "tracker_monitor", url = tracker.to_string()),
state.clone().task_single_tracker_monitor(tracker.clone()), state.clone().task_single_tracker_monitor(tracker.clone()),
); );
} }
state.spawn( state.spawn(
"speed_estimator_updater",
error_span!(parent: state.meta.span.clone(), "speed_estimator_updater"), error_span!(parent: state.meta.span.clone(), "speed_estimator_updater"),
{ {
let state = Arc::downgrade(&state); let state = Arc::downgrade(&state);
@ -273,29 +268,18 @@ impl TorrentStateLive {
); );
state.spawn( state.spawn(
"peer_adder",
error_span!(parent: state.meta.span.clone(), "peer_adder"), error_span!(parent: state.meta.span.clone(), "peer_adder"),
state.clone().task_peer_adder(peer_queue_rx), state.clone().task_peer_adder(peer_queue_rx),
); );
state state
} }
fn spawn( pub(crate) fn spawn(
&self, &self,
name: &str,
span: tracing::Span, span: tracing::Span,
fut: impl std::future::Future<Output = anyhow::Result<()>> + Send + 'static, fut: impl std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
) { ) {
let mut cancel_rx = self.cancel_rx.clone(); spawn_with_cancel(span, self.cancellation_token.clone(), fut);
spawn(name, span, async move {
tokio::select! {
r = fut => r,
_ = cancel_rx.changed() => {
debug!("task canceled");
Ok(())
}
}
});
} }
pub fn down_speed_estimator(&self) -> &SpeedEstimator { pub fn down_speed_estimator(&self) -> &SpeedEstimator {
@ -418,7 +402,6 @@ impl TorrentStateLive {
atomic_inc(&counters.incoming_connections); atomic_inc(&counters.incoming_connections);
self.spawn( self.spawn(
"incoming peer",
error_span!( error_span!(
parent: self.meta.span.clone(), parent: self.meta.span.clone(),
"manage_incoming_peer", "manage_incoming_peer",
@ -570,7 +553,6 @@ impl TorrentStateLive {
let permit = state.peer_semaphore.clone().acquire_owned().await?; let permit = state.peer_semaphore.clone().acquire_owned().await?;
state.spawn( state.spawn(
"manage_peer",
error_span!(parent: state.meta.span.clone(), "manage_peer", peer = addr.to_string()), error_span!(parent: state.meta.span.clone(), "manage_peer", peer = addr.to_string()),
state.clone().task_manage_outgoing_peer(addr, permit), state.clone().task_manage_outgoing_peer(addr, permit),
); );
@ -682,7 +664,6 @@ impl TorrentStateLive {
// We don't want to remember this task as there may be too many. // We don't want to remember this task as there may be too many.
self.spawn( self.spawn(
"transmit_haves",
error_span!( error_span!(
parent: self.meta.span.clone(), parent: self.meta.span.clone(),
"transmit_haves", "transmit_haves",
@ -744,7 +725,7 @@ impl TorrentStateLive {
} }
pub fn pause(&self) -> anyhow::Result<TorrentStatePaused> { pub fn pause(&self) -> anyhow::Result<TorrentStatePaused> {
let _ = self.cancel_tx.send(()); self.cancellation_token.cancel();
let mut g = self.locked.write(); let mut g = self.locked.write();
@ -971,7 +952,6 @@ impl PeerHandler {
if let Some(dur) = backoff { if let Some(dur) = backoff {
self.state.clone().spawn( self.state.clone().spawn(
"wait_for_peer",
error_span!( error_span!(
parent: self.state.meta.span.clone(), parent: self.state.meta.span.clone(),
"wait_for_peer", "wait_for_peer",

View file

@ -20,12 +20,14 @@ use librqbit_core::id20::Id20;
use librqbit_core::lengths::Lengths; use librqbit_core::lengths::Lengths;
use librqbit_core::peer_id::generate_peer_id; use librqbit_core::peer_id::generate_peer_id;
use librqbit_core::spawn_utils::spawn_with_cancel;
use librqbit_core::torrent_metainfo::TorrentMetaV1Info; use librqbit_core::torrent_metainfo::TorrentMetaV1Info;
pub use live::*; pub use live::*;
use parking_lot::RwLock; use parking_lot::RwLock;
use tokio::time::timeout; use tokio::time::timeout;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tokio_util::sync::CancellationToken;
use tracing::debug; use tracing::debug;
use tracing::error_span; use tracing::error_span;
use tracing::trace; use tracing::trace;
@ -33,7 +35,6 @@ use tracing::warn;
use url::Url; use url::Url;
use crate::chunk_tracker::ChunkTracker; use crate::chunk_tracker::ChunkTracker;
use crate::spawn_utils::spawn;
use crate::spawn_utils::BlockingSpawner; use crate::spawn_utils::BlockingSpawner;
use crate::torrent_state::stats::LiveStats; use crate::torrent_state::stats::LiveStats;
@ -91,6 +92,7 @@ pub struct ManagedTorrentInfo {
pub struct ManagedTorrent { pub struct ManagedTorrent {
pub info: Arc<ManagedTorrentInfo>, pub info: Arc<ManagedTorrentInfo>,
pub cancellation_token: CancellationToken,
pub(crate) only_files: Option<Vec<usize>>, pub(crate) only_files: Option<Vec<usize>>,
locked: RwLock<ManagedTorrentLocked>, locked: RwLock<ManagedTorrentLocked>,
} }
@ -179,10 +181,11 @@ impl ManagedTorrent {
let spawn_fatal_errors_receiver = let spawn_fatal_errors_receiver =
|state: &Arc<Self>, rx: tokio::sync::oneshot::Receiver<anyhow::Error>| { |state: &Arc<Self>, rx: tokio::sync::oneshot::Receiver<anyhow::Error>| {
let span = state.info.span.clone(); let span = state.info.span.clone();
let token = state.cancellation_token.clone();
let state = Arc::downgrade(state); let state = Arc::downgrade(state);
spawn( spawn_with_cancel(
"fatal_errors_receiver",
error_span!(parent: span, "fatal_errors_receiver"), error_span!(parent: span, "fatal_errors_receiver"),
token,
async move { async move {
let e = match rx.await { let e = match rx.await {
Ok(e) => e, Ok(e) => e,
@ -191,7 +194,7 @@ impl ManagedTorrent {
if let Some(state) = state.upgrade() { if let Some(state) = state.upgrade() {
state.stop_with_error(e); state.stop_with_error(e);
} else { } else {
warn!("tried to stop the torrent with error, but it's couldn't upgrade the arc"); warn!("tried to stop the torrent with error, but couldn't upgrade the arc");
} }
Ok(()) Ok(())
}, },
@ -203,40 +206,42 @@ impl ManagedTorrent {
initial_peers: Vec<SocketAddr>, initial_peers: Vec<SocketAddr>,
peer_rx: Option<RequestPeersStream>, peer_rx: Option<RequestPeersStream>,
) { ) {
let span = live.meta().span.clone(); live.spawn(
let live = Arc::downgrade(live); error_span!(parent: live.meta().span.clone(), "external_peer_adder"),
spawn( {
"external_peer_adder", let live = live.clone();
error_span!(parent: span, "external_peer_adder"), async move {
async move {
{
let live: Arc<TorrentStateLive> =
live.upgrade().context("no longer live")?;
trace!("adding {} initial peers", initial_peers.len()); trace!("adding {} initial peers", initial_peers.len());
for peer in initial_peers { for peer in initial_peers {
live.add_peer_if_not_seen(peer).context("torrent closed")?; live.add_peer_if_not_seen(peer).context("torrent closed")?;
} }
}
let mut peer_rx = if let Some(peer_rx) = peer_rx { let live = {
peer_rx let weak = Arc::downgrade(&live);
} else { drop(live);
return Ok(()); weak
}; };
loop { let mut peer_rx = if let Some(peer_rx) = peer_rx {
match timeout(Duration::from_secs(5), peer_rx.next()).await { peer_rx
Ok(Some(peer)) => { } else {
let live = match live.upgrade() { return Ok(());
Some(live) => live, };
None => return Ok(()),
}; loop {
live.add_peer_if_not_seen(peer).context("torrent closed")?; match timeout(Duration::from_secs(5), peer_rx.next()).await {
Ok(Some(peer)) => {
let live = match live.upgrade() {
Some(live) => live,
None => return Ok(()),
};
live.add_peer_if_not_seen(peer).context("torrent closed")?;
}
Ok(None) => return Ok(()),
// If timeout, check if the torrent is live.
Err(_) if live.strong_count() == 0 => return Ok(()),
Err(_) => continue,
} }
Ok(None) => return Ok(()),
// If timeout, check if the torrent is live.
Err(_) if live.strong_count() == 0 => return Ok(()),
Err(_) => continue,
} }
} }
}, },
@ -252,9 +257,10 @@ impl ManagedTorrent {
drop(g); drop(g);
let t = self.clone(); let t = self.clone();
let span = self.info().span.clone(); let span = self.info().span.clone();
spawn( let token = self.cancellation_token.clone();
"initialize_and_start", spawn_with_cancel(
error_span!(parent: span.clone(), "initialize_and_start"), error_span!(parent: span.clone(), "initialize_and_start"),
token.clone(),
async move { async move {
match init.check().await { match init.check().await {
Ok(paused) => { Ok(paused) => {
@ -271,7 +277,7 @@ impl ManagedTorrent {
} }
let (tx, rx) = tokio::sync::oneshot::channel(); let (tx, rx) = tokio::sync::oneshot::channel();
let live = TorrentStateLive::new(paused, tx); let live = TorrentStateLive::new(paused, tx, token.child_token());
g.state = ManagedTorrentState::Live(live.clone()); g.state = ManagedTorrentState::Live(live.clone());
spawn_fatal_errors_receiver(&t, rx); spawn_fatal_errors_receiver(&t, rx);
@ -292,7 +298,11 @@ impl ManagedTorrent {
ManagedTorrentState::Paused(_) => { ManagedTorrentState::Paused(_) => {
let paused = g.state.take().assert_paused(); let paused = g.state.take().assert_paused();
let (tx, rx) = tokio::sync::oneshot::channel(); let (tx, rx) = tokio::sync::oneshot::channel();
let live = TorrentStateLive::new(paused, tx); let live = TorrentStateLive::new(
paused,
tx,
self.cancellation_token.child_token().clone(),
);
g.state = ManagedTorrentState::Live(live.clone()); g.state = ManagedTorrentState::Live(live.clone());
spawn_fatal_errors_receiver(self, rx); spawn_fatal_errors_receiver(self, rx);
spawn_peer_adder(&live, initial_peers, peer_rx); spawn_peer_adder(&live, initial_peers, peer_rx);
@ -409,6 +419,7 @@ pub struct ManagedTorrentBuilder {
peer_id: Option<Id20>, peer_id: Option<Id20>,
overwrite: bool, overwrite: bool,
spawner: Option<BlockingSpawner>, spawner: Option<BlockingSpawner>,
cancellation_token: Option<CancellationToken>,
} }
impl ManagedTorrentBuilder { impl ManagedTorrentBuilder {
@ -429,9 +440,15 @@ impl ManagedTorrentBuilder {
trackers: Default::default(), trackers: Default::default(),
peer_id: None, peer_id: None,
overwrite: false, overwrite: false,
cancellation_token: None,
} }
} }
pub fn cancellation_token(&mut self, token: CancellationToken) -> &mut Self {
self.cancellation_token = Some(token);
self
}
pub fn only_files(&mut self, only_files: Vec<usize>) -> &mut Self { pub fn only_files(&mut self, only_files: Vec<usize>) -> &mut Self {
self.only_files = Some(only_files); self.only_files = Some(only_files);
self self
@ -472,7 +489,7 @@ impl ManagedTorrentBuilder {
self self
} }
pub(crate) fn build(self, span: tracing::Span) -> anyhow::Result<ManagedTorrentHandle> { pub(crate) fn build(mut self, span: tracing::Span) -> anyhow::Result<ManagedTorrentHandle> {
let lengths = Lengths::from_torrent(&self.info)?; let lengths = Lengths::from_torrent(&self.info)?;
let info = Arc::new(ManagedTorrentInfo { let info = Arc::new(ManagedTorrentInfo {
span, span,
@ -499,6 +516,7 @@ impl ManagedTorrentBuilder {
locked: RwLock::new(ManagedTorrentLocked { locked: RwLock::new(ManagedTorrentLocked {
state: ManagedTorrentState::Initializing(initializing), state: ManagedTorrentState::Initializing(initializing),
}), }),
cancellation_token: self.cancellation_token.take().unwrap_or_default(),
info, info,
})) }))
} }

View file

@ -30,6 +30,7 @@ bencode = {path = "../bencode", default-features=false, package="librqbit-bencod
clone_to_owned = {path="../clone_to_owned", package="librqbit-clone-to-owned", version = "2.2.1"} clone_to_owned = {path="../clone_to_owned", package="librqbit-clone-to-owned", version = "2.2.1"}
itertools = "0.12" itertools = "0.12"
directories = "5" directories = "5"
tokio-util = "0.7.10"
[dev-dependencies] [dev-dependencies]
serde_json = "1" serde_json = "1"

View file

@ -1,3 +1,5 @@
use anyhow::bail;
use tokio_util::sync::CancellationToken;
use tracing::{error, trace, Instrument}; use tracing::{error, trace, Instrument};
/// Spawns a future with tracing instrumentation. /// Spawns a future with tracing instrumentation.
@ -32,3 +34,18 @@ pub fn spawn(
.instrument(span); .instrument(span);
tokio::task::spawn(fut) tokio::task::spawn(fut)
} }
pub fn spawn_with_cancel(
span: tracing::Span,
cancellation_token: CancellationToken,
fut: impl std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
) -> tokio::task::JoinHandle<()> {
spawn(span, async move {
tokio::select! {
_ = cancellation_token.cancelled() => {
bail!("cancelled");
},
r = fut => r
}
})
}

View file

@ -501,6 +501,7 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> {
) )
.await .await
.context("error initializing rqbit session")?; .context("error initializing rqbit session")?;
librqbit_spawn( librqbit_spawn(
"stats_printer", "stats_printer",
trace_span!("stats_printer"), trace_span!("stats_printer"),

View file

@ -1900,6 +1900,7 @@ dependencies = [
"size_format", "size_format",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-util",
"tower-http", "tower-http",
"tracing", "tracing",
"url", "url",
@ -1944,6 +1945,7 @@ dependencies = [
"parking_lot", "parking_lot",
"serde", "serde",
"tokio", "tokio",
"tokio-util",
"tracing", "tracing",
"url", "url",
"uuid", "uuid",
@ -1970,6 +1972,7 @@ dependencies = [
"serde_json", "serde_json",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-util",
"tracing", "tracing",
] ]

View file

@ -101,7 +101,7 @@ async fn api_from_config(
librqbit::http_api::HttpApi::new(session.clone(), Some(rust_log_reload_tx.clone())) librqbit::http_api::HttpApi::new(session.clone(), Some(rust_log_reload_tx.clone()))
.make_http_api_and_run(config.http_api.listen_addr, config.http_api.read_only); .make_http_api_and_run(config.http_api.listen_addr, config.http_api.read_only);
session.spawn("http api", error_span!("http_api"), http_api_task); session.spawn(error_span!("http_api"), http_api_task);
} }
Ok(api) Ok(api)
} }