diff --git a/Cargo.lock b/Cargo.lock index 55e824b..45b64ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -235,6 +235,7 @@ dependencies = [ "clone_to_owned", "futures", "hex 0.4.3", + "indexmap", "librqbit_core", "log", "parking_lot", diff --git a/crates/dht/Cargo.toml b/crates/dht/Cargo.toml index 2894ca2..7ea32e3 100644 --- a/crates/dht/Cargo.toml +++ b/crates/dht/Cargo.toml @@ -18,6 +18,7 @@ log = "0.4" pretty_env_logger = "0.4" futures = "0.3" rand = "0.8" +indexmap = "1.7" clone_to_owned = {path="../clone_to_owned"} librqbit_core = {path="../librqbit_core"} diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index 843889b..2680236 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -2,6 +2,7 @@ use std::{ collections::{hash_map::Entry, HashMap, HashSet}, net::SocketAddr, sync::Arc, + task::Poll, }; use crate::{ @@ -14,7 +15,8 @@ use crate::{ }; use anyhow::Context; use bencode::ByteString; -use futures::{stream::FuturesUnordered, Stream, StreamExt, TryStreamExt}; +use futures::{stream::FuturesUnordered, Stream, StreamExt}; +use indexmap::IndexSet; use librqbit_core::{id20::Id20, peer_id::generate_peer_id}; use log::{debug, info, trace, warn}; use parking_lot::Mutex; @@ -49,7 +51,7 @@ struct DhtState { // Alternatively, we can lock only the parts that change, and use that internally inside DhtState... sender: UnboundedSender<(Message, SocketAddr)>, - seen_peers: HashMap>, + seen_peers: HashMap>, get_peers_subscribers: HashMap>, made_requests: HashSet<(Request, SocketAddr)>, @@ -231,29 +233,30 @@ impl DhtState { } } - pub fn get_peers( + #[allow(clippy::type_complexity)] + fn get_peers( &mut self, info_hash: Id20, ) -> anyhow::Result<( - Vec, + Option<(usize, usize)>, tokio::sync::broadcast::Receiver, )> { match self.get_peers_subscribers.entry(info_hash) { Entry::Occupied(o) => { - let existing_peers = self - .seen_peers - .get(&info_hash) - .map(|c| c.iter().copied().collect()) - .unwrap_or_default(); + let pos = self.seen_peers.get(&info_hash).and_then(|p| { + if p.is_empty() { + None + } else { + Some((0, p.len())) + } + }); let rx = o.get().subscribe(); - Ok((existing_peers, rx)) + Ok((pos, rx)) } Entry::Vacant(v) => { - // DHT sends peers REALLY fast, so the consumer of this broadcast should not lag behind. - // That's why capacity is so high. - // - // What could be done is we could re-send all known peers once someone lags. Maybe do that... - let (tx, rx) = tokio::sync::broadcast::channel(20000); + // DHT sends peers REALLY fast, so ideally the consumer of this broadcast should not lag behind. + // In case it does though we have PeerStream to replay. + let (tx, rx) = tokio::sync::broadcast::channel(100); v.insert(tx); // We don't need to allocate/collect here, but the borrow checker is not happy otherwise. @@ -268,7 +271,7 @@ impl DhtState { self.send_find_peers_if_not_yet(info_hash, id, addr)?; } - Ok((Vec::new(), rx)) + Ok((None, rx)) } } } @@ -352,25 +355,27 @@ impl DhtState { self.routing_table.add_node(source, source_addr); self.routing_table.mark_response(&source); + let bsender = match self.get_peers_subscribers.get(&target) { + Some(s) => s, + None => { + warn!( + "ignoring get_peers response, no subscribers for {:?}", + target + ); + return Ok(()); + } + }; + if let Some(peers) = data.values { let seen = self.seen_peers.entry(target).or_default(); + for peer in peers.iter() { - seen.insert(SocketAddr::V4(peer.addr)); - } - let bsender = match self.get_peers_subscribers.get(&target) { - Some(s) => s, - None => { - warn!( - "ignoring peers for {:?}: no subscribers left. Peers: {:?}", - target, peers - ); - return Ok(()); + let addr = SocketAddr::V4(peer.addr); + if seen.insert(addr) { + bsender + .send(addr) + .context("error sending peers to subscribers")?; } - }; - for peer in peers.iter() { - bsender - .send(peer.addr.into()) - .context("error sending peers to subscribers")?; } }; if let Some(nodes) = data.nodes { @@ -543,6 +548,61 @@ impl DhtWorker { } } +struct PeerStream { + info_hash: Id20, + state: Arc>, + absolute_stream_pos: usize, + initial_peers_pos: Option<(usize, usize)>, + broadcast_rx: BroadcastStream, +} + +impl Stream for PeerStream { + type Item = SocketAddr; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + loop { + if let Some((pos, end)) = self.initial_peers_pos.take() { + let g = self.state.lock(); + let seen = g.seen_peers.get(&self.info_hash).unwrap(); + let addr = *seen.get_index(pos).unwrap(); + drop(g); + if pos < end { + self.initial_peers_pos = Some((pos + 1, end)); + } + self.absolute_stream_pos += 1; + return Poll::Ready(Some(addr)); + } + + let r = match self.broadcast_rx.poll_next_unpin(cx) { + Poll::Ready(r) => match r { + Some(r) => r, + None => return Poll::Ready(None), + }, + Poll::Pending => return Poll::Pending, + }; + + match r { + Ok(v) => { + self.absolute_stream_pos += 1; + return Poll::Ready(Some(v)); + } + Err(e) => match e { + tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(lagged_by) => { + warn!("peer stream is lagged by {}", lagged_by); + let s = self.absolute_stream_pos; + let e = s + lagged_by as usize; + self.initial_peers_pos = Some((s, e)); + continue; + } + }, + } + } + } +} + impl Dht { pub async fn new() -> anyhow::Result { Self::with_bootstrap_addrs(DHT_BOOTSTRAP).await @@ -578,11 +638,17 @@ impl Dht { pub async fn get_peers( &self, info_hash: Id20, - ) -> anyhow::Result> + Unpin> { - let (initial_peers, rx) = self.state.lock().get_peers(info_hash)?; - let rx = BroadcastStream::new(rx).map_err(|e| e.into()); - let rx = futures::stream::iter(initial_peers).map(Ok).chain(rx); - Ok(rx) + ) -> anyhow::Result + Unpin> { + // TODO: we don't need the vec here. + let (pos, rx) = self.state.lock().get_peers(info_hash)?; + let stream = PeerStream { + info_hash, + state: self.state.clone(), + absolute_stream_pos: 0, + initial_peers_pos: pos, + broadcast_rx: BroadcastStream::new(rx), + }; + Ok(stream) } pub fn stats(&self) -> DhtStats { diff --git a/crates/dht/src/main.rs b/crates/dht/src/main.rs index bc69280..762829a 100644 --- a/crates/dht/src/main.rs +++ b/crates/dht/src/main.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, str::FromStr, time::Duration}; +use std::{str::FromStr, time::Duration}; use anyhow::Context; use dht::{Dht, Id20}; @@ -12,13 +12,13 @@ async fn main() -> anyhow::Result<()> { let info_hash = Id20::from_str("64a980abe6e448226bb930ba061592e44c3781a1").unwrap(); let dht = Dht::new().await.context("error initializing DHT")?; let mut stream = dht.get_peers(info_hash).await?; - let mut seen = HashSet::new(); let stats_printer = async { loop { tokio::time::sleep(Duration::from_secs(5)).await; info!("DHT stats: {:?}", dht.stats()); } + #[allow(unreachable_code)] Ok::<_, anyhow::Error>(()) }; @@ -36,15 +36,13 @@ async fn main() -> anyhow::Result<()> { info!("Dumped DHT routing table to {}", filename); }); } + #[allow(unreachable_code)] Ok::<_, anyhow::Error>(()) }; let peer_printer = async { while let Some(peer) = stream.next().await { - let peer = peer.context("error reading peer stream")?; - if seen.insert(peer) { - log::info!("peer found: {}", peer) - } + log::info!("peer found: {}", peer) } Ok(()) }; diff --git a/crates/librqbit/src/dht_utils.rs b/crates/librqbit/src/dht_utils.rs index 3d98a1c..b965946 100644 --- a/crates/librqbit/src/dht_utils.rs +++ b/crates/librqbit/src/dht_utils.rs @@ -103,8 +103,6 @@ mod tests { let dht = Dht::new().await.unwrap(); let peer_rx = dht.get_peers(info_hash).await.unwrap(); let peer_id = generate_peer_id(); - let peer_rx = peer_rx.filter_map(|r| async move { r.ok() }); - tokio::pin!(peer_rx); match read_metainfo_from_peer_receiver(peer_id, info_hash, peer_rx, None).await { ReadMetainfoResult::Found { info, .. } => dbg!(info), ReadMetainfoResult::ChannelClosed { .. } => todo!("should not have happened"), diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index 2f8e40e..107a2a0 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -211,7 +211,6 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()> .ok_or_else(|| anyhow::anyhow!("magnet links without DHT are not supported"))? .get_peers(info_hash) .await?; - let dht_rx = flatten_dht_peers_stream(dht_rx); let trackers = trackers .into_iter() @@ -254,9 +253,7 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()> torrent_from_file(&opts.torrent_path)? }; let dht_rx = match dht.as_ref() { - Some(dht) => Some(flatten_dht_peers_stream( - dht.get_peers(torrent.info_hash).await?, - )), + Some(dht) => Some(dht.get_peers(torrent.info_hash).await?), None => None, }; let trackers = torrent @@ -293,21 +290,6 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()> } } -fn flatten_dht_peers_stream( - rx: impl Stream> + Unpin, -) -> impl Stream + Unpin { - let rx = rx.filter_map(|addr| async move { - match addr { - Ok(addr) => Some(addr), - Err(e) => { - warn!("DHT peer receiver got an error: {:#}", e); - None - } - } - }); - Box::pin(rx) -} - #[allow(clippy::too_many_arguments)] async fn main_torrent_info( opts: Opts,