diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index 65b6d76..4492b19 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -44,6 +44,12 @@ struct DhtState { next_transaction_id: u16, outstanding_requests: Vec, routing_table: RoutingTable, + + // This sender sends requests to the worker. + // It is unbounded so that the methods on Dht state don't need to be async. + // If the methods on Dht state were async, we would have a problem, as it's behind + // a lock. + // Alternatively, we can lock only the parts that change, and use that internally inside DhtState... sender: UnboundedSender<(Message, SocketAddr)>, seen_peers: HashMap>, @@ -420,7 +426,6 @@ impl DhtWorker { self, in_tx: UnboundedSender<(Message, SocketAddr)>, in_rx: UnboundedReceiver<(Message, SocketAddr)>, - mut request_rx: Receiver<(Request, UnboundedSender)>, bootstrap_addrs: &[String], ) -> anyhow::Result<()> { let (out_tx, mut out_rx) = channel(1); @@ -501,7 +506,6 @@ impl Dht { Self::with_bootstrap_addrs(DHT_BOOTSTRAP).await } pub async fn with_bootstrap_addrs(bootstrap_addrs: &[&str]) -> anyhow::Result { - let (request_tx, request_rx) = channel(1); let socket = UdpSocket::bind("0.0.0.0:0") .await .context("error binding socket")?; @@ -523,9 +527,7 @@ impl Dht { peer_id, state, }; - let result = worker - .start(in_tx, in_rx, request_rx, &bootstrap_addrs) - .await; + let result = worker.start(in_tx, in_rx, &bootstrap_addrs).await; warn!("DHT worker finished with {:?}", result); } }); diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index aa85358..5fc710f 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -1,9 +1,9 @@ -use std::{fs::File, io::Read, net::SocketAddr, str::FromStr, time::Duration}; +use std::{fs::File, io::Read, net::SocketAddr, pin::Pin, str::FromStr, time::Duration}; use anyhow::Context; use clap::Clap; use dht::{Dht, Id20}; -use futures::StreamExt; +use futures::{Stream, StreamExt}; use librqbit::{ dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult}, generate_peer_id, @@ -209,15 +209,7 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()> .ok_or_else(|| anyhow::anyhow!("magnet links without DHT are not supported"))? .get_peers(info_hash) .await?; - let dht_rx = Box::pin(dht_rx.filter_map(|addr| async move { - match addr { - Ok(addr) => Some(addr), - Err(e) => { - warn!("DHT peer receiver got an error: {:#}", e); - None - } - } - })); + let dht_rx = flatten_dht_peers_stream(dht_rx); let trackers = trackers .into_iter() @@ -259,18 +251,8 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()> torrent_from_file(&opts.torrent_path)? }; let dht_rx = match dht { - Some(dht) => Some(Box::pin( - dht.get_peers(torrent.info_hash) - .await? - .filter_map(|r| async move { - match r { - Ok(addr) => Some(addr), - Err(e) => { - warn!("DHT peer receiver got an error: {:#}", e); - None - } - } - }), + Some(dht) => Some(flatten_dht_peers_stream( + dht.get_peers(torrent.info_hash).await?, )), None => None, }; @@ -307,6 +289,21 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()> } } +fn flatten_dht_peers_stream( + rx: impl Stream> + Unpin, +) -> impl Stream + Unpin { + let rx = rx.filter_map(|addr| async move { + match addr { + Ok(addr) => Some(addr), + Err(e) => { + warn!("DHT peer receiver got an error: {:#}", e); + None + } + } + }); + Box::pin(rx) +} + #[allow(clippy::too_many_arguments)] async fn main_torrent_info( opts: Opts,