From f6656841c0264327a35593e62e5a6770a92cb799 Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Sat, 10 Jul 2021 23:56:42 +0100 Subject: [PATCH] Playing with DHT over UDP --- Cargo.lock | 4 + crates/dht/Cargo.toml | 11 ++- crates/dht/src/bprotocol.rs | 27 ++++--- crates/dht/src/lib.rs | 3 +- crates/dht/src/main.rs | 143 ++++++++++++++++++++++++++++++++++++ 5 files changed, 173 insertions(+), 15 deletions(-) create mode 100644 crates/dht/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index dca0d1b..59b9499 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -278,6 +278,10 @@ dependencies = [ "bencode", "hex 0.4.3", "kad", + "librqbit_core", + "log", + "parking_lot", + "pretty_env_logger", "serde", "tokio", ] diff --git a/crates/dht/Cargo.toml b/crates/dht/Cargo.toml index c219f43..93a2327 100644 --- a/crates/dht/Cargo.toml +++ b/crates/dht/Cargo.toml @@ -7,8 +7,15 @@ edition = "2018" [dependencies] kad = "0.6" -tokio = {version = "1", features = ["macros", "rt"]} +tokio = {version = "1", features = ["macros", "rt-multi-thread", "net", "sync"]} serde = {version = "1", features = ["derive"]} hex = "0.4" bencode = {path = "../bencode"} -anyhow = "1" \ No newline at end of file +anyhow = "1" +parking_lot = "0.11" +log = "0.4" +pretty_env_logger = "0.4" + +librqbit_core = {path="../librqbit_core"} + +[dev-dependencies] diff --git a/crates/dht/src/bprotocol.rs b/crates/dht/src/bprotocol.rs index eac9f65..f4e7093 100644 --- a/crates/dht/src/bprotocol.rs +++ b/crates/dht/src/bprotocol.rs @@ -17,6 +17,7 @@ enum MessageType { Error, } +#[derive(Clone, Copy)] pub struct Id20(pub [u8; 20]); impl std::fmt::Debug for Id20 { @@ -110,9 +111,9 @@ impl Serialize for MessageType { } #[derive(Debug)] -struct ErrorDescription { - code: i32, - description: BufT, +pub struct ErrorDescription { + pub code: i32, + pub description: BufT, } impl Serialize for ErrorDescription @@ -325,7 +326,7 @@ pub struct FindNodeRequest { } #[derive(Debug, Serialize, Deserialize)] -struct Response { +pub struct Response { pub id: Id20, #[serde(skip_serializing_if = "Option::is_none")] pub nodes: Option, @@ -437,7 +438,7 @@ pub fn deserialize_message<'de, BufT>(buf: &'de [u8]) -> anyhow::Result + AsRef<[u8]>, { - let de: RawMessage = bencode::from_bytes(buf)?; + let de: RawMessage = bencode::from_bytes(buf)?; match de.message_type { MessageType::Request => match (de.arguments, de.method_name, de.response, de.error) { (Some(_), Some(method_name), None, None) => match method_name.as_ref() { @@ -480,12 +481,15 @@ where ), }, MessageType::Error => match (de.arguments, de.method_name, de.response, de.error) { - (None, None, None, Some(e)) => Ok(Message { - transaction_id: de.transaction_id, - version: de.version, - ip: de.ip.map(|c| c.addr), - kind: MessageKind::Error(e), - }), + (None, None, None, Some(e)) => { + let de: RawMessage> = bencode::from_bytes(buf)?; + Ok(Message { + transaction_id: de.transaction_id, + version: de.version, + ip: de.ip.map(|c| c.addr), + kind: MessageKind::Error(de.error.unwrap()), + }) + } _ => anyhow::bail!( "cannot deserialize message as response, expected exactly \"r\" to be set" ), @@ -499,6 +503,7 @@ mod tests { use crate::bprotocol; use bencode::ByteBuf; + use librqbit_core::peer_id::generate_peer_id; // Dumped with wireshark. const FIND_NODE_REQUEST: &[u8] = b"64313a6164323a696432303abd7b477cfbcd10f30b705da20201e7101d8df155363a74617267657432303abd7b477cfbcd10f30b705da20201e7101d8df15565313a71393a66696e645f6e6f6465313a74323a0005313a79313a7165"; diff --git a/crates/dht/src/lib.rs b/crates/dht/src/lib.rs index bb9a673..8ae9a86 100644 --- a/crates/dht/src/lib.rs +++ b/crates/dht/src/lib.rs @@ -1,2 +1 @@ -mod bprotocol; - +pub mod bprotocol; diff --git a/crates/dht/src/main.rs b/crates/dht/src/main.rs new file mode 100644 index 0000000..290bc70 --- /dev/null +++ b/crates/dht/src/main.rs @@ -0,0 +1,143 @@ +use std::{collections::HashMap, net::SocketAddrV4}; + +use crate::bprotocol::MessageKind; +use bencode::ByteString; +use librqbit_core::peer_id::generate_peer_id; +use log::debug; +use parking_lot::Mutex; + +use crate::bprotocol::Message; + +mod bprotocol; + +struct SocketManager { + socket: tokio::net::UdpSocket, + rx: tokio::sync::mpsc::Receiver<( + SocketAddrV4, + MessageKind, + tokio::sync::oneshot::Sender>, + )>, +} + +impl SocketManager { + pub async fn spawn() -> anyhow::Result { + let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?; + let (tx, rx) = tokio::sync::mpsc::channel(1); + let mgr = SocketManager { socket, rx }; + tokio::spawn(mgr.run()); + Ok(SocketManagerHandle { tx }) + } + pub async fn run(self) -> anyhow::Result<()> { + let Self { socket, mut rx } = self; + + let mut transaction_id = 0u16; + let mut next_transaction_id = move || { + let next = transaction_id; + transaction_id = next + 1; + next + }; + + let outstanding = Mutex::new(HashMap::< + u16, + tokio::sync::oneshot::Sender>, + >::new()); + + let writer = async { + let mut buf = Vec::new(); + while let Some((addr, msg, tx)) = rx.recv().await { + let transaction_id = next_transaction_id(); + let transaction_id_buf = + [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8]; + buf.clear(); + bprotocol::serialize_message( + &mut buf, + // this is bad, allocates + ByteString::from(transaction_id_buf.as_ref()), + None, + None, + msg, + ) + .unwrap(); + + debug!("inserting transaction id {}", transaction_id); + assert!(outstanding.lock().insert(transaction_id, tx).is_none()); + debug!("sending msg to {}", addr); + socket.send_to(&buf, addr).await.unwrap(); + } + }; + + let reader = async { + let mut buf = vec![0u8; 16384]; + while let Ok(size) = socket.recv(&mut buf).await { + debug!("received {}", size); + let msg = match bprotocol::deserialize_message::(&buf[..size]) { + Ok(msg) => msg, + // todo handle errors + Err(e) => panic!("{}", e), + }; + assert!(msg.transaction_id.len() == 2); + let b0 = msg.transaction_id[0]; + let b1 = msg.transaction_id[1]; + let tid = ((b0 as u16) << 8) + b1 as u16; + let tx = outstanding.lock().remove(&tid).unwrap(); + debug!("sending oneshot result, tid {}", tid); + tx.send(msg).unwrap(); + } + }; + + tokio::select! { + _ = writer => {}, + _ = reader => {} + } + + Ok(()) + } +} + +#[derive(Clone)] +struct SocketManagerHandle { + tx: tokio::sync::mpsc::Sender<( + SocketAddrV4, + MessageKind, + tokio::sync::oneshot::Sender>, + )>, +} + +impl SocketManagerHandle { + async fn request( + &self, + addr: SocketAddrV4, + kind: MessageKind, + ) -> anyhow::Result> { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.tx.send((addr, kind, tx)).await?; + let msg = rx.await?; + Ok(msg) + } +} + +#[tokio::main] +async fn main() { + std::env::set_var("RUST_LOG", "trace"); + pretty_env_logger::init(); + + let mgr = SocketManager::spawn().await.unwrap(); + + let peer_id = bprotocol::Id20(generate_peer_id()); + for first_addr in tokio::net::lookup_host("dht.transmissionbt.com:6881") + .await + .unwrap() + .filter_map(|a| match a { + std::net::SocketAddr::V4(v4) => Some(v4), + std::net::SocketAddr::V6(_) => None, + }) + .skip(1) + { + let msg = bprotocol::MessageKind::FindNodeRequest(bprotocol::FindNodeRequest { + id: peer_id, + target: peer_id, + }); + + dbg!(mgr.request(first_addr, msg).await.unwrap()); + } +}