From c7cf5eedefed3de1a139274d5305b75d0ea9923a Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Tue, 28 Nov 2023 08:03:12 +0000 Subject: [PATCH] Remove the giant lock from dht --- crates/dht/examples/dht.rs | 4 +- crates/dht/src/dht.rs | 114 ++++++++++++++++--------------- crates/dht/src/lib.rs | 19 +++++- crates/dht/src/persistence.rs | 4 +- crates/librqbit/src/dht_utils.rs | 4 +- crates/librqbit/src/session.rs | 4 +- 6 files changed, 84 insertions(+), 65 deletions(-) diff --git a/crates/dht/examples/dht.rs b/crates/dht/examples/dht.rs index 8862cdc..38c5342 100644 --- a/crates/dht/examples/dht.rs +++ b/crates/dht/examples/dht.rs @@ -2,7 +2,7 @@ use std::time::Duration; use anyhow::Context; use librqbit_core::magnet::Magnet; -use librqbit_dht::Dht; +use librqbit_dht::{Dht, DhtBuilder}; use tokio_stream::StreamExt; use tracing::info; @@ -16,7 +16,7 @@ async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt::init(); - let dht = Dht::new().await.context("error initializing DHT")?; + let dht = DhtBuilder::new().await.context("error initializing DHT")?; let mut stream = dht.get_peers(info_hash)?; let stats_printer = async { diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index ae040aa..5769893 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -1,5 +1,4 @@ use std::{ - collections::{hash_map::Entry, HashMap}, net::SocketAddr, sync::{ atomic::{AtomicU16, Ordering}, @@ -18,6 +17,7 @@ use crate::{ }; use anyhow::Context; use bencode::ByteString; +use dashmap::DashMap; use futures::{stream::FuturesUnordered, Stream, StreamExt}; use indexmap::IndexSet; use leaky_bucket::RateLimiter; @@ -42,7 +42,7 @@ pub struct DhtStats { pub routing_table_size: usize, } -struct DhtState { +pub struct DhtState { id: Id20, next_transaction_id: AtomicU16, @@ -50,12 +50,12 @@ struct DhtState { // If we get a response, it gets removed from here. // // TODO: clean up old entries - outstanding_requests_by_transaction_id: HashMap<(u16, SocketAddr), Request>, + outstanding_requests_by_transaction_id: DashMap<(u16, SocketAddr), Request>, // TODO: clean up old entries - made_requests_by_addr: HashMap<(Request, SocketAddr), Instant>, + made_requests_by_addr: DashMap<(Request, SocketAddr), Instant>, - routing_table: RoutingTable, + routing_table: RwLock, listen_addr: SocketAddr, // This sender sends requests to the worker. @@ -65,12 +65,12 @@ struct DhtState { // Alternatively, we can lock only the parts that change, and use that internally inside DhtState... sender: UnboundedSender<(Message, SocketAddr)>, - seen_peers: HashMap>, - get_peers_subscribers: HashMap>, + seen_peers: DashMap>, + get_peers_subscribers: DashMap>, } impl DhtState { - fn new( + fn new_internal( id: Id20, sender: UnboundedSender<(Message, SocketAddr)>, routing_table: Option, @@ -81,7 +81,7 @@ impl DhtState { id, next_transaction_id: AtomicU16::new(0), outstanding_requests_by_transaction_id: Default::default(), - routing_table, + routing_table: RwLock::new(routing_table), sender, listen_addr, seen_peers: Default::default(), @@ -90,14 +90,14 @@ impl DhtState { } } - fn send_request(&mut self, request: Request, addr: SocketAddr) -> anyhow::Result<()> { - let (tid, msg) = self.create_request(request, addr); + fn send_request(self: &Arc, request: Request, addr: SocketAddr) -> anyhow::Result<()> { + let (tid, msg) = self.create_request(request); self.outstanding_requests_by_transaction_id .insert((tid, addr), request); Ok(self.sender.send((msg, addr))?) } - fn create_request(&mut self, request: Request, addr: SocketAddr) -> (u16, Message) { + fn create_request(&self, request: Request) -> (u16, Message) { let transaction_id = self.next_transaction_id.fetch_add(1, Ordering::Relaxed); let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8]; @@ -131,13 +131,14 @@ impl DhtState { } fn on_incoming_from_remote( - &mut self, + self: &Arc, msg: Message, addr: SocketAddr, ) -> anyhow::Result<()> { let generate_compact_nodes = |target| { let nodes = self .routing_table + .read() .sorted_by_distance_from(target) .into_iter() .filter_map(|r| { @@ -167,6 +168,7 @@ impl DhtState { let request = match self .outstanding_requests_by_transaction_id .remove(&(tid, addr)) + .map(|(_, v)| v) { Some(req) => req, None => anyhow::bail!("outstanding request not found. Message: {:?}", msg), @@ -178,7 +180,7 @@ impl DhtState { MessageKind::Response(r) => r, _ => unreachable!(), }; - self.routing_table.mark_response(&response.id); + self.routing_table.write().mark_response(&response.id); match request { Request::FindNode(id) => { let nodes = response.nodes.ok_or_else(|| { @@ -226,7 +228,7 @@ impl DhtState { None }; let compact_node_info = generate_compact_nodes(req.info_hash); - self.routing_table.mark_last_query(&req.id); + self.routing_table.write().mark_last_query(&req.id); let message = Message { transaction_id: msg.transaction_id, version: None, @@ -243,7 +245,7 @@ impl DhtState { } MessageKind::FindNodeRequest(req) => { let compact_node_info = generate_compact_nodes(req.target); - self.routing_table.mark_last_query(&req.id); + self.routing_table.write().mark_last_query(&req.id); let message = Message { transaction_id: msg.transaction_id, version: None, @@ -264,20 +266,21 @@ impl DhtState { DhtStats { id: self.id, outstanding_requests: self.outstanding_requests_by_transaction_id.len(), - seen_peers: self.seen_peers.values().map(|v| v.len()).sum(), + seen_peers: self.seen_peers.iter().map(|(e)| e.value().len()).sum(), made_requests: self.made_requests_by_addr.len(), - routing_table_size: self.routing_table.len(), + routing_table_size: self.routing_table.read().len(), } } #[allow(clippy::type_complexity)] - fn get_peers( - &mut self, + fn get_peers_internal( + self: &Arc, info_hash: Id20, ) -> anyhow::Result<( Option<(usize, usize)>, tokio::sync::broadcast::Receiver, )> { + use dashmap::mapref::entry::Entry; match self.get_peers_subscribers.entry(info_hash) { Entry::Occupied(o) => { let pos = self.seen_peers.get(&info_hash).and_then(|p| { @@ -299,6 +302,7 @@ impl DhtState { // We don't need to allocate/collect here, but the borrow checker is not happy otherwise. let nodes_to_query = self .routing_table + .read() .sorted_by_distance_from(info_hash) .iter() .map(|n| (n.id(), n.addr())) @@ -313,8 +317,9 @@ impl DhtState { } } - fn should_request(&mut self, request: Request, addr: SocketAddr) -> bool { + fn should_request(&self, request: Request, addr: SocketAddr) -> bool { const RE_REQUEST_TIME: Duration = Duration::from_secs(10 * 60); + use dashmap::mapref::entry::Entry; match self.made_requests_by_addr.entry((request, addr)) { Entry::Occupied(mut o) => { if o.get().elapsed() > RE_REQUEST_TIME { @@ -332,36 +337,40 @@ impl DhtState { } fn send_find_peers_if_not_yet( - &mut self, + self: &Arc, info_hash: Id20, target_node: Id20, addr: SocketAddr, ) -> anyhow::Result<()> { let request = Request::GetPeers(info_hash); if self.should_request(request, addr) { - self.routing_table.mark_outgoing_request(&target_node); + self.routing_table + .write() + .mark_outgoing_request(&target_node); self.send_request(request, addr)?; } Ok(()) } fn send_find_node_if_not_yet( - &mut self, + self: &Arc, search_id: Id20, target_node: Id20, addr: SocketAddr, ) -> anyhow::Result<()> { let request = Request::FindNode(search_id); if self.should_request(request, addr) { - self.routing_table.mark_outgoing_request(&target_node); + self.routing_table + .write() + .mark_outgoing_request(&target_node); self.send_request(request, addr)?; } Ok(()) } - fn routing_table_add_node(&mut self, id: Id20, addr: SocketAddr) -> InsertResult { + fn routing_table_add_node(self: &Arc, id: Id20, addr: SocketAddr) -> InsertResult { let mut questionable_nodes = Vec::new(); - let res = self.routing_table.add_node(id, addr, |addr| { + let res = self.routing_table.write().add_node(id, addr, |addr| { questionable_nodes.push(addr); true }); @@ -372,7 +381,7 @@ impl DhtState { } fn on_found_nodes( - &mut self, + self: &Arc, source: Id20, source_addr: SocketAddr, target: Id20, @@ -382,8 +391,8 @@ impl DhtState { // otherwise when we iterate self.searching_for_peers and mutating self in the loop. let searching_for_peers = self .get_peers_subscribers - .keys() - .copied() + .iter() + .map(|e| *e.key()) .collect::>(); // On newly discovered nodes, ask them for peers that we are interested in. @@ -411,14 +420,14 @@ impl DhtState { } fn on_found_peers_or_nodes( - &mut self, + self: &Arc, source: Id20, source_addr: SocketAddr, target: Id20, data: bprotocol::Response, ) -> anyhow::Result<()> { self.routing_table_add_node(source, source_addr); - self.routing_table.mark_response(&source); + self.routing_table.write().mark_response(&source); let bsender = match self.get_peers_subscribers.get(&target) { Some(s) => s, @@ -432,7 +441,7 @@ impl DhtState { }; if let Some(peers) = data.values { - let seen = self.seen_peers.entry(target).or_default(); + let mut seen = self.seen_peers.entry(target).or_default(); for peer in peers.iter() { if peer.addr.port() < 1024 { @@ -542,20 +551,15 @@ enum Request { Ping, } -#[derive(Clone)] -pub struct Dht { - state: Arc>, -} - struct DhtWorker { socket: UdpSocket, peer_id: Id20, - state: Arc>, + state: Arc, } impl DhtWorker { fn on_response(&self, msg: Message, addr: SocketAddr) -> anyhow::Result<()> { - self.state.write().on_incoming_from_remote(msg, addr) + self.state.on_incoming_from_remote(msg, addr) } async fn start( @@ -577,7 +581,6 @@ impl DhtWorker { Ok(addrs) => { for addr in addrs { this.state - .write() .send_request(Request::FindNode(this.peer_id), addr)?; } } @@ -641,7 +644,7 @@ impl DhtWorker { struct PeerStream { info_hash: Id20, - state: Arc>, + state: Arc, absolute_stream_pos: usize, initial_peers_pos: Option<(usize, usize)>, broadcast_rx: BroadcastStream, @@ -658,7 +661,6 @@ impl Stream for PeerStream { if let Some((pos, end)) = self.initial_peers_pos.take() { let addr = *self .state - .read() .seen_peers .get(&self.info_hash) .unwrap() @@ -698,11 +700,11 @@ pub struct DhtConfig { pub listen_addr: Option, } -impl Dht { - pub async fn new() -> anyhow::Result { +impl DhtState { + pub async fn new() -> anyhow::Result> { Self::with_config(DhtConfig::default()).await } - pub async fn with_config(config: DhtConfig) -> anyhow::Result { + pub async fn with_config(config: DhtConfig) -> anyhow::Result> { let socket = match config.listen_addr { Some(addr) => UdpSocket::bind(addr) .await @@ -724,12 +726,12 @@ impl Dht { .unwrap_or_else(|| crate::DHT_BOOTSTRAP.iter().map(|v| v.to_string()).collect()); let (in_tx, in_rx) = unbounded_channel(); - let state = Arc::new(RwLock::new(DhtState::new( + let state = Arc::new(Self::new_internal( peer_id, in_tx, config.routing_table, listen_addr, - ))); + )); spawn(error_span!("dht"), { let state = state.clone(); @@ -743,17 +745,17 @@ impl Dht { Ok(()) } }); - Ok(Dht { state }) + Ok(state) } pub fn get_peers( - &self, + self: &Arc, info_hash: Id20, ) -> anyhow::Result + Unpin> { - let (pos, rx) = self.state.write().get_peers(info_hash)?; + let (pos, rx) = self.get_peers_internal(info_hash)?; Ok(PeerStream { info_hash, - state: self.state.clone(), + state: self.clone(), absolute_stream_pos: 0, initial_peers_pos: pos, broadcast_rx: BroadcastStream::new(rx), @@ -761,18 +763,18 @@ impl Dht { } pub fn listen_addr(&self) -> SocketAddr { - self.state.read().listen_addr + self.listen_addr } pub fn stats(&self) -> DhtStats { - self.state.read().get_stats() + self.get_stats() } pub fn with_routing_table R>(&self, f: F) -> R { - f(&self.state.read().routing_table) + f(&self.routing_table.read()) } pub fn clone_routing_table(&self) -> RoutingTable { - self.state.read().routing_table.clone() + self.routing_table.read().clone() } } diff --git a/crates/dht/src/lib.rs b/crates/dht/src/lib.rs index 9000fcc..81713d3 100644 --- a/crates/dht/src/lib.rs +++ b/crates/dht/src/lib.rs @@ -4,9 +4,26 @@ mod persistence; mod routing_table; mod utils; +use std::sync::Arc; + pub use crate::dht::DhtStats; -pub use crate::dht::{Dht, DhtConfig}; +pub use crate::dht::{DhtConfig, DhtState}; pub use librqbit_core::id20::Id20; pub use persistence::{PersistentDht, PersistentDhtConfig}; +pub type Dht = Arc; + +pub struct DhtBuilder {} + +impl DhtBuilder { + #[allow(clippy::new_ret_no_self)] + pub async fn new() -> anyhow::Result { + DhtState::new().await + } + + pub async fn with_config(config: DhtConfig) -> anyhow::Result { + DhtState::with_config(config).await + } +} + pub static DHT_BOOTSTRAP: &[&str] = &["dht.transmissionbt.com:6881", "dht.libtorrent.org:25401"]; diff --git a/crates/dht/src/persistence.rs b/crates/dht/src/persistence.rs index a4f091e..bf91903 100644 --- a/crates/dht/src/persistence.rs +++ b/crates/dht/src/persistence.rs @@ -11,8 +11,8 @@ use std::time::Duration; use anyhow::Context; use tracing::{debug, error, error_span, info, trace, warn}; -use crate::dht::{Dht, DhtConfig}; use crate::routing_table::RoutingTable; +use crate::{Dht, DhtConfig, DhtState}; #[derive(Default, Clone)] pub struct PersistentDhtConfig { @@ -108,7 +108,7 @@ impl PersistentDht { listen_addr, ..Default::default() }; - let dht = Dht::with_config(dht_config).await?; + let dht = DhtState::with_config(dht_config).await?; spawn(error_span!("dht_persistence"), { let dht = dht.clone(); diff --git a/crates/librqbit/src/dht_utils.rs b/crates/librqbit/src/dht_utils.rs index 9e7d60f..fb5b339 100644 --- a/crates/librqbit/src/dht_utils.rs +++ b/crates/librqbit/src/dht_utils.rs @@ -86,7 +86,7 @@ pub async fn read_metainfo_from_peer_receiver + Unp #[cfg(test)] mod tests { - use dht::{Dht, Id20}; + use dht::{Dht, DhtBuilder, Id20}; use librqbit_core::peer_id::generate_peer_id; use super::*; @@ -106,7 +106,7 @@ mod tests { init_logging(); let info_hash = Id20::from_str("cf3ea75e2ebbd30e0da6e6e215e2226bf35f2e33").unwrap(); - let dht = Dht::new().await.unwrap(); + let dht = DhtBuilder::new().await.unwrap(); let peer_rx = dht.get_peers(info_hash).unwrap(); let peer_id = generate_peer_id(); match read_metainfo_from_peer_receiver(peer_id, info_hash, peer_rx, None).await { diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index d4a56bf..47bc9cc 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -11,7 +11,7 @@ use std::{ use anyhow::{bail, Context}; use buffers::ByteString; -use dht::{Dht, Id20, PersistentDht, PersistentDhtConfig}; +use dht::{Dht, DhtBuilder, Id20, PersistentDht, PersistentDhtConfig}; use librqbit_core::{ magnet::Magnet, peer_id::generate_peer_id, @@ -234,7 +234,7 @@ impl Session { None } else { let dht = if opts.disable_dht_persistence { - Dht::new().await + DhtBuilder::new().await } else { PersistentDht::create(opts.dht_config).await }