Simplify cancellation as peer_rx doesnt need a token no longer

This commit is contained in:
Igor Katson 2024-02-26 20:45:21 +00:00
parent 39330dc717
commit 18f22cf323
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
5 changed files with 103 additions and 90 deletions

View file

@ -4,13 +4,15 @@ use std::time::Duration;
use anyhow::bail;
use anyhow::Context;
use futures::future::Either;
use futures::stream::FuturesUnordered;
use futures::FutureExt;
use futures::Stream;
use librqbit_core::spawn_utils::spawn_with_cancel;
use tokio_util::sync::CancellationToken;
use futures::StreamExt;
use tracing::debug;
use tracing::error_span;
use tracing::info;
use tracing::trace;
use tracing::Instrument;
use url::Url;
use crate::tracker_comms_http;
@ -22,7 +24,6 @@ pub struct TrackerComms {
peer_id: Id20,
stats: Box<dyn TorrentStatsProvider>,
force_tracker_interval: Option<Duration>,
cancellation_token: CancellationToken,
tx: Sender,
tcp_listen_port: Option<u16>,
}
@ -64,69 +65,98 @@ impl TrackerComms {
trackers: Vec<String>,
stats: Box<dyn TorrentStatsProvider>,
force_interval: Option<Duration>,
cancellation_token: CancellationToken,
tcp_listen_port: Option<u16>,
) -> Option<impl Stream<Item = SocketAddr> + Send + Sync + Unpin + 'static> {
let (tx, rx) = tokio::sync::mpsc::channel::<SocketAddr>(16);
let comms = Arc::new(Self {
info_hash,
peer_id,
stats,
force_tracker_interval: force_interval,
cancellation_token,
tx,
tcp_listen_port,
});
let mut added = false;
for tracker in trackers {
if let Err(e) = comms.clone().add_tracker(&tracker) {
info!(tracker = tracker, "error adding tracker: {:#}", e)
} else {
added = true;
}
}
if !added {
) -> Option<impl Stream<Item = SocketAddr> + Unpin + Send + 'static> {
let trackers = trackers
.into_iter()
.filter_map(|t| match Url::parse(&t) {
Ok(parsed) => Some(parsed),
Err(e) => {
debug!("error parsing tracker URL: {}", e);
None
}
})
.collect::<Vec<_>>();
if trackers.is_empty() {
return None;
}
Some(tokio_stream::wrappers::ReceiverStream::new(rx))
let (tx, mut rx) = tokio::sync::mpsc::channel::<SocketAddr>(16);
let s = async_stream::stream! {
use futures::StreamExt;
let mut rx_done = false;
let comms = Arc::new(Self {
info_hash,
peer_id,
stats,
force_tracker_interval: force_interval,
tx,
tcp_listen_port,
});
let mut futures = FuturesUnordered::new();
for tracker in trackers {
if let Ok(fut) = comms.add_tracker(tracker) {
futures.push(fut);
}
}
if futures.is_empty() {
return;
}
while !(futures.is_empty() && rx_done) {
tokio::select! {
addr = rx.recv(), if !rx_done => {
match addr {
Some(addr) => yield addr,
None => rx_done = true
}
}
e = futures.next(), if !futures.is_empty() => {
if let Some(Err(e)) = e {
debug!("error: {e}");
}
}
}
}
};
Some(s.boxed())
}
fn add_tracker(self: Arc<Self>, tracker: &str) -> anyhow::Result<()> {
if tracker.starts_with("http://") || tracker.starts_with("https://") {
spawn_with_cancel(
error_span!(
parent: None,
"http_tracker",
tracker = tracker,
info_hash = ?self.info_hash
),
self.cancellation_token.clone(),
{
let comms = self;
let url = Url::parse(tracker).context("can't parse URL")?;
async move { comms.task_single_tracker_monitor_http(url).await }
},
);
} else if tracker.starts_with("udp://") {
spawn_with_cancel(
error_span!(parent: None, "udp_tracker", tracker = tracker, info_hash = ?self.info_hash),
self.cancellation_token.clone(),
{
let comms = self;
let url = Url::parse(tracker).context("can't parse URL")?;
async move { comms.task_single_tracker_monitor_udp(url).await }
},
fn add_tracker(
&self,
url: Url,
) -> anyhow::Result<
Either<
impl std::future::Future<Output = anyhow::Result<()>> + '_ + Send,
impl std::future::Future<Output = anyhow::Result<()>> + '_ + Send,
>,
> {
let info_hash = self.info_hash;
if url.scheme() == "http" || url.scheme() == "https" {
let span = error_span!(
parent: None,
"http_tracker",
tracker = %url,
info_hash = ?info_hash
);
Ok(self
.task_single_tracker_monitor_http(url)
.instrument(span)
.left_future())
} else if url.scheme() == "udp" {
let span =
error_span!(parent: None, "udp_tracker", tracker = %url, info_hash = ?info_hash);
Ok(self
.task_single_tracker_monitor_udp(url)
.instrument(span)
.right_future())
} else {
bail!("unsupported tracker url {}", tracker)
bail!("unsupported tracker url {}", url)
}
Ok(())
}
async fn task_single_tracker_monitor_http(
self: Arc<Self>,
mut tracker_url: Url,
) -> anyhow::Result<()> {
async fn task_single_tracker_monitor_http(&self, mut tracker_url: Url) -> anyhow::Result<()> {
let mut event = Some(tracker_comms_http::TrackerRequestEvent::Started);
loop {
let stats = self.stats.get();