Will start to test soon
This commit is contained in:
parent
dc6fc6bba2
commit
d57079c75a
3 changed files with 285 additions and 74 deletions
|
|
@ -2,7 +2,7 @@ use std::cmp::Ordering;
|
|||
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct Id20(pub [u8; 20]);
|
||||
|
||||
impl std::fmt::Debug for Id20 {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use std::{
|
||||
collections::BTreeMap,
|
||||
cell::RefCell,
|
||||
collections::{BTreeMap, HashMap},
|
||||
net::{SocketAddr, SocketAddrV4},
|
||||
time::Instant,
|
||||
};
|
||||
|
|
@ -11,12 +12,15 @@ use dht::{
|
|||
MessageKind,
|
||||
},
|
||||
id20::Id20,
|
||||
routing_table::RoutingTable,
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use librqbit_core::peer_id::generate_peer_id;
|
||||
use log::{debug, warn};
|
||||
use parking_lot::Mutex;
|
||||
use tokio::{
|
||||
net::UdpSocket,
|
||||
sync::mpsc::{channel, Receiver, Sender},
|
||||
sync::mpsc::{channel, Receiver, Sender, UnboundedReceiver, UnboundedSender},
|
||||
};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
|
|
@ -32,6 +36,11 @@ struct DhtState {
|
|||
next_transaction_id: u16,
|
||||
outstanding_requests: Vec<OutstandingRequest>,
|
||||
searching_for_peers: Vec<Id20>,
|
||||
routing_table: RoutingTable,
|
||||
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
|
||||
|
||||
// TODO: convert to broadcast
|
||||
subscribers: HashMap<Id20, Vec<Sender<Response>>>,
|
||||
}
|
||||
|
||||
enum PeersOrNodes {
|
||||
|
|
@ -40,10 +49,22 @@ enum PeersOrNodes {
|
|||
}
|
||||
|
||||
impl DhtState {
|
||||
pub fn new(id: Id20, sender: UnboundedSender<(Message<ByteString>, SocketAddr)>) -> Self {
|
||||
Self {
|
||||
id,
|
||||
next_transaction_id: 0,
|
||||
outstanding_requests: Vec::new(),
|
||||
searching_for_peers: Vec::new(),
|
||||
routing_table: RoutingTable::new(id),
|
||||
sender,
|
||||
subscribers: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_searching_for_peers(&mut self, info_hash: Id20) {
|
||||
self.searching_for_peers.push(info_hash)
|
||||
}
|
||||
fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message<ByteString> {
|
||||
pub fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message<ByteString> {
|
||||
let transaction_id = self.next_transaction_id;
|
||||
let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8];
|
||||
let message = match request {
|
||||
|
|
@ -107,17 +128,10 @@ impl DhtState {
|
|||
};
|
||||
match outstanding.request {
|
||||
Request::FindNode(id) => {
|
||||
if response.id != id {
|
||||
anyhow::bail!(
|
||||
"response id does not match: expected {:?}, received {:?}",
|
||||
id,
|
||||
response.id
|
||||
)
|
||||
};
|
||||
let nodes = response
|
||||
.nodes
|
||||
.ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?;
|
||||
self.on_found_nodes(id, nodes)
|
||||
self.on_found_nodes(response.id, addr, id, nodes)
|
||||
}
|
||||
Request::GetPeers(id) => {
|
||||
if response.id != id {
|
||||
|
|
@ -127,32 +141,69 @@ impl DhtState {
|
|||
response.id
|
||||
)
|
||||
};
|
||||
let nodes = response
|
||||
.nodes
|
||||
.ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?;
|
||||
// let pn = match (response.nodes, response.values) {
|
||||
// (Some(nodes), None) => PeersOrNodes::Nodes(nodes),
|
||||
// (None, Some(peers)) => PeersOrNodes::Peers(peers),
|
||||
// _ => anyhow::bail!("expected nodes or values to be set in find_peers response"),
|
||||
// };
|
||||
// self.on_found_peers_or_nodes(id, pn)
|
||||
let pn = match (response.nodes, response.values) {
|
||||
(Some(nodes), None) => PeersOrNodes::Nodes(nodes),
|
||||
(None, Some(peers)) => PeersOrNodes::Peers(peers),
|
||||
_ => anyhow::bail!("expected nodes or values to be set in find_peers response"),
|
||||
};
|
||||
self.on_found_peers_or_nodes(response.id, addr, id, pn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn on_request(&mut self, request: Request, sender: Sender<Response>) -> anyhow::Result<()> {
|
||||
match request {
|
||||
Request::GetPeers(info_hash) => {
|
||||
let subs = self.subscribers.entry(info_hash).or_default();
|
||||
subs.push(sender);
|
||||
self.add_searching_for_peers(info_hash);
|
||||
|
||||
// workaround borrow checker.
|
||||
let mut addrs = Vec::new();
|
||||
for node in self
|
||||
.routing_table
|
||||
.sorted_by_distance_from_mut(info_hash)
|
||||
.into_iter()
|
||||
.take(8)
|
||||
{
|
||||
node.mark_outgoing_request();
|
||||
addrs.push(node.addr());
|
||||
}
|
||||
for addr in addrs {
|
||||
let request = self.create_request(Request::GetPeers(info_hash), addr);
|
||||
self.sender.send((request, addr))?;
|
||||
}
|
||||
}
|
||||
Request::FindNode(_) => todo!(),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
fn on_found_nodes(&mut self, target: Id20, nodes: CompactNodeInfo) {
|
||||
|
||||
fn on_found_nodes(
|
||||
&mut self,
|
||||
source: Id20,
|
||||
source_addr: SocketAddr,
|
||||
target: Id20,
|
||||
nodes: CompactNodeInfo,
|
||||
) -> anyhow::Result<()> {
|
||||
todo!("on_found_nodes not implemented")
|
||||
}
|
||||
|
||||
fn on_found_peers_or_nodes(&mut self, target: Id20, data: PeersOrNodes) {
|
||||
todo!("on_found_nodes not implemented")
|
||||
fn on_found_peers_or_nodes(
|
||||
&mut self,
|
||||
source: Id20,
|
||||
source_addr: SocketAddr,
|
||||
target: Id20,
|
||||
data: PeersOrNodes,
|
||||
) -> anyhow::Result<()> {
|
||||
todo!("on_found_peers_or_nodes not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_framer(
|
||||
socket: &UdpSocket,
|
||||
mut input_rx: Receiver<(Message<ByteString>, SocketAddr)>,
|
||||
output_tx: Sender<Message<ByteString>>,
|
||||
mut input_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
|
||||
output_tx: Sender<(Message<ByteString>, SocketAddr)>,
|
||||
) -> anyhow::Result<()> {
|
||||
let writer = async {
|
||||
let mut buf = Vec::new();
|
||||
|
|
@ -173,7 +224,7 @@ async fn run_framer(
|
|||
let mut buf = vec![0u8; 16384];
|
||||
while let Ok((size, addr)) = socket.recv_from(&mut buf).await {
|
||||
match bprotocol::deserialize_message::<ByteString>(&buf[..size]) {
|
||||
Ok(msg) => match output_tx.send(msg).await {
|
||||
Ok(msg) => match output_tx.send((msg, addr)).await {
|
||||
Ok(_) => {}
|
||||
Err(_) => break,
|
||||
},
|
||||
|
|
@ -188,7 +239,7 @@ async fn run_framer(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
enum Request {
|
||||
GetPeers(Id20),
|
||||
FindNode(Id20),
|
||||
|
|
@ -205,50 +256,94 @@ struct Dht {
|
|||
|
||||
struct DhtWorker {
|
||||
socket: UdpSocket,
|
||||
request_rx: Receiver<(Request, Sender<Response>)>,
|
||||
next_transaction_id: u16,
|
||||
peer_id: Id20,
|
||||
state: Mutex<DhtState>,
|
||||
}
|
||||
|
||||
impl DhtWorker {
|
||||
fn on_request(&self, request: Request, sender: Sender<Response>) {}
|
||||
fn on_request(&self, request: Request, sender: Sender<Response>) -> anyhow::Result<()> {
|
||||
self.state.lock().on_request(request, sender)
|
||||
}
|
||||
fn on_response(&self, msg: Message<ByteString>, addr: SocketAddr) -> anyhow::Result<()> {
|
||||
self.state.lock().on_incoming_from_remote(msg, addr)
|
||||
}
|
||||
|
||||
async fn start(&mut self, bootstrap_addrs: Vec<String>) -> anyhow::Result<()> {
|
||||
let (in_tx, in_rx) = channel(1);
|
||||
let (out_tx, out_rx) = channel(1);
|
||||
async fn start(
|
||||
self,
|
||||
in_tx: UnboundedSender<(Message<ByteString>, SocketAddr)>,
|
||||
in_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
|
||||
mut request_rx: Receiver<(Request, Sender<Response>)>,
|
||||
bootstrap_addrs: &[String],
|
||||
) -> anyhow::Result<()> {
|
||||
let (out_tx, mut out_rx) = channel(1);
|
||||
let framer = run_framer(&self.socket, in_rx, out_tx);
|
||||
|
||||
let bootstrap = async {
|
||||
let mut futs = FuturesUnordered::new();
|
||||
// bootstrap
|
||||
for addr in bootstrap_addrs {
|
||||
for addr in tokio::net::lookup_host(addr).await.unwrap() {
|
||||
// let msg = MessageKind::FindNodeRequest(FindNodeRequest {
|
||||
// id: self.peer_id,
|
||||
// target: self.peer_id,
|
||||
// });
|
||||
// in_tx.send((msg, addr)).await.unwrap();
|
||||
}
|
||||
for addr in bootstrap_addrs.iter() {
|
||||
let addr = addr;
|
||||
let this = &self;
|
||||
let in_tx = &in_tx;
|
||||
futs.push(async move {
|
||||
match tokio::net::lookup_host(addr).await {
|
||||
Ok(addrs) => {
|
||||
for addr in addrs {
|
||||
let request = this
|
||||
.state
|
||||
.lock()
|
||||
.create_request(Request::FindNode(this.peer_id), addr);
|
||||
match in_tx.send((request, addr)) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
debug!("bootstrap: channel closed, did not send {:?}", e)
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("error looking up {}", addr),
|
||||
}
|
||||
});
|
||||
}
|
||||
while futs.next().await.is_some() {}
|
||||
};
|
||||
let mut bootstrap_done = false;
|
||||
|
||||
// let request_reader = async {
|
||||
// while let Some((request, sender)) = self.request_rx.recv().await {
|
||||
// self.on_request(request, sender)
|
||||
// }
|
||||
// };
|
||||
let request_reader = {
|
||||
let this = &self;
|
||||
async move {
|
||||
while let Some((request, sender)) = request_rx.recv().await {
|
||||
this.on_request(request, sender).unwrap();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// tokio::select! {
|
||||
// _ = framer => {
|
||||
// anyhow::bail!("framer quit")
|
||||
// },
|
||||
// _ = bootstrap, if !bootstrap_done => {
|
||||
// bootstrap_done = true
|
||||
// },
|
||||
// _ = request_reader => {}
|
||||
// }
|
||||
let response_reader = {
|
||||
let this = &self;
|
||||
async move {
|
||||
while let Some((response, addr)) = out_rx.recv().await {
|
||||
this.on_response(response, addr).unwrap();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
todo!()
|
||||
tokio::pin!(framer);
|
||||
tokio::pin!(bootstrap);
|
||||
tokio::pin!(request_reader);
|
||||
tokio::pin!(response_reader);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = &mut framer => {
|
||||
anyhow::bail!("framer quit")
|
||||
},
|
||||
_ = &mut bootstrap, if !bootstrap_done => {
|
||||
bootstrap_done = true
|
||||
},
|
||||
_ = &mut request_reader => {anyhow::bail!("request reader quit")}
|
||||
_ = &mut response_reader => {anyhow::bail!("response reader quit")}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -256,14 +351,23 @@ impl Dht {
|
|||
pub async fn new(bootstrap_addrs: &[&str]) -> anyhow::Result<Self> {
|
||||
let (request_tx, request_rx) = channel(1);
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
let mut worker = DhtWorker {
|
||||
socket,
|
||||
request_rx,
|
||||
next_transaction_id: 0,
|
||||
peer_id: Id20(generate_peer_id()),
|
||||
};
|
||||
let bootstrap_addrs = bootstrap_addrs.iter().map(|s| s.to_string()).collect();
|
||||
tokio::spawn(async move { worker.start(bootstrap_addrs).await });
|
||||
let peer_id = Id20(generate_peer_id());
|
||||
let bootstrap_addrs = bootstrap_addrs
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let (in_tx, in_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
let worker = DhtWorker {
|
||||
socket,
|
||||
peer_id,
|
||||
state: Mutex::new(DhtState::new(peer_id, in_tx.clone())),
|
||||
};
|
||||
worker
|
||||
.start(in_tx, in_rx, request_rx, &bootstrap_addrs)
|
||||
.await
|
||||
});
|
||||
Ok(Dht { request_tx })
|
||||
}
|
||||
pub async fn get_peers(&self, info_hash: Id20) -> impl StreamExt<Item = SocketAddr> {
|
||||
|
|
@ -277,16 +381,12 @@ impl Dht {
|
|||
_ => panic!("programming error"),
|
||||
})
|
||||
}
|
||||
// async fn run(self) -> anyhow::Result<Self> {
|
||||
// let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
// let (in_tx, in_rx) = channel(1);
|
||||
// let (out_tx, out_rx) = channel(1);
|
||||
// let framer = run_framer(socket, in_rx, out_tx);
|
||||
// }
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
pretty_env_logger::init();
|
||||
|
||||
let info_hash = Id20([0u8; 20]);
|
||||
let dht = Dht::new(&["dht.transmissionbt.com:6881"]).await.unwrap();
|
||||
let mut stream = dht.get_peers(info_hash).await;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
use std::{net::SocketAddr, time::Instant};
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
enum BucketTreeNode {
|
||||
|
|
@ -63,6 +66,55 @@ impl<'a> Iterator for BucketTreeNodeIterator<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct BucketTreeNodeIteratorMut<'a> {
|
||||
current: std::slice::IterMut<'a, RoutingTableNode>,
|
||||
queue: Vec<&'a mut BucketTree>,
|
||||
}
|
||||
|
||||
impl<'a> BucketTreeNodeIteratorMut<'a> {
|
||||
fn new(mut tree: &'a mut BucketTree) -> Self {
|
||||
let mut queue = Vec::new();
|
||||
let current = loop {
|
||||
match &mut tree.data {
|
||||
BucketTreeNode::Leaf(nodes) => break nodes.iter_mut(),
|
||||
BucketTreeNode::LeftRight(left, right) => {
|
||||
queue.push(right.as_mut());
|
||||
tree = left.as_mut()
|
||||
}
|
||||
}
|
||||
};
|
||||
BucketTreeNodeIteratorMut { current, queue }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for BucketTreeNodeIteratorMut<'a> {
|
||||
type Item = &'a mut RoutingTableNode;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if let Some(v) = self.current.next() {
|
||||
return Some(v);
|
||||
};
|
||||
|
||||
loop {
|
||||
let tree = self.queue.pop()?;
|
||||
match &mut tree.data {
|
||||
BucketTreeNode::Leaf(nodes) => {
|
||||
self.current = nodes.iter_mut();
|
||||
match self.current.next() {
|
||||
Some(v) => return Some(v),
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
BucketTreeNode::LeftRight(left, right) => {
|
||||
self.queue.push(right.as_mut());
|
||||
self.queue.push(left.as_mut());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_split_start_end(
|
||||
start: Id20,
|
||||
end_inclusive: Id20,
|
||||
|
|
@ -129,6 +181,23 @@ impl BucketTree {
|
|||
pub fn iter(&self) -> BucketTreeNodeIterator<'_> {
|
||||
BucketTreeNodeIterator::new(self)
|
||||
}
|
||||
|
||||
pub fn iter_mut(&mut self) -> BucketTreeNodeIteratorMut<'_> {
|
||||
BucketTreeNodeIteratorMut::new(self)
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self, id: &Id20) -> Option<&mut RoutingTableNode> {
|
||||
if !(*id >= self.start && *id <= self.end_inclusive) {
|
||||
return None;
|
||||
}
|
||||
match &mut self.data {
|
||||
BucketTreeNode::Leaf(nodes) => nodes.iter_mut().find(|b| b.id == *id),
|
||||
BucketTreeNode::LeftRight(left, right) => {
|
||||
left.get_mut(id).or_else(move || right.get_mut(id))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_node(&mut self, self_id: &Id20, id: Id20, addr: SocketAddr) -> InsertResult {
|
||||
let mut tree = self;
|
||||
loop {
|
||||
|
|
@ -259,11 +328,25 @@ impl RoutingTableNode {
|
|||
Some(v) => v,
|
||||
None => return NodeStatus::Unknown,
|
||||
};
|
||||
if self.outstanding_queries_in_a_row > 0 && last_request.elapsed() > Duration::from_secs(10)
|
||||
{
|
||||
return NodeStatus::Bad;
|
||||
}
|
||||
if self.last_response.is_some() {
|
||||
return NodeStatus::Good;
|
||||
}
|
||||
NodeStatus::Questionable
|
||||
}
|
||||
|
||||
pub fn mark_outgoing_request(&mut self) {
|
||||
self.last_request = Some(Instant::now());
|
||||
self.outstanding_queries_in_a_row += 1;
|
||||
}
|
||||
|
||||
pub fn mark_response(&mut self) {
|
||||
self.last_response = Some(Instant::now());
|
||||
self.outstanding_queries_in_a_row = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
|
@ -289,6 +372,16 @@ impl RoutingTable {
|
|||
result.sort_by_key(|n| id.distance(&n.id));
|
||||
result
|
||||
}
|
||||
|
||||
pub fn sorted_by_distance_from_mut(&mut self, id: Id20) -> Vec<&mut RoutingTableNode> {
|
||||
let mut result = Vec::with_capacity(self.size);
|
||||
for node in self.buckets.iter_mut() {
|
||||
result.push(node);
|
||||
}
|
||||
result.sort_by_key(|n| id.distance(&n.id));
|
||||
result
|
||||
}
|
||||
|
||||
pub fn add_node(&mut self, id: Id20, addr: SocketAddr) -> InsertResult {
|
||||
let res = self.buckets.add_node(&self.id, id, addr);
|
||||
let replaced = match &res {
|
||||
|
|
@ -302,6 +395,23 @@ impl RoutingTable {
|
|||
}
|
||||
res
|
||||
}
|
||||
pub fn mark_outgoing_request(&mut self, id: &Id20) -> bool {
|
||||
let r = match self.buckets.get_mut(id) {
|
||||
Some(r) => r,
|
||||
None => return false,
|
||||
};
|
||||
r.mark_outgoing_request();
|
||||
true
|
||||
}
|
||||
|
||||
pub fn mark_response(&mut self, id: &Id20) -> bool {
|
||||
let r = match self.buckets.get_mut(id) {
|
||||
Some(r) => r,
|
||||
None => return false,
|
||||
};
|
||||
r.mark_response();
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -410,6 +520,7 @@ mod tests {
|
|||
let addr = std::net::SocketAddr::V4(SocketAddrV4::new("0.0.0.0".parse().unwrap(), i));
|
||||
rtable.add_node(other_id, addr);
|
||||
}
|
||||
dbg!(rtable);
|
||||
dbg!(&rtable);
|
||||
assert_eq!(rtable.sorted_by_distance_from(my_id).len(), rtable.size);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue