Remove the giant lock from dht

This commit is contained in:
Igor Katson 2023-11-28 08:03:12 +00:00
parent eaf5021908
commit c7cf5eedef
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
6 changed files with 84 additions and 65 deletions

View file

@ -2,7 +2,7 @@ use std::time::Duration;
use anyhow::Context; use anyhow::Context;
use librqbit_core::magnet::Magnet; use librqbit_core::magnet::Magnet;
use librqbit_dht::Dht; use librqbit_dht::{Dht, DhtBuilder};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::info; use tracing::info;
@ -16,7 +16,7 @@ async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let dht = Dht::new().await.context("error initializing DHT")?; let dht = DhtBuilder::new().await.context("error initializing DHT")?;
let mut stream = dht.get_peers(info_hash)?; let mut stream = dht.get_peers(info_hash)?;
let stats_printer = async { let stats_printer = async {

View file

@ -1,5 +1,4 @@
use std::{ use std::{
collections::{hash_map::Entry, HashMap},
net::SocketAddr, net::SocketAddr,
sync::{ sync::{
atomic::{AtomicU16, Ordering}, atomic::{AtomicU16, Ordering},
@ -18,6 +17,7 @@ use crate::{
}; };
use anyhow::Context; use anyhow::Context;
use bencode::ByteString; use bencode::ByteString;
use dashmap::DashMap;
use futures::{stream::FuturesUnordered, Stream, StreamExt}; use futures::{stream::FuturesUnordered, Stream, StreamExt};
use indexmap::IndexSet; use indexmap::IndexSet;
use leaky_bucket::RateLimiter; use leaky_bucket::RateLimiter;
@ -42,7 +42,7 @@ pub struct DhtStats {
pub routing_table_size: usize, pub routing_table_size: usize,
} }
struct DhtState { pub struct DhtState {
id: Id20, id: Id20,
next_transaction_id: AtomicU16, next_transaction_id: AtomicU16,
@ -50,12 +50,12 @@ struct DhtState {
// If we get a response, it gets removed from here. // If we get a response, it gets removed from here.
// //
// TODO: clean up old entries // TODO: clean up old entries
outstanding_requests_by_transaction_id: HashMap<(u16, SocketAddr), Request>, outstanding_requests_by_transaction_id: DashMap<(u16, SocketAddr), Request>,
// TODO: clean up old entries // TODO: clean up old entries
made_requests_by_addr: HashMap<(Request, SocketAddr), Instant>, made_requests_by_addr: DashMap<(Request, SocketAddr), Instant>,
routing_table: RoutingTable, routing_table: RwLock<RoutingTable>,
listen_addr: SocketAddr, listen_addr: SocketAddr,
// This sender sends requests to the worker. // This sender sends requests to the worker.
@ -65,12 +65,12 @@ struct DhtState {
// Alternatively, we can lock only the parts that change, and use that internally inside DhtState... // Alternatively, we can lock only the parts that change, and use that internally inside DhtState...
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>, sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
seen_peers: HashMap<Id20, IndexSet<SocketAddr>>, seen_peers: DashMap<Id20, IndexSet<SocketAddr>>,
get_peers_subscribers: HashMap<Id20, tokio::sync::broadcast::Sender<SocketAddr>>, get_peers_subscribers: DashMap<Id20, tokio::sync::broadcast::Sender<SocketAddr>>,
} }
impl DhtState { impl DhtState {
fn new( fn new_internal(
id: Id20, id: Id20,
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>, sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
routing_table: Option<RoutingTable>, routing_table: Option<RoutingTable>,
@ -81,7 +81,7 @@ impl DhtState {
id, id,
next_transaction_id: AtomicU16::new(0), next_transaction_id: AtomicU16::new(0),
outstanding_requests_by_transaction_id: Default::default(), outstanding_requests_by_transaction_id: Default::default(),
routing_table, routing_table: RwLock::new(routing_table),
sender, sender,
listen_addr, listen_addr,
seen_peers: Default::default(), seen_peers: Default::default(),
@ -90,14 +90,14 @@ impl DhtState {
} }
} }
fn send_request(&mut self, request: Request, addr: SocketAddr) -> anyhow::Result<()> { fn send_request(self: &Arc<Self>, request: Request, addr: SocketAddr) -> anyhow::Result<()> {
let (tid, msg) = self.create_request(request, addr); let (tid, msg) = self.create_request(request);
self.outstanding_requests_by_transaction_id self.outstanding_requests_by_transaction_id
.insert((tid, addr), request); .insert((tid, addr), request);
Ok(self.sender.send((msg, addr))?) Ok(self.sender.send((msg, addr))?)
} }
fn create_request(&mut self, request: Request, addr: SocketAddr) -> (u16, Message<ByteString>) { fn create_request(&self, request: Request) -> (u16, Message<ByteString>) {
let transaction_id = self.next_transaction_id.fetch_add(1, Ordering::Relaxed); let transaction_id = self.next_transaction_id.fetch_add(1, Ordering::Relaxed);
let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8]; let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8];
@ -131,13 +131,14 @@ impl DhtState {
} }
fn on_incoming_from_remote( fn on_incoming_from_remote(
&mut self, self: &Arc<Self>,
msg: Message<ByteString>, msg: Message<ByteString>,
addr: SocketAddr, addr: SocketAddr,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let generate_compact_nodes = |target| { let generate_compact_nodes = |target| {
let nodes = self let nodes = self
.routing_table .routing_table
.read()
.sorted_by_distance_from(target) .sorted_by_distance_from(target)
.into_iter() .into_iter()
.filter_map(|r| { .filter_map(|r| {
@ -167,6 +168,7 @@ impl DhtState {
let request = match self let request = match self
.outstanding_requests_by_transaction_id .outstanding_requests_by_transaction_id
.remove(&(tid, addr)) .remove(&(tid, addr))
.map(|(_, v)| v)
{ {
Some(req) => req, Some(req) => req,
None => anyhow::bail!("outstanding request not found. Message: {:?}", msg), None => anyhow::bail!("outstanding request not found. Message: {:?}", msg),
@ -178,7 +180,7 @@ impl DhtState {
MessageKind::Response(r) => r, MessageKind::Response(r) => r,
_ => unreachable!(), _ => unreachable!(),
}; };
self.routing_table.mark_response(&response.id); self.routing_table.write().mark_response(&response.id);
match request { match request {
Request::FindNode(id) => { Request::FindNode(id) => {
let nodes = response.nodes.ok_or_else(|| { let nodes = response.nodes.ok_or_else(|| {
@ -226,7 +228,7 @@ impl DhtState {
None None
}; };
let compact_node_info = generate_compact_nodes(req.info_hash); let compact_node_info = generate_compact_nodes(req.info_hash);
self.routing_table.mark_last_query(&req.id); self.routing_table.write().mark_last_query(&req.id);
let message = Message { let message = Message {
transaction_id: msg.transaction_id, transaction_id: msg.transaction_id,
version: None, version: None,
@ -243,7 +245,7 @@ impl DhtState {
} }
MessageKind::FindNodeRequest(req) => { MessageKind::FindNodeRequest(req) => {
let compact_node_info = generate_compact_nodes(req.target); let compact_node_info = generate_compact_nodes(req.target);
self.routing_table.mark_last_query(&req.id); self.routing_table.write().mark_last_query(&req.id);
let message = Message { let message = Message {
transaction_id: msg.transaction_id, transaction_id: msg.transaction_id,
version: None, version: None,
@ -264,20 +266,21 @@ impl DhtState {
DhtStats { DhtStats {
id: self.id, id: self.id,
outstanding_requests: self.outstanding_requests_by_transaction_id.len(), outstanding_requests: self.outstanding_requests_by_transaction_id.len(),
seen_peers: self.seen_peers.values().map(|v| v.len()).sum(), seen_peers: self.seen_peers.iter().map(|(e)| e.value().len()).sum(),
made_requests: self.made_requests_by_addr.len(), made_requests: self.made_requests_by_addr.len(),
routing_table_size: self.routing_table.len(), routing_table_size: self.routing_table.read().len(),
} }
} }
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
fn get_peers( fn get_peers_internal(
&mut self, self: &Arc<Self>,
info_hash: Id20, info_hash: Id20,
) -> anyhow::Result<( ) -> anyhow::Result<(
Option<(usize, usize)>, Option<(usize, usize)>,
tokio::sync::broadcast::Receiver<SocketAddr>, tokio::sync::broadcast::Receiver<SocketAddr>,
)> { )> {
use dashmap::mapref::entry::Entry;
match self.get_peers_subscribers.entry(info_hash) { match self.get_peers_subscribers.entry(info_hash) {
Entry::Occupied(o) => { Entry::Occupied(o) => {
let pos = self.seen_peers.get(&info_hash).and_then(|p| { let pos = self.seen_peers.get(&info_hash).and_then(|p| {
@ -299,6 +302,7 @@ impl DhtState {
// We don't need to allocate/collect here, but the borrow checker is not happy otherwise. // We don't need to allocate/collect here, but the borrow checker is not happy otherwise.
let nodes_to_query = self let nodes_to_query = self
.routing_table .routing_table
.read()
.sorted_by_distance_from(info_hash) .sorted_by_distance_from(info_hash)
.iter() .iter()
.map(|n| (n.id(), n.addr())) .map(|n| (n.id(), n.addr()))
@ -313,8 +317,9 @@ impl DhtState {
} }
} }
fn should_request(&mut self, request: Request, addr: SocketAddr) -> bool { fn should_request(&self, request: Request, addr: SocketAddr) -> bool {
const RE_REQUEST_TIME: Duration = Duration::from_secs(10 * 60); const RE_REQUEST_TIME: Duration = Duration::from_secs(10 * 60);
use dashmap::mapref::entry::Entry;
match self.made_requests_by_addr.entry((request, addr)) { match self.made_requests_by_addr.entry((request, addr)) {
Entry::Occupied(mut o) => { Entry::Occupied(mut o) => {
if o.get().elapsed() > RE_REQUEST_TIME { if o.get().elapsed() > RE_REQUEST_TIME {
@ -332,36 +337,40 @@ impl DhtState {
} }
fn send_find_peers_if_not_yet( fn send_find_peers_if_not_yet(
&mut self, self: &Arc<Self>,
info_hash: Id20, info_hash: Id20,
target_node: Id20, target_node: Id20,
addr: SocketAddr, addr: SocketAddr,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let request = Request::GetPeers(info_hash); let request = Request::GetPeers(info_hash);
if self.should_request(request, addr) { if self.should_request(request, addr) {
self.routing_table.mark_outgoing_request(&target_node); self.routing_table
.write()
.mark_outgoing_request(&target_node);
self.send_request(request, addr)?; self.send_request(request, addr)?;
} }
Ok(()) Ok(())
} }
fn send_find_node_if_not_yet( fn send_find_node_if_not_yet(
&mut self, self: &Arc<Self>,
search_id: Id20, search_id: Id20,
target_node: Id20, target_node: Id20,
addr: SocketAddr, addr: SocketAddr,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let request = Request::FindNode(search_id); let request = Request::FindNode(search_id);
if self.should_request(request, addr) { if self.should_request(request, addr) {
self.routing_table.mark_outgoing_request(&target_node); self.routing_table
.write()
.mark_outgoing_request(&target_node);
self.send_request(request, addr)?; self.send_request(request, addr)?;
} }
Ok(()) Ok(())
} }
fn routing_table_add_node(&mut self, id: Id20, addr: SocketAddr) -> InsertResult { fn routing_table_add_node(self: &Arc<Self>, id: Id20, addr: SocketAddr) -> InsertResult {
let mut questionable_nodes = Vec::new(); let mut questionable_nodes = Vec::new();
let res = self.routing_table.add_node(id, addr, |addr| { let res = self.routing_table.write().add_node(id, addr, |addr| {
questionable_nodes.push(addr); questionable_nodes.push(addr);
true true
}); });
@ -372,7 +381,7 @@ impl DhtState {
} }
fn on_found_nodes( fn on_found_nodes(
&mut self, self: &Arc<Self>,
source: Id20, source: Id20,
source_addr: SocketAddr, source_addr: SocketAddr,
target: Id20, target: Id20,
@ -382,8 +391,8 @@ impl DhtState {
// otherwise when we iterate self.searching_for_peers and mutating self in the loop. // otherwise when we iterate self.searching_for_peers and mutating self in the loop.
let searching_for_peers = self let searching_for_peers = self
.get_peers_subscribers .get_peers_subscribers
.keys() .iter()
.copied() .map(|e| *e.key())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
// On newly discovered nodes, ask them for peers that we are interested in. // On newly discovered nodes, ask them for peers that we are interested in.
@ -411,14 +420,14 @@ impl DhtState {
} }
fn on_found_peers_or_nodes( fn on_found_peers_or_nodes(
&mut self, self: &Arc<Self>,
source: Id20, source: Id20,
source_addr: SocketAddr, source_addr: SocketAddr,
target: Id20, target: Id20,
data: bprotocol::Response<ByteString>, data: bprotocol::Response<ByteString>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
self.routing_table_add_node(source, source_addr); self.routing_table_add_node(source, source_addr);
self.routing_table.mark_response(&source); self.routing_table.write().mark_response(&source);
let bsender = match self.get_peers_subscribers.get(&target) { let bsender = match self.get_peers_subscribers.get(&target) {
Some(s) => s, Some(s) => s,
@ -432,7 +441,7 @@ impl DhtState {
}; };
if let Some(peers) = data.values { if let Some(peers) = data.values {
let seen = self.seen_peers.entry(target).or_default(); let mut seen = self.seen_peers.entry(target).or_default();
for peer in peers.iter() { for peer in peers.iter() {
if peer.addr.port() < 1024 { if peer.addr.port() < 1024 {
@ -542,20 +551,15 @@ enum Request {
Ping, Ping,
} }
#[derive(Clone)]
pub struct Dht {
state: Arc<RwLock<DhtState>>,
}
struct DhtWorker { struct DhtWorker {
socket: UdpSocket, socket: UdpSocket,
peer_id: Id20, peer_id: Id20,
state: Arc<RwLock<DhtState>>, state: Arc<DhtState>,
} }
impl DhtWorker { impl DhtWorker {
fn on_response(&self, msg: Message<ByteString>, addr: SocketAddr) -> anyhow::Result<()> { fn on_response(&self, msg: Message<ByteString>, addr: SocketAddr) -> anyhow::Result<()> {
self.state.write().on_incoming_from_remote(msg, addr) self.state.on_incoming_from_remote(msg, addr)
} }
async fn start( async fn start(
@ -577,7 +581,6 @@ impl DhtWorker {
Ok(addrs) => { Ok(addrs) => {
for addr in addrs { for addr in addrs {
this.state this.state
.write()
.send_request(Request::FindNode(this.peer_id), addr)?; .send_request(Request::FindNode(this.peer_id), addr)?;
} }
} }
@ -641,7 +644,7 @@ impl DhtWorker {
struct PeerStream { struct PeerStream {
info_hash: Id20, info_hash: Id20,
state: Arc<RwLock<DhtState>>, state: Arc<DhtState>,
absolute_stream_pos: usize, absolute_stream_pos: usize,
initial_peers_pos: Option<(usize, usize)>, initial_peers_pos: Option<(usize, usize)>,
broadcast_rx: BroadcastStream<SocketAddr>, broadcast_rx: BroadcastStream<SocketAddr>,
@ -658,7 +661,6 @@ impl Stream for PeerStream {
if let Some((pos, end)) = self.initial_peers_pos.take() { if let Some((pos, end)) = self.initial_peers_pos.take() {
let addr = *self let addr = *self
.state .state
.read()
.seen_peers .seen_peers
.get(&self.info_hash) .get(&self.info_hash)
.unwrap() .unwrap()
@ -698,11 +700,11 @@ pub struct DhtConfig {
pub listen_addr: Option<SocketAddr>, pub listen_addr: Option<SocketAddr>,
} }
impl Dht { impl DhtState {
pub async fn new() -> anyhow::Result<Self> { pub async fn new() -> anyhow::Result<Arc<Self>> {
Self::with_config(DhtConfig::default()).await Self::with_config(DhtConfig::default()).await
} }
pub async fn with_config(config: DhtConfig) -> anyhow::Result<Self> { pub async fn with_config(config: DhtConfig) -> anyhow::Result<Arc<Self>> {
let socket = match config.listen_addr { let socket = match config.listen_addr {
Some(addr) => UdpSocket::bind(addr) Some(addr) => UdpSocket::bind(addr)
.await .await
@ -724,12 +726,12 @@ impl Dht {
.unwrap_or_else(|| crate::DHT_BOOTSTRAP.iter().map(|v| v.to_string()).collect()); .unwrap_or_else(|| crate::DHT_BOOTSTRAP.iter().map(|v| v.to_string()).collect());
let (in_tx, in_rx) = unbounded_channel(); let (in_tx, in_rx) = unbounded_channel();
let state = Arc::new(RwLock::new(DhtState::new( let state = Arc::new(Self::new_internal(
peer_id, peer_id,
in_tx, in_tx,
config.routing_table, config.routing_table,
listen_addr, listen_addr,
))); ));
spawn(error_span!("dht"), { spawn(error_span!("dht"), {
let state = state.clone(); let state = state.clone();
@ -743,17 +745,17 @@ impl Dht {
Ok(()) Ok(())
} }
}); });
Ok(Dht { state }) Ok(state)
} }
pub fn get_peers( pub fn get_peers(
&self, self: &Arc<Self>,
info_hash: Id20, info_hash: Id20,
) -> anyhow::Result<impl Stream<Item = SocketAddr> + Unpin> { ) -> anyhow::Result<impl Stream<Item = SocketAddr> + Unpin> {
let (pos, rx) = self.state.write().get_peers(info_hash)?; let (pos, rx) = self.get_peers_internal(info_hash)?;
Ok(PeerStream { Ok(PeerStream {
info_hash, info_hash,
state: self.state.clone(), state: self.clone(),
absolute_stream_pos: 0, absolute_stream_pos: 0,
initial_peers_pos: pos, initial_peers_pos: pos,
broadcast_rx: BroadcastStream::new(rx), broadcast_rx: BroadcastStream::new(rx),
@ -761,18 +763,18 @@ impl Dht {
} }
pub fn listen_addr(&self) -> SocketAddr { pub fn listen_addr(&self) -> SocketAddr {
self.state.read().listen_addr self.listen_addr
} }
pub fn stats(&self) -> DhtStats { pub fn stats(&self) -> DhtStats {
self.state.read().get_stats() self.get_stats()
} }
pub fn with_routing_table<R, F: FnOnce(&RoutingTable) -> R>(&self, f: F) -> R { pub fn with_routing_table<R, F: FnOnce(&RoutingTable) -> R>(&self, f: F) -> R {
f(&self.state.read().routing_table) f(&self.routing_table.read())
} }
pub fn clone_routing_table(&self) -> RoutingTable { pub fn clone_routing_table(&self) -> RoutingTable {
self.state.read().routing_table.clone() self.routing_table.read().clone()
} }
} }

View file

@ -4,9 +4,26 @@ mod persistence;
mod routing_table; mod routing_table;
mod utils; mod utils;
use std::sync::Arc;
pub use crate::dht::DhtStats; pub use crate::dht::DhtStats;
pub use crate::dht::{Dht, DhtConfig}; pub use crate::dht::{DhtConfig, DhtState};
pub use librqbit_core::id20::Id20; pub use librqbit_core::id20::Id20;
pub use persistence::{PersistentDht, PersistentDhtConfig}; pub use persistence::{PersistentDht, PersistentDhtConfig};
pub type Dht = Arc<DhtState>;
pub struct DhtBuilder {}
impl DhtBuilder {
#[allow(clippy::new_ret_no_self)]
pub async fn new() -> anyhow::Result<Dht> {
DhtState::new().await
}
pub async fn with_config(config: DhtConfig) -> anyhow::Result<Dht> {
DhtState::with_config(config).await
}
}
pub static DHT_BOOTSTRAP: &[&str] = &["dht.transmissionbt.com:6881", "dht.libtorrent.org:25401"]; pub static DHT_BOOTSTRAP: &[&str] = &["dht.transmissionbt.com:6881", "dht.libtorrent.org:25401"];

View file

@ -11,8 +11,8 @@ use std::time::Duration;
use anyhow::Context; use anyhow::Context;
use tracing::{debug, error, error_span, info, trace, warn}; use tracing::{debug, error, error_span, info, trace, warn};
use crate::dht::{Dht, DhtConfig};
use crate::routing_table::RoutingTable; use crate::routing_table::RoutingTable;
use crate::{Dht, DhtConfig, DhtState};
#[derive(Default, Clone)] #[derive(Default, Clone)]
pub struct PersistentDhtConfig { pub struct PersistentDhtConfig {
@ -108,7 +108,7 @@ impl PersistentDht {
listen_addr, listen_addr,
..Default::default() ..Default::default()
}; };
let dht = Dht::with_config(dht_config).await?; let dht = DhtState::with_config(dht_config).await?;
spawn(error_span!("dht_persistence"), { spawn(error_span!("dht_persistence"), {
let dht = dht.clone(); let dht = dht.clone();

View file

@ -86,7 +86,7 @@ pub async fn read_metainfo_from_peer_receiver<A: Stream<Item = SocketAddr> + Unp
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use dht::{Dht, Id20}; use dht::{Dht, DhtBuilder, Id20};
use librqbit_core::peer_id::generate_peer_id; use librqbit_core::peer_id::generate_peer_id;
use super::*; use super::*;
@ -106,7 +106,7 @@ mod tests {
init_logging(); init_logging();
let info_hash = Id20::from_str("cf3ea75e2ebbd30e0da6e6e215e2226bf35f2e33").unwrap(); let info_hash = Id20::from_str("cf3ea75e2ebbd30e0da6e6e215e2226bf35f2e33").unwrap();
let dht = Dht::new().await.unwrap(); let dht = DhtBuilder::new().await.unwrap();
let peer_rx = dht.get_peers(info_hash).unwrap(); let peer_rx = dht.get_peers(info_hash).unwrap();
let peer_id = generate_peer_id(); let peer_id = generate_peer_id();
match read_metainfo_from_peer_receiver(peer_id, info_hash, peer_rx, None).await { match read_metainfo_from_peer_receiver(peer_id, info_hash, peer_rx, None).await {

View file

@ -11,7 +11,7 @@ use std::{
use anyhow::{bail, Context}; use anyhow::{bail, Context};
use buffers::ByteString; use buffers::ByteString;
use dht::{Dht, Id20, PersistentDht, PersistentDhtConfig}; use dht::{Dht, DhtBuilder, Id20, PersistentDht, PersistentDhtConfig};
use librqbit_core::{ use librqbit_core::{
magnet::Magnet, magnet::Magnet,
peer_id::generate_peer_id, peer_id::generate_peer_id,
@ -234,7 +234,7 @@ impl Session {
None None
} else { } else {
let dht = if opts.disable_dht_persistence { let dht = if opts.disable_dht_persistence {
Dht::new().await DhtBuilder::new().await
} else { } else {
PersistentDht::create(opts.dht_config).await PersistentDht::create(opts.dht_config).await
} }