diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 60491eb..06bbf3e 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -642,7 +642,7 @@ impl Session { if opts.enable_upnp_port_forwarding { session.spawn( error_span!(parent: session.rs(), "upnp_forward", port = listen_port), - session.clone().task_upnp_port_forwarder(listen_port), + Self::task_upnp_port_forwarder(listen_port), ); } } @@ -691,7 +691,7 @@ impl Session { } async fn check_incoming_connection( - &self, + self: Arc, addr: SocketAddr, mut stream: TcpStream, ) -> anyhow::Result<(Arc, CheckedIncomingConnection)> { @@ -744,6 +744,8 @@ impl Session { async fn task_tcp_listener(self: Arc, l: TcpListener) -> anyhow::Result<()> { let mut futs = FuturesUnordered::new(); + let session = Arc::downgrade(&self); + drop(self); loop { tokio::select! { @@ -751,13 +753,15 @@ impl Session { match r { Ok((stream, addr)) => { trace!("accepted connection from {addr}"); + let session = session.upgrade().context("session is dead")?; + let span = error_span!(parent: session.rs(), "incoming", addr=%addr); futs.push( - self.check_incoming_connection(addr, stream) + session.check_incoming_connection(addr, stream) .map_err(|e| { debug!("error checking incoming connection: {e:#}"); e }) - .instrument(error_span!(parent: self.rs(), "incoming", addr=%addr)) + .instrument(span) ); } Err(e) => { @@ -775,7 +779,7 @@ impl Session { } } - async fn task_upnp_port_forwarder(self: Arc, port: u16) -> anyhow::Result<()> { + async fn task_upnp_port_forwarder(port: u16) -> anyhow::Result<()> { let pf = librqbit_upnp::UpnpPortForwarder::new(vec![port], None)?; pf.run_forever().await } diff --git a/crates/librqbit/src/session_stats/mod.rs b/crates/librqbit/src/session_stats/mod.rs index 8cf5eb0..2939694 100644 --- a/crates/librqbit/src/session_stats/mod.rs +++ b/crates/librqbit/src/session_stats/mod.rs @@ -3,6 +3,7 @@ use std::{ time::{Duration, Instant}, }; +use anyhow::Context; use atomic::AtomicSessionStats; use librqbit_core::speed_estimator::SpeedEstimator; use snapshot::SessionStatsSnapshot; @@ -40,12 +41,13 @@ impl Default for SessionStats { impl Session { pub(crate) fn start_speed_estimator_updater(self: &Arc) { self.spawn(error_span!(parent: self.rs(), "speed_estimator"), { - let s = self.clone(); + let s = Arc::downgrade(self); async move { let mut i = tokio::time::interval(Duration::from_secs(1)); loop { i.tick().await; + let s = s.upgrade().context("session is dead")?; let now = Instant::now(); let fetched = s.stats.atomic.fetched_bytes.load(Ordering::Relaxed); let uploaded = s.stats.atomic.uploaded_bytes.load(Ordering::Relaxed); diff --git a/crates/librqbit/src/tests/e2e.rs b/crates/librqbit/src/tests/e2e.rs index b43b4ac..d762d04 100644 --- a/crates/librqbit/src/tests/e2e.rs +++ b/crates/librqbit/src/tests/e2e.rs @@ -38,7 +38,7 @@ async fn test_e2e_download() { .unwrap(); // Wait to ensure everything is dropped. - tokio::time::sleep(Duration::from_secs(10)).await; + tokio::time::sleep(Duration::from_secs(1)).await; let metrics = tokio::runtime::Handle::current().metrics(); assert_eq!(metrics.num_alive_tasks(), 1); @@ -157,12 +157,13 @@ async fn _test_e2e_download(drop_checks: &DropChecks) { } } info!("torrent is live"); - Ok::<_, anyhow::Error>(SocketAddr::new( + let addr = SocketAddr::new( std::net::IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), session .tcp_listen_port() .context("expected session.tcp_listen_port() to be set")?, - )) + ); + Ok::<_, anyhow::Error>((session, addr)) } .instrument(error_span!("server", id = i)), ); @@ -170,12 +171,15 @@ async fn _test_e2e_download(drop_checks: &DropChecks) { } let mut peers = Vec::new(); + + // This is around just not to drop. + let mut _servers = Vec::new(); for (id, peer) in futures::future::join_all(futs) .await .into_iter() .enumerate() { - let peer = peer + let (server, peer) = peer .with_context(|| format!("join error, server={id}")) .unwrap() .with_context(|| format!("timeout, server={id}")) @@ -183,6 +187,7 @@ async fn _test_e2e_download(drop_checks: &DropChecks) { .with_context(|| format!("server couldn't start, server={id}")) .unwrap(); peers.push(peer); + _servers.push(server); } info!("started all servers, starting client");