diff --git a/crates/dht/src/bprotocol.rs b/crates/dht/src/bprotocol.rs index 3ab76ac..5db82fe 100644 --- a/crates/dht/src/bprotocol.rs +++ b/crates/dht/src/bprotocol.rs @@ -292,12 +292,12 @@ pub struct FindNodeRequest { #[derive(Debug, Serialize, Deserialize, Default)] pub struct Response { + #[serde(skip_serializing_if = "Option::is_none")] + pub values: Option>, pub id: Id20, #[serde(skip_serializing_if = "Option::is_none")] pub nodes: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub values: Option>, - #[serde(skip_serializing_if = "Option::is_none")] pub token: Option, } @@ -326,10 +326,10 @@ pub struct GetPeersResponse { #[derive(Debug)] pub struct Message { + pub kind: MessageKind, pub transaction_id: BufT, pub version: Option, pub ip: Option, - pub kind: MessageKind, } impl Message { diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index f6d8143..365352c 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -1,4 +1,5 @@ use std::{ + cmp::Reverse, net::SocketAddr, sync::{ atomic::{AtomicU16, Ordering}, @@ -19,7 +20,7 @@ use crate::{ use anyhow::{bail, Context}; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use bencode::ByteString; -use dashmap::{DashMap, DashSet}; +use dashmap::DashMap; use futures::{stream::FuturesUnordered, Stream, StreamExt}; use indexmap::IndexSet; use leaky_bucket::RateLimiter; @@ -54,6 +55,12 @@ pub struct WorkerSendRequest { addr: SocketAddr, } +struct MaybeUsefulNode { + id: Id20, + addr: SocketAddr, + last_response: Option, +} + pub struct DhtState { id: Id20, next_transaction_id: AtomicU16, @@ -72,6 +79,8 @@ pub struct DhtState { sender: UnboundedSender, seen_peers: DashMap>, + + closest_responding_nodes_for_info_hash: DashMap>, get_peers_subscribers: DashMap>, } @@ -92,6 +101,7 @@ impl DhtState { listen_addr, seen_peers: Default::default(), get_peers_subscribers: Default::default(), + closest_responding_nodes_for_info_hash: Default::default(), recent_requests: Default::default(), } } @@ -127,6 +137,7 @@ impl DhtState { } async fn request(&self, request: Request, addr: SocketAddr) -> anyhow::Result { + // self.rate_limiter.acquire_one().await; let (tid, message) = self.create_request(request); let key = (tid, addr); let (tx, rx) = tokio::sync::oneshot::channel(); @@ -387,18 +398,34 @@ impl DhtState { let (tx, rx) = tokio::sync::broadcast::channel(100); v.insert(tx); - // We don't need to allocate/collect here, but the borrow checker is not happy otherwise. - let nodes_to_query = self - .routing_table - .read() - .sorted_by_distance_from(info_hash) - .iter() - .map(|n| (n.id(), n.addr())) - .take(8) - .collect::>(); - for (id, addr) in nodes_to_query { - self.send_find_peers_if_not_yet(info_hash, id, addr)?; - } + let this = self.clone(); + spawn( + error_span!("peers_requester", info_hash = format!("{:?}", info_hash)), + async move { + loop { + // We don't need to allocate/collect here, but the borrow checker is not happy otherwise. + let nodes_to_query = this + .routing_table + .read() + .sorted_by_distance_from(info_hash) + .iter() + .map(|n| (n.id(), n.addr())) + .take(8) + .collect::>(); + for (id, addr) in nodes_to_query { + this.send_find_peers_if_not_yet(info_hash, id, addr)?; + } + if let Some(e) = + this.closest_responding_nodes_for_info_hash.get(&info_hash) + { + for MaybeUsefulNode { id, addr, .. } in e.value().iter() { + this.send_find_peers_if_not_yet(info_hash, *id, *addr)?; + } + } + tokio::time::sleep(REQUERY_INTERVAL).await; + } + }, + ); Ok((None, rx)) } @@ -422,18 +449,18 @@ impl DhtState { ) -> anyhow::Result<()> { let key = (request, addr); - use dashmap::mapref::entry::Entry; - match self.recent_requests.entry(key) { - Entry::Occupied(mut o) => { - if o.get().elapsed() < REQUERY_INTERVAL { - return Ok(()); - } - o.insert(Instant::now()); - } - Entry::Vacant(v) => { - v.insert(Instant::now()); - } - } + // use dashmap::mapref::entry::Entry; + // match self.recent_requests.entry(key) { + // Entry::Occupied(mut o) => { + // if o.get().elapsed() < REQUERY_INTERVAL { + // return Ok(()); + // } + // o.insert(Instant::now()); + // } + // Entry::Vacant(v) => { + // v.insert(Instant::now()); + // } + // } let this = self.clone(); @@ -534,6 +561,43 @@ impl DhtState { Ok(()) } + fn am_i_interested_in_node_for_this_info_hash( + &self, + info_hash: Id20, + node_id: Id20, + addr: SocketAddr, + ) -> bool { + use dashmap::mapref::entry::Entry; + let n = MaybeUsefulNode { + id: node_id, + addr, + last_response: None, + }; + match self.closest_responding_nodes_for_info_hash.entry(info_hash) { + Entry::Occupied(mut occ) => { + const LIMIT: usize = 128; + let v = occ.get_mut(); + v.push(n); + v.sort_by_key(|n| { + let responded = Reverse(n.last_response.is_some() as u8); + let distance = n.id.distance(&info_hash); + (responded, distance) + }); + while v.len() > LIMIT { + if v.pop().unwrap().id == node_id { + return false; + } + } + + true + } + Entry::Vacant(v) => { + v.insert(vec![n]); + true + } + } + } + fn on_found_peers_or_nodes( self: &Arc, source: Id20, @@ -555,6 +619,31 @@ impl DhtState { } }; + { + use dashmap::mapref::entry::Entry; + let n = MaybeUsefulNode { + id: source, + addr: source_addr, + last_response: Some(Instant::now()), + }; + match self.closest_responding_nodes_for_info_hash.entry(info_hash) { + Entry::Occupied(mut useful_nodes) => { + if let Some(useful_node) = useful_nodes + .get_mut() + .iter_mut() + .find(|n| n.id == source && n.addr == source_addr) + { + useful_node.last_response = Some(Instant::now()); + } else { + useful_nodes.get_mut().push(n); + } + } + Entry::Vacant(v) => { + v.insert(vec![n]); + } + }; + } + if let Some(peers) = data.values { let mut seen = self.seen_peers.entry(info_hash).or_default(); @@ -573,31 +662,20 @@ impl DhtState { }; if let Some(nodes) = data.nodes { for node in nodes.nodes { - self.routing_table_add_node(node.id, node.addr.into()); - self.send_find_peers_if_not_yet(info_hash, node.id, node.addr.into())?; + if self.am_i_interested_in_node_for_this_info_hash( + info_hash, + node.id, + node.addr.into(), + ) { + self.routing_table_add_node(node.id, node.addr.into()); + self.send_find_peers_if_not_yet(info_hash, node.id, node.addr.into())?; + } } }; Ok(()) } } -fn make_rate_limiter() -> RateLimiter { - // TODO: move to configuration, i'm lazy. - let dht_queries_per_second = std::env::var("DHT_QUERIES_PER_SECOND") - .map(|v| v.parse().expect("couldn't parse DHT_QUERIES_PER_SECOND")) - .unwrap_or(250usize); - - let per_100_ms = dht_queries_per_second / 10; - - RateLimiter::builder() - .initial(per_100_ms) - .max(dht_queries_per_second) - .interval(Duration::from_millis(100)) - .fair(false) - .refill(per_100_ms) - .build() -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum Request { GetPeers(Id20), @@ -731,14 +809,12 @@ impl DhtWorker { ) -> anyhow::Result<()> { let writer = async { let mut buf = Vec::new(); - let rate_limiter = make_rate_limiter(); while let Some(WorkerSendRequest { our_tid, message, addr, }) = input_rx.recv().await { - rate_limiter.acquire_one().await; trace!("{}: sending {:?}", addr, &message); buf.clear(); bprotocol::serialize_message(