Remove the giant lock from dht

This commit is contained in:
Igor Katson 2023-11-28 08:03:12 +00:00
parent eaf5021908
commit c7cf5eedef
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
6 changed files with 84 additions and 65 deletions

View file

@ -2,7 +2,7 @@ use std::time::Duration;
use anyhow::Context;
use librqbit_core::magnet::Magnet;
use librqbit_dht::Dht;
use librqbit_dht::{Dht, DhtBuilder};
use tokio_stream::StreamExt;
use tracing::info;
@ -16,7 +16,7 @@ async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
let dht = Dht::new().await.context("error initializing DHT")?;
let dht = DhtBuilder::new().await.context("error initializing DHT")?;
let mut stream = dht.get_peers(info_hash)?;
let stats_printer = async {

View file

@ -1,5 +1,4 @@
use std::{
collections::{hash_map::Entry, HashMap},
net::SocketAddr,
sync::{
atomic::{AtomicU16, Ordering},
@ -18,6 +17,7 @@ use crate::{
};
use anyhow::Context;
use bencode::ByteString;
use dashmap::DashMap;
use futures::{stream::FuturesUnordered, Stream, StreamExt};
use indexmap::IndexSet;
use leaky_bucket::RateLimiter;
@ -42,7 +42,7 @@ pub struct DhtStats {
pub routing_table_size: usize,
}
struct DhtState {
pub struct DhtState {
id: Id20,
next_transaction_id: AtomicU16,
@ -50,12 +50,12 @@ struct DhtState {
// If we get a response, it gets removed from here.
//
// TODO: clean up old entries
outstanding_requests_by_transaction_id: HashMap<(u16, SocketAddr), Request>,
outstanding_requests_by_transaction_id: DashMap<(u16, SocketAddr), Request>,
// TODO: clean up old entries
made_requests_by_addr: HashMap<(Request, SocketAddr), Instant>,
made_requests_by_addr: DashMap<(Request, SocketAddr), Instant>,
routing_table: RoutingTable,
routing_table: RwLock<RoutingTable>,
listen_addr: SocketAddr,
// This sender sends requests to the worker.
@ -65,12 +65,12 @@ struct DhtState {
// Alternatively, we can lock only the parts that change, and use that internally inside DhtState...
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
seen_peers: HashMap<Id20, IndexSet<SocketAddr>>,
get_peers_subscribers: HashMap<Id20, tokio::sync::broadcast::Sender<SocketAddr>>,
seen_peers: DashMap<Id20, IndexSet<SocketAddr>>,
get_peers_subscribers: DashMap<Id20, tokio::sync::broadcast::Sender<SocketAddr>>,
}
impl DhtState {
fn new(
fn new_internal(
id: Id20,
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
routing_table: Option<RoutingTable>,
@ -81,7 +81,7 @@ impl DhtState {
id,
next_transaction_id: AtomicU16::new(0),
outstanding_requests_by_transaction_id: Default::default(),
routing_table,
routing_table: RwLock::new(routing_table),
sender,
listen_addr,
seen_peers: Default::default(),
@ -90,14 +90,14 @@ impl DhtState {
}
}
fn send_request(&mut self, request: Request, addr: SocketAddr) -> anyhow::Result<()> {
let (tid, msg) = self.create_request(request, addr);
fn send_request(self: &Arc<Self>, request: Request, addr: SocketAddr) -> anyhow::Result<()> {
let (tid, msg) = self.create_request(request);
self.outstanding_requests_by_transaction_id
.insert((tid, addr), request);
Ok(self.sender.send((msg, addr))?)
}
fn create_request(&mut self, request: Request, addr: SocketAddr) -> (u16, Message<ByteString>) {
fn create_request(&self, request: Request) -> (u16, Message<ByteString>) {
let transaction_id = self.next_transaction_id.fetch_add(1, Ordering::Relaxed);
let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8];
@ -131,13 +131,14 @@ impl DhtState {
}
fn on_incoming_from_remote(
&mut self,
self: &Arc<Self>,
msg: Message<ByteString>,
addr: SocketAddr,
) -> anyhow::Result<()> {
let generate_compact_nodes = |target| {
let nodes = self
.routing_table
.read()
.sorted_by_distance_from(target)
.into_iter()
.filter_map(|r| {
@ -167,6 +168,7 @@ impl DhtState {
let request = match self
.outstanding_requests_by_transaction_id
.remove(&(tid, addr))
.map(|(_, v)| v)
{
Some(req) => req,
None => anyhow::bail!("outstanding request not found. Message: {:?}", msg),
@ -178,7 +180,7 @@ impl DhtState {
MessageKind::Response(r) => r,
_ => unreachable!(),
};
self.routing_table.mark_response(&response.id);
self.routing_table.write().mark_response(&response.id);
match request {
Request::FindNode(id) => {
let nodes = response.nodes.ok_or_else(|| {
@ -226,7 +228,7 @@ impl DhtState {
None
};
let compact_node_info = generate_compact_nodes(req.info_hash);
self.routing_table.mark_last_query(&req.id);
self.routing_table.write().mark_last_query(&req.id);
let message = Message {
transaction_id: msg.transaction_id,
version: None,
@ -243,7 +245,7 @@ impl DhtState {
}
MessageKind::FindNodeRequest(req) => {
let compact_node_info = generate_compact_nodes(req.target);
self.routing_table.mark_last_query(&req.id);
self.routing_table.write().mark_last_query(&req.id);
let message = Message {
transaction_id: msg.transaction_id,
version: None,
@ -264,20 +266,21 @@ impl DhtState {
DhtStats {
id: self.id,
outstanding_requests: self.outstanding_requests_by_transaction_id.len(),
seen_peers: self.seen_peers.values().map(|v| v.len()).sum(),
seen_peers: self.seen_peers.iter().map(|(e)| e.value().len()).sum(),
made_requests: self.made_requests_by_addr.len(),
routing_table_size: self.routing_table.len(),
routing_table_size: self.routing_table.read().len(),
}
}
#[allow(clippy::type_complexity)]
fn get_peers(
&mut self,
fn get_peers_internal(
self: &Arc<Self>,
info_hash: Id20,
) -> anyhow::Result<(
Option<(usize, usize)>,
tokio::sync::broadcast::Receiver<SocketAddr>,
)> {
use dashmap::mapref::entry::Entry;
match self.get_peers_subscribers.entry(info_hash) {
Entry::Occupied(o) => {
let pos = self.seen_peers.get(&info_hash).and_then(|p| {
@ -299,6 +302,7 @@ impl DhtState {
// We don't need to allocate/collect here, but the borrow checker is not happy otherwise.
let nodes_to_query = self
.routing_table
.read()
.sorted_by_distance_from(info_hash)
.iter()
.map(|n| (n.id(), n.addr()))
@ -313,8 +317,9 @@ impl DhtState {
}
}
fn should_request(&mut self, request: Request, addr: SocketAddr) -> bool {
fn should_request(&self, request: Request, addr: SocketAddr) -> bool {
const RE_REQUEST_TIME: Duration = Duration::from_secs(10 * 60);
use dashmap::mapref::entry::Entry;
match self.made_requests_by_addr.entry((request, addr)) {
Entry::Occupied(mut o) => {
if o.get().elapsed() > RE_REQUEST_TIME {
@ -332,36 +337,40 @@ impl DhtState {
}
fn send_find_peers_if_not_yet(
&mut self,
self: &Arc<Self>,
info_hash: Id20,
target_node: Id20,
addr: SocketAddr,
) -> anyhow::Result<()> {
let request = Request::GetPeers(info_hash);
if self.should_request(request, addr) {
self.routing_table.mark_outgoing_request(&target_node);
self.routing_table
.write()
.mark_outgoing_request(&target_node);
self.send_request(request, addr)?;
}
Ok(())
}
fn send_find_node_if_not_yet(
&mut self,
self: &Arc<Self>,
search_id: Id20,
target_node: Id20,
addr: SocketAddr,
) -> anyhow::Result<()> {
let request = Request::FindNode(search_id);
if self.should_request(request, addr) {
self.routing_table.mark_outgoing_request(&target_node);
self.routing_table
.write()
.mark_outgoing_request(&target_node);
self.send_request(request, addr)?;
}
Ok(())
}
fn routing_table_add_node(&mut self, id: Id20, addr: SocketAddr) -> InsertResult {
fn routing_table_add_node(self: &Arc<Self>, id: Id20, addr: SocketAddr) -> InsertResult {
let mut questionable_nodes = Vec::new();
let res = self.routing_table.add_node(id, addr, |addr| {
let res = self.routing_table.write().add_node(id, addr, |addr| {
questionable_nodes.push(addr);
true
});
@ -372,7 +381,7 @@ impl DhtState {
}
fn on_found_nodes(
&mut self,
self: &Arc<Self>,
source: Id20,
source_addr: SocketAddr,
target: Id20,
@ -382,8 +391,8 @@ impl DhtState {
// otherwise when we iterate self.searching_for_peers and mutating self in the loop.
let searching_for_peers = self
.get_peers_subscribers
.keys()
.copied()
.iter()
.map(|e| *e.key())
.collect::<Vec<_>>();
// On newly discovered nodes, ask them for peers that we are interested in.
@ -411,14 +420,14 @@ impl DhtState {
}
fn on_found_peers_or_nodes(
&mut self,
self: &Arc<Self>,
source: Id20,
source_addr: SocketAddr,
target: Id20,
data: bprotocol::Response<ByteString>,
) -> anyhow::Result<()> {
self.routing_table_add_node(source, source_addr);
self.routing_table.mark_response(&source);
self.routing_table.write().mark_response(&source);
let bsender = match self.get_peers_subscribers.get(&target) {
Some(s) => s,
@ -432,7 +441,7 @@ impl DhtState {
};
if let Some(peers) = data.values {
let seen = self.seen_peers.entry(target).or_default();
let mut seen = self.seen_peers.entry(target).or_default();
for peer in peers.iter() {
if peer.addr.port() < 1024 {
@ -542,20 +551,15 @@ enum Request {
Ping,
}
#[derive(Clone)]
pub struct Dht {
state: Arc<RwLock<DhtState>>,
}
struct DhtWorker {
socket: UdpSocket,
peer_id: Id20,
state: Arc<RwLock<DhtState>>,
state: Arc<DhtState>,
}
impl DhtWorker {
fn on_response(&self, msg: Message<ByteString>, addr: SocketAddr) -> anyhow::Result<()> {
self.state.write().on_incoming_from_remote(msg, addr)
self.state.on_incoming_from_remote(msg, addr)
}
async fn start(
@ -577,7 +581,6 @@ impl DhtWorker {
Ok(addrs) => {
for addr in addrs {
this.state
.write()
.send_request(Request::FindNode(this.peer_id), addr)?;
}
}
@ -641,7 +644,7 @@ impl DhtWorker {
struct PeerStream {
info_hash: Id20,
state: Arc<RwLock<DhtState>>,
state: Arc<DhtState>,
absolute_stream_pos: usize,
initial_peers_pos: Option<(usize, usize)>,
broadcast_rx: BroadcastStream<SocketAddr>,
@ -658,7 +661,6 @@ impl Stream for PeerStream {
if let Some((pos, end)) = self.initial_peers_pos.take() {
let addr = *self
.state
.read()
.seen_peers
.get(&self.info_hash)
.unwrap()
@ -698,11 +700,11 @@ pub struct DhtConfig {
pub listen_addr: Option<SocketAddr>,
}
impl Dht {
pub async fn new() -> anyhow::Result<Self> {
impl DhtState {
pub async fn new() -> anyhow::Result<Arc<Self>> {
Self::with_config(DhtConfig::default()).await
}
pub async fn with_config(config: DhtConfig) -> anyhow::Result<Self> {
pub async fn with_config(config: DhtConfig) -> anyhow::Result<Arc<Self>> {
let socket = match config.listen_addr {
Some(addr) => UdpSocket::bind(addr)
.await
@ -724,12 +726,12 @@ impl Dht {
.unwrap_or_else(|| crate::DHT_BOOTSTRAP.iter().map(|v| v.to_string()).collect());
let (in_tx, in_rx) = unbounded_channel();
let state = Arc::new(RwLock::new(DhtState::new(
let state = Arc::new(Self::new_internal(
peer_id,
in_tx,
config.routing_table,
listen_addr,
)));
));
spawn(error_span!("dht"), {
let state = state.clone();
@ -743,17 +745,17 @@ impl Dht {
Ok(())
}
});
Ok(Dht { state })
Ok(state)
}
pub fn get_peers(
&self,
self: &Arc<Self>,
info_hash: Id20,
) -> anyhow::Result<impl Stream<Item = SocketAddr> + Unpin> {
let (pos, rx) = self.state.write().get_peers(info_hash)?;
let (pos, rx) = self.get_peers_internal(info_hash)?;
Ok(PeerStream {
info_hash,
state: self.state.clone(),
state: self.clone(),
absolute_stream_pos: 0,
initial_peers_pos: pos,
broadcast_rx: BroadcastStream::new(rx),
@ -761,18 +763,18 @@ impl Dht {
}
pub fn listen_addr(&self) -> SocketAddr {
self.state.read().listen_addr
self.listen_addr
}
pub fn stats(&self) -> DhtStats {
self.state.read().get_stats()
self.get_stats()
}
pub fn with_routing_table<R, F: FnOnce(&RoutingTable) -> R>(&self, f: F) -> R {
f(&self.state.read().routing_table)
f(&self.routing_table.read())
}
pub fn clone_routing_table(&self) -> RoutingTable {
self.state.read().routing_table.clone()
self.routing_table.read().clone()
}
}

View file

@ -4,9 +4,26 @@ mod persistence;
mod routing_table;
mod utils;
use std::sync::Arc;
pub use crate::dht::DhtStats;
pub use crate::dht::{Dht, DhtConfig};
pub use crate::dht::{DhtConfig, DhtState};
pub use librqbit_core::id20::Id20;
pub use persistence::{PersistentDht, PersistentDhtConfig};
pub type Dht = Arc<DhtState>;
pub struct DhtBuilder {}
impl DhtBuilder {
#[allow(clippy::new_ret_no_self)]
pub async fn new() -> anyhow::Result<Dht> {
DhtState::new().await
}
pub async fn with_config(config: DhtConfig) -> anyhow::Result<Dht> {
DhtState::with_config(config).await
}
}
pub static DHT_BOOTSTRAP: &[&str] = &["dht.transmissionbt.com:6881", "dht.libtorrent.org:25401"];

View file

@ -11,8 +11,8 @@ use std::time::Duration;
use anyhow::Context;
use tracing::{debug, error, error_span, info, trace, warn};
use crate::dht::{Dht, DhtConfig};
use crate::routing_table::RoutingTable;
use crate::{Dht, DhtConfig, DhtState};
#[derive(Default, Clone)]
pub struct PersistentDhtConfig {
@ -108,7 +108,7 @@ impl PersistentDht {
listen_addr,
..Default::default()
};
let dht = Dht::with_config(dht_config).await?;
let dht = DhtState::with_config(dht_config).await?;
spawn(error_span!("dht_persistence"), {
let dht = dht.clone();

View file

@ -86,7 +86,7 @@ pub async fn read_metainfo_from_peer_receiver<A: Stream<Item = SocketAddr> + Unp
#[cfg(test)]
mod tests {
use dht::{Dht, Id20};
use dht::{Dht, DhtBuilder, Id20};
use librqbit_core::peer_id::generate_peer_id;
use super::*;
@ -106,7 +106,7 @@ mod tests {
init_logging();
let info_hash = Id20::from_str("cf3ea75e2ebbd30e0da6e6e215e2226bf35f2e33").unwrap();
let dht = Dht::new().await.unwrap();
let dht = DhtBuilder::new().await.unwrap();
let peer_rx = dht.get_peers(info_hash).unwrap();
let peer_id = generate_peer_id();
match read_metainfo_from_peer_receiver(peer_id, info_hash, peer_rx, None).await {

View file

@ -11,7 +11,7 @@ use std::{
use anyhow::{bail, Context};
use buffers::ByteString;
use dht::{Dht, Id20, PersistentDht, PersistentDhtConfig};
use dht::{Dht, DhtBuilder, Id20, PersistentDht, PersistentDhtConfig};
use librqbit_core::{
magnet::Magnet,
peer_id::generate_peer_id,
@ -234,7 +234,7 @@ impl Session {
None
} else {
let dht = if opts.disable_dht_persistence {
Dht::new().await
DhtBuilder::new().await
} else {
PersistentDht::create(opts.dht_config).await
}