diff --git a/Cargo.lock b/Cargo.lock index 59b9499..8bda825 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -276,6 +276,7 @@ version = "0.1.0" dependencies = [ "anyhow", "bencode", + "futures 0.3.15", "hex 0.4.3", "kad", "librqbit_core", @@ -284,6 +285,7 @@ dependencies = [ "pretty_env_logger", "serde", "tokio", + "tokio-stream", ] [[package]] diff --git a/crates/dht/Cargo.toml b/crates/dht/Cargo.toml index 93a2327..9c578c8 100644 --- a/crates/dht/Cargo.toml +++ b/crates/dht/Cargo.toml @@ -8,6 +8,7 @@ edition = "2018" [dependencies] kad = "0.6" tokio = {version = "1", features = ["macros", "rt-multi-thread", "net", "sync"]} +tokio-stream = "0.1" serde = {version = "1", features = ["derive"]} hex = "0.4" bencode = {path = "../bencode"} @@ -15,6 +16,7 @@ anyhow = "1" parking_lot = "0.11" log = "0.4" pretty_env_logger = "0.4" +futures = "0.3" librqbit_core = {path="../librqbit_core"} diff --git a/crates/dht/src/bprotocol.rs b/crates/dht/src/bprotocol.rs index f4e7093..57aa219 100644 --- a/crates/dht/src/bprotocol.rs +++ b/crates/dht/src/bprotocol.rs @@ -10,6 +10,8 @@ use serde::{ Deserialize, Deserializer, Serialize, }; +use crate::id20::Id20; + #[derive(Debug)] enum MessageType { Request, @@ -17,57 +19,6 @@ enum MessageType { Error, } -#[derive(Clone, Copy)] -pub struct Id20(pub [u8; 20]); - -impl std::fmt::Debug for Id20 { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "<")?; - for byte in self.0 { - write!(f, "{:02x?}", byte)?; - } - write!(f, ">")?; - Ok(()) - } -} - -impl Serialize for Id20 { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_bytes(&self.0) - } -} - -impl<'de> Deserialize<'de> for Id20 { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct Visitor; - impl<'de> serde::de::Visitor<'de> for Visitor { - type Value = Id20; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(formatter, "a 20 byte slice") - } - fn visit_bytes(self, v: &[u8]) -> Result - where - E: serde::de::Error, - { - if v.len() != 20 { - return Err(E::invalid_length(20, &self)); - } - let mut buf = [0u8; 20]; - buf.copy_from_slice(&v); - Ok(Id20(buf)) - } - } - deserializer.deserialize_bytes(Visitor {}) - } -} - impl<'de> Deserialize<'de> for MessageType { fn deserialize(deserializer: D) -> Result where diff --git a/crates/dht/src/id20.rs b/crates/dht/src/id20.rs new file mode 100644 index 0000000..61befdc --- /dev/null +++ b/crates/dht/src/id20.rs @@ -0,0 +1,89 @@ +use std::cmp::Ordering; + +use serde::{Deserialize, Deserializer, Serialize}; + +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct Id20(pub [u8; 20]); + +impl std::fmt::Debug for Id20 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "<")?; + for byte in self.0 { + write!(f, "{:02x?}", byte)?; + } + write!(f, ">")?; + Ok(()) + } +} + +impl Serialize for Id20 { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_bytes(&self.0) + } +} + +impl<'de> Deserialize<'de> for Id20 { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct Visitor; + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = Id20; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a 20 byte slice") + } + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + if v.len() != 20 { + return Err(E::invalid_length(20, &self)); + } + let mut buf = [0u8; 20]; + buf.copy_from_slice(&v); + Ok(Id20(buf)) + } + } + deserializer.deserialize_bytes(Visitor {}) + } +} + +impl Id20 { + pub fn distance(&self, other: &Id20) -> Id20 { + let mut xor = [0u8; 20]; + for (idx, (s, o)) in self + .0 + .iter() + .copied() + .zip(other.0.iter().copied()) + .enumerate() + { + xor[idx] = s ^ o; + } + Id20(xor) + } +} + +impl Ord for Id20 { + fn cmp(&self, other: &Id20) -> Ordering { + for (s, o) in self.0.iter().copied().zip(other.0.iter().copied()) { + match s.cmp(&o) { + Ordering::Less => return Ordering::Less, + Ordering::Equal => continue, + Ordering::Greater => return Ordering::Greater, + } + } + Ordering::Equal + } +} + +impl PartialOrd for Id20 { + fn partial_cmp(&self, other: &Id20) -> Option { + Some(self.cmp(other)) + } +} diff --git a/crates/dht/src/lib.rs b/crates/dht/src/lib.rs index 8ae9a86..6c46a8b 100644 --- a/crates/dht/src/lib.rs +++ b/crates/dht/src/lib.rs @@ -1 +1,3 @@ pub mod bprotocol; +pub mod id20; +pub mod routing_table; diff --git a/crates/dht/src/main.rs b/crates/dht/src/main.rs index 290bc70..d5f8fb6 100644 --- a/crates/dht/src/main.rs +++ b/crates/dht/src/main.rs @@ -1,143 +1,297 @@ -use std::{collections::HashMap, net::SocketAddrV4}; +use std::{ + collections::BTreeMap, + net::{SocketAddr, SocketAddrV4}, + time::Instant, +}; -use crate::bprotocol::MessageKind; use bencode::ByteString; +use dht::{ + bprotocol::{ + self, CompactNodeInfo, CompactPeerInfo, FindNodeRequest, GetPeersRequest, Message, + MessageKind, + }, + id20::Id20, +}; +use futures::StreamExt; use librqbit_core::peer_id::generate_peer_id; -use log::debug; -use parking_lot::Mutex; +use tokio::{ + net::UdpSocket, + sync::mpsc::{channel, Receiver, Sender}, +}; +use tokio_stream::wrappers::ReceiverStream; -use crate::bprotocol::Message; - -mod bprotocol; - -struct SocketManager { - socket: tokio::net::UdpSocket, - rx: tokio::sync::mpsc::Receiver<( - SocketAddrV4, - MessageKind, - tokio::sync::oneshot::Sender>, - )>, +struct OutstandingRequest { + transaction_id: u16, + addr: SocketAddr, + request: Request, + time: Instant, } -impl SocketManager { - pub async fn spawn() -> anyhow::Result { - let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?; - let (tx, rx) = tokio::sync::mpsc::channel(1); - let mgr = SocketManager { socket, rx }; - tokio::spawn(mgr.run()); - Ok(SocketManagerHandle { tx }) +struct DhtState { + id: Id20, + next_transaction_id: u16, + outstanding_requests: Vec, + searching_for_peers: Vec, +} + +enum PeersOrNodes { + Nodes(CompactNodeInfo), + Peers(Vec), +} + +impl DhtState { + fn add_searching_for_peers(&mut self, info_hash: Id20) { + self.searching_for_peers.push(info_hash) } - pub async fn run(self) -> anyhow::Result<()> { - let Self { socket, mut rx } = self; - - let mut transaction_id = 0u16; - let mut next_transaction_id = move || { - let next = transaction_id; - transaction_id = next + 1; - next + fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message { + let transaction_id = self.next_transaction_id; + let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8]; + let message = match request { + Request::GetPeers(info_hash) => Message { + transaction_id: ByteString::from(transaction_id_buf.as_ref()), + version: None, + ip: None, + kind: MessageKind::GetPeersRequest(GetPeersRequest { + id: self.id, + info_hash, + }), + }, + Request::FindNode(target) => Message { + transaction_id: ByteString::from(transaction_id_buf.as_ref()), + version: None, + ip: None, + kind: MessageKind::FindNodeRequest(FindNodeRequest { + id: self.id, + target, + }), + }, }; - - let outstanding = Mutex::new(HashMap::< - u16, - tokio::sync::oneshot::Sender>, - >::new()); - - let writer = async { - let mut buf = Vec::new(); - while let Some((addr, msg, tx)) = rx.recv().await { - let transaction_id = next_transaction_id(); - let transaction_id_buf = - [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8]; - buf.clear(); - bprotocol::serialize_message( - &mut buf, - // this is bad, allocates - ByteString::from(transaction_id_buf.as_ref()), - None, - None, - msg, + self.outstanding_requests.push(OutstandingRequest { + transaction_id, + addr, + request, + time: Instant::now(), + }); + message + } + fn on_incoming_from_remote( + &mut self, + msg: Message, + addr: SocketAddr, + ) -> anyhow::Result<()> { + match msg.kind { + MessageKind::Error(_) | MessageKind::Response(_) => {} + other => anyhow::bail!("requests from DHT not supported, but got {:?}", other), + }; + if msg.transaction_id.len() != 2 { + anyhow::bail!("transaction id unrecognized") + } + let tid = ((msg.transaction_id[0] as u16) << 8) + (msg.transaction_id[1] as u16); + // O(n) but whatever + let outstanding_id = self + .outstanding_requests + .iter() + .position(|req| req.transaction_id == tid && req.addr == addr) + .ok_or_else(|| anyhow::anyhow!("outstanding request not found"))?; + let outstanding = self.outstanding_requests.remove(outstanding_id); + let response = match msg.kind { + MessageKind::Error(e) => { + anyhow::bail!( + "request {:?} received error response {:?}", + outstanding.request, + e ) - .unwrap(); - - debug!("inserting transaction id {}", transaction_id); - assert!(outstanding.lock().insert(transaction_id, tx).is_none()); - debug!("sending msg to {}", addr); - socket.send_to(&buf, addr).await.unwrap(); + } + MessageKind::Response(r) => r, + _ => unreachable!(), + }; + match outstanding.request { + Request::FindNode(id) => { + if response.id != id { + anyhow::bail!( + "response id does not match: expected {:?}, received {:?}", + id, + response.id + ) + }; + let nodes = response + .nodes + .ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?; + self.on_found_nodes(id, nodes) + } + Request::GetPeers(id) => { + if response.id != id { + anyhow::bail!( + "response id does not match: expected {:?}, received {:?}", + id, + response.id + ) + }; + let nodes = response + .nodes + .ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?; + let pn = match (response.nodes, response.values) { + (Some(nodes), None) => PeersOrNodes::Nodes(nodes), + (None, Some(peers)) => PeersOrNodes::Peers(peers), + _ => anyhow::bail!("expected nodes or values to be set in find_peers response"), + }; + self.on_found_peers_or_nodes(id, pn) } }; + Ok(()) + } + fn on_found_nodes(&mut self, target: Id20, nodes: CompactNodeInfo) { + todo!("on_found_nodes not implemented") + } - let reader = async { - let mut buf = vec![0u8; 16384]; - while let Ok(size) = socket.recv(&mut buf).await { - debug!("received {}", size); - let msg = match bprotocol::deserialize_message::(&buf[..size]) { - Ok(msg) => msg, - // todo handle errors - Err(e) => panic!("{}", e), - }; - assert!(msg.transaction_id.len() == 2); - let b0 = msg.transaction_id[0]; - let b1 = msg.transaction_id[1]; - let tid = ((b0 as u16) << 8) + b1 as u16; - let tx = outstanding.lock().remove(&tid).unwrap(); - debug!("sending oneshot result, tid {}", tid); - tx.send(msg).unwrap(); + fn on_found_peers_or_nodes(&mut self, target: Id20, data: PeersOrNodes) { + todo!("on_found_nodes not implemented") + } +} + +async fn run_framer( + socket: &UdpSocket, + mut input_rx: Receiver<(Message, SocketAddr)>, + output_tx: Sender>, +) -> anyhow::Result<()> { + let writer = async { + let mut buf = Vec::new(); + while let Some((msg, addr)) = input_rx.recv().await { + buf.clear(); + bprotocol::serialize_message( + &mut buf, + msg.transaction_id, + msg.version, + msg.ip, + msg.kind, + ) + .unwrap(); + socket.send_to(&buf, addr).await.unwrap(); + } + }; + let reader = async { + let mut buf = vec![0u8; 16384]; + while let Ok((size, addr)) = socket.recv_from(&mut buf).await { + match bprotocol::deserialize_message::(&buf[..size]) { + Ok(msg) => match output_tx.send(msg).await { + Ok(_) => {} + Err(_) => break, + }, + Err(e) => log::warn!("error deseriaizing msg: {}", e), + } + } + }; + tokio::select! { + _ = writer => {}, + _ = reader => {}, + }; + Ok(()) +} + +#[derive(Debug, Clone, Copy)] +enum Request { + GetPeers(Id20), + FindNode(Id20), +} + +#[derive(Debug)] +enum Response { + Peer(SocketAddr), +} + +struct Dht { + request_tx: Sender<(Request, Sender)>, +} + +struct DhtWorker { + socket: UdpSocket, + request_rx: Receiver<(Request, Sender)>, + next_transaction_id: u16, + peer_id: Id20, +} + +impl DhtWorker { + fn on_request(&self, request: Request, sender: Sender) {} + + async fn start(&mut self, bootstrap_addrs: Vec) -> anyhow::Result<()> { + let (in_tx, in_rx) = channel(1); + let (out_tx, out_rx) = channel(1); + let framer = run_framer(&self.socket, in_rx, out_tx); + + let bootstrap = async { + // bootstrap + for addr in bootstrap_addrs { + for addr in tokio::net::lookup_host(addr).await.unwrap() { + let msg = MessageKind::FindNodeRequest(FindNodeRequest { + id: self.peer_id, + target: self.peer_id, + }); + in_tx.send((msg, addr)).await.unwrap(); + } + } + }; + let mut bootstrap_done = false; + + let request_reader = async { + while let Some((request, sender)) = self.request_rx.recv().await { + self.on_request(request, sender) } }; tokio::select! { - _ = writer => {}, - _ = reader => {} + _ = framer => { + anyhow::bail!("framer quit") + }, + _ = bootstrap, if !bootstrap_done => { + bootstrap_done = true + }, + _ = request_reader => {} } - Ok(()) + todo!() } } -#[derive(Clone)] -struct SocketManagerHandle { - tx: tokio::sync::mpsc::Sender<( - SocketAddrV4, - MessageKind, - tokio::sync::oneshot::Sender>, - )>, -} - -impl SocketManagerHandle { - async fn request( - &self, - addr: SocketAddrV4, - kind: MessageKind, - ) -> anyhow::Result> { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.tx.send((addr, kind, tx)).await?; - let msg = rx.await?; - Ok(msg) +impl Dht { + pub async fn new(bootstrap_addrs: &[&str]) -> anyhow::Result { + let (request_tx, request_rx) = channel(1); + let socket = UdpSocket::bind("0.0.0.0:0").await?; + let mut worker = DhtWorker { + socket, + request_rx, + next_transaction_id: 0, + peer_id: Id20(generate_peer_id()), + }; + let bootstrap_addrs = bootstrap_addrs.iter().map(|s| s.to_string()).collect(); + tokio::spawn(async move { worker.start(bootstrap_addrs).await }); + Ok(Dht { request_tx }) } + pub async fn get_peers(&self, info_hash: Id20) -> impl StreamExt { + let (tx, rx) = channel::(1); + self.request_tx + .send((Request::GetPeers(info_hash), tx)) + .await + .unwrap(); + ReceiverStream::new(rx).map(|r| match r { + Response::Peer(addr) => addr, + _ => panic!("programming error"), + }) + } + // async fn run(self) -> anyhow::Result { + // let socket = UdpSocket::bind("0.0.0.0:0").await?; + // let (in_tx, in_rx) = channel(1); + // let (out_tx, out_rx) = channel(1); + // let framer = run_framer(socket, in_rx, out_tx); + // } } #[tokio::main] -async fn main() { - std::env::set_var("RUST_LOG", "trace"); - pretty_env_logger::init(); - - let mgr = SocketManager::spawn().await.unwrap(); - - let peer_id = bprotocol::Id20(generate_peer_id()); - for first_addr in tokio::net::lookup_host("dht.transmissionbt.com:6881") - .await - .unwrap() - .filter_map(|a| match a { - std::net::SocketAddr::V4(v4) => Some(v4), - std::net::SocketAddr::V6(_) => None, - }) - .skip(1) - { - let msg = bprotocol::MessageKind::FindNodeRequest(bprotocol::FindNodeRequest { - id: peer_id, - target: peer_id, - }); - - dbg!(mgr.request(first_addr, msg).await.unwrap()); +async fn main() -> anyhow::Result<()> { + let info_hash = Id20([0u8; 20]); + let dht = Dht::new(&["dht.transmissionbt.com:6881"]).await.unwrap(); + let mut stream = dht.get_peers(info_hash).await; + while let Some(peer) = stream.next().await { + log::info!("peer found: {}", peer) } + Ok(()) } diff --git a/crates/dht/src/routing_table.rs b/crates/dht/src/routing_table.rs new file mode 100644 index 0000000..ff51bcd --- /dev/null +++ b/crates/dht/src/routing_table.rs @@ -0,0 +1,84 @@ +use std::{ + collections::BTreeMap, + net::SocketAddr, + time::{Duration, Instant}, +}; + +use crate::id20::Id20; + +pub struct RoutingTableNode { + id: Id20, + addr: SocketAddr, + last_request: Option, + last_response: Option, + outstanding_queries_in_a_row: usize, +} + +pub enum NodeStatus { + Good, + Questionable, + Bad, + Unknown, +} + +impl RoutingTableNode { + pub fn id(&self) -> Id20 { + self.id + } + pub fn addr(&self) -> SocketAddr { + self.addr + } + pub fn status(&self) -> NodeStatus { + // TODO: this is just a stub with simpler logic + let last_request = match self.last_request { + Some(v) => v, + None => return NodeStatus::Unknown, + }; + if self.last_response.is_some() { + return NodeStatus::Good; + } + NodeStatus::Questionable + } +} + +struct Bucket { + bits: u8, + nodes: Vec, + end: Id20, +} + +pub struct RoutingTable { + id: Id20, + size: usize, + buckets: BTreeMap, +} + +impl RoutingTable { + pub fn new(id: Id20) -> Self { + let initial_bucket = Id20([0u8; 20]); + let mut buckets = BTreeMap::new(); + buckets.insert( + initial_bucket, + Bucket { + bits: 160, + nodes: Vec::new(), + }, + ); + Self { + id, + buckets, + size: 0, + } + } + pub fn sorted_by_distance_from(&self, id: Id20) -> Vec<&RoutingTableNode> { + let mut result = Vec::with_capacity(self.size); + for bucket in self.buckets.values() { + for node in bucket.nodes.iter() { + result.push(node); + } + } + result.sort_by_key(|n| id.distance(&n.id)); + result + } + pub fn add_node(&mut self, id: Id20, addr: SocketAddr) -> bool {} +}