diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index 7bdfc30..ae1fbbc 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -1,8 +1,9 @@ use std::{ + any, cmp::Reverse, net::SocketAddr, sync::{ - atomic::{AtomicU16, Ordering}, + atomic::{AtomicBool, AtomicU16, Ordering}, Arc, }, task::Poll, @@ -21,7 +22,9 @@ use anyhow::{bail, Context}; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use bencode::ByteString; use dashmap::DashMap; -use futures::{future::BoxFuture, stream::FuturesUnordered, FutureExt, Stream, StreamExt}; +use futures::{ + future::BoxFuture, stream::FuturesUnordered, FutureExt, Stream, StreamExt, TryFutureExt, +}; use indexmap::IndexSet; use leaky_bucket::RateLimiter; use librqbit_core::{id20::Id20, peer_id::generate_peer_id, spawn_utils::spawn}; @@ -62,6 +65,7 @@ struct MaybeUsefulNode { addr: SocketAddr, last_request: Instant, last_response: Option, + errors_in_a_row: usize, returned_peers: bool, } @@ -86,26 +90,31 @@ struct RequestPeers { info_hash: Id20, dht: Arc, useful_nodes: RwLock>, - tx: tokio::sync::mpsc::UnboundedSender, + peer_tx: tokio::sync::mpsc::UnboundedSender, + node_tx: tokio::sync::mpsc::UnboundedSender, } struct RequestPeersStream { rx: tokio::sync::mpsc::UnboundedReceiver, cancel_join_handle: tokio::task::JoinHandle<()>, + request_peers: Arc, } impl RequestPeersStream { fn new(dht: Arc, info_hash: Id20) -> Self { - let (tx, rx) = unbounded_channel(); + let (peer_tx, peer_rx) = unbounded_channel(); + let (node_tx, node_rx) = unbounded_channel(); let rp = Arc::new(RequestPeers { info_hash, dht, useful_nodes: RwLock::new(Vec::new()), - tx, + peer_tx, + node_tx, }); - let join_handle = rp.request_peers_forever(); + let join_handle = rp.clone().request_peers_forever(node_rx); Self { - rx, + request_peers: rp, + rx: peer_rx, cancel_join_handle: join_handle, } } @@ -128,74 +137,100 @@ impl Stream for RequestPeersStream { } } +// So what do I want to do? +// Every 60 seconds, we add root nodes to the queue. +// We poll the following things: +// 1. The queue. If got item from there, insert into the futures unordered. +// 2. Futures unordered. +// If received, send to the resulting one. +struct Tmp {} + impl RequestPeers { - fn request_peers_forever(self: Arc) -> tokio::task::JoinHandle<()> { + fn request_peers_forever( + self: Arc, + mut node_rx: tokio::sync::mpsc::UnboundedReceiver, + ) -> tokio::task::JoinHandle<()> { spawn( error_span!("request_peers", info_hash = format!("{:?}", self.info_hash)), async move { - let mut iteration = 0; - loop { - debug!("iteration {}", iteration); - let sleep_duration = match self.get_peers_root().await { - Ok(_) => Duration::from_secs(60), - Err(e) => { - debug!("error: {e:?}"); - Duration::from_secs(1) + // Looper adds root nodes to the queue every 60 seconds. + let looper = { + let this = self.clone(); + async move { + let mut iteration = 0; + loop { + debug!("iteration {}", iteration); + let sleep = match this.get_peers_root() { + Ok(0) => Duration::from_secs(1), + Ok(n) if n < 8 => REQUERY_INTERVAL / 2, + Ok(_) => REQUERY_INTERVAL, + Err(e) => { + error!("error: {e:?}"); + return Err::<(), anyhow::Error>(e); + } + }; + tokio::time::sleep(sleep).await; + iteration += 1; } - }; - tokio::time::sleep(sleep_duration).await; - iteration += 1; + } + }; + tokio::pin!(looper); + + let mut futs = FuturesUnordered::new(); + loop { + tokio::select! { + addr = node_rx.recv() => { + let addr = addr.unwrap(); + futs.push( + self.get_peers_one(addr) + .map_err(|e| debug!("error: {e:?}")) + .instrument(error_span!("addr", addr=addr.to_string())) + ); + } + Some(_) = futs.next(), if !futs.is_empty() => {} + _ = &mut looper => {} + } } }, ) } - fn request_peers_one<'a>( - self: &'a Arc, - addr: SocketAddr, - ) -> BoxFuture<'a, anyhow::Result<()>> { - let fut = async move { - let response = self - .dht - .request(Request::GetPeers(self.info_hash), addr) - .await?; - let response = match response { - ResponseOrError::Response(r) => r, - ResponseOrError::Error(e) => { - bail!("error response: {:?}", e) - } - }; - self.mark_node_responded(addr, &response); - - if let Some(peers) = response.values { - for peer in peers { - self.tx.send(SocketAddr::V4(peer.addr))?; - } + async fn get_peers_one<'a>(self: &'a Arc, addr: SocketAddr) -> anyhow::Result<()> { + let response = self + .dht + .request(Request::GetPeers(self.info_hash), addr) + .await + .map_err(|e| { + self.mark_node_error(addr); + e + })?; + self.mark_node_responded(addr, &response); + let response = match response { + ResponseOrError::Response(r) => r, + ResponseOrError::Error(e) => { + bail!("error response: {:?}", e) } - - let mut futs = FuturesUnordered::new(); - if let Some(nodes) = response.nodes { - for node in nodes.nodes { - let addr = SocketAddr::V4(node.addr); - if self.should_request_node(node.id, addr) { - futs.push(self.request_peers_one(addr)); - } - } - } - - while let Some(res) = futs.next().await { - if let Err(e) = res { - debug!("error: {e:?}") - } - } - - Ok(()) }; - fut.boxed() + + if let Some(peers) = response.values { + for peer in peers { + self.peer_tx.send(SocketAddr::V4(peer.addr))?; + } + } + + if let Some(nodes) = response.nodes { + for node in nodes.nodes { + let addr = SocketAddr::V4(node.addr); + if self.should_request_node(node.id, addr) { + self.node_tx.send(addr)?; + } + } + } + Ok(()) } - async fn get_peers_root(self: &Arc) -> anyhow::Result<()> { - let mut futs = FuturesUnordered::new(); + fn get_peers_root(self: &Arc) -> anyhow::Result { + let mut count = 0; for (_, addr) in self .dht .routing_table @@ -205,32 +240,42 @@ impl RequestPeers { .map(|n| (n.id(), n.addr())) .take(8) { - futs.push(self.request_peers_one(addr)) + count += 1; + self.node_tx.send(addr)?; } - if futs.is_empty() { - bail!("no nodes in routing table") - } - while let Some(res) = futs.next().await { - if let Err(e) = res { - debug!("error: {e:?}") - } - } - Ok(()) + Ok(count) } - fn mark_node_responded(&self, addr: SocketAddr, response: &Response) { - let mut closest_nodes = self.useful_nodes.write(); - for node in closest_nodes.iter_mut() { - if node.addr == addr { + fn mark_node_error(&self, addr: SocketAddr) -> bool { + self.useful_nodes + .write() + .iter_mut() + .find(|n| n.addr == addr) + .map(|n| { + n.errors_in_a_row += 1; + }) + .is_some() + } + + fn mark_node_responded(&self, addr: SocketAddr, response: &ResponseOrError) -> bool { + self.useful_nodes + .write() + .iter_mut() + .find(|n| n.addr == addr) + .map(|node| { node.last_response = Some(Instant::now()); - node.returned_peers = response - .values - .as_ref() - .map(|c| !c.is_empty()) - .unwrap_or(false); - break; - } - } + node.errors_in_a_row = 0; + match response { + ResponseOrError::Response(r) => { + node.returned_peers = + r.values.as_ref().map(|c| !c.is_empty()).unwrap_or(false) + } + ResponseOrError::Error(_) => { + node.returned_peers = false; + } + } + }) + .is_some() } fn should_request_node(&self, node_id: Id20, addr: SocketAddr) -> bool { @@ -251,6 +296,7 @@ impl RequestPeers { last_request: Instant::now(), last_response: None, returned_peers: false, + errors_in_a_row: 0, }); const LIMIT: usize = 256; @@ -273,7 +319,6 @@ impl RequestPeers { pub struct DhtState { id: Id20, next_transaction_id: AtomicU16, - bootstrapped: Notify, // Created requests: (transaction_id, addr) => Requests. // If we get a response, it gets removed from here. @@ -297,7 +342,6 @@ impl DhtState { let routing_table = routing_table.unwrap_or_else(|| RoutingTable::new(id)); Self { id, - bootstrapped: Default::default(), next_transaction_id: AtomicU16::new(0), inflight_by_transaction_id: Default::default(), routing_table: RwLock::new(routing_table),