From eaf5021908a3c2c1b8c5d9df15cbf40c05d76d7f Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Tue, 28 Nov 2023 07:40:27 +0000 Subject: [PATCH] Next id -> AtomicU16 --- Cargo.lock | 1 + crates/dht/Cargo.toml | 1 + crates/dht/src/dht.rs | 50 ++++++++++++++++++++----------------------- 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b7dec58..ee5227f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1087,6 +1087,7 @@ name = "librqbit-dht" version = "3.2.0" dependencies = [ "anyhow", + "dashmap", "directories", "futures", "hex 0.4.3", diff --git a/crates/dht/Cargo.toml b/crates/dht/Cargo.toml index ec5c9d5..e6647b5 100644 --- a/crates/dht/Cargo.toml +++ b/crates/dht/Cargo.toml @@ -31,6 +31,7 @@ futures = "0.3" rand = "0.8" indexmap = "2" directories = "5" +dashmap = "5.5.3" clone_to_owned = {path="../clone_to_owned", package="librqbit-clone-to-owned", version = "2.2.1"} librqbit-core = {path="../librqbit_core", version = "3.1.0"} diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index 533c221..ae040aa 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -1,7 +1,10 @@ use std::{ collections::{hash_map::Entry, HashMap}, net::SocketAddr, - sync::Arc, + sync::{ + atomic::{AtomicU16, Ordering}, + Arc, + }, task::Poll, time::{Duration, Instant}, }; @@ -41,7 +44,7 @@ pub struct DhtStats { struct DhtState { id: Id20, - next_transaction_id: u16, + next_transaction_id: AtomicU16, // Created requests: (transaction_id, addr) => Requests. // If we get a response, it gets removed from here. @@ -76,7 +79,7 @@ impl DhtState { let routing_table = routing_table.unwrap_or_else(|| RoutingTable::new(id)); Self { id, - next_transaction_id: 0, + next_transaction_id: AtomicU16::new(0), outstanding_requests_by_transaction_id: Default::default(), routing_table, sender, @@ -87,15 +90,17 @@ impl DhtState { } } - fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message { - let transaction_id = self.next_transaction_id; + fn send_request(&mut self, request: Request, addr: SocketAddr) -> anyhow::Result<()> { + let (tid, msg) = self.create_request(request, addr); + self.outstanding_requests_by_transaction_id + .insert((tid, addr), request); + Ok(self.sender.send((msg, addr))?) + } + + fn create_request(&mut self, request: Request, addr: SocketAddr) -> (u16, Message) { + let transaction_id = self.next_transaction_id.fetch_add(1, Ordering::Relaxed); let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8]; - self.next_transaction_id = if transaction_id == u16::MAX { - 0 - } else { - transaction_id + 1 - }; let message = match request { Request::GetPeers(info_hash) => Message { transaction_id: ByteString::from(transaction_id_buf.as_ref()), @@ -122,9 +127,7 @@ impl DhtState { kind: MessageKind::PingRequest(PingRequest { id: self.id }), }, }; - self.outstanding_requests_by_transaction_id - .insert((transaction_id, addr), request); - message + (transaction_id, message) } fn on_incoming_from_remote( @@ -337,8 +340,7 @@ impl DhtState { let request = Request::GetPeers(info_hash); if self.should_request(request, addr) { self.routing_table.mark_outgoing_request(&target_node); - let msg = self.create_request(request, addr); - self.sender.send((msg, addr))?; + self.send_request(request, addr)?; } Ok(()) } @@ -352,8 +354,7 @@ impl DhtState { let request = Request::FindNode(search_id); if self.should_request(request, addr) { self.routing_table.mark_outgoing_request(&target_node); - let msg = self.create_request(request, addr); - self.sender.send((msg, addr))?; + self.send_request(request, addr)?; } Ok(()) } @@ -365,8 +366,7 @@ impl DhtState { true }); for addr in questionable_nodes { - let req = self.create_request(Request::Ping, addr); - let _ = self.sender.send((req, addr)); + let _ = self.send_request(Request::Ping, addr); } res } @@ -560,7 +560,6 @@ impl DhtWorker { async fn start( self, - in_tx: UnboundedSender<(Message, SocketAddr)>, in_rx: UnboundedReceiver<(Message, SocketAddr)>, bootstrap_addrs: &[String], ) -> anyhow::Result<()> { @@ -572,17 +571,14 @@ impl DhtWorker { // bootstrap for addr in bootstrap_addrs.iter() { let this = &self; - let in_tx = &in_tx; futs.push( async move { match tokio::net::lookup_host(addr).await { Ok(addrs) => { for addr in addrs { - let request = this - .state + this.state .write() - .create_request(Request::FindNode(this.peer_id), addr); - in_tx.send((request, addr))?; + .send_request(Request::FindNode(this.peer_id), addr)?; } } Err(e) => { @@ -730,7 +726,7 @@ impl Dht { let (in_tx, in_rx) = unbounded_channel(); let state = Arc::new(RwLock::new(DhtState::new( peer_id, - in_tx.clone(), + in_tx, config.routing_table, listen_addr, ))); @@ -743,7 +739,7 @@ impl Dht { peer_id, state, }; - worker.start(in_tx, in_rx, &bootstrap_addrs).await?; + worker.start(in_rx, &bootstrap_addrs).await?; Ok(()) } });