Will start to test soon

This commit is contained in:
Igor Katson 2021-07-12 16:24:26 +01:00
parent dc6fc6bba2
commit d57079c75a
3 changed files with 285 additions and 74 deletions

View file

@ -2,7 +2,7 @@ use std::cmp::Ordering;
use serde::{Deserialize, Deserializer, Serialize};
#[derive(Clone, Copy, PartialEq, Eq)]
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct Id20(pub [u8; 20]);
impl std::fmt::Debug for Id20 {

View file

@ -1,5 +1,6 @@
use std::{
collections::BTreeMap,
cell::RefCell,
collections::{BTreeMap, HashMap},
net::{SocketAddr, SocketAddrV4},
time::Instant,
};
@ -11,12 +12,15 @@ use dht::{
MessageKind,
},
id20::Id20,
routing_table::RoutingTable,
};
use futures::StreamExt;
use futures::{stream::FuturesUnordered, StreamExt};
use librqbit_core::peer_id::generate_peer_id;
use log::{debug, warn};
use parking_lot::Mutex;
use tokio::{
net::UdpSocket,
sync::mpsc::{channel, Receiver, Sender},
sync::mpsc::{channel, Receiver, Sender, UnboundedReceiver, UnboundedSender},
};
use tokio_stream::wrappers::ReceiverStream;
@ -32,6 +36,11 @@ struct DhtState {
next_transaction_id: u16,
outstanding_requests: Vec<OutstandingRequest>,
searching_for_peers: Vec<Id20>,
routing_table: RoutingTable,
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
// TODO: convert to broadcast
subscribers: HashMap<Id20, Vec<Sender<Response>>>,
}
enum PeersOrNodes {
@ -40,10 +49,22 @@ enum PeersOrNodes {
}
impl DhtState {
pub fn new(id: Id20, sender: UnboundedSender<(Message<ByteString>, SocketAddr)>) -> Self {
Self {
id,
next_transaction_id: 0,
outstanding_requests: Vec::new(),
searching_for_peers: Vec::new(),
routing_table: RoutingTable::new(id),
sender,
subscribers: Default::default(),
}
}
fn add_searching_for_peers(&mut self, info_hash: Id20) {
self.searching_for_peers.push(info_hash)
}
fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message<ByteString> {
pub fn create_request(&mut self, request: Request, addr: SocketAddr) -> Message<ByteString> {
let transaction_id = self.next_transaction_id;
let transaction_id_buf = [(transaction_id >> 8) as u8, (transaction_id & 0xff) as u8];
let message = match request {
@ -107,17 +128,10 @@ impl DhtState {
};
match outstanding.request {
Request::FindNode(id) => {
if response.id != id {
anyhow::bail!(
"response id does not match: expected {:?}, received {:?}",
id,
response.id
)
};
let nodes = response
.nodes
.ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?;
self.on_found_nodes(id, nodes)
self.on_found_nodes(response.id, addr, id, nodes)
}
Request::GetPeers(id) => {
if response.id != id {
@ -127,32 +141,69 @@ impl DhtState {
response.id
)
};
let nodes = response
.nodes
.ok_or_else(|| anyhow::anyhow!("expected nodes for find_node requests"))?;
// let pn = match (response.nodes, response.values) {
// (Some(nodes), None) => PeersOrNodes::Nodes(nodes),
// (None, Some(peers)) => PeersOrNodes::Peers(peers),
// _ => anyhow::bail!("expected nodes or values to be set in find_peers response"),
// };
// self.on_found_peers_or_nodes(id, pn)
let pn = match (response.nodes, response.values) {
(Some(nodes), None) => PeersOrNodes::Nodes(nodes),
(None, Some(peers)) => PeersOrNodes::Peers(peers),
_ => anyhow::bail!("expected nodes or values to be set in find_peers response"),
};
self.on_found_peers_or_nodes(response.id, addr, id, pn)
}
}
}
pub fn on_request(&mut self, request: Request, sender: Sender<Response>) -> anyhow::Result<()> {
match request {
Request::GetPeers(info_hash) => {
let subs = self.subscribers.entry(info_hash).or_default();
subs.push(sender);
self.add_searching_for_peers(info_hash);
// workaround borrow checker.
let mut addrs = Vec::new();
for node in self
.routing_table
.sorted_by_distance_from_mut(info_hash)
.into_iter()
.take(8)
{
node.mark_outgoing_request();
addrs.push(node.addr());
}
for addr in addrs {
let request = self.create_request(Request::GetPeers(info_hash), addr);
self.sender.send((request, addr))?;
}
}
Request::FindNode(_) => todo!(),
};
Ok(())
}
fn on_found_nodes(&mut self, target: Id20, nodes: CompactNodeInfo) {
fn on_found_nodes(
&mut self,
source: Id20,
source_addr: SocketAddr,
target: Id20,
nodes: CompactNodeInfo,
) -> anyhow::Result<()> {
todo!("on_found_nodes not implemented")
}
fn on_found_peers_or_nodes(&mut self, target: Id20, data: PeersOrNodes) {
todo!("on_found_nodes not implemented")
fn on_found_peers_or_nodes(
&mut self,
source: Id20,
source_addr: SocketAddr,
target: Id20,
data: PeersOrNodes,
) -> anyhow::Result<()> {
todo!("on_found_peers_or_nodes not implemented")
}
}
async fn run_framer(
socket: &UdpSocket,
mut input_rx: Receiver<(Message<ByteString>, SocketAddr)>,
output_tx: Sender<Message<ByteString>>,
mut input_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
output_tx: Sender<(Message<ByteString>, SocketAddr)>,
) -> anyhow::Result<()> {
let writer = async {
let mut buf = Vec::new();
@ -173,7 +224,7 @@ async fn run_framer(
let mut buf = vec![0u8; 16384];
while let Ok((size, addr)) = socket.recv_from(&mut buf).await {
match bprotocol::deserialize_message::<ByteString>(&buf[..size]) {
Ok(msg) => match output_tx.send(msg).await {
Ok(msg) => match output_tx.send((msg, addr)).await {
Ok(_) => {}
Err(_) => break,
},
@ -188,7 +239,7 @@ async fn run_framer(
Ok(())
}
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum Request {
GetPeers(Id20),
FindNode(Id20),
@ -205,50 +256,94 @@ struct Dht {
struct DhtWorker {
socket: UdpSocket,
request_rx: Receiver<(Request, Sender<Response>)>,
next_transaction_id: u16,
peer_id: Id20,
state: Mutex<DhtState>,
}
impl DhtWorker {
fn on_request(&self, request: Request, sender: Sender<Response>) {}
fn on_request(&self, request: Request, sender: Sender<Response>) -> anyhow::Result<()> {
self.state.lock().on_request(request, sender)
}
fn on_response(&self, msg: Message<ByteString>, addr: SocketAddr) -> anyhow::Result<()> {
self.state.lock().on_incoming_from_remote(msg, addr)
}
async fn start(&mut self, bootstrap_addrs: Vec<String>) -> anyhow::Result<()> {
let (in_tx, in_rx) = channel(1);
let (out_tx, out_rx) = channel(1);
async fn start(
self,
in_tx: UnboundedSender<(Message<ByteString>, SocketAddr)>,
in_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
mut request_rx: Receiver<(Request, Sender<Response>)>,
bootstrap_addrs: &[String],
) -> anyhow::Result<()> {
let (out_tx, mut out_rx) = channel(1);
let framer = run_framer(&self.socket, in_rx, out_tx);
let bootstrap = async {
let mut futs = FuturesUnordered::new();
// bootstrap
for addr in bootstrap_addrs {
for addr in tokio::net::lookup_host(addr).await.unwrap() {
// let msg = MessageKind::FindNodeRequest(FindNodeRequest {
// id: self.peer_id,
// target: self.peer_id,
// });
// in_tx.send((msg, addr)).await.unwrap();
}
for addr in bootstrap_addrs.iter() {
let addr = addr;
let this = &self;
let in_tx = &in_tx;
futs.push(async move {
match tokio::net::lookup_host(addr).await {
Ok(addrs) => {
for addr in addrs {
let request = this
.state
.lock()
.create_request(Request::FindNode(this.peer_id), addr);
match in_tx.send((request, addr)) {
Ok(_) => {}
Err(e) => {
debug!("bootstrap: channel closed, did not send {:?}", e)
}
};
}
}
Err(e) => warn!("error looking up {}", addr),
}
});
}
while futs.next().await.is_some() {}
};
let mut bootstrap_done = false;
// let request_reader = async {
// while let Some((request, sender)) = self.request_rx.recv().await {
// self.on_request(request, sender)
// }
// };
let request_reader = {
let this = &self;
async move {
while let Some((request, sender)) = request_rx.recv().await {
this.on_request(request, sender).unwrap();
}
}
};
// tokio::select! {
// _ = framer => {
// anyhow::bail!("framer quit")
// },
// _ = bootstrap, if !bootstrap_done => {
// bootstrap_done = true
// },
// _ = request_reader => {}
// }
let response_reader = {
let this = &self;
async move {
while let Some((response, addr)) = out_rx.recv().await {
this.on_response(response, addr).unwrap();
}
}
};
todo!()
tokio::pin!(framer);
tokio::pin!(bootstrap);
tokio::pin!(request_reader);
tokio::pin!(response_reader);
loop {
tokio::select! {
_ = &mut framer => {
anyhow::bail!("framer quit")
},
_ = &mut bootstrap, if !bootstrap_done => {
bootstrap_done = true
},
_ = &mut request_reader => {anyhow::bail!("request reader quit")}
_ = &mut response_reader => {anyhow::bail!("response reader quit")}
}
}
}
}
@ -256,14 +351,23 @@ impl Dht {
pub async fn new(bootstrap_addrs: &[&str]) -> anyhow::Result<Self> {
let (request_tx, request_rx) = channel(1);
let socket = UdpSocket::bind("0.0.0.0:0").await?;
let mut worker = DhtWorker {
socket,
request_rx,
next_transaction_id: 0,
peer_id: Id20(generate_peer_id()),
};
let bootstrap_addrs = bootstrap_addrs.iter().map(|s| s.to_string()).collect();
tokio::spawn(async move { worker.start(bootstrap_addrs).await });
let peer_id = Id20(generate_peer_id());
let bootstrap_addrs = bootstrap_addrs
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>();
tokio::spawn(async move {
let (in_tx, in_rx) = tokio::sync::mpsc::unbounded_channel();
let worker = DhtWorker {
socket,
peer_id,
state: Mutex::new(DhtState::new(peer_id, in_tx.clone())),
};
worker
.start(in_tx, in_rx, request_rx, &bootstrap_addrs)
.await
});
Ok(Dht { request_tx })
}
pub async fn get_peers(&self, info_hash: Id20) -> impl StreamExt<Item = SocketAddr> {
@ -277,16 +381,12 @@ impl Dht {
_ => panic!("programming error"),
})
}
// async fn run(self) -> anyhow::Result<Self> {
// let socket = UdpSocket::bind("0.0.0.0:0").await?;
// let (in_tx, in_rx) = channel(1);
// let (out_tx, out_rx) = channel(1);
// let framer = run_framer(socket, in_rx, out_tx);
// }
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
pretty_env_logger::init();
let info_hash = Id20([0u8; 20]);
let dht = Dht::new(&["dht.transmissionbt.com:6881"]).await.unwrap();
let mut stream = dht.get_peers(info_hash).await;

View file

@ -1,4 +1,7 @@
use std::{net::SocketAddr, time::Instant};
use std::{
net::SocketAddr,
time::{Duration, Instant},
};
#[derive(Debug)]
enum BucketTreeNode {
@ -63,6 +66,55 @@ impl<'a> Iterator for BucketTreeNodeIterator<'a> {
}
}
pub struct BucketTreeNodeIteratorMut<'a> {
current: std::slice::IterMut<'a, RoutingTableNode>,
queue: Vec<&'a mut BucketTree>,
}
impl<'a> BucketTreeNodeIteratorMut<'a> {
fn new(mut tree: &'a mut BucketTree) -> Self {
let mut queue = Vec::new();
let current = loop {
match &mut tree.data {
BucketTreeNode::Leaf(nodes) => break nodes.iter_mut(),
BucketTreeNode::LeftRight(left, right) => {
queue.push(right.as_mut());
tree = left.as_mut()
}
}
};
BucketTreeNodeIteratorMut { current, queue }
}
}
impl<'a> Iterator for BucketTreeNodeIteratorMut<'a> {
type Item = &'a mut RoutingTableNode;
fn next(&mut self) -> Option<Self::Item> {
if let Some(v) = self.current.next() {
return Some(v);
};
loop {
let tree = self.queue.pop()?;
match &mut tree.data {
BucketTreeNode::Leaf(nodes) => {
self.current = nodes.iter_mut();
match self.current.next() {
Some(v) => return Some(v),
None => continue,
}
}
BucketTreeNode::LeftRight(left, right) => {
self.queue.push(right.as_mut());
self.queue.push(left.as_mut());
continue;
}
}
}
}
}
fn compute_split_start_end(
start: Id20,
end_inclusive: Id20,
@ -129,6 +181,23 @@ impl BucketTree {
pub fn iter(&self) -> BucketTreeNodeIterator<'_> {
BucketTreeNodeIterator::new(self)
}
pub fn iter_mut(&mut self) -> BucketTreeNodeIteratorMut<'_> {
BucketTreeNodeIteratorMut::new(self)
}
pub fn get_mut(&mut self, id: &Id20) -> Option<&mut RoutingTableNode> {
if !(*id >= self.start && *id <= self.end_inclusive) {
return None;
}
match &mut self.data {
BucketTreeNode::Leaf(nodes) => nodes.iter_mut().find(|b| b.id == *id),
BucketTreeNode::LeftRight(left, right) => {
left.get_mut(id).or_else(move || right.get_mut(id))
}
}
}
pub fn add_node(&mut self, self_id: &Id20, id: Id20, addr: SocketAddr) -> InsertResult {
let mut tree = self;
loop {
@ -259,11 +328,25 @@ impl RoutingTableNode {
Some(v) => v,
None => return NodeStatus::Unknown,
};
if self.outstanding_queries_in_a_row > 0 && last_request.elapsed() > Duration::from_secs(10)
{
return NodeStatus::Bad;
}
if self.last_response.is_some() {
return NodeStatus::Good;
}
NodeStatus::Questionable
}
pub fn mark_outgoing_request(&mut self) {
self.last_request = Some(Instant::now());
self.outstanding_queries_in_a_row += 1;
}
pub fn mark_response(&mut self) {
self.last_response = Some(Instant::now());
self.outstanding_queries_in_a_row = 0;
}
}
#[derive(Debug)]
@ -289,6 +372,16 @@ impl RoutingTable {
result.sort_by_key(|n| id.distance(&n.id));
result
}
pub fn sorted_by_distance_from_mut(&mut self, id: Id20) -> Vec<&mut RoutingTableNode> {
let mut result = Vec::with_capacity(self.size);
for node in self.buckets.iter_mut() {
result.push(node);
}
result.sort_by_key(|n| id.distance(&n.id));
result
}
pub fn add_node(&mut self, id: Id20, addr: SocketAddr) -> InsertResult {
let res = self.buckets.add_node(&self.id, id, addr);
let replaced = match &res {
@ -302,6 +395,23 @@ impl RoutingTable {
}
res
}
pub fn mark_outgoing_request(&mut self, id: &Id20) -> bool {
let r = match self.buckets.get_mut(id) {
Some(r) => r,
None => return false,
};
r.mark_outgoing_request();
true
}
pub fn mark_response(&mut self, id: &Id20) -> bool {
let r = match self.buckets.get_mut(id) {
Some(r) => r,
None => return false,
};
r.mark_response();
true
}
}
#[cfg(test)]
@ -410,6 +520,7 @@ mod tests {
let addr = std::net::SocketAddr::V4(SocketAddrV4::new("0.0.0.0".parse().unwrap(), i));
rtable.add_node(other_id, addr);
}
dbg!(rtable);
dbg!(&rtable);
assert_eq!(rtable.sorted_by_distance_from(my_id).len(), rtable.size);
}
}