From 5c417964852d909a13bf4aec69ddae18129a1de7 Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Mon, 12 Jul 2021 19:42:48 +0100 Subject: [PATCH] Saving --- Cargo.lock | 1 + crates/dht/Cargo.toml | 1 + crates/dht/src/bprotocol.rs | 62 +++- crates/dht/src/dht.rs | 535 ++++++++++++++++++++++++++++ crates/dht/src/id20.rs | 15 +- crates/dht/src/lib.rs | 1 + crates/dht/src/main.rs | 398 +-------------------- crates/dht/src/routing_table.rs | 16 +- crates/librqbit/src/dht/inforead.rs | 8 +- 9 files changed, 633 insertions(+), 404 deletions(-) create mode 100644 crates/dht/src/dht.rs diff --git a/Cargo.lock b/Cargo.lock index 4fbb82d..9acc6be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -276,6 +276,7 @@ version = "0.1.0" dependencies = [ "anyhow", "bencode", + "clone_to_owned", "futures 0.3.15", "hex 0.4.3", "kad", diff --git a/crates/dht/Cargo.toml b/crates/dht/Cargo.toml index 7b12095..5e2aa63 100644 --- a/crates/dht/Cargo.toml +++ b/crates/dht/Cargo.toml @@ -19,6 +19,7 @@ pretty_env_logger = "0.4" futures = "0.3" rand = "0.8" +clone_to_owned = {path="../clone_to_owned"} librqbit_core = {path="../librqbit_core"} [dev-dependencies] diff --git a/crates/dht/src/bprotocol.rs b/crates/dht/src/bprotocol.rs index 57aa219..f2b31a3 100644 --- a/crates/dht/src/bprotocol.rs +++ b/crates/dht/src/bprotocol.rs @@ -5,6 +5,7 @@ use std::{ }; use bencode::ByteBuf; +use clone_to_owned::CloneToOwned; use serde::{ de::{IgnoredAny, Unexpected}, Deserialize, Deserializer, Serialize, @@ -67,6 +68,20 @@ pub struct ErrorDescription { pub description: BufT, } +impl CloneToOwned for ErrorDescription +where + BufT: CloneToOwned, +{ + type Target = ErrorDescription<::Target>; + + fn clone_to_owned(&self) -> Self::Target { + ErrorDescription { + code: self.code, + description: self.description.clone_to_owned(), + } + } +} + impl Serialize for ErrorDescription where BufT: Serialize, @@ -293,6 +308,11 @@ pub struct GetPeersRequest { pub info_hash: Id20, } +#[derive(Debug, Serialize, Deserialize)] +pub struct PingRequest { + id: Id20, +} + #[derive(Debug, Serialize, Deserialize)] #[serde(bound(serialize = "BufT: AsRef<[u8]> + Serialize"))] #[serde(bound(deserialize = "BufT: From<&'de [u8]> + Deserialize<'de>"))] @@ -319,6 +339,7 @@ pub enum MessageKind { GetPeersRequest(GetPeersRequest), FindNodeRequest(FindNodeRequest), Response(Response), + PingRequest(PingRequest), } pub fn serialize_message<'a, W: Write, BufT: Serialize + From<&'a [u8]>>( @@ -382,6 +403,19 @@ pub fn serialize_message<'a, W: Write, BufT: Serialize + From<&'a [u8]>>( }; Ok(bencode::bencode_serialize_to_writer(msg, writer)?) } + MessageKind::PingRequest(ping) => { + let msg: RawMessage = RawMessage { + message_type: MessageType::Request, + transaction_id, + error: None, + response: None, + method_name: Some(BufT::from(b"ping")), + arguments: Some(ping), + ip, + version, + }; + Ok(bencode::bencode_serialize_to_writer(msg, writer)?) + } } } @@ -391,7 +425,7 @@ where { let de: RawMessage = bencode::from_bytes(buf)?; match de.message_type { - MessageType::Request => match (de.arguments, de.method_name, de.response, de.error) { + MessageType::Request => match (&de.arguments, &de.method_name, &de.response, &de.error) { (Some(_), Some(method_name), None, None) => match method_name.as_ref() { b"find_node" => { let de: RawMessage = bencode::from_bytes(buf)?; @@ -411,14 +445,24 @@ where kind: MessageKind::GetPeersRequest(de.arguments.unwrap()), }) } + b"ping" => { + let de: RawMessage = bencode::from_bytes(buf)?; + Ok(Message { + transaction_id: de.transaction_id, + version: de.version, + ip: de.ip.map(|c| c.addr), + kind: MessageKind::PingRequest(de.arguments.unwrap()), + }) + } other => anyhow::bail!("unsupported method {:?}", ByteBuf(other)), }, _ => anyhow::bail!( - "cannot deserialize message as request, expected exactly \"a\" and \"q\" to be set" + "cannot deserialize message as request, expected exactly \"a\" and \"q\" to be set. Message: {:?}", de ), }, - MessageType::Response => match (de.arguments, de.method_name, de.response, de.error) { - (None, None, Some(_), None) => { + MessageType::Response => match (&de.arguments, &de.method_name, &de.response, &de.error) { + // some peers are sending method name against the protocol, so ignore it. + (None, _, Some(_), None) => { let de: RawMessage> = bencode::from_bytes(buf)?; Ok(Message { transaction_id: de.transaction_id, @@ -428,11 +472,12 @@ where }) } _ => anyhow::bail!( - "cannot deserialize message as response, expected exactly \"r\" to be set" + "cannot deserialize message as response, expected exactly \"r\" to be set. Message: {:?}", de ), }, - MessageType::Error => match (de.arguments, de.method_name, de.response, de.error) { - (None, None, None, Some(e)) => { + MessageType::Error => match (&de.arguments, &de.method_name, &de.response, &de.error) { + // some peers are sending method name against the protocol, so ignore it. + (None, _, None, Some(_)) => { let de: RawMessage> = bencode::from_bytes(buf)?; Ok(Message { transaction_id: de.transaction_id, @@ -442,7 +487,7 @@ where }) } _ => anyhow::bail!( - "cannot deserialize message as response, expected exactly \"r\" to be set" + "cannot deserialize message as error, expected exactly \"e\" to be set. Message: {:?}", de ), }, } @@ -454,7 +499,6 @@ mod tests { use crate::bprotocol; use bencode::ByteBuf; - use librqbit_core::peer_id::generate_peer_id; // Dumped with wireshark. const FIND_NODE_REQUEST: &[u8] = b"64313a6164323a696432303abd7b477cfbcd10f30b705da20201e7101d8df155363a74617267657432303abd7b477cfbcd10f30b705da20201e7101d8df15565313a71393a66696e645f6e6f6465313a74323a0005313a79313a7165"; diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs new file mode 100644 index 0000000..6a6fbf2 --- /dev/null +++ b/crates/dht/src/dht.rs @@ -0,0 +1,535 @@ +use std::{ + collections::{HashMap, HashSet}, + net::SocketAddr, +}; + +use crate::{ + bprotocol::{ + self, CompactNodeInfo, FindNodeRequest, GetPeersRequest, Message, MessageKind, Node, + }, + id20::Id20, + routing_table::{InsertResult, RoutingTable}, +}; +use anyhow::Context; +use bencode::ByteString; +use futures::{stream::FuturesUnordered, StreamExt}; +use librqbit_core::peer_id::generate_peer_id; +use log::{debug, info, trace, warn}; +use parking_lot::Mutex; +use tokio::{ + net::UdpSocket, + sync::mpsc::{ + channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender, + }, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; + +struct OutstandingRequest { + transaction_id: u16, + addr: SocketAddr, + request: Request, +} + +struct DhtState { + id: Id20, + next_transaction_id: u16, + outstanding_requests: Vec, + searching_for_peers: Vec, + routing_table: RoutingTable, + sender: UnboundedSender<(Message, SocketAddr)>, + + // TODO: convert to broadcast + subscribers: HashMap>>, + + made_requests: HashSet<(Request, SocketAddr)>, +} + +impl DhtState { + pub fn new(id: Id20, sender: UnboundedSender<(Message, SocketAddr)>) -> Self { + Self { + id, + next_transaction_id: 0, + outstanding_requests: Vec::new(), + searching_for_peers: Vec::new(), + routing_table: RoutingTable::new(id), + sender, + subscribers: Default::default(), + made_requests: Default::default(), + } + } + + pub fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message { + let transaction_id = self.next_transaction_id; + let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8]; + self.next_transaction_id += 1; + let message = match request { + Request::GetPeers(info_hash) => Message { + transaction_id: ByteString::from(transaction_id_buf.as_ref()), + version: None, + ip: None, + kind: MessageKind::GetPeersRequest(GetPeersRequest { + id: self.id, + info_hash, + }), + }, + Request::FindNode(target) => Message { + transaction_id: ByteString::from(transaction_id_buf.as_ref()), + version: None, + ip: None, + kind: MessageKind::FindNodeRequest(FindNodeRequest { + id: self.id, + target, + }), + }, + }; + self.outstanding_requests.push(OutstandingRequest { + transaction_id, + addr, + request, + // time: Instant::now(), + }); + message + } + fn on_incoming_from_remote( + &mut self, + msg: Message, + addr: SocketAddr, + ) -> anyhow::Result<()> { + match &msg.kind { + MessageKind::Error(_) | MessageKind::Response(_) => {} + MessageKind::PingRequest(_) => { + let response = bprotocol::Response { + id: self.id, + nodes: None, + values: None, + token: None, + }; + let message = Message { + transaction_id: msg.transaction_id, + version: None, + ip: None, + kind: MessageKind::Response(response), + }; + self.sender.send((message, addr))?; + return Ok(()); + } + MessageKind::FindNodeRequest(_) | MessageKind::GetPeersRequest(_) => { + let target = match &msg.kind { + MessageKind::FindNodeRequest(req) => req.target, + MessageKind::GetPeersRequest(req) => req.info_hash, + _ => unreachable!(), + }; + // we don't track who is downloading a torrent (we don't have a peer store), + // so send nodes all the time. + let nodes = self + .routing_table + .sorted_by_distance_from(target) + .into_iter() + .filter_map(|r| { + Some(Node { + id: r.id(), + addr: match r.addr() { + SocketAddr::V4(v4) => v4, + SocketAddr::V6(_) => return None, + }, + }) + }) + .take(8) + .collect::>(); + let compact_node_info = CompactNodeInfo { nodes }; + let response = bprotocol::Response { + id: self.id, + nodes: Some(compact_node_info), + values: None, + token: None, + }; + let message = Message { + transaction_id: msg.transaction_id, + version: None, + ip: None, + kind: MessageKind::Response(response), + }; + self.sender.send((message, addr))?; + return Ok(()); + } + }; + if msg.transaction_id.len() != 2 { + anyhow::bail!( + "{}: transaction id unrecognized, we didn't ask for it. Message: {:?}", + addr, + msg + ) + } + let tid = ((msg.transaction_id[0] as u16) << 8) + (msg.transaction_id[1] as u16); + // O(n) but whatever + let outstanding_id = self + .outstanding_requests + .iter() + .position(|req| req.transaction_id == tid && req.addr == addr) + .ok_or_else(|| anyhow::anyhow!("outstanding request not found. Message: {:?}", msg))?; + let outstanding = self.outstanding_requests.remove(outstanding_id); + let response = match msg.kind { + MessageKind::Error(e) => { + anyhow::bail!( + "request {:?} received error response {:?}", + outstanding.request, + e + ) + } + MessageKind::Response(r) => r, + _ => unreachable!(), + }; + match outstanding.request { + Request::FindNode(id) => { + let nodes = response + .nodes + .ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?; + self.on_found_nodes(response.id, addr, id, nodes) + } + Request::GetPeers(id) => self.on_found_peers_or_nodes(response.id, addr, id, response), + } + } + + pub fn on_request( + &mut self, + request: Request, + sender: UnboundedSender, + ) -> anyhow::Result<()> { + match request { + Request::GetPeers(info_hash) => { + let subs = self.subscribers.entry(info_hash).or_default(); + subs.push(sender); + self.searching_for_peers.push(info_hash); + + // workaround borrow checker. + let mut addrs = Vec::new(); + for node in self + .routing_table + .sorted_by_distance_from_mut(info_hash) + .into_iter() + .take(8) + { + node.mark_outgoing_request(); + addrs.push(node.addr()); + } + for addr in addrs { + let request = self.create_request(Request::GetPeers(info_hash), addr); + self.sender + .send((request, addr)) + .context("DhtState: error sending to self.sender")?; + } + } + Request::FindNode(_) => todo!(), + }; + Ok(()) + } + + fn on_found_nodes( + &mut self, + source: Id20, + source_addr: SocketAddr, + _target: Id20, + nodes: CompactNodeInfo, + ) -> anyhow::Result<()> { + match self.routing_table.add_node(source, source_addr) { + InsertResult::ReplacedBad(_) | InsertResult::Added => { + for idx in 0..self.searching_for_peers.len() { + let info_hash = self.searching_for_peers[idx]; + let request = Request::GetPeers(info_hash); + if self.made_requests.insert((request, source_addr)) { + self.routing_table.mark_outgoing_request(&source); + let msg = self.create_request(request, source_addr); + self.sender.send((msg, source_addr))?; + } + } + } + InsertResult::WasExisting => { + self.routing_table.mark_response(&source); + } + _ => {} + }; + for node in nodes.nodes { + match self.routing_table.add_node(node.id, node.addr.into()) { + InsertResult::ReplacedBad(_) | InsertResult::Added => { + for idx in 0..self.searching_for_peers.len() { + let info_hash = self.searching_for_peers[idx]; + let request = Request::GetPeers(info_hash); + if self.made_requests.insert((request, node.addr.into())) { + let msg = + self.create_request(Request::GetPeers(info_hash), node.addr.into()); + self.routing_table.mark_outgoing_request(&node.id); + self.sender.send((msg, node.addr.into()))? + } + } + } + _ => {} + }; + } + Ok(()) + } + + fn on_found_peers_or_nodes( + &mut self, + source: Id20, + source_addr: SocketAddr, + target: Id20, + data: bprotocol::Response, + ) -> anyhow::Result<()> { + self.routing_table.add_node(source, source_addr); + if let Some(peers) = data.values { + let subscribers = match self.subscribers.get(&target) { + Some(subscribers) => subscribers, + None => { + warn!( + "ignoring peers for {:?}: no subscribers left. Peers: {:?}", + target, peers + ); + return Ok(()); + } + }; + for subscriber in subscribers { + for peer in peers.iter() { + subscriber.send(Response::Peer(peer.addr.into()))? + } + } + }; + if let Some(nodes) = data.nodes { + for node in nodes.nodes { + self.routing_table.add_node(node.id, node.addr.into()); + let request = Request::GetPeers(target); + if self.made_requests.insert((request, node.addr.into())) { + let msg = self.create_request(Request::GetPeers(target), node.addr.into()); + self.routing_table.mark_outgoing_request(&node.id); + self.sender.send((msg, node.addr.into()))? + } + } + }; + Ok(()) + } +} + +async fn run_framer( + socket: &UdpSocket, + mut input_rx: UnboundedReceiver<(Message, SocketAddr)>, + output_tx: Sender<(Message, SocketAddr)>, +) -> anyhow::Result<()> { + let writer = async { + let mut buf = Vec::new(); + while let Some((msg, addr)) = input_rx.recv().await { + let addr = match addr { + SocketAddr::V4(v4) => v4, + SocketAddr::V6(_) => continue, + }; + trace!("{}: sending {:?}", addr, &msg); + buf.clear(); + bprotocol::serialize_message( + &mut buf, + msg.transaction_id, + msg.version, + msg.ip, + msg.kind, + ) + .unwrap(); + if let Err(e) = socket.send_to(&buf, addr).await { + warn!("could not send to {:?}: {}", addr, e) + } + } + Err::<(), _>(anyhow::anyhow!( + "DHT UDP socket writer over, nowhere to read messages from" + )) + }; + let reader = async { + let mut buf = vec![0u8; 16384]; + loop { + let (size, addr) = socket + .recv_from(&mut buf) + .await + .context("error reading from UDP socket")?; + match bprotocol::deserialize_message::(&buf[..size]) { + Ok(msg) => { + trace!("{}: received {:?}", addr, &msg); + match output_tx.send((msg, addr)).await { + Ok(_) => {} + Err(_) => break, + } + } + Err(e) => log::debug!("{}: error deserializing incoming message: {}", addr, e), + } + } + Err::<(), _>(anyhow::anyhow!( + "DHT UDP socket reader over, nowhere to read messages from" + )) + }; + let result = tokio::select! { + err = writer => err, + err = reader => err, + }; + result.context("DHT UDP framer closed") +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum Request { + GetPeers(Id20), + FindNode(Id20), +} + +#[derive(Debug)] +enum Response { + Peer(SocketAddr), +} + +pub struct Dht { + request_tx: Sender<(Request, UnboundedSender)>, +} + +struct DhtWorker { + socket: UdpSocket, + peer_id: Id20, + state: Mutex, +} + +impl DhtWorker { + fn on_request( + &self, + request: Request, + sender: UnboundedSender, + ) -> anyhow::Result<()> { + self.state.lock().on_request(request, sender) + } + fn on_response(&self, msg: Message, addr: SocketAddr) -> anyhow::Result<()> { + self.state.lock().on_incoming_from_remote(msg, addr) + } + + async fn start( + self, + in_tx: UnboundedSender<(Message, SocketAddr)>, + in_rx: UnboundedReceiver<(Message, SocketAddr)>, + mut request_rx: Receiver<(Request, UnboundedSender)>, + bootstrap_addrs: &[String], + ) -> anyhow::Result<()> { + let (out_tx, mut out_rx) = channel(1); + let framer = run_framer(&self.socket, in_rx, out_tx); + + let bootstrap = async { + let mut futs = FuturesUnordered::new(); + // bootstrap + for addr in bootstrap_addrs.iter() { + let addr = addr; + let this = &self; + let in_tx = &in_tx; + futs.push(async move { + match tokio::net::lookup_host(addr).await { + Ok(addrs) => { + for addr in addrs { + let request = this + .state + .lock() + .create_request(Request::FindNode(this.peer_id), addr); + in_tx.send((request, addr))?; + } + } + Err(e) => warn!("error looking up {}: {}", addr, e), + } + Ok::<_, anyhow::Error>(()) + }); + } + let mut successes = 0; + while let Some(resp) = futs.next().await { + match resp { + Ok(_) => successes += 1, + Err(e) => warn!("error in one of the bootstrappers: {}", e), + } + } + if successes == 0 { + anyhow::bail!("bootstrapping did not succeed") + } + Ok(()) + }; + let mut bootstrap_done = false; + + let request_reader = { + let this = &self; + async move { + while let Some((request, sender)) = request_rx.recv().await { + this.on_request(request, sender) + .context("error processing request")?; + } + Err::<(), _>(anyhow::anyhow!( + "closed request reader, no more subscribers" + )) + } + }; + + let response_reader = { + let this = &self; + async move { + while let Some((response, addr)) = out_rx.recv().await { + if let Err(e) = this.on_response(response, addr) { + debug!("error in on_response, addr={:?}: {}", addr, e) + } + } + Err::<(), _>(anyhow::anyhow!( + "closed response reader, nowhere to send results to, DHT closed" + )) + } + }; + + tokio::pin!(framer); + tokio::pin!(bootstrap); + tokio::pin!(request_reader); + tokio::pin!(response_reader); + + loop { + tokio::select! { + err = &mut framer => { + anyhow::bail!("framer quit: {:?}", err) + }, + result = &mut bootstrap, if !bootstrap_done => { + bootstrap_done = true; + result?; + }, + err = &mut request_reader => {anyhow::bail!("request reader quit: {:?}", err)} + err = &mut response_reader => {anyhow::bail!("response reader quit: {:?}", err)} + } + } + } +} + +impl Dht { + pub async fn new(bootstrap_addrs: &[&str]) -> anyhow::Result { + let (request_tx, request_rx) = channel(1); + let socket = UdpSocket::bind("0.0.0.0:0") + .await + .context("error binding socket")?; + let peer_id = Id20(generate_peer_id()); + info!("starting up DHT with peer id {:?}", peer_id); + let bootstrap_addrs = bootstrap_addrs + .iter() + .map(|s| s.to_string()) + .collect::>(); + + tokio::spawn(async move { + let (in_tx, in_rx) = unbounded_channel(); + let worker = DhtWorker { + socket, + peer_id, + state: Mutex::new(DhtState::new(peer_id, in_tx.clone())), + }; + let result = worker + .start(in_tx, in_rx, request_rx, &bootstrap_addrs) + .await; + warn!("DHT worker finished with {:?}", result); + }); + Ok(Dht { request_tx }) + } + pub async fn get_peers(&self, info_hash: Id20) -> impl StreamExt { + let (tx, rx) = unbounded_channel::(); + self.request_tx + .send((Request::GetPeers(info_hash), tx)) + .await + .unwrap(); + UnboundedReceiverStream::new(rx).map(|r| match r { + Response::Peer(addr) => addr, + }) + } +} diff --git a/crates/dht/src/id20.rs b/crates/dht/src/id20.rs index afd2ba3..6d7bc01 100644 --- a/crates/dht/src/id20.rs +++ b/crates/dht/src/id20.rs @@ -1,10 +1,23 @@ -use std::cmp::Ordering; +use std::{cmp::Ordering, str::FromStr}; use serde::{Deserialize, Deserializer, Serialize}; #[derive(Clone, Copy, PartialEq, Eq, Hash)] pub struct Id20(pub [u8; 20]); +impl FromStr for Id20 { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let mut out = [0u8; 20]; + if s.len() != 40 { + anyhow::bail!("expected a hex string of length 40") + }; + hex::decode_to_slice(s, &mut out)?; + Ok(Id20(out)) + } +} + impl std::fmt::Debug for Id20 { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "<")?; diff --git a/crates/dht/src/lib.rs b/crates/dht/src/lib.rs index 6c46a8b..3658052 100644 --- a/crates/dht/src/lib.rs +++ b/crates/dht/src/lib.rs @@ -1,3 +1,4 @@ pub mod bprotocol; +pub mod dht; pub mod id20; pub mod routing_table; diff --git a/crates/dht/src/main.rs b/crates/dht/src/main.rs index cd0478d..16fed0a 100644 --- a/crates/dht/src/main.rs +++ b/crates/dht/src/main.rs @@ -1,397 +1,23 @@ -use std::{ - cell::RefCell, - collections::{BTreeMap, HashMap}, - net::{SocketAddr, SocketAddrV4}, - time::Instant, -}; +use std::{collections::HashSet, str::FromStr, time::Duration}; -use bencode::ByteString; -use dht::{ - bprotocol::{ - self, CompactNodeInfo, CompactPeerInfo, FindNodeRequest, GetPeersRequest, Message, - MessageKind, - }, - id20::Id20, - routing_table::RoutingTable, -}; -use futures::{stream::FuturesUnordered, StreamExt}; -use librqbit_core::peer_id::generate_peer_id; -use log::{debug, warn}; -use parking_lot::Mutex; -use tokio::{ - net::UdpSocket, - sync::mpsc::{channel, Receiver, Sender, UnboundedReceiver, UnboundedSender}, -}; -use tokio_stream::wrappers::ReceiverStream; - -struct OutstandingRequest { - transaction_id: u16, - addr: SocketAddr, - request: Request, - time: Instant, -} - -struct DhtState { - id: Id20, - next_transaction_id: u16, - outstanding_requests: Vec, - searching_for_peers: Vec, - routing_table: RoutingTable, - sender: UnboundedSender<(Message, SocketAddr)>, - - // TODO: convert to broadcast - subscribers: HashMap>>, -} - -enum PeersOrNodes { - Nodes(CompactNodeInfo), - Peers(Vec), -} - -impl DhtState { - pub fn new(id: Id20, sender: UnboundedSender<(Message, SocketAddr)>) -> Self { - Self { - id, - next_transaction_id: 0, - outstanding_requests: Vec::new(), - searching_for_peers: Vec::new(), - routing_table: RoutingTable::new(id), - sender, - subscribers: Default::default(), - } - } - - fn add_searching_for_peers(&mut self, info_hash: Id20) { - self.searching_for_peers.push(info_hash) - } - pub fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message { - let transaction_id = self.next_transaction_id; - let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8]; - let message = match request { - Request::GetPeers(info_hash) => Message { - transaction_id: ByteString::from(transaction_id_buf.as_ref()), - version: None, - ip: None, - kind: MessageKind::GetPeersRequest(GetPeersRequest { - id: self.id, - info_hash, - }), - }, - Request::FindNode(target) => Message { - transaction_id: ByteString::from(transaction_id_buf.as_ref()), - version: None, - ip: None, - kind: MessageKind::FindNodeRequest(FindNodeRequest { - id: self.id, - target, - }), - }, - }; - self.outstanding_requests.push(OutstandingRequest { - transaction_id, - addr, - request, - time: Instant::now(), - }); - message - } - fn on_incoming_from_remote( - &mut self, - msg: Message, - addr: SocketAddr, - ) -> anyhow::Result<()> { - match msg.kind { - MessageKind::Error(_) | MessageKind::Response(_) => {} - other => anyhow::bail!("requests from DHT not supported, but got {:?}", other), - }; - if msg.transaction_id.len() != 2 { - anyhow::bail!("transaction id unrecognized") - } - let tid = ((msg.transaction_id[0] as u16) << 8) + (msg.transaction_id[1] as u16); - // O(n) but whatever - let outstanding_id = self - .outstanding_requests - .iter() - .position(|req| req.transaction_id == tid && req.addr == addr) - .ok_or_else(|| anyhow::anyhow!("outstanding request not found"))?; - let outstanding = self.outstanding_requests.remove(outstanding_id); - let response = match msg.kind { - MessageKind::Error(e) => { - anyhow::bail!( - "request {:?} received error response {:?}", - outstanding.request, - e - ) - } - MessageKind::Response(r) => r, - _ => unreachable!(), - }; - match outstanding.request { - Request::FindNode(id) => { - let nodes = response - .nodes - .ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?; - self.on_found_nodes(response.id, addr, id, nodes) - } - Request::GetPeers(id) => { - if response.id != id { - anyhow::bail!( - "response id does not match: expected {:?}, received {:?}", - id, - response.id - ) - }; - let pn = match (response.nodes, response.values) { - (Some(nodes), None) => PeersOrNodes::Nodes(nodes), - (None, Some(peers)) => PeersOrNodes::Peers(peers), - _ => anyhow::bail!("expected nodes or values to be set in find_peers response"), - }; - self.on_found_peers_or_nodes(response.id, addr, id, pn) - } - } - } - - pub fn on_request(&mut self, request: Request, sender: Sender) -> anyhow::Result<()> { - match request { - Request::GetPeers(info_hash) => { - let subs = self.subscribers.entry(info_hash).or_default(); - subs.push(sender); - self.add_searching_for_peers(info_hash); - - // workaround borrow checker. - let mut addrs = Vec::new(); - for node in self - .routing_table - .sorted_by_distance_from_mut(info_hash) - .into_iter() - .take(8) - { - node.mark_outgoing_request(); - addrs.push(node.addr()); - } - for addr in addrs { - let request = self.create_request(Request::GetPeers(info_hash), addr); - self.sender.send((request, addr))?; - } - } - Request::FindNode(_) => todo!(), - }; - Ok(()) - } - - fn on_found_nodes( - &mut self, - source: Id20, - source_addr: SocketAddr, - target: Id20, - nodes: CompactNodeInfo, - ) -> anyhow::Result<()> { - todo!("on_found_nodes not implemented") - } - - fn on_found_peers_or_nodes( - &mut self, - source: Id20, - source_addr: SocketAddr, - target: Id20, - data: PeersOrNodes, - ) -> anyhow::Result<()> { - todo!("on_found_peers_or_nodes not implemented") - } -} - -async fn run_framer( - socket: &UdpSocket, - mut input_rx: UnboundedReceiver<(Message, SocketAddr)>, - output_tx: Sender<(Message, SocketAddr)>, -) -> anyhow::Result<()> { - let writer = async { - let mut buf = Vec::new(); - while let Some((msg, addr)) = input_rx.recv().await { - buf.clear(); - bprotocol::serialize_message( - &mut buf, - msg.transaction_id, - msg.version, - msg.ip, - msg.kind, - ) - .unwrap(); - socket.send_to(&buf, addr).await.unwrap(); - } - }; - let reader = async { - let mut buf = vec![0u8; 16384]; - while let Ok((size, addr)) = socket.recv_from(&mut buf).await { - match bprotocol::deserialize_message::(&buf[..size]) { - Ok(msg) => match output_tx.send((msg, addr)).await { - Ok(_) => {} - Err(_) => break, - }, - Err(e) => log::warn!("error deseriaizing msg: {}", e), - } - } - }; - tokio::select! { - _ = writer => {}, - _ = reader => {}, - }; - Ok(()) -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -enum Request { - GetPeers(Id20), - FindNode(Id20), -} - -#[derive(Debug)] -enum Response { - Peer(SocketAddr), -} - -struct Dht { - request_tx: Sender<(Request, Sender)>, -} - -struct DhtWorker { - socket: UdpSocket, - peer_id: Id20, - state: Mutex, -} - -impl DhtWorker { - fn on_request(&self, request: Request, sender: Sender) -> anyhow::Result<()> { - self.state.lock().on_request(request, sender) - } - fn on_response(&self, msg: Message, addr: SocketAddr) -> anyhow::Result<()> { - self.state.lock().on_incoming_from_remote(msg, addr) - } - - async fn start( - self, - in_tx: UnboundedSender<(Message, SocketAddr)>, - in_rx: UnboundedReceiver<(Message, SocketAddr)>, - mut request_rx: Receiver<(Request, Sender)>, - bootstrap_addrs: &[String], - ) -> anyhow::Result<()> { - let (out_tx, mut out_rx) = channel(1); - let framer = run_framer(&self.socket, in_rx, out_tx); - - let bootstrap = async { - let mut futs = FuturesUnordered::new(); - // bootstrap - for addr in bootstrap_addrs.iter() { - let addr = addr; - let this = &self; - let in_tx = &in_tx; - futs.push(async move { - match tokio::net::lookup_host(addr).await { - Ok(addrs) => { - for addr in addrs { - let request = this - .state - .lock() - .create_request(Request::FindNode(this.peer_id), addr); - match in_tx.send((request, addr)) { - Ok(_) => {} - Err(e) => { - debug!("bootstrap: channel closed, did not send {:?}", e) - } - }; - } - } - Err(e) => warn!("error looking up {}", addr), - } - }); - } - while futs.next().await.is_some() {} - }; - let mut bootstrap_done = false; - - let request_reader = { - let this = &self; - async move { - while let Some((request, sender)) = request_rx.recv().await { - this.on_request(request, sender).unwrap(); - } - } - }; - - let response_reader = { - let this = &self; - async move { - while let Some((response, addr)) = out_rx.recv().await { - this.on_response(response, addr).unwrap(); - } - } - }; - - tokio::pin!(framer); - tokio::pin!(bootstrap); - tokio::pin!(request_reader); - tokio::pin!(response_reader); - - loop { - tokio::select! { - _ = &mut framer => { - anyhow::bail!("framer quit") - }, - _ = &mut bootstrap, if !bootstrap_done => { - bootstrap_done = true - }, - _ = &mut request_reader => {anyhow::bail!("request reader quit")} - _ = &mut response_reader => {anyhow::bail!("response reader quit")} - } - } - } -} - -impl Dht { - pub async fn new(bootstrap_addrs: &[&str]) -> anyhow::Result { - let (request_tx, request_rx) = channel(1); - let socket = UdpSocket::bind("0.0.0.0:0").await?; - let peer_id = Id20(generate_peer_id()); - let bootstrap_addrs = bootstrap_addrs - .iter() - .map(|s| s.to_string()) - .collect::>(); - - tokio::spawn(async move { - let (in_tx, in_rx) = tokio::sync::mpsc::unbounded_channel(); - let worker = DhtWorker { - socket, - peer_id, - state: Mutex::new(DhtState::new(peer_id, in_tx.clone())), - }; - worker - .start(in_tx, in_rx, request_rx, &bootstrap_addrs) - .await - }); - Ok(Dht { request_tx }) - } - pub async fn get_peers(&self, info_hash: Id20) -> impl StreamExt { - let (tx, rx) = channel::(1); - self.request_tx - .send((Request::GetPeers(info_hash), tx)) - .await - .unwrap(); - ReceiverStream::new(rx).map(|r| match r { - Response::Peer(addr) => addr, - _ => panic!("programming error"), - }) - } -} +use anyhow::Context; +use dht::{dht::Dht, id20::Id20}; +use tokio_stream::StreamExt; #[tokio::main] async fn main() -> anyhow::Result<()> { pretty_env_logger::init(); - let info_hash = Id20([0u8; 20]); - let dht = Dht::new(&["dht.transmissionbt.com:6881"]).await.unwrap(); + let info_hash = Id20::from_str("64a980abe6e448226bb930ba061592e44c3781a1").unwrap(); + let dht = Dht::new(&["dht.transmissionbt.com:6881", "dht.libtorrent.org:25401"]) + .await + .context("error initializing dht")?; let mut stream = dht.get_peers(info_hash).await; + let mut seen = HashSet::new(); while let Some(peer) = stream.next().await { - log::info!("peer found: {}", peer) + if seen.insert(peer) { + log::info!("peer found: {}", peer) + } } Ok(()) } diff --git a/crates/dht/src/routing_table.rs b/crates/dht/src/routing_table.rs index dcbea69..e284f99 100644 --- a/crates/dht/src/routing_table.rs +++ b/crates/dht/src/routing_table.rs @@ -3,6 +3,10 @@ use std::{ time::{Duration, Instant}, }; +use log::debug; + +use crate::id20::Id20; + #[derive(Debug)] enum BucketTreeNode { Leaf(Vec), @@ -248,6 +252,8 @@ impl BucketTree { .find(|r| matches!(r.status(), NodeStatus::Bad)) { std::mem::swap(bad_node, &mut new_node); + nodes.sort_by_key(|n| n.id); + debug!("replaced bad node {:?}", new_node); return InsertResult::ReplacedBad(new_node); } @@ -297,8 +303,6 @@ impl Default for BucketTree { } } -use crate::id20::Id20; - #[derive(Debug)] pub struct RoutingTableNode { id: Id20, @@ -344,7 +348,11 @@ impl RoutingTableNode { } pub fn mark_response(&mut self) { - self.last_response = Some(Instant::now()); + let now = Instant::now(); + self.last_response = Some(now); + if self.last_request.is_none() { + self.last_request = Some(now); + } self.outstanding_queries_in_a_row = 0; } } @@ -386,7 +394,7 @@ impl RoutingTable { let res = self.buckets.add_node(&self.id, id, addr); let replaced = match &res { InsertResult::WasExisting => false, - InsertResult::ReplacedBad(_) => true, + InsertResult::ReplacedBad(..) => true, InsertResult::Added => true, InsertResult::Ignored => false, }; diff --git a/crates/librqbit/src/dht/inforead.rs b/crates/librqbit/src/dht/inforead.rs index 999fbc2..5385b32 100644 --- a/crates/librqbit/src/dht/inforead.rs +++ b/crates/librqbit/src/dht/inforead.rs @@ -1,7 +1,7 @@ use std::net::SocketAddr; use buffers::ByteString; -use futures::{stream::FuturesUnordered, StreamExt}; +use futures::{stream::FuturesUnordered, Stream, StreamExt}; use librqbit_core::torrent_metainfo::TorrentMetaV1Info; use log::debug; use tokio::sync::mpsc::UnboundedReceiver; @@ -23,10 +23,10 @@ pub enum ReadMetainfoResult { pub async fn read_metainfo_from_peer_receiver( peer_id: [u8; 20], info_hash: [u8; 20], - mut addrs: UnboundedReceiver, + mut addrs: impl StreamExt + Unpin, ) -> ReadMetainfoResult { let mut seen = Vec::::new(); - let first_addr = match addrs.recv().await { + let first_addr = match addrs.next().await { Some(addr) => addr, None => return ReadMetainfoResult::ChannelClosed { seen }, }; @@ -39,7 +39,7 @@ pub async fn read_metainfo_from_peer_receiver( loop { tokio::select! { - next_addr = addrs.recv() => { + next_addr = addrs.next() => { match next_addr { Some(addr) => { seen.push(addr);