Simplify cancellation as peer_rx doesnt need a token no longer

This commit is contained in:
Igor Katson 2024-02-26 20:45:21 +00:00
parent 39330dc717
commit 18f22cf323
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
5 changed files with 103 additions and 90 deletions

1
Cargo.lock generated
View file

@ -1257,6 +1257,7 @@ name = "librqbit"
version = "5.5.0" version = "5.5.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-stream",
"axum 0.7.4", "axum 0.7.4",
"backoff", "backoff",
"base64", "base64",

View file

@ -68,6 +68,7 @@ serde_with = "3.4.0"
tokio-util = "0.7.10" tokio-util = "0.7.10"
bytes = "1.5.0" bytes = "1.5.0"
rlimit = "0.10.1" rlimit = "0.10.1"
async-stream = "0.3.5"
[dev-dependencies] [dev-dependencies]
futures = {version = "0.3"} futures = {version = "0.3"}

View file

@ -756,15 +756,13 @@ impl Session {
self.tcp_listen_port self.tcp_listen_port
}; };
let cancellation_token = self.cancellation_token.child_token();
let cancellation_token_drop_guard = cancellation_token.clone().drop_guard();
let paused = opts.list_only || opts.paused; let paused = opts.list_only || opts.paused;
// The main difference between magnet link and torrent file, is that we need to resolve the magnet link // The main difference between magnet link and torrent file, is that we need to resolve the magnet link
// into a torrent file by connecting to peers that support extended handshakes. // into a torrent file by connecting to peers that support extended handshakes.
// So we must discover at least one peer and connect to it to be able to proceed further. // So we must discover at least one peer and connect to it to be able to proceed further.
let (info_hash, info, trackers, peer_rx, initial_peers, cancellation_token) = match add { let (info_hash, info, trackers, peer_rx, initial_peers) = match add {
AddTorrent::Url(magnet) if magnet.starts_with("magnet:") => { AddTorrent::Url(magnet) if magnet.starts_with("magnet:") => {
let magnet = let magnet =
Magnet::parse(&magnet).context("provided path is not a valid magnet URL")?; Magnet::parse(&magnet).context("provided path is not a valid magnet URL")?;
@ -772,11 +770,9 @@ impl Session {
.as_id20() .as_id20()
.context("magnet link didn't contain a BTv1 infohash")?; .context("magnet link didn't contain a BTv1 infohash")?;
let peer_token = cancellation_token.child_token();
let peer_rx = self.make_peer_rx( let peer_rx = self.make_peer_rx(
info_hash, info_hash,
magnet.trackers.clone(), magnet.trackers.clone(),
peer_token.clone(),
announce_port, announce_port,
opts.force_tracker_interval, opts.force_tracker_interval,
)?; )?;
@ -800,9 +796,6 @@ impl Session {
anyhow::bail!("DHT died, no way to discover torrent metainfo") anyhow::bail!("DHT died, no way to discover torrent metainfo")
} }
}; };
if paused {
peer_token.cancel();
}
debug!(?info, "received result from DHT"); debug!(?info, "received result from DHT");
( (
info_hash, info_hash,
@ -810,7 +803,6 @@ impl Session {
magnet.trackers, magnet.trackers,
Some(peer_rx), Some(peer_rx),
initial_peers, initial_peers,
cancellation_token,
) )
} }
other => { other => {
@ -849,7 +841,6 @@ impl Session {
self.make_peer_rx( self.make_peer_rx(
torrent.info_hash, torrent.info_hash,
trackers.clone(), trackers.clone(),
cancellation_token.clone(),
announce_port, announce_port,
opts.force_tracker_interval, opts.force_tracker_interval,
)? )?
@ -865,13 +856,10 @@ impl Session {
.unwrap_or_default() .unwrap_or_default()
.into_iter() .into_iter()
.collect(), .collect(),
cancellation_token,
) )
} }
}; };
cancellation_token_drop_guard.disarm();
self.main_torrent_info( self.main_torrent_info(
info_hash, info_hash,
info, info,
@ -879,7 +867,6 @@ impl Session {
peer_rx, peer_rx,
initial_peers.into_iter().collect(), initial_peers.into_iter().collect(),
opts, opts,
cancellation_token,
) )
.await .await
} }
@ -893,12 +880,9 @@ impl Session {
peer_rx: Option<PeerStream>, peer_rx: Option<PeerStream>,
initial_peers: Vec<SocketAddr>, initial_peers: Vec<SocketAddr>,
opts: AddTorrentOptions, opts: AddTorrentOptions,
cancellation_token: CancellationToken,
) -> anyhow::Result<AddTorrentResponse> { ) -> anyhow::Result<AddTorrentResponse> {
debug!("Torrent info: {:#?}", &info); debug!("Torrent info: {:#?}", &info);
let drop_guard = cancellation_token.clone().drop_guard();
let get_only_files = let get_only_files =
|only_files: Option<Vec<usize>>, only_files_regex: Option<String>, list_only: bool| { |only_files: Option<Vec<usize>>, only_files_regex: Option<String>, list_only: bool| {
match (only_files, only_files_regex) { match (only_files, only_files_regex) {
@ -1016,20 +1000,16 @@ impl Session {
let span = managed_torrent.info.span.clone(); let span = managed_torrent.info.span.clone();
let _ = span.enter(); let _ = span.enter();
// Just in case, cancel all tasks started for this torrent so far.
// This is defensive, and not proven necessary.
let token = if opts.paused {
cancellation_token.cancel();
self.cancellation_token.child_token()
} else {
cancellation_token
};
managed_torrent managed_torrent
.start(initial_peers, peer_rx, opts.paused, token) .start(
initial_peers,
peer_rx,
opts.paused,
self.cancellation_token.child_token(),
)
.context("error starting torrent")?; .context("error starting torrent")?;
} }
drop_guard.disarm();
Ok(AddTorrentResponse::Added(id, managed_torrent)) Ok(AddTorrentResponse::Added(id, managed_torrent))
} }
@ -1080,7 +1060,6 @@ impl Session {
&self, &self,
info_hash: Id20, info_hash: Id20,
trackers: Vec<String>, trackers: Vec<String>,
cancel: CancellationToken,
announce_port: Option<u16>, announce_port: Option<u16>,
force_tracker_interval: Option<Duration>, force_tracker_interval: Option<Duration>,
) -> anyhow::Result<Option<PeerStream>> { ) -> anyhow::Result<Option<PeerStream>> {
@ -1097,7 +1076,6 @@ impl Session {
// TODO: report actual bytes, not zeroes. // TODO: report actual bytes, not zeroes.
Box::new(()), Box::new(()),
force_tracker_interval, force_tracker_interval,
cancel,
announce_port, announce_port,
); );
@ -1111,15 +1089,18 @@ impl Session {
} }
pub fn unpause(&self, handle: &ManagedTorrentHandle) -> anyhow::Result<()> { pub fn unpause(&self, handle: &ManagedTorrentHandle) -> anyhow::Result<()> {
let token = self.cancellation_token.child_token();
let peer_rx = self.make_peer_rx( let peer_rx = self.make_peer_rx(
handle.info_hash(), handle.info_hash(),
handle.info().trackers.clone().into_iter().collect(), handle.info().trackers.clone().into_iter().collect(),
token.clone(),
self.tcp_listen_port, self.tcp_listen_port,
handle.info().options.force_tracker_interval, handle.info().options.force_tracker_interval,
)?; )?;
handle.start(Default::default(), peer_rx, false, token)?; handle.start(
Default::default(),
peer_rx,
false,
self.cancellation_token.child_token(),
)?;
Ok(()) Ok(())
} }
} }

