Refactor DHT a bit
This commit is contained in:
parent
48f4c0a8b7
commit
ace11186ef
6 changed files with 119 additions and 96 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -1614,6 +1614,7 @@ dependencies = [
|
|||
"futures-core",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ edition = "2018"
|
|||
|
||||
[dependencies]
|
||||
tokio = {version = "1", features = ["macros", "rt-multi-thread", "net", "sync"]}
|
||||
tokio-stream = "0.1"
|
||||
tokio-stream = {version = "0.1", features = ["sync"]}
|
||||
serde = {version = "1", features = ["derive"]}
|
||||
hex = "0.4"
|
||||
bencode = {path = "../bencode"}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
collections::{hash_map::Entry, HashMap, HashSet},
|
||||
net::SocketAddr,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
|
|
@ -12,7 +13,7 @@ use crate::{
|
|||
};
|
||||
use anyhow::Context;
|
||||
use bencode::ByteString;
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use futures::{stream::FuturesUnordered, Stream, StreamExt, TryStreamExt};
|
||||
use librqbit_core::{id20::Id20, peer_id::generate_peer_id};
|
||||
use log::{debug, info, trace, warn};
|
||||
use parking_lot::Mutex;
|
||||
|
|
@ -22,7 +23,7 @@ use tokio::{
|
|||
channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender,
|
||||
},
|
||||
};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::wrappers::{BroadcastStream, UnboundedReceiverStream};
|
||||
|
||||
struct OutstandingRequest {
|
||||
transaction_id: u16,
|
||||
|
|
@ -30,35 +31,42 @@ struct OutstandingRequest {
|
|||
request: Request,
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// - searching for peers - make it a set
|
||||
// - peers - convert to broadcast
|
||||
// - return a DHT handle.
|
||||
// - flatten abstractions
|
||||
// - framer is fine (I guess)
|
||||
// - DhtHandle - straight out do things
|
||||
|
||||
struct DhtState {
|
||||
id: Id20,
|
||||
next_transaction_id: u16,
|
||||
outstanding_requests: Vec<OutstandingRequest>,
|
||||
searching_for_peers: Vec<Id20>,
|
||||
routing_table: RoutingTable,
|
||||
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
|
||||
|
||||
// TODO: convert to broadcast
|
||||
subscribers: HashMap<Id20, Vec<UnboundedSender<Response>>>,
|
||||
seen_peers: HashMap<Id20, HashSet<SocketAddr>>,
|
||||
get_peers_subscribers: HashMap<Id20, tokio::sync::broadcast::Sender<SocketAddr>>,
|
||||
|
||||
made_requests: HashSet<(Request, SocketAddr)>,
|
||||
}
|
||||
|
||||
impl DhtState {
|
||||
pub fn new(id: Id20, sender: UnboundedSender<(Message<ByteString>, SocketAddr)>) -> Self {
|
||||
fn new(id: Id20, sender: UnboundedSender<(Message<ByteString>, SocketAddr)>) -> Self {
|
||||
Self {
|
||||
id,
|
||||
next_transaction_id: 0,
|
||||
outstanding_requests: Vec::new(),
|
||||
searching_for_peers: Vec::new(),
|
||||
routing_table: RoutingTable::new(id),
|
||||
sender,
|
||||
subscribers: Default::default(),
|
||||
seen_peers: Default::default(),
|
||||
get_peers_subscribers: Default::default(),
|
||||
made_requests: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message<ByteString> {
|
||||
fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message<ByteString> {
|
||||
let transaction_id = self.next_transaction_id;
|
||||
let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8];
|
||||
self.next_transaction_id += 1;
|
||||
|
|
@ -191,18 +199,27 @@ impl DhtState {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn on_request(
|
||||
pub fn get_peers(
|
||||
&mut self,
|
||||
request: Request,
|
||||
sender: UnboundedSender<Response>,
|
||||
) -> anyhow::Result<()> {
|
||||
match request {
|
||||
Request::GetPeers(info_hash) => {
|
||||
let subs = self.subscribers.entry(info_hash).or_default();
|
||||
subs.push(sender);
|
||||
self.searching_for_peers.push(info_hash);
|
||||
info_hash: Id20,
|
||||
) -> anyhow::Result<(
|
||||
Vec<SocketAddr>,
|
||||
tokio::sync::broadcast::Receiver<SocketAddr>,
|
||||
)> {
|
||||
match self.get_peers_subscribers.entry(info_hash) {
|
||||
Entry::Occupied(o) => {
|
||||
let existing_peers = self
|
||||
.seen_peers
|
||||
.get(&info_hash)
|
||||
.map(|c| c.iter().copied().collect())
|
||||
.unwrap_or_default();
|
||||
let rx = o.get().subscribe();
|
||||
return Ok((existing_peers, rx));
|
||||
}
|
||||
Entry::Vacant(v) => {
|
||||
let (tx, rx) = tokio::sync::broadcast::channel(100);
|
||||
v.insert(tx);
|
||||
|
||||
// workaround borrow checker.
|
||||
let mut addrs = Vec::new();
|
||||
for node in self
|
||||
.routing_table
|
||||
|
|
@ -219,10 +236,10 @@ impl DhtState {
|
|||
.send((request, addr))
|
||||
.context("DhtState: error sending to self.sender")?;
|
||||
}
|
||||
|
||||
return Ok((Vec::new(), rx));
|
||||
}
|
||||
Request::FindNode(_) => todo!(),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn on_found_nodes(
|
||||
|
|
@ -232,11 +249,18 @@ impl DhtState {
|
|||
_target: Id20,
|
||||
nodes: CompactNodeInfo,
|
||||
) -> anyhow::Result<()> {
|
||||
// We don't need to allocate/collect here, but the borrow checker is not happy
|
||||
// otherwise when we iterate self.searching_for_peers and mutating self in the loop.
|
||||
let searching_for_peers = self
|
||||
.get_peers_subscribers
|
||||
.keys()
|
||||
.copied()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
match self.routing_table.add_node(source, source_addr) {
|
||||
InsertResult::ReplacedBad(_) | InsertResult::Added => {
|
||||
for idx in 0..self.searching_for_peers.len() {
|
||||
let info_hash = self.searching_for_peers[idx];
|
||||
let request = Request::GetPeers(info_hash);
|
||||
for info_hash in &searching_for_peers {
|
||||
let request = Request::GetPeers(*info_hash);
|
||||
if self.made_requests.insert((request, source_addr)) {
|
||||
self.routing_table.mark_outgoing_request(&source);
|
||||
let msg = self.create_request(request, source_addr);
|
||||
|
|
@ -249,12 +273,10 @@ impl DhtState {
|
|||
for node in nodes.nodes {
|
||||
match self.routing_table.add_node(node.id, node.addr.into()) {
|
||||
InsertResult::ReplacedBad(_) | InsertResult::Added => {
|
||||
for idx in 0..self.searching_for_peers.len() {
|
||||
let info_hash = self.searching_for_peers[idx];
|
||||
let request = Request::GetPeers(info_hash);
|
||||
for info_hash in &searching_for_peers {
|
||||
let request = Request::GetPeers(*info_hash);
|
||||
if self.made_requests.insert((request, node.addr.into())) {
|
||||
let msg =
|
||||
self.create_request(Request::GetPeers(info_hash), node.addr.into());
|
||||
let msg = self.create_request(request, node.addr.into());
|
||||
self.routing_table.mark_outgoing_request(&node.id);
|
||||
self.sender.send((msg, node.addr.into()))?
|
||||
}
|
||||
|
|
@ -277,8 +299,8 @@ impl DhtState {
|
|||
self.routing_table.mark_response(&source);
|
||||
|
||||
if let Some(peers) = data.values {
|
||||
let subscribers = match self.subscribers.get(&target) {
|
||||
Some(subscribers) => subscribers,
|
||||
let bsender = match self.get_peers_subscribers.get(&target) {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
warn!(
|
||||
"ignoring peers for {:?}: no subscribers left. Peers: {:?}",
|
||||
|
|
@ -287,10 +309,10 @@ impl DhtState {
|
|||
return Ok(());
|
||||
}
|
||||
};
|
||||
for subscriber in subscribers {
|
||||
for peer in peers.iter() {
|
||||
subscriber.send(Response::Peer(peer.addr.into()))?
|
||||
}
|
||||
for peer in peers.iter() {
|
||||
bsender
|
||||
.send(peer.addr.into())
|
||||
.context("error sending peers to subscribers")?;
|
||||
}
|
||||
};
|
||||
if let Some(nodes) = data.nodes {
|
||||
|
|
@ -378,24 +400,18 @@ enum Response {
|
|||
Peer(SocketAddr),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Dht {
|
||||
request_tx: Sender<(Request, UnboundedSender<Response>)>,
|
||||
state: Arc<Mutex<DhtState>>,
|
||||
}
|
||||
|
||||
struct DhtWorker {
|
||||
socket: UdpSocket,
|
||||
peer_id: Id20,
|
||||
state: Mutex<DhtState>,
|
||||
state: Arc<Mutex<DhtState>>,
|
||||
}
|
||||
|
||||
impl DhtWorker {
|
||||
fn on_request(
|
||||
&self,
|
||||
request: Request,
|
||||
sender: UnboundedSender<Response>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.state.lock().on_request(request, sender)
|
||||
}
|
||||
fn on_response(&self, msg: Message<ByteString>, addr: SocketAddr) -> anyhow::Result<()> {
|
||||
self.state.lock().on_incoming_from_remote(msg, addr)
|
||||
}
|
||||
|
|
@ -447,19 +463,6 @@ impl DhtWorker {
|
|||
};
|
||||
let mut bootstrap_done = false;
|
||||
|
||||
let request_reader = {
|
||||
let this = &self;
|
||||
async move {
|
||||
while let Some((request, sender)) = request_rx.recv().await {
|
||||
this.on_request(request, sender)
|
||||
.context("error processing request")?;
|
||||
}
|
||||
Err::<(), _>(anyhow::anyhow!(
|
||||
"closed request reader, no more subscribers"
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let response_reader = {
|
||||
let this = &self;
|
||||
async move {
|
||||
|
|
@ -476,7 +479,6 @@ impl DhtWorker {
|
|||
|
||||
tokio::pin!(framer);
|
||||
tokio::pin!(bootstrap);
|
||||
tokio::pin!(request_reader);
|
||||
tokio::pin!(response_reader);
|
||||
|
||||
loop {
|
||||
|
|
@ -488,7 +490,6 @@ impl DhtWorker {
|
|||
bootstrap_done = true;
|
||||
result?;
|
||||
},
|
||||
err = &mut request_reader => {anyhow::bail!("request reader quit: {:?}", err)}
|
||||
err = &mut response_reader => {anyhow::bail!("response reader quit: {:?}", err)}
|
||||
}
|
||||
}
|
||||
|
|
@ -511,35 +512,32 @@ impl Dht {
|
|||
.map(|s| s.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let (in_tx, in_rx) = unbounded_channel();
|
||||
let worker = DhtWorker {
|
||||
socket,
|
||||
peer_id,
|
||||
state: Mutex::new(DhtState::new(peer_id, in_tx.clone())),
|
||||
};
|
||||
let result = worker
|
||||
.start(in_tx, in_rx, request_rx, &bootstrap_addrs)
|
||||
.await;
|
||||
warn!("DHT worker finished with {:?}", result);
|
||||
let (in_tx, in_rx) = unbounded_channel();
|
||||
let state = Arc::new(Mutex::new(DhtState::new(peer_id, in_tx.clone())));
|
||||
|
||||
tokio::spawn({
|
||||
let state = state.clone();
|
||||
async move {
|
||||
let worker = DhtWorker {
|
||||
socket,
|
||||
peer_id,
|
||||
state,
|
||||
};
|
||||
let result = worker
|
||||
.start(in_tx, in_rx, request_rx, &bootstrap_addrs)
|
||||
.await;
|
||||
warn!("DHT worker finished with {:?}", result);
|
||||
}
|
||||
});
|
||||
Ok(Dht { request_tx })
|
||||
Ok(Dht { state })
|
||||
}
|
||||
pub async fn get_peers(&self, info_hash: Id20) -> impl StreamExt<Item = SocketAddr> {
|
||||
let (tx, rx) = unbounded_channel::<Response>();
|
||||
|
||||
// This is a hack to test localhost speeds, uncomment to test that quickly.
|
||||
//
|
||||
// tx.send(Response::Peer("127.0.0.1:27311".parse().unwrap()))
|
||||
// .unwrap();
|
||||
// std::mem::forget(tx);
|
||||
|
||||
self.request_tx
|
||||
.send((Request::GetPeers(info_hash), tx))
|
||||
.await
|
||||
.unwrap();
|
||||
UnboundedReceiverStream::new(rx).map(|r| match r {
|
||||
Response::Peer(addr) => addr,
|
||||
})
|
||||
pub async fn get_peers(
|
||||
&self,
|
||||
info_hash: Id20,
|
||||
) -> anyhow::Result<impl Stream<Item = anyhow::Result<SocketAddr>> + Unpin> {
|
||||
let (initial_peers, rx) = self.state.lock().get_peers(info_hash)?;
|
||||
let rx = BroadcastStream::new(rx).map_err(|e| e.into());
|
||||
let rx = futures::stream::iter(initial_peers).map(Ok).chain(rx);
|
||||
Ok(rx)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,9 +10,10 @@ async fn main() -> anyhow::Result<()> {
|
|||
|
||||
let info_hash = Id20::from_str("64a980abe6e448226bb930ba061592e44c3781a1").unwrap();
|
||||
let dht = Dht::new().await.context("error initializing DHT")?;
|
||||
let mut stream = dht.get_peers(info_hash).await;
|
||||
let mut stream = dht.get_peers(info_hash).await?;
|
||||
let mut seen = HashSet::new();
|
||||
while let Some(peer) = stream.next().await {
|
||||
let peer = peer.context("error reading peer stream")?;
|
||||
if seen.insert(peer) {
|
||||
log::info!("peer found: {}", peer)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use std::{collections::HashSet, net::SocketAddr};
|
|||
|
||||
use anyhow::Context;
|
||||
use buffers::ByteString;
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use futures::{stream::FuturesUnordered, Stream, StreamExt};
|
||||
use librqbit_core::torrent_metainfo::TorrentMetaV1Info;
|
||||
use log::debug;
|
||||
|
||||
|
|
@ -21,7 +21,7 @@ pub enum ReadMetainfoResult<Rx> {
|
|||
},
|
||||
}
|
||||
|
||||
pub async fn read_metainfo_from_peer_receiver<A: StreamExt<Item = SocketAddr> + Unpin>(
|
||||
pub async fn read_metainfo_from_peer_receiver<A: Stream<Item = SocketAddr> + Unpin>(
|
||||
peer_id: Id20,
|
||||
info_hash: Id20,
|
||||
mut addrs: A,
|
||||
|
|
@ -101,8 +101,10 @@ mod tests {
|
|||
|
||||
let info_hash = Id20::from_str("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7").unwrap();
|
||||
let dht = Dht::new().await.unwrap();
|
||||
let peer_rx = dht.get_peers(info_hash).await;
|
||||
let peer_rx = dht.get_peers(info_hash).await.unwrap();
|
||||
let peer_id = generate_peer_id();
|
||||
let peer_rx = peer_rx.filter_map(|r| async move { r.ok() });
|
||||
tokio::pin!(peer_rx);
|
||||
match read_metainfo_from_peer_receiver(peer_id, info_hash, peer_rx, None).await {
|
||||
ReadMetainfoResult::Found { info, .. } => dbg!(info),
|
||||
ReadMetainfoResult::ChannelClosed { .. } => todo!("should not have happened"),
|
||||
|
|
|
|||
|
|
@ -208,7 +208,16 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
|
|||
let dht_rx = dht
|
||||
.ok_or_else(|| anyhow::anyhow!("magnet links without DHT are not supported"))?
|
||||
.get_peers(info_hash)
|
||||
.await;
|
||||
.await?;
|
||||
let dht_rx = Box::pin(dht_rx.filter_map(|addr| async move {
|
||||
match addr {
|
||||
Ok(addr) => Some(addr),
|
||||
Err(e) => {
|
||||
warn!("DHT peer receiver got an error: {:#}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
let trackers = trackers
|
||||
.into_iter()
|
||||
|
|
@ -250,7 +259,19 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
|
|||
torrent_from_file(&opts.torrent_path)?
|
||||
};
|
||||
let dht_rx = match dht {
|
||||
Some(dht) => Some(dht.get_peers(torrent.info_hash).await),
|
||||
Some(dht) => Some(Box::pin(
|
||||
dht.get_peers(torrent.info_hash)
|
||||
.await?
|
||||
.filter_map(|r| async move {
|
||||
match r {
|
||||
Ok(addr) => Some(addr),
|
||||
Err(e) => {
|
||||
warn!("DHT peer receiver got an error: {:#}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}),
|
||||
)),
|
||||
None => None,
|
||||
};
|
||||
let trackers = torrent
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue