Created more tasks but it impacts perf and memory badly

This commit is contained in:
Igor Katson 2023-11-28 15:35:27 +00:00
parent 91c99a272f
commit 93740ec84b
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
3 changed files with 126 additions and 74 deletions

View file

@ -333,7 +333,8 @@ pub struct Message<BufT> {
} }
impl Message<ByteString> { impl Message<ByteString> {
pub fn get_transaction_id(&self) -> Option<u16> { // This implies that the transaction id was generated by us.
pub fn get_our_transaction_id(&self) -> Option<u16> {
if self.transaction_id.len() != 2 { if self.transaction_id.len() != 2 {
return None; return None;
} }

View file

@ -5,7 +5,7 @@ use std::{
Arc, Arc,
}, },
task::Poll, task::Poll,
time::{Duration, Instant}, time::Duration,
}; };
use crate::{ use crate::{
@ -14,12 +14,12 @@ use crate::{
Message, MessageKind, Node, PingRequest, Response, Message, MessageKind, Node, PingRequest, Response,
}, },
routing_table::{InsertResult, RoutingTable}, routing_table::{InsertResult, RoutingTable},
RESPONSE_TIMEOUT, REQUERY_INTERVAL, RESPONSE_TIMEOUT,
}; };
use anyhow::Context; use anyhow::{bail, Context};
use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
use bencode::ByteString; use bencode::ByteString;
use dashmap::DashMap; use dashmap::{DashMap, DashSet};
use futures::{stream::FuturesUnordered, Stream, StreamExt}; use futures::{stream::FuturesUnordered, Stream, StreamExt};
use indexmap::IndexSet; use indexmap::IndexSet;
use leaky_bucket::RateLimiter; use leaky_bucket::RateLimiter;
@ -40,7 +40,7 @@ pub struct DhtStats {
pub id: Id20, pub id: Id20,
pub outstanding_requests: usize, pub outstanding_requests: usize,
pub seen_peers: usize, pub seen_peers: usize,
pub made_requests: usize, pub outstanding_backoff_tasks: usize,
pub routing_table_size: usize, pub routing_table_size: usize,
} }
@ -54,10 +54,10 @@ pub struct DhtState {
// Created requests: (transaction_id, addr) => Requests. // Created requests: (transaction_id, addr) => Requests.
// If we get a response, it gets removed from here. // If we get a response, it gets removed from here.
inflight: DashMap<(u16, SocketAddr), OutstandingRequest>, inflight_by_transaction_id: DashMap<(u16, SocketAddr), OutstandingRequest>,
// TODO: clean up old entries // Current requests to addr being re-sent with backoff.
made_requests_by_addr: DashMap<(Request, SocketAddr), Instant>, inflight_by_request: DashSet<(Request, SocketAddr)>,
routing_table: RwLock<RoutingTable>, routing_table: RwLock<RoutingTable>,
listen_addr: SocketAddr, listen_addr: SocketAddr,
@ -80,13 +80,13 @@ impl DhtState {
Self { Self {
id, id,
next_transaction_id: AtomicU16::new(0), next_transaction_id: AtomicU16::new(0),
inflight: Default::default(), inflight_by_transaction_id: Default::default(),
routing_table: RwLock::new(routing_table), routing_table: RwLock::new(routing_table),
sender, sender,
listen_addr, listen_addr,
seen_peers: Default::default(), seen_peers: Default::default(),
get_peers_subscribers: Default::default(), get_peers_subscribers: Default::default(),
made_requests_by_addr: Default::default(), inflight_by_request: Default::default(),
} }
} }
@ -115,7 +115,7 @@ impl DhtState {
match resp { match resp {
ResponseOrError::Response(r) => self.on_response(addr, request, r), ResponseOrError::Response(r) => self.on_response(addr, request, r),
ResponseOrError::Error(e) => { ResponseOrError::Error(e) => {
anyhow::bail!("received error: {:?}", e); bail!("received error: {:?}", e);
} }
} }
} }
@ -124,24 +124,25 @@ impl DhtState {
let (tid, msg) = self.create_request(request); let (tid, msg) = self.create_request(request);
let key = (tid, addr); let key = (tid, addr);
let (tx, rx) = tokio::sync::oneshot::channel(); let (tx, rx) = tokio::sync::oneshot::channel();
self.inflight.insert(key, OutstandingRequest { done: tx }); self.inflight_by_transaction_id
.insert(key, OutstandingRequest { done: tx });
match self.sender.send((msg, addr)) { match self.sender.send((msg, addr)) {
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
self.inflight.remove(&key); self.inflight_by_transaction_id.remove(&key);
return Err(e.into()); return Err(e.into());
} }
}; };
match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await { match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await {
Ok(Ok(r)) => r, Ok(Ok(r)) => r,
Ok(Err(e)) => { Ok(Err(e)) => {
self.inflight.remove(&key); self.inflight_by_transaction_id.remove(&key);
warn!("recv error, did not expect this: {:?}", e); warn!("recv error, did not expect this: {:?}", e);
Err(e.into()) Err(e.into())
} }
Err(_) => { Err(_) => {
self.inflight.remove(&key); self.inflight_by_transaction_id.remove(&key);
anyhow::bail!("timeout") bail!("timeout")
} }
} }
} }
@ -192,12 +193,14 @@ impl DhtState {
.ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?; .ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?;
self.on_found_nodes(response.id, addr, id, nodes) self.on_found_nodes(response.id, addr, id, nodes)
} }
Request::GetPeers(id) => self.on_found_peers_or_nodes(response.id, addr, id, response),
Request::Ping => Ok(()), Request::Ping => Ok(()),
Request::GetPeers(info_hash) => {
self.on_found_peers_or_nodes(response.id, addr, info_hash, response)
}
} }
} }
fn on_incoming_from_remote( fn on_received_message(
self: &Arc<Self>, self: &Arc<Self>,
msg: Message<ByteString>, msg: Message<ByteString>,
addr: SocketAddr, addr: SocketAddr,
@ -226,10 +229,14 @@ impl DhtState {
// If it's a response to a request we made, find the request task, notify it with the response, // If it's a response to a request we made, find the request task, notify it with the response,
// and let it handle it. // and let it handle it.
MessageKind::Error(_) | MessageKind::Response(_) => { MessageKind::Error(_) | MessageKind::Response(_) => {
let tid = msg.get_transaction_id().context("bad transaction id")?; let tid = msg.get_our_transaction_id().context("bad transaction id")?;
let request = match self.inflight.remove(&(tid, addr)).map(|(_, v)| v) { let request = match self
.inflight_by_transaction_id
.remove(&(tid, addr))
.map(|(_, v)| v)
{
Some(req) => req, Some(req) => req,
None => anyhow::bail!("outstanding request not found. Message: {:?}", msg), None => bail!("outstanding request not found. Message: {:?}", msg),
}; };
let response_or_error = match msg.kind { let response_or_error = match msg.kind {
@ -324,9 +331,9 @@ impl DhtState {
pub fn get_stats(&self) -> DhtStats { pub fn get_stats(&self) -> DhtStats {
DhtStats { DhtStats {
id: self.id, id: self.id,
outstanding_requests: self.inflight.len(), outstanding_requests: self.inflight_by_transaction_id.len(),
seen_peers: self.seen_peers.iter().map(|e| e.value().len()).sum(), seen_peers: self.seen_peers.iter().map(|e| e.value().len()).sum(),
made_requests: self.made_requests_by_addr.len(), outstanding_backoff_tasks: self.inflight_by_request.len(),
routing_table_size: self.routing_table.read().len(), routing_table_size: self.routing_table.read().len(),
} }
} }
@ -376,38 +383,86 @@ impl DhtState {
} }
} }
fn should_request(&self, request: Request, addr: SocketAddr) -> bool {
const RE_REQUEST_TIME: Duration = Duration::from_secs(10 * 60);
use dashmap::mapref::entry::Entry;
match self.made_requests_by_addr.entry((request, addr)) {
Entry::Occupied(mut o) => {
if o.get().elapsed() > RE_REQUEST_TIME {
o.insert(Instant::now());
true
} else {
false
}
}
Entry::Vacant(v) => {
v.insert(Instant::now());
true
}
}
}
fn send_find_peers_if_not_yet( fn send_find_peers_if_not_yet(
self: &Arc<Self>, self: &Arc<Self>,
info_hash: Id20, info_hash: Id20,
target_node: Id20, target_node: Id20,
addr: SocketAddr, addr: SocketAddr,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let request = Request::GetPeers(info_hash); self.send_request_if_not_yet(target_node, Request::GetPeers(info_hash), addr)
if self.should_request(request, addr) { }
self.routing_table
.write() fn send_request_if_not_yet(
.mark_outgoing_request(&target_node); self: &Arc<Self>,
self.spawn_request(request, addr); target_node: Id20,
request: Request,
addr: SocketAddr,
) -> anyhow::Result<()> {
let key = (request, addr);
if !self.inflight_by_request.insert(key) {
return Ok(());
} }
let this = self.clone();
let fut = async move {
let mut backoff = ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_secs(60))
.with_multiplier(1.5)
.with_max_interval(Duration::from_secs(10 * 60))
.with_max_elapsed_time(Some(Duration::from_secs(15 * 60)))
.build();
loop {
this.routing_table
.write()
.mark_outgoing_request(&target_node);
let resp = this.request(request, addr).await;
let sleep = match resp {
Ok(ResponseOrError::Response(response)) => {
match this.on_response(addr, request, response) {
Ok(()) => {
backoff.reset();
Some(REQUERY_INTERVAL)
}
Err(e) => {
warn!("error in on_response: {:?}", e);
backoff.next_backoff()
}
}
}
Ok(ResponseOrError::Error(e)) => {
debug!("error response: {:?}", e);
backoff.next_backoff()
}
Err(e) => {
debug!("error: {:?}", e);
backoff.next_backoff()
}
};
if let Some(sleep) = sleep {
tokio::time::sleep(sleep).await;
continue;
}
tokio::task::spawn(async move {
this.inflight_by_request.remove(&key);
});
return Ok(());
}
};
spawn(
error_span!(
parent: None,
"dht_request",
addr = addr.to_string(),
request = format!("{:?}", request),
),
fut,
);
Ok(()) Ok(())
} }
@ -417,14 +472,7 @@ impl DhtState {
target_node: Id20, target_node: Id20,
addr: SocketAddr, addr: SocketAddr,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let request = Request::FindNode(search_id); self.send_request_if_not_yet(target_node, Request::FindNode(search_id), addr)
if self.should_request(request, addr) {
self.routing_table
.write()
.mark_outgoing_request(&target_node);
self.spawn_request(request, addr);
}
Ok(())
} }
fn routing_table_add_node(self: &Arc<Self>, id: Id20, addr: SocketAddr) -> InsertResult { fn routing_table_add_node(self: &Arc<Self>, id: Id20, addr: SocketAddr) -> InsertResult {
@ -482,25 +530,25 @@ impl DhtState {
self: &Arc<Self>, self: &Arc<Self>,
source: Id20, source: Id20,
source_addr: SocketAddr, source_addr: SocketAddr,
target: Id20, info_hash: Id20,
data: bprotocol::Response<ByteString>, data: bprotocol::Response<ByteString>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
self.routing_table_add_node(source, source_addr); self.routing_table_add_node(source, source_addr);
self.routing_table.write().mark_response(&source); self.routing_table.write().mark_response(&source);
let bsender = match self.get_peers_subscribers.get(&target) { let bsender = match self.get_peers_subscribers.get(&info_hash) {
Some(s) => s, Some(s) => s,
None => { None => {
warn!( warn!(
"ignoring get_peers response, no subscribers for {:?}", "ignoring get_peers response, no subscribers for {:?}",
target info_hash
); );
return Ok(()); return Ok(());
} }
}; };
if let Some(peers) = data.values { if let Some(peers) = data.values {
let mut seen = self.seen_peers.entry(target).or_default(); let mut seen = self.seen_peers.entry(info_hash).or_default();
for peer in peers.iter() { for peer in peers.iter() {
if peer.addr.port() < 1024 { if peer.addr.port() < 1024 {
@ -518,7 +566,7 @@ impl DhtState {
if let Some(nodes) = data.nodes { if let Some(nodes) = data.nodes {
for node in nodes.nodes { for node in nodes.nodes {
self.routing_table_add_node(node.id, node.addr.into()); self.routing_table_add_node(node.id, node.addr.into());
self.send_find_peers_if_not_yet(target, node.id, node.addr.into())?; self.send_find_peers_if_not_yet(info_hash, node.id, node.addr.into())?;
} }
}; };
Ok(()) Ok(())
@ -562,12 +610,10 @@ struct DhtWorker {
} }
impl DhtWorker { impl DhtWorker {
fn on_response(&self, msg: Message<ByteString>, addr: SocketAddr) -> anyhow::Result<()> {
self.state.on_incoming_from_remote(msg, addr)
}
fn on_send_error(&self, tid: u16, addr: SocketAddr, err: anyhow::Error) { fn on_send_error(&self, tid: u16, addr: SocketAddr, err: anyhow::Error) {
if let Some((_, OutstandingRequest { done })) = self.state.inflight.remove(&(tid, addr)) { if let Some((_, OutstandingRequest { done })) =
self.state.inflight_by_transaction_id.remove(&(tid, addr))
{
let _ = done.send(Err(err)).is_err(); let _ = done.send(Err(err)).is_err();
}; };
} }
@ -593,7 +639,7 @@ impl DhtWorker {
tokio::time::sleep(backoff).await; tokio::time::sleep(backoff).await;
continue; continue;
} }
anyhow::bail!("given up bootstrapping, timed out") bail!("given up bootstrapping, timed out")
} }
} }
} }
@ -618,7 +664,7 @@ impl DhtWorker {
}; };
} }
if successes == 0 { if successes == 0 {
anyhow::bail!("none of the {} bootstrap requests succeded", requests); bail!("none of the {} bootstrap requests succeded", requests);
} }
Ok(()) Ok(())
} }
@ -643,7 +689,7 @@ impl DhtWorker {
tokio::time::sleep(backoff).await; tokio::time::sleep(backoff).await;
continue; continue;
} }
anyhow::bail!("bootstrap failed") bail!("bootstrap failed")
} }
} }
@ -664,7 +710,7 @@ impl DhtWorker {
} }
} }
if successes == 0 { if successes == 0 {
anyhow::bail!("bootstrapping failed") bail!("bootstrapping failed")
} }
Ok(()) Ok(())
} }
@ -682,7 +728,7 @@ impl DhtWorker {
rate_limiter.acquire_one().await; rate_limiter.acquire_one().await;
trace!("{}: sending {:?}", addr, &msg); trace!("{}: sending {:?}", addr, &msg);
buf.clear(); buf.clear();
let tid = msg.get_transaction_id().unwrap(); let tid = msg.get_our_transaction_id();
bprotocol::serialize_message( bprotocol::serialize_message(
&mut buf, &mut buf,
msg.transaction_id, msg.transaction_id,
@ -692,7 +738,10 @@ impl DhtWorker {
) )
.unwrap(); .unwrap();
if let Err(e) = socket.send_to(&buf, addr).await { if let Err(e) = socket.send_to(&buf, addr).await {
self.on_send_error(tid, addr, e.into()); debug!("error sending to {addr}: {e:?}");
if let Some(tid) = tid {
self.on_send_error(tid, addr, e.into());
}
} }
} }
Err::<(), _>(anyhow::anyhow!( Err::<(), _>(anyhow::anyhow!(
@ -745,7 +794,7 @@ impl DhtWorker {
let this = &self; let this = &self;
async move { async move {
while let Some((response, addr)) = out_rx.recv().await { while let Some((response, addr)) = out_rx.recv().await {
if let Err(e) = this.on_response(response, addr) { if let Err(e) = this.state.on_received_message(response, addr) {
debug!("error in on_response, addr={:?}: {}", addr, e) debug!("error in on_response, addr={:?}: {}", addr, e)
} }
} }

View file

@ -16,6 +16,8 @@ pub type Dht = Arc<DhtState>;
// How long do we wait for a response from a DHT node. // How long do we wait for a response from a DHT node.
pub(crate) const RESPONSE_TIMEOUT: Duration = Duration::from_secs(60); pub(crate) const RESPONSE_TIMEOUT: Duration = Duration::from_secs(60);
// TODO: Not sure if we should re-query tbh.
pub(crate) const REQUERY_INTERVAL: Duration = Duration::from_secs(60);
// After how long should we ping the node again. // After how long should we ping the node again.
pub(crate) const INACTIVITY_TIMEOUT: Duration = Duration::from_secs(15 * 60); pub(crate) const INACTIVITY_TIMEOUT: Duration = Duration::from_secs(15 * 60);