diff --git a/Cargo.lock b/Cargo.lock index 89191e2..e83606a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1614,6 +1614,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] diff --git a/crates/dht/Cargo.toml b/crates/dht/Cargo.toml index 0ad6cfb..7dbdaac 100644 --- a/crates/dht/Cargo.toml +++ b/crates/dht/Cargo.toml @@ -7,7 +7,7 @@ edition = "2018" [dependencies] tokio = {version = "1", features = ["macros", "rt-multi-thread", "net", "sync"]} -tokio-stream = "0.1" +tokio-stream = {version = "0.1", features = ["sync"]} serde = {version = "1", features = ["derive"]} hex = "0.4" bencode = {path = "../bencode"} diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index 65d188f..65b6d76 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -1,6 +1,7 @@ use std::{ - collections::{HashMap, HashSet}, + collections::{hash_map::Entry, HashMap, HashSet}, net::SocketAddr, + sync::Arc, }; use crate::{ @@ -12,7 +13,7 @@ use crate::{ }; use anyhow::Context; use bencode::ByteString; -use futures::{stream::FuturesUnordered, StreamExt}; +use futures::{stream::FuturesUnordered, Stream, StreamExt, TryStreamExt}; use librqbit_core::{id20::Id20, peer_id::generate_peer_id}; use log::{debug, info, trace, warn}; use parking_lot::Mutex; @@ -22,7 +23,7 @@ use tokio::{ channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender, }, }; -use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::wrappers::{BroadcastStream, UnboundedReceiverStream}; struct OutstandingRequest { transaction_id: u16, @@ -30,35 +31,42 @@ struct OutstandingRequest { request: Request, } +// TODO: +// - searching for peers - make it a set +// - peers - convert to broadcast +// - return a DHT handle. +// - flatten abstractions +// - framer is fine (I guess) +// - DhtHandle - straight out do things + struct DhtState { id: Id20, next_transaction_id: u16, outstanding_requests: Vec, - searching_for_peers: Vec, routing_table: RoutingTable, sender: UnboundedSender<(Message, SocketAddr)>, - // TODO: convert to broadcast - subscribers: HashMap>>, + seen_peers: HashMap>, + get_peers_subscribers: HashMap>, made_requests: HashSet<(Request, SocketAddr)>, } impl DhtState { - pub fn new(id: Id20, sender: UnboundedSender<(Message, SocketAddr)>) -> Self { + fn new(id: Id20, sender: UnboundedSender<(Message, SocketAddr)>) -> Self { Self { id, next_transaction_id: 0, outstanding_requests: Vec::new(), - searching_for_peers: Vec::new(), routing_table: RoutingTable::new(id), sender, - subscribers: Default::default(), + seen_peers: Default::default(), + get_peers_subscribers: Default::default(), made_requests: Default::default(), } } - pub fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message { + fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message { let transaction_id = self.next_transaction_id; let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8]; self.next_transaction_id += 1; @@ -191,18 +199,27 @@ impl DhtState { } } - pub fn on_request( + pub fn get_peers( &mut self, - request: Request, - sender: UnboundedSender, - ) -> anyhow::Result<()> { - match request { - Request::GetPeers(info_hash) => { - let subs = self.subscribers.entry(info_hash).or_default(); - subs.push(sender); - self.searching_for_peers.push(info_hash); + info_hash: Id20, + ) -> anyhow::Result<( + Vec, + tokio::sync::broadcast::Receiver, + )> { + match self.get_peers_subscribers.entry(info_hash) { + Entry::Occupied(o) => { + let existing_peers = self + .seen_peers + .get(&info_hash) + .map(|c| c.iter().copied().collect()) + .unwrap_or_default(); + let rx = o.get().subscribe(); + return Ok((existing_peers, rx)); + } + Entry::Vacant(v) => { + let (tx, rx) = tokio::sync::broadcast::channel(100); + v.insert(tx); - // workaround borrow checker. let mut addrs = Vec::new(); for node in self .routing_table @@ -219,10 +236,10 @@ impl DhtState { .send((request, addr)) .context("DhtState: error sending to self.sender")?; } + + return Ok((Vec::new(), rx)); } - Request::FindNode(_) => todo!(), - }; - Ok(()) + } } fn on_found_nodes( @@ -232,11 +249,18 @@ impl DhtState { _target: Id20, nodes: CompactNodeInfo, ) -> anyhow::Result<()> { + // We don't need to allocate/collect here, but the borrow checker is not happy + // otherwise when we iterate self.searching_for_peers and mutating self in the loop. + let searching_for_peers = self + .get_peers_subscribers + .keys() + .copied() + .collect::>(); + match self.routing_table.add_node(source, source_addr) { InsertResult::ReplacedBad(_) | InsertResult::Added => { - for idx in 0..self.searching_for_peers.len() { - let info_hash = self.searching_for_peers[idx]; - let request = Request::GetPeers(info_hash); + for info_hash in &searching_for_peers { + let request = Request::GetPeers(*info_hash); if self.made_requests.insert((request, source_addr)) { self.routing_table.mark_outgoing_request(&source); let msg = self.create_request(request, source_addr); @@ -249,12 +273,10 @@ impl DhtState { for node in nodes.nodes { match self.routing_table.add_node(node.id, node.addr.into()) { InsertResult::ReplacedBad(_) | InsertResult::Added => { - for idx in 0..self.searching_for_peers.len() { - let info_hash = self.searching_for_peers[idx]; - let request = Request::GetPeers(info_hash); + for info_hash in &searching_for_peers { + let request = Request::GetPeers(*info_hash); if self.made_requests.insert((request, node.addr.into())) { - let msg = - self.create_request(Request::GetPeers(info_hash), node.addr.into()); + let msg = self.create_request(request, node.addr.into()); self.routing_table.mark_outgoing_request(&node.id); self.sender.send((msg, node.addr.into()))? } @@ -277,8 +299,8 @@ impl DhtState { self.routing_table.mark_response(&source); if let Some(peers) = data.values { - let subscribers = match self.subscribers.get(&target) { - Some(subscribers) => subscribers, + let bsender = match self.get_peers_subscribers.get(&target) { + Some(s) => s, None => { warn!( "ignoring peers for {:?}: no subscribers left. Peers: {:?}", @@ -287,10 +309,10 @@ impl DhtState { return Ok(()); } }; - for subscriber in subscribers { - for peer in peers.iter() { - subscriber.send(Response::Peer(peer.addr.into()))? - } + for peer in peers.iter() { + bsender + .send(peer.addr.into()) + .context("error sending peers to subscribers")?; } }; if let Some(nodes) = data.nodes { @@ -378,24 +400,18 @@ enum Response { Peer(SocketAddr), } +#[derive(Clone)] pub struct Dht { - request_tx: Sender<(Request, UnboundedSender)>, + state: Arc>, } struct DhtWorker { socket: UdpSocket, peer_id: Id20, - state: Mutex, + state: Arc>, } impl DhtWorker { - fn on_request( - &self, - request: Request, - sender: UnboundedSender, - ) -> anyhow::Result<()> { - self.state.lock().on_request(request, sender) - } fn on_response(&self, msg: Message, addr: SocketAddr) -> anyhow::Result<()> { self.state.lock().on_incoming_from_remote(msg, addr) } @@ -447,19 +463,6 @@ impl DhtWorker { }; let mut bootstrap_done = false; - let request_reader = { - let this = &self; - async move { - while let Some((request, sender)) = request_rx.recv().await { - this.on_request(request, sender) - .context("error processing request")?; - } - Err::<(), _>(anyhow::anyhow!( - "closed request reader, no more subscribers" - )) - } - }; - let response_reader = { let this = &self; async move { @@ -476,7 +479,6 @@ impl DhtWorker { tokio::pin!(framer); tokio::pin!(bootstrap); - tokio::pin!(request_reader); tokio::pin!(response_reader); loop { @@ -488,7 +490,6 @@ impl DhtWorker { bootstrap_done = true; result?; }, - err = &mut request_reader => {anyhow::bail!("request reader quit: {:?}", err)} err = &mut response_reader => {anyhow::bail!("response reader quit: {:?}", err)} } } @@ -511,35 +512,32 @@ impl Dht { .map(|s| s.to_string()) .collect::>(); - tokio::spawn(async move { - let (in_tx, in_rx) = unbounded_channel(); - let worker = DhtWorker { - socket, - peer_id, - state: Mutex::new(DhtState::new(peer_id, in_tx.clone())), - }; - let result = worker - .start(in_tx, in_rx, request_rx, &bootstrap_addrs) - .await; - warn!("DHT worker finished with {:?}", result); + let (in_tx, in_rx) = unbounded_channel(); + let state = Arc::new(Mutex::new(DhtState::new(peer_id, in_tx.clone()))); + + tokio::spawn({ + let state = state.clone(); + async move { + let worker = DhtWorker { + socket, + peer_id, + state, + }; + let result = worker + .start(in_tx, in_rx, request_rx, &bootstrap_addrs) + .await; + warn!("DHT worker finished with {:?}", result); + } }); - Ok(Dht { request_tx }) + Ok(Dht { state }) } - pub async fn get_peers(&self, info_hash: Id20) -> impl StreamExt { - let (tx, rx) = unbounded_channel::(); - - // This is a hack to test localhost speeds, uncomment to test that quickly. - // - // tx.send(Response::Peer("127.0.0.1:27311".parse().unwrap())) - // .unwrap(); - // std::mem::forget(tx); - - self.request_tx - .send((Request::GetPeers(info_hash), tx)) - .await - .unwrap(); - UnboundedReceiverStream::new(rx).map(|r| match r { - Response::Peer(addr) => addr, - }) + pub async fn get_peers( + &self, + info_hash: Id20, + ) -> anyhow::Result> + Unpin> { + let (initial_peers, rx) = self.state.lock().get_peers(info_hash)?; + let rx = BroadcastStream::new(rx).map_err(|e| e.into()); + let rx = futures::stream::iter(initial_peers).map(Ok).chain(rx); + Ok(rx) } } diff --git a/crates/dht/src/main.rs b/crates/dht/src/main.rs index a4abbc9..7240a35 100644 --- a/crates/dht/src/main.rs +++ b/crates/dht/src/main.rs @@ -10,9 +10,10 @@ async fn main() -> anyhow::Result<()> { let info_hash = Id20::from_str("64a980abe6e448226bb930ba061592e44c3781a1").unwrap(); let dht = Dht::new().await.context("error initializing DHT")?; - let mut stream = dht.get_peers(info_hash).await; + let mut stream = dht.get_peers(info_hash).await?; let mut seen = HashSet::new(); while let Some(peer) = stream.next().await { + let peer = peer.context("error reading peer stream")?; if seen.insert(peer) { log::info!("peer found: {}", peer) } diff --git a/crates/librqbit/src/dht_utils.rs b/crates/librqbit/src/dht_utils.rs index 056edf1..3d98a1c 100644 --- a/crates/librqbit/src/dht_utils.rs +++ b/crates/librqbit/src/dht_utils.rs @@ -2,7 +2,7 @@ use std::{collections::HashSet, net::SocketAddr}; use anyhow::Context; use buffers::ByteString; -use futures::{stream::FuturesUnordered, StreamExt}; +use futures::{stream::FuturesUnordered, Stream, StreamExt}; use librqbit_core::torrent_metainfo::TorrentMetaV1Info; use log::debug; @@ -21,7 +21,7 @@ pub enum ReadMetainfoResult { }, } -pub async fn read_metainfo_from_peer_receiver + Unpin>( +pub async fn read_metainfo_from_peer_receiver + Unpin>( peer_id: Id20, info_hash: Id20, mut addrs: A, @@ -101,8 +101,10 @@ mod tests { let info_hash = Id20::from_str("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7").unwrap(); let dht = Dht::new().await.unwrap(); - let peer_rx = dht.get_peers(info_hash).await; + let peer_rx = dht.get_peers(info_hash).await.unwrap(); let peer_id = generate_peer_id(); + let peer_rx = peer_rx.filter_map(|r| async move { r.ok() }); + tokio::pin!(peer_rx); match read_metainfo_from_peer_receiver(peer_id, info_hash, peer_rx, None).await { ReadMetainfoResult::Found { info, .. } => dbg!(info), ReadMetainfoResult::ChannelClosed { .. } => todo!("should not have happened"), diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index e25b4fa..aa85358 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -208,7 +208,16 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()> let dht_rx = dht .ok_or_else(|| anyhow::anyhow!("magnet links without DHT are not supported"))? .get_peers(info_hash) - .await; + .await?; + let dht_rx = Box::pin(dht_rx.filter_map(|addr| async move { + match addr { + Ok(addr) => Some(addr), + Err(e) => { + warn!("DHT peer receiver got an error: {:#}", e); + None + } + } + })); let trackers = trackers .into_iter() @@ -250,7 +259,19 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()> torrent_from_file(&opts.torrent_path)? }; let dht_rx = match dht { - Some(dht) => Some(dht.get_peers(torrent.info_hash).await), + Some(dht) => Some(Box::pin( + dht.get_peers(torrent.info_hash) + .await? + .filter_map(|r| async move { + match r { + Ok(addr) => Some(addr), + Err(e) => { + warn!("DHT peer receiver got an error: {:#}", e); + None + } + } + }), + )), None => None, }; let trackers = torrent