Next id -> AtomicU16

This commit is contained in:
Igor Katson 2023-11-28 07:40:27 +00:00
parent 1a6eb05ca1
commit eaf5021908
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
3 changed files with 25 additions and 27 deletions

1
Cargo.lock generated
View file

@ -1087,6 +1087,7 @@ name = "librqbit-dht"
version = "3.2.0"
dependencies = [
"anyhow",
"dashmap",
"directories",
"futures",
"hex 0.4.3",

View file

@ -31,6 +31,7 @@ futures = "0.3"
rand = "0.8"
indexmap = "2"
directories = "5"
dashmap = "5.5.3"
clone_to_owned = {path="../clone_to_owned", package="librqbit-clone-to-owned", version = "2.2.1"}
librqbit-core = {path="../librqbit_core", version = "3.1.0"}

View file

@ -1,7 +1,10 @@
use std::{
collections::{hash_map::Entry, HashMap},
net::SocketAddr,
sync::Arc,
sync::{
atomic::{AtomicU16, Ordering},
Arc,
},
task::Poll,
time::{Duration, Instant},
};
@ -41,7 +44,7 @@ pub struct DhtStats {
struct DhtState {
id: Id20,
next_transaction_id: u16,
next_transaction_id: AtomicU16,
// Created requests: (transaction_id, addr) => Requests.
// If we get a response, it gets removed from here.
@ -76,7 +79,7 @@ impl DhtState {
let routing_table = routing_table.unwrap_or_else(|| RoutingTable::new(id));
Self {
id,
next_transaction_id: 0,
next_transaction_id: AtomicU16::new(0),
outstanding_requests_by_transaction_id: Default::default(),
routing_table,
sender,
@ -87,15 +90,17 @@ impl DhtState {
}
}
fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message<ByteString> {
let transaction_id = self.next_transaction_id;
fn send_request(&mut self, request: Request, addr: SocketAddr) -> anyhow::Result<()> {
let (tid, msg) = self.create_request(request, addr);
self.outstanding_requests_by_transaction_id
.insert((tid, addr), request);
Ok(self.sender.send((msg, addr))?)
}
fn create_request(&mut self, request: Request, addr: SocketAddr) -> (u16, Message<ByteString>) {
let transaction_id = self.next_transaction_id.fetch_add(1, Ordering::Relaxed);
let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8];
self.next_transaction_id = if transaction_id == u16::MAX {
0
} else {
transaction_id + 1
};
let message = match request {
Request::GetPeers(info_hash) => Message {
transaction_id: ByteString::from(transaction_id_buf.as_ref()),
@ -122,9 +127,7 @@ impl DhtState {
kind: MessageKind::PingRequest(PingRequest { id: self.id }),
},
};
self.outstanding_requests_by_transaction_id
.insert((transaction_id, addr), request);
message
(transaction_id, message)
}
fn on_incoming_from_remote(
@ -337,8 +340,7 @@ impl DhtState {
let request = Request::GetPeers(info_hash);
if self.should_request(request, addr) {
self.routing_table.mark_outgoing_request(&target_node);
let msg = self.create_request(request, addr);
self.sender.send((msg, addr))?;
self.send_request(request, addr)?;
}
Ok(())
}
@ -352,8 +354,7 @@ impl DhtState {
let request = Request::FindNode(search_id);
if self.should_request(request, addr) {
self.routing_table.mark_outgoing_request(&target_node);
let msg = self.create_request(request, addr);
self.sender.send((msg, addr))?;
self.send_request(request, addr)?;
}
Ok(())
}
@ -365,8 +366,7 @@ impl DhtState {
true
});
for addr in questionable_nodes {
let req = self.create_request(Request::Ping, addr);
let _ = self.sender.send((req, addr));
let _ = self.send_request(Request::Ping, addr);
}
res
}
@ -560,7 +560,6 @@ impl DhtWorker {
async fn start(
self,
in_tx: UnboundedSender<(Message<ByteString>, SocketAddr)>,
in_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
bootstrap_addrs: &[String],
) -> anyhow::Result<()> {
@ -572,17 +571,14 @@ impl DhtWorker {
// bootstrap
for addr in bootstrap_addrs.iter() {
let this = &self;
let in_tx = &in_tx;
futs.push(
async move {
match tokio::net::lookup_host(addr).await {
Ok(addrs) => {
for addr in addrs {
let request = this
.state
this.state
.write()
.create_request(Request::FindNode(this.peer_id), addr);
in_tx.send((request, addr))?;
.send_request(Request::FindNode(this.peer_id), addr)?;
}
}
Err(e) => {
@ -730,7 +726,7 @@ impl Dht {
let (in_tx, in_rx) = unbounded_channel();
let state = Arc::new(RwLock::new(DhtState::new(
peer_id,
in_tx.clone(),
in_tx,
config.routing_table,
listen_addr,
)));
@ -743,7 +739,7 @@ impl Dht {
peer_id,
state,
};
worker.start(in_tx, in_rx, &bootstrap_addrs).await?;
worker.start(in_rx, &bootstrap_addrs).await?;
Ok(())
}
});