This commit is contained in:
Igor Katson 2023-11-28 15:55:13 +00:00
parent 93740ec84b
commit 81428e30a2
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5

View file

@ -48,6 +48,12 @@ struct OutstandingRequest {
done: tokio::sync::oneshot::Sender<anyhow::Result<ResponseOrError>>, done: tokio::sync::oneshot::Sender<anyhow::Result<ResponseOrError>>,
} }
pub struct WorkerSendRequest {
our_tid: Option<u16>,
message: Message<ByteString>,
addr: SocketAddr,
}
pub struct DhtState { pub struct DhtState {
id: Id20, id: Id20,
next_transaction_id: AtomicU16, next_transaction_id: AtomicU16,
@ -63,7 +69,7 @@ pub struct DhtState {
listen_addr: SocketAddr, listen_addr: SocketAddr,
// Sending requests to the worker. // Sending requests to the worker.
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>, sender: UnboundedSender<WorkerSendRequest>,
seen_peers: DashMap<Id20, IndexSet<SocketAddr>>, seen_peers: DashMap<Id20, IndexSet<SocketAddr>>,
get_peers_subscribers: DashMap<Id20, tokio::sync::broadcast::Sender<SocketAddr>>, get_peers_subscribers: DashMap<Id20, tokio::sync::broadcast::Sender<SocketAddr>>,
@ -72,7 +78,7 @@ pub struct DhtState {
impl DhtState { impl DhtState {
fn new_internal( fn new_internal(
id: Id20, id: Id20,
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>, sender: UnboundedSender<WorkerSendRequest>,
routing_table: Option<RoutingTable>, routing_table: Option<RoutingTable>,
listen_addr: SocketAddr, listen_addr: SocketAddr,
) -> Self { ) -> Self {
@ -121,12 +127,16 @@ impl DhtState {
} }
async fn request(&self, request: Request, addr: SocketAddr) -> anyhow::Result<ResponseOrError> { async fn request(&self, request: Request, addr: SocketAddr) -> anyhow::Result<ResponseOrError> {
let (tid, msg) = self.create_request(request); let (tid, message) = self.create_request(request);
let key = (tid, addr); let key = (tid, addr);
let (tx, rx) = tokio::sync::oneshot::channel(); let (tx, rx) = tokio::sync::oneshot::channel();
self.inflight_by_transaction_id self.inflight_by_transaction_id
.insert(key, OutstandingRequest { done: tx }); .insert(key, OutstandingRequest { done: tx });
match self.sender.send((msg, addr)) { match self.sender.send(WorkerSendRequest {
our_tid: Some(tid),
message,
addr,
}) {
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
self.inflight_by_transaction_id.remove(&key); self.inflight_by_transaction_id.remove(&key);
@ -270,7 +280,11 @@ impl DhtState {
}), }),
}; };
self.routing_table.write().mark_last_query(&req.id); self.routing_table.write().mark_last_query(&req.id);
self.sender.send((message, addr))?; self.sender.send(WorkerSendRequest {
our_tid: None,
message,
addr,
})?;
Ok(()) Ok(())
} }
MessageKind::GetPeersRequest(req) => { MessageKind::GetPeersRequest(req) => {
@ -306,7 +320,11 @@ impl DhtState {
token, token,
}), }),
}; };
self.sender.send((message, addr))?; self.sender.send(WorkerSendRequest {
our_tid: None,
message,
addr,
})?;
Ok(()) Ok(())
} }
MessageKind::FindNodeRequest(req) => { MessageKind::FindNodeRequest(req) => {
@ -322,7 +340,11 @@ impl DhtState {
..Default::default() ..Default::default()
}), }),
}; };
self.sender.send((message, addr))?; self.sender.send(WorkerSendRequest {
our_tid: None,
message,
addr,
})?;
Ok(()) Ok(())
} }
} }
@ -718,28 +740,32 @@ impl DhtWorker {
async fn framer( async fn framer(
&self, &self,
socket: &UdpSocket, socket: &UdpSocket,
mut input_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>, mut input_rx: UnboundedReceiver<WorkerSendRequest>,
output_tx: Sender<(Message<ByteString>, SocketAddr)>, output_tx: Sender<(Message<ByteString>, SocketAddr)>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let writer = async { let writer = async {
let mut buf = Vec::new(); let mut buf = Vec::new();
let rate_limiter = make_rate_limiter(); let rate_limiter = make_rate_limiter();
while let Some((msg, addr)) = input_rx.recv().await { while let Some(WorkerSendRequest {
our_tid,
message,
addr,
}) = input_rx.recv().await
{
rate_limiter.acquire_one().await; rate_limiter.acquire_one().await;
trace!("{}: sending {:?}", addr, &msg); trace!("{}: sending {:?}", addr, &message);
buf.clear(); buf.clear();
let tid = msg.get_our_transaction_id();
bprotocol::serialize_message( bprotocol::serialize_message(
&mut buf, &mut buf,
msg.transaction_id, message.transaction_id,
msg.version, message.version,
msg.ip, message.ip,
msg.kind, message.kind,
) )
.unwrap(); .unwrap();
if let Err(e) = socket.send_to(&buf, addr).await { if let Err(e) = socket.send_to(&buf, addr).await {
debug!("error sending to {addr}: {e:?}"); debug!("error sending to {addr}: {e:?}");
if let Some(tid) = tid { if let Some(tid) = our_tid {
self.on_send_error(tid, addr, e.into()); self.on_send_error(tid, addr, e.into());
} }
} }
@ -779,7 +805,7 @@ impl DhtWorker {
async fn start( async fn start(
self, self,
in_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>, in_rx: UnboundedReceiver<WorkerSendRequest>,
bootstrap_addrs: &[String], bootstrap_addrs: &[String],
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let (out_tx, mut out_rx) = channel(1); let (out_tx, mut out_rx) = channel(1);