Fixing up bugs, refactored DHT works alright now

This commit is contained in:
Igor Katson 2023-11-29 23:12:20 +00:00
parent 69b9918e4f
commit aa2a41a53c
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
2 changed files with 90 additions and 162 deletions

View file

@ -36,6 +36,7 @@ async fn main() -> anyhow::Result<()> {
let mut f = std::fs::OpenOptions::new() let mut f = std::fs::OpenOptions::new()
.create(true) .create(true)
.write(true) .write(true)
.truncate(true)
.open(filename) .open(filename)
.unwrap(); .unwrap();
serde_json::to_writer_pretty(&mut f, r).unwrap(); serde_json::to_writer_pretty(&mut f, r).unwrap();

View file

@ -81,15 +81,10 @@ fn make_rate_limiter() -> RateLimiter {
} }
trait RecursiveRequestCallbacks: Sized + Send + Sync + 'static { trait RecursiveRequestCallbacks: Sized + Send + Sync + 'static {
fn on_request_start( fn on_request_start(&self, req: &RecursiveRequest<Self>, target_node: Id20, addr: SocketAddr);
&self,
req: &Arc<RecursiveRequest<Self>>,
target_node: Id20,
addr: SocketAddr,
);
fn on_request_end( fn on_request_end(
&self, &self,
req: &Arc<RecursiveRequest<Self>>, req: &RecursiveRequest<Self>,
target_node: Id20, target_node: Id20,
addr: SocketAddr, addr: SocketAddr,
resp: &anyhow::Result<ResponseOrError>, resp: &anyhow::Result<ResponseOrError>,
@ -98,11 +93,11 @@ trait RecursiveRequestCallbacks: Sized + Send + Sync + 'static {
struct RecursiveRequestCallbacksGetPeers {} struct RecursiveRequestCallbacksGetPeers {}
impl RecursiveRequestCallbacks for RecursiveRequestCallbacksGetPeers { impl RecursiveRequestCallbacks for RecursiveRequestCallbacksGetPeers {
fn on_request_start(&self, _: &Arc<RecursiveRequest<Self>>, _: Id20, _: SocketAddr) {} fn on_request_start(&self, _: &RecursiveRequest<Self>, _: Id20, _: SocketAddr) {}
fn on_request_end( fn on_request_end(
&self, &self,
_: &Arc<RecursiveRequest<Self>>, _: &RecursiveRequest<Self>,
_: Id20, _: Id20,
_: SocketAddr, _: SocketAddr,
_: &anyhow::Result<ResponseOrError>, _: &anyhow::Result<ResponseOrError>,
@ -112,12 +107,7 @@ impl RecursiveRequestCallbacks for RecursiveRequestCallbacksGetPeers {
struct RecursiveRequestCallbacksFindNodes {} struct RecursiveRequestCallbacksFindNodes {}
impl RecursiveRequestCallbacks for RecursiveRequestCallbacksFindNodes { impl RecursiveRequestCallbacks for RecursiveRequestCallbacksFindNodes {
fn on_request_start( fn on_request_start(&self, req: &RecursiveRequest<Self>, target_node: Id20, addr: SocketAddr) {
&self,
req: &Arc<RecursiveRequest<Self>>,
target_node: Id20,
addr: SocketAddr,
) {
match req.dht.routing_table_add_node(target_node, addr) { match req.dht.routing_table_add_node(target_node, addr) {
InsertResult::WasExisting | InsertResult::ReplacedBad(_) | InsertResult::Added => { InsertResult::WasExisting | InsertResult::ReplacedBad(_) | InsertResult::Added => {
req.dht req.dht
@ -131,7 +121,7 @@ impl RecursiveRequestCallbacks for RecursiveRequestCallbacksFindNodes {
fn on_request_end( fn on_request_end(
&self, &self,
req: &Arc<RecursiveRequest<Self>>, req: &RecursiveRequest<Self>,
target_node: Id20, target_node: Id20,
_addr: SocketAddr, _addr: SocketAddr,
resp: &anyhow::Result<ResponseOrError>, resp: &anyhow::Result<ResponseOrError>,
@ -150,8 +140,8 @@ struct RecursiveRequest<C: RecursiveRequestCallbacks> {
request: Request, request: Request,
dht: Arc<DhtState>, dht: Arc<DhtState>,
useful_nodes: RwLock<Vec<MaybeUsefulNode>>, useful_nodes: RwLock<Vec<MaybeUsefulNode>>,
// peer_tx: tokio::sync::mpsc::UnboundedSender<SocketAddr>, peer_tx: tokio::sync::mpsc::UnboundedSender<SocketAddr>,
// node_tx: tokio::sync::mpsc::UnboundedSender<(Option<Id20>, SocketAddr)>, node_tx: tokio::sync::mpsc::UnboundedSender<(Option<Id20>, SocketAddr)>,
callbacks: C, callbacks: C,
} }
@ -169,11 +159,11 @@ impl RequestPeersStream {
request: Request::GetPeers(info_hash), request: Request::GetPeers(info_hash),
dht, dht,
useful_nodes: RwLock::new(Vec::new()), useful_nodes: RwLock::new(Vec::new()),
// peer_tx, peer_tx,
// node_tx, node_tx,
callbacks: RecursiveRequestCallbacksGetPeers {}, callbacks: RecursiveRequestCallbacksGetPeers {},
}); });
let join_handle = rp.clone().request_peers_forever(node_rx, node_tx, peer_tx); let join_handle = rp.request_peers_forever(node_rx);
Self { Self {
rx: peer_rx, rx: peer_rx,
cancel_join_handle: join_handle, cancel_join_handle: join_handle,
@ -199,77 +189,101 @@ impl Stream for RequestPeersStream {
} }
impl RecursiveRequest<RecursiveRequestCallbacksFindNodes> { impl RecursiveRequest<RecursiveRequestCallbacksFindNodes> {
async fn find_node( async fn bootstrap(dht: Arc<DhtState>, target: Id20, hostname: &str) -> anyhow::Result<()> {
dht: Arc<DhtState>, let addrs = tokio::net::lookup_host(hostname)
target: Id20, .await
root_addrs: impl Iterator<Item = SocketAddr>, .with_context(|| format!("error looking up {}", hostname))?;
) -> anyhow::Result<()> {
let (peer_tx, peer_rx) = unbounded_channel();
drop(peer_rx);
let (node_tx, mut node_rx) = unbounded_channel(); let (node_tx, mut node_rx) = unbounded_channel();
let req = Arc::new(RecursiveRequest { let req = RecursiveRequest {
info_hash: target, info_hash: target,
request: Request::FindNode(target), request: Request::FindNode(target),
dht, dht,
useful_nodes: RwLock::new(Vec::new()), useful_nodes: RwLock::new(Vec::new()),
// peer_tx: unbounded_channel().0, peer_tx: unbounded_channel().0,
// node_tx, node_tx,
callbacks: RecursiveRequestCallbacksFindNodes {}, callbacks: RecursiveRequestCallbacksFindNodes {},
}); };
let request_one = |id, addr| {
req.request_one(id, addr)
.map_err(|e| {
debug!("error: {e:?}");
e
})
.instrument(error_span!(
"find_node",
target = format!("{target:?}"),
addr = addr.to_string()
))
};
let mut futs = FuturesUnordered::new(); let mut futs = FuturesUnordered::new();
for addr in root_addrs { let mut initial_addrs = 0;
node_tx.send((None, addr)).unwrap(); for addr in addrs {
futs.push(request_one(None, addr));
initial_addrs += 1;
} }
let mut successes = 0;
let mut errors = 0;
loop { loop {
tokio::select! { tokio::select! {
biased;
r = node_rx.recv() => { r = node_rx.recv() => {
let (id, addr) = r.unwrap(); let (id, addr) = r.unwrap();
futs.push( futs.push(request_one(id, addr))
req.request_one(id, addr, node_tx.clone(), peer_tx.clone())
.instrument(
error_span!("find_node", target=format!("{target:?}"), addr=addr.to_string())
)
)
}, },
Some(f) = futs.next(), if !futs.is_empty() => { f = futs.next() => {
if let Err(e) = f { let f = match f {
error!("error: {e:?}"); Some(f) => f,
None => {
// find_node recursion finished.
break;
}
};
if f.is_ok() {
successes += 1;
} else {
errors += 1;
} }
} }
} }
} }
if successes == 0 {
bail!("no successful lookups, errors = {errors}");
}
debug!(
"finished, successes = {successes}, errors = {errors}, initial_addrs = {initial_addrs}"
);
Ok(()) Ok(())
} }
} }
impl RecursiveRequest<RecursiveRequestCallbacksGetPeers> { impl RecursiveRequest<RecursiveRequestCallbacksGetPeers> {
fn request_peers_forever( fn request_peers_forever(
self: Arc<Self>, self: &Arc<Self>,
mut node_rx: tokio::sync::mpsc::UnboundedReceiver<(Option<Id20>, SocketAddr)>, mut node_rx: tokio::sync::mpsc::UnboundedReceiver<(Option<Id20>, SocketAddr)>,
node_tx: tokio::sync::mpsc::UnboundedSender<(Option<Id20>, SocketAddr)>,
peer_tx: tokio::sync::mpsc::UnboundedSender<SocketAddr>,
) -> tokio::task::JoinHandle<()> { ) -> tokio::task::JoinHandle<()> {
let this = self.clone();
spawn( spawn(
error_span!("get_peers", info_hash = format!("{:?}", self.info_hash)), error_span!(parent: None, "get_peers", info_hash = format!("{:?}", self.info_hash)),
async move { async move {
let this = &this;
// Looper adds root nodes to the queue every 60 seconds. // Looper adds root nodes to the queue every 60 seconds.
let looper = { let looper = {
let this = self.clone();
let node_tx = node_tx.clone();
async move { async move {
let mut iteration = 0; let mut iteration = 0;
loop { loop {
debug!("iteration {}", iteration); debug!("iteration {}", iteration);
let sleep = match this.get_peers_root(&node_tx) { let sleep = match this.get_peers_root() {
Ok(0) => Duration::from_secs(1), Ok(0) => Duration::from_secs(1),
Ok(n) if n < 8 => REQUERY_INTERVAL / 2, Ok(n) if n < 8 => REQUERY_INTERVAL / 2,
Ok(_) => REQUERY_INTERVAL, Ok(_) => REQUERY_INTERVAL,
Err(e) => { Err(e) => {
error!("error: {e:?}"); error!("error in get_peers_root(): {e:?}");
return Err::<(), anyhow::Error>(e); return Err::<(), anyhow::Error>(e);
} }
}; };
@ -286,7 +300,7 @@ impl RecursiveRequest<RecursiveRequestCallbacksGetPeers> {
addr = node_rx.recv() => { addr = node_rx.recv() => {
let (id, addr) = addr.unwrap(); let (id, addr) = addr.unwrap();
futs.push( futs.push(
self.request_one(id, addr, node_tx.clone(), peer_tx.clone()) this.request_one(id, addr)
.map_err(|e| debug!("error: {e:?}")) .map_err(|e| debug!("error: {e:?}"))
.instrument(error_span!("addr", addr=addr.to_string())) .instrument(error_span!("addr", addr=addr.to_string()))
); );
@ -299,10 +313,7 @@ impl RecursiveRequest<RecursiveRequestCallbacksGetPeers> {
) )
} }
fn get_peers_root( fn get_peers_root(&self) -> anyhow::Result<usize> {
self: &Arc<Self>,
node_tx: &UnboundedSender<(Option<Id20>, SocketAddr)>,
) -> anyhow::Result<usize> {
let mut count = 0; let mut count = 0;
for (id, addr) in self for (id, addr) in self
.dht .dht
@ -314,20 +325,14 @@ impl RecursiveRequest<RecursiveRequestCallbacksGetPeers> {
.take(8) .take(8)
{ {
count += 1; count += 1;
node_tx.send((Some(id), addr))?; self.node_tx.send((Some(id), addr))?;
} }
Ok(count) Ok(count)
} }
} }
impl<C: RecursiveRequestCallbacks> RecursiveRequest<C> { impl<C: RecursiveRequestCallbacks> RecursiveRequest<C> {
async fn request_one<'a>( async fn request_one(&self, id: Option<Id20>, addr: SocketAddr) -> anyhow::Result<()> {
self: &'a Arc<Self>,
id: Option<Id20>,
addr: SocketAddr,
node_tx: UnboundedSender<(Option<Id20>, SocketAddr)>,
peer_tx: UnboundedSender<SocketAddr>,
) -> anyhow::Result<()> {
if let Some(id) = id { if let Some(id) = id {
self.callbacks.on_request_start(self, id, addr); self.callbacks.on_request_start(self, id, addr);
} }
@ -348,18 +353,26 @@ impl<C: RecursiveRequestCallbacks> RecursiveRequest<C> {
return Err(e); return Err(e);
} }
}; };
trace!("received {response:?}");
if let Some(peers) = response.values { if let Some(peers) = response.values {
for peer in peers { for peer in peers {
peer_tx.send(SocketAddr::V4(peer.addr))?; self.peer_tx.send(SocketAddr::V4(peer.addr))?;
} }
} }
if let Some(nodes) = response.nodes { if let Some(nodes) = response.nodes {
for node in nodes.nodes { for node in nodes.nodes {
let addr = SocketAddr::V4(node.addr); let addr = SocketAddr::V4(node.addr);
if self.should_request_node(node.id, addr) { let should_request = self.should_request_node(node.id, addr);
node_tx.send((Some(node.id), addr))?; trace!(
"should_request={}, id={:?}, addr={}",
should_request,
node.id,
addr
);
if should_request {
self.node_tx.send((Some(node.id), addr))?;
} }
} }
} }
@ -471,20 +484,6 @@ impl DhtState {
} }
} }
async fn send_request_and_handle_response(
self: &Arc<Self>,
request: Request,
addr: SocketAddr,
) -> anyhow::Result<()> {
let resp = self.request(request, addr).await?;
match resp {
ResponseOrError::Response(r) => self.on_response(addr, request, r),
ResponseOrError::Error(e) => {
bail!("received error: {:?}", e);
}
}
}
async fn request(&self, request: Request, addr: SocketAddr) -> anyhow::Result<ResponseOrError> { async fn request(&self, request: Request, addr: SocketAddr) -> anyhow::Result<ResponseOrError> {
self.rate_limiter.acquire_one().await; self.rate_limiter.acquire_one().await;
let (tid, message) = self.create_request(request); let (tid, message) = self.create_request(request);
@ -550,24 +549,6 @@ impl DhtState {
(transaction_id, message) (transaction_id, message)
} }
fn on_response(
self: &Arc<Self>,
addr: SocketAddr,
request: Request,
response: Response<ByteString>,
) -> anyhow::Result<()> {
self.routing_table.write().mark_response(&response.id);
match request {
Request::FindNode(id) => {
todo!()
}
Request::Ping => Ok(()),
Request::GetPeers(_info_hash) => {
todo!()
}
}
}
fn on_received_message( fn on_received_message(
self: &Arc<Self>, self: &Arc<Self>,
msg: Message<ByteString>, msg: Message<ByteString>,
@ -615,7 +596,7 @@ impl DhtState {
match request.done.send(Ok(response_or_error)) { match request.done.send(Ok(response_or_error)) {
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
warn!( debug!(
"recieved response, but the receiver task is closed: {:?}", "recieved response, but the receiver task is closed: {:?}",
e e
); );
@ -746,68 +727,22 @@ enum ResponseOrError {
struct DhtWorker { struct DhtWorker {
socket: UdpSocket, socket: UdpSocket,
peer_id: Id20, dht: Arc<DhtState>,
state: Arc<DhtState>,
} }
impl DhtWorker { impl DhtWorker {
fn on_send_error(&self, tid: u16, addr: SocketAddr, err: anyhow::Error) { fn on_send_error(&self, tid: u16, addr: SocketAddr, err: anyhow::Error) {
if let Some((_, OutstandingRequest { done })) = if let Some((_, OutstandingRequest { done })) =
self.state.inflight_by_transaction_id.remove(&(tid, addr)) self.dht.inflight_by_transaction_id.remove(&(tid, addr))
{ {
let _ = done.send(Err(err)).is_err(); let _ = done.send(Err(err)).is_err();
}; };
} }
async fn bootstrap_one_ip_with_backoff(&self, addr: SocketAddr) -> anyhow::Result<()> {
let mut backoff = ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_secs(10))
.with_multiplier(1.5)
.with_max_interval(Duration::from_secs(60))
.with_max_elapsed_time(Some(Duration::from_secs(86400)))
.build();
loop {
let res = self
.state
.send_request_and_handle_response(Request::FindNode(self.peer_id), addr)
.await;
match res {
Ok(r) => return Ok(r),
Err(e) => {
debug!("error: {:?}", e);
if let Some(backoff) = backoff.next_backoff() {
tokio::time::sleep(backoff).await;
continue;
}
bail!("given up bootstrapping, timed out")
}
}
}
}
async fn bootstrap_hostname(&self, hostname: &str) -> anyhow::Result<()> { async fn bootstrap_hostname(&self, hostname: &str) -> anyhow::Result<()> {
let addrs = tokio::net::lookup_host(hostname) RecursiveRequest::bootstrap(self.dht.clone(), self.dht.id, hostname)
.instrument(error_span!("bootstrap", hostname = hostname))
.await .await
.with_context(|| format!("error looking up {}", hostname))?;
let mut futs = FuturesUnordered::new();
for addr in addrs {
futs.push(
self.bootstrap_one_ip_with_backoff(addr)
.instrument(error_span!("addr", addr = addr.to_string())),
);
}
let requests = futs.len();
let mut successes = 0;
while let Some(resp) = futs.next().await {
if resp.is_ok() {
successes += 1
};
}
if successes == 0 {
bail!("none of the {} bootstrap requests succeded", requests);
}
Ok(())
} }
async fn bootstrap_hostname_with_backoff(&self, addr: &str) -> anyhow::Result<()> { async fn bootstrap_hostname_with_backoff(&self, addr: &str) -> anyhow::Result<()> {
@ -838,11 +773,7 @@ impl DhtWorker {
let mut futs = FuturesUnordered::new(); let mut futs = FuturesUnordered::new();
for addr in bootstrap_addrs.iter() { for addr in bootstrap_addrs.iter() {
let this = &self; futs.push(self.bootstrap_hostname_with_backoff(addr));
futs.push(
this.bootstrap_hostname_with_backoff(addr)
.instrument(error_span!("bootstrap", hostname = addr)),
);
} }
let mut successes = 0; let mut successes = 0;
while let Some(resp) = futs.next().await { while let Some(resp) = futs.next().await {
@ -937,7 +868,7 @@ impl DhtWorker {
let this = &self; let this = &self;
async move { async move {
while let Some((response, addr)) = out_rx.recv().await { while let Some((response, addr)) = out_rx.recv().await {
if let Err(e) = this.state.on_received_message(response, addr) { if let Err(e) = this.dht.on_received_message(response, addr) {
debug!("error in on_response, addr={:?}: {}", addr, e) debug!("error in on_response, addr={:?}: {}", addr, e)
} }
} }
@ -1011,11 +942,7 @@ impl DhtState {
spawn(error_span!("dht"), { spawn(error_span!("dht"), {
let state = state.clone(); let state = state.clone();
async move { async move {
let worker = DhtWorker { let worker = DhtWorker { socket, dht: state };
socket,
peer_id,
state,
};
worker.start(in_rx, &bootstrap_addrs).await?; worker.start(in_rx, &bootstrap_addrs).await?;
Ok(()) Ok(())
} }