From 81428e30a24cd2ad7934b68875ad9dcbce206ce1 Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Tue, 28 Nov 2023 15:55:13 +0000 Subject: [PATCH] Nothing --- crates/dht/src/dht.rs | 60 +++++++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index a5d2ec8..bfa5018 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -48,6 +48,12 @@ struct OutstandingRequest { done: tokio::sync::oneshot::Sender>, } +pub struct WorkerSendRequest { + our_tid: Option, + message: Message, + addr: SocketAddr, +} + pub struct DhtState { id: Id20, next_transaction_id: AtomicU16, @@ -63,7 +69,7 @@ pub struct DhtState { listen_addr: SocketAddr, // Sending requests to the worker. - sender: UnboundedSender<(Message, SocketAddr)>, + sender: UnboundedSender, seen_peers: DashMap>, get_peers_subscribers: DashMap>, @@ -72,7 +78,7 @@ pub struct DhtState { impl DhtState { fn new_internal( id: Id20, - sender: UnboundedSender<(Message, SocketAddr)>, + sender: UnboundedSender, routing_table: Option, listen_addr: SocketAddr, ) -> Self { @@ -121,12 +127,16 @@ impl DhtState { } async fn request(&self, request: Request, addr: SocketAddr) -> anyhow::Result { - let (tid, msg) = self.create_request(request); + let (tid, message) = self.create_request(request); let key = (tid, addr); let (tx, rx) = tokio::sync::oneshot::channel(); self.inflight_by_transaction_id .insert(key, OutstandingRequest { done: tx }); - match self.sender.send((msg, addr)) { + match self.sender.send(WorkerSendRequest { + our_tid: Some(tid), + message, + addr, + }) { Ok(_) => {} Err(e) => { self.inflight_by_transaction_id.remove(&key); @@ -270,7 +280,11 @@ impl DhtState { }), }; self.routing_table.write().mark_last_query(&req.id); - self.sender.send((message, addr))?; + self.sender.send(WorkerSendRequest { + our_tid: None, + message, + addr, + })?; Ok(()) } MessageKind::GetPeersRequest(req) => { @@ -306,7 +320,11 @@ impl DhtState { token, }), }; - self.sender.send((message, addr))?; + self.sender.send(WorkerSendRequest { + our_tid: None, + message, + addr, + })?; Ok(()) } MessageKind::FindNodeRequest(req) => { @@ -322,7 +340,11 @@ impl DhtState { ..Default::default() }), }; - self.sender.send((message, addr))?; + self.sender.send(WorkerSendRequest { + our_tid: None, + message, + addr, + })?; Ok(()) } } @@ -718,28 +740,32 @@ impl DhtWorker { async fn framer( &self, socket: &UdpSocket, - mut input_rx: UnboundedReceiver<(Message, SocketAddr)>, + mut input_rx: UnboundedReceiver, output_tx: Sender<(Message, SocketAddr)>, ) -> anyhow::Result<()> { let writer = async { let mut buf = Vec::new(); let rate_limiter = make_rate_limiter(); - while let Some((msg, addr)) = input_rx.recv().await { + while let Some(WorkerSendRequest { + our_tid, + message, + addr, + }) = input_rx.recv().await + { rate_limiter.acquire_one().await; - trace!("{}: sending {:?}", addr, &msg); + trace!("{}: sending {:?}", addr, &message); buf.clear(); - let tid = msg.get_our_transaction_id(); bprotocol::serialize_message( &mut buf, - msg.transaction_id, - msg.version, - msg.ip, - msg.kind, + message.transaction_id, + message.version, + message.ip, + message.kind, ) .unwrap(); if let Err(e) = socket.send_to(&buf, addr).await { debug!("error sending to {addr}: {e:?}"); - if let Some(tid) = tid { + if let Some(tid) = our_tid { self.on_send_error(tid, addr, e.into()); } } @@ -779,7 +805,7 @@ impl DhtWorker { async fn start( self, - in_rx: UnboundedReceiver<(Message, SocketAddr)>, + in_rx: UnboundedReceiver, bootstrap_addrs: &[String], ) -> anyhow::Result<()> { let (out_tx, mut out_rx) = channel(1);