Refactor DHT a bit

This commit is contained in:
Igor Katson 2021-07-13 16:10:36 +01:00
parent 48f4c0a8b7
commit ace11186ef
6 changed files with 119 additions and 96 deletions

1
Cargo.lock generated
View file

@ -1614,6 +1614,7 @@ dependencies = [
"futures-core", "futures-core",
"pin-project-lite", "pin-project-lite",
"tokio", "tokio",
"tokio-util",
] ]
[[package]] [[package]]

View file

@ -7,7 +7,7 @@ edition = "2018"
[dependencies] [dependencies]
tokio = {version = "1", features = ["macros", "rt-multi-thread", "net", "sync"]} tokio = {version = "1", features = ["macros", "rt-multi-thread", "net", "sync"]}
tokio-stream = "0.1" tokio-stream = {version = "0.1", features = ["sync"]}
serde = {version = "1", features = ["derive"]} serde = {version = "1", features = ["derive"]}
hex = "0.4" hex = "0.4"
bencode = {path = "../bencode"} bencode = {path = "../bencode"}

View file

@ -1,6 +1,7 @@
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{hash_map::Entry, HashMap, HashSet},
net::SocketAddr, net::SocketAddr,
sync::Arc,
}; };
use crate::{ use crate::{
@ -12,7 +13,7 @@ use crate::{
}; };
use anyhow::Context; use anyhow::Context;
use bencode::ByteString; use bencode::ByteString;
use futures::{stream::FuturesUnordered, StreamExt}; use futures::{stream::FuturesUnordered, Stream, StreamExt, TryStreamExt};
use librqbit_core::{id20::Id20, peer_id::generate_peer_id}; use librqbit_core::{id20::Id20, peer_id::generate_peer_id};
use log::{debug, info, trace, warn}; use log::{debug, info, trace, warn};
use parking_lot::Mutex; use parking_lot::Mutex;
@ -22,7 +23,7 @@ use tokio::{
channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender, channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender,
}, },
}; };
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::{BroadcastStream, UnboundedReceiverStream};
struct OutstandingRequest { struct OutstandingRequest {
transaction_id: u16, transaction_id: u16,
@ -30,35 +31,42 @@ struct OutstandingRequest {
request: Request, request: Request,
} }
// TODO:
// - searching for peers - make it a set
// - peers - convert to broadcast
// - return a DHT handle.
// - flatten abstractions
// - framer is fine (I guess)
// - DhtHandle - straight out do things
struct DhtState { struct DhtState {
id: Id20, id: Id20,
next_transaction_id: u16, next_transaction_id: u16,
outstanding_requests: Vec<OutstandingRequest>, outstanding_requests: Vec<OutstandingRequest>,
searching_for_peers: Vec<Id20>,
routing_table: RoutingTable, routing_table: RoutingTable,
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>, sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
// TODO: convert to broadcast seen_peers: HashMap<Id20, HashSet<SocketAddr>>,
subscribers: HashMap<Id20, Vec<UnboundedSender<Response>>>, get_peers_subscribers: HashMap<Id20, tokio::sync::broadcast::Sender<SocketAddr>>,
made_requests: HashSet<(Request, SocketAddr)>, made_requests: HashSet<(Request, SocketAddr)>,
} }
impl DhtState { impl DhtState {
pub fn new(id: Id20, sender: UnboundedSender<(Message<ByteString>, SocketAddr)>) -> Self { fn new(id: Id20, sender: UnboundedSender<(Message<ByteString>, SocketAddr)>) -> Self {
Self { Self {
id, id,
next_transaction_id: 0, next_transaction_id: 0,
outstanding_requests: Vec::new(), outstanding_requests: Vec::new(),
searching_for_peers: Vec::new(),
routing_table: RoutingTable::new(id), routing_table: RoutingTable::new(id),
sender, sender,
subscribers: Default::default(), seen_peers: Default::default(),
get_peers_subscribers: Default::default(),
made_requests: Default::default(), made_requests: Default::default(),
} }
} }
pub fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message<ByteString> { fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message<ByteString> {
let transaction_id = self.next_transaction_id; let transaction_id = self.next_transaction_id;
let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8]; let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8];
self.next_transaction_id += 1; self.next_transaction_id += 1;
@ -191,18 +199,27 @@ impl DhtState {
} }
} }
pub fn on_request( pub fn get_peers(
&mut self, &mut self,
request: Request, info_hash: Id20,
sender: UnboundedSender<Response>, ) -> anyhow::Result<(
) -> anyhow::Result<()> { Vec<SocketAddr>,
match request { tokio::sync::broadcast::Receiver<SocketAddr>,
Request::GetPeers(info_hash) => { )> {
let subs = self.subscribers.entry(info_hash).or_default(); match self.get_peers_subscribers.entry(info_hash) {
subs.push(sender); Entry::Occupied(o) => {
self.searching_for_peers.push(info_hash); let existing_peers = self
.seen_peers
.get(&info_hash)
.map(|c| c.iter().copied().collect())
.unwrap_or_default();
let rx = o.get().subscribe();
return Ok((existing_peers, rx));
}
Entry::Vacant(v) => {
let (tx, rx) = tokio::sync::broadcast::channel(100);
v.insert(tx);
// workaround borrow checker.
let mut addrs = Vec::new(); let mut addrs = Vec::new();
for node in self for node in self
.routing_table .routing_table
@ -219,10 +236,10 @@ impl DhtState {
.send((request, addr)) .send((request, addr))
.context("DhtState: error sending to self.sender")?; .context("DhtState: error sending to self.sender")?;
} }
return Ok((Vec::new(), rx));
} }
Request::FindNode(_) => todo!(), }
};
Ok(())
} }
fn on_found_nodes( fn on_found_nodes(
@ -232,11 +249,18 @@ impl DhtState {
_target: Id20, _target: Id20,
nodes: CompactNodeInfo, nodes: CompactNodeInfo,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// We don't need to allocate/collect here, but the borrow checker is not happy
// otherwise when we iterate self.searching_for_peers and mutating self in the loop.
let searching_for_peers = self
.get_peers_subscribers
.keys()
.copied()
.collect::<Vec<_>>();
match self.routing_table.add_node(source, source_addr) { match self.routing_table.add_node(source, source_addr) {
InsertResult::ReplacedBad(_) | InsertResult::Added => { InsertResult::ReplacedBad(_) | InsertResult::Added => {
for idx in 0..self.searching_for_peers.len() { for info_hash in &searching_for_peers {
let info_hash = self.searching_for_peers[idx]; let request = Request::GetPeers(*info_hash);
let request = Request::GetPeers(info_hash);
if self.made_requests.insert((request, source_addr)) { if self.made_requests.insert((request, source_addr)) {
self.routing_table.mark_outgoing_request(&source); self.routing_table.mark_outgoing_request(&source);
let msg = self.create_request(request, source_addr); let msg = self.create_request(request, source_addr);
@ -249,12 +273,10 @@ impl DhtState {
for node in nodes.nodes { for node in nodes.nodes {
match self.routing_table.add_node(node.id, node.addr.into()) { match self.routing_table.add_node(node.id, node.addr.into()) {
InsertResult::ReplacedBad(_) | InsertResult::Added => { InsertResult::ReplacedBad(_) | InsertResult::Added => {
for idx in 0..self.searching_for_peers.len() { for info_hash in &searching_for_peers {
let info_hash = self.searching_for_peers[idx]; let request = Request::GetPeers(*info_hash);
let request = Request::GetPeers(info_hash);
if self.made_requests.insert((request, node.addr.into())) { if self.made_requests.insert((request, node.addr.into())) {
let msg = let msg = self.create_request(request, node.addr.into());
self.create_request(Request::GetPeers(info_hash), node.addr.into());
self.routing_table.mark_outgoing_request(&node.id); self.routing_table.mark_outgoing_request(&node.id);
self.sender.send((msg, node.addr.into()))? self.sender.send((msg, node.addr.into()))?
} }
@ -277,8 +299,8 @@ impl DhtState {
self.routing_table.mark_response(&source); self.routing_table.mark_response(&source);
if let Some(peers) = data.values { if let Some(peers) = data.values {
let subscribers = match self.subscribers.get(&target) { let bsender = match self.get_peers_subscribers.get(&target) {
Some(subscribers) => subscribers, Some(s) => s,
None => { None => {
warn!( warn!(
"ignoring peers for {:?}: no subscribers left. Peers: {:?}", "ignoring peers for {:?}: no subscribers left. Peers: {:?}",
@ -287,10 +309,10 @@ impl DhtState {
return Ok(()); return Ok(());
} }
}; };
for subscriber in subscribers { for peer in peers.iter() {
for peer in peers.iter() { bsender
subscriber.send(Response::Peer(peer.addr.into()))? .send(peer.addr.into())
} .context("error sending peers to subscribers")?;
} }
}; };
if let Some(nodes) = data.nodes { if let Some(nodes) = data.nodes {
@ -378,24 +400,18 @@ enum Response {
Peer(SocketAddr), Peer(SocketAddr),
} }
#[derive(Clone)]
pub struct Dht { pub struct Dht {
request_tx: Sender<(Request, UnboundedSender<Response>)>, state: Arc<Mutex<DhtState>>,
} }
struct DhtWorker { struct DhtWorker {
socket: UdpSocket, socket: UdpSocket,
peer_id: Id20, peer_id: Id20,
state: Mutex<DhtState>, state: Arc<Mutex<DhtState>>,
} }
impl DhtWorker { impl DhtWorker {
fn on_request(
&self,
request: Request,
sender: UnboundedSender<Response>,
) -> anyhow::Result<()> {
self.state.lock().on_request(request, sender)
}
fn on_response(&self, msg: Message<ByteString>, addr: SocketAddr) -> anyhow::Result<()> { fn on_response(&self, msg: Message<ByteString>, addr: SocketAddr) -> anyhow::Result<()> {
self.state.lock().on_incoming_from_remote(msg, addr) self.state.lock().on_incoming_from_remote(msg, addr)
} }
@ -447,19 +463,6 @@ impl DhtWorker {
}; };
let mut bootstrap_done = false; let mut bootstrap_done = false;
let request_reader = {
let this = &self;
async move {
while let Some((request, sender)) = request_rx.recv().await {
this.on_request(request, sender)
.context("error processing request")?;
}
Err::<(), _>(anyhow::anyhow!(
"closed request reader, no more subscribers"
))
}
};
let response_reader = { let response_reader = {
let this = &self; let this = &self;
async move { async move {
@ -476,7 +479,6 @@ impl DhtWorker {
tokio::pin!(framer); tokio::pin!(framer);
tokio::pin!(bootstrap); tokio::pin!(bootstrap);
tokio::pin!(request_reader);
tokio::pin!(response_reader); tokio::pin!(response_reader);
loop { loop {
@ -488,7 +490,6 @@ impl DhtWorker {
bootstrap_done = true; bootstrap_done = true;
result?; result?;
}, },
err = &mut request_reader => {anyhow::bail!("request reader quit: {:?}", err)}
err = &mut response_reader => {anyhow::bail!("response reader quit: {:?}", err)} err = &mut response_reader => {anyhow::bail!("response reader quit: {:?}", err)}
} }
} }
@ -511,35 +512,32 @@ impl Dht {
.map(|s| s.to_string()) .map(|s| s.to_string())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
tokio::spawn(async move { let (in_tx, in_rx) = unbounded_channel();
let (in_tx, in_rx) = unbounded_channel(); let state = Arc::new(Mutex::new(DhtState::new(peer_id, in_tx.clone())));
let worker = DhtWorker {
socket, tokio::spawn({
peer_id, let state = state.clone();
state: Mutex::new(DhtState::new(peer_id, in_tx.clone())), async move {
}; let worker = DhtWorker {
let result = worker socket,
.start(in_tx, in_rx, request_rx, &bootstrap_addrs) peer_id,
.await; state,
warn!("DHT worker finished with {:?}", result); };
let result = worker
.start(in_tx, in_rx, request_rx, &bootstrap_addrs)
.await;
warn!("DHT worker finished with {:?}", result);
}
}); });
Ok(Dht { request_tx }) Ok(Dht { state })
} }
pub async fn get_peers(&self, info_hash: Id20) -> impl StreamExt<Item = SocketAddr> { pub async fn get_peers(
let (tx, rx) = unbounded_channel::<Response>(); &self,
info_hash: Id20,
// This is a hack to test localhost speeds, uncomment to test that quickly. ) -> anyhow::Result<impl Stream<Item = anyhow::Result<SocketAddr>> + Unpin> {
// let (initial_peers, rx) = self.state.lock().get_peers(info_hash)?;
// tx.send(Response::Peer("127.0.0.1:27311".parse().unwrap())) let rx = BroadcastStream::new(rx).map_err(|e| e.into());
// .unwrap(); let rx = futures::stream::iter(initial_peers).map(Ok).chain(rx);
// std::mem::forget(tx); Ok(rx)
self.request_tx
.send((Request::GetPeers(info_hash), tx))
.await
.unwrap();
UnboundedReceiverStream::new(rx).map(|r| match r {
Response::Peer(addr) => addr,
})
} }
} }

View file

@ -10,9 +10,10 @@ async fn main() -> anyhow::Result<()> {
let info_hash = Id20::from_str("64a980abe6e448226bb930ba061592e44c3781a1").unwrap(); let info_hash = Id20::from_str("64a980abe6e448226bb930ba061592e44c3781a1").unwrap();
let dht = Dht::new().await.context("error initializing DHT")?; let dht = Dht::new().await.context("error initializing DHT")?;
let mut stream = dht.get_peers(info_hash).await; let mut stream = dht.get_peers(info_hash).await?;
let mut seen = HashSet::new(); let mut seen = HashSet::new();
while let Some(peer) = stream.next().await { while let Some(peer) = stream.next().await {
let peer = peer.context("error reading peer stream")?;
if seen.insert(peer) { if seen.insert(peer) {
log::info!("peer found: {}", peer) log::info!("peer found: {}", peer)
} }

View file

@ -2,7 +2,7 @@ use std::{collections::HashSet, net::SocketAddr};
use anyhow::Context; use anyhow::Context;
use buffers::ByteString; use buffers::ByteString;
use futures::{stream::FuturesUnordered, StreamExt}; use futures::{stream::FuturesUnordered, Stream, StreamExt};
use librqbit_core::torrent_metainfo::TorrentMetaV1Info; use librqbit_core::torrent_metainfo::TorrentMetaV1Info;
use log::debug; use log::debug;
@ -21,7 +21,7 @@ pub enum ReadMetainfoResult<Rx> {
}, },
} }
pub async fn read_metainfo_from_peer_receiver<A: StreamExt<Item = SocketAddr> + Unpin>( pub async fn read_metainfo_from_peer_receiver<A: Stream<Item = SocketAddr> + Unpin>(
peer_id: Id20, peer_id: Id20,
info_hash: Id20, info_hash: Id20,
mut addrs: A, mut addrs: A,
@ -101,8 +101,10 @@ mod tests {
let info_hash = Id20::from_str("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7").unwrap(); let info_hash = Id20::from_str("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7").unwrap();
let dht = Dht::new().await.unwrap(); let dht = Dht::new().await.unwrap();
let peer_rx = dht.get_peers(info_hash).await; let peer_rx = dht.get_peers(info_hash).await.unwrap();
let peer_id = generate_peer_id(); let peer_id = generate_peer_id();
let peer_rx = peer_rx.filter_map(|r| async move { r.ok() });
tokio::pin!(peer_rx);
match read_metainfo_from_peer_receiver(peer_id, info_hash, peer_rx, None).await { match read_metainfo_from_peer_receiver(peer_id, info_hash, peer_rx, None).await {
ReadMetainfoResult::Found { info, .. } => dbg!(info), ReadMetainfoResult::Found { info, .. } => dbg!(info),
ReadMetainfoResult::ChannelClosed { .. } => todo!("should not have happened"), ReadMetainfoResult::ChannelClosed { .. } => todo!("should not have happened"),

View file

@ -208,7 +208,16 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
let dht_rx = dht let dht_rx = dht
.ok_or_else(|| anyhow::anyhow!("magnet links without DHT are not supported"))? .ok_or_else(|| anyhow::anyhow!("magnet links without DHT are not supported"))?
.get_peers(info_hash) .get_peers(info_hash)
.await; .await?;
let dht_rx = Box::pin(dht_rx.filter_map(|addr| async move {
match addr {
Ok(addr) => Some(addr),
Err(e) => {
warn!("DHT peer receiver got an error: {:#}", e);
None
}
}
}));
let trackers = trackers let trackers = trackers
.into_iter() .into_iter()
@ -250,7 +259,19 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
torrent_from_file(&opts.torrent_path)? torrent_from_file(&opts.torrent_path)?
}; };
let dht_rx = match dht { let dht_rx = match dht {
Some(dht) => Some(dht.get_peers(torrent.info_hash).await), Some(dht) => Some(Box::pin(
dht.get_peers(torrent.info_hash)
.await?
.filter_map(|r| async move {
match r {
Ok(addr) => Some(addr),
Err(e) => {
warn!("DHT peer receiver got an error: {:#}", e);
None
}
}
}),
)),
None => None, None => None,
}; };
let trackers = torrent let trackers = torrent