Fixing up bugs, refactored DHT works alright now
This commit is contained in:
parent
69b9918e4f
commit
aa2a41a53c
2 changed files with 90 additions and 162 deletions
|
|
@ -36,6 +36,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
let mut f = std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.truncate(true)
|
||||
.open(filename)
|
||||
.unwrap();
|
||||
serde_json::to_writer_pretty(&mut f, r).unwrap();
|
||||
|
|
|
|||
|
|
@ -81,15 +81,10 @@ fn make_rate_limiter() -> RateLimiter {
|
|||
}
|
||||
|
||||
trait RecursiveRequestCallbacks: Sized + Send + Sync + 'static {
|
||||
fn on_request_start(
|
||||
&self,
|
||||
req: &Arc<RecursiveRequest<Self>>,
|
||||
target_node: Id20,
|
||||
addr: SocketAddr,
|
||||
);
|
||||
fn on_request_start(&self, req: &RecursiveRequest<Self>, target_node: Id20, addr: SocketAddr);
|
||||
fn on_request_end(
|
||||
&self,
|
||||
req: &Arc<RecursiveRequest<Self>>,
|
||||
req: &RecursiveRequest<Self>,
|
||||
target_node: Id20,
|
||||
addr: SocketAddr,
|
||||
resp: &anyhow::Result<ResponseOrError>,
|
||||
|
|
@ -98,11 +93,11 @@ trait RecursiveRequestCallbacks: Sized + Send + Sync + 'static {
|
|||
|
||||
struct RecursiveRequestCallbacksGetPeers {}
|
||||
impl RecursiveRequestCallbacks for RecursiveRequestCallbacksGetPeers {
|
||||
fn on_request_start(&self, _: &Arc<RecursiveRequest<Self>>, _: Id20, _: SocketAddr) {}
|
||||
fn on_request_start(&self, _: &RecursiveRequest<Self>, _: Id20, _: SocketAddr) {}
|
||||
|
||||
fn on_request_end(
|
||||
&self,
|
||||
_: &Arc<RecursiveRequest<Self>>,
|
||||
_: &RecursiveRequest<Self>,
|
||||
_: Id20,
|
||||
_: SocketAddr,
|
||||
_: &anyhow::Result<ResponseOrError>,
|
||||
|
|
@ -112,12 +107,7 @@ impl RecursiveRequestCallbacks for RecursiveRequestCallbacksGetPeers {
|
|||
|
||||
struct RecursiveRequestCallbacksFindNodes {}
|
||||
impl RecursiveRequestCallbacks for RecursiveRequestCallbacksFindNodes {
|
||||
fn on_request_start(
|
||||
&self,
|
||||
req: &Arc<RecursiveRequest<Self>>,
|
||||
target_node: Id20,
|
||||
addr: SocketAddr,
|
||||
) {
|
||||
fn on_request_start(&self, req: &RecursiveRequest<Self>, target_node: Id20, addr: SocketAddr) {
|
||||
match req.dht.routing_table_add_node(target_node, addr) {
|
||||
InsertResult::WasExisting | InsertResult::ReplacedBad(_) | InsertResult::Added => {
|
||||
req.dht
|
||||
|
|
@ -131,7 +121,7 @@ impl RecursiveRequestCallbacks for RecursiveRequestCallbacksFindNodes {
|
|||
|
||||
fn on_request_end(
|
||||
&self,
|
||||
req: &Arc<RecursiveRequest<Self>>,
|
||||
req: &RecursiveRequest<Self>,
|
||||
target_node: Id20,
|
||||
_addr: SocketAddr,
|
||||
resp: &anyhow::Result<ResponseOrError>,
|
||||
|
|
@ -150,8 +140,8 @@ struct RecursiveRequest<C: RecursiveRequestCallbacks> {
|
|||
request: Request,
|
||||
dht: Arc<DhtState>,
|
||||
useful_nodes: RwLock<Vec<MaybeUsefulNode>>,
|
||||
// peer_tx: tokio::sync::mpsc::UnboundedSender<SocketAddr>,
|
||||
// node_tx: tokio::sync::mpsc::UnboundedSender<(Option<Id20>, SocketAddr)>,
|
||||
peer_tx: tokio::sync::mpsc::UnboundedSender<SocketAddr>,
|
||||
node_tx: tokio::sync::mpsc::UnboundedSender<(Option<Id20>, SocketAddr)>,
|
||||
callbacks: C,
|
||||
}
|
||||
|
||||
|
|
@ -169,11 +159,11 @@ impl RequestPeersStream {
|
|||
request: Request::GetPeers(info_hash),
|
||||
dht,
|
||||
useful_nodes: RwLock::new(Vec::new()),
|
||||
// peer_tx,
|
||||
// node_tx,
|
||||
peer_tx,
|
||||
node_tx,
|
||||
callbacks: RecursiveRequestCallbacksGetPeers {},
|
||||
});
|
||||
let join_handle = rp.clone().request_peers_forever(node_rx, node_tx, peer_tx);
|
||||
let join_handle = rp.request_peers_forever(node_rx);
|
||||
Self {
|
||||
rx: peer_rx,
|
||||
cancel_join_handle: join_handle,
|
||||
|
|
@ -199,77 +189,101 @@ impl Stream for RequestPeersStream {
|
|||
}
|
||||
|
||||
impl RecursiveRequest<RecursiveRequestCallbacksFindNodes> {
|
||||
async fn find_node(
|
||||
dht: Arc<DhtState>,
|
||||
target: Id20,
|
||||
root_addrs: impl Iterator<Item = SocketAddr>,
|
||||
) -> anyhow::Result<()> {
|
||||
let (peer_tx, peer_rx) = unbounded_channel();
|
||||
drop(peer_rx);
|
||||
|
||||
async fn bootstrap(dht: Arc<DhtState>, target: Id20, hostname: &str) -> anyhow::Result<()> {
|
||||
let addrs = tokio::net::lookup_host(hostname)
|
||||
.await
|
||||
.with_context(|| format!("error looking up {}", hostname))?;
|
||||
let (node_tx, mut node_rx) = unbounded_channel();
|
||||
let req = Arc::new(RecursiveRequest {
|
||||
let req = RecursiveRequest {
|
||||
info_hash: target,
|
||||
request: Request::FindNode(target),
|
||||
dht,
|
||||
useful_nodes: RwLock::new(Vec::new()),
|
||||
// peer_tx: unbounded_channel().0,
|
||||
// node_tx,
|
||||
peer_tx: unbounded_channel().0,
|
||||
node_tx,
|
||||
callbacks: RecursiveRequestCallbacksFindNodes {},
|
||||
});
|
||||
};
|
||||
|
||||
let request_one = |id, addr| {
|
||||
req.request_one(id, addr)
|
||||
.map_err(|e| {
|
||||
debug!("error: {e:?}");
|
||||
e
|
||||
})
|
||||
.instrument(error_span!(
|
||||
"find_node",
|
||||
target = format!("{target:?}"),
|
||||
addr = addr.to_string()
|
||||
))
|
||||
};
|
||||
|
||||
let mut futs = FuturesUnordered::new();
|
||||
|
||||
for addr in root_addrs {
|
||||
node_tx.send((None, addr)).unwrap();
|
||||
let mut initial_addrs = 0;
|
||||
for addr in addrs {
|
||||
futs.push(request_one(None, addr));
|
||||
initial_addrs += 1;
|
||||
}
|
||||
|
||||
let mut successes = 0;
|
||||
let mut errors = 0;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
r = node_rx.recv() => {
|
||||
let (id, addr) = r.unwrap();
|
||||
futs.push(
|
||||
req.request_one(id, addr, node_tx.clone(), peer_tx.clone())
|
||||
.instrument(
|
||||
error_span!("find_node", target=format!("{target:?}"), addr=addr.to_string())
|
||||
)
|
||||
)
|
||||
futs.push(request_one(id, addr))
|
||||
},
|
||||
Some(f) = futs.next(), if !futs.is_empty() => {
|
||||
if let Err(e) = f {
|
||||
error!("error: {e:?}");
|
||||
f = futs.next() => {
|
||||
let f = match f {
|
||||
Some(f) => f,
|
||||
None => {
|
||||
// find_node recursion finished.
|
||||
break;
|
||||
}
|
||||
};
|
||||
if f.is_ok() {
|
||||
successes += 1;
|
||||
} else {
|
||||
errors += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if successes == 0 {
|
||||
bail!("no successful lookups, errors = {errors}");
|
||||
}
|
||||
debug!(
|
||||
"finished, successes = {successes}, errors = {errors}, initial_addrs = {initial_addrs}"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl RecursiveRequest<RecursiveRequestCallbacksGetPeers> {
|
||||
fn request_peers_forever(
|
||||
self: Arc<Self>,
|
||||
self: &Arc<Self>,
|
||||
mut node_rx: tokio::sync::mpsc::UnboundedReceiver<(Option<Id20>, SocketAddr)>,
|
||||
node_tx: tokio::sync::mpsc::UnboundedSender<(Option<Id20>, SocketAddr)>,
|
||||
peer_tx: tokio::sync::mpsc::UnboundedSender<SocketAddr>,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
let this = self.clone();
|
||||
spawn(
|
||||
error_span!("get_peers", info_hash = format!("{:?}", self.info_hash)),
|
||||
error_span!(parent: None, "get_peers", info_hash = format!("{:?}", self.info_hash)),
|
||||
async move {
|
||||
let this = &this;
|
||||
// Looper adds root nodes to the queue every 60 seconds.
|
||||
let looper = {
|
||||
let this = self.clone();
|
||||
let node_tx = node_tx.clone();
|
||||
async move {
|
||||
let mut iteration = 0;
|
||||
loop {
|
||||
debug!("iteration {}", iteration);
|
||||
let sleep = match this.get_peers_root(&node_tx) {
|
||||
let sleep = match this.get_peers_root() {
|
||||
Ok(0) => Duration::from_secs(1),
|
||||
Ok(n) if n < 8 => REQUERY_INTERVAL / 2,
|
||||
Ok(_) => REQUERY_INTERVAL,
|
||||
Err(e) => {
|
||||
error!("error: {e:?}");
|
||||
error!("error in get_peers_root(): {e:?}");
|
||||
return Err::<(), anyhow::Error>(e);
|
||||
}
|
||||
};
|
||||
|
|
@ -286,7 +300,7 @@ impl RecursiveRequest<RecursiveRequestCallbacksGetPeers> {
|
|||
addr = node_rx.recv() => {
|
||||
let (id, addr) = addr.unwrap();
|
||||
futs.push(
|
||||
self.request_one(id, addr, node_tx.clone(), peer_tx.clone())
|
||||
this.request_one(id, addr)
|
||||
.map_err(|e| debug!("error: {e:?}"))
|
||||
.instrument(error_span!("addr", addr=addr.to_string()))
|
||||
);
|
||||
|
|
@ -299,10 +313,7 @@ impl RecursiveRequest<RecursiveRequestCallbacksGetPeers> {
|
|||
)
|
||||
}
|
||||
|
||||
fn get_peers_root(
|
||||
self: &Arc<Self>,
|
||||
node_tx: &UnboundedSender<(Option<Id20>, SocketAddr)>,
|
||||
) -> anyhow::Result<usize> {
|
||||
fn get_peers_root(&self) -> anyhow::Result<usize> {
|
||||
let mut count = 0;
|
||||
for (id, addr) in self
|
||||
.dht
|
||||
|
|
@ -314,20 +325,14 @@ impl RecursiveRequest<RecursiveRequestCallbacksGetPeers> {
|
|||
.take(8)
|
||||
{
|
||||
count += 1;
|
||||
node_tx.send((Some(id), addr))?;
|
||||
self.node_tx.send((Some(id), addr))?;
|
||||
}
|
||||
Ok(count)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: RecursiveRequestCallbacks> RecursiveRequest<C> {
|
||||
async fn request_one<'a>(
|
||||
self: &'a Arc<Self>,
|
||||
id: Option<Id20>,
|
||||
addr: SocketAddr,
|
||||
node_tx: UnboundedSender<(Option<Id20>, SocketAddr)>,
|
||||
peer_tx: UnboundedSender<SocketAddr>,
|
||||
) -> anyhow::Result<()> {
|
||||
async fn request_one(&self, id: Option<Id20>, addr: SocketAddr) -> anyhow::Result<()> {
|
||||
if let Some(id) = id {
|
||||
self.callbacks.on_request_start(self, id, addr);
|
||||
}
|
||||
|
|
@ -348,18 +353,26 @@ impl<C: RecursiveRequestCallbacks> RecursiveRequest<C> {
|
|||
return Err(e);
|
||||
}
|
||||
};
|
||||
trace!("received {response:?}");
|
||||
|
||||
if let Some(peers) = response.values {
|
||||
for peer in peers {
|
||||
peer_tx.send(SocketAddr::V4(peer.addr))?;
|
||||
self.peer_tx.send(SocketAddr::V4(peer.addr))?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(nodes) = response.nodes {
|
||||
for node in nodes.nodes {
|
||||
let addr = SocketAddr::V4(node.addr);
|
||||
if self.should_request_node(node.id, addr) {
|
||||
node_tx.send((Some(node.id), addr))?;
|
||||
let should_request = self.should_request_node(node.id, addr);
|
||||
trace!(
|
||||
"should_request={}, id={:?}, addr={}",
|
||||
should_request,
|
||||
node.id,
|
||||
addr
|
||||
);
|
||||
if should_request {
|
||||
self.node_tx.send((Some(node.id), addr))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -471,20 +484,6 @@ impl DhtState {
|
|||
}
|
||||
}
|
||||
|
||||
async fn send_request_and_handle_response(
|
||||
self: &Arc<Self>,
|
||||
request: Request,
|
||||
addr: SocketAddr,
|
||||
) -> anyhow::Result<()> {
|
||||
let resp = self.request(request, addr).await?;
|
||||
match resp {
|
||||
ResponseOrError::Response(r) => self.on_response(addr, request, r),
|
||||
ResponseOrError::Error(e) => {
|
||||
bail!("received error: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn request(&self, request: Request, addr: SocketAddr) -> anyhow::Result<ResponseOrError> {
|
||||
self.rate_limiter.acquire_one().await;
|
||||
let (tid, message) = self.create_request(request);
|
||||
|
|
@ -550,24 +549,6 @@ impl DhtState {
|
|||
(transaction_id, message)
|
||||
}
|
||||
|
||||
fn on_response(
|
||||
self: &Arc<Self>,
|
||||
addr: SocketAddr,
|
||||
request: Request,
|
||||
response: Response<ByteString>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.routing_table.write().mark_response(&response.id);
|
||||
match request {
|
||||
Request::FindNode(id) => {
|
||||
todo!()
|
||||
}
|
||||
Request::Ping => Ok(()),
|
||||
Request::GetPeers(_info_hash) => {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn on_received_message(
|
||||
self: &Arc<Self>,
|
||||
msg: Message<ByteString>,
|
||||
|
|
@ -615,7 +596,7 @@ impl DhtState {
|
|||
match request.done.send(Ok(response_or_error)) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
debug!(
|
||||
"recieved response, but the receiver task is closed: {:?}",
|
||||
e
|
||||
);
|
||||
|
|
@ -746,68 +727,22 @@ enum ResponseOrError {
|
|||
|
||||
struct DhtWorker {
|
||||
socket: UdpSocket,
|
||||
peer_id: Id20,
|
||||
state: Arc<DhtState>,
|
||||
dht: Arc<DhtState>,
|
||||
}
|
||||
|
||||
impl DhtWorker {
|
||||
fn on_send_error(&self, tid: u16, addr: SocketAddr, err: anyhow::Error) {
|
||||
if let Some((_, OutstandingRequest { done })) =
|
||||
self.state.inflight_by_transaction_id.remove(&(tid, addr))
|
||||
self.dht.inflight_by_transaction_id.remove(&(tid, addr))
|
||||
{
|
||||
let _ = done.send(Err(err)).is_err();
|
||||
};
|
||||
}
|
||||
|
||||
async fn bootstrap_one_ip_with_backoff(&self, addr: SocketAddr) -> anyhow::Result<()> {
|
||||
let mut backoff = ExponentialBackoffBuilder::new()
|
||||
.with_initial_interval(Duration::from_secs(10))
|
||||
.with_multiplier(1.5)
|
||||
.with_max_interval(Duration::from_secs(60))
|
||||
.with_max_elapsed_time(Some(Duration::from_secs(86400)))
|
||||
.build();
|
||||
|
||||
loop {
|
||||
let res = self
|
||||
.state
|
||||
.send_request_and_handle_response(Request::FindNode(self.peer_id), addr)
|
||||
.await;
|
||||
match res {
|
||||
Ok(r) => return Ok(r),
|
||||
Err(e) => {
|
||||
debug!("error: {:?}", e);
|
||||
if let Some(backoff) = backoff.next_backoff() {
|
||||
tokio::time::sleep(backoff).await;
|
||||
continue;
|
||||
}
|
||||
bail!("given up bootstrapping, timed out")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn bootstrap_hostname(&self, hostname: &str) -> anyhow::Result<()> {
|
||||
let addrs = tokio::net::lookup_host(hostname)
|
||||
RecursiveRequest::bootstrap(self.dht.clone(), self.dht.id, hostname)
|
||||
.instrument(error_span!("bootstrap", hostname = hostname))
|
||||
.await
|
||||
.with_context(|| format!("error looking up {}", hostname))?;
|
||||
let mut futs = FuturesUnordered::new();
|
||||
for addr in addrs {
|
||||
futs.push(
|
||||
self.bootstrap_one_ip_with_backoff(addr)
|
||||
.instrument(error_span!("addr", addr = addr.to_string())),
|
||||
);
|
||||
}
|
||||
let requests = futs.len();
|
||||
let mut successes = 0;
|
||||
while let Some(resp) = futs.next().await {
|
||||
if resp.is_ok() {
|
||||
successes += 1
|
||||
};
|
||||
}
|
||||
if successes == 0 {
|
||||
bail!("none of the {} bootstrap requests succeded", requests);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn bootstrap_hostname_with_backoff(&self, addr: &str) -> anyhow::Result<()> {
|
||||
|
|
@ -838,11 +773,7 @@ impl DhtWorker {
|
|||
let mut futs = FuturesUnordered::new();
|
||||
|
||||
for addr in bootstrap_addrs.iter() {
|
||||
let this = &self;
|
||||
futs.push(
|
||||
this.bootstrap_hostname_with_backoff(addr)
|
||||
.instrument(error_span!("bootstrap", hostname = addr)),
|
||||
);
|
||||
futs.push(self.bootstrap_hostname_with_backoff(addr));
|
||||
}
|
||||
let mut successes = 0;
|
||||
while let Some(resp) = futs.next().await {
|
||||
|
|
@ -937,7 +868,7 @@ impl DhtWorker {
|
|||
let this = &self;
|
||||
async move {
|
||||
while let Some((response, addr)) = out_rx.recv().await {
|
||||
if let Err(e) = this.state.on_received_message(response, addr) {
|
||||
if let Err(e) = this.dht.on_received_message(response, addr) {
|
||||
debug!("error in on_response, addr={:?}: {}", addr, e)
|
||||
}
|
||||
}
|
||||
|
|
@ -1011,11 +942,7 @@ impl DhtState {
|
|||
spawn(error_span!("dht"), {
|
||||
let state = state.clone();
|
||||
async move {
|
||||
let worker = DhtWorker {
|
||||
socket,
|
||||
peer_id,
|
||||
state,
|
||||
};
|
||||
let worker = DhtWorker { socket, dht: state };
|
||||
worker.start(in_rx, &bootstrap_addrs).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue