From d57079c75aaaf24b1eca8473f7a489891a552301 Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Mon, 12 Jul 2021 16:24:26 +0100 Subject: [PATCH] Will start to test soon --- crates/dht/src/id20.rs | 2 +- crates/dht/src/main.rs | 242 ++++++++++++++++++++++---------- crates/dht/src/routing_table.rs | 115 ++++++++++++++- 3 files changed, 285 insertions(+), 74 deletions(-) diff --git a/crates/dht/src/id20.rs b/crates/dht/src/id20.rs index 37c8cd0..afd2ba3 100644 --- a/crates/dht/src/id20.rs +++ b/crates/dht/src/id20.rs @@ -2,7 +2,7 @@ use std::cmp::Ordering; use serde::{Deserialize, Deserializer, Serialize}; -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] pub struct Id20(pub [u8; 20]); impl std::fmt::Debug for Id20 { diff --git a/crates/dht/src/main.rs b/crates/dht/src/main.rs index ae6606c..cd0478d 100644 --- a/crates/dht/src/main.rs +++ b/crates/dht/src/main.rs @@ -1,5 +1,6 @@ use std::{ - collections::BTreeMap, + cell::RefCell, + collections::{BTreeMap, HashMap}, net::{SocketAddr, SocketAddrV4}, time::Instant, }; @@ -11,12 +12,15 @@ use dht::{ MessageKind, }, id20::Id20, + routing_table::RoutingTable, }; -use futures::StreamExt; +use futures::{stream::FuturesUnordered, StreamExt}; use librqbit_core::peer_id::generate_peer_id; +use log::{debug, warn}; +use parking_lot::Mutex; use tokio::{ net::UdpSocket, - sync::mpsc::{channel, Receiver, Sender}, + sync::mpsc::{channel, Receiver, Sender, UnboundedReceiver, UnboundedSender}, }; use tokio_stream::wrappers::ReceiverStream; @@ -32,6 +36,11 @@ struct DhtState { next_transaction_id: u16, outstanding_requests: Vec, searching_for_peers: Vec, + routing_table: RoutingTable, + sender: UnboundedSender<(Message, SocketAddr)>, + + // TODO: convert to broadcast + subscribers: HashMap>>, } enum PeersOrNodes { @@ -40,10 +49,22 @@ enum PeersOrNodes { } impl DhtState { + pub fn new(id: Id20, sender: UnboundedSender<(Message, SocketAddr)>) -> Self { + Self { + id, + next_transaction_id: 0, + outstanding_requests: Vec::new(), + searching_for_peers: Vec::new(), + routing_table: RoutingTable::new(id), + sender, + subscribers: Default::default(), + } + } + fn add_searching_for_peers(&mut self, info_hash: Id20) { self.searching_for_peers.push(info_hash) } - fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message { + pub fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message { let transaction_id = self.next_transaction_id; let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8]; let message = match request { @@ -107,17 +128,10 @@ impl DhtState { }; match outstanding.request { Request::FindNode(id) => { - if response.id != id { - anyhow::bail!( - "response id does not match: expected {:?}, received {:?}", - id, - response.id - ) - }; let nodes = response .nodes .ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?; - self.on_found_nodes(id, nodes) + self.on_found_nodes(response.id, addr, id, nodes) } Request::GetPeers(id) => { if response.id != id { @@ -127,32 +141,69 @@ impl DhtState { response.id ) }; - let nodes = response - .nodes - .ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?; - // let pn = match (response.nodes, response.values) { - // (Some(nodes), None) => PeersOrNodes::Nodes(nodes), - // (None, Some(peers)) => PeersOrNodes::Peers(peers), - // _ => anyhow::bail!("expected nodes or values to be set in find_peers response"), - // }; - // self.on_found_peers_or_nodes(id, pn) + let pn = match (response.nodes, response.values) { + (Some(nodes), None) => PeersOrNodes::Nodes(nodes), + (None, Some(peers)) => PeersOrNodes::Peers(peers), + _ => anyhow::bail!("expected nodes or values to be set in find_peers response"), + }; + self.on_found_peers_or_nodes(response.id, addr, id, pn) } + } + } + + pub fn on_request(&mut self, request: Request, sender: Sender) -> anyhow::Result<()> { + match request { + Request::GetPeers(info_hash) => { + let subs = self.subscribers.entry(info_hash).or_default(); + subs.push(sender); + self.add_searching_for_peers(info_hash); + + // workaround borrow checker. + let mut addrs = Vec::new(); + for node in self + .routing_table + .sorted_by_distance_from_mut(info_hash) + .into_iter() + .take(8) + { + node.mark_outgoing_request(); + addrs.push(node.addr()); + } + for addr in addrs { + let request = self.create_request(Request::GetPeers(info_hash), addr); + self.sender.send((request, addr))?; + } + } + Request::FindNode(_) => todo!(), }; Ok(()) } - fn on_found_nodes(&mut self, target: Id20, nodes: CompactNodeInfo) { + + fn on_found_nodes( + &mut self, + source: Id20, + source_addr: SocketAddr, + target: Id20, + nodes: CompactNodeInfo, + ) -> anyhow::Result<()> { todo!("on_found_nodes not implemented") } - fn on_found_peers_or_nodes(&mut self, target: Id20, data: PeersOrNodes) { - todo!("on_found_nodes not implemented") + fn on_found_peers_or_nodes( + &mut self, + source: Id20, + source_addr: SocketAddr, + target: Id20, + data: PeersOrNodes, + ) -> anyhow::Result<()> { + todo!("on_found_peers_or_nodes not implemented") } } async fn run_framer( socket: &UdpSocket, - mut input_rx: Receiver<(Message, SocketAddr)>, - output_tx: Sender>, + mut input_rx: UnboundedReceiver<(Message, SocketAddr)>, + output_tx: Sender<(Message, SocketAddr)>, ) -> anyhow::Result<()> { let writer = async { let mut buf = Vec::new(); @@ -173,7 +224,7 @@ async fn run_framer( let mut buf = vec![0u8; 16384]; while let Ok((size, addr)) = socket.recv_from(&mut buf).await { match bprotocol::deserialize_message::(&buf[..size]) { - Ok(msg) => match output_tx.send(msg).await { + Ok(msg) => match output_tx.send((msg, addr)).await { Ok(_) => {} Err(_) => break, }, @@ -188,7 +239,7 @@ async fn run_framer( Ok(()) } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum Request { GetPeers(Id20), FindNode(Id20), @@ -205,50 +256,94 @@ struct Dht { struct DhtWorker { socket: UdpSocket, - request_rx: Receiver<(Request, Sender)>, - next_transaction_id: u16, peer_id: Id20, + state: Mutex, } impl DhtWorker { - fn on_request(&self, request: Request, sender: Sender) {} + fn on_request(&self, request: Request, sender: Sender) -> anyhow::Result<()> { + self.state.lock().on_request(request, sender) + } + fn on_response(&self, msg: Message, addr: SocketAddr) -> anyhow::Result<()> { + self.state.lock().on_incoming_from_remote(msg, addr) + } - async fn start(&mut self, bootstrap_addrs: Vec) -> anyhow::Result<()> { - let (in_tx, in_rx) = channel(1); - let (out_tx, out_rx) = channel(1); + async fn start( + self, + in_tx: UnboundedSender<(Message, SocketAddr)>, + in_rx: UnboundedReceiver<(Message, SocketAddr)>, + mut request_rx: Receiver<(Request, Sender)>, + bootstrap_addrs: &[String], + ) -> anyhow::Result<()> { + let (out_tx, mut out_rx) = channel(1); let framer = run_framer(&self.socket, in_rx, out_tx); let bootstrap = async { + let mut futs = FuturesUnordered::new(); // bootstrap - for addr in bootstrap_addrs { - for addr in tokio::net::lookup_host(addr).await.unwrap() { - // let msg = MessageKind::FindNodeRequest(FindNodeRequest { - // id: self.peer_id, - // target: self.peer_id, - // }); - // in_tx.send((msg, addr)).await.unwrap(); - } + for addr in bootstrap_addrs.iter() { + let addr = addr; + let this = &self; + let in_tx = &in_tx; + futs.push(async move { + match tokio::net::lookup_host(addr).await { + Ok(addrs) => { + for addr in addrs { + let request = this + .state + .lock() + .create_request(Request::FindNode(this.peer_id), addr); + match in_tx.send((request, addr)) { + Ok(_) => {} + Err(e) => { + debug!("bootstrap: channel closed, did not send {:?}", e) + } + }; + } + } + Err(e) => warn!("error looking up {}", addr), + } + }); } + while futs.next().await.is_some() {} }; let mut bootstrap_done = false; - // let request_reader = async { - // while let Some((request, sender)) = self.request_rx.recv().await { - // self.on_request(request, sender) - // } - // }; + let request_reader = { + let this = &self; + async move { + while let Some((request, sender)) = request_rx.recv().await { + this.on_request(request, sender).unwrap(); + } + } + }; - // tokio::select! { - // _ = framer => { - // anyhow::bail!("framer quit") - // }, - // _ = bootstrap, if !bootstrap_done => { - // bootstrap_done = true - // }, - // _ = request_reader => {} - // } + let response_reader = { + let this = &self; + async move { + while let Some((response, addr)) = out_rx.recv().await { + this.on_response(response, addr).unwrap(); + } + } + }; - todo!() + tokio::pin!(framer); + tokio::pin!(bootstrap); + tokio::pin!(request_reader); + tokio::pin!(response_reader); + + loop { + tokio::select! { + _ = &mut framer => { + anyhow::bail!("framer quit") + }, + _ = &mut bootstrap, if !bootstrap_done => { + bootstrap_done = true + }, + _ = &mut request_reader => {anyhow::bail!("request reader quit")} + _ = &mut response_reader => {anyhow::bail!("response reader quit")} + } + } } } @@ -256,14 +351,23 @@ impl Dht { pub async fn new(bootstrap_addrs: &[&str]) -> anyhow::Result { let (request_tx, request_rx) = channel(1); let socket = UdpSocket::bind("0.0.0.0:0").await?; - let mut worker = DhtWorker { - socket, - request_rx, - next_transaction_id: 0, - peer_id: Id20(generate_peer_id()), - }; - let bootstrap_addrs = bootstrap_addrs.iter().map(|s| s.to_string()).collect(); - tokio::spawn(async move { worker.start(bootstrap_addrs).await }); + let peer_id = Id20(generate_peer_id()); + let bootstrap_addrs = bootstrap_addrs + .iter() + .map(|s| s.to_string()) + .collect::>(); + + tokio::spawn(async move { + let (in_tx, in_rx) = tokio::sync::mpsc::unbounded_channel(); + let worker = DhtWorker { + socket, + peer_id, + state: Mutex::new(DhtState::new(peer_id, in_tx.clone())), + }; + worker + .start(in_tx, in_rx, request_rx, &bootstrap_addrs) + .await + }); Ok(Dht { request_tx }) } pub async fn get_peers(&self, info_hash: Id20) -> impl StreamExt { @@ -277,16 +381,12 @@ impl Dht { _ => panic!("programming error"), }) } - // async fn run(self) -> anyhow::Result { - // let socket = UdpSocket::bind("0.0.0.0:0").await?; - // let (in_tx, in_rx) = channel(1); - // let (out_tx, out_rx) = channel(1); - // let framer = run_framer(socket, in_rx, out_tx); - // } } #[tokio::main] async fn main() -> anyhow::Result<()> { + pretty_env_logger::init(); + let info_hash = Id20([0u8; 20]); let dht = Dht::new(&["dht.transmissionbt.com:6881"]).await.unwrap(); let mut stream = dht.get_peers(info_hash).await; diff --git a/crates/dht/src/routing_table.rs b/crates/dht/src/routing_table.rs index e5f09cd..dcbea69 100644 --- a/crates/dht/src/routing_table.rs +++ b/crates/dht/src/routing_table.rs @@ -1,4 +1,7 @@ -use std::{net::SocketAddr, time::Instant}; +use std::{ + net::SocketAddr, + time::{Duration, Instant}, +}; #[derive(Debug)] enum BucketTreeNode { @@ -63,6 +66,55 @@ impl<'a> Iterator for BucketTreeNodeIterator<'a> { } } +pub struct BucketTreeNodeIteratorMut<'a> { + current: std::slice::IterMut<'a, RoutingTableNode>, + queue: Vec<&'a mut BucketTree>, +} + +impl<'a> BucketTreeNodeIteratorMut<'a> { + fn new(mut tree: &'a mut BucketTree) -> Self { + let mut queue = Vec::new(); + let current = loop { + match &mut tree.data { + BucketTreeNode::Leaf(nodes) => break nodes.iter_mut(), + BucketTreeNode::LeftRight(left, right) => { + queue.push(right.as_mut()); + tree = left.as_mut() + } + } + }; + BucketTreeNodeIteratorMut { current, queue } + } +} + +impl<'a> Iterator for BucketTreeNodeIteratorMut<'a> { + type Item = &'a mut RoutingTableNode; + + fn next(&mut self) -> Option { + if let Some(v) = self.current.next() { + return Some(v); + }; + + loop { + let tree = self.queue.pop()?; + match &mut tree.data { + BucketTreeNode::Leaf(nodes) => { + self.current = nodes.iter_mut(); + match self.current.next() { + Some(v) => return Some(v), + None => continue, + } + } + BucketTreeNode::LeftRight(left, right) => { + self.queue.push(right.as_mut()); + self.queue.push(left.as_mut()); + continue; + } + } + } + } +} + fn compute_split_start_end( start: Id20, end_inclusive: Id20, @@ -129,6 +181,23 @@ impl BucketTree { pub fn iter(&self) -> BucketTreeNodeIterator<'_> { BucketTreeNodeIterator::new(self) } + + pub fn iter_mut(&mut self) -> BucketTreeNodeIteratorMut<'_> { + BucketTreeNodeIteratorMut::new(self) + } + + pub fn get_mut(&mut self, id: &Id20) -> Option<&mut RoutingTableNode> { + if !(*id >= self.start && *id <= self.end_inclusive) { + return None; + } + match &mut self.data { + BucketTreeNode::Leaf(nodes) => nodes.iter_mut().find(|b| b.id == *id), + BucketTreeNode::LeftRight(left, right) => { + left.get_mut(id).or_else(move || right.get_mut(id)) + } + } + } + pub fn add_node(&mut self, self_id: &Id20, id: Id20, addr: SocketAddr) -> InsertResult { let mut tree = self; loop { @@ -259,11 +328,25 @@ impl RoutingTableNode { Some(v) => v, None => return NodeStatus::Unknown, }; + if self.outstanding_queries_in_a_row > 0 && last_request.elapsed() > Duration::from_secs(10) + { + return NodeStatus::Bad; + } if self.last_response.is_some() { return NodeStatus::Good; } NodeStatus::Questionable } + + pub fn mark_outgoing_request(&mut self) { + self.last_request = Some(Instant::now()); + self.outstanding_queries_in_a_row += 1; + } + + pub fn mark_response(&mut self) { + self.last_response = Some(Instant::now()); + self.outstanding_queries_in_a_row = 0; + } } #[derive(Debug)] @@ -289,6 +372,16 @@ impl RoutingTable { result.sort_by_key(|n| id.distance(&n.id)); result } + + pub fn sorted_by_distance_from_mut(&mut self, id: Id20) -> Vec<&mut RoutingTableNode> { + let mut result = Vec::with_capacity(self.size); + for node in self.buckets.iter_mut() { + result.push(node); + } + result.sort_by_key(|n| id.distance(&n.id)); + result + } + pub fn add_node(&mut self, id: Id20, addr: SocketAddr) -> InsertResult { let res = self.buckets.add_node(&self.id, id, addr); let replaced = match &res { @@ -302,6 +395,23 @@ impl RoutingTable { } res } + pub fn mark_outgoing_request(&mut self, id: &Id20) -> bool { + let r = match self.buckets.get_mut(id) { + Some(r) => r, + None => return false, + }; + r.mark_outgoing_request(); + true + } + + pub fn mark_response(&mut self, id: &Id20) -> bool { + let r = match self.buckets.get_mut(id) { + Some(r) => r, + None => return false, + }; + r.mark_response(); + true + } } #[cfg(test)] @@ -410,6 +520,7 @@ mod tests { let addr = std::net::SocketAddr::V4(SocketAddrV4::new("0.0.0.0".parse().unwrap(), i)); rtable.add_node(other_id, addr); } - dbg!(rtable); + dbg!(&rtable); + assert_eq!(rtable.sorted_by_distance_from(my_id).len(), rtable.size); } }