This commit is contained in:
Igor Katson 2023-11-28 15:55:13 +00:00
parent 93740ec84b
commit 81428e30a2
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5

View file

@ -48,6 +48,12 @@ struct OutstandingRequest {
done: tokio::sync::oneshot::Sender<anyhow::Result<ResponseOrError>>,
}
pub struct WorkerSendRequest {
our_tid: Option<u16>,
message: Message<ByteString>,
addr: SocketAddr,
}
pub struct DhtState {
id: Id20,
next_transaction_id: AtomicU16,
@ -63,7 +69,7 @@ pub struct DhtState {
listen_addr: SocketAddr,
// Sending requests to the worker.
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
sender: UnboundedSender<WorkerSendRequest>,
seen_peers: DashMap<Id20, IndexSet<SocketAddr>>,
get_peers_subscribers: DashMap<Id20, tokio::sync::broadcast::Sender<SocketAddr>>,
@ -72,7 +78,7 @@ pub struct DhtState {
impl DhtState {
fn new_internal(
id: Id20,
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
sender: UnboundedSender<WorkerSendRequest>,
routing_table: Option<RoutingTable>,
listen_addr: SocketAddr,
) -> Self {
@ -121,12 +127,16 @@ impl DhtState {
}
async fn request(&self, request: Request, addr: SocketAddr) -> anyhow::Result<ResponseOrError> {
let (tid, msg) = self.create_request(request);
let (tid, message) = self.create_request(request);
let key = (tid, addr);
let (tx, rx) = tokio::sync::oneshot::channel();
self.inflight_by_transaction_id
.insert(key, OutstandingRequest { done: tx });
match self.sender.send((msg, addr)) {
match self.sender.send(WorkerSendRequest {
our_tid: Some(tid),
message,
addr,
}) {
Ok(_) => {}
Err(e) => {
self.inflight_by_transaction_id.remove(&key);
@ -270,7 +280,11 @@ impl DhtState {
}),
};
self.routing_table.write().mark_last_query(&req.id);
self.sender.send((message, addr))?;
self.sender.send(WorkerSendRequest {
our_tid: None,
message,
addr,
})?;
Ok(())
}
MessageKind::GetPeersRequest(req) => {
@ -306,7 +320,11 @@ impl DhtState {
token,
}),
};
self.sender.send((message, addr))?;
self.sender.send(WorkerSendRequest {
our_tid: None,
message,
addr,
})?;
Ok(())
}
MessageKind::FindNodeRequest(req) => {
@ -322,7 +340,11 @@ impl DhtState {
..Default::default()
}),
};
self.sender.send((message, addr))?;
self.sender.send(WorkerSendRequest {
our_tid: None,
message,
addr,
})?;
Ok(())
}
}
@ -718,28 +740,32 @@ impl DhtWorker {
async fn framer(
&self,
socket: &UdpSocket,
mut input_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
mut input_rx: UnboundedReceiver<WorkerSendRequest>,
output_tx: Sender<(Message<ByteString>, SocketAddr)>,
) -> anyhow::Result<()> {
let writer = async {
let mut buf = Vec::new();
let rate_limiter = make_rate_limiter();
while let Some((msg, addr)) = input_rx.recv().await {
while let Some(WorkerSendRequest {
our_tid,
message,
addr,
}) = input_rx.recv().await
{
rate_limiter.acquire_one().await;
trace!("{}: sending {:?}", addr, &msg);
trace!("{}: sending {:?}", addr, &message);
buf.clear();
let tid = msg.get_our_transaction_id();
bprotocol::serialize_message(
&mut buf,
msg.transaction_id,
msg.version,
msg.ip,
msg.kind,
message.transaction_id,
message.version,
message.ip,
message.kind,
)
.unwrap();
if let Err(e) = socket.send_to(&buf, addr).await {
debug!("error sending to {addr}: {e:?}");
if let Some(tid) = tid {
if let Some(tid) = our_tid {
self.on_send_error(tid, addr, e.into());
}
}
@ -779,7 +805,7 @@ impl DhtWorker {
async fn start(
self,
in_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
in_rx: UnboundedReceiver<WorkerSendRequest>,
bootstrap_addrs: &[String],
) -> anyhow::Result<()> {
let (out_tx, mut out_rx) = channel(1);