Next id -> AtomicU16

This commit is contained in:
Igor Katson 2023-11-28 07:40:27 +00:00
parent 1a6eb05ca1
commit eaf5021908
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
3 changed files with 25 additions and 27 deletions

1
Cargo.lock generated
View file

@ -1087,6 +1087,7 @@ name = "librqbit-dht"
version = "3.2.0" version = "3.2.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"dashmap",
"directories", "directories",
"futures", "futures",
"hex 0.4.3", "hex 0.4.3",

View file

@ -31,6 +31,7 @@ futures = "0.3"
rand = "0.8" rand = "0.8"
indexmap = "2" indexmap = "2"
directories = "5" directories = "5"
dashmap = "5.5.3"
clone_to_owned = {path="../clone_to_owned", package="librqbit-clone-to-owned", version = "2.2.1"} clone_to_owned = {path="../clone_to_owned", package="librqbit-clone-to-owned", version = "2.2.1"}
librqbit-core = {path="../librqbit_core", version = "3.1.0"} librqbit-core = {path="../librqbit_core", version = "3.1.0"}

View file

@ -1,7 +1,10 @@
use std::{ use std::{
collections::{hash_map::Entry, HashMap}, collections::{hash_map::Entry, HashMap},
net::SocketAddr, net::SocketAddr,
sync::Arc, sync::{
atomic::{AtomicU16, Ordering},
Arc,
},
task::Poll, task::Poll,
time::{Duration, Instant}, time::{Duration, Instant},
}; };
@ -41,7 +44,7 @@ pub struct DhtStats {
struct DhtState { struct DhtState {
id: Id20, id: Id20,
next_transaction_id: u16, next_transaction_id: AtomicU16,
// Created requests: (transaction_id, addr) => Requests. // Created requests: (transaction_id, addr) => Requests.
// If we get a response, it gets removed from here. // If we get a response, it gets removed from here.
@ -76,7 +79,7 @@ impl DhtState {
let routing_table = routing_table.unwrap_or_else(|| RoutingTable::new(id)); let routing_table = routing_table.unwrap_or_else(|| RoutingTable::new(id));
Self { Self {
id, id,
next_transaction_id: 0, next_transaction_id: AtomicU16::new(0),
outstanding_requests_by_transaction_id: Default::default(), outstanding_requests_by_transaction_id: Default::default(),
routing_table, routing_table,
sender, sender,
@ -87,15 +90,17 @@ impl DhtState {
} }
} }
fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message<ByteString> { fn send_request(&mut self, request: Request, addr: SocketAddr) -> anyhow::Result<()> {
let transaction_id = self.next_transaction_id; let (tid, msg) = self.create_request(request, addr);
self.outstanding_requests_by_transaction_id
.insert((tid, addr), request);
Ok(self.sender.send((msg, addr))?)
}
fn create_request(&mut self, request: Request, addr: SocketAddr) -> (u16, Message<ByteString>) {
let transaction_id = self.next_transaction_id.fetch_add(1, Ordering::Relaxed);
let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8]; let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8];
self.next_transaction_id = if transaction_id == u16::MAX {
0
} else {
transaction_id + 1
};
let message = match request { let message = match request {
Request::GetPeers(info_hash) => Message { Request::GetPeers(info_hash) => Message {
transaction_id: ByteString::from(transaction_id_buf.as_ref()), transaction_id: ByteString::from(transaction_id_buf.as_ref()),
@ -122,9 +127,7 @@ impl DhtState {
kind: MessageKind::PingRequest(PingRequest { id: self.id }), kind: MessageKind::PingRequest(PingRequest { id: self.id }),
}, },
}; };
self.outstanding_requests_by_transaction_id (transaction_id, message)
.insert((transaction_id, addr), request);
message
} }
fn on_incoming_from_remote( fn on_incoming_from_remote(
@ -337,8 +340,7 @@ impl DhtState {
let request = Request::GetPeers(info_hash); let request = Request::GetPeers(info_hash);
if self.should_request(request, addr) { if self.should_request(request, addr) {
self.routing_table.mark_outgoing_request(&target_node); self.routing_table.mark_outgoing_request(&target_node);
let msg = self.create_request(request, addr); self.send_request(request, addr)?;
self.sender.send((msg, addr))?;
} }
Ok(()) Ok(())
} }
@ -352,8 +354,7 @@ impl DhtState {
let request = Request::FindNode(search_id); let request = Request::FindNode(search_id);
if self.should_request(request, addr) { if self.should_request(request, addr) {
self.routing_table.mark_outgoing_request(&target_node); self.routing_table.mark_outgoing_request(&target_node);
let msg = self.create_request(request, addr); self.send_request(request, addr)?;
self.sender.send((msg, addr))?;
} }
Ok(()) Ok(())
} }
@ -365,8 +366,7 @@ impl DhtState {
true true
}); });
for addr in questionable_nodes { for addr in questionable_nodes {
let req = self.create_request(Request::Ping, addr); let _ = self.send_request(Request::Ping, addr);
let _ = self.sender.send((req, addr));
} }
res res
} }
@ -560,7 +560,6 @@ impl DhtWorker {
async fn start( async fn start(
self, self,
in_tx: UnboundedSender<(Message<ByteString>, SocketAddr)>,
in_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>, in_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
bootstrap_addrs: &[String], bootstrap_addrs: &[String],
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
@ -572,17 +571,14 @@ impl DhtWorker {
// bootstrap // bootstrap
for addr in bootstrap_addrs.iter() { for addr in bootstrap_addrs.iter() {
let this = &self; let this = &self;
let in_tx = &in_tx;
futs.push( futs.push(
async move { async move {
match tokio::net::lookup_host(addr).await { match tokio::net::lookup_host(addr).await {
Ok(addrs) => { Ok(addrs) => {
for addr in addrs { for addr in addrs {
let request = this this.state
.state
.write() .write()
.create_request(Request::FindNode(this.peer_id), addr); .send_request(Request::FindNode(this.peer_id), addr)?;
in_tx.send((request, addr))?;
} }
} }
Err(e) => { Err(e) => {
@ -730,7 +726,7 @@ impl Dht {
let (in_tx, in_rx) = unbounded_channel(); let (in_tx, in_rx) = unbounded_channel();
let state = Arc::new(RwLock::new(DhtState::new( let state = Arc::new(RwLock::new(DhtState::new(
peer_id, peer_id,
in_tx.clone(), in_tx,
config.routing_table, config.routing_table,
listen_addr, listen_addr,
))); )));
@ -743,7 +739,7 @@ impl Dht {
peer_id, peer_id,
state, state,
}; };
worker.start(in_tx, in_rx, &bootstrap_addrs).await?; worker.start(in_rx, &bootstrap_addrs).await?;
Ok(()) Ok(())
} }
}); });