Remove the giant lock from dht
This commit is contained in:
parent
eaf5021908
commit
c7cf5eedef
6 changed files with 84 additions and 65 deletions
|
|
@ -2,7 +2,7 @@ use std::time::Duration;
|
|||
|
||||
use anyhow::Context;
|
||||
use librqbit_core::magnet::Magnet;
|
||||
use librqbit_dht::Dht;
|
||||
use librqbit_dht::{Dht, DhtBuilder};
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::info;
|
||||
|
||||
|
|
@ -16,7 +16,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let dht = Dht::new().await.context("error initializing DHT")?;
|
||||
let dht = DhtBuilder::new().await.context("error initializing DHT")?;
|
||||
let mut stream = dht.get_peers(info_hash)?;
|
||||
|
||||
let stats_printer = async {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
use std::{
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
net::SocketAddr,
|
||||
sync::{
|
||||
atomic::{AtomicU16, Ordering},
|
||||
|
|
@ -18,6 +17,7 @@ use crate::{
|
|||
};
|
||||
use anyhow::Context;
|
||||
use bencode::ByteString;
|
||||
use dashmap::DashMap;
|
||||
use futures::{stream::FuturesUnordered, Stream, StreamExt};
|
||||
use indexmap::IndexSet;
|
||||
use leaky_bucket::RateLimiter;
|
||||
|
|
@ -42,7 +42,7 @@ pub struct DhtStats {
|
|||
pub routing_table_size: usize,
|
||||
}
|
||||
|
||||
struct DhtState {
|
||||
pub struct DhtState {
|
||||
id: Id20,
|
||||
next_transaction_id: AtomicU16,
|
||||
|
||||
|
|
@ -50,12 +50,12 @@ struct DhtState {
|
|||
// If we get a response, it gets removed from here.
|
||||
//
|
||||
// TODO: clean up old entries
|
||||
outstanding_requests_by_transaction_id: HashMap<(u16, SocketAddr), Request>,
|
||||
outstanding_requests_by_transaction_id: DashMap<(u16, SocketAddr), Request>,
|
||||
|
||||
// TODO: clean up old entries
|
||||
made_requests_by_addr: HashMap<(Request, SocketAddr), Instant>,
|
||||
made_requests_by_addr: DashMap<(Request, SocketAddr), Instant>,
|
||||
|
||||
routing_table: RoutingTable,
|
||||
routing_table: RwLock<RoutingTable>,
|
||||
listen_addr: SocketAddr,
|
||||
|
||||
// This sender sends requests to the worker.
|
||||
|
|
@ -65,12 +65,12 @@ struct DhtState {
|
|||
// Alternatively, we can lock only the parts that change, and use that internally inside DhtState...
|
||||
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
|
||||
|
||||
seen_peers: HashMap<Id20, IndexSet<SocketAddr>>,
|
||||
get_peers_subscribers: HashMap<Id20, tokio::sync::broadcast::Sender<SocketAddr>>,
|
||||
seen_peers: DashMap<Id20, IndexSet<SocketAddr>>,
|
||||
get_peers_subscribers: DashMap<Id20, tokio::sync::broadcast::Sender<SocketAddr>>,
|
||||
}
|
||||
|
||||
impl DhtState {
|
||||
fn new(
|
||||
fn new_internal(
|
||||
id: Id20,
|
||||
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
|
||||
routing_table: Option<RoutingTable>,
|
||||
|
|
@ -81,7 +81,7 @@ impl DhtState {
|
|||
id,
|
||||
next_transaction_id: AtomicU16::new(0),
|
||||
outstanding_requests_by_transaction_id: Default::default(),
|
||||
routing_table,
|
||||
routing_table: RwLock::new(routing_table),
|
||||
sender,
|
||||
listen_addr,
|
||||
seen_peers: Default::default(),
|
||||
|
|
@ -90,14 +90,14 @@ impl DhtState {
|
|||
}
|
||||
}
|
||||
|
||||
fn send_request(&mut self, request: Request, addr: SocketAddr) -> anyhow::Result<()> {
|
||||
let (tid, msg) = self.create_request(request, addr);
|
||||
fn send_request(self: &Arc<Self>, request: Request, addr: SocketAddr) -> anyhow::Result<()> {
|
||||
let (tid, msg) = self.create_request(request);
|
||||
self.outstanding_requests_by_transaction_id
|
||||
.insert((tid, addr), request);
|
||||
Ok(self.sender.send((msg, addr))?)
|
||||
}
|
||||
|
||||
fn create_request(&mut self, request: Request, addr: SocketAddr) -> (u16, Message<ByteString>) {
|
||||
fn create_request(&self, request: Request) -> (u16, Message<ByteString>) {
|
||||
let transaction_id = self.next_transaction_id.fetch_add(1, Ordering::Relaxed);
|
||||
let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8];
|
||||
|
||||
|
|
@ -131,13 +131,14 @@ impl DhtState {
|
|||
}
|
||||
|
||||
fn on_incoming_from_remote(
|
||||
&mut self,
|
||||
self: &Arc<Self>,
|
||||
msg: Message<ByteString>,
|
||||
addr: SocketAddr,
|
||||
) -> anyhow::Result<()> {
|
||||
let generate_compact_nodes = |target| {
|
||||
let nodes = self
|
||||
.routing_table
|
||||
.read()
|
||||
.sorted_by_distance_from(target)
|
||||
.into_iter()
|
||||
.filter_map(|r| {
|
||||
|
|
@ -167,6 +168,7 @@ impl DhtState {
|
|||
let request = match self
|
||||
.outstanding_requests_by_transaction_id
|
||||
.remove(&(tid, addr))
|
||||
.map(|(_, v)| v)
|
||||
{
|
||||
Some(req) => req,
|
||||
None => anyhow::bail!("outstanding request not found. Message: {:?}", msg),
|
||||
|
|
@ -178,7 +180,7 @@ impl DhtState {
|
|||
MessageKind::Response(r) => r,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.routing_table.mark_response(&response.id);
|
||||
self.routing_table.write().mark_response(&response.id);
|
||||
match request {
|
||||
Request::FindNode(id) => {
|
||||
let nodes = response.nodes.ok_or_else(|| {
|
||||
|
|
@ -226,7 +228,7 @@ impl DhtState {
|
|||
None
|
||||
};
|
||||
let compact_node_info = generate_compact_nodes(req.info_hash);
|
||||
self.routing_table.mark_last_query(&req.id);
|
||||
self.routing_table.write().mark_last_query(&req.id);
|
||||
let message = Message {
|
||||
transaction_id: msg.transaction_id,
|
||||
version: None,
|
||||
|
|
@ -243,7 +245,7 @@ impl DhtState {
|
|||
}
|
||||
MessageKind::FindNodeRequest(req) => {
|
||||
let compact_node_info = generate_compact_nodes(req.target);
|
||||
self.routing_table.mark_last_query(&req.id);
|
||||
self.routing_table.write().mark_last_query(&req.id);
|
||||
let message = Message {
|
||||
transaction_id: msg.transaction_id,
|
||||
version: None,
|
||||
|
|
@ -264,20 +266,21 @@ impl DhtState {
|
|||
DhtStats {
|
||||
id: self.id,
|
||||
outstanding_requests: self.outstanding_requests_by_transaction_id.len(),
|
||||
seen_peers: self.seen_peers.values().map(|v| v.len()).sum(),
|
||||
seen_peers: self.seen_peers.iter().map(|(e)| e.value().len()).sum(),
|
||||
made_requests: self.made_requests_by_addr.len(),
|
||||
routing_table_size: self.routing_table.len(),
|
||||
routing_table_size: self.routing_table.read().len(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn get_peers(
|
||||
&mut self,
|
||||
fn get_peers_internal(
|
||||
self: &Arc<Self>,
|
||||
info_hash: Id20,
|
||||
) -> anyhow::Result<(
|
||||
Option<(usize, usize)>,
|
||||
tokio::sync::broadcast::Receiver<SocketAddr>,
|
||||
)> {
|
||||
use dashmap::mapref::entry::Entry;
|
||||
match self.get_peers_subscribers.entry(info_hash) {
|
||||
Entry::Occupied(o) => {
|
||||
let pos = self.seen_peers.get(&info_hash).and_then(|p| {
|
||||
|
|
@ -299,6 +302,7 @@ impl DhtState {
|
|||
// We don't need to allocate/collect here, but the borrow checker is not happy otherwise.
|
||||
let nodes_to_query = self
|
||||
.routing_table
|
||||
.read()
|
||||
.sorted_by_distance_from(info_hash)
|
||||
.iter()
|
||||
.map(|n| (n.id(), n.addr()))
|
||||
|
|
@ -313,8 +317,9 @@ impl DhtState {
|
|||
}
|
||||
}
|
||||
|
||||
fn should_request(&mut self, request: Request, addr: SocketAddr) -> bool {
|
||||
fn should_request(&self, request: Request, addr: SocketAddr) -> bool {
|
||||
const RE_REQUEST_TIME: Duration = Duration::from_secs(10 * 60);
|
||||
use dashmap::mapref::entry::Entry;
|
||||
match self.made_requests_by_addr.entry((request, addr)) {
|
||||
Entry::Occupied(mut o) => {
|
||||
if o.get().elapsed() > RE_REQUEST_TIME {
|
||||
|
|
@ -332,36 +337,40 @@ impl DhtState {
|
|||
}
|
||||
|
||||
fn send_find_peers_if_not_yet(
|
||||
&mut self,
|
||||
self: &Arc<Self>,
|
||||
info_hash: Id20,
|
||||
target_node: Id20,
|
||||
addr: SocketAddr,
|
||||
) -> anyhow::Result<()> {
|
||||
let request = Request::GetPeers(info_hash);
|
||||
if self.should_request(request, addr) {
|
||||
self.routing_table.mark_outgoing_request(&target_node);
|
||||
self.routing_table
|
||||
.write()
|
||||
.mark_outgoing_request(&target_node);
|
||||
self.send_request(request, addr)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn send_find_node_if_not_yet(
|
||||
&mut self,
|
||||
self: &Arc<Self>,
|
||||
search_id: Id20,
|
||||
target_node: Id20,
|
||||
addr: SocketAddr,
|
||||
) -> anyhow::Result<()> {
|
||||
let request = Request::FindNode(search_id);
|
||||
if self.should_request(request, addr) {
|
||||
self.routing_table.mark_outgoing_request(&target_node);
|
||||
self.routing_table
|
||||
.write()
|
||||
.mark_outgoing_request(&target_node);
|
||||
self.send_request(request, addr)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn routing_table_add_node(&mut self, id: Id20, addr: SocketAddr) -> InsertResult {
|
||||
fn routing_table_add_node(self: &Arc<Self>, id: Id20, addr: SocketAddr) -> InsertResult {
|
||||
let mut questionable_nodes = Vec::new();
|
||||
let res = self.routing_table.add_node(id, addr, |addr| {
|
||||
let res = self.routing_table.write().add_node(id, addr, |addr| {
|
||||
questionable_nodes.push(addr);
|
||||
true
|
||||
});
|
||||
|
|
@ -372,7 +381,7 @@ impl DhtState {
|
|||
}
|
||||
|
||||
fn on_found_nodes(
|
||||
&mut self,
|
||||
self: &Arc<Self>,
|
||||
source: Id20,
|
||||
source_addr: SocketAddr,
|
||||
target: Id20,
|
||||
|
|
@ -382,8 +391,8 @@ impl DhtState {
|
|||
// otherwise when we iterate self.searching_for_peers and mutating self in the loop.
|
||||
let searching_for_peers = self
|
||||
.get_peers_subscribers
|
||||
.keys()
|
||||
.copied()
|
||||
.iter()
|
||||
.map(|e| *e.key())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// On newly discovered nodes, ask them for peers that we are interested in.
|
||||
|
|
@ -411,14 +420,14 @@ impl DhtState {
|
|||
}
|
||||
|
||||
fn on_found_peers_or_nodes(
|
||||
&mut self,
|
||||
self: &Arc<Self>,
|
||||
source: Id20,
|
||||
source_addr: SocketAddr,
|
||||
target: Id20,
|
||||
data: bprotocol::Response<ByteString>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.routing_table_add_node(source, source_addr);
|
||||
self.routing_table.mark_response(&source);
|
||||
self.routing_table.write().mark_response(&source);
|
||||
|
||||
let bsender = match self.get_peers_subscribers.get(&target) {
|
||||
Some(s) => s,
|
||||
|
|
@ -432,7 +441,7 @@ impl DhtState {
|
|||
};
|
||||
|
||||
if let Some(peers) = data.values {
|
||||
let seen = self.seen_peers.entry(target).or_default();
|
||||
let mut seen = self.seen_peers.entry(target).or_default();
|
||||
|
||||
for peer in peers.iter() {
|
||||
if peer.addr.port() < 1024 {
|
||||
|
|
@ -542,20 +551,15 @@ enum Request {
|
|||
Ping,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Dht {
|
||||
state: Arc<RwLock<DhtState>>,
|
||||
}
|
||||
|
||||
struct DhtWorker {
|
||||
socket: UdpSocket,
|
||||
peer_id: Id20,
|
||||
state: Arc<RwLock<DhtState>>,
|
||||
state: Arc<DhtState>,
|
||||
}
|
||||
|
||||
impl DhtWorker {
|
||||
fn on_response(&self, msg: Message<ByteString>, addr: SocketAddr) -> anyhow::Result<()> {
|
||||
self.state.write().on_incoming_from_remote(msg, addr)
|
||||
self.state.on_incoming_from_remote(msg, addr)
|
||||
}
|
||||
|
||||
async fn start(
|
||||
|
|
@ -577,7 +581,6 @@ impl DhtWorker {
|
|||
Ok(addrs) => {
|
||||
for addr in addrs {
|
||||
this.state
|
||||
.write()
|
||||
.send_request(Request::FindNode(this.peer_id), addr)?;
|
||||
}
|
||||
}
|
||||
|
|
@ -641,7 +644,7 @@ impl DhtWorker {
|
|||
|
||||
struct PeerStream {
|
||||
info_hash: Id20,
|
||||
state: Arc<RwLock<DhtState>>,
|
||||
state: Arc<DhtState>,
|
||||
absolute_stream_pos: usize,
|
||||
initial_peers_pos: Option<(usize, usize)>,
|
||||
broadcast_rx: BroadcastStream<SocketAddr>,
|
||||
|
|
@ -658,7 +661,6 @@ impl Stream for PeerStream {
|
|||
if let Some((pos, end)) = self.initial_peers_pos.take() {
|
||||
let addr = *self
|
||||
.state
|
||||
.read()
|
||||
.seen_peers
|
||||
.get(&self.info_hash)
|
||||
.unwrap()
|
||||
|
|
@ -698,11 +700,11 @@ pub struct DhtConfig {
|
|||
pub listen_addr: Option<SocketAddr>,
|
||||
}
|
||||
|
||||
impl Dht {
|
||||
pub async fn new() -> anyhow::Result<Self> {
|
||||
impl DhtState {
|
||||
pub async fn new() -> anyhow::Result<Arc<Self>> {
|
||||
Self::with_config(DhtConfig::default()).await
|
||||
}
|
||||
pub async fn with_config(config: DhtConfig) -> anyhow::Result<Self> {
|
||||
pub async fn with_config(config: DhtConfig) -> anyhow::Result<Arc<Self>> {
|
||||
let socket = match config.listen_addr {
|
||||
Some(addr) => UdpSocket::bind(addr)
|
||||
.await
|
||||
|
|
@ -724,12 +726,12 @@ impl Dht {
|
|||
.unwrap_or_else(|| crate::DHT_BOOTSTRAP.iter().map(|v| v.to_string()).collect());
|
||||
|
||||
let (in_tx, in_rx) = unbounded_channel();
|
||||
let state = Arc::new(RwLock::new(DhtState::new(
|
||||
let state = Arc::new(Self::new_internal(
|
||||
peer_id,
|
||||
in_tx,
|
||||
config.routing_table,
|
||||
listen_addr,
|
||||
)));
|
||||
));
|
||||
|
||||
spawn(error_span!("dht"), {
|
||||
let state = state.clone();
|
||||
|
|
@ -743,17 +745,17 @@ impl Dht {
|
|||
Ok(())
|
||||
}
|
||||
});
|
||||
Ok(Dht { state })
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
pub fn get_peers(
|
||||
&self,
|
||||
self: &Arc<Self>,
|
||||
info_hash: Id20,
|
||||
) -> anyhow::Result<impl Stream<Item = SocketAddr> + Unpin> {
|
||||
let (pos, rx) = self.state.write().get_peers(info_hash)?;
|
||||
let (pos, rx) = self.get_peers_internal(info_hash)?;
|
||||
Ok(PeerStream {
|
||||
info_hash,
|
||||
state: self.state.clone(),
|
||||
state: self.clone(),
|
||||
absolute_stream_pos: 0,
|
||||
initial_peers_pos: pos,
|
||||
broadcast_rx: BroadcastStream::new(rx),
|
||||
|
|
@ -761,18 +763,18 @@ impl Dht {
|
|||
}
|
||||
|
||||
pub fn listen_addr(&self) -> SocketAddr {
|
||||
self.state.read().listen_addr
|
||||
self.listen_addr
|
||||
}
|
||||
|
||||
pub fn stats(&self) -> DhtStats {
|
||||
self.state.read().get_stats()
|
||||
self.get_stats()
|
||||
}
|
||||
|
||||
pub fn with_routing_table<R, F: FnOnce(&RoutingTable) -> R>(&self, f: F) -> R {
|
||||
f(&self.state.read().routing_table)
|
||||
f(&self.routing_table.read())
|
||||
}
|
||||
|
||||
pub fn clone_routing_table(&self) -> RoutingTable {
|
||||
self.state.read().routing_table.clone()
|
||||
self.routing_table.read().clone()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,9 +4,26 @@ mod persistence;
|
|||
mod routing_table;
|
||||
mod utils;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use crate::dht::DhtStats;
|
||||
pub use crate::dht::{Dht, DhtConfig};
|
||||
pub use crate::dht::{DhtConfig, DhtState};
|
||||
pub use librqbit_core::id20::Id20;
|
||||
pub use persistence::{PersistentDht, PersistentDhtConfig};
|
||||
|
||||
pub type Dht = Arc<DhtState>;
|
||||
|
||||
pub struct DhtBuilder {}
|
||||
|
||||
impl DhtBuilder {
|
||||
#[allow(clippy::new_ret_no_self)]
|
||||
pub async fn new() -> anyhow::Result<Dht> {
|
||||
DhtState::new().await
|
||||
}
|
||||
|
||||
pub async fn with_config(config: DhtConfig) -> anyhow::Result<Dht> {
|
||||
DhtState::with_config(config).await
|
||||
}
|
||||
}
|
||||
|
||||
pub static DHT_BOOTSTRAP: &[&str] = &["dht.transmissionbt.com:6881", "dht.libtorrent.org:25401"];
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ use std::time::Duration;
|
|||
use anyhow::Context;
|
||||
use tracing::{debug, error, error_span, info, trace, warn};
|
||||
|
||||
use crate::dht::{Dht, DhtConfig};
|
||||
use crate::routing_table::RoutingTable;
|
||||
use crate::{Dht, DhtConfig, DhtState};
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct PersistentDhtConfig {
|
||||
|
|
@ -108,7 +108,7 @@ impl PersistentDht {
|
|||
listen_addr,
|
||||
..Default::default()
|
||||
};
|
||||
let dht = Dht::with_config(dht_config).await?;
|
||||
let dht = DhtState::with_config(dht_config).await?;
|
||||
|
||||
spawn(error_span!("dht_persistence"), {
|
||||
let dht = dht.clone();
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ pub async fn read_metainfo_from_peer_receiver<A: Stream<Item = SocketAddr> + Unp
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use dht::{Dht, Id20};
|
||||
use dht::{Dht, DhtBuilder, Id20};
|
||||
use librqbit_core::peer_id::generate_peer_id;
|
||||
|
||||
use super::*;
|
||||
|
|
@ -106,7 +106,7 @@ mod tests {
|
|||
init_logging();
|
||||
|
||||
let info_hash = Id20::from_str("cf3ea75e2ebbd30e0da6e6e215e2226bf35f2e33").unwrap();
|
||||
let dht = Dht::new().await.unwrap();
|
||||
let dht = DhtBuilder::new().await.unwrap();
|
||||
let peer_rx = dht.get_peers(info_hash).unwrap();
|
||||
let peer_id = generate_peer_id();
|
||||
match read_metainfo_from_peer_receiver(peer_id, info_hash, peer_rx, None).await {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ use std::{
|
|||
|
||||
use anyhow::{bail, Context};
|
||||
use buffers::ByteString;
|
||||
use dht::{Dht, Id20, PersistentDht, PersistentDhtConfig};
|
||||
use dht::{Dht, DhtBuilder, Id20, PersistentDht, PersistentDhtConfig};
|
||||
use librqbit_core::{
|
||||
magnet::Magnet,
|
||||
peer_id::generate_peer_id,
|
||||
|
|
@ -234,7 +234,7 @@ impl Session {
|
|||
None
|
||||
} else {
|
||||
let dht = if opts.disable_dht_persistence {
|
||||
Dht::new().await
|
||||
DhtBuilder::new().await
|
||||
} else {
|
||||
PersistentDht::create(opts.dht_config).await
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue