diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index e9382a7..935097c 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -129,7 +129,48 @@ impl DhtState { }; match &msg.kind { - MessageKind::Error(_) | MessageKind::Response(_) => {} + MessageKind::Error(_) | MessageKind::Response(_) => { + if msg.transaction_id.len() != 2 { + anyhow::bail!( + "{}: transaction id unrecognized, expected its length == 2. Message: {:?}", + addr, + msg + ) + } + let tid = ((msg.transaction_id[0] as u16) << 8) + (msg.transaction_id[1] as u16); + // O(n) but whatever + let outstanding_id = self + .outstanding_requests + .iter() + .position(|req| req.transaction_id == tid && req.addr == addr) + .ok_or_else(|| { + anyhow::anyhow!("outstanding request not found. Message: {:?}", msg) + })?; + let outstanding = self.outstanding_requests.remove(outstanding_id); + let response = match msg.kind { + MessageKind::Error(e) => { + anyhow::bail!( + "request {:?} received error response {:?}", + outstanding.request, + e + ) + } + MessageKind::Response(r) => r, + _ => unreachable!(), + }; + self.routing_table.mark_response(&response.id); + match outstanding.request { + Request::FindNode(id) => { + let nodes = response.nodes.ok_or_else(|| { + anyhow::anyhow!("expected nodes for find_node requests") + })?; + self.on_found_nodes(response.id, addr, id, nodes) + } + Request::GetPeers(id) => { + self.on_found_peers_or_nodes(response.id, addr, id, response) + } + } + } MessageKind::PingRequest(_) => { let response = bprotocol::Response { id: self.id, @@ -144,7 +185,7 @@ impl DhtState { kind: MessageKind::Response(response), }; self.sender.send((message, addr))?; - return Ok(()); + Ok(()) } MessageKind::GetPeersRequest(req) => { let peers = self.seen_peers.get(&req.info_hash).map(|peers| { @@ -180,7 +221,7 @@ impl DhtState { kind: MessageKind::Response(response), }; self.sender.send((message, addr))?; - return Ok(()); + Ok(()) } MessageKind::FindNodeRequest(req) => { let compact_node_info = generate_compact_nodes(req.target); @@ -197,44 +238,8 @@ impl DhtState { kind: MessageKind::Response(response), }; self.sender.send((message, addr))?; - return Ok(()); + Ok(()) } - }; - if msg.transaction_id.len() != 2 { - anyhow::bail!( - "{}: transaction id unrecognized, we didn't ask for it. Message: {:?}", - addr, - msg - ) - } - let tid = ((msg.transaction_id[0] as u16) << 8) + (msg.transaction_id[1] as u16); - // O(n) but whatever - let outstanding_id = self - .outstanding_requests - .iter() - .position(|req| req.transaction_id == tid && req.addr == addr) - .ok_or_else(|| anyhow::anyhow!("outstanding request not found. Message: {:?}", msg))?; - let outstanding = self.outstanding_requests.remove(outstanding_id); - let response = match msg.kind { - MessageKind::Error(e) => { - anyhow::bail!( - "request {:?} received error response {:?}", - outstanding.request, - e - ) - } - MessageKind::Response(r) => r, - _ => unreachable!(), - }; - self.routing_table.mark_response(&response.id); - match outstanding.request { - Request::FindNode(id) => { - let nodes = response - .nodes - .ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?; - self.on_found_nodes(response.id, addr, id, nodes) - } - Request::GetPeers(id) => self.on_found_peers_or_nodes(response.id, addr, id, response), } }