From 336bf751e38d535e5b8adbf0163e46093987afbe Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Tue, 28 Nov 2023 10:53:22 +0000 Subject: [PATCH] DHT: better tracking requests/responses --- Cargo.lock | 1 + crates/dht/Cargo.toml | 1 + crates/dht/src/dht.rs | 273 +++++++++++++++++++++++++-------------- crates/rqbit/src/main.rs | 4 +- 4 files changed, 182 insertions(+), 97 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ee5227f..8e9c891 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1087,6 +1087,7 @@ name = "librqbit-dht" version = "3.2.0" dependencies = [ "anyhow", + "backoff", "dashmap", "directories", "futures", diff --git a/crates/dht/Cargo.toml b/crates/dht/Cargo.toml index e6647b5..81decc1 100644 --- a/crates/dht/Cargo.toml +++ b/crates/dht/Cargo.toml @@ -27,6 +27,7 @@ bencode = {path = "../bencode", default-features=false, package="librqbit-bencod anyhow = "1" parking_lot = "0.12" tracing = "0.1" +backoff = "0.4.0" futures = "0.3" rand = "0.8" indexmap = "2" diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index 393fc38..f15df9d 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -1,4 +1,5 @@ use std::{ + f32::consts::E, net::SocketAddr, sync::{ atomic::{AtomicU16, Ordering}, @@ -10,16 +11,17 @@ use std::{ use crate::{ bprotocol::{ - self, CompactNodeInfo, CompactPeerInfo, FindNodeRequest, GetPeersRequest, Message, - MessageKind, Node, PingRequest, Response, + self, CompactNodeInfo, CompactPeerInfo, ErrorDescription, FindNodeRequest, GetPeersRequest, + Message, MessageKind, Node, PingRequest, Response, }, routing_table::{InsertResult, RoutingTable}, RESPONSE_TIMEOUT, }; use anyhow::Context; +use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use bencode::{ByteBuf, ByteString}; use dashmap::DashMap; -use futures::{stream::FuturesUnordered, Stream, StreamExt}; +use futures::{future::join_all, stream::FuturesUnordered, Stream, StreamExt, TryFutureExt}; use indexmap::IndexSet; use leaky_bucket::RateLimiter; use librqbit_core::{id20::Id20, peer_id::generate_peer_id, spawn_utils::spawn}; @@ -44,8 +46,7 @@ pub struct DhtStats { } struct OutstandingRequest { - request: Request, - done: tokio::sync::oneshot::Sender<()>, + done: tokio::sync::oneshot::Sender, } pub struct DhtState { @@ -54,7 +55,7 @@ pub struct DhtState { // Created requests: (transaction_id, addr) => Requests. // If we get a response, it gets removed from here. - outstanding_requests_by_transaction_id: DashMap<(u16, SocketAddr), OutstandingRequest>, + inflight: DashMap<(u16, SocketAddr), OutstandingRequest>, // TODO: clean up old entries made_requests_by_addr: DashMap<(Request, SocketAddr), Instant>, @@ -62,11 +63,7 @@ pub struct DhtState { routing_table: RwLock, listen_addr: SocketAddr, - // This sender sends requests to the worker. - // It is unbounded so that the methods on Dht state don't need to be async. - // If the methods on Dht state were async, we would have a problem, as it's behind - // a lock. - // Alternatively, we can lock only the parts that change, and use that internally inside DhtState... + // Sending requests to the worker. sender: UnboundedSender<(Message, SocketAddr)>, seen_peers: DashMap>, @@ -84,7 +81,7 @@ impl DhtState { Self { id, next_transaction_id: AtomicU16::new(0), - outstanding_requests_by_transaction_id: Default::default(), + inflight: Default::default(), routing_table: RwLock::new(routing_table), sender, listen_addr, @@ -94,41 +91,53 @@ impl DhtState { } } - fn send_request(self: &Arc, request: Request, addr: SocketAddr) -> anyhow::Result<()> { + fn spawn_request(self: &Arc, request: Request, addr: SocketAddr) { + let this = self.clone(); + spawn( + error_span!(parent: None, "dht_request", addr=addr.to_string(), request=format!("{:?}", request)), + async move { this.send_request_and_handle_response(request, addr).await }, + ); + } + + async fn send_request_and_handle_response( + self: &Arc, + request: Request, + addr: SocketAddr, + ) -> anyhow::Result<()> { + let resp = self.request(request, addr).await?; + match resp { + ResponseOrError::Response(r) => self.on_response(addr, request, r), + ResponseOrError::Error(e) => { + anyhow::bail!("received error: {:?}", e); + } + } + } + + async fn request(&self, request: Request, addr: SocketAddr) -> anyhow::Result { let (tid, msg) = self.create_request(request); + let key = (tid, addr); let (tx, rx) = tokio::sync::oneshot::channel(); - self.outstanding_requests_by_transaction_id - .insert((tid, addr), OutstandingRequest { request, done: tx }); + self.inflight.insert(key, OutstandingRequest { done: tx }); match self.sender.send((msg, addr)) { Ok(_) => {} Err(e) => { - self.outstanding_requests_by_transaction_id - .remove(&(tid, addr)); + self.inflight.remove(&key); return Err(e.into()); } }; - let this = self.clone(); - spawn( - debug_span!("dht_request", tid = tid, addr = addr.to_string()), - async move { - match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await { - Ok(Ok(_)) => {} - Ok(Err(e)) => { - this.outstanding_requests_by_transaction_id - .remove(&(tid, addr)); - warn!("recv error, did not expect this: {:?}", e); - } - Err(e) => { - this.outstanding_requests_by_transaction_id - .remove(&(tid, addr)); - debug!("error: {:?}", e); - } - }; - - Ok(()) - }, - ); - Ok(()) + match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await { + Ok(Ok(r)) => Ok(r), + Ok(Err(e)) => { + self.inflight.remove(&key); + warn!("recv error, did not expect this: {:?}", e); + Err(e.into()) + } + Err(e) => { + self.inflight.remove(&key); + debug!("error: {:?}", e); + anyhow::bail!("timeout") + } + } } fn create_request(&self, request: Request) -> (u16, Message) { @@ -208,6 +217,8 @@ impl DhtState { }; match &msg.kind { + // If it's a response to a request we made, find the request task, notify it with the response, + // and let it handle it. MessageKind::Error(_) | MessageKind::Response(_) => { if msg.transaction_id.len() != 2 { anyhow::bail!( @@ -217,29 +228,32 @@ impl DhtState { ) } let tid = ((msg.transaction_id[0] as u16) << 8) + (msg.transaction_id[1] as u16); - let request = match self - .outstanding_requests_by_transaction_id - .remove(&(tid, addr)) - .map(|(_, v)| v) - { + let request = match self.inflight.remove(&(tid, addr)).map(|(_, v)| v) { Some(req) => req, None => anyhow::bail!("outstanding request not found. Message: {:?}", msg), }; - let request = { - let _ = request.done.send(()); - request.request - }; - let response = match msg.kind { - MessageKind::Error(e) => { - anyhow::bail!("request {:?} received error response {:?}", request, e) + + let response_or_error = match msg.kind { + MessageKind::Error(e) => ResponseOrError::Error(e), + MessageKind::Response(r) => { + self.routing_table.write().mark_response(&r.id); + ResponseOrError::Response(r) } - MessageKind::Response(r) => r, _ => unreachable!(), }; - self.routing_table.write().mark_response(&response.id); - self.on_response(addr, request, response) + match request.done.send(response_or_error) { + Ok(_) => {} + Err(e) => { + warn!( + "recieved response, but the receiver task is closed: {:?}", + e + ); + } + } + Ok(()) } - MessageKind::PingRequest(_) => { + // Otherwise, respond to a query. + MessageKind::PingRequest(req) => { let message = Message { transaction_id: msg.transaction_id, version: None, @@ -249,6 +263,7 @@ impl DhtState { ..Default::default() }), }; + self.routing_table.write().mark_last_query(&req.id); self.sender.send((message, addr))?; Ok(()) } @@ -310,7 +325,7 @@ impl DhtState { pub fn get_stats(&self) -> DhtStats { DhtStats { id: self.id, - outstanding_requests: self.outstanding_requests_by_transaction_id.len(), + outstanding_requests: self.inflight.len(), seen_peers: self.seen_peers.iter().map(|e| e.value().len()).sum(), made_requests: self.made_requests_by_addr.len(), routing_table_size: self.routing_table.read().len(), @@ -392,7 +407,7 @@ impl DhtState { self.routing_table .write() .mark_outgoing_request(&target_node); - self.send_request(request, addr)?; + self.spawn_request(request, addr); } Ok(()) } @@ -408,7 +423,7 @@ impl DhtState { self.routing_table .write() .mark_outgoing_request(&target_node); - self.send_request(request, addr)?; + self.spawn_request(request, addr); } Ok(()) } @@ -420,7 +435,7 @@ impl DhtState { true }); for addr in questionable_nodes { - let _ = self.send_request(Request::Ping, addr); + self.spawn_request(Request::Ping, addr); } res } @@ -596,6 +611,12 @@ enum Request { Ping, } +#[derive(Debug)] +enum ResponseOrError { + Response(Response), + Error(ErrorDescription), +} + struct DhtWorker { socket: UdpSocket, peer_id: Id20, @@ -607,6 +628,103 @@ impl DhtWorker { self.state.on_incoming_from_remote(msg, addr) } + async fn bootstrap_one_ip_with_backoff(&self, addr: SocketAddr) -> anyhow::Result<()> { + let mut backoff = ExponentialBackoffBuilder::new() + .with_initial_interval(Duration::from_secs(10)) + .with_multiplier(1.5) + .with_max_interval(Duration::from_secs(60)) + .with_max_elapsed_time(Some(Duration::from_secs(86400))) + .build(); + + loop { + let res = self + .state + .send_request_and_handle_response(Request::FindNode(self.peer_id), addr) + .await; + match res { + Ok(r) => return Ok(r), + Err(e) => { + debug!("error: {:?}", e); + if let Some(backoff) = backoff.next_backoff() { + tokio::time::sleep(backoff).await; + continue; + } + anyhow::bail!("given up bootstrapping, timed out") + } + } + } + } + + async fn bootstrap_hostname(&self, hostname: &str) -> anyhow::Result<()> { + let addrs = tokio::net::lookup_host(hostname) + .await + .with_context(|| format!("error looking up {}", hostname))?; + let mut futs = FuturesUnordered::new(); + for addr in addrs { + futs.push( + self.bootstrap_one_ip_with_backoff(addr) + .instrument(error_span!("addr", addr = addr.to_string())), + ); + } + let requests = futs.len(); + let mut successes = 0; + while let Some(resp) = futs.next().await { + if resp.is_ok() { + successes += 1 + }; + } + if successes == 0 { + anyhow::bail!("none of the {} bootstrap requests succeded", requests); + } + Ok(()) + } + + async fn bootstrap_hostname_with_backoff(&self, addr: &str) -> anyhow::Result<()> { + let mut backoff = ExponentialBackoffBuilder::new() + .with_initial_interval(Duration::from_secs(10)) + .with_multiplier(1.5) + .with_max_interval(Duration::from_secs(60)) + .with_max_elapsed_time(Some(Duration::from_secs(86400))) + .build(); + + loop { + let backoff = match self.bootstrap_hostname(addr).await { + Ok(_) => return Ok(()), + Err(e) => { + warn!("error: {}", e); + backoff.next_backoff() + } + }; + if let Some(backoff) = backoff { + tokio::time::sleep(backoff).await; + continue; + } + anyhow::bail!("bootstrap failed") + } + } + + async fn bootstrap(&self, bootstrap_addrs: &[String]) -> anyhow::Result<()> { + let mut futs = FuturesUnordered::new(); + + for addr in bootstrap_addrs.iter() { + let this = &self; + futs.push( + this.bootstrap_hostname_with_backoff(addr) + .instrument(error_span!("bootstrap", hostname = addr)), + ); + } + let mut successes = 0; + while let Some(resp) = futs.next().await { + if resp.is_ok() { + successes += 1 + } + } + if successes == 0 { + anyhow::bail!("bootstrapping failed") + } + Ok(()) + } + async fn start( self, in_rx: UnboundedReceiver<(Message, SocketAddr)>, @@ -615,42 +733,7 @@ impl DhtWorker { let (out_tx, mut out_rx) = channel(1); let framer = run_framer(&self.socket, in_rx, out_tx).instrument(debug_span!("dht_framer")); - let bootstrap = async { - let mut futs = FuturesUnordered::new(); - // bootstrap - for addr in bootstrap_addrs.iter() { - let this = &self; - futs.push( - async move { - match tokio::net::lookup_host(addr).await { - Ok(addrs) => { - for addr in addrs { - this.state - .send_request(Request::FindNode(this.peer_id), addr)?; - } - } - Err(e) => { - warn!("error looking up {}: {}", addr, e); - return Err(e.into()); - } - } - Ok::<_, anyhow::Error>(()) - } - .instrument(error_span!("dht_bootstrap", addr = addr)), - ); - } - let mut successes = 0; - while let Some(resp) = futs.next().await { - if resp.is_ok() { - successes += 1 - } - } - if successes == 0 { - anyhow::bail!("bootstrapping did not succeed") - } - Ok(()) - } - .instrument(debug_span!("dht_bootstrapper")); + let bootstrap = self.bootstrap(bootstrap_addrs); let mut bootstrap_done = false; let response_reader = { diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index f9fa5c4..861dd01 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -1,4 +1,4 @@ -use std::{io::BufWriter, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; +use std::{io::LineWriter, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; use anyhow::Context; use clap::{Parser, ValueEnum}; @@ -205,7 +205,7 @@ fn init_logging(opts: &Opts) -> tokio::sync::mpsc::UnboundedSender { if let Some(log_file) = &opts.log_file { let log_file = log_file.clone(); let log_file = move || { - BufWriter::new( + LineWriter::new( std::fs::OpenOptions::new() .create(true) .append(true)