diff --git a/crates/dht/src/bprotocol.rs b/crates/dht/src/bprotocol.rs index 562a0ba..3ab76ac 100644 --- a/crates/dht/src/bprotocol.rs +++ b/crates/dht/src/bprotocol.rs @@ -333,7 +333,8 @@ pub struct Message { } impl Message { - pub fn get_transaction_id(&self) -> Option { + // This implies that the transaction id was generated by us. + pub fn get_our_transaction_id(&self) -> Option { if self.transaction_id.len() != 2 { return None; } diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index 3c42f1c..a5d2ec8 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -5,7 +5,7 @@ use std::{ Arc, }, task::Poll, - time::{Duration, Instant}, + time::Duration, }; use crate::{ @@ -14,12 +14,12 @@ use crate::{ Message, MessageKind, Node, PingRequest, Response, }, routing_table::{InsertResult, RoutingTable}, - RESPONSE_TIMEOUT, + REQUERY_INTERVAL, RESPONSE_TIMEOUT, }; -use anyhow::Context; +use anyhow::{bail, Context}; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use bencode::ByteString; -use dashmap::DashMap; +use dashmap::{DashMap, DashSet}; use futures::{stream::FuturesUnordered, Stream, StreamExt}; use indexmap::IndexSet; use leaky_bucket::RateLimiter; @@ -40,7 +40,7 @@ pub struct DhtStats { pub id: Id20, pub outstanding_requests: usize, pub seen_peers: usize, - pub made_requests: usize, + pub outstanding_backoff_tasks: usize, pub routing_table_size: usize, } @@ -54,10 +54,10 @@ pub struct DhtState { // Created requests: (transaction_id, addr) => Requests. // If we get a response, it gets removed from here. - inflight: DashMap<(u16, SocketAddr), OutstandingRequest>, + inflight_by_transaction_id: DashMap<(u16, SocketAddr), OutstandingRequest>, - // TODO: clean up old entries - made_requests_by_addr: DashMap<(Request, SocketAddr), Instant>, + // Current requests to addr being re-sent with backoff. + inflight_by_request: DashSet<(Request, SocketAddr)>, routing_table: RwLock, listen_addr: SocketAddr, @@ -80,13 +80,13 @@ impl DhtState { Self { id, next_transaction_id: AtomicU16::new(0), - inflight: Default::default(), + inflight_by_transaction_id: Default::default(), routing_table: RwLock::new(routing_table), sender, listen_addr, seen_peers: Default::default(), get_peers_subscribers: Default::default(), - made_requests_by_addr: Default::default(), + inflight_by_request: Default::default(), } } @@ -115,7 +115,7 @@ impl DhtState { match resp { ResponseOrError::Response(r) => self.on_response(addr, request, r), ResponseOrError::Error(e) => { - anyhow::bail!("received error: {:?}", e); + bail!("received error: {:?}", e); } } } @@ -124,24 +124,25 @@ impl DhtState { let (tid, msg) = self.create_request(request); let key = (tid, addr); let (tx, rx) = tokio::sync::oneshot::channel(); - self.inflight.insert(key, OutstandingRequest { done: tx }); + self.inflight_by_transaction_id + .insert(key, OutstandingRequest { done: tx }); match self.sender.send((msg, addr)) { Ok(_) => {} Err(e) => { - self.inflight.remove(&key); + self.inflight_by_transaction_id.remove(&key); return Err(e.into()); } }; match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await { Ok(Ok(r)) => r, Ok(Err(e)) => { - self.inflight.remove(&key); + self.inflight_by_transaction_id.remove(&key); warn!("recv error, did not expect this: {:?}", e); Err(e.into()) } Err(_) => { - self.inflight.remove(&key); - anyhow::bail!("timeout") + self.inflight_by_transaction_id.remove(&key); + bail!("timeout") } } } @@ -192,12 +193,14 @@ impl DhtState { .ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?; self.on_found_nodes(response.id, addr, id, nodes) } - Request::GetPeers(id) => self.on_found_peers_or_nodes(response.id, addr, id, response), Request::Ping => Ok(()), + Request::GetPeers(info_hash) => { + self.on_found_peers_or_nodes(response.id, addr, info_hash, response) + } } } - fn on_incoming_from_remote( + fn on_received_message( self: &Arc, msg: Message, addr: SocketAddr, @@ -226,10 +229,14 @@ impl DhtState { // If it's a response to a request we made, find the request task, notify it with the response, // and let it handle it. MessageKind::Error(_) | MessageKind::Response(_) => { - let tid = msg.get_transaction_id().context("bad transaction id")?; - let request = match self.inflight.remove(&(tid, addr)).map(|(_, v)| v) { + let tid = msg.get_our_transaction_id().context("bad transaction id")?; + let request = match self + .inflight_by_transaction_id + .remove(&(tid, addr)) + .map(|(_, v)| v) + { Some(req) => req, - None => anyhow::bail!("outstanding request not found. Message: {:?}", msg), + None => bail!("outstanding request not found. Message: {:?}", msg), }; let response_or_error = match msg.kind { @@ -324,9 +331,9 @@ impl DhtState { pub fn get_stats(&self) -> DhtStats { DhtStats { id: self.id, - outstanding_requests: self.inflight.len(), + outstanding_requests: self.inflight_by_transaction_id.len(), seen_peers: self.seen_peers.iter().map(|e| e.value().len()).sum(), - made_requests: self.made_requests_by_addr.len(), + outstanding_backoff_tasks: self.inflight_by_request.len(), routing_table_size: self.routing_table.read().len(), } } @@ -376,38 +383,86 @@ impl DhtState { } } - fn should_request(&self, request: Request, addr: SocketAddr) -> bool { - const RE_REQUEST_TIME: Duration = Duration::from_secs(10 * 60); - use dashmap::mapref::entry::Entry; - match self.made_requests_by_addr.entry((request, addr)) { - Entry::Occupied(mut o) => { - if o.get().elapsed() > RE_REQUEST_TIME { - o.insert(Instant::now()); - true - } else { - false - } - } - Entry::Vacant(v) => { - v.insert(Instant::now()); - true - } - } - } - fn send_find_peers_if_not_yet( self: &Arc, info_hash: Id20, target_node: Id20, addr: SocketAddr, ) -> anyhow::Result<()> { - let request = Request::GetPeers(info_hash); - if self.should_request(request, addr) { - self.routing_table - .write() - .mark_outgoing_request(&target_node); - self.spawn_request(request, addr); + self.send_request_if_not_yet(target_node, Request::GetPeers(info_hash), addr) + } + + fn send_request_if_not_yet( + self: &Arc, + target_node: Id20, + request: Request, + addr: SocketAddr, + ) -> anyhow::Result<()> { + let key = (request, addr); + if !self.inflight_by_request.insert(key) { + return Ok(()); } + + let this = self.clone(); + + let fut = async move { + let mut backoff = ExponentialBackoffBuilder::new() + .with_initial_interval(Duration::from_secs(60)) + .with_multiplier(1.5) + .with_max_interval(Duration::from_secs(10 * 60)) + .with_max_elapsed_time(Some(Duration::from_secs(15 * 60))) + .build(); + + loop { + this.routing_table + .write() + .mark_outgoing_request(&target_node); + + let resp = this.request(request, addr).await; + let sleep = match resp { + Ok(ResponseOrError::Response(response)) => { + match this.on_response(addr, request, response) { + Ok(()) => { + backoff.reset(); + Some(REQUERY_INTERVAL) + } + Err(e) => { + warn!("error in on_response: {:?}", e); + backoff.next_backoff() + } + } + } + Ok(ResponseOrError::Error(e)) => { + debug!("error response: {:?}", e); + backoff.next_backoff() + } + Err(e) => { + debug!("error: {:?}", e); + backoff.next_backoff() + } + }; + if let Some(sleep) = sleep { + tokio::time::sleep(sleep).await; + continue; + } + + tokio::task::spawn(async move { + this.inflight_by_request.remove(&key); + }); + + return Ok(()); + } + }; + + spawn( + error_span!( + parent: None, + "dht_request", + addr = addr.to_string(), + request = format!("{:?}", request), + ), + fut, + ); Ok(()) } @@ -417,14 +472,7 @@ impl DhtState { target_node: Id20, addr: SocketAddr, ) -> anyhow::Result<()> { - let request = Request::FindNode(search_id); - if self.should_request(request, addr) { - self.routing_table - .write() - .mark_outgoing_request(&target_node); - self.spawn_request(request, addr); - } - Ok(()) + self.send_request_if_not_yet(target_node, Request::FindNode(search_id), addr) } fn routing_table_add_node(self: &Arc, id: Id20, addr: SocketAddr) -> InsertResult { @@ -482,25 +530,25 @@ impl DhtState { self: &Arc, source: Id20, source_addr: SocketAddr, - target: Id20, + info_hash: Id20, data: bprotocol::Response, ) -> anyhow::Result<()> { self.routing_table_add_node(source, source_addr); self.routing_table.write().mark_response(&source); - let bsender = match self.get_peers_subscribers.get(&target) { + let bsender = match self.get_peers_subscribers.get(&info_hash) { Some(s) => s, None => { warn!( "ignoring get_peers response, no subscribers for {:?}", - target + info_hash ); return Ok(()); } }; if let Some(peers) = data.values { - let mut seen = self.seen_peers.entry(target).or_default(); + let mut seen = self.seen_peers.entry(info_hash).or_default(); for peer in peers.iter() { if peer.addr.port() < 1024 { @@ -518,7 +566,7 @@ impl DhtState { if let Some(nodes) = data.nodes { for node in nodes.nodes { self.routing_table_add_node(node.id, node.addr.into()); - self.send_find_peers_if_not_yet(target, node.id, node.addr.into())?; + self.send_find_peers_if_not_yet(info_hash, node.id, node.addr.into())?; } }; Ok(()) @@ -562,12 +610,10 @@ struct DhtWorker { } impl DhtWorker { - fn on_response(&self, msg: Message, addr: SocketAddr) -> anyhow::Result<()> { - self.state.on_incoming_from_remote(msg, addr) - } - fn on_send_error(&self, tid: u16, addr: SocketAddr, err: anyhow::Error) { - if let Some((_, OutstandingRequest { done })) = self.state.inflight.remove(&(tid, addr)) { + if let Some((_, OutstandingRequest { done })) = + self.state.inflight_by_transaction_id.remove(&(tid, addr)) + { let _ = done.send(Err(err)).is_err(); }; } @@ -593,7 +639,7 @@ impl DhtWorker { tokio::time::sleep(backoff).await; continue; } - anyhow::bail!("given up bootstrapping, timed out") + bail!("given up bootstrapping, timed out") } } } @@ -618,7 +664,7 @@ impl DhtWorker { }; } if successes == 0 { - anyhow::bail!("none of the {} bootstrap requests succeded", requests); + bail!("none of the {} bootstrap requests succeded", requests); } Ok(()) } @@ -643,7 +689,7 @@ impl DhtWorker { tokio::time::sleep(backoff).await; continue; } - anyhow::bail!("bootstrap failed") + bail!("bootstrap failed") } } @@ -664,7 +710,7 @@ impl DhtWorker { } } if successes == 0 { - anyhow::bail!("bootstrapping failed") + bail!("bootstrapping failed") } Ok(()) } @@ -682,7 +728,7 @@ impl DhtWorker { rate_limiter.acquire_one().await; trace!("{}: sending {:?}", addr, &msg); buf.clear(); - let tid = msg.get_transaction_id().unwrap(); + let tid = msg.get_our_transaction_id(); bprotocol::serialize_message( &mut buf, msg.transaction_id, @@ -692,7 +738,10 @@ impl DhtWorker { ) .unwrap(); if let Err(e) = socket.send_to(&buf, addr).await { - self.on_send_error(tid, addr, e.into()); + debug!("error sending to {addr}: {e:?}"); + if let Some(tid) = tid { + self.on_send_error(tid, addr, e.into()); + } } } Err::<(), _>(anyhow::anyhow!( @@ -745,7 +794,7 @@ impl DhtWorker { let this = &self; async move { while let Some((response, addr)) = out_rx.recv().await { - if let Err(e) = this.on_response(response, addr) { + if let Err(e) = this.state.on_received_message(response, addr) { debug!("error in on_response, addr={:?}: {}", addr, e) } } diff --git a/crates/dht/src/lib.rs b/crates/dht/src/lib.rs index 9e5bfe4..5a28d07 100644 --- a/crates/dht/src/lib.rs +++ b/crates/dht/src/lib.rs @@ -16,6 +16,8 @@ pub type Dht = Arc; // How long do we wait for a response from a DHT node. pub(crate) const RESPONSE_TIMEOUT: Duration = Duration::from_secs(60); +// TODO: Not sure if we should re-query tbh. +pub(crate) const REQUERY_INTERVAL: Duration = Duration::from_secs(60); // After how long should we ping the node again. pub(crate) const INACTIVITY_TIMEOUT: Duration = Duration::from_secs(15 * 60);