DHT: better tracking requests/responses

This commit is contained in:
Igor Katson 2023-11-28 10:53:22 +00:00
parent 0478577a72
commit 336bf751e3
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
4 changed files with 182 additions and 97 deletions

1
Cargo.lock generated
View file

@ -1087,6 +1087,7 @@ name = "librqbit-dht"
version = "3.2.0"
dependencies = [
"anyhow",
"backoff",
"dashmap",
"directories",
"futures",

View file

@ -27,6 +27,7 @@ bencode = {path = "../bencode", default-features=false, package="librqbit-bencod
anyhow = "1"
parking_lot = "0.12"
tracing = "0.1"
backoff = "0.4.0"
futures = "0.3"
rand = "0.8"
indexmap = "2"

View file

@ -1,4 +1,5 @@
use std::{
f32::consts::E,
net::SocketAddr,
sync::{
atomic::{AtomicU16, Ordering},
@ -10,16 +11,17 @@ use std::{
use crate::{
bprotocol::{
self, CompactNodeInfo, CompactPeerInfo, FindNodeRequest, GetPeersRequest, Message,
MessageKind, Node, PingRequest, Response,
self, CompactNodeInfo, CompactPeerInfo, ErrorDescription, FindNodeRequest, GetPeersRequest,
Message, MessageKind, Node, PingRequest, Response,
},
routing_table::{InsertResult, RoutingTable},
RESPONSE_TIMEOUT,
};
use anyhow::Context;
use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
use bencode::{ByteBuf, ByteString};
use dashmap::DashMap;
use futures::{stream::FuturesUnordered, Stream, StreamExt};
use futures::{future::join_all, stream::FuturesUnordered, Stream, StreamExt, TryFutureExt};
use indexmap::IndexSet;
use leaky_bucket::RateLimiter;
use librqbit_core::{id20::Id20, peer_id::generate_peer_id, spawn_utils::spawn};
@ -44,8 +46,7 @@ pub struct DhtStats {
}
struct OutstandingRequest {
request: Request,
done: tokio::sync::oneshot::Sender<()>,
done: tokio::sync::oneshot::Sender<ResponseOrError>,
}
pub struct DhtState {
@ -54,7 +55,7 @@ pub struct DhtState {
// Created requests: (transaction_id, addr) => Requests.
// If we get a response, it gets removed from here.
outstanding_requests_by_transaction_id: DashMap<(u16, SocketAddr), OutstandingRequest>,
inflight: DashMap<(u16, SocketAddr), OutstandingRequest>,
// TODO: clean up old entries
made_requests_by_addr: DashMap<(Request, SocketAddr), Instant>,
@ -62,11 +63,7 @@ pub struct DhtState {
routing_table: RwLock<RoutingTable>,
listen_addr: SocketAddr,
// This sender sends requests to the worker.
// It is unbounded so that the methods on Dht state don't need to be async.
// If the methods on Dht state were async, we would have a problem, as it's behind
// a lock.
// Alternatively, we can lock only the parts that change, and use that internally inside DhtState...
// Sending requests to the worker.
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
seen_peers: DashMap<Id20, IndexSet<SocketAddr>>,
@ -84,7 +81,7 @@ impl DhtState {
Self {
id,
next_transaction_id: AtomicU16::new(0),
outstanding_requests_by_transaction_id: Default::default(),
inflight: Default::default(),
routing_table: RwLock::new(routing_table),
sender,
listen_addr,
@ -94,41 +91,53 @@ impl DhtState {
}
}
fn send_request(self: &Arc<Self>, request: Request, addr: SocketAddr) -> anyhow::Result<()> {
fn spawn_request(self: &Arc<Self>, request: Request, addr: SocketAddr) {
let this = self.clone();
spawn(
error_span!(parent: None, "dht_request", addr=addr.to_string(), request=format!("{:?}", request)),
async move { this.send_request_and_handle_response(request, addr).await },
);
}
async fn send_request_and_handle_response(
self: &Arc<Self>,
request: Request,
addr: SocketAddr,
) -> anyhow::Result<()> {
let resp = self.request(request, addr).await?;
match resp {
ResponseOrError::Response(r) => self.on_response(addr, request, r),
ResponseOrError::Error(e) => {
anyhow::bail!("received error: {:?}", e);
}
}
}
async fn request(&self, request: Request, addr: SocketAddr) -> anyhow::Result<ResponseOrError> {
let (tid, msg) = self.create_request(request);
let key = (tid, addr);
let (tx, rx) = tokio::sync::oneshot::channel();
self.outstanding_requests_by_transaction_id
.insert((tid, addr), OutstandingRequest { request, done: tx });
self.inflight.insert(key, OutstandingRequest { done: tx });
match self.sender.send((msg, addr)) {
Ok(_) => {}
Err(e) => {
self.outstanding_requests_by_transaction_id
.remove(&(tid, addr));
self.inflight.remove(&key);
return Err(e.into());
}
};
let this = self.clone();
spawn(
debug_span!("dht_request", tid = tid, addr = addr.to_string()),
async move {
match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await {
Ok(Ok(_)) => {}
Ok(Err(e)) => {
this.outstanding_requests_by_transaction_id
.remove(&(tid, addr));
warn!("recv error, did not expect this: {:?}", e);
}
Err(e) => {
this.outstanding_requests_by_transaction_id
.remove(&(tid, addr));
debug!("error: {:?}", e);
}
};
Ok(())
},
);
Ok(())
match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await {
Ok(Ok(r)) => Ok(r),
Ok(Err(e)) => {
self.inflight.remove(&key);
warn!("recv error, did not expect this: {:?}", e);
Err(e.into())
}
Err(e) => {
self.inflight.remove(&key);
debug!("error: {:?}", e);
anyhow::bail!("timeout")
}
}
}
fn create_request(&self, request: Request) -> (u16, Message<ByteString>) {
@ -208,6 +217,8 @@ impl DhtState {
};
match &msg.kind {
// If it's a response to a request we made, find the request task, notify it with the response,
// and let it handle it.
MessageKind::Error(_) | MessageKind::Response(_) => {
if msg.transaction_id.len() != 2 {
anyhow::bail!(
@ -217,29 +228,32 @@ impl DhtState {
)
}
let tid = ((msg.transaction_id[0] as u16) << 8) + (msg.transaction_id[1] as u16);
let request = match self
.outstanding_requests_by_transaction_id
.remove(&(tid, addr))
.map(|(_, v)| v)
{
let request = match self.inflight.remove(&(tid, addr)).map(|(_, v)| v) {
Some(req) => req,
None => anyhow::bail!("outstanding request not found. Message: {:?}", msg),
};
let request = {
let _ = request.done.send(());
request.request
};
let response = match msg.kind {
MessageKind::Error(e) => {
anyhow::bail!("request {:?} received error response {:?}", request, e)
let response_or_error = match msg.kind {
MessageKind::Error(e) => ResponseOrError::Error(e),
MessageKind::Response(r) => {
self.routing_table.write().mark_response(&r.id);
ResponseOrError::Response(r)
}
MessageKind::Response(r) => r,
_ => unreachable!(),
};
self.routing_table.write().mark_response(&response.id);
self.on_response(addr, request, response)
match request.done.send(response_or_error) {
Ok(_) => {}
Err(e) => {
warn!(
"recieved response, but the receiver task is closed: {:?}",
e
);
}
}
Ok(())
}
MessageKind::PingRequest(_) => {
// Otherwise, respond to a query.
MessageKind::PingRequest(req) => {
let message = Message {
transaction_id: msg.transaction_id,
version: None,
@ -249,6 +263,7 @@ impl DhtState {
..Default::default()
}),
};
self.routing_table.write().mark_last_query(&req.id);
self.sender.send((message, addr))?;
Ok(())
}
@ -310,7 +325,7 @@ impl DhtState {
pub fn get_stats(&self) -> DhtStats {
DhtStats {
id: self.id,
outstanding_requests: self.outstanding_requests_by_transaction_id.len(),
outstanding_requests: self.inflight.len(),
seen_peers: self.seen_peers.iter().map(|e| e.value().len()).sum(),
made_requests: self.made_requests_by_addr.len(),
routing_table_size: self.routing_table.read().len(),
@ -392,7 +407,7 @@ impl DhtState {
self.routing_table
.write()
.mark_outgoing_request(&target_node);
self.send_request(request, addr)?;
self.spawn_request(request, addr);
}
Ok(())
}
@ -408,7 +423,7 @@ impl DhtState {
self.routing_table
.write()
.mark_outgoing_request(&target_node);
self.send_request(request, addr)?;
self.spawn_request(request, addr);
}
Ok(())
}
@ -420,7 +435,7 @@ impl DhtState {
true
});
for addr in questionable_nodes {
let _ = self.send_request(Request::Ping, addr);
self.spawn_request(Request::Ping, addr);
}
res
}
@ -596,6 +611,12 @@ enum Request {
Ping,
}
#[derive(Debug)]
enum ResponseOrError {
Response(Response<ByteString>),
Error(ErrorDescription<ByteString>),
}
struct DhtWorker {
socket: UdpSocket,
peer_id: Id20,
@ -607,6 +628,103 @@ impl DhtWorker {
self.state.on_incoming_from_remote(msg, addr)
}
async fn bootstrap_one_ip_with_backoff(&self, addr: SocketAddr) -> anyhow::Result<()> {
let mut backoff = ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_secs(10))
.with_multiplier(1.5)
.with_max_interval(Duration::from_secs(60))
.with_max_elapsed_time(Some(Duration::from_secs(86400)))
.build();
loop {
let res = self
.state
.send_request_and_handle_response(Request::FindNode(self.peer_id), addr)
.await;
match res {
Ok(r) => return Ok(r),
Err(e) => {
debug!("error: {:?}", e);
if let Some(backoff) = backoff.next_backoff() {
tokio::time::sleep(backoff).await;
continue;
}
anyhow::bail!("given up bootstrapping, timed out")
}
}
}
}
async fn bootstrap_hostname(&self, hostname: &str) -> anyhow::Result<()> {
let addrs = tokio::net::lookup_host(hostname)
.await
.with_context(|| format!("error looking up {}", hostname))?;
let mut futs = FuturesUnordered::new();
for addr in addrs {
futs.push(
self.bootstrap_one_ip_with_backoff(addr)
.instrument(error_span!("addr", addr = addr.to_string())),
);
}
let requests = futs.len();
let mut successes = 0;
while let Some(resp) = futs.next().await {
if resp.is_ok() {
successes += 1
};
}
if successes == 0 {
anyhow::bail!("none of the {} bootstrap requests succeded", requests);
}
Ok(())
}
async fn bootstrap_hostname_with_backoff(&self, addr: &str) -> anyhow::Result<()> {
let mut backoff = ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_secs(10))
.with_multiplier(1.5)
.with_max_interval(Duration::from_secs(60))
.with_max_elapsed_time(Some(Duration::from_secs(86400)))
.build();
loop {
let backoff = match self.bootstrap_hostname(addr).await {
Ok(_) => return Ok(()),
Err(e) => {
warn!("error: {}", e);
backoff.next_backoff()
}
};
if let Some(backoff) = backoff {
tokio::time::sleep(backoff).await;
continue;
}
anyhow::bail!("bootstrap failed")
}
}
async fn bootstrap(&self, bootstrap_addrs: &[String]) -> anyhow::Result<()> {
let mut futs = FuturesUnordered::new();
for addr in bootstrap_addrs.iter() {
let this = &self;
futs.push(
this.bootstrap_hostname_with_backoff(addr)
.instrument(error_span!("bootstrap", hostname = addr)),
);
}
let mut successes = 0;
while let Some(resp) = futs.next().await {
if resp.is_ok() {
successes += 1
}
}
if successes == 0 {
anyhow::bail!("bootstrapping failed")
}
Ok(())
}
async fn start(
self,
in_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
@ -615,42 +733,7 @@ impl DhtWorker {
let (out_tx, mut out_rx) = channel(1);
let framer = run_framer(&self.socket, in_rx, out_tx).instrument(debug_span!("dht_framer"));
let bootstrap = async {
let mut futs = FuturesUnordered::new();
// bootstrap
for addr in bootstrap_addrs.iter() {
let this = &self;
futs.push(
async move {
match tokio::net::lookup_host(addr).await {
Ok(addrs) => {
for addr in addrs {
this.state
.send_request(Request::FindNode(this.peer_id), addr)?;
}
}
Err(e) => {
warn!("error looking up {}: {}", addr, e);
return Err(e.into());
}
}
Ok::<_, anyhow::Error>(())
}
.instrument(error_span!("dht_bootstrap", addr = addr)),
);
}
let mut successes = 0;
while let Some(resp) = futs.next().await {
if resp.is_ok() {
successes += 1
}
}
if successes == 0 {
anyhow::bail!("bootstrapping did not succeed")
}
Ok(())
}
.instrument(debug_span!("dht_bootstrapper"));
let bootstrap = self.bootstrap(bootstrap_addrs);
let mut bootstrap_done = false;
let response_reader = {

View file

@ -1,4 +1,4 @@
use std::{io::BufWriter, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};
use std::{io::LineWriter, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};
use anyhow::Context;
use clap::{Parser, ValueEnum};
@ -205,7 +205,7 @@ fn init_logging(opts: &Opts) -> tokio::sync::mpsc::UnboundedSender<String> {
if let Some(log_file) = &opts.log_file {
let log_file = log_file.clone();
let log_file = move || {
BufWriter::new(
LineWriter::new(
std::fs::OpenOptions::new()
.create(true)
.append(true)