Trackers: reuse UDP socket

This commit is contained in:
Igor Katson 2025-02-27 14:25:24 +00:00
parent 94877aec6f
commit 29508014b8
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
6 changed files with 187 additions and 82 deletions

2
Cargo.lock generated
View file

@ -2827,10 +2827,12 @@ dependencies = [
"librqbit-bencode",
"librqbit-buffers",
"librqbit-core",
"parking_lot",
"rand 0.8.5",
"reqwest",
"serde",
"tokio",
"tokio-util",
"tracing",
"url",
"urlencoding",

View file

@ -61,7 +61,7 @@ use tokio::{
};
use tokio_util::sync::{CancellationToken, DropGuard};
use tracing::{debug, error, error_span, info, trace, warn, Instrument, Span};
use tracker_comms::TrackerComms;
use tracker_comms::{TrackerComms, UdpTrackerClient};
pub const SUPPORTED_SCHEMES: [&str; 3] = ["http:", "https:", "magnet:"];
@ -110,6 +110,7 @@ pub struct Session {
dht: Option<Dht>,
pub(crate) connector: Arc<StreamConnector>,
reqwest_client: reqwest::Client,
udp_tracker_client: UdpTrackerClient,
// Lifecycle management
cancellation_token: CancellationToken,
@ -625,6 +626,10 @@ impl Session {
blocklist::Blocklist::empty()
};
let udp_tracker_client = UdpTrackerClient::new(token.clone())
.await
.context("error creating UDP tracker client")?;
let session = Arc::new(Self {
persistence,
bitv_factory,
@ -647,6 +652,7 @@ impl Session {
concurrent_initialize_semaphore: Arc::new(tokio::sync::Semaphore::new(
opts.concurrent_init_limit.unwrap_or(3),
)),
udp_tracker_client,
ratelimits: Limits::new(opts.ratelimits),
trackers: opts.trackers,
#[cfg(feature = "disable-upload")]
@ -1355,6 +1361,7 @@ impl Session {
force_tracker_interval,
announce_port,
self.reqwest_client.clone(),
self.udp_tracker_client.clone(),
);
let initial_peers_rx = if initial_peers.is_empty() {

View file

@ -32,3 +32,5 @@ tracing = "0.1.40"
reqwest = { version = "0.12", default-features = false, features = ["json"] }
bencode = { path = "../bencode", default-features = false, package = "librqbit-bencode", version = "3" }
url = { version = "2", default-features = false }
parking_lot = "0.12.3"
tokio-util = "0.7.13"

View file

@ -3,3 +3,4 @@ mod tracker_comms_http;
mod tracker_comms_udp;
pub use tracker_comms::*;
pub use tracker_comms_udp::UdpTrackerClient;

View file

@ -18,6 +18,7 @@ use url::Url;
use crate::tracker_comms_http;
use crate::tracker_comms_udp;
use crate::tracker_comms_udp::UdpTrackerClient;
use librqbit_core::hash_id::Id20;
pub struct TrackerComms {
@ -89,6 +90,7 @@ impl std::fmt::Debug for SupportedTracker {
}
impl TrackerComms {
#[allow(clippy::too_many_arguments)]
pub fn start(
info_hash: Id20,
peer_id: Id20,
@ -97,6 +99,7 @@ impl TrackerComms {
force_interval: Option<Duration>,
tcp_listen_port: Option<u16>,
reqwest_client: reqwest::Client,
udp_client: UdpTrackerClient,
) -> Option<BoxStream<'static, SocketAddr>> {
let trackers = trackers
.into_iter()
@ -131,7 +134,7 @@ impl TrackerComms {
});
let mut futures = FuturesUnordered::new();
for tracker in trackers {
futures.push(comms.add_tracker(tracker))
futures.push(comms.add_tracker(tracker, &udp_client))
}
while !(futures.is_empty()) {
tokio::select! {
@ -155,6 +158,7 @@ impl TrackerComms {
fn add_tracker(
&self,
url: SupportedTracker,
client: &UdpTrackerClient,
) -> Either<
impl std::future::Future<Output = anyhow::Result<()>> + '_ + Send,
impl std::future::Future<Output = anyhow::Result<()>> + '_ + Send,
@ -163,7 +167,7 @@ impl TrackerComms {
match url {
SupportedTracker::Udp(url) => {
let span = error_span!(parent: None, "udp_tracker", tracker = %url, info_hash = ?info_hash);
self.task_single_tracker_monitor_udp(url)
self.task_single_tracker_monitor_udp(url, client.clone())
.instrument(span)
.right_future()
}
@ -247,19 +251,20 @@ impl TrackerComms {
Ok(response.interval)
}
async fn task_single_tracker_monitor_udp(&self, url: Url) -> anyhow::Result<()> {
async fn task_single_tracker_monitor_udp(
&self,
url: Url,
client: UdpTrackerClient,
) -> anyhow::Result<()> {
use tracker_comms_udp::*;
if url.scheme() != "udp" {
bail!("expected UDP scheme in {}", url);
}
let hp: (&str, u16) = (
url.host_str().context("missing host")?,
let hp: (String, u16) = (
url.host_str().context("missing host")?.to_owned(),
url.port().context("missing port")?,
);
let mut requester = UdpTrackerRequester::new(hp)
.await
.context("error creating UDP tracker requester")?;
let mut sleep_interval: Option<Duration> = None;
loop {
@ -291,7 +296,7 @@ impl TrackerComms {
port: self.tcp_listen_port.unwrap_or(0),
};
match requester.announce(request).await {
match client.announce(&hp, request).await {
Ok(response) => {
trace!(len = response.addrs.len(), "received announce response");
for addr in response.addrs {

View file

@ -1,10 +1,16 @@
use std::net::{Ipv4Addr, SocketAddrV4};
use std::{
collections::{hash_map::Entry, HashMap},
net::{Ipv4Addr, SocketAddrV4},
sync::Arc,
time::{Duration, Instant},
};
use anyhow::{bail, Context};
use librqbit_core::hash_id::Id20;
use librqbit_core::{hash_id::Id20, spawn_utils::spawn_with_cancel};
use parking_lot::RwLock;
use rand::Rng;
use tokio::net::ToSocketAddrs;
use tracing::trace;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error_span, trace, warn};
const ACTION_CONNECT: u32 = 0;
const ACTION_ANNOUNCE: u32 = 1;
@ -172,88 +178,170 @@ impl Response {
}
}
pub struct UdpTrackerRequester {
sock: tokio::net::UdpSocket,
connection_id: ConnectionId,
read_buf: Vec<u8>,
write_buf: Vec<u8>,
pub type TrackerAddr = (String, u16);
struct ConnectionIdMeta {
id: ConnectionId,
created: Instant,
}
impl UdpTrackerRequester {
// Addr is "host:port"
pub async fn new(addr: impl ToSocketAddrs) -> anyhow::Result<Self> {
#[derive(Default)]
struct ClientLocked {
connections: HashMap<TrackerAddr, ConnectionIdMeta>,
transactions: HashMap<TransactionId, tokio::sync::oneshot::Sender<Response>>,
}
struct ClientShared {
sock: tokio::net::UdpSocket,
locked: RwLock<ClientLocked>,
}
#[derive(Clone)]
pub struct UdpTrackerClient {
state: Arc<ClientShared>,
}
struct TransactionIdGuard<'a> {
tid: TransactionId,
state: &'a ClientShared,
}
impl Drop for TransactionIdGuard<'_> {
fn drop(&mut self) {
let mut g = self.state.locked.write();
g.transactions.remove(&self.tid);
}
}
impl UdpTrackerClient {
pub async fn new(cancel_token: CancellationToken) -> anyhow::Result<Self> {
let sock = tokio::net::UdpSocket::bind("0.0.0.0:0")
.await
.context("error binding UDP socket")?;
sock.connect(addr)
.await
.context("error connecting UDP socket")?;
let tid = new_transaction_id();
let mut write_buf = Vec::new();
let mut read_buf = vec![0u8; 4096];
trace!("sending connect request");
Request::Connect.serialize(tid, &mut write_buf);
sock.send(&write_buf)
.await
.context("error sending to socket")?;
let size = sock
.recv(&mut read_buf)
.await
.context("error receiving from socket")?;
let (rtid, response) =
Response::parse(&read_buf[..size]).context("error parsing response")?;
if tid != rtid {
bail!("expected transaction id {} == {}", tid, rtid);
}
trace!(response=?response, "received");
let connection_id = match response {
Response::Connect(connection_id) => connection_id,
other => bail!("unexpected response {other:?}"),
.context("error binding UDP for tracker")?;
let client = Self {
state: Arc::new(ClientShared {
sock,
locked: RwLock::new(Default::default()),
}),
};
trace!(connection_id);
spawn_with_cancel(error_span!("udp_tracker"), cancel_token, {
let client = client.clone();
async move { client.run().await }
});
Ok(Self {
sock,
connection_id,
read_buf,
write_buf,
})
Ok(client)
}
pub async fn announce(&mut self, fields: AnnounceFields) -> anyhow::Result<AnnounceResponse> {
let request = Request::Announce(self.connection_id, fields);
let response = self.request(request).await?;
async fn run(self) -> anyhow::Result<()> {
let mut buf = [0u8; 16384];
loop {
let (len, addr) = match self.state.sock.recv_from(&mut buf).await {
Ok(r) => r,
Err(e) => {
warn!("error in UdpSocket::recv_from: {e:#}");
continue;
}
};
let (tid, response) = match Response::parse(&buf[..len]) {
Ok(r) => r,
Err(e) => {
debug!(?addr, "error parsing UDP response: {e:#}");
continue;
}
};
trace!(?tid, ?response, ?addr, "received");
let t = self.state.locked.write().transactions.remove(&tid);
match t {
Some(tx) => match tx.send(response) {
Ok(_) => {}
Err(_) => {
debug!(tid, "reader dead");
}
},
None => {
debug!(tid, "nowhere to send response");
}
};
}
}
async fn get_connection_id(&self, addr: &TrackerAddr) -> anyhow::Result<ConnectionId> {
if let Some(m) = self.state.locked.read().connections.get(addr) {
if m.created.elapsed() < Duration::from_secs(60) {
return Ok(m.id);
}
}
let response = self.request(addr, Request::Connect).await?;
match response {
Response::Connect(connection_id) => {
self.state.locked.write().connections.insert(
addr.clone(),
ConnectionIdMeta {
id: connection_id,
created: Instant::now(),
},
);
Ok(connection_id)
}
_ => anyhow::bail!("expected connect response"),
}
}
async fn request(&self, addr: &TrackerAddr, request: Request) -> anyhow::Result<Response> {
let (tx, rx) = tokio::sync::oneshot::channel();
let tid_g = self.reserve_transaction_id(tx)?;
// TODO: no allocs
let mut write_buf = Vec::new();
request.serialize(tid_g.tid, &mut write_buf);
self.state.sock.send_to(&write_buf, addr).await?;
let response = tokio::time::timeout(Duration::from_secs(10), rx)
.await
.context("timeout connecting")?
.context("sender dead")?;
Ok(response)
}
fn reserve_transaction_id(
&self,
tx: tokio::sync::oneshot::Sender<Response>,
) -> anyhow::Result<TransactionIdGuard<'_>> {
let mut g = self.state.locked.write();
for _ in 0..10 {
let t = new_transaction_id();
match g.transactions.entry(t) {
Entry::Occupied(_) => continue,
Entry::Vacant(vac) => {
vac.insert(tx);
return Ok(TransactionIdGuard {
tid: t,
state: &self.state,
});
}
}
}
bail!("cant generate transaction id")
}
pub async fn announce(
&self,
tracker: &TrackerAddr,
fields: AnnounceFields,
) -> anyhow::Result<AnnounceResponse> {
let connection_id = self.get_connection_id(tracker).await?;
let request = Request::Announce(connection_id, fields);
let response = self.request(tracker, request).await?;
match response {
Response::Announce(r) => Ok(r),
other => bail!("unexpected response {other:?}, expected announce"),
}
}
pub async fn request(&mut self, request: Request) -> anyhow::Result<Response> {
let tid = new_transaction_id();
self.write_buf.clear();
let size = request.serialize(tid, &mut self.write_buf);
trace!(request=?request, tid, "sending");
self.sock
.send(&self.write_buf[..size])
.await
.context("error sending")?;
let size = self.sock.recv(&mut self.read_buf).await?;
let (rtid, response) = Response::parse(&self.read_buf[..size])?;
trace!("received response");
if tid != rtid {
bail!("unexpected transaction id");
}
Ok(response)
}
}
#[cfg(test)]