UDP send errors now kill requests right away
This commit is contained in:
parent
e9b7103c26
commit
7da46d0bbf
3 changed files with 85 additions and 75 deletions
|
|
@ -4,7 +4,7 @@ use std::{
|
|||
net::{Ipv4Addr, SocketAddrV4},
|
||||
};
|
||||
|
||||
use bencode::ByteBuf;
|
||||
use bencode::{ByteBuf, ByteString};
|
||||
use clone_to_owned::CloneToOwned;
|
||||
use librqbit_core::id20::Id20;
|
||||
use serde::{
|
||||
|
|
@ -332,6 +332,16 @@ pub struct Message<BufT> {
|
|||
pub kind: MessageKind<BufT>,
|
||||
}
|
||||
|
||||
impl Message<ByteString> {
|
||||
pub fn get_transaction_id(&self) -> Option<u16> {
|
||||
if self.transaction_id.len() != 2 {
|
||||
return None;
|
||||
}
|
||||
let tid = ((self.transaction_id[0] as u16) << 8) + (self.transaction_id[1] as u16);
|
||||
Some(tid)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum MessageKind<BufT> {
|
||||
Error(ErrorDescription<BufT>),
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ pub struct DhtStats {
|
|||
}
|
||||
|
||||
struct OutstandingRequest {
|
||||
done: tokio::sync::oneshot::Sender<ResponseOrError>,
|
||||
done: tokio::sync::oneshot::Sender<anyhow::Result<ResponseOrError>>,
|
||||
}
|
||||
|
||||
pub struct DhtState {
|
||||
|
|
@ -134,13 +134,13 @@ impl DhtState {
|
|||
}
|
||||
};
|
||||
match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await {
|
||||
Ok(Ok(r)) => Ok(r),
|
||||
Ok(Ok(r)) => r,
|
||||
Ok(Err(e)) => {
|
||||
self.inflight.remove(&key);
|
||||
warn!("recv error, did not expect this: {:?}", e);
|
||||
Err(e.into())
|
||||
}
|
||||
Err(e) => {
|
||||
Err(_) => {
|
||||
self.inflight.remove(&key);
|
||||
anyhow::bail!("timeout")
|
||||
}
|
||||
|
|
@ -227,14 +227,7 @@ impl DhtState {
|
|||
// If it's a response to a request we made, find the request task, notify it with the response,
|
||||
// and let it handle it.
|
||||
MessageKind::Error(_) | MessageKind::Response(_) => {
|
||||
if msg.transaction_id.len() != 2 {
|
||||
anyhow::bail!(
|
||||
"{}: transaction id unrecognized, expected its length == 2. Message: {:?}",
|
||||
addr,
|
||||
msg
|
||||
)
|
||||
}
|
||||
let tid = ((msg.transaction_id[0] as u16) << 8) + (msg.transaction_id[1] as u16);
|
||||
let tid = msg.get_transaction_id().context("bad transaction id")?;
|
||||
let request = match self.inflight.remove(&(tid, addr)).map(|(_, v)| v) {
|
||||
Some(req) => req,
|
||||
None => anyhow::bail!("outstanding request not found. Message: {:?}", msg),
|
||||
|
|
@ -248,7 +241,7 @@ impl DhtState {
|
|||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
match request.done.send(response_or_error) {
|
||||
match request.done.send(Ok(response_or_error)) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
|
|
@ -550,67 +543,6 @@ fn make_rate_limiter() -> RateLimiter {
|
|||
.build()
|
||||
}
|
||||
|
||||
async fn run_framer(
|
||||
socket: &UdpSocket,
|
||||
mut input_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
|
||||
output_tx: Sender<(Message<ByteString>, SocketAddr)>,
|
||||
) -> anyhow::Result<()> {
|
||||
let writer = async {
|
||||
let mut buf = Vec::new();
|
||||
let rate_limiter = make_rate_limiter();
|
||||
while let Some((msg, addr)) = input_rx.recv().await {
|
||||
let addr = match addr {
|
||||
SocketAddr::V4(v4) => v4,
|
||||
SocketAddr::V6(_) => continue,
|
||||
};
|
||||
rate_limiter.acquire_one().await;
|
||||
trace!("{}: sending {:?}", addr, &msg);
|
||||
buf.clear();
|
||||
bprotocol::serialize_message(
|
||||
&mut buf,
|
||||
msg.transaction_id,
|
||||
msg.version,
|
||||
msg.ip,
|
||||
msg.kind,
|
||||
)
|
||||
.unwrap();
|
||||
if let Err(e) = socket.send_to(&buf, addr).await {
|
||||
warn!("could not send to {:?}: {}", addr, e)
|
||||
}
|
||||
}
|
||||
Err::<(), _>(anyhow::anyhow!(
|
||||
"DHT UDP socket writer over, nowhere to read messages from"
|
||||
))
|
||||
};
|
||||
let reader = async {
|
||||
let mut buf = vec![0u8; 16384];
|
||||
loop {
|
||||
let (size, addr) = socket
|
||||
.recv_from(&mut buf)
|
||||
.await
|
||||
.context("error reading from UDP socket")?;
|
||||
match bprotocol::deserialize_message::<ByteString>(&buf[..size]) {
|
||||
Ok(msg) => {
|
||||
trace!("{}: received {:?}", addr, &msg);
|
||||
match output_tx.send((msg, addr)).await {
|
||||
Ok(_) => {}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
Err(e) => debug!("{}: error deserializing incoming message: {}", addr, e),
|
||||
}
|
||||
}
|
||||
Err::<(), _>(anyhow::anyhow!(
|
||||
"DHT UDP socket reader over, nowhere to send responses to"
|
||||
))
|
||||
};
|
||||
let result = tokio::select! {
|
||||
err = writer => err,
|
||||
err = reader => err,
|
||||
};
|
||||
result.context("DHT UDP framer closed")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
enum Request {
|
||||
GetPeers(Id20),
|
||||
|
|
@ -635,6 +567,12 @@ impl DhtWorker {
|
|||
self.state.on_incoming_from_remote(msg, addr)
|
||||
}
|
||||
|
||||
fn on_send_error(&self, tid: u16, addr: SocketAddr, err: anyhow::Error) {
|
||||
if let Some((_, OutstandingRequest { done })) = self.state.inflight.remove(&(tid, addr)) {
|
||||
let _ = done.send(Err(err)).is_err();
|
||||
};
|
||||
}
|
||||
|
||||
async fn bootstrap_one_ip_with_backoff(&self, addr: SocketAddr) -> anyhow::Result<()> {
|
||||
let mut backoff = ExponentialBackoffBuilder::new()
|
||||
.with_initial_interval(Duration::from_secs(10))
|
||||
|
|
@ -732,13 +670,74 @@ impl DhtWorker {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn framer(
|
||||
&self,
|
||||
socket: &UdpSocket,
|
||||
mut input_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
|
||||
output_tx: Sender<(Message<ByteString>, SocketAddr)>,
|
||||
) -> anyhow::Result<()> {
|
||||
let writer = async {
|
||||
let mut buf = Vec::new();
|
||||
let rate_limiter = make_rate_limiter();
|
||||
while let Some((msg, addr)) = input_rx.recv().await {
|
||||
rate_limiter.acquire_one().await;
|
||||
trace!("{}: sending {:?}", addr, &msg);
|
||||
buf.clear();
|
||||
let tid = msg.get_transaction_id().unwrap();
|
||||
bprotocol::serialize_message(
|
||||
&mut buf,
|
||||
msg.transaction_id,
|
||||
msg.version,
|
||||
msg.ip,
|
||||
msg.kind,
|
||||
)
|
||||
.unwrap();
|
||||
if let Err(e) = socket.send_to(&buf, addr).await {
|
||||
self.on_send_error(tid, addr, e.into());
|
||||
}
|
||||
}
|
||||
Err::<(), _>(anyhow::anyhow!(
|
||||
"DHT UDP socket writer over, nowhere to read messages from"
|
||||
))
|
||||
};
|
||||
let reader = async {
|
||||
let mut buf = vec![0u8; 16384];
|
||||
loop {
|
||||
let (size, addr) = socket
|
||||
.recv_from(&mut buf)
|
||||
.await
|
||||
.context("error reading from UDP socket")?;
|
||||
match bprotocol::deserialize_message::<ByteString>(&buf[..size]) {
|
||||
Ok(msg) => {
|
||||
trace!("{}: received {:?}", addr, &msg);
|
||||
match output_tx.send((msg, addr)).await {
|
||||
Ok(_) => {}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
Err(e) => debug!("{}: error deserializing incoming message: {}", addr, e),
|
||||
}
|
||||
}
|
||||
Err::<(), _>(anyhow::anyhow!(
|
||||
"DHT UDP socket reader over, nowhere to send responses to"
|
||||
))
|
||||
};
|
||||
let result = tokio::select! {
|
||||
err = writer => err,
|
||||
err = reader => err,
|
||||
};
|
||||
result.context("DHT UDP framer closed")
|
||||
}
|
||||
|
||||
async fn start(
|
||||
self,
|
||||
in_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
|
||||
bootstrap_addrs: &[String],
|
||||
) -> anyhow::Result<()> {
|
||||
let (out_tx, mut out_rx) = channel(1);
|
||||
let framer = run_framer(&self.socket, in_rx, out_tx).instrument(debug_span!("dht_framer"));
|
||||
let framer = self
|
||||
.framer(&self.socket, in_rx, out_tx)
|
||||
.instrument(debug_span!("dht_framer"));
|
||||
|
||||
let bootstrap = self.bootstrap(bootstrap_addrs);
|
||||
let mut bootstrap_done = false;
|
||||
|
|
|
|||
|
|
@ -218,6 +218,7 @@ fn init_logging(opts: &Opts) -> tokio::sync::mpsc::UnboundedSender<String> {
|
|||
layered
|
||||
.with(
|
||||
fmt::layer()
|
||||
.with_ansi(false)
|
||||
.with_writer(log_file)
|
||||
.with_filter(EnvFilter::builder().parse(&opts.log_file_rust_log).unwrap()),
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue