Merge pull request #338 from ikatson/udp-tracker-socket-reuse

[enhancement] UDP tracker socket reuse
This commit is contained in:
Igor Katson 2025-02-27 15:13:01 +00:00 committed by GitHub
commit a16247aadd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 274 additions and 116 deletions

2
Cargo.lock generated
View file

@ -2827,10 +2827,12 @@ dependencies = [
"librqbit-bencode",
"librqbit-buffers",
"librqbit-core",
"parking_lot",
"rand 0.8.5",
"reqwest",
"serde",
"tokio",
"tokio-util",
"tracing",
"url",
"urlencoding",

View file

@ -61,7 +61,7 @@ use tokio::{
};
use tokio_util::sync::{CancellationToken, DropGuard};
use tracing::{debug, error, error_span, info, trace, warn, Instrument, Span};
use tracker_comms::TrackerComms;
use tracker_comms::{TrackerComms, UdpTrackerClient};
pub const SUPPORTED_SCHEMES: [&str; 3] = ["http:", "https:", "magnet:"];
@ -110,6 +110,7 @@ pub struct Session {
dht: Option<Dht>,
pub(crate) connector: Arc<StreamConnector>,
reqwest_client: reqwest::Client,
udp_tracker_client: UdpTrackerClient,
// Lifecycle management
cancellation_token: CancellationToken,
@ -625,6 +626,10 @@ impl Session {
blocklist::Blocklist::empty()
};
let udp_tracker_client = UdpTrackerClient::new(token.clone())
.await
.context("error creating UDP tracker client")?;
let session = Arc::new(Self {
persistence,
bitv_factory,
@ -647,6 +652,7 @@ impl Session {
concurrent_initialize_semaphore: Arc::new(tokio::sync::Semaphore::new(
opts.concurrent_init_limit.unwrap_or(3),
)),
udp_tracker_client,
ratelimits: Limits::new(opts.ratelimits),
trackers: opts.trackers,
#[cfg(feature = "disable-upload")]
@ -1355,6 +1361,7 @@ impl Session {
force_tracker_interval,
announce_port,
self.reqwest_client.clone(),
self.udp_tracker_client.clone(),
);
let initial_peers_rx = if initial_peers.is_empty() {

View file

@ -32,3 +32,5 @@ tracing = "0.1.40"
reqwest = { version = "0.12", default-features = false, features = ["json"] }
bencode = { path = "../bencode", default-features = false, package = "librqbit-bencode", version = "3" }
url = { version = "2", default-features = false }
parking_lot = "0.12.3"
tokio-util = "0.7.13"

View file

@ -3,3 +3,4 @@ mod tracker_comms_http;
mod tracker_comms_udp;
pub use tracker_comms::*;
pub use tracker_comms_udp::UdpTrackerClient;

View file

@ -18,6 +18,7 @@ use url::Url;
use crate::tracker_comms_http;
use crate::tracker_comms_udp;
use crate::tracker_comms_udp::UdpTrackerClient;
use librqbit_core::hash_id::Id20;
pub struct TrackerComms {
@ -89,6 +90,7 @@ impl std::fmt::Debug for SupportedTracker {
}
impl TrackerComms {
#[allow(clippy::too_many_arguments)]
pub fn start(
info_hash: Id20,
peer_id: Id20,
@ -97,6 +99,7 @@ impl TrackerComms {
force_interval: Option<Duration>,
tcp_listen_port: Option<u16>,
reqwest_client: reqwest::Client,
udp_client: UdpTrackerClient,
) -> Option<BoxStream<'static, SocketAddr>> {
let trackers = trackers
.into_iter()
@ -131,7 +134,7 @@ impl TrackerComms {
});
let mut futures = FuturesUnordered::new();
for tracker in trackers {
futures.push(comms.add_tracker(tracker))
futures.push(comms.add_tracker(tracker, &udp_client))
}
while !(futures.is_empty()) {
tokio::select! {
@ -155,6 +158,7 @@ impl TrackerComms {
fn add_tracker(
&self,
url: SupportedTracker,
client: &UdpTrackerClient,
) -> Either<
impl std::future::Future<Output = anyhow::Result<()>> + '_ + Send,
impl std::future::Future<Output = anyhow::Result<()>> + '_ + Send,
@ -163,7 +167,7 @@ impl TrackerComms {
match url {
SupportedTracker::Udp(url) => {
let span = error_span!(parent: None, "udp_tracker", tracker = %url, info_hash = ?info_hash);
self.task_single_tracker_monitor_udp(url)
self.task_single_tracker_monitor_udp(url, client.clone())
.instrument(span)
.right_future()
}
@ -183,7 +187,7 @@ impl TrackerComms {
async fn task_single_tracker_monitor_http(&self, mut tracker_url: Url) -> anyhow::Result<()> {
let mut event = Some(tracker_comms_http::TrackerRequestEvent::Started);
trace!(url=?tracker_url, "starting monitor");
trace!(url=%tracker_url, "starting monitor");
loop {
let stats = self.stats.get();
let request = tracker_comms_http::TrackerRequest {
@ -227,7 +231,7 @@ impl TrackerComms {
}
async fn tracker_one_request_http(&self, tracker_url: Url) -> anyhow::Result<u64> {
debug!(url = ?tracker_url, "calling tracker over http");
debug!(url = %tracker_url, "calling tracker over http");
let response: reqwest::Response = self.reqwest_client.get(tracker_url).send().await?;
if !response.status().is_success() {
anyhow::bail!("tracker responded with {:?}", response.status());
@ -247,19 +251,20 @@ impl TrackerComms {
Ok(response.interval)
}
async fn task_single_tracker_monitor_udp(&self, url: Url) -> anyhow::Result<()> {
async fn task_single_tracker_monitor_udp(
&self,
url: Url,
client: UdpTrackerClient,
) -> anyhow::Result<()> {
use tracker_comms_udp::*;
if url.scheme() != "udp" {
bail!("expected UDP scheme in {}", url);
}
let hp: (&str, u16) = (
url.host_str().context("missing host")?,
let hp: (String, u16) = (
url.host_str().context("missing host")?.to_owned(),
url.port().context("missing port")?,
);
let mut requester = UdpTrackerRequester::new(hp)
.await
.context("error creating UDP tracker requester")?;
let mut sleep_interval: Option<Duration> = None;
loop {
@ -291,7 +296,7 @@ impl TrackerComms {
port: self.tcp_listen_port.unwrap_or(0),
};
match requester.announce(request).await {
match client.announce(&hp, request).await {
Ok(response) => {
trace!(len = response.addrs.len(), "received announce response");
for addr in response.addrs {
@ -305,7 +310,7 @@ impl TrackerComms {
sleep_interval = Some(self.force_tracker_interval.unwrap_or(new_interval));
}
Err(e) => {
debug!(url = ?url, "error reading announce response: {e:#}");
debug!(url = %url, "error reading announce response: {e:#}");
if sleep_interval.is_none() {
sleep_interval = Some(
self.force_tracker_interval

View file

@ -1,15 +1,22 @@
use std::net::{Ipv4Addr, SocketAddrV4};
use std::{
collections::{hash_map::Entry, HashMap},
ffi::CStr,
net::{Ipv4Addr, SocketAddrV4},
sync::Arc,
time::{Duration, Instant},
};
use anyhow::{bail, Context};
use librqbit_core::hash_id::Id20;
use librqbit_core::{hash_id::Id20, spawn_utils::spawn_with_cancel};
use parking_lot::RwLock;
use rand::Rng;
use tokio::net::ToSocketAddrs;
use tracing::trace;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error_span, trace, warn};
const ACTION_CONNECT: u32 = 0;
const ACTION_ANNOUNCE: u32 = 1;
// const ACTION_SCRAPE: u32 = 2;
// const ACTION_ERROR: u32 = 3;
const ACTION_ERROR: u32 = 3;
pub const EVENT_NONE: u32 = 0;
pub const EVENT_COMPLETED: u32 = 1;
@ -44,31 +51,51 @@ pub enum Request {
}
impl Request {
pub fn serialize(&self, transaction_id: TransactionId, buf: &mut Vec<u8>) -> usize {
let cur_len = buf.len();
match self {
Request::Connect => {
buf.extend_from_slice(&CONNECTION_ID_MAGIC.to_be_bytes());
buf.extend_from_slice(&ACTION_CONNECT.to_be_bytes());
buf.extend_from_slice(&transaction_id.to_be_bytes());
}
Request::Announce(connection_id, fields) => {
buf.extend_from_slice(&connection_id.to_be_bytes());
buf.extend_from_slice(&ACTION_ANNOUNCE.to_be_bytes());
buf.extend_from_slice(&transaction_id.to_be_bytes());
buf.extend_from_slice(&fields.info_hash.0);
buf.extend_from_slice(&fields.peer_id.0);
buf.extend_from_slice(&fields.downloaded.to_be_bytes());
buf.extend_from_slice(&fields.left.to_be_bytes());
buf.extend_from_slice(&fields.uploaded.to_be_bytes());
buf.extend_from_slice(&fields.event.to_be_bytes());
buf.extend_from_slice(&0u32.to_be_bytes()); // ip address 0
buf.extend_from_slice(&fields.key.to_be_bytes());
buf.extend_from_slice(&(-1i32).to_be_bytes()); // num want -1
buf.extend_from_slice(&fields.port.to_be_bytes());
pub fn serialize(
&self,
transaction_id: TransactionId,
buf: &mut [u8],
) -> anyhow::Result<usize> {
struct W<'a> {
buf: &'a mut [u8],
offset: usize,
}
impl W<'_> {
fn extend_from_slice(&mut self, s: &[u8]) -> anyhow::Result<()> {
if self.buf.len() < self.offset + s.len() {
bail!("not enough space in buffer")
}
self.buf[self.offset..self.offset + s.len()].copy_from_slice(s);
self.offset += s.len();
Ok(())
}
}
buf.len() - cur_len
let mut w = W { buf, offset: 0 };
match self {
Request::Connect => {
w.extend_from_slice(&CONNECTION_ID_MAGIC.to_be_bytes())?;
w.extend_from_slice(&ACTION_CONNECT.to_be_bytes())?;
w.extend_from_slice(&transaction_id.to_be_bytes())?;
}
Request::Announce(connection_id, fields) => {
w.extend_from_slice(&connection_id.to_be_bytes())?;
w.extend_from_slice(&ACTION_ANNOUNCE.to_be_bytes())?;
w.extend_from_slice(&transaction_id.to_be_bytes())?;
w.extend_from_slice(&fields.info_hash.0)?;
w.extend_from_slice(&fields.peer_id.0)?;
w.extend_from_slice(&fields.downloaded.to_be_bytes())?;
w.extend_from_slice(&fields.left.to_be_bytes())?;
w.extend_from_slice(&fields.uploaded.to_be_bytes())?;
w.extend_from_slice(&fields.event.to_be_bytes())?;
w.extend_from_slice(&0u32.to_be_bytes())?; // ip address 0
w.extend_from_slice(&fields.key.to_be_bytes())?;
w.extend_from_slice(&(-1i32).to_be_bytes())?; // num want -1
w.extend_from_slice(&fields.port.to_be_bytes())?;
}
}
Ok(w.offset)
}
}
@ -86,6 +113,9 @@ pub struct AnnounceResponse {
pub enum Response {
Connect(ConnectionId),
Announce(AnnounceResponse),
#[allow(dead_code)]
Error(String),
Unknown,
}
fn split_slice(s: &[u8], first_len: usize) -> Option<(&[u8], &[u8])> {
@ -128,7 +158,20 @@ parse_impl!(i16, 2);
impl Response {
pub fn parse(buf: &[u8]) -> anyhow::Result<(TransactionId, Self)> {
let (action, buf) = u32::parse_num(buf).context("can't parse action")?;
let (tid, mut buf) = u32::parse_num(buf).context("can't parse transaction id")?;
let (tid, buf) = u32::parse_num(buf).context("can't parse transaction id")?;
let response = match Self::parse_response(action, buf) {
Ok(r) => r,
Err(e) => {
debug!("error parsing: {e:#}");
Response::Unknown
}
};
Ok((tid, response))
}
fn parse_response(action: u32, mut buf: &[u8]) -> anyhow::Result<Self> {
let response = match action {
ACTION_CONNECT => {
let (connection_id, b) =
@ -158,6 +201,15 @@ impl Response {
addrs,
})
}
ACTION_ERROR => {
let msg = CStr::from_bytes_with_nul(buf)
.ok()
.and_then(|s| s.to_str().ok())
.or_else(|| std::str::from_utf8(buf).ok())
.unwrap_or("<invalid UTF-8>")
.to_owned();
return Ok(Response::Error(msg));
}
_ => bail!("unsupported action {action}"),
};
@ -168,92 +220,182 @@ impl Response {
);
}
Ok((tid, response))
Ok(response)
}
}
pub struct UdpTrackerRequester {
sock: tokio::net::UdpSocket,
connection_id: ConnectionId,
read_buf: Vec<u8>,
write_buf: Vec<u8>,
pub type TrackerAddr = (String, u16);
struct ConnectionIdMeta {
id: ConnectionId,
created: Instant,
}
impl UdpTrackerRequester {
// Addr is "host:port"
pub async fn new(addr: impl ToSocketAddrs) -> anyhow::Result<Self> {
#[derive(Default)]
struct ClientLocked {
connections: HashMap<TrackerAddr, ConnectionIdMeta>,
transactions: HashMap<TransactionId, tokio::sync::oneshot::Sender<Response>>,
}
struct ClientShared {
sock: tokio::net::UdpSocket,
locked: RwLock<ClientLocked>,
}
#[derive(Clone)]
pub struct UdpTrackerClient {
state: Arc<ClientShared>,
}
struct TransactionIdGuard<'a> {
tid: TransactionId,
state: &'a ClientShared,
}
impl Drop for TransactionIdGuard<'_> {
fn drop(&mut self) {
let mut g = self.state.locked.write();
g.transactions.remove(&self.tid);
}
}
impl UdpTrackerClient {
pub async fn new(cancel_token: CancellationToken) -> anyhow::Result<Self> {
let sock = tokio::net::UdpSocket::bind("0.0.0.0:0")
.await
.context("error binding UDP socket")?;
sock.connect(addr)
.await
.context("error connecting UDP socket")?;
let tid = new_transaction_id();
let mut write_buf = Vec::new();
let mut read_buf = vec![0u8; 4096];
trace!("sending connect request");
Request::Connect.serialize(tid, &mut write_buf);
sock.send(&write_buf)
.await
.context("error sending to socket")?;
let size = sock
.recv(&mut read_buf)
.await
.context("error receiving from socket")?;
let (rtid, response) =
Response::parse(&read_buf[..size]).context("error parsing response")?;
if tid != rtid {
bail!("expected transaction id {} == {}", tid, rtid);
}
trace!(response=?response, "received");
let connection_id = match response {
Response::Connect(connection_id) => connection_id,
other => bail!("unexpected response {other:?}"),
.context("error binding UDP for tracker")?;
let client = Self {
state: Arc::new(ClientShared {
sock,
locked: RwLock::new(Default::default()),
}),
};
trace!(connection_id);
spawn_with_cancel(error_span!("udp_tracker"), cancel_token, {
let client = client.clone();
async move { client.run().await }
});
Ok(Self {
sock,
connection_id,
read_buf,
write_buf,
})
Ok(client)
}
pub async fn announce(&mut self, fields: AnnounceFields) -> anyhow::Result<AnnounceResponse> {
let request = Request::Announce(self.connection_id, fields);
let response = self.request(request).await?;
async fn run(self) -> anyhow::Result<()> {
let mut buf = [0u8; 16384];
loop {
let (len, addr) = match self.state.sock.recv_from(&mut buf).await {
Ok(r) => r,
Err(e) => {
warn!("error in UdpSocket::recv_from: {e:#}");
continue;
}
};
let (tid, response) = match Response::parse(&buf[..len]) {
Ok(r) => r,
Err(e) => {
debug!(?addr, "error parsing UDP response: {e:#}");
continue;
}
};
trace!(?tid, ?response, ?addr, "received");
let t = self.state.locked.write().transactions.remove(&tid);
match t {
Some(tx) => match tx.send(response) {
Ok(_) => {}
Err(_) => {
debug!(tid, "reader dead");
}
},
None => {
debug!(tid, "nowhere to send response");
}
};
}
}
async fn get_connection_id(&self, addr: &TrackerAddr) -> anyhow::Result<ConnectionId> {
if let Some(m) = self.state.locked.read().connections.get(addr) {
if m.created.elapsed() < Duration::from_secs(60) {
return Ok(m.id);
}
}
let response = self.request(addr, Request::Connect).await?;
match response {
Response::Connect(connection_id) => {
self.state.locked.write().connections.insert(
addr.clone(),
ConnectionIdMeta {
id: connection_id,
created: Instant::now(),
},
);
Ok(connection_id)
}
_ => anyhow::bail!("expected connect response"),
}
}
async fn request(&self, addr: &TrackerAddr, request: Request) -> anyhow::Result<Response> {
let (tx, rx) = tokio::sync::oneshot::channel();
let tid_g = self.reserve_transaction_id(tx)?;
let mut write_buf = [0u8; 1024];
let len = request.serialize(tid_g.tid, &mut write_buf)?;
self.state.sock.send_to(&write_buf[..len], addr).await?;
let response = tokio::time::timeout(Duration::from_secs(10), rx)
.await
.context("timeout connecting")?
.context("sender dead")?;
match &response {
Response::Error(e) => {
anyhow::bail!("remote errored: {e}")
}
Response::Unknown => {
anyhow::bail!("remote replied with something we could not parse")
}
_ => {}
}
Ok(response)
}
fn reserve_transaction_id(
&self,
tx: tokio::sync::oneshot::Sender<Response>,
) -> anyhow::Result<TransactionIdGuard<'_>> {
let mut g = self.state.locked.write();
for _ in 0..10 {
let t = new_transaction_id();
match g.transactions.entry(t) {
Entry::Occupied(_) => continue,
Entry::Vacant(vac) => {
vac.insert(tx);
return Ok(TransactionIdGuard {
tid: t,
state: &self.state,
});
}
}
}
bail!("cant generate transaction id")
}
pub async fn announce(
&self,
tracker: &TrackerAddr,
fields: AnnounceFields,
) -> anyhow::Result<AnnounceResponse> {
let connection_id = self.get_connection_id(tracker).await?;
let request = Request::Announce(connection_id, fields);
let response = self.request(tracker, request).await?;
match response {
Response::Announce(r) => Ok(r),
other => bail!("unexpected response {other:?}, expected announce"),
}
}
pub async fn request(&mut self, request: Request) -> anyhow::Result<Response> {
let tid = new_transaction_id();
self.write_buf.clear();
let size = request.serialize(tid, &mut self.write_buf);
trace!(request=?request, tid, "sending");
self.sock
.send(&self.write_buf[..size])
.await
.context("error sending")?;
let size = self.sock.recv(&mut self.read_buf).await?;
let (rtid, response) = Response::parse(&self.read_buf[..size])?;
trace!("received response");
if tid != rtid {
bail!("unexpected transaction id");
}
Ok(response)
}
}
#[cfg(test)]
@ -280,12 +422,12 @@ mod tests {
sock.connect("opentor.net:6969").await.unwrap();
let tid = new_transaction_id();
let mut write_buf = Vec::new();
let mut write_buf = [0u8; 16384];
let mut read_buf = vec![0u8; 4096];
Request::Connect.serialize(tid, &mut write_buf);
let len = Request::Connect.serialize(tid, &mut write_buf).unwrap();
sock.send(&write_buf).await.unwrap();
sock.send(&write_buf[..len]).await.unwrap();
let size = sock.recv(&mut read_buf).await.unwrap();
@ -314,8 +456,7 @@ mod tests {
port: 24563,
},
);
write_buf.clear();
let size = request.serialize(tid, &mut write_buf);
let size = request.serialize(tid, &mut write_buf).unwrap();
sock.send(&write_buf[..size]).await.unwrap();
let size = sock.recv(&mut read_buf).await.unwrap();