From 69b9918e4fefce44d4b55a5b1badc5bcfac264fa Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Wed, 29 Nov 2023 19:34:29 +0000 Subject: [PATCH] Going so far again... --- TODO.md | 3 +- crates/dht/src/dht.rs | 389 +++++++++++++++++--------------- crates/dht/src/routing_table.rs | 2 +- 3 files changed, 204 insertions(+), 190 deletions(-) diff --git a/TODO.md b/TODO.md index b21b187..7214b9f 100644 --- a/TODO.md +++ b/TODO.md @@ -14,9 +14,10 @@ - [x] pause/unpause - [x] remove including from disk - [ ] DHT + - [ ] bootstrapping is lame - [x] many nodes in "Unknown" status, do smth about it - [x] for torrents with a few seeds might be cool to re-query DHT once in a while. - - [ ] don't leak memory when deleting torrents (i.e. remove torrent information (seen peers etc) once the torrent is deleted) + - [x] don't leak memory when deleting torrents (i.e. remove torrent information (seen peers etc) once the torrent is deleted) - [ ] Buckets that have not been changed in 15 minutes should be "refreshed." (per RFC) - [x] it's sending many requests now way too fast, locks up Mac OS UI annoyingly - [ ] After the search is exhausted, the client then inserts the peer contact information for itself onto the responding nodes with IDs closest to the infohash of the torrent. diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index ae1fbbc..f4b6695 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -1,9 +1,8 @@ use std::{ - any, cmp::Reverse, net::SocketAddr, sync::{ - atomic::{AtomicBool, AtomicU16, Ordering}, + atomic::{AtomicU16, Ordering}, Arc, }, task::Poll, @@ -12,8 +11,8 @@ use std::{ use crate::{ bprotocol::{ - self, CompactNodeInfo, CompactPeerInfo, ErrorDescription, FindNodeRequest, GetPeersRequest, - Message, MessageKind, Node, PingRequest, Response, + self, CompactNodeInfo, ErrorDescription, FindNodeRequest, GetPeersRequest, Message, + MessageKind, Node, PingRequest, Response, }, routing_table::{InsertResult, RoutingTable}, REQUERY_INTERVAL, RESPONSE_TIMEOUT, @@ -22,23 +21,18 @@ use anyhow::{bail, Context}; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use bencode::ByteString; use dashmap::DashMap; -use futures::{ - future::BoxFuture, stream::FuturesUnordered, FutureExt, Stream, StreamExt, TryFutureExt, -}; -use indexmap::IndexSet; +use futures::{stream::FuturesUnordered, Stream, StreamExt, TryFutureExt}; + use leaky_bucket::RateLimiter; use librqbit_core::{id20::Id20, peer_id::generate_peer_id, spawn_utils::spawn}; use parking_lot::RwLock; -use rand::Rng; + use serde::Serialize; use tokio::{ net::UdpSocket, - sync::{ - mpsc::{channel, unbounded_channel, Sender, UnboundedReceiver, UnboundedSender}, - Notify, - }, + sync::mpsc::{channel, unbounded_channel, Sender, UnboundedReceiver, UnboundedSender}, }; -use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; + use tracing::{debug, debug_span, error, error_span, info, trace, warn, Instrument}; #[derive(Debug, Serialize)] @@ -86,34 +80,101 @@ fn make_rate_limiter() -> RateLimiter { .build() } -struct RequestPeers { +trait RecursiveRequestCallbacks: Sized + Send + Sync + 'static { + fn on_request_start( + &self, + req: &Arc>, + target_node: Id20, + addr: SocketAddr, + ); + fn on_request_end( + &self, + req: &Arc>, + target_node: Id20, + addr: SocketAddr, + resp: &anyhow::Result, + ); +} + +struct RecursiveRequestCallbacksGetPeers {} +impl RecursiveRequestCallbacks for RecursiveRequestCallbacksGetPeers { + fn on_request_start(&self, _: &Arc>, _: Id20, _: SocketAddr) {} + + fn on_request_end( + &self, + _: &Arc>, + _: Id20, + _: SocketAddr, + _: &anyhow::Result, + ) { + } +} + +struct RecursiveRequestCallbacksFindNodes {} +impl RecursiveRequestCallbacks for RecursiveRequestCallbacksFindNodes { + fn on_request_start( + &self, + req: &Arc>, + target_node: Id20, + addr: SocketAddr, + ) { + match req.dht.routing_table_add_node(target_node, addr) { + InsertResult::WasExisting | InsertResult::ReplacedBad(_) | InsertResult::Added => { + req.dht + .routing_table + .write() + .mark_outgoing_request(&target_node); + } + InsertResult::Ignored => {} + } + } + + fn on_request_end( + &self, + req: &Arc>, + target_node: Id20, + _addr: SocketAddr, + resp: &anyhow::Result, + ) { + let mut table = req.dht.routing_table.write(); + if resp.is_ok() { + table.mark_response(&target_node); + } else { + table.mark_error(&target_node); + } + } +} + +struct RecursiveRequest { info_hash: Id20, + request: Request, dht: Arc, useful_nodes: RwLock>, - peer_tx: tokio::sync::mpsc::UnboundedSender, - node_tx: tokio::sync::mpsc::UnboundedSender, + // peer_tx: tokio::sync::mpsc::UnboundedSender, + // node_tx: tokio::sync::mpsc::UnboundedSender<(Option, SocketAddr)>, + callbacks: C, } struct RequestPeersStream { rx: tokio::sync::mpsc::UnboundedReceiver, cancel_join_handle: tokio::task::JoinHandle<()>, - request_peers: Arc, } impl RequestPeersStream { fn new(dht: Arc, info_hash: Id20) -> Self { let (peer_tx, peer_rx) = unbounded_channel(); let (node_tx, node_rx) = unbounded_channel(); - let rp = Arc::new(RequestPeers { + let rp = Arc::new(RecursiveRequest { info_hash, + request: Request::GetPeers(info_hash), dht, useful_nodes: RwLock::new(Vec::new()), - peer_tx, - node_tx, + // peer_tx, + // node_tx, + callbacks: RecursiveRequestCallbacksGetPeers {}, }); - let join_handle = rp.clone().request_peers_forever(node_rx); + let join_handle = rp.clone().request_peers_forever(node_rx, node_tx, peer_tx); Self { - request_peers: rp, rx: peer_rx, cancel_join_handle: join_handle, } @@ -137,30 +198,73 @@ impl Stream for RequestPeersStream { } } -// So what do I want to do? -// Every 60 seconds, we add root nodes to the queue. -// We poll the following things: -// 1. The queue. If got item from there, insert into the futures unordered. -// 2. Futures unordered. -// If received, send to the resulting one. -struct Tmp {} +impl RecursiveRequest { + async fn find_node( + dht: Arc, + target: Id20, + root_addrs: impl Iterator, + ) -> anyhow::Result<()> { + let (peer_tx, peer_rx) = unbounded_channel(); + drop(peer_rx); -impl RequestPeers { + let (node_tx, mut node_rx) = unbounded_channel(); + let req = Arc::new(RecursiveRequest { + info_hash: target, + request: Request::FindNode(target), + dht, + useful_nodes: RwLock::new(Vec::new()), + // peer_tx: unbounded_channel().0, + // node_tx, + callbacks: RecursiveRequestCallbacksFindNodes {}, + }); + + let mut futs = FuturesUnordered::new(); + + for addr in root_addrs { + node_tx.send((None, addr)).unwrap(); + } + + loop { + tokio::select! { + r = node_rx.recv() => { + let (id, addr) = r.unwrap(); + futs.push( + req.request_one(id, addr, node_tx.clone(), peer_tx.clone()) + .instrument( + error_span!("find_node", target=format!("{target:?}"), addr=addr.to_string()) + ) + ) + }, + Some(f) = futs.next(), if !futs.is_empty() => { + if let Err(e) = f { + error!("error: {e:?}"); + } + } + } + } + Ok(()) + } +} + +impl RecursiveRequest { fn request_peers_forever( self: Arc, - mut node_rx: tokio::sync::mpsc::UnboundedReceiver, + mut node_rx: tokio::sync::mpsc::UnboundedReceiver<(Option, SocketAddr)>, + node_tx: tokio::sync::mpsc::UnboundedSender<(Option, SocketAddr)>, + peer_tx: tokio::sync::mpsc::UnboundedSender, ) -> tokio::task::JoinHandle<()> { spawn( - error_span!("request_peers", info_hash = format!("{:?}", self.info_hash)), + error_span!("get_peers", info_hash = format!("{:?}", self.info_hash)), async move { // Looper adds root nodes to the queue every 60 seconds. let looper = { let this = self.clone(); + let node_tx = node_tx.clone(); async move { let mut iteration = 0; loop { debug!("iteration {}", iteration); - let sleep = match this.get_peers_root() { + let sleep = match this.get_peers_root(&node_tx) { Ok(0) => Duration::from_secs(1), Ok(n) if n < 8 => REQUERY_INTERVAL / 2, Ok(_) => REQUERY_INTERVAL, @@ -180,9 +284,9 @@ impl RequestPeers { loop { tokio::select! { addr = node_rx.recv() => { - let addr = addr.unwrap(); + let (id, addr) = addr.unwrap(); futs.push( - self.get_peers_one(addr) + self.request_one(id, addr, node_tx.clone(), peer_tx.clone()) .map_err(|e| debug!("error: {e:?}")) .instrument(error_span!("addr", addr=addr.to_string())) ); @@ -195,43 +299,12 @@ impl RequestPeers { ) } - async fn get_peers_one<'a>(self: &'a Arc, addr: SocketAddr) -> anyhow::Result<()> { - let response = self - .dht - .request(Request::GetPeers(self.info_hash), addr) - .await - .map_err(|e| { - self.mark_node_error(addr); - e - })?; - self.mark_node_responded(addr, &response); - let response = match response { - ResponseOrError::Response(r) => r, - ResponseOrError::Error(e) => { - bail!("error response: {:?}", e) - } - }; - - if let Some(peers) = response.values { - for peer in peers { - self.peer_tx.send(SocketAddr::V4(peer.addr))?; - } - } - - if let Some(nodes) = response.nodes { - for node in nodes.nodes { - let addr = SocketAddr::V4(node.addr); - if self.should_request_node(node.id, addr) { - self.node_tx.send(addr)?; - } - } - } - Ok(()) - } - - fn get_peers_root(self: &Arc) -> anyhow::Result { + fn get_peers_root( + self: &Arc, + node_tx: &UnboundedSender<(Option, SocketAddr)>, + ) -> anyhow::Result { let mut count = 0; - for (_, addr) in self + for (id, addr) in self .dht .routing_table .read() @@ -241,10 +314,57 @@ impl RequestPeers { .take(8) { count += 1; - self.node_tx.send(addr)?; + node_tx.send((Some(id), addr))?; } Ok(count) } +} + +impl RecursiveRequest { + async fn request_one<'a>( + self: &'a Arc, + id: Option, + addr: SocketAddr, + node_tx: UnboundedSender<(Option, SocketAddr)>, + peer_tx: UnboundedSender, + ) -> anyhow::Result<()> { + if let Some(id) = id { + self.callbacks.on_request_start(self, id, addr); + } + + let response = self.dht.request(self.request, addr).await.map(|r| { + self.mark_node_responded(addr, &r); + r + }); + if let Some(id) = id { + self.callbacks.on_request_end(self, id, addr, &response); + } + + let response = match self.dht.request(self.request, addr).await { + Ok(ResponseOrError::Response(r)) => r, + Ok(ResponseOrError::Error(e)) => bail!("error response: {:?}", e), + Err(e) => { + self.mark_node_error(addr); + return Err(e); + } + }; + + if let Some(peers) = response.values { + for peer in peers { + peer_tx.send(SocketAddr::V4(peer.addr))?; + } + } + + if let Some(nodes) = response.nodes { + for node in nodes.nodes { + let addr = SocketAddr::V4(node.addr); + if self.should_request_node(node.id, addr) { + node_tx.send((Some(node.id), addr))?; + } + } + } + Ok(()) + } fn mark_node_error(&self, addr: SocketAddr) -> bool { self.useful_nodes @@ -351,22 +471,6 @@ impl DhtState { } } - fn spawn_request(self: &Arc, request: Request, addr: SocketAddr) { - let this = self.clone(); - spawn( - error_span!(parent: None, "dht_spawn_request", addr=addr.to_string(), request=format!("{:?}", request)), - async move { - match this.send_request_and_handle_response(request, addr).await { - Ok(_) => {} - Err(e) => { - debug!("error: {:?}", e); - } - }; - Ok(()) - }, - ); - } - async fn send_request_and_handle_response( self: &Arc, request: Request, @@ -455,15 +559,11 @@ impl DhtState { self.routing_table.write().mark_response(&response.id); match request { Request::FindNode(id) => { - let nodes = response - .nodes - .ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?; - self.on_found_nodes(response.id, addr, id, nodes) + todo!() } Request::Ping => Ok(()), - Request::GetPeers(info_hash) => { + Request::GetPeers(_info_hash) => { todo!() - // self.on_found_peers_or_nodes(response.id, addr, info_hash, response) } } } @@ -613,62 +713,6 @@ impl DhtState { } } - fn send_request_if_not_yet( - self: &Arc, - target_node: Id20, - request: Request, - addr: SocketAddr, - ) -> anyhow::Result<()> { - let this = self.clone(); - let fut = async move { - this.routing_table - .write() - .mark_outgoing_request(&target_node); - - let resp = this.request(request, addr).await; - match resp { - Ok(ResponseOrError::Response(response)) => { - this.routing_table.write().mark_response(&target_node); - match this.on_response(addr, request, response) { - Ok(()) => {} - Err(e) => { - warn!("error in on_response: {:?}", e); - } - } - } - Ok(ResponseOrError::Error(e)) => { - this.routing_table.write().mark_response(&target_node); - debug!("error response: {:?}", e); - } - Err(e) => { - this.routing_table.write().mark_error(&target_node); - debug!("error: {:?}", e); - } - }; - Ok(()) - }; - - spawn( - error_span!( - parent: None, - "dht_request", - addr = addr.to_string(), - request = format!("{:?}", request), - ), - fut, - ); - Ok(()) - } - - fn send_find_node_if_not_yet( - self: &Arc, - search_id: Id20, - target_node: Id20, - addr: SocketAddr, - ) -> anyhow::Result<()> { - self.send_request_if_not_yet(target_node, Request::FindNode(search_id), addr) - } - fn routing_table_add_node(self: &Arc, id: Id20, addr: SocketAddr) -> InsertResult { let mut questionable_nodes = Vec::new(); let res = self.routing_table.write().add_node(id, addr, |addr| { @@ -676,30 +720,15 @@ impl DhtState { true }); for addr in questionable_nodes { - self.spawn_request(Request::Ping, addr); + let (_, req) = self.create_request(Request::Ping); + let _ = self.sender.send(WorkerSendRequest { + our_tid: None, + message: req, + addr, + }); } res } - - fn on_found_nodes( - self: &Arc, - source: Id20, - source_addr: SocketAddr, - target: Id20, - nodes: CompactNodeInfo, - ) -> anyhow::Result<()> { - self.routing_table_add_node(source, source_addr); - for node in nodes.nodes { - match self.routing_table_add_node(node.id, node.addr.into()) { - InsertResult::ReplacedBad(_) | InsertResult::Added => { - // recursively find nodes closest to us until we can't find more. - self.send_find_node_if_not_yet(target, source, source_addr)?; - } - _ => {} - }; - } - Ok(()) - } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -938,22 +967,6 @@ impl DhtWorker { } } -struct PeerStream { - info_hash: Id20, - state: Arc, -} - -impl Stream for PeerStream { - type Item = SocketAddr; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - todo!() - } -} - #[derive(Default)] pub struct DhtConfig { pub peer_id: Option, diff --git a/crates/dht/src/routing_table.rs b/crates/dht/src/routing_table.rs index 6cfb6da..6da6ce7 100644 --- a/crates/dht/src/routing_table.rs +++ b/crates/dht/src/routing_table.rs @@ -7,7 +7,7 @@ use serde::{ }; use tracing::debug; -use crate::{INACTIVITY_TIMEOUT, RESPONSE_TIMEOUT}; +use crate::{INACTIVITY_TIMEOUT}; #[derive(Debug, Clone, Serialize, Deserialize)] enum BucketTreeNodeData {