Created more tasks but it impacts perf and memory badly

This commit is contained in:
Igor Katson 2023-11-28 15:35:27 +00:00
parent 91c99a272f
commit 93740ec84b
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
3 changed files with 126 additions and 74 deletions

View file

@ -333,7 +333,8 @@ pub struct Message<BufT> {
}
impl Message<ByteString> {
pub fn get_transaction_id(&self) -> Option<u16> {
// This implies that the transaction id was generated by us.
pub fn get_our_transaction_id(&self) -> Option<u16> {
if self.transaction_id.len() != 2 {
return None;
}

View file

@ -5,7 +5,7 @@ use std::{
Arc,
},
task::Poll,
time::{Duration, Instant},
time::Duration,
};
use crate::{
@ -14,12 +14,12 @@ use crate::{
Message, MessageKind, Node, PingRequest, Response,
},
routing_table::{InsertResult, RoutingTable},
RESPONSE_TIMEOUT,
REQUERY_INTERVAL, RESPONSE_TIMEOUT,
};
use anyhow::Context;
use anyhow::{bail, Context};
use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
use bencode::ByteString;
use dashmap::DashMap;
use dashmap::{DashMap, DashSet};
use futures::{stream::FuturesUnordered, Stream, StreamExt};
use indexmap::IndexSet;
use leaky_bucket::RateLimiter;
@ -40,7 +40,7 @@ pub struct DhtStats {
pub id: Id20,
pub outstanding_requests: usize,
pub seen_peers: usize,
pub made_requests: usize,
pub outstanding_backoff_tasks: usize,
pub routing_table_size: usize,
}
@ -54,10 +54,10 @@ pub struct DhtState {
// Created requests: (transaction_id, addr) => Requests.
// If we get a response, it gets removed from here.
inflight: DashMap<(u16, SocketAddr), OutstandingRequest>,
inflight_by_transaction_id: DashMap<(u16, SocketAddr), OutstandingRequest>,
// TODO: clean up old entries
made_requests_by_addr: DashMap<(Request, SocketAddr), Instant>,
// Current requests to addr being re-sent with backoff.
inflight_by_request: DashSet<(Request, SocketAddr)>,
routing_table: RwLock<RoutingTable>,
listen_addr: SocketAddr,
@ -80,13 +80,13 @@ impl DhtState {
Self {
id,
next_transaction_id: AtomicU16::new(0),
inflight: Default::default(),
inflight_by_transaction_id: Default::default(),
routing_table: RwLock::new(routing_table),
sender,
listen_addr,
seen_peers: Default::default(),
get_peers_subscribers: Default::default(),
made_requests_by_addr: Default::default(),
inflight_by_request: Default::default(),
}
}
@ -115,7 +115,7 @@ impl DhtState {
match resp {
ResponseOrError::Response(r) => self.on_response(addr, request, r),
ResponseOrError::Error(e) => {
anyhow::bail!("received error: {:?}", e);
bail!("received error: {:?}", e);
}
}
}
@ -124,24 +124,25 @@ impl DhtState {
let (tid, msg) = self.create_request(request);
let key = (tid, addr);
let (tx, rx) = tokio::sync::oneshot::channel();
self.inflight.insert(key, OutstandingRequest { done: tx });
self.inflight_by_transaction_id
.insert(key, OutstandingRequest { done: tx });
match self.sender.send((msg, addr)) {
Ok(_) => {}
Err(e) => {
self.inflight.remove(&key);
self.inflight_by_transaction_id.remove(&key);
return Err(e.into());
}
};
match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await {
Ok(Ok(r)) => r,
Ok(Err(e)) => {
self.inflight.remove(&key);
self.inflight_by_transaction_id.remove(&key);
warn!("recv error, did not expect this: {:?}", e);
Err(e.into())
}
Err(_) => {
self.inflight.remove(&key);
anyhow::bail!("timeout")
self.inflight_by_transaction_id.remove(&key);
bail!("timeout")
}
}
}
@ -192,12 +193,14 @@ impl DhtState {
.ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?;
self.on_found_nodes(response.id, addr, id, nodes)
}
Request::GetPeers(id) => self.on_found_peers_or_nodes(response.id, addr, id, response),
Request::Ping => Ok(()),
Request::GetPeers(info_hash) => {
self.on_found_peers_or_nodes(response.id, addr, info_hash, response)
}
}
}
fn on_incoming_from_remote(
fn on_received_message(
self: &Arc<Self>,
msg: Message<ByteString>,
addr: SocketAddr,
@ -226,10 +229,14 @@ impl DhtState {
// If it's a response to a request we made, find the request task, notify it with the response,
// and let it handle it.
MessageKind::Error(_) | MessageKind::Response(_) => {
let tid = msg.get_transaction_id().context("bad transaction id")?;
let request = match self.inflight.remove(&(tid, addr)).map(|(_, v)| v) {
let tid = msg.get_our_transaction_id().context("bad transaction id")?;
let request = match self
.inflight_by_transaction_id
.remove(&(tid, addr))
.map(|(_, v)| v)
{
Some(req) => req,
None => anyhow::bail!("outstanding request not found. Message: {:?}", msg),
None => bail!("outstanding request not found. Message: {:?}", msg),
};
let response_or_error = match msg.kind {
@ -324,9 +331,9 @@ impl DhtState {
pub fn get_stats(&self) -> DhtStats {
DhtStats {
id: self.id,
outstanding_requests: self.inflight.len(),
outstanding_requests: self.inflight_by_transaction_id.len(),
seen_peers: self.seen_peers.iter().map(|e| e.value().len()).sum(),
made_requests: self.made_requests_by_addr.len(),
outstanding_backoff_tasks: self.inflight_by_request.len(),
routing_table_size: self.routing_table.read().len(),
}
}
@ -376,38 +383,86 @@ impl DhtState {
}
}
fn should_request(&self, request: Request, addr: SocketAddr) -> bool {
const RE_REQUEST_TIME: Duration = Duration::from_secs(10 * 60);
use dashmap::mapref::entry::Entry;
match self.made_requests_by_addr.entry((request, addr)) {
Entry::Occupied(mut o) => {
if o.get().elapsed() > RE_REQUEST_TIME {
o.insert(Instant::now());
true
} else {
false
}
}
Entry::Vacant(v) => {
v.insert(Instant::now());
true
}
}
}
fn send_find_peers_if_not_yet(
self: &Arc<Self>,
info_hash: Id20,
target_node: Id20,
addr: SocketAddr,
) -> anyhow::Result<()> {
let request = Request::GetPeers(info_hash);
if self.should_request(request, addr) {
self.routing_table
.write()
.mark_outgoing_request(&target_node);
self.spawn_request(request, addr);
self.send_request_if_not_yet(target_node, Request::GetPeers(info_hash), addr)
}
fn send_request_if_not_yet(
self: &Arc<Self>,
target_node: Id20,
request: Request,
addr: SocketAddr,
) -> anyhow::Result<()> {
let key = (request, addr);
if !self.inflight_by_request.insert(key) {
return Ok(());
}
let this = self.clone();
let fut = async move {
let mut backoff = ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_secs(60))
.with_multiplier(1.5)
.with_max_interval(Duration::from_secs(10 * 60))
.with_max_elapsed_time(Some(Duration::from_secs(15 * 60)))
.build();
loop {
this.routing_table
.write()
.mark_outgoing_request(&target_node);
let resp = this.request(request, addr).await;
let sleep = match resp {
Ok(ResponseOrError::Response(response)) => {
match this.on_response(addr, request, response) {
Ok(()) => {
backoff.reset();
Some(REQUERY_INTERVAL)
}
Err(e) => {
warn!("error in on_response: {:?}", e);
backoff.next_backoff()
}
}
}
Ok(ResponseOrError::Error(e)) => {
debug!("error response: {:?}", e);
backoff.next_backoff()
}
Err(e) => {
debug!("error: {:?}", e);
backoff.next_backoff()
}
};
if let Some(sleep) = sleep {
tokio::time::sleep(sleep).await;
continue;
}
tokio::task::spawn(async move {
this.inflight_by_request.remove(&key);
});
return Ok(());
}
};
spawn(
error_span!(
parent: None,
"dht_request",
addr = addr.to_string(),
request = format!("{:?}", request),
),
fut,
);
Ok(())
}
@ -417,14 +472,7 @@ impl DhtState {
target_node: Id20,
addr: SocketAddr,
) -> anyhow::Result<()> {
let request = Request::FindNode(search_id);
if self.should_request(request, addr) {
self.routing_table
.write()
.mark_outgoing_request(&target_node);
self.spawn_request(request, addr);
}
Ok(())
self.send_request_if_not_yet(target_node, Request::FindNode(search_id), addr)
}
fn routing_table_add_node(self: &Arc<Self>, id: Id20, addr: SocketAddr) -> InsertResult {
@ -482,25 +530,25 @@ impl DhtState {
self: &Arc<Self>,
source: Id20,
source_addr: SocketAddr,
target: Id20,
info_hash: Id20,
data: bprotocol::Response<ByteString>,
) -> anyhow::Result<()> {
self.routing_table_add_node(source, source_addr);
self.routing_table.write().mark_response(&source);
let bsender = match self.get_peers_subscribers.get(&target) {
let bsender = match self.get_peers_subscribers.get(&info_hash) {
Some(s) => s,
None => {
warn!(
"ignoring get_peers response, no subscribers for {:?}",
target
info_hash
);
return Ok(());
}
};
if let Some(peers) = data.values {
let mut seen = self.seen_peers.entry(target).or_default();
let mut seen = self.seen_peers.entry(info_hash).or_default();
for peer in peers.iter() {
if peer.addr.port() < 1024 {
@ -518,7 +566,7 @@ impl DhtState {
if let Some(nodes) = data.nodes {
for node in nodes.nodes {
self.routing_table_add_node(node.id, node.addr.into());
self.send_find_peers_if_not_yet(target, node.id, node.addr.into())?;
self.send_find_peers_if_not_yet(info_hash, node.id, node.addr.into())?;
}
};
Ok(())
@ -562,12 +610,10 @@ struct DhtWorker {
}
impl DhtWorker {
fn on_response(&self, msg: Message<ByteString>, addr: SocketAddr) -> anyhow::Result<()> {
self.state.on_incoming_from_remote(msg, addr)
}
fn on_send_error(&self, tid: u16, addr: SocketAddr, err: anyhow::Error) {
if let Some((_, OutstandingRequest { done })) = self.state.inflight.remove(&(tid, addr)) {
if let Some((_, OutstandingRequest { done })) =
self.state.inflight_by_transaction_id.remove(&(tid, addr))
{
let _ = done.send(Err(err)).is_err();
};
}
@ -593,7 +639,7 @@ impl DhtWorker {
tokio::time::sleep(backoff).await;
continue;
}
anyhow::bail!("given up bootstrapping, timed out")
bail!("given up bootstrapping, timed out")
}
}
}
@ -618,7 +664,7 @@ impl DhtWorker {
};
}
if successes == 0 {
anyhow::bail!("none of the {} bootstrap requests succeded", requests);
bail!("none of the {} bootstrap requests succeded", requests);
}
Ok(())
}
@ -643,7 +689,7 @@ impl DhtWorker {
tokio::time::sleep(backoff).await;
continue;
}
anyhow::bail!("bootstrap failed")
bail!("bootstrap failed")
}
}
@ -664,7 +710,7 @@ impl DhtWorker {
}
}
if successes == 0 {
anyhow::bail!("bootstrapping failed")
bail!("bootstrapping failed")
}
Ok(())
}
@ -682,7 +728,7 @@ impl DhtWorker {
rate_limiter.acquire_one().await;
trace!("{}: sending {:?}", addr, &msg);
buf.clear();
let tid = msg.get_transaction_id().unwrap();
let tid = msg.get_our_transaction_id();
bprotocol::serialize_message(
&mut buf,
msg.transaction_id,
@ -692,7 +738,10 @@ impl DhtWorker {
)
.unwrap();
if let Err(e) = socket.send_to(&buf, addr).await {
self.on_send_error(tid, addr, e.into());
debug!("error sending to {addr}: {e:?}");
if let Some(tid) = tid {
self.on_send_error(tid, addr, e.into());
}
}
}
Err::<(), _>(anyhow::anyhow!(
@ -745,7 +794,7 @@ impl DhtWorker {
let this = &self;
async move {
while let Some((response, addr)) = out_rx.recv().await {
if let Err(e) = this.on_response(response, addr) {
if let Err(e) = this.state.on_received_message(response, addr) {
debug!("error in on_response, addr={:?}: {}", addr, e)
}
}

View file

@ -16,6 +16,8 @@ pub type Dht = Arc<DhtState>;
// How long do we wait for a response from a DHT node.
pub(crate) const RESPONSE_TIMEOUT: Duration = Duration::from_secs(60);
// TODO: Not sure if we should re-query tbh.
pub(crate) const REQUERY_INTERVAL: Duration = Duration::from_secs(60);
// After how long should we ping the node again.
pub(crate) const INACTIVITY_TIMEOUT: Duration = Duration::from_secs(15 * 60);