DHT: better tracking requests/responses

This commit is contained in:
Igor Katson 2023-11-28 10:53:22 +00:00
parent 0478577a72
commit 336bf751e3
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
4 changed files with 182 additions and 97 deletions

1
Cargo.lock generated
View file

@ -1087,6 +1087,7 @@ name = "librqbit-dht"
version = "3.2.0" version = "3.2.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"backoff",
"dashmap", "dashmap",
"directories", "directories",
"futures", "futures",

View file

@ -27,6 +27,7 @@ bencode = {path = "../bencode", default-features=false, package="librqbit-bencod
anyhow = "1" anyhow = "1"
parking_lot = "0.12" parking_lot = "0.12"
tracing = "0.1" tracing = "0.1"
backoff = "0.4.0"
futures = "0.3" futures = "0.3"
rand = "0.8" rand = "0.8"
indexmap = "2" indexmap = "2"

View file

@ -1,4 +1,5 @@
use std::{ use std::{
f32::consts::E,
net::SocketAddr, net::SocketAddr,
sync::{ sync::{
atomic::{AtomicU16, Ordering}, atomic::{AtomicU16, Ordering},
@ -10,16 +11,17 @@ use std::{
use crate::{ use crate::{
bprotocol::{ bprotocol::{
self, CompactNodeInfo, CompactPeerInfo, FindNodeRequest, GetPeersRequest, Message, self, CompactNodeInfo, CompactPeerInfo, ErrorDescription, FindNodeRequest, GetPeersRequest,
MessageKind, Node, PingRequest, Response, Message, MessageKind, Node, PingRequest, Response,
}, },
routing_table::{InsertResult, RoutingTable}, routing_table::{InsertResult, RoutingTable},
RESPONSE_TIMEOUT, RESPONSE_TIMEOUT,
}; };
use anyhow::Context; use anyhow::Context;
use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
use bencode::{ByteBuf, ByteString}; use bencode::{ByteBuf, ByteString};
use dashmap::DashMap; use dashmap::DashMap;
use futures::{stream::FuturesUnordered, Stream, StreamExt}; use futures::{future::join_all, stream::FuturesUnordered, Stream, StreamExt, TryFutureExt};
use indexmap::IndexSet; use indexmap::IndexSet;
use leaky_bucket::RateLimiter; use leaky_bucket::RateLimiter;
use librqbit_core::{id20::Id20, peer_id::generate_peer_id, spawn_utils::spawn}; use librqbit_core::{id20::Id20, peer_id::generate_peer_id, spawn_utils::spawn};
@ -44,8 +46,7 @@ pub struct DhtStats {
} }
struct OutstandingRequest { struct OutstandingRequest {
request: Request, done: tokio::sync::oneshot::Sender<ResponseOrError>,
done: tokio::sync::oneshot::Sender<()>,
} }
pub struct DhtState { pub struct DhtState {
@ -54,7 +55,7 @@ pub struct DhtState {
// Created requests: (transaction_id, addr) => Requests. // Created requests: (transaction_id, addr) => Requests.
// If we get a response, it gets removed from here. // If we get a response, it gets removed from here.
outstanding_requests_by_transaction_id: DashMap<(u16, SocketAddr), OutstandingRequest>, inflight: DashMap<(u16, SocketAddr), OutstandingRequest>,
// TODO: clean up old entries // TODO: clean up old entries
made_requests_by_addr: DashMap<(Request, SocketAddr), Instant>, made_requests_by_addr: DashMap<(Request, SocketAddr), Instant>,
@ -62,11 +63,7 @@ pub struct DhtState {
routing_table: RwLock<RoutingTable>, routing_table: RwLock<RoutingTable>,
listen_addr: SocketAddr, listen_addr: SocketAddr,
// This sender sends requests to the worker. // Sending requests to the worker.
// It is unbounded so that the methods on Dht state don't need to be async.
// If the methods on Dht state were async, we would have a problem, as it's behind
// a lock.
// Alternatively, we can lock only the parts that change, and use that internally inside DhtState...
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>, sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
seen_peers: DashMap<Id20, IndexSet<SocketAddr>>, seen_peers: DashMap<Id20, IndexSet<SocketAddr>>,
@ -84,7 +81,7 @@ impl DhtState {
Self { Self {
id, id,
next_transaction_id: AtomicU16::new(0), next_transaction_id: AtomicU16::new(0),
outstanding_requests_by_transaction_id: Default::default(), inflight: Default::default(),
routing_table: RwLock::new(routing_table), routing_table: RwLock::new(routing_table),
sender, sender,
listen_addr, listen_addr,
@ -94,41 +91,53 @@ impl DhtState {
} }
} }
fn send_request(self: &Arc<Self>, request: Request, addr: SocketAddr) -> anyhow::Result<()> { fn spawn_request(self: &Arc<Self>, request: Request, addr: SocketAddr) {
let this = self.clone();
spawn(
error_span!(parent: None, "dht_request", addr=addr.to_string(), request=format!("{:?}", request)),
async move { this.send_request_and_handle_response(request, addr).await },
);
}
async fn send_request_and_handle_response(
self: &Arc<Self>,
request: Request,
addr: SocketAddr,
) -> anyhow::Result<()> {
let resp = self.request(request, addr).await?;
match resp {
ResponseOrError::Response(r) => self.on_response(addr, request, r),
ResponseOrError::Error(e) => {
anyhow::bail!("received error: {:?}", e);
}
}
}
async fn request(&self, request: Request, addr: SocketAddr) -> anyhow::Result<ResponseOrError> {
let (tid, msg) = self.create_request(request); let (tid, msg) = self.create_request(request);
let key = (tid, addr);
let (tx, rx) = tokio::sync::oneshot::channel(); let (tx, rx) = tokio::sync::oneshot::channel();
self.outstanding_requests_by_transaction_id self.inflight.insert(key, OutstandingRequest { done: tx });
.insert((tid, addr), OutstandingRequest { request, done: tx });
match self.sender.send((msg, addr)) { match self.sender.send((msg, addr)) {
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
self.outstanding_requests_by_transaction_id self.inflight.remove(&key);
.remove(&(tid, addr));
return Err(e.into()); return Err(e.into());
} }
}; };
let this = self.clone(); match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await {
spawn( Ok(Ok(r)) => Ok(r),
debug_span!("dht_request", tid = tid, addr = addr.to_string()), Ok(Err(e)) => {
async move { self.inflight.remove(&key);
match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await { warn!("recv error, did not expect this: {:?}", e);
Ok(Ok(_)) => {} Err(e.into())
Ok(Err(e)) => { }
this.outstanding_requests_by_transaction_id Err(e) => {
.remove(&(tid, addr)); self.inflight.remove(&key);
warn!("recv error, did not expect this: {:?}", e); debug!("error: {:?}", e);
} anyhow::bail!("timeout")
Err(e) => { }
this.outstanding_requests_by_transaction_id }
.remove(&(tid, addr));
debug!("error: {:?}", e);
}
};
Ok(())
},
);
Ok(())
} }
fn create_request(&self, request: Request) -> (u16, Message<ByteString>) { fn create_request(&self, request: Request) -> (u16, Message<ByteString>) {
@ -208,6 +217,8 @@ impl DhtState {
}; };
match &msg.kind { match &msg.kind {
// If it's a response to a request we made, find the request task, notify it with the response,
// and let it handle it.
MessageKind::Error(_) | MessageKind::Response(_) => { MessageKind::Error(_) | MessageKind::Response(_) => {
if msg.transaction_id.len() != 2 { if msg.transaction_id.len() != 2 {
anyhow::bail!( anyhow::bail!(
@ -217,29 +228,32 @@ impl DhtState {
) )
} }
let tid = ((msg.transaction_id[0] as u16) << 8) + (msg.transaction_id[1] as u16); let tid = ((msg.transaction_id[0] as u16) << 8) + (msg.transaction_id[1] as u16);
let request = match self let request = match self.inflight.remove(&(tid, addr)).map(|(_, v)| v) {
.outstanding_requests_by_transaction_id
.remove(&(tid, addr))
.map(|(_, v)| v)
{
Some(req) => req, Some(req) => req,
None => anyhow::bail!("outstanding request not found. Message: {:?}", msg), None => anyhow::bail!("outstanding request not found. Message: {:?}", msg),
}; };
let request = {
let _ = request.done.send(()); let response_or_error = match msg.kind {
request.request MessageKind::Error(e) => ResponseOrError::Error(e),
}; MessageKind::Response(r) => {
let response = match msg.kind { self.routing_table.write().mark_response(&r.id);
MessageKind::Error(e) => { ResponseOrError::Response(r)
anyhow::bail!("request {:?} received error response {:?}", request, e)
} }
MessageKind::Response(r) => r,
_ => unreachable!(), _ => unreachable!(),
}; };
self.routing_table.write().mark_response(&response.id); match request.done.send(response_or_error) {
self.on_response(addr, request, response) Ok(_) => {}
Err(e) => {
warn!(
"recieved response, but the receiver task is closed: {:?}",
e
);
}
}
Ok(())
} }
MessageKind::PingRequest(_) => { // Otherwise, respond to a query.
MessageKind::PingRequest(req) => {
let message = Message { let message = Message {
transaction_id: msg.transaction_id, transaction_id: msg.transaction_id,
version: None, version: None,
@ -249,6 +263,7 @@ impl DhtState {
..Default::default() ..Default::default()
}), }),
}; };
self.routing_table.write().mark_last_query(&req.id);
self.sender.send((message, addr))?; self.sender.send((message, addr))?;
Ok(()) Ok(())
} }
@ -310,7 +325,7 @@ impl DhtState {
pub fn get_stats(&self) -> DhtStats { pub fn get_stats(&self) -> DhtStats {
DhtStats { DhtStats {
id: self.id, id: self.id,
outstanding_requests: self.outstanding_requests_by_transaction_id.len(), outstanding_requests: self.inflight.len(),
seen_peers: self.seen_peers.iter().map(|e| e.value().len()).sum(), seen_peers: self.seen_peers.iter().map(|e| e.value().len()).sum(),
made_requests: self.made_requests_by_addr.len(), made_requests: self.made_requests_by_addr.len(),
routing_table_size: self.routing_table.read().len(), routing_table_size: self.routing_table.read().len(),
@ -392,7 +407,7 @@ impl DhtState {
self.routing_table self.routing_table
.write() .write()
.mark_outgoing_request(&target_node); .mark_outgoing_request(&target_node);
self.send_request(request, addr)?; self.spawn_request(request, addr);
} }
Ok(()) Ok(())
} }
@ -408,7 +423,7 @@ impl DhtState {
self.routing_table self.routing_table
.write() .write()
.mark_outgoing_request(&target_node); .mark_outgoing_request(&target_node);
self.send_request(request, addr)?; self.spawn_request(request, addr);
} }
Ok(()) Ok(())
} }
@ -420,7 +435,7 @@ impl DhtState {
true true
}); });
for addr in questionable_nodes { for addr in questionable_nodes {
let _ = self.send_request(Request::Ping, addr); self.spawn_request(Request::Ping, addr);
} }
res res
} }
@ -596,6 +611,12 @@ enum Request {
Ping, Ping,
} }
#[derive(Debug)]
enum ResponseOrError {
Response(Response<ByteString>),
Error(ErrorDescription<ByteString>),
}
struct DhtWorker { struct DhtWorker {
socket: UdpSocket, socket: UdpSocket,
peer_id: Id20, peer_id: Id20,
@ -607,6 +628,103 @@ impl DhtWorker {
self.state.on_incoming_from_remote(msg, addr) self.state.on_incoming_from_remote(msg, addr)
} }
async fn bootstrap_one_ip_with_backoff(&self, addr: SocketAddr) -> anyhow::Result<()> {
let mut backoff = ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_secs(10))
.with_multiplier(1.5)
.with_max_interval(Duration::from_secs(60))
.with_max_elapsed_time(Some(Duration::from_secs(86400)))
.build();
loop {
let res = self
.state
.send_request_and_handle_response(Request::FindNode(self.peer_id), addr)
.await;
match res {
Ok(r) => return Ok(r),
Err(e) => {
debug!("error: {:?}", e);
if let Some(backoff) = backoff.next_backoff() {
tokio::time::sleep(backoff).await;
continue;
}
anyhow::bail!("given up bootstrapping, timed out")
}
}
}
}
async fn bootstrap_hostname(&self, hostname: &str) -> anyhow::Result<()> {
let addrs = tokio::net::lookup_host(hostname)
.await
.with_context(|| format!("error looking up {}", hostname))?;
let mut futs = FuturesUnordered::new();
for addr in addrs {
futs.push(
self.bootstrap_one_ip_with_backoff(addr)
.instrument(error_span!("addr", addr = addr.to_string())),
);
}
let requests = futs.len();
let mut successes = 0;
while let Some(resp) = futs.next().await {
if resp.is_ok() {
successes += 1
};
}
if successes == 0 {
anyhow::bail!("none of the {} bootstrap requests succeded", requests);
}
Ok(())
}
async fn bootstrap_hostname_with_backoff(&self, addr: &str) -> anyhow::Result<()> {
let mut backoff = ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_secs(10))
.with_multiplier(1.5)
.with_max_interval(Duration::from_secs(60))
.with_max_elapsed_time(Some(Duration::from_secs(86400)))
.build();
loop {
let backoff = match self.bootstrap_hostname(addr).await {
Ok(_) => return Ok(()),
Err(e) => {
warn!("error: {}", e);
backoff.next_backoff()
}
};
if let Some(backoff) = backoff {
tokio::time::sleep(backoff).await;
continue;
}
anyhow::bail!("bootstrap failed")
}
}
async fn bootstrap(&self, bootstrap_addrs: &[String]) -> anyhow::Result<()> {
let mut futs = FuturesUnordered::new();
for addr in bootstrap_addrs.iter() {
let this = &self;
futs.push(
this.bootstrap_hostname_with_backoff(addr)
.instrument(error_span!("bootstrap", hostname = addr)),
);
}
let mut successes = 0;
while let Some(resp) = futs.next().await {
if resp.is_ok() {
successes += 1
}
}
if successes == 0 {
anyhow::bail!("bootstrapping failed")
}
Ok(())
}
async fn start( async fn start(
self, self,
in_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>, in_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
@ -615,42 +733,7 @@ impl DhtWorker {
let (out_tx, mut out_rx) = channel(1); let (out_tx, mut out_rx) = channel(1);
let framer = run_framer(&self.socket, in_rx, out_tx).instrument(debug_span!("dht_framer")); let framer = run_framer(&self.socket, in_rx, out_tx).instrument(debug_span!("dht_framer"));
let bootstrap = async { let bootstrap = self.bootstrap(bootstrap_addrs);
let mut futs = FuturesUnordered::new();
// bootstrap
for addr in bootstrap_addrs.iter() {
let this = &self;
futs.push(
async move {
match tokio::net::lookup_host(addr).await {
Ok(addrs) => {
for addr in addrs {
this.state
.send_request(Request::FindNode(this.peer_id), addr)?;
}
}
Err(e) => {
warn!("error looking up {}: {}", addr, e);
return Err(e.into());
}
}
Ok::<_, anyhow::Error>(())
}
.instrument(error_span!("dht_bootstrap", addr = addr)),
);
}
let mut successes = 0;
while let Some(resp) = futs.next().await {
if resp.is_ok() {
successes += 1
}
}
if successes == 0 {
anyhow::bail!("bootstrapping did not succeed")
}
Ok(())
}
.instrument(debug_span!("dht_bootstrapper"));
let mut bootstrap_done = false; let mut bootstrap_done = false;
let response_reader = { let response_reader = {

View file

@ -1,4 +1,4 @@
use std::{io::BufWriter, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; use std::{io::LineWriter, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};
use anyhow::Context; use anyhow::Context;
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
@ -205,7 +205,7 @@ fn init_logging(opts: &Opts) -> tokio::sync::mpsc::UnboundedSender<String> {
if let Some(log_file) = &opts.log_file { if let Some(log_file) = &opts.log_file {
let log_file = log_file.clone(); let log_file = log_file.clone();
let log_file = move || { let log_file = move || {
BufWriter::new( LineWriter::new(
std::fs::OpenOptions::new() std::fs::OpenOptions::new()
.create(true) .create(true)
.append(true) .append(true)