This commit is contained in:
Igor Katson 2021-07-12 11:56:26 +01:00
parent f6656841c0
commit 950d47ab31
7 changed files with 451 additions and 167 deletions

2
Cargo.lock generated
View file

@ -276,6 +276,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bencode", "bencode",
"futures 0.3.15",
"hex 0.4.3", "hex 0.4.3",
"kad", "kad",
"librqbit_core", "librqbit_core",
@ -284,6 +285,7 @@ dependencies = [
"pretty_env_logger", "pretty_env_logger",
"serde", "serde",
"tokio", "tokio",
"tokio-stream",
] ]
[[package]] [[package]]

View file

@ -8,6 +8,7 @@ edition = "2018"
[dependencies] [dependencies]
kad = "0.6" kad = "0.6"
tokio = {version = "1", features = ["macros", "rt-multi-thread", "net", "sync"]} tokio = {version = "1", features = ["macros", "rt-multi-thread", "net", "sync"]}
tokio-stream = "0.1"
serde = {version = "1", features = ["derive"]} serde = {version = "1", features = ["derive"]}
hex = "0.4" hex = "0.4"
bencode = {path = "../bencode"} bencode = {path = "../bencode"}
@ -15,6 +16,7 @@ anyhow = "1"
parking_lot = "0.11" parking_lot = "0.11"
log = "0.4" log = "0.4"
pretty_env_logger = "0.4" pretty_env_logger = "0.4"
futures = "0.3"
librqbit_core = {path="../librqbit_core"} librqbit_core = {path="../librqbit_core"}

View file

@ -10,6 +10,8 @@ use serde::{
Deserialize, Deserializer, Serialize, Deserialize, Deserializer, Serialize,
}; };
use crate::id20::Id20;
#[derive(Debug)] #[derive(Debug)]
enum MessageType { enum MessageType {
Request, Request,
@ -17,57 +19,6 @@ enum MessageType {
Error, Error,
} }
#[derive(Clone, Copy)]
pub struct Id20(pub [u8; 20]);
impl std::fmt::Debug for Id20 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<")?;
for byte in self.0 {
write!(f, "{:02x?}", byte)?;
}
write!(f, ">")?;
Ok(())
}
}
impl Serialize for Id20 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_bytes(&self.0)
}
}
impl<'de> Deserialize<'de> for Id20 {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = Id20;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a 20 byte slice")
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if v.len() != 20 {
return Err(E::invalid_length(20, &self));
}
let mut buf = [0u8; 20];
buf.copy_from_slice(&v);
Ok(Id20(buf))
}
}
deserializer.deserialize_bytes(Visitor {})
}
}
impl<'de> Deserialize<'de> for MessageType { impl<'de> Deserialize<'de> for MessageType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where where

89
crates/dht/src/id20.rs Normal file
View file

@ -0,0 +1,89 @@
use std::cmp::Ordering;
use serde::{Deserialize, Deserializer, Serialize};
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct Id20(pub [u8; 20]);
impl std::fmt::Debug for Id20 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<")?;
for byte in self.0 {
write!(f, "{:02x?}", byte)?;
}
write!(f, ">")?;
Ok(())
}
}
impl Serialize for Id20 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_bytes(&self.0)
}
}
impl<'de> Deserialize<'de> for Id20 {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = Id20;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a 20 byte slice")
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if v.len() != 20 {
return Err(E::invalid_length(20, &self));
}
let mut buf = [0u8; 20];
buf.copy_from_slice(&v);
Ok(Id20(buf))
}
}
deserializer.deserialize_bytes(Visitor {})
}
}
impl Id20 {
pub fn distance(&self, other: &Id20) -> Id20 {
let mut xor = [0u8; 20];
for (idx, (s, o)) in self
.0
.iter()
.copied()
.zip(other.0.iter().copied())
.enumerate()
{
xor[idx] = s ^ o;
}
Id20(xor)
}
}
impl Ord for Id20 {
fn cmp(&self, other: &Id20) -> Ordering {
for (s, o) in self.0.iter().copied().zip(other.0.iter().copied()) {
match s.cmp(&o) {
Ordering::Less => return Ordering::Less,
Ordering::Equal => continue,
Ordering::Greater => return Ordering::Greater,
}
}
Ordering::Equal
}
}
impl PartialOrd<Id20> for Id20 {
fn partial_cmp(&self, other: &Id20) -> Option<Ordering> {
Some(self.cmp(other))
}
}

View file

@ -1 +1,3 @@
pub mod bprotocol; pub mod bprotocol;
pub mod id20;
pub mod routing_table;

View file

