UDP send errors now kill requests right away

This commit is contained in:
Igor Katson 2023-11-28 11:31:34 +00:00
parent e9b7103c26
commit 7da46d0bbf
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
3 changed files with 85 additions and 75 deletions

View file

@ -4,7 +4,7 @@ use std::{
net::{Ipv4Addr, SocketAddrV4},
};
use bencode::ByteBuf;
use bencode::{ByteBuf, ByteString};
use clone_to_owned::CloneToOwned;
use librqbit_core::id20::Id20;
use serde::{
@ -332,6 +332,16 @@ pub struct Message<BufT> {
pub kind: MessageKind<BufT>,
}
impl Message<ByteString> {
pub fn get_transaction_id(&self) -> Option<u16> {
if self.transaction_id.len() != 2 {
return None;
}
let tid = ((self.transaction_id[0] as u16) << 8) + (self.transaction_id[1] as u16);
Some(tid)
}
}
#[derive(Debug)]
pub enum MessageKind<BufT> {
Error(ErrorDescription<BufT>),

View file

@ -46,7 +46,7 @@ pub struct DhtStats {
}
struct OutstandingRequest {
done: tokio::sync::oneshot::Sender<ResponseOrError>,
done: tokio::sync::oneshot::Sender<anyhow::Result<ResponseOrError>>,
}
pub struct DhtState {
@ -134,13 +134,13 @@ impl DhtState {
}
};
match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await {
Ok(Ok(r)) => Ok(r),
Ok(Ok(r)) => r,
Ok(Err(e)) => {
self.inflight.remove(&key);
warn!("recv error, did not expect this: {:?}", e);
Err(e.into())
}
Err(e) => {
Err(_) => {
self.inflight.remove(&key);
anyhow::bail!("timeout")
}
@ -227,14 +227,7 @@ impl DhtState {
// If it's a response to a request we made, find the request task, notify it with the response,
// and let it handle it.
MessageKind::Error(_) | MessageKind::Response(_) => {
if msg.transaction_id.len() != 2 {
anyhow::bail!(
"{}: transaction id unrecognized, expected its length == 2. Message: {:?}",
addr,
msg
)
}
let tid = ((msg.transaction_id[0] as u16) << 8) + (msg.transaction_id[1] as u16);
let tid = msg.get_transaction_id().context("bad transaction id")?;
let request = match self.inflight.remove(&(tid, addr)).map(|(_, v)| v) {
Some(req) => req,
None => anyhow::bail!("outstanding request not found. Message: {:?}", msg),
@ -248,7 +241,7 @@ impl DhtState {
}
_ => unreachable!(),
};
match request.done.send(response_or_error) {
match request.done.send(Ok(response_or_error)) {
Ok(_) => {}
Err(e) => {
warn!(
@ -550,67 +543,6 @@ fn make_rate_limiter() -> RateLimiter {
.build()
}
async fn run_framer(
socket: &UdpSocket,
mut input_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
output_tx: Sender<(Message<ByteString>, SocketAddr)>,
) -> anyhow::Result<()> {
let writer = async {
let mut buf = Vec::new();
let rate_limiter = make_rate_limiter();
while let Some((msg, addr)) = input_rx.recv().await {
let addr = match addr {
SocketAddr::V4(v4) => v4,
SocketAddr::V6(_) => continue,
};
rate_limiter.acquire_one().await;
trace!("{}: sending {:?}", addr, &msg);
buf.clear();
bprotocol::serialize_message(
&mut buf,
msg.transaction_id,
msg.version,
msg.ip,
msg.kind,
)
.unwrap();
if let Err(e) = socket.send_to(&buf, addr).await {
warn!("could not send to {:?}: {}", addr, e)
}
}
Err::<(), _>(anyhow::anyhow!(
"DHT UDP socket writer over, nowhere to read messages from"
))
};
let reader = async {
let mut buf = vec![0u8; 16384];
loop {
let (size, addr) = socket
.recv_from(&mut buf)
.await
.context("error reading from UDP socket")?;
match bprotocol::deserialize_message::<ByteString>(&buf[..size]) {
Ok(msg) => {
trace!("{}: received {:?}", addr, &msg);
match output_tx.send((msg, addr)).await {
Ok(_) => {}
Err(_) => break,
}
}
Err(e) => debug!("{}: error deserializing incoming message: {}", addr, e),
}
}
Err::<(), _>(anyhow::anyhow!(
"DHT UDP socket reader over, nowhere to send responses to"
))
};
let result = tokio::select! {
err = writer => err,
err = reader => err,
};
result.context("DHT UDP framer closed")
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum Request {
GetPeers(Id20),
@ -635,6 +567,12 @@ impl DhtWorker {
self.state.on_incoming_from_remote(msg, addr)
}
fn on_send_error(&self, tid: u16, addr: SocketAddr, err: anyhow::Error) {
if let Some((_, OutstandingRequest { done })) = self.state.inflight.remove(&(tid, addr)) {
let _ = done.send(Err(err)).is_err();
};
}
async fn bootstrap_one_ip_with_backoff(&self, addr: SocketAddr) -> anyhow::Result<()> {
let mut backoff = ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_secs(10))
@ -732,13 +670,74 @@ impl DhtWorker {
Ok(())
}
async fn framer(
&self,
socket: &UdpSocket,
mut input_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
output_tx: Sender<(Message<ByteString>, SocketAddr)>,
) -> anyhow::Result<()> {
let writer = async {
let mut buf = Vec::new();
let rate_limiter = make_rate_limiter();
while let Some((msg, addr)) = input_rx.recv().await {
rate_limiter.acquire_one().await;
trace!("{}: sending {:?}", addr, &msg);
buf.clear();
let tid = msg.get_transaction_id().unwrap();
bprotocol::serialize_message(
&mut buf,
msg.transaction_id,
msg.version,
msg.ip,
msg.kind,
)
.unwrap();
if let Err(e) = socket.send_to(&buf, addr).await {
self.on_send_error(tid, addr, e.into());
}
}
Err::<(), _>(anyhow::anyhow!(
"DHT UDP socket writer over, nowhere to read messages from"
))
};
let reader = async {
let mut buf = vec![0u8; 16384];
loop {
let (size, addr) = socket
.recv_from(&mut buf)
.await
.context("error reading from UDP socket")?;
match bprotocol::deserialize_message::<ByteString>(&buf[..size]) {
Ok(msg) => {
trace!("{}: received {:?}", addr, &msg);
match output_tx.send((msg, addr)).await {
Ok(_) => {}
Err(_) => break,
}
}
Err(e) => debug!("{}: error deserializing incoming message: {}", addr, e),
}
}
Err::<(), _>(anyhow::anyhow!(
"DHT UDP socket reader over, nowhere to send responses to"
))
};
let result = tokio::select! {
err = writer => err,
err = reader => err,
};
result.context("DHT UDP framer closed")
}
async fn start(
self,
in_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
bootstrap_addrs: &[String],
) -> anyhow::Result<()> {
let (out_tx, mut out_rx) = channel(1);
let framer = run_framer(&self.socket, in_rx, out_tx).instrument(debug_span!("dht_framer"));
let framer = self
.framer(&self.socket, in_rx, out_tx)
.instrument(debug_span!("dht_framer"));
let bootstrap = self.bootstrap(bootstrap_addrs);
let mut bootstrap_done = false;

View file

@ -218,6 +218,7 @@ fn init_logging(opts: &Opts) -> tokio::sync::mpsc::UnboundedSender<String> {
layered
.with(
fmt::layer()
.with_ansi(false)
.with_writer(log_file)
.with_filter(EnvFilter::builder().parse(&opts.log_file_rust_log).unwrap()),
)