UDP send errors now kill requests right away
This commit is contained in:
parent
e9b7103c26
commit
7da46d0bbf
3 changed files with 85 additions and 75 deletions
|
|
@ -4,7 +4,7 @@ use std::{
|
||||||
net::{Ipv4Addr, SocketAddrV4},
|
net::{Ipv4Addr, SocketAddrV4},
|
||||||
};
|
};
|
||||||
|
|
||||||
use bencode::ByteBuf;
|
use bencode::{ByteBuf, ByteString};
|
||||||
use clone_to_owned::CloneToOwned;
|
use clone_to_owned::CloneToOwned;
|
||||||
use librqbit_core::id20::Id20;
|
use librqbit_core::id20::Id20;
|
||||||
use serde::{
|
use serde::{
|
||||||
|
|
@ -332,6 +332,16 @@ pub struct Message<BufT> {
|
||||||
pub kind: MessageKind<BufT>,
|
pub kind: MessageKind<BufT>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Message<ByteString> {
|
||||||
|
pub fn get_transaction_id(&self) -> Option<u16> {
|
||||||
|
if self.transaction_id.len() != 2 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let tid = ((self.transaction_id[0] as u16) << 8) + (self.transaction_id[1] as u16);
|
||||||
|
Some(tid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum MessageKind<BufT> {
|
pub enum MessageKind<BufT> {
|
||||||
Error(ErrorDescription<BufT>),
|
Error(ErrorDescription<BufT>),
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ pub struct DhtStats {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct OutstandingRequest {
|
struct OutstandingRequest {
|
||||||
done: tokio::sync::oneshot::Sender<ResponseOrError>,
|
done: tokio::sync::oneshot::Sender<anyhow::Result<ResponseOrError>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct DhtState {
|
pub struct DhtState {
|
||||||
|
|
@ -134,13 +134,13 @@ impl DhtState {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await {
|
match tokio::time::timeout(RESPONSE_TIMEOUT, rx).await {
|
||||||
Ok(Ok(r)) => Ok(r),
|
Ok(Ok(r)) => r,
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
self.inflight.remove(&key);
|
self.inflight.remove(&key);
|
||||||
warn!("recv error, did not expect this: {:?}", e);
|
warn!("recv error, did not expect this: {:?}", e);
|
||||||
Err(e.into())
|
Err(e.into())
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(_) => {
|
||||||
self.inflight.remove(&key);
|
self.inflight.remove(&key);
|
||||||
anyhow::bail!("timeout")
|
anyhow::bail!("timeout")
|
||||||
}
|
}
|
||||||
|
|
@ -227,14 +227,7 @@ impl DhtState {
|
||||||
// If it's a response to a request we made, find the request task, notify it with the response,
|
// If it's a response to a request we made, find the request task, notify it with the response,
|
||||||
// and let it handle it.
|
// and let it handle it.
|
||||||
MessageKind::Error(_) | MessageKind::Response(_) => {
|
MessageKind::Error(_) | MessageKind::Response(_) => {
|
||||||
if msg.transaction_id.len() != 2 {
|
let tid = msg.get_transaction_id().context("bad transaction id")?;
|
||||||
anyhow::bail!(
|
|
||||||
"{}: transaction id unrecognized, expected its length == 2. Message: {:?}",
|
|
||||||
addr,
|
|
||||||
msg
|
|
||||||
)
|
|
||||||
}
|
|
||||||
let tid = ((msg.transaction_id[0] as u16) << 8) + (msg.transaction_id[1] as u16);
|
|
||||||
let request = match self.inflight.remove(&(tid, addr)).map(|(_, v)| v) {
|
let request = match self.inflight.remove(&(tid, addr)).map(|(_, v)| v) {
|
||||||
Some(req) => req,
|
Some(req) => req,
|
||||||
None => anyhow::bail!("outstanding request not found. Message: {:?}", msg),
|
None => anyhow::bail!("outstanding request not found. Message: {:?}", msg),
|
||||||
|
|
@ -248,7 +241,7 @@ impl DhtState {
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
match request.done.send(response_or_error) {
|
match request.done.send(Ok(response_or_error)) {
|
||||||
Ok(_) => {}
|
Ok(_) => {}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!(
|
warn!(
|
||||||
|
|
@ -550,67 +543,6 @@ fn make_rate_limiter() -> RateLimiter {
|
||||||
.build()
|
.build()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_framer(
|
|
||||||
socket: &UdpSocket,
|
|
||||||
mut input_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
|
|
||||||
output_tx: Sender<(Message<ByteString>, SocketAddr)>,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let writer = async {
|
|
||||||
let mut buf = Vec::new();
|
|
||||||
let rate_limiter = make_rate_limiter();
|
|
||||||
while let Some((msg, addr)) = input_rx.recv().await {
|
|
||||||
let addr = match addr {
|
|
||||||
SocketAddr::V4(v4) => v4,
|
|
||||||
SocketAddr::V6(_) => continue,
|
|
||||||
};
|
|
||||||
rate_limiter.acquire_one().await;
|
|
||||||
trace!("{}: sending {:?}", addr, &msg);
|
|
||||||
buf.clear();
|
|
||||||
bprotocol::serialize_message(
|
|
||||||
&mut buf,
|
|
||||||
msg.transaction_id,
|
|
||||||
msg.version,
|
|
||||||
msg.ip,
|
|
||||||
msg.kind,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
if let Err(e) = socket.send_to(&buf, addr).await {
|
|
||||||
warn!("could not send to {:?}: {}", addr, e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err::<(), _>(anyhow::anyhow!(
|
|
||||||
"DHT UDP socket writer over, nowhere to read messages from"
|
|
||||||
))
|
|
||||||
};
|
|
||||||
let reader = async {
|
|
||||||
let mut buf = vec![0u8; 16384];
|
|
||||||
loop {
|
|
||||||
let (size, addr) = socket
|
|
||||||
.recv_from(&mut buf)
|
|
||||||
.await
|
|
||||||
.context("error reading from UDP socket")?;
|
|
||||||
match bprotocol::deserialize_message::<ByteString>(&buf[..size]) {
|
|
||||||
Ok(msg) => {
|
|
||||||
trace!("{}: received {:?}", addr, &msg);
|
|
||||||
match output_tx.send((msg, addr)).await {
|
|
||||||
Ok(_) => {}
|
|
||||||
Err(_) => break,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => debug!("{}: error deserializing incoming message: {}", addr, e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err::<(), _>(anyhow::anyhow!(
|
|
||||||
"DHT UDP socket reader over, nowhere to send responses to"
|
|
||||||
))
|
|
||||||
};
|
|
||||||
let result = tokio::select! {
|
|
||||||
err = writer => err,
|
|
||||||
err = reader => err,
|
|
||||||
};
|
|
||||||
result.context("DHT UDP framer closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
enum Request {
|
enum Request {
|
||||||
GetPeers(Id20),
|
GetPeers(Id20),
|
||||||
|
|
@ -635,6 +567,12 @@ impl DhtWorker {
|
||||||
self.state.on_incoming_from_remote(msg, addr)
|
self.state.on_incoming_from_remote(msg, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn on_send_error(&self, tid: u16, addr: SocketAddr, err: anyhow::Error) {
|
||||||
|
if let Some((_, OutstandingRequest { done })) = self.state.inflight.remove(&(tid, addr)) {
|
||||||
|
let _ = done.send(Err(err)).is_err();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
async fn bootstrap_one_ip_with_backoff(&self, addr: SocketAddr) -> anyhow::Result<()> {
|
async fn bootstrap_one_ip_with_backoff(&self, addr: SocketAddr) -> anyhow::Result<()> {
|
||||||
let mut backoff = ExponentialBackoffBuilder::new()
|
let mut backoff = ExponentialBackoffBuilder::new()
|
||||||
.with_initial_interval(Duration::from_secs(10))
|
.with_initial_interval(Duration::from_secs(10))
|
||||||
|
|
@ -732,13 +670,74 @@ impl DhtWorker {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn framer(
|
||||||
|
&self,
|
||||||
|
socket: &UdpSocket,
|
||||||
|
mut input_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
|
||||||
|
output_tx: Sender<(Message<ByteString>, SocketAddr)>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let writer = async {
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
let rate_limiter = make_rate_limiter();
|
||||||
|
while let Some((msg, addr)) = input_rx.recv().await {
|
||||||
|
rate_limiter.acquire_one().await;
|
||||||
|
trace!("{}: sending {:?}", addr, &msg);
|
||||||
|
buf.clear();
|
||||||
|
let tid = msg.get_transaction_id().unwrap();
|
||||||
|
bprotocol::serialize_message(
|
||||||
|
&mut buf,
|
||||||
|
msg.transaction_id,
|
||||||
|
msg.version,
|
||||||
|
msg.ip,
|
||||||
|
msg.kind,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
if let Err(e) = socket.send_to(&buf, addr).await {
|
||||||
|
self.on_send_error(tid, addr, e.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err::<(), _>(anyhow::anyhow!(
|
||||||
|
"DHT UDP socket writer over, nowhere to read messages from"
|
||||||
|
))
|
||||||
|
};
|
||||||
|
let reader = async {
|
||||||
|
let mut buf = vec![0u8; 16384];
|
||||||
|
loop {
|
||||||
|
let (size, addr) = socket
|
||||||
|
.recv_from(&mut buf)
|
||||||
|
.await
|
||||||
|
.context("error reading from UDP socket")?;
|
||||||
|
match bprotocol::deserialize_message::<ByteString>(&buf[..size]) {
|
||||||
|
Ok(msg) => {
|
||||||
|
trace!("{}: received {:?}", addr, &msg);
|
||||||
|
match output_tx.send((msg, addr)).await {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(_) => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => debug!("{}: error deserializing incoming message: {}", addr, e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err::<(), _>(anyhow::anyhow!(
|
||||||
|
"DHT UDP socket reader over, nowhere to send responses to"
|
||||||
|
))
|
||||||
|
};
|
||||||
|
let result = tokio::select! {
|
||||||
|
err = writer => err,
|
||||||
|
err = reader => err,
|
||||||
|
};
|
||||||
|
result.context("DHT UDP framer closed")
|
||||||
|
}
|
||||||
|
|
||||||
async fn start(
|
async fn start(
|
||||||
self,
|
self,
|
||||||
in_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
|
in_rx: UnboundedReceiver<(Message<ByteString>, SocketAddr)>,
|
||||||
bootstrap_addrs: &[String],
|
bootstrap_addrs: &[String],
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let (out_tx, mut out_rx) = channel(1);
|
let (out_tx, mut out_rx) = channel(1);
|
||||||
let framer = run_framer(&self.socket, in_rx, out_tx).instrument(debug_span!("dht_framer"));
|
let framer = self
|
||||||
|
.framer(&self.socket, in_rx, out_tx)
|
||||||
|
.instrument(debug_span!("dht_framer"));
|
||||||
|
|
||||||
let bootstrap = self.bootstrap(bootstrap_addrs);
|
let bootstrap = self.bootstrap(bootstrap_addrs);
|
||||||
let mut bootstrap_done = false;
|
let mut bootstrap_done = false;
|
||||||
|
|
|
||||||
|
|
@ -218,6 +218,7 @@ fn init_logging(opts: &Opts) -> tokio::sync::mpsc::UnboundedSender<String> {
|
||||||
layered
|
layered
|
||||||
.with(
|
.with(
|
||||||
fmt::layer()
|
fmt::layer()
|
||||||
|
.with_ansi(false)
|
||||||
.with_writer(log_file)
|
.with_writer(log_file)
|
||||||
.with_filter(EnvFilter::builder().parse(&opts.log_file_rust_log).unwrap()),
|
.with_filter(EnvFilter::builder().parse(&opts.log_file_rust_log).unwrap()),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue