Refactor DHT a bit

This commit is contained in:
Igor Katson 2021-07-13 16:10:36 +01:00
parent 48f4c0a8b7
commit ace11186ef
6 changed files with 119 additions and 96 deletions

1
Cargo.lock generated
View file

@ -1614,6 +1614,7 @@ dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]

View file

@ -7,7 +7,7 @@ edition = "2018"
[dependencies]
tokio = {version = "1", features = ["macros", "rt-multi-thread", "net", "sync"]}
tokio-stream = "0.1"
tokio-stream = {version = "0.1", features = ["sync"]}
serde = {version = "1", features = ["derive"]}
hex = "0.4"
bencode = {path = "../bencode"}

View file

@ -1,6 +1,7 @@
use std::{
collections::{HashMap, HashSet},
collections::{hash_map::Entry, HashMap, HashSet},
net::SocketAddr,
sync::Arc,
};
use crate::{
@ -12,7 +13,7 @@ use crate::{
};
use anyhow::Context;
use bencode::ByteString;
use futures::{stream::FuturesUnordered, StreamExt};
use futures::{stream::FuturesUnordered, Stream, StreamExt, TryStreamExt};
use librqbit_core::{id20::Id20, peer_id::generate_peer_id};
use log::{debug, info, trace, warn};
use parking_lot::Mutex;
@ -22,7 +23,7 @@ use tokio::{
channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender,
},
};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::wrappers::{BroadcastStream, UnboundedReceiverStream};
struct OutstandingRequest {
transaction_id: u16,
@ -30,35 +31,42 @@ struct OutstandingRequest {
request: Request,
}
// TODO:
// - searching for peers - make it a set
// - peers - convert to broadcast
// - return a DHT handle.
// - flatten abstractions
// - framer is fine (I guess)
// - DhtHandle - straight out do things
struct DhtState {
id: Id20,
next_transaction_id: u16,
outstanding_requests: Vec<OutstandingRequest>,
searching_for_peers: Vec<Id20>,
routing_table: RoutingTable,
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
// TODO: convert to broadcast
subscribers: HashMap<Id20, Vec<UnboundedSender<Response>>>,
seen_peers: HashMap<Id20, HashSet<SocketAddr>>,
get_peers_subscribers: HashMap<Id20, tokio::sync::broadcast::Sender<SocketAddr>>,
made_requests: HashSet<(Request, SocketAddr)>,
}
impl DhtState {
pub fn new(id: Id20, sender: UnboundedSender<(Message<ByteString>, SocketAddr)>) -> Self {
fn new(id: Id20, sender: UnboundedSender<(Message<ByteString>, SocketAddr)>) -> Self {
Self {
id,
next_transaction_id: 0,
outstanding_requests: Vec::new(),
searching_for_peers: Vec::new(),
routing_table: RoutingTable::new(id),
sender,
subscribers: Default::default(),
seen_peers: Default::default(),
get_peers_subscribers: Default::default(),
made_requests: Default::default(),
}
}
pub fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message<ByteString> {
fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message<ByteString> {
let transaction_id = self.next_transaction_id;
let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8];
self.next_transaction_id += 1;
@ -191,18 +199,27 @@ impl DhtState {
}
}
pub fn on_request(
pub fn get_peers(
&mut self,
request: Request,
sender: UnboundedSender<Response>,
) -> anyhow::Result<()> {
match request {
Request::GetPeers(info_hash) => {
let subs = self.subscribers.entry(info_hash).or_default();
subs.push(sender);
self.searching_for_peers.push(info_hash);
info_hash: Id20,
) -> anyhow::Result<(
Vec<SocketAddr>,
tokio::sync::broadcast::Receiver<SocketAddr>,
)> {
match self.get_peers_subscribers.entry(info_hash) {
Entry::Occupied(o) => {
let existing_peers = self
.seen_peers
.get(&info_hash)
.map(|c| c.iter().copied().collect())
.unwrap_or_default();
let rx = o.get().subscribe();
return Ok((existing_peers, rx));
}
Entry::Vacant(v) => {
let (tx, rx) = tokio::sync::broadcast::channel(100);
v.insert(tx);
// workaround borrow checker.
let mut addrs = Vec::new();
for node in self
.routing_table
@ -219,10 +236,10 @@ impl DhtState {
.send((request, addr))
.context("DhtState: error sending to self.sender")?;
}
return Ok((Vec::new(), rx));
}
Request::FindNode(_) => todo!(),
};
Ok(())
}
}
fn on_found_nodes(
@ -232,11 +249,18 @@ impl DhtState {
_target: Id20,
nodes: CompactNodeInfo,
) -> anyhow::Result<()> {
// We don't need to allocate/collect here, but the borrow checker is not happy
// otherwise when we iterate self.searching_for_peers and mutating self in the loop.
let searching_for_peers = self
.get_peers_subscribers
.keys()
.copied()
.collect::<Vec<_>>();
match self.routing_table.add_node(source, source_addr) {
InsertResult::ReplacedBad(_) | InsertResult::Added => {
for idx in 0..self.searching_for_peers.len() {
let info_hash = self.searching_for_peers[idx];
let request = Request::GetPeers(info_hash);
for info_hash in &searching_for_peers {
let request = Request::GetPeers(*info_hash);
if self.made_requests.insert((request, source_addr)) {
self.routing_table.mark_outgoing_request(&source);
let msg = self.create_request(request, source_addr);
@ -249,12 +273,10 @@ impl DhtState {
for node in nodes.nodes {
match self.routing_table.add_node(node.id, node.addr.into()) {
InsertResult::ReplacedBad(_) | InsertResult::Added => {
for idx in 0..self.searching_for_peers.len() {
let info_hash = self.searching_for_peers[idx];
let request = Request::GetPeers(info_hash);
for info_hash in &searching_for_peers {
let request = Request::GetPeers(*info_hash);
if self.made_requests.insert((request, node.addr.into())) {
let msg =
self.create_request(Request::GetPeers(info_hash), node.addr.into());
let msg = self.create_request(request, node.addr.into());
self.routing_table.mark_outgoing_request(&node.id);
self.sender.send((msg, node.addr.into()))?
}
@ -277,8 +299,8 @@ impl DhtState {
self.routing_table.mark_response(&source);
if let Some(peers) = data.values {
let subscribers = match self.subscribers.get(&target) {
Some(subscribers) => subscribers,
let bsender = match self.get_peers_subscribers.get(&target) {
Some(s) => s,
None => {
warn!(
"ignoring peers for {:?}: no subscribers left. Peers: {:?}",
@ -287,10 +309,10 @@ impl DhtState {
return Ok(());
}
};
for subscriber in subscribers {
for peer in peers.iter() {
subscriber.send(Response::Peer(peer.addr.into()))?
}
for peer in peers.iter() {
bsender
.send(peer.addr.into())
.context("error sending peers to subscribers")?;
}
};
if let Some(nodes) = data.nodes {
@ -378,24 +400,18 @@ enum Response {
Peer(SocketAddr),
}
#[derive(Clone)]
pub struct Dht {
request_tx: Sender<(Request, UnboundedSender<Response>)>,
state: Arc<Mutex<DhtState>>,
}
struct DhtWorker {
socket: UdpSocket,
peer_id: Id20,
state: Mutex<DhtState>,
state: Arc<Mutex<DhtState>>,
}
impl DhtWorker {
fn on_request(
&self,
request: Request,
sender: UnboundedSender<Response>,
) -> anyhow::Result<()> {
self.state.lock().on_request(request, sender)
}
fn on_response(&self, msg: Message<ByteString>, addr: SocketAddr) -> anyhow::Result<()> {
self.state.lock().on_incoming_from_remote(msg, addr)
}
@ -447,19 +463,6 @@ impl DhtWorker {
};
let mut bootstrap_done = false;
let request_reader = {
let this = &self;
async move {
while let Some((request, sender)) = request_rx.recv().await {
this.on_request(request, sender)
.context("error processing request")?;
}
Err::<(), _>(anyhow::anyhow!(
"closed request reader, no more subscribers"
))
}
};
let response_reader = {
let this = &self;
async move {
@ -476,7 +479,6 @@ impl DhtWorker {
tokio::pin!(framer);
tokio::pin!(bootstrap);
tokio::pin!(request_reader);
tokio::pin!(response_reader);
loop {
@ -488,7 +490,6 @@ impl DhtWorker {
bootstrap_done = true;
result?;
},
err = &mut request_reader => {anyhow::bail!("request reader quit: {:?}", err)}
err = &mut response_reader => {anyhow::bail!("response reader quit: {:?}", err)}
}
}
@ -511,35 +512,32 @@ impl Dht {
.map(|s| s.to_string())
.collect::<Vec<_>>();
tokio::spawn(async move {
let (in_tx, in_rx) = unbounded_channel();
let worker = DhtWorker {
socket,
peer_id,
state: Mutex::new(DhtState::new(peer_id, in_tx.clone())),
};
let result = worker
.start(in_tx, in_rx, request_rx, &bootstrap_addrs)
.await;
warn!("DHT worker finished with {:?}", result);
let (in_tx, in_rx) = unbounded_channel();
let state = Arc::new(Mutex::new(DhtState::new(peer_id, in_tx.clone())));
tokio::spawn({
let state = state.clone();
async move {
let worker = DhtWorker {
socket,
peer_id,
state,
};
let result = worker
.start(in_tx, in_rx, request_rx, &bootstrap_addrs)
.await;
warn!("DHT worker finished with {:?}", result);
}
});
Ok(Dht { request_tx })
Ok(Dht { state })
}
pub async fn get_peers(&self, info_hash: Id20) -> impl StreamExt<Item = SocketAddr> {
let (tx, rx) = unbounded_channel::<Response>();
// This is a hack to test localhost speeds, uncomment to test that quickly.
//
// tx.send(Response::Peer("127.0.0.1:27311".parse().unwrap()))
// .unwrap();
// std::mem::forget(tx);
self.request_tx
.send((Request::GetPeers(info_hash), tx))
.await
.unwrap();
UnboundedReceiverStream::new(rx).map(|r| match r {
Response::Peer(addr) => addr,
})
pub async fn get_peers(
&self,
info_hash: Id20,
) -> anyhow::Result<impl Stream<Item = anyhow::Result<SocketAddr>> + Unpin> {
let (initial_peers, rx) = self.state.lock().get_peers(info_hash)?;
let rx = BroadcastStream::new(rx).map_err(|e| e.into());
let rx = futures::stream::iter(initial_peers).map(Ok).chain(rx);
Ok(rx)
}
}

View file

@ -10,9 +10,10 @@ async fn main() -> anyhow::Result<()> {
let info_hash = Id20::from_str("64a980abe6e448226bb930ba061592e44c3781a1").unwrap();
let dht = Dht::new().await.context("error initializing DHT")?;
let mut stream = dht.get_peers(info_hash).await;
let mut stream = dht.get_peers(info_hash).await?;
let mut seen = HashSet::new();
while let Some(peer) = stream.next().await {
let peer = peer.context("error reading peer stream")?;
if seen.insert(peer) {
log::info!("peer found: {}", peer)
}

View file

@ -2,7 +2,7 @@ use std::{collections::HashSet, net::SocketAddr};
use anyhow::Context;
use buffers::ByteString;
use futures::{stream::FuturesUnordered, StreamExt};
use futures::{stream::FuturesUnordered, Stream, StreamExt};
use librqbit_core::torrent_metainfo::TorrentMetaV1Info;
use log::debug;
@ -21,7 +21,7 @@ pub enum ReadMetainfoResult<Rx> {
},
}
pub async fn read_metainfo_from_peer_receiver<A: StreamExt<Item = SocketAddr> + Unpin>(
pub async fn read_metainfo_from_peer_receiver<A: Stream<Item = SocketAddr> + Unpin>(
peer_id: Id20,
info_hash: Id20,
mut addrs: A,
@ -101,8 +101,10 @@ mod tests {
let info_hash = Id20::from_str("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7").unwrap();
let dht = Dht::new().await.unwrap();
let peer_rx = dht.get_peers(info_hash).await;
let peer_rx = dht.get_peers(info_hash).await.unwrap();
let peer_id = generate_peer_id();
let peer_rx = peer_rx.filter_map(|r| async move { r.ok() });
tokio::pin!(peer_rx);
match read_metainfo_from_peer_receiver(peer_id, info_hash, peer_rx, None).await {
ReadMetainfoResult::Found { info, .. } => dbg!(info),
ReadMetainfoResult::ChannelClosed { .. } => todo!("should not have happened"),

View file

@ -208,7 +208,16 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
let dht_rx = dht
.ok_or_else(|| anyhow::anyhow!("magnet links without DHT are not supported"))?
.get_peers(info_hash)
.await;
.await?;
let dht_rx = Box::pin(dht_rx.filter_map(|addr| async move {
match addr {
Ok(addr) => Some(addr),
Err(e) => {
warn!("DHT peer receiver got an error: {:#}", e);
None
}
}
}));
let trackers = trackers
.into_iter()
@ -250,7 +259,19 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
torrent_from_file(&opts.torrent_path)?
};
let dht_rx = match dht {
Some(dht) => Some(dht.get_peers(torrent.info_hash).await),
Some(dht) => Some(Box::pin(
dht.get_peers(torrent.info_hash)
.await?
.filter_map(|r| async move {
match r {
Ok(addr) => Some(addr),
Err(e) => {
warn!("DHT peer receiver got an error: {:#}", e);
None
}
}
}),
)),
None => None,
};
let trackers = torrent