diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index d458650..7bdfc30 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -21,7 +21,7 @@ use anyhow::{bail, Context}; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use bencode::ByteString; use dashmap::DashMap; -use futures::{stream::FuturesUnordered, Stream, StreamExt}; +use futures::{future::BoxFuture, stream::FuturesUnordered, FutureExt, Stream, StreamExt}; use indexmap::IndexSet; use leaky_bucket::RateLimiter; use librqbit_core::{id20::Id20, peer_id::generate_peer_id, spawn_utils::spawn}; @@ -30,18 +30,19 @@ use rand::Rng; use serde::Serialize; use tokio::{ net::UdpSocket, - sync::mpsc::{channel, unbounded_channel, Sender, UnboundedReceiver, UnboundedSender}, + sync::{ + mpsc::{channel, unbounded_channel, Sender, UnboundedReceiver, UnboundedSender}, + Notify, + }, }; use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; -use tracing::{debug, debug_span, error_span, info, trace, warn, Instrument}; +use tracing::{debug, debug_span, error, error_span, info, trace, warn, Instrument}; #[derive(Debug, Serialize)] pub struct DhtStats { #[serde(serialize_with = "crate::utils::serialize_id20")] pub id: Id20, pub outstanding_requests: usize, - pub seen_peers: usize, - pub recent_requests: usize, pub routing_table_size: usize, } @@ -59,6 +60,7 @@ pub struct WorkerSendRequest { struct MaybeUsefulNode { id: Id20, addr: SocketAddr, + last_request: Instant, last_response: Option, returned_peers: bool, } @@ -80,33 +82,209 @@ fn make_rate_limiter() -> RateLimiter { .build() } -struct InfoHashMeta { - seen_peers: IndexSet, - subscriber: tokio::sync::broadcast::Sender, - closest_responding_nodes: Vec, - join_handle: tokio::task::JoinHandle<()>, +struct RequestPeers { + info_hash: Id20, + dht: Arc, + useful_nodes: RwLock>, + tx: tokio::sync::mpsc::UnboundedSender, +} + +struct RequestPeersStream { + rx: tokio::sync::mpsc::UnboundedReceiver, + cancel_join_handle: tokio::task::JoinHandle<()>, +} + +impl RequestPeersStream { + fn new(dht: Arc, info_hash: Id20) -> Self { + let (tx, rx) = unbounded_channel(); + let rp = Arc::new(RequestPeers { + info_hash, + dht, + useful_nodes: RwLock::new(Vec::new()), + tx, + }); + let join_handle = rp.request_peers_forever(); + Self { + rx, + cancel_join_handle: join_handle, + } + } +} + +impl Drop for RequestPeersStream { + fn drop(&mut self) { + self.cancel_join_handle.abort(); + } +} + +impl Stream for RequestPeersStream { + type Item = SocketAddr; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.rx.poll_recv(cx) + } +} + +impl RequestPeers { + fn request_peers_forever(self: Arc) -> tokio::task::JoinHandle<()> { + spawn( + error_span!("request_peers", info_hash = format!("{:?}", self.info_hash)), + async move { + let mut iteration = 0; + loop { + debug!("iteration {}", iteration); + let sleep_duration = match self.get_peers_root().await { + Ok(_) => Duration::from_secs(60), + Err(e) => { + debug!("error: {e:?}"); + Duration::from_secs(1) + } + }; + tokio::time::sleep(sleep_duration).await; + iteration += 1; + } + }, + ) + } + fn request_peers_one<'a>( + self: &'a Arc, + addr: SocketAddr, + ) -> BoxFuture<'a, anyhow::Result<()>> { + let fut = async move { + let response = self + .dht + .request(Request::GetPeers(self.info_hash), addr) + .await?; + let response = match response { + ResponseOrError::Response(r) => r, + ResponseOrError::Error(e) => { + bail!("error response: {:?}", e) + } + }; + + self.mark_node_responded(addr, &response); + + if let Some(peers) = response.values { + for peer in peers { + self.tx.send(SocketAddr::V4(peer.addr))?; + } + } + + let mut futs = FuturesUnordered::new(); + if let Some(nodes) = response.nodes { + for node in nodes.nodes { + let addr = SocketAddr::V4(node.addr); + if self.should_request_node(node.id, addr) { + futs.push(self.request_peers_one(addr)); + } + } + } + + while let Some(res) = futs.next().await { + if let Err(e) = res { + debug!("error: {e:?}") + } + } + + Ok(()) + }; + fut.boxed() + } + + async fn get_peers_root(self: &Arc) -> anyhow::Result<()> { + let mut futs = FuturesUnordered::new(); + for (_, addr) in self + .dht + .routing_table + .read() + .sorted_by_distance_from(self.info_hash) + .iter() + .map(|n| (n.id(), n.addr())) + .take(8) + { + futs.push(self.request_peers_one(addr)) + } + if futs.is_empty() { + bail!("no nodes in routing table") + } + while let Some(res) = futs.next().await { + if let Err(e) = res { + debug!("error: {e:?}") + } + } + Ok(()) + } + + fn mark_node_responded(&self, addr: SocketAddr, response: &Response) { + let mut closest_nodes = self.useful_nodes.write(); + for node in closest_nodes.iter_mut() { + if node.addr == addr { + node.last_response = Some(Instant::now()); + node.returned_peers = response + .values + .as_ref() + .map(|c| !c.is_empty()) + .unwrap_or(false); + break; + } + } + } + + fn should_request_node(&self, node_id: Id20, addr: SocketAddr) -> bool { + let mut closest_nodes = self.useful_nodes.write(); + + // If recently requested, ignore + if let Some(existing) = closest_nodes.iter_mut().find(|n| n.id == node_id) { + if existing.last_request.elapsed() > Duration::from_secs(60) { + existing.last_request = Instant::now(); + return true; + } + return false; + } + + closest_nodes.push(MaybeUsefulNode { + id: node_id, + addr, + last_request: Instant::now(), + last_response: None, + returned_peers: false, + }); + + const LIMIT: usize = 256; + closest_nodes.sort_by_key(|n| { + let has_returned_peers_desc = Reverse(n.returned_peers); + let has_responded_desc = Reverse(n.last_response.is_some() as u8); + let distance = n.id.distance(&self.info_hash); + (has_returned_peers_desc, has_responded_desc, distance) + }); + if closest_nodes.len() > LIMIT { + let popped = closest_nodes.pop().unwrap(); + if popped.id == node_id { + return false; + } + } + true + } } pub struct DhtState { id: Id20, next_transaction_id: AtomicU16, + bootstrapped: Notify, // Created requests: (transaction_id, addr) => Requests. // If we get a response, it gets removed from here. inflight_by_transaction_id: DashMap<(u16, SocketAddr), OutstandingRequest>, - // Current requests to addr being re-sent with backoff. - recent_requests: DashMap<(Request, SocketAddr), Instant>, - routing_table: RwLock, listen_addr: SocketAddr, // Sending requests to the worker. rate_limiter: RateLimiter, sender: UnboundedSender, - - // Per-torrent stats. - info_hash_meta: DashMap, } impl DhtState { @@ -119,14 +297,13 @@ impl DhtState { let routing_table = routing_table.unwrap_or_else(|| RoutingTable::new(id)); Self { id, + bootstrapped: Default::default(), next_transaction_id: AtomicU16::new(0), inflight_by_transaction_id: Default::default(), routing_table: RwLock::new(routing_table), sender, listen_addr, rate_limiter: make_rate_limiter(), - info_hash_meta: Default::default(), - recent_requests: Default::default(), } } @@ -241,7 +418,8 @@ impl DhtState { } Request::Ping => Ok(()), Request::GetPeers(info_hash) => { - self.on_found_peers_or_nodes(response.id, addr, info_hash, response) + todo!() + // self.on_found_peers_or_nodes(response.id, addr, info_hash, response) } } } @@ -321,26 +499,26 @@ impl DhtState { Ok(()) } MessageKind::GetPeersRequest(req) => { - let peers = self.info_hash_meta.get(&req.info_hash).map(|meta| { - meta.seen_peers - .iter() - .copied() - .filter_map(|a| match a { - SocketAddr::V4(v4) => Some(CompactPeerInfo { addr: v4 }), - // this should never happen in practice - SocketAddr::V6(_) => None, - }) - .take(50) - .collect::>() - }); - let token = if peers.is_some() { - let mut token = [0u8; 20]; - rand::thread_rng().fill(&mut token); - Some(ByteString::from(token.as_ref())) - } else { - None - }; - let compact_node_info = generate_compact_nodes(req.info_hash); + // let peers = self.info_hash_meta.get(&req.info_hash).map(|meta| { + // meta.seen_peers + // .iter() + // .copied() + // .filter_map(|a| match a { + // SocketAddr::V4(v4) => Some(CompactPeerInfo { addr: v4 }), + // // this should never happen in practice + // SocketAddr::V6(_) => None, + // }) + // .take(50) + // .collect::>() + // }); + // let token = if peers.is_some() { + // let mut token = [0u8; 20]; + // rand::thread_rng().fill(&mut token); + // Some(ByteString::from(token.as_ref())) + // } else { + // None + // }; + // let compact_node_info = generate_compact_nodes(req.info_hash); self.routing_table.write().mark_last_query(&req.id); let message = Message { transaction_id: msg.transaction_id, @@ -348,9 +526,9 @@ impl DhtState { ip: None, kind: MessageKind::Response(bprotocol::Response { id: self.id, - nodes: Some(compact_node_info), - values: peers, - token, + nodes: None, + values: None, + token: None, }), }; self.sender.send(WorkerSendRequest { @@ -387,123 +565,17 @@ impl DhtState { DhtStats { id: self.id, outstanding_requests: self.inflight_by_transaction_id.len(), - seen_peers: self - .info_hash_meta - .iter() - .map(|e| e.value().seen_peers.len()) - .sum(), - recent_requests: self.recent_requests.len(), routing_table_size: self.routing_table.read().len(), } } - #[allow(clippy::type_complexity)] - fn get_peers_internal( - self: &Arc, - info_hash: Id20, - ) -> anyhow::Result<( - Option<(usize, usize)>, - tokio::sync::broadcast::Receiver, - )> { - use dashmap::mapref::entry::Entry; - match self.info_hash_meta.entry(info_hash) { - Entry::Occupied(o) => { - let seen_peers = &o.get().seen_peers; - let pos = if seen_peers.is_empty() { - None - } else { - Some((0, seen_peers.len())) - }; - let rx = o.get().subscriber.subscribe(); - Ok((pos, rx)) - } - Entry::Vacant(v) => { - // DHT sends peers REALLY fast, so ideally the consumer of this broadcast should not lag behind. - // In case it does though we have PeerStream to replay. - - let this = self.clone(); - let join_handle = spawn( - error_span!("peers_requester", info_hash = format!("{:?}", info_hash)), - async move { - let mut iteration = 0usize; - loop { - let meta = match this.info_hash_meta.get(&info_hash) { - Some(meta) => meta, - None => { - debug!("no more subscribers, closing peers_requester"); - return Ok(()); - } - }; - trace!("iteration {iteration}"); - let nodes_to_query = this - .routing_table - .read() - .sorted_by_distance_from(info_hash) - .iter() - .map(|n| (n.id(), n.addr())) - .take(8) - .collect::>(); - for (id, addr) in nodes_to_query { - this.send_find_peers_if_not_yet(info_hash, id, addr)?; - } - for MaybeUsefulNode { id, addr, .. } in - meta.closest_responding_nodes.iter() - { - this.send_find_peers_if_not_yet(info_hash, *id, *addr)?; - } - drop(meta); - tokio::time::sleep(REQUERY_INTERVAL).await; - iteration += 1; - } - }, - ); - - let (tx, rx) = tokio::sync::broadcast::channel(100); - v.insert(InfoHashMeta { - seen_peers: Default::default(), - subscriber: tx, - closest_responding_nodes: Default::default(), - join_handle, - }); - - Ok((None, rx)) - } - } - } - - fn send_find_peers_if_not_yet( - self: &Arc, - info_hash: Id20, - target_node: Id20, - addr: SocketAddr, - ) -> anyhow::Result<()> { - self.send_request_if_not_yet(target_node, Request::GetPeers(info_hash), addr) - } - fn send_request_if_not_yet( self: &Arc, target_node: Id20, request: Request, addr: SocketAddr, ) -> anyhow::Result<()> { - let key = (request, addr); - - use dashmap::mapref::entry::Entry; - match self.recent_requests.entry(key) { - Entry::Occupied(mut o) => { - // minus to account for randomness - if o.get().elapsed() < REQUERY_INTERVAL - Duration::from_secs(1) { - return Ok(()); - } - o.insert(Instant::now()); - } - Entry::Vacant(v) => { - v.insert(Instant::now()); - } - } - let this = self.clone(); - let fut = async move { this.routing_table .write() @@ -572,27 +644,10 @@ impl DhtState { target: Id20, nodes: CompactNodeInfo, ) -> anyhow::Result<()> { - let searching_for_peers = self - .info_hash_meta - .iter() - .map(|e| *e.key()) - .collect::>(); - - // On newly discovered nodes, ask them for peers that we are interested in. - match self.routing_table_add_node(source, source_addr) { - InsertResult::ReplacedBad(_) | InsertResult::Added => { - for info_hash in &searching_for_peers { - self.send_find_peers_if_not_yet(*info_hash, source, source_addr)?; - } - } - _ => {} - }; + self.routing_table_add_node(source, source_addr); for node in nodes.nodes { match self.routing_table_add_node(node.id, node.addr.into()) { InsertResult::ReplacedBad(_) | InsertResult::Added => { - for info_hash in &searching_for_peers { - self.send_find_peers_if_not_yet(*info_hash, node.id, node.addr.into())?; - } // recursively find nodes closest to us until we can't find more. self.send_find_node_if_not_yet(target, source, source_addr)?; } @@ -601,116 +656,6 @@ impl DhtState { } Ok(()) } - - fn am_i_interested_in_node_for_this_info_hash( - &self, - info_hash: Id20, - node_id: Id20, - addr: SocketAddr, - closest_nodes: &mut Vec, - ) -> bool { - closest_nodes.push(MaybeUsefulNode { - id: node_id, - addr, - last_response: None, - returned_peers: false, - }); - - const LIMIT: usize = 256; - closest_nodes.sort_by_key(|n| { - let has_returned_peers_desc = Reverse(n.returned_peers); - let has_responded_desc = Reverse(n.last_response.is_some() as u8); - let distance = n.id.distance(&info_hash); - (has_returned_peers_desc, has_responded_desc, distance) - }); - if closest_nodes.len() > LIMIT { - let popped = closest_nodes.pop().unwrap(); - if popped.id == node_id { - return false; - } - } - true - } - - fn on_found_peers_or_nodes( - self: &Arc, - source: Id20, - source_addr: SocketAddr, - info_hash: Id20, - data: bprotocol::Response, - ) -> anyhow::Result<()> { - self.routing_table_add_node(source, source_addr); - - use dashmap::mapref::entry::Entry; - let mut meta = match self.info_hash_meta.entry(info_hash) { - Entry::Occupied(o) => o, - Entry::Vacant(_) => { - warn!( - "ignoring found_peers response, no subscribers for {:?}", - info_hash - ); - return Ok(()); - } - }; - - let meta_mut = meta.get_mut(); - - { - let now = Some(Instant::now()); - let returned_peers = data.values.as_ref().map(|p| !p.is_empty()).unwrap_or(false); - - if let Some(existing_useful_node) = meta_mut - .closest_responding_nodes - .iter_mut() - .find(|n| n.id == source && n.addr == source_addr) - { - existing_useful_node.last_response = now; - existing_useful_node.returned_peers |= returned_peers; - } else { - meta_mut.closest_responding_nodes.push(MaybeUsefulNode { - id: source, - addr: source_addr, - last_response: now, - returned_peers, - }); - } - } - - if let Some(peers) = data.values { - for peer in peers.iter() { - if peer.addr.port() < 1024 { - debug!("bad peer port, ignoring: {}", peer.addr); - continue; - } - let addr = SocketAddr::V4(peer.addr); - if meta_mut.seen_peers.insert(addr) { - match meta_mut.subscriber.send(addr) { - Ok(_) => {} - Err(_) => { - debug!("no more subscribers for {:?}, cleaning up", info_hash); - meta_mut.join_handle.abort(); - meta.remove(); - return Ok(()); - } - } - } - } - }; - if let Some(nodes) = data.nodes { - for node in nodes.nodes { - if self.am_i_interested_in_node_for_this_info_hash( - info_hash, - node.id, - node.addr.into(), - &mut meta_mut.closest_responding_nodes, - ) { - self.routing_table_add_node(node.id, node.addr.into()); - self.send_find_peers_if_not_yet(info_hash, node.id, node.addr.into())?; - } - } - }; - Ok(()) - } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -952,9 +897,6 @@ impl DhtWorker { struct PeerStream { info_hash: Id20, state: Arc, - absolute_stream_pos: usize, - initial_peers_pos: Option<(usize, usize)>, - broadcast_rx: BroadcastStream, } impl Stream for PeerStream { @@ -964,40 +906,7 @@ impl Stream for PeerStream { mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { - loop { - if let Some((pos, end)) = self.initial_peers_pos.take() { - let addr = match self - .state - .info_hash_meta - .get(&self.info_hash) - .and_then(|meta| meta.seen_peers.get_index(pos).copied()) - { - Some(addr) => addr, - None => return Poll::Ready(None), - }; - if pos + 1 < end { - self.initial_peers_pos = Some((pos + 1, end)); - } - self.absolute_stream_pos += 1; - return Poll::Ready(Some(addr)); - } - - match self.broadcast_rx.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(v))) => { - self.absolute_stream_pos += 1; - return Poll::Ready(Some(v)); - } - Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(lagged_by)))) => { - debug!("peer stream is lagged by {}", lagged_by); - let s = self.absolute_stream_pos; - let e = s + lagged_by as usize; - self.initial_peers_pos = Some((s, e)); - continue; - } - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => return Poll::Pending, - }; - } + todo!() } } @@ -1061,14 +970,7 @@ impl DhtState { self: &Arc, info_hash: Id20, ) -> anyhow::Result + Unpin> { - let (pos, rx) = self.get_peers_internal(info_hash)?; - Ok(PeerStream { - info_hash, - state: self.clone(), - absolute_stream_pos: 0, - initial_peers_pos: pos, - broadcast_rx: BroadcastStream::new(rx), - }) + Ok(RequestPeersStream::new(self.clone(), info_hash)) } pub fn listen_addr(&self) -> SocketAddr {