@ -1,143 +1,297 @@
use std::{collections::HashMap, net::SocketAddrV4}; use std::{
collections::BTreeMap,
net::{SocketAddr, SocketAddrV4},
time::Instant,
};
use crate::bprotocol::MessageKind;
use bencode::ByteString; use bencode::ByteString;
use dht::{
bprotocol::{
self, CompactNodeInfo, CompactPeerInfo, FindNodeRequest, GetPeersRequest, Message,
MessageKind,
},
id20::Id20,
};
use futures::StreamExt;
use librqbit_core::peer_id::generate_peer_id; use librqbit_core::peer_id::generate_peer_id;
use log::debug; use tokio::{
use parking_lot::Mutex; net::UdpSocket,
sync::mpsc::{channel, Receiver, Sender},
};
use tokio_stream::wrappers::ReceiverStream;
use crate::bprotocol::Message; struct OutstandingRequest {
transaction_id: u16,
mod bprotocol; addr: SocketAddr,
request: Request,
struct SocketManager { time: Instant,
socket: tokio::net::UdpSocket,
rx: tokio::sync::mpsc::Receiver<(
SocketAddrV4,
MessageKind<ByteString>,
tokio::sync::oneshot::Sender<Message<ByteString>>,
)>,
} }
impl SocketManager { struct DhtState {
pub async fn spawn() -> anyhow::Result<SocketManagerHandle> { id: Id20,
let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?; next_transaction_id: u16,
let (tx, rx) = tokio::sync::mpsc::channel(1); outstanding_requests: Vec<OutstandingRequest>,
let mgr = SocketManager { socket, rx }; searching_for_peers: Vec<Id20>,
tokio::spawn(mgr.run()); }
Ok(SocketManagerHandle { tx })
enum PeersOrNodes {
Nodes(CompactNodeInfo),
Peers(Vec<CompactPeerInfo>),
}
impl DhtState {
fn add_searching_for_peers(&mut self, info_hash: Id20) {
self.searching_for_peers.push(info_hash)
} }
pub async fn run(self) -> anyhow::Result<()> { fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message<ByteString> {
let Self { socket, mut rx } = self; let transaction_id = self.next_transaction_id;
let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8];
let mut transaction_id = 0u16; let message = match request {
let mut next_transaction_id = move || { Request::GetPeers(info_hash) => Message {
let next = transaction_id; transaction_id: ByteString::from(transaction_id_buf.as_ref()),
transaction_id = next + 1; version: None,
next ip: None,
kind: MessageKind::GetPeersRequest(GetPeersRequest {
id: self.id,
info_hash,
}),
},
Request::FindNode(target) => Message {
transaction_id: ByteString::from(transaction_id_buf.as_ref()),
version: None,
ip: None,
kind: MessageKind::FindNodeRequest(FindNodeRequest {
id: self.id,
target,
}),
},
}; };
self.outstanding_requests.push(OutstandingRequest {
let outstanding = Mutex::new(HashMap::< transaction_id,
u16, addr,
tokio::sync::oneshot::Sender<Message<ByteString>>, request,
>::new()); time: Instant::now(),
});
let writer = async { message
let mut buf = Vec::new(); }
while let Some((addr, msg, tx)) = rx.recv().await { fn on_incoming_from_remote(
let transaction_id = next_transaction_id(); &mut self,
let transaction_id_buf = msg: Message<ByteString>,
[(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8]; addr: SocketAddr,
buf.clear(); ) -> anyhow::Result<()> {
bprotocol::serialize_message( match msg.kind {
&mut buf, MessageKind::Error(_) | MessageKind::Response(_) => {}
// this is bad, allocates other => anyhow::bail!("requests from DHT not supported, but got {:?}", other),
ByteString::from(transaction_id_buf.as_ref()), };
None, if msg.transaction_id.len() != 2 {
None, anyhow::bail!("transaction id unrecognized")
msg, }
let tid = ((msg.transaction_id[0] as u16) << 8) + (msg.transaction_id[1] as u16);
// O(n) but whatever
let outstanding_id = self
.outstanding_requests
.iter()
.position(|req| req.transaction_id == tid && req.addr == addr)
.ok_or_else(|| anyhow::anyhow!("outstanding request not found"))?;
let outstanding = self.outstanding_requests.remove(outstanding_id);
let response = match msg.kind {
MessageKind::Error(e) => {
anyhow::bail!(
"request {:?} received error response {:?}",
outstanding.request,
e
) )
.unwrap(); }
MessageKind::Response(r) => r,
debug!("inserting transaction id {}", transaction_id); _ => unreachable!(),
assert!(outstanding.lock().insert(transaction_id, tx).is_none()); };
debug!("sending msg to {}", addr); match outstanding.request {
socket.send_to(&buf, addr).await.unwrap(); Request::FindNode(id) => {
if response.id != id {
anyhow::bail!(
"response id does not match: expected {:?}, received {:?}",
id,
response.id
)
};
let nodes = response
.nodes
.ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?;
self.on_found_nodes(id, nodes)
}
Request::GetPeers(id) => {
if response.id != id {
anyhow::bail!(
"response id does not match: expected {:?}, received {:?}",
id,
response.id
)
};
let nodes = response
.nodes
.ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?;
let pn = match (response.nodes, response.values) {
(Some(nodes), None) => PeersOrNodes::Nodes(nodes),
(None, Some(peers)) => PeersOrNodes::Peers(peers),
_ => anyhow::bail!("expected nodes or values to be set in find_peers response"),
};
self.on_found_peers_or_nodes(id, pn)
} }
}; };
Ok(())
}
fn on_found_nodes(&mut self, target: Id20, nodes: CompactNodeInfo) {
todo!("on_found_nodes not implemented")
}
let reader = async { fn on_found_peers_or_nodes(&mut self, target: Id20, data: PeersOrNodes) {
let mut buf = vec![0u8; 16384]; todo!("on_found_nodes not implemented")
while let Ok(size) = socket.recv(&mut buf).await { }
debug!("received {}", size); }
let msg = match bprotocol::deserialize_message::<ByteString>(&buf[..size]) {
Ok(msg) => msg, async fn run_framer(
// todo handle errors socket: &UdpSocket,
Err(e) => panic!("{}", e), mut input_rx: Receiver<(Message<ByteString>, SocketAddr)>,
}; output_tx: Sender<Message<ByteString>>,
assert!(msg.transaction_id.len() == 2); ) -> anyhow::Result<()> {
let b0 = msg.transaction_id[0]; let writer = async {
let b1 = msg.transaction_id[1]; let mut buf = Vec::new();
let tid = ((b0 as u16) << 8) + b1 as u16; while let Some((msg, addr)) = input_rx.recv().await {
let tx = outstanding.lock().remove(&tid).unwrap(); buf.clear();
debug!("sending oneshot result, tid {}", tid); bprotocol::serialize_message(
tx.send(msg).unwrap(); &mut buf,
msg.transaction_id,
msg.version,
msg.ip,
msg.kind,
)
.unwrap();
socket.send_to(&buf, addr).await.unwrap();
}
};
let reader = async {
let mut buf = vec![0u8; 16384];
while let Ok((size, addr)) = socket.recv_from(&mut buf).await {
match bprotocol::deserialize_message::<ByteString>(&buf[..size]) {
Ok(msg) => match output_tx.send(msg).await {
Ok(_) => {}
Err(_) => break,
},
Err(e) => log::warn!("error deseriaizing msg: {}", e),
}
}
};
tokio::select! {
_ = writer => {},
_ = reader => {},
};
Ok(())
}
#[derive(Debug, Clone, Copy)]
enum Request {
GetPeers(Id20),
FindNode(Id20),
}
#[derive(Debug)]
enum Response {
Peer(SocketAddr),
}
struct Dht {
request_tx: Sender<(Request, Sender<Response>)>,
}
struct DhtWorker {
socket: UdpSocket,
request_rx: Receiver<(Request, Sender<Response>)>,
next_transaction_id: u16,
peer_id: Id20,
}
impl DhtWorker {
fn on_request(&self, request: Request, sender: Sender<Response>) {}
async fn start(&mut self, bootstrap_addrs: Vec<String>) -> anyhow::Result<()> {
let (in_tx, in_rx) = channel(1);
let (out_tx, out_rx) = channel(1);
let framer = run_framer(&self.socket, in_rx, out_tx);
let bootstrap = async {
// bootstrap
for addr in bootstrap_addrs {
for addr in tokio::net::lookup_host(addr).await.unwrap() {
let msg = MessageKind::FindNodeRequest(FindNodeRequest {
id: self.peer_id,
target: self.peer_id,
});
in_tx.send((msg, addr)).await.unwrap();
}
}
};
let mut bootstrap_done = false;
let request_reader = async {
while let Some((request, sender)) = self.request_rx.recv().await {
self.on_request(request, sender)
} }
}; };
tokio::select! { tokio::select! {
_ = writer => {}, _ = framer => {
_ = reader => {} anyhow::bail!("framer quit")
},
_ = bootstrap, if !bootstrap_done => {
bootstrap_done = true
},
_ = request_reader => {}
} }
Ok(()) todo!()
} }
} }
#[derive(Clone)] impl Dht {
struct SocketManagerHandle { pub async fn new(bootstrap_addrs: &[&str]) -> anyhow::Result<Self> {
tx: tokio::sync::mpsc::Sender<( let (request_tx, request_rx) = channel(1);
SocketAddrV4, let socket = UdpSocket::bind("0.0.0.0:0").await?;
MessageKind<ByteString>, let mut worker = DhtWorker {
tokio::sync::oneshot::Sender<Message<ByteString>>, socket,
)>, request_rx,
} next_transaction_id: 0,
peer_id: Id20(generate_peer_id()),
impl SocketManagerHandle { };
async fn request( let bootstrap_addrs = bootstrap_addrs.iter().map(|s| s.to_string()).collect();
&self, tokio::spawn(async move { worker.start(bootstrap_addrs).await });
addr: SocketAddrV4, Ok(Dht { request_tx })
kind: MessageKind<ByteString>,
) -> anyhow::Result<bprotocol::Message<ByteString>> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.tx.send((addr, kind, tx)).await?;
let msg = rx.await?;
Ok(msg)
} }
pub async fn get_peers(&self, info_hash: Id20) -> impl StreamExt<Item = SocketAddr> {
let (tx, rx) = channel::<Response>(1);
self.request_tx
.send((Request::GetPeers(info_hash), tx))
.await
.unwrap();
ReceiverStream::new(rx).map(|r| match r {
Response::Peer(addr) => addr,
_ => panic!("programming error"),
})
}
// async fn run(self) -> anyhow::Result<Self> {
// let socket = UdpSocket::bind("0.0.0.0:0").await?;
// let (in_tx, in_rx) = channel(1);
// let (out_tx, out_rx) = channel(1);
// let framer = run_framer(socket, in_rx, out_tx);
// }
} }
#[tokio::main] #[tokio::main]
async fn main() { async fn main() -> anyhow::Result<()> {
std::env::set_var("RUST_LOG", "trace"); let info_hash = Id20([0u8; 20]);
pretty_env_logger::init(); let dht = Dht::new(&["dht.transmissionbt.com:6881"]).await.unwrap();
let mut stream = dht.get_peers(info_hash).await;
let mgr = SocketManager::spawn().await.unwrap(); while let Some(peer) = stream.next().await {
log::info!("peer found: {}", peer)
let peer_id = bprotocol::Id20(generate_peer_id());
for first_addr in tokio::net::lookup_host("dht.transmissionbt.com:6881")
.await
.unwrap()
.filter_map(|a| match a {
std::net::SocketAddr::V4(v4) => Some(v4),
std::net::SocketAddr::V6(_) => None,
})
.skip(1)
{
let msg = bprotocol::MessageKind::FindNodeRequest(bprotocol::FindNodeRequest {
id: peer_id,
target: peer_id,
});
dbg!(mgr.request(first_addr, msg).await.unwrap());
} }
Ok(())
} }

