From aa2a41a53cbb9b39ca6b3e4b55fb346499bf7ae7 Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Wed, 29 Nov 2023 23:12:20 +0000 Subject: [PATCH] Fixing up bugs, refactored DHT works alright now --- crates/dht/examples/dht.rs | 1 + crates/dht/src/dht.rs | 251 +++++++++++++------------------------ 2 files changed, 90 insertions(+), 162 deletions(-) diff --git a/crates/dht/examples/dht.rs b/crates/dht/examples/dht.rs index dc0cc4f..883ef79 100644 --- a/crates/dht/examples/dht.rs +++ b/crates/dht/examples/dht.rs @@ -36,6 +36,7 @@ async fn main() -> anyhow::Result<()> { let mut f = std::fs::OpenOptions::new() .create(true) .write(true) + .truncate(true) .open(filename) .unwrap(); serde_json::to_writer_pretty(&mut f, r).unwrap(); diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index f4b6695..cb7a756 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -81,15 +81,10 @@ fn make_rate_limiter() -> RateLimiter { } trait RecursiveRequestCallbacks: Sized + Send + Sync + 'static { - fn on_request_start( - &self, - req: &Arc>, - target_node: Id20, - addr: SocketAddr, - ); + fn on_request_start(&self, req: &RecursiveRequest, target_node: Id20, addr: SocketAddr); fn on_request_end( &self, - req: &Arc>, + req: &RecursiveRequest, target_node: Id20, addr: SocketAddr, resp: &anyhow::Result, @@ -98,11 +93,11 @@ trait RecursiveRequestCallbacks: Sized + Send + Sync + 'static { struct RecursiveRequestCallbacksGetPeers {} impl RecursiveRequestCallbacks for RecursiveRequestCallbacksGetPeers { - fn on_request_start(&self, _: &Arc>, _: Id20, _: SocketAddr) {} + fn on_request_start(&self, _: &RecursiveRequest, _: Id20, _: SocketAddr) {} fn on_request_end( &self, - _: &Arc>, + _: &RecursiveRequest, _: Id20, _: SocketAddr, _: &anyhow::Result, @@ -112,12 +107,7 @@ impl RecursiveRequestCallbacks for RecursiveRequestCallbacksGetPeers { struct RecursiveRequestCallbacksFindNodes {} impl RecursiveRequestCallbacks for RecursiveRequestCallbacksFindNodes { - fn on_request_start( - &self, - req: &Arc>, - target_node: Id20, - addr: SocketAddr, - ) { + fn on_request_start(&self, req: &RecursiveRequest, target_node: Id20, addr: SocketAddr) { match req.dht.routing_table_add_node(target_node, addr) { InsertResult::WasExisting | InsertResult::ReplacedBad(_) | InsertResult::Added => { req.dht @@ -131,7 +121,7 @@ impl RecursiveRequestCallbacks for RecursiveRequestCallbacksFindNodes { fn on_request_end( &self, - req: &Arc>, + req: &RecursiveRequest, target_node: Id20, _addr: SocketAddr, resp: &anyhow::Result, @@ -150,8 +140,8 @@ struct RecursiveRequest { request: Request, dht: Arc, useful_nodes: RwLock>, - // peer_tx: tokio::sync::mpsc::UnboundedSender, - // node_tx: tokio::sync::mpsc::UnboundedSender<(Option, SocketAddr)>, + peer_tx: tokio::sync::mpsc::UnboundedSender, + node_tx: tokio::sync::mpsc::UnboundedSender<(Option, SocketAddr)>, callbacks: C, } @@ -169,11 +159,11 @@ impl RequestPeersStream { request: Request::GetPeers(info_hash), dht, useful_nodes: RwLock::new(Vec::new()), - // peer_tx, - // node_tx, + peer_tx, + node_tx, callbacks: RecursiveRequestCallbacksGetPeers {}, }); - let join_handle = rp.clone().request_peers_forever(node_rx, node_tx, peer_tx); + let join_handle = rp.request_peers_forever(node_rx); Self { rx: peer_rx, cancel_join_handle: join_handle, @@ -199,77 +189,101 @@ impl Stream for RequestPeersStream { } impl RecursiveRequest { - async fn find_node( - dht: Arc, - target: Id20, - root_addrs: impl Iterator, - ) -> anyhow::Result<()> { - let (peer_tx, peer_rx) = unbounded_channel(); - drop(peer_rx); - + async fn bootstrap(dht: Arc, target: Id20, hostname: &str) -> anyhow::Result<()> { + let addrs = tokio::net::lookup_host(hostname) + .await + .with_context(|| format!("error looking up {}", hostname))?; let (node_tx, mut node_rx) = unbounded_channel(); - let req = Arc::new(RecursiveRequest { + let req = RecursiveRequest { info_hash: target, request: Request::FindNode(target), dht, useful_nodes: RwLock::new(Vec::new()), - // peer_tx: unbounded_channel().0, - // node_tx, + peer_tx: unbounded_channel().0, + node_tx, callbacks: RecursiveRequestCallbacksFindNodes {}, - }); + }; + + let request_one = |id, addr| { + req.request_one(id, addr) + .map_err(|e| { + debug!("error: {e:?}"); + e + }) + .instrument(error_span!( + "find_node", + target = format!("{target:?}"), + addr = addr.to_string() + )) + }; let mut futs = FuturesUnordered::new(); - for addr in root_addrs { - node_tx.send((None, addr)).unwrap(); + let mut initial_addrs = 0; + for addr in addrs { + futs.push(request_one(None, addr)); + initial_addrs += 1; } + let mut successes = 0; + let mut errors = 0; + loop { tokio::select! { + biased; + r = node_rx.recv() => { let (id, addr) = r.unwrap(); - futs.push( - req.request_one(id, addr, node_tx.clone(), peer_tx.clone()) - .instrument( - error_span!("find_node", target=format!("{target:?}"), addr=addr.to_string()) - ) - ) + futs.push(request_one(id, addr)) }, - Some(f) = futs.next(), if !futs.is_empty() => { - if let Err(e) = f { - error!("error: {e:?}"); + f = futs.next() => { + let f = match f { + Some(f) => f, + None => { + // find_node recursion finished. + break; + } + }; + if f.is_ok() { + successes += 1; + } else { + errors += 1; } } } } + if successes == 0 { + bail!("no successful lookups, errors = {errors}"); + } + debug!( + "finished, successes = {successes}, errors = {errors}, initial_addrs = {initial_addrs}" + ); Ok(()) } } impl RecursiveRequest { fn request_peers_forever( - self: Arc, + self: &Arc, mut node_rx: tokio::sync::mpsc::UnboundedReceiver<(Option, SocketAddr)>, - node_tx: tokio::sync::mpsc::UnboundedSender<(Option, SocketAddr)>, - peer_tx: tokio::sync::mpsc::UnboundedSender, ) -> tokio::task::JoinHandle<()> { + let this = self.clone(); spawn( - error_span!("get_peers", info_hash = format!("{:?}", self.info_hash)), + error_span!(parent: None, "get_peers", info_hash = format!("{:?}", self.info_hash)), async move { + let this = &this; // Looper adds root nodes to the queue every 60 seconds. let looper = { - let this = self.clone(); - let node_tx = node_tx.clone(); async move { let mut iteration = 0; loop { debug!("iteration {}", iteration); - let sleep = match this.get_peers_root(&node_tx) { + let sleep = match this.get_peers_root() { Ok(0) => Duration::from_secs(1), Ok(n) if n < 8 => REQUERY_INTERVAL / 2, Ok(_) => REQUERY_INTERVAL, Err(e) => { - error!("error: {e:?}"); + error!("error in get_peers_root(): {e:?}"); return Err::<(), anyhow::Error>(e); } }; @@ -286,7 +300,7 @@ impl RecursiveRequest { addr = node_rx.recv() => { let (id, addr) = addr.unwrap(); futs.push( - self.request_one(id, addr, node_tx.clone(), peer_tx.clone()) + this.request_one(id, addr) .map_err(|e| debug!("error: {e:?}")) .instrument(error_span!("addr", addr=addr.to_string())) ); @@ -299,10 +313,7 @@ impl RecursiveRequest { ) } - fn get_peers_root( - self: &Arc, - node_tx: &UnboundedSender<(Option, SocketAddr)>, - ) -> anyhow::Result { + fn get_peers_root(&self) -> anyhow::Result { let mut count = 0; for (id, addr) in self .dht @@ -314,20 +325,14 @@ impl RecursiveRequest { .take(8) { count += 1; - node_tx.send((Some(id), addr))?; + self.node_tx.send((Some(id), addr))?; } Ok(count) } } impl RecursiveRequest { - async fn request_one<'a>( - self: &'a Arc, - id: Option, - addr: SocketAddr, - node_tx: UnboundedSender<(Option, SocketAddr)>, - peer_tx: UnboundedSender, - ) -> anyhow::Result<()> { + async fn request_one(&self, id: Option, addr: SocketAddr) -> anyhow::Result<()> { if let Some(id) = id { self.callbacks.on_request_start(self, id, addr); } @@ -348,18 +353,26 @@ impl RecursiveRequest { return Err(e); } }; + trace!("received {response:?}"); if let Some(peers) = response.values { for peer in peers { - peer_tx.send(SocketAddr::V4(peer.addr))?; + self.peer_tx.send(SocketAddr::V4(peer.addr))?; } } if let Some(nodes) = response.nodes { for node in nodes.nodes { let addr = SocketAddr::V4(node.addr); - if self.should_request_node(node.id, addr) { - node_tx.send((Some(node.id), addr))?; + let should_request = self.should_request_node(node.id, addr); + trace!( + "should_request={}, id={:?}, addr={}", + should_request, + node.id, + addr + ); + if should_request { + self.node_tx.send((Some(node.id), addr))?; } } } @@ -471,20 +484,6 @@ impl DhtState { } } - async fn send_request_and_handle_response( - self: &Arc, - request: Request, - addr: SocketAddr, - ) -> anyhow::Result<()> { - let resp = self.request(request, addr).await?; - match resp { - ResponseOrError::Response(r) => self.on_response(addr, request, r), - ResponseOrError::Error(e) => { - bail!("received error: {:?}", e); - } - } - } - async fn request(&self, request: Request, addr: SocketAddr) -> anyhow::Result { self.rate_limiter.acquire_one().await; let (tid, message) = self.create_request(request); @@ -550,24 +549,6 @@ impl DhtState { (transaction_id, message) } - fn on_response( - self: &Arc, - addr: SocketAddr, - request: Request, - response: Response, - ) -> anyhow::Result<()> { - self.routing_table.write().mark_response(&response.id); - match request { - Request::FindNode(id) => { - todo!() - } - Request::Ping => Ok(()), - Request::GetPeers(_info_hash) => { - todo!() - } - } - } - fn on_received_message( self: &Arc, msg: Message, @@ -615,7 +596,7 @@ impl DhtState { match request.done.send(Ok(response_or_error)) { Ok(_) => {} Err(e) => { - warn!( + debug!( "recieved response, but the receiver task is closed: {:?}", e ); @@ -746,68 +727,22 @@ enum ResponseOrError { struct DhtWorker { socket: UdpSocket, - peer_id: Id20, - state: Arc, + dht: Arc, } impl DhtWorker { fn on_send_error(&self, tid: u16, addr: SocketAddr, err: anyhow::Error) { if let Some((_, OutstandingRequest { done })) = - self.state.inflight_by_transaction_id.remove(&(tid, addr)) + self.dht.inflight_by_transaction_id.remove(&(tid, addr)) { let _ = done.send(Err(err)).is_err(); }; } - async fn bootstrap_one_ip_with_backoff(&self, addr: SocketAddr) -> anyhow::Result<()> { - let mut backoff = ExponentialBackoffBuilder::new() - .with_initial_interval(Duration::from_secs(10)) - .with_multiplier(1.5) - .with_max_interval(Duration::from_secs(60)) - .with_max_elapsed_time(Some(Duration::from_secs(86400))) - .build(); - - loop { - let res = self - .state - .send_request_and_handle_response(Request::FindNode(self.peer_id), addr) - .await; - match res { - Ok(r) => return Ok(r), - Err(e) => { - debug!("error: {:?}", e); - if let Some(backoff) = backoff.next_backoff() { - tokio::time::sleep(backoff).await; - continue; - } - bail!("given up bootstrapping, timed out") - } - } - } - } - async fn bootstrap_hostname(&self, hostname: &str) -> anyhow::Result<()> { - let addrs = tokio::net::lookup_host(hostname) + RecursiveRequest::bootstrap(self.dht.clone(), self.dht.id, hostname) + .instrument(error_span!("bootstrap", hostname = hostname)) .await - .with_context(|| format!("error looking up {}", hostname))?; - let mut futs = FuturesUnordered::new(); - for addr in addrs { - futs.push( - self.bootstrap_one_ip_with_backoff(addr) - .instrument(error_span!("addr", addr = addr.to_string())), - ); - } - let requests = futs.len(); - let mut successes = 0; - while let Some(resp) = futs.next().await { - if resp.is_ok() { - successes += 1 - }; - } - if successes == 0 { - bail!("none of the {} bootstrap requests succeded", requests); - } - Ok(()) } async fn bootstrap_hostname_with_backoff(&self, addr: &str) -> anyhow::Result<()> { @@ -838,11 +773,7 @@ impl DhtWorker { let mut futs = FuturesUnordered::new(); for addr in bootstrap_addrs.iter() { - let this = &self; - futs.push( - this.bootstrap_hostname_with_backoff(addr) - .instrument(error_span!("bootstrap", hostname = addr)), - ); + futs.push(self.bootstrap_hostname_with_backoff(addr)); } let mut successes = 0; while let Some(resp) = futs.next().await { @@ -937,7 +868,7 @@ impl DhtWorker { let this = &self; async move { while let Some((response, addr)) = out_rx.recv().await { - if let Err(e) = this.state.on_received_message(response, addr) { + if let Err(e) = this.dht.on_received_message(response, addr) { debug!("error in on_response, addr={:?}: {}", addr, e) } } @@ -1011,11 +942,7 @@ impl DhtState { spawn(error_span!("dht"), { let state = state.clone(); async move { - let worker = DhtWorker { - socket, - peer_id, - state, - }; + let worker = DhtWorker { socket, dht: state }; worker.start(in_rx, &bootstrap_addrs).await?; Ok(()) }