View file

@ -4,13 +4,15 @@ use std::time::Duration;
use anyhow::bail; use anyhow::bail;
use anyhow::Context; use anyhow::Context;
use futures::future::Either;
use futures::stream::FuturesUnordered;
use futures::FutureExt;
use futures::Stream; use futures::Stream;
use librqbit_core::spawn_utils::spawn_with_cancel; use futures::StreamExt;
use tokio_util::sync::CancellationToken;
use tracing::debug; use tracing::debug;
use tracing::error_span; use tracing::error_span;
use tracing::info;
use tracing::trace; use tracing::trace;
use tracing::Instrument;
use url::Url; use url::Url;
use crate::tracker_comms_http; use crate::tracker_comms_http;
@ -22,7 +24,6 @@ pub struct TrackerComms {
peer_id: Id20, peer_id: Id20,
stats: Box<dyn TorrentStatsProvider>, stats: Box<dyn TorrentStatsProvider>,
force_tracker_interval: Option<Duration>, force_tracker_interval: Option<Duration>,
cancellation_token: CancellationToken,
tx: Sender, tx: Sender,
tcp_listen_port: Option<u16>, tcp_listen_port: Option<u16>,
} }
@ -64,69 +65,98 @@ impl TrackerComms {
trackers: Vec<String>, trackers: Vec<String>,
stats: Box<dyn TorrentStatsProvider>, stats: Box<dyn TorrentStatsProvider>,
force_interval: Option<Duration>, force_interval: Option<Duration>,
cancellation_token: CancellationToken,
tcp_listen_port: Option<u16>, tcp_listen_port: Option<u16>,
) -> Option<impl Stream<Item = SocketAddr> + Send + Sync + Unpin + 'static> { ) -> Option<impl Stream<Item = SocketAddr> + Unpin + Send + 'static> {
let (tx, rx) = tokio::sync::mpsc::channel::<SocketAddr>(16); let trackers = trackers
let comms = Arc::new(Self { .into_iter()
info_hash, .filter_map(|t| match Url::parse(&t) {
peer_id, Ok(parsed) => Some(parsed),
stats, Err(e) => {
force_tracker_interval: force_interval, debug!("error parsing tracker URL: {}", e);
cancellation_token, None
tx, }
tcp_listen_port, })
}); .collect::<Vec<_>>();
let mut added = false; if trackers.is_empty() {
for tracker in trackers {
if let Err(e) = comms.clone().add_tracker(&tracker) {
info!(tracker = tracker, "error adding tracker: {:#}", e)
} else {
added = true;
}
}
if !added {
return None; return None;
} }
Some(tokio_stream::wrappers::ReceiverStream::new(rx))
let (tx, mut rx) = tokio::sync::mpsc::channel::<SocketAddr>(16);
let s = async_stream::stream! {
use futures::StreamExt;
let mut rx_done = false;
let comms = Arc::new(Self {
info_hash,
peer_id,
stats,
force_tracker_interval: force_interval,
tx,
tcp_listen_port,
});
let mut futures = FuturesUnordered::new();
for tracker in trackers {
if let Ok(fut) = comms.add_tracker(tracker) {
futures.push(fut);
}
}
if futures.is_empty() {
return;
}
while !(futures.is_empty() && rx_done) {
tokio::select! {
addr = rx.recv(), if !rx_done => {
match addr {
Some(addr) => yield addr,
None => rx_done = true
}
}
e = futures.next(), if !futures.is_empty() => {
if let Some(Err(e)) = e {
debug!("error: {e}");
}
}
}
}
};
Some(s.boxed())
} }
fn add_tracker(self: Arc<Self>, tracker: &str) -> anyhow::Result<()> { fn add_tracker(
if tracker.starts_with("http://") || tracker.starts_with("https://") { &self,
spawn_with_cancel( url: Url,
error_span!( ) -> anyhow::Result<
parent: None, Either<
"http_tracker", impl std::future::Future<Output = anyhow::Result<()>> + '_ + Send,
tracker = tracker, impl std::future::Future<Output = anyhow::Result<()>> + '_ + Send,
info_hash = ?self.info_hash >,
), > {
self.cancellation_token.clone(), let info_hash = self.info_hash;
{ if url.scheme() == "http" || url.scheme() == "https" {
let comms = self; let span = error_span!(
let url = Url::parse(tracker).context("can't parse URL")?; parent: None,
async move { comms.task_single_tracker_monitor_http(url).await } "http_tracker",
}, tracker = %url,
); info_hash = ?info_hash
} else if tracker.starts_with("udp://") {
spawn_with_cancel(
error_span!(parent: None, "udp_tracker", tracker = tracker, info_hash = ?self.info_hash),
self.cancellation_token.clone(),
{
let comms = self;
let url = Url::parse(tracker).context("can't parse URL")?;
async move { comms.task_single_tracker_monitor_udp(url).await }
},
); );
Ok(self
.task_single_tracker_monitor_http(url)
.instrument(span)
.left_future())
} else if url.scheme() == "udp" {
let span =
error_span!(parent: None, "udp_tracker", tracker = %url, info_hash = ?info_hash);
Ok(self
.task_single_tracker_monitor_udp(url)
.instrument(span)
.right_future())
} else { } else {
bail!("unsupported tracker url {}", tracker) bail!("unsupported tracker url {}", url)
} }
Ok(())
} }
async fn task_single_tracker_monitor_http( async fn task_single_tracker_monitor_http(&self, mut tracker_url: Url) -> anyhow::Result<()> {
self: Arc<Self>,
mut tracker_url: Url,
) -> anyhow::Result<()> {
let mut event = Some(tracker_comms_http::TrackerRequestEvent::Started); let mut event = Some(tracker_comms_http::TrackerRequestEvent::Started);
loop { loop {
let stats = self.stats.get(); let stats = self.stats.get();

View file

@ -5,4 +5,4 @@ use futures::Stream;
pub type BF = bitvec::vec::BitVec<u8, bitvec::order::Msb0>; pub type BF = bitvec::vec::BitVec<u8, bitvec::order::Msb0>;
pub type PeerHandle = SocketAddr; pub type PeerHandle = SocketAddr;
pub type PeerStream = Box<dyn Stream<Item = SocketAddr> + Unpin + Send + Sync + 'static>; pub type PeerStream = Box<dyn Stream<Item = SocketAddr> + Unpin + Send + 'static>;