View file

@ -0,0 +1,84 @@
use std::{
collections::BTreeMap,
net::SocketAddr,
time::{Duration, Instant},
};
use crate::id20::Id20;
pub struct RoutingTableNode {
id: Id20,
addr: SocketAddr,
last_request: Option<Instant>,
last_response: Option<Instant>,
outstanding_queries_in_a_row: usize,
}
pub enum NodeStatus {
Good,
Questionable,
Bad,
Unknown,
}
impl RoutingTableNode {
pub fn id(&self) -> Id20 {
self.id
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub fn status(&self) -> NodeStatus {
// TODO: this is just a stub with simpler logic
let last_request = match self.last_request {
Some(v) => v,
None => return NodeStatus::Unknown,
};
if self.last_response.is_some() {
return NodeStatus::Good;
}
NodeStatus::Questionable
}
}
struct Bucket {
bits: u8,
nodes: Vec<RoutingTableNode>,
end: Id20,
}
pub struct RoutingTable {
id: Id20,
size: usize,
buckets: BTreeMap<Id20, Bucket>,
}
impl RoutingTable {
pub fn new(id: Id20) -> Self {
let initial_bucket = Id20([0u8; 20]);
let mut buckets = BTreeMap::new();
buckets.insert(
initial_bucket,
Bucket {
bits: 160,
nodes: Vec::new(),
},
);
Self {
id,
buckets,
size: 0,
}
}
pub fn sorted_by_distance_from(&self, id: Id20) -> Vec<&RoutingTableNode> {
let mut result = Vec::with_capacity(self.size);
for bucket in self.buckets.values() {
for node in bucket.nodes.iter() {
result.push(node);
}
}
result.sort_by_key(|n| id.distance(&n.id));
result
}
pub fn add_node(&mut self, id: Id20, addr: SocketAddr) -> bool {}
}