Trackers: reuse UDP socket
This commit is contained in:
parent
94877aec6f
commit
29508014b8
6 changed files with 187 additions and 82 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
|
@ -2827,10 +2827,12 @@ dependencies = [
|
|||
"librqbit-bencode",
|
||||
"librqbit-buffers",
|
||||
"librqbit-core",
|
||||
"parking_lot",
|
||||
"rand 0.8.5",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"url",
|
||||
"urlencoding",
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ use tokio::{
|
|||
};
|
||||
use tokio_util::sync::{CancellationToken, DropGuard};
|
||||
use tracing::{debug, error, error_span, info, trace, warn, Instrument, Span};
|
||||
use tracker_comms::TrackerComms;
|
||||
use tracker_comms::{TrackerComms, UdpTrackerClient};
|
||||
|
||||
pub const SUPPORTED_SCHEMES: [&str; 3] = ["http:", "https:", "magnet:"];
|
||||
|
||||
|
|
@ -110,6 +110,7 @@ pub struct Session {
|
|||
dht: Option<Dht>,
|
||||
pub(crate) connector: Arc<StreamConnector>,
|
||||
reqwest_client: reqwest::Client,
|
||||
udp_tracker_client: UdpTrackerClient,
|
||||
|
||||
// Lifecycle management
|
||||
cancellation_token: CancellationToken,
|
||||
|
|
@ -625,6 +626,10 @@ impl Session {
|
|||
blocklist::Blocklist::empty()
|
||||
};
|
||||
|
||||
let udp_tracker_client = UdpTrackerClient::new(token.clone())
|
||||
.await
|
||||
.context("error creating UDP tracker client")?;
|
||||
|
||||
let session = Arc::new(Self {
|
||||
persistence,
|
||||
bitv_factory,
|
||||
|
|
@ -647,6 +652,7 @@ impl Session {
|
|||
concurrent_initialize_semaphore: Arc::new(tokio::sync::Semaphore::new(
|
||||
opts.concurrent_init_limit.unwrap_or(3),
|
||||
)),
|
||||
udp_tracker_client,
|
||||
ratelimits: Limits::new(opts.ratelimits),
|
||||
trackers: opts.trackers,
|
||||
#[cfg(feature = "disable-upload")]
|
||||
|
|
@ -1355,6 +1361,7 @@ impl Session {
|
|||
force_tracker_interval,
|
||||
announce_port,
|
||||
self.reqwest_client.clone(),
|
||||
self.udp_tracker_client.clone(),
|
||||
);
|
||||
|
||||
let initial_peers_rx = if initial_peers.is_empty() {
|
||||
|
|
|
|||
|
|
@ -32,3 +32,5 @@ tracing = "0.1.40"
|
|||
reqwest = { version = "0.12", default-features = false, features = ["json"] }
|
||||
bencode = { path = "../bencode", default-features = false, package = "librqbit-bencode", version = "3" }
|
||||
url = { version = "2", default-features = false }
|
||||
parking_lot = "0.12.3"
|
||||
tokio-util = "0.7.13"
|
||||
|
|
|
|||
|
|
@ -3,3 +3,4 @@ mod tracker_comms_http;
|
|||
mod tracker_comms_udp;
|
||||
|
||||
pub use tracker_comms::*;
|
||||
pub use tracker_comms_udp::UdpTrackerClient;
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ use url::Url;
|
|||
|
||||
use crate::tracker_comms_http;
|
||||
use crate::tracker_comms_udp;
|
||||
use crate::tracker_comms_udp::UdpTrackerClient;
|
||||
use librqbit_core::hash_id::Id20;
|
||||
|
||||
pub struct TrackerComms {
|
||||
|
|
@ -89,6 +90,7 @@ impl std::fmt::Debug for SupportedTracker {
|
|||
}
|
||||
|
||||
impl TrackerComms {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn start(
|
||||
info_hash: Id20,
|
||||
peer_id: Id20,
|
||||
|
|
@ -97,6 +99,7 @@ impl TrackerComms {
|
|||
force_interval: Option<Duration>,
|
||||
tcp_listen_port: Option<u16>,
|
||||
reqwest_client: reqwest::Client,
|
||||
udp_client: UdpTrackerClient,
|
||||
) -> Option<BoxStream<'static, SocketAddr>> {
|
||||
let trackers = trackers
|
||||
.into_iter()
|
||||
|
|
@ -131,7 +134,7 @@ impl TrackerComms {
|
|||
});
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for tracker in trackers {
|
||||
futures.push(comms.add_tracker(tracker))
|
||||
futures.push(comms.add_tracker(tracker, &udp_client))
|
||||
}
|
||||
while !(futures.is_empty()) {
|
||||
tokio::select! {
|
||||
|
|
@ -155,6 +158,7 @@ impl TrackerComms {
|
|||
fn add_tracker(
|
||||
&self,
|
||||
url: SupportedTracker,
|
||||
client: &UdpTrackerClient,
|
||||
) -> Either<
|
||||
impl std::future::Future<Output = anyhow::Result<()>> + '_ + Send,
|
||||
impl std::future::Future<Output = anyhow::Result<()>> + '_ + Send,
|
||||
|
|
@ -163,7 +167,7 @@ impl TrackerComms {
|
|||
match url {
|
||||
SupportedTracker::Udp(url) => {
|
||||
let span = error_span!(parent: None, "udp_tracker", tracker = %url, info_hash = ?info_hash);
|
||||
self.task_single_tracker_monitor_udp(url)
|
||||
self.task_single_tracker_monitor_udp(url, client.clone())
|
||||
.instrument(span)
|
||||
.right_future()
|
||||
}
|
||||
|
|
@ -247,19 +251,20 @@ impl TrackerComms {
|
|||
Ok(response.interval)
|
||||
}
|
||||
|
||||
async fn task_single_tracker_monitor_udp(&self, url: Url) -> anyhow::Result<()> {
|
||||
async fn task_single_tracker_monitor_udp(
|
||||
&self,
|
||||
url: Url,
|
||||
client: UdpTrackerClient,
|
||||
) -> anyhow::Result<()> {
|
||||
use tracker_comms_udp::*;
|
||||
|
||||
if url.scheme() != "udp" {
|
||||
bail!("expected UDP scheme in {}", url);
|
||||
}
|
||||
let hp: (&str, u16) = (
|
||||
url.host_str().context("missing host")?,
|
||||
let hp: (String, u16) = (
|
||||
url.host_str().context("missing host")?.to_owned(),
|
||||
url.port().context("missing port")?,
|
||||
);
|
||||
let mut requester = UdpTrackerRequester::new(hp)
|
||||
.await
|
||||
.context("error creating UDP tracker requester")?;
|
||||
|
||||
let mut sleep_interval: Option<Duration> = None;
|
||||
loop {
|
||||
|
|
@ -291,7 +296,7 @@ impl TrackerComms {
|
|||
port: self.tcp_listen_port.unwrap_or(0),
|
||||
};
|
||||
|
||||
match requester.announce(request).await {
|
||||
match client.announce(&hp, request).await {
|
||||
Ok(response) => {
|
||||
trace!(len = response.addrs.len(), "received announce response");
|
||||
for addr in response.addrs {
|
||||
|
|
|
|||
|
|
@ -1,10 +1,16 @@
|
|||
use std::net::{Ipv4Addr, SocketAddrV4};
|
||||
use std::{
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
net::{Ipv4Addr, SocketAddrV4},
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use librqbit_core::hash_id::Id20;
|
||||
use librqbit_core::{hash_id::Id20, spawn_utils::spawn_with_cancel};
|
||||
use parking_lot::RwLock;
|
||||
use rand::Rng;
|
||||
use tokio::net::ToSocketAddrs;
|
||||
use tracing::trace;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error_span, trace, warn};
|
||||
|
||||
const ACTION_CONNECT: u32 = 0;
|
||||
const ACTION_ANNOUNCE: u32 = 1;
|
||||
|
|
@ -172,88 +178,170 @@ impl Response {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct UdpTrackerRequester {
|
||||
sock: tokio::net::UdpSocket,
|
||||
connection_id: ConnectionId,
|
||||
read_buf: Vec<u8>,
|
||||
write_buf: Vec<u8>,
|
||||
pub type TrackerAddr = (String, u16);
|
||||
|
||||
struct ConnectionIdMeta {
|
||||
id: ConnectionId,
|
||||
created: Instant,
|
||||
}
|
||||
|
||||
impl UdpTrackerRequester {
|
||||
// Addr is "host:port"
|
||||
pub async fn new(addr: impl ToSocketAddrs) -> anyhow::Result<Self> {
|
||||
#[derive(Default)]
|
||||
struct ClientLocked {
|
||||
connections: HashMap<TrackerAddr, ConnectionIdMeta>,
|
||||
transactions: HashMap<TransactionId, tokio::sync::oneshot::Sender<Response>>,
|
||||
}
|
||||
|
||||
struct ClientShared {
|
||||
sock: tokio::net::UdpSocket,
|
||||
locked: RwLock<ClientLocked>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct UdpTrackerClient {
|
||||
state: Arc<ClientShared>,
|
||||
}
|
||||
|
||||
struct TransactionIdGuard<'a> {
|
||||
tid: TransactionId,
|
||||
state: &'a ClientShared,
|
||||
}
|
||||
|
||||
impl Drop for TransactionIdGuard<'_> {
|
||||
fn drop(&mut self) {
|
||||
let mut g = self.state.locked.write();
|
||||
g.transactions.remove(&self.tid);
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpTrackerClient {
|
||||
pub async fn new(cancel_token: CancellationToken) -> anyhow::Result<Self> {
|
||||
let sock = tokio::net::UdpSocket::bind("0.0.0.0:0")
|
||||
.await
|
||||
.context("error binding UDP socket")?;
|
||||
sock.connect(addr)
|
||||
.await
|
||||
.context("error connecting UDP socket")?;
|
||||
|
||||
let tid = new_transaction_id();
|
||||
let mut write_buf = Vec::new();
|
||||
let mut read_buf = vec![0u8; 4096];
|
||||
|
||||
trace!("sending connect request");
|
||||
Request::Connect.serialize(tid, &mut write_buf);
|
||||
|
||||
sock.send(&write_buf)
|
||||
.await
|
||||
.context("error sending to socket")?;
|
||||
|
||||
let size = sock
|
||||
.recv(&mut read_buf)
|
||||
.await
|
||||
.context("error receiving from socket")?;
|
||||
|
||||
let (rtid, response) =
|
||||
Response::parse(&read_buf[..size]).context("error parsing response")?;
|
||||
if tid != rtid {
|
||||
bail!("expected transaction id {} == {}", tid, rtid);
|
||||
}
|
||||
trace!(response=?response, "received");
|
||||
|
||||
let connection_id = match response {
|
||||
Response::Connect(connection_id) => connection_id,
|
||||
other => bail!("unexpected response {other:?}"),
|
||||
.context("error binding UDP for tracker")?;
|
||||
let client = Self {
|
||||
state: Arc::new(ClientShared {
|
||||
sock,
|
||||
locked: RwLock::new(Default::default()),
|
||||
}),
|
||||
};
|
||||
|
||||
trace!(connection_id);
|
||||
spawn_with_cancel(error_span!("udp_tracker"), cancel_token, {
|
||||
let client = client.clone();
|
||||
async move { client.run().await }
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
sock,
|
||||
connection_id,
|
||||
read_buf,
|
||||
write_buf,
|
||||
})
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
pub async fn announce(&mut self, fields: AnnounceFields) -> anyhow::Result<AnnounceResponse> {
|
||||
let request = Request::Announce(self.connection_id, fields);
|
||||
let response = self.request(request).await?;
|
||||
async fn run(self) -> anyhow::Result<()> {
|
||||
let mut buf = [0u8; 16384];
|
||||
loop {
|
||||
let (len, addr) = match self.state.sock.recv_from(&mut buf).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("error in UdpSocket::recv_from: {e:#}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let (tid, response) = match Response::parse(&buf[..len]) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
debug!(?addr, "error parsing UDP response: {e:#}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
trace!(?tid, ?response, ?addr, "received");
|
||||
|
||||
let t = self.state.locked.write().transactions.remove(&tid);
|
||||
match t {
|
||||
Some(tx) => match tx.send(response) {
|
||||
Ok(_) => {}
|
||||
Err(_) => {
|
||||
debug!(tid, "reader dead");
|
||||
}
|
||||
},
|
||||
None => {
|
||||
debug!(tid, "nowhere to send response");
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_connection_id(&self, addr: &TrackerAddr) -> anyhow::Result<ConnectionId> {
|
||||
if let Some(m) = self.state.locked.read().connections.get(addr) {
|
||||
if m.created.elapsed() < Duration::from_secs(60) {
|
||||
return Ok(m.id);
|
||||
}
|
||||
}
|
||||
|
||||
let response = self.request(addr, Request::Connect).await?;
|
||||
match response {
|
||||
Response::Connect(connection_id) => {
|
||||
self.state.locked.write().connections.insert(
|
||||
addr.clone(),
|
||||
ConnectionIdMeta {
|
||||
id: connection_id,
|
||||
created: Instant::now(),
|
||||
},
|
||||
);
|
||||
Ok(connection_id)
|
||||
}
|
||||
_ => anyhow::bail!("expected connect response"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn request(&self, addr: &TrackerAddr, request: Request) -> anyhow::Result<Response> {
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
let tid_g = self.reserve_transaction_id(tx)?;
|
||||
|
||||
// TODO: no allocs
|
||||
let mut write_buf = Vec::new();
|
||||
request.serialize(tid_g.tid, &mut write_buf);
|
||||
self.state.sock.send_to(&write_buf, addr).await?;
|
||||
|
||||
let response = tokio::time::timeout(Duration::from_secs(10), rx)
|
||||
.await
|
||||
.context("timeout connecting")?
|
||||
.context("sender dead")?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
fn reserve_transaction_id(
|
||||
&self,
|
||||
tx: tokio::sync::oneshot::Sender<Response>,
|
||||
) -> anyhow::Result<TransactionIdGuard<'_>> {
|
||||
let mut g = self.state.locked.write();
|
||||
for _ in 0..10 {
|
||||
let t = new_transaction_id();
|
||||
match g.transactions.entry(t) {
|
||||
Entry::Occupied(_) => continue,
|
||||
Entry::Vacant(vac) => {
|
||||
vac.insert(tx);
|
||||
return Ok(TransactionIdGuard {
|
||||
tid: t,
|
||||
state: &self.state,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
bail!("cant generate transaction id")
|
||||
}
|
||||
|
||||
pub async fn announce(
|
||||
&self,
|
||||
tracker: &TrackerAddr,
|
||||
fields: AnnounceFields,
|
||||
) -> anyhow::Result<AnnounceResponse> {
|
||||
let connection_id = self.get_connection_id(tracker).await?;
|
||||
let request = Request::Announce(connection_id, fields);
|
||||
let response = self.request(tracker, request).await?;
|
||||
match response {
|
||||
Response::Announce(r) => Ok(r),
|
||||
other => bail!("unexpected response {other:?}, expected announce"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn request(&mut self, request: Request) -> anyhow::Result<Response> {
|
||||
let tid = new_transaction_id();
|
||||
self.write_buf.clear();
|
||||
let size = request.serialize(tid, &mut self.write_buf);
|
||||
trace!(request=?request, tid, "sending");
|
||||
self.sock
|
||||
.send(&self.write_buf[..size])
|
||||
.await
|
||||
.context("error sending")?;
|
||||
let size = self.sock.recv(&mut self.read_buf).await?;
|
||||
|
||||
let (rtid, response) = Response::parse(&self.read_buf[..size])?;
|
||||
trace!("received response");
|
||||
if tid != rtid {
|
||||
bail!("unexpected transaction id");
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue