Merge pull request #338 from ikatson/udp-tracker-socket-reuse
[enhancement] UDP tracker socket reuse
This commit is contained in:
commit
a16247aadd
6 changed files with 274 additions and 116 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
|
@ -2827,10 +2827,12 @@ dependencies = [
|
||||||
"librqbit-bencode",
|
"librqbit-bencode",
|
||||||
"librqbit-buffers",
|
"librqbit-buffers",
|
||||||
"librqbit-core",
|
"librqbit-core",
|
||||||
|
"parking_lot",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
"tracing",
|
"tracing",
|
||||||
"url",
|
"url",
|
||||||
"urlencoding",
|
"urlencoding",
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ use tokio::{
|
||||||
};
|
};
|
||||||
use tokio_util::sync::{CancellationToken, DropGuard};
|
use tokio_util::sync::{CancellationToken, DropGuard};
|
||||||
use tracing::{debug, error, error_span, info, trace, warn, Instrument, Span};
|
use tracing::{debug, error, error_span, info, trace, warn, Instrument, Span};
|
||||||
use tracker_comms::TrackerComms;
|
use tracker_comms::{TrackerComms, UdpTrackerClient};
|
||||||
|
|
||||||
pub const SUPPORTED_SCHEMES: [&str; 3] = ["http:", "https:", "magnet:"];
|
pub const SUPPORTED_SCHEMES: [&str; 3] = ["http:", "https:", "magnet:"];
|
||||||
|
|
||||||
|
|
@ -110,6 +110,7 @@ pub struct Session {
|
||||||
dht: Option<Dht>,
|
dht: Option<Dht>,
|
||||||
pub(crate) connector: Arc<StreamConnector>,
|
pub(crate) connector: Arc<StreamConnector>,
|
||||||
reqwest_client: reqwest::Client,
|
reqwest_client: reqwest::Client,
|
||||||
|
udp_tracker_client: UdpTrackerClient,
|
||||||
|
|
||||||
// Lifecycle management
|
// Lifecycle management
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
|
|
@ -625,6 +626,10 @@ impl Session {
|
||||||
blocklist::Blocklist::empty()
|
blocklist::Blocklist::empty()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let udp_tracker_client = UdpTrackerClient::new(token.clone())
|
||||||
|
.await
|
||||||
|
.context("error creating UDP tracker client")?;
|
||||||
|
|
||||||
let session = Arc::new(Self {
|
let session = Arc::new(Self {
|
||||||
persistence,
|
persistence,
|
||||||
bitv_factory,
|
bitv_factory,
|
||||||
|
|
@ -647,6 +652,7 @@ impl Session {
|
||||||
concurrent_initialize_semaphore: Arc::new(tokio::sync::Semaphore::new(
|
concurrent_initialize_semaphore: Arc::new(tokio::sync::Semaphore::new(
|
||||||
opts.concurrent_init_limit.unwrap_or(3),
|
opts.concurrent_init_limit.unwrap_or(3),
|
||||||
)),
|
)),
|
||||||
|
udp_tracker_client,
|
||||||
ratelimits: Limits::new(opts.ratelimits),
|
ratelimits: Limits::new(opts.ratelimits),
|
||||||
trackers: opts.trackers,
|
trackers: opts.trackers,
|
||||||
#[cfg(feature = "disable-upload")]
|
#[cfg(feature = "disable-upload")]
|
||||||
|
|
@ -1355,6 +1361,7 @@ impl Session {
|
||||||
force_tracker_interval,
|
force_tracker_interval,
|
||||||
announce_port,
|
announce_port,
|
||||||
self.reqwest_client.clone(),
|
self.reqwest_client.clone(),
|
||||||
|
self.udp_tracker_client.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let initial_peers_rx = if initial_peers.is_empty() {
|
let initial_peers_rx = if initial_peers.is_empty() {
|
||||||
|
|
|
||||||
|
|
@ -32,3 +32,5 @@ tracing = "0.1.40"
|
||||||
reqwest = { version = "0.12", default-features = false, features = ["json"] }
|
reqwest = { version = "0.12", default-features = false, features = ["json"] }
|
||||||
bencode = { path = "../bencode", default-features = false, package = "librqbit-bencode", version = "3" }
|
bencode = { path = "../bencode", default-features = false, package = "librqbit-bencode", version = "3" }
|
||||||
url = { version = "2", default-features = false }
|
url = { version = "2", default-features = false }
|
||||||
|
parking_lot = "0.12.3"
|
||||||
|
tokio-util = "0.7.13"
|
||||||
|
|
|
||||||
|
|
@ -3,3 +3,4 @@ mod tracker_comms_http;
|
||||||
mod tracker_comms_udp;
|
mod tracker_comms_udp;
|
||||||
|
|
||||||
pub use tracker_comms::*;
|
pub use tracker_comms::*;
|
||||||
|
pub use tracker_comms_udp::UdpTrackerClient;
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ use url::Url;
|
||||||
|
|
||||||
use crate::tracker_comms_http;
|
use crate::tracker_comms_http;
|
||||||
use crate::tracker_comms_udp;
|
use crate::tracker_comms_udp;
|
||||||
|
use crate::tracker_comms_udp::UdpTrackerClient;
|
||||||
use librqbit_core::hash_id::Id20;
|
use librqbit_core::hash_id::Id20;
|
||||||
|
|
||||||
pub struct TrackerComms {
|
pub struct TrackerComms {
|
||||||
|
|
@ -89,6 +90,7 @@ impl std::fmt::Debug for SupportedTracker {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TrackerComms {
|
impl TrackerComms {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn start(
|
pub fn start(
|
||||||
info_hash: Id20,
|
info_hash: Id20,
|
||||||
peer_id: Id20,
|
peer_id: Id20,
|
||||||
|
|
@ -97,6 +99,7 @@ impl TrackerComms {
|
||||||
force_interval: Option<Duration>,
|
force_interval: Option<Duration>,
|
||||||
tcp_listen_port: Option<u16>,
|
tcp_listen_port: Option<u16>,
|
||||||
reqwest_client: reqwest::Client,
|
reqwest_client: reqwest::Client,
|
||||||
|
udp_client: UdpTrackerClient,
|
||||||
) -> Option<BoxStream<'static, SocketAddr>> {
|
) -> Option<BoxStream<'static, SocketAddr>> {
|
||||||
let trackers = trackers
|
let trackers = trackers
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|
@ -131,7 +134,7 @@ impl TrackerComms {
|
||||||
});
|
});
|
||||||
let mut futures = FuturesUnordered::new();
|
let mut futures = FuturesUnordered::new();
|
||||||
for tracker in trackers {
|
for tracker in trackers {
|
||||||
futures.push(comms.add_tracker(tracker))
|
futures.push(comms.add_tracker(tracker, &udp_client))
|
||||||
}
|
}
|
||||||
while !(futures.is_empty()) {
|
while !(futures.is_empty()) {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
|
|
@ -155,6 +158,7 @@ impl TrackerComms {
|
||||||
fn add_tracker(
|
fn add_tracker(
|
||||||
&self,
|
&self,
|
||||||
url: SupportedTracker,
|
url: SupportedTracker,
|
||||||
|
client: &UdpTrackerClient,
|
||||||
) -> Either<
|
) -> Either<
|
||||||
impl std::future::Future<Output = anyhow::Result<()>> + '_ + Send,
|
impl std::future::Future<Output = anyhow::Result<()>> + '_ + Send,
|
||||||
impl std::future::Future<Output = anyhow::Result<()>> + '_ + Send,
|
impl std::future::Future<Output = anyhow::Result<()>> + '_ + Send,
|
||||||
|
|
@ -163,7 +167,7 @@ impl TrackerComms {
|
||||||
match url {
|
match url {
|
||||||
SupportedTracker::Udp(url) => {
|
SupportedTracker::Udp(url) => {
|
||||||
let span = error_span!(parent: None, "udp_tracker", tracker = %url, info_hash = ?info_hash);
|
let span = error_span!(parent: None, "udp_tracker", tracker = %url, info_hash = ?info_hash);
|
||||||
self.task_single_tracker_monitor_udp(url)
|
self.task_single_tracker_monitor_udp(url, client.clone())
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.right_future()
|
.right_future()
|
||||||
}
|
}
|
||||||
|
|
@ -183,7 +187,7 @@ impl TrackerComms {
|
||||||
|
|
||||||
async fn task_single_tracker_monitor_http(&self, mut tracker_url: Url) -> anyhow::Result<()> {
|
async fn task_single_tracker_monitor_http(&self, mut tracker_url: Url) -> anyhow::Result<()> {
|
||||||
let mut event = Some(tracker_comms_http::TrackerRequestEvent::Started);
|
let mut event = Some(tracker_comms_http::TrackerRequestEvent::Started);
|
||||||
trace!(url=?tracker_url, "starting monitor");
|
trace!(url=%tracker_url, "starting monitor");
|
||||||
loop {
|
loop {
|
||||||
let stats = self.stats.get();
|
let stats = self.stats.get();
|
||||||
let request = tracker_comms_http::TrackerRequest {
|
let request = tracker_comms_http::TrackerRequest {
|
||||||
|
|
@ -227,7 +231,7 @@ impl TrackerComms {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn tracker_one_request_http(&self, tracker_url: Url) -> anyhow::Result<u64> {
|
async fn tracker_one_request_http(&self, tracker_url: Url) -> anyhow::Result<u64> {
|
||||||
debug!(url = ?tracker_url, "calling tracker over http");
|
debug!(url = %tracker_url, "calling tracker over http");
|
||||||
let response: reqwest::Response = self.reqwest_client.get(tracker_url).send().await?;
|
let response: reqwest::Response = self.reqwest_client.get(tracker_url).send().await?;
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
anyhow::bail!("tracker responded with {:?}", response.status());
|
anyhow::bail!("tracker responded with {:?}", response.status());
|
||||||
|
|
@ -247,19 +251,20 @@ impl TrackerComms {
|
||||||
Ok(response.interval)
|
Ok(response.interval)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn task_single_tracker_monitor_udp(&self, url: Url) -> anyhow::Result<()> {
|
async fn task_single_tracker_monitor_udp(
|
||||||
|
&self,
|
||||||
|
url: Url,
|
||||||
|
client: UdpTrackerClient,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
use tracker_comms_udp::*;
|
use tracker_comms_udp::*;
|
||||||
|
|
||||||
if url.scheme() != "udp" {
|
if url.scheme() != "udp" {
|
||||||
bail!("expected UDP scheme in {}", url);
|
bail!("expected UDP scheme in {}", url);
|
||||||
}
|
}
|
||||||
let hp: (&str, u16) = (
|
let hp: (String, u16) = (
|
||||||
url.host_str().context("missing host")?,
|
url.host_str().context("missing host")?.to_owned(),
|
||||||
url.port().context("missing port")?,
|
url.port().context("missing port")?,
|
||||||
);
|
);
|
||||||
let mut requester = UdpTrackerRequester::new(hp)
|
|
||||||
.await
|
|
||||||
.context("error creating UDP tracker requester")?;
|
|
||||||
|
|
||||||
let mut sleep_interval: Option<Duration> = None;
|
let mut sleep_interval: Option<Duration> = None;
|
||||||
loop {
|
loop {
|
||||||
|
|
@ -291,7 +296,7 @@ impl TrackerComms {
|
||||||
port: self.tcp_listen_port.unwrap_or(0),
|
port: self.tcp_listen_port.unwrap_or(0),
|
||||||
};
|
};
|
||||||
|
|
||||||
match requester.announce(request).await {
|
match client.announce(&hp, request).await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
trace!(len = response.addrs.len(), "received announce response");
|
trace!(len = response.addrs.len(), "received announce response");
|
||||||
for addr in response.addrs {
|
for addr in response.addrs {
|
||||||
|
|
@ -305,7 +310,7 @@ impl TrackerComms {
|
||||||
sleep_interval = Some(self.force_tracker_interval.unwrap_or(new_interval));
|
sleep_interval = Some(self.force_tracker_interval.unwrap_or(new_interval));
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!(url = ?url, "error reading announce response: {e:#}");
|
debug!(url = %url, "error reading announce response: {e:#}");
|
||||||
if sleep_interval.is_none() {
|
if sleep_interval.is_none() {
|
||||||
sleep_interval = Some(
|
sleep_interval = Some(
|
||||||
self.force_tracker_interval
|
self.force_tracker_interval
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,22 @@
|
||||||
use std::net::{Ipv4Addr, SocketAddrV4};
|
use std::{
|
||||||
|
collections::{hash_map::Entry, HashMap},
|
||||||
|
ffi::CStr,
|
||||||
|
net::{Ipv4Addr, SocketAddrV4},
|
||||||
|
sync::Arc,
|
||||||
|
time::{Duration, Instant},
|
||||||
|
};
|
||||||
|
|
||||||
use anyhow::{bail, Context};
|
use anyhow::{bail, Context};
|
||||||
use librqbit_core::hash_id::Id20;
|
use librqbit_core::{hash_id::Id20, spawn_utils::spawn_with_cancel};
|
||||||
|
use parking_lot::RwLock;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use tokio::net::ToSocketAddrs;
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::trace;
|
use tracing::{debug, error_span, trace, warn};
|
||||||
|
|
||||||
const ACTION_CONNECT: u32 = 0;
|
const ACTION_CONNECT: u32 = 0;
|
||||||
const ACTION_ANNOUNCE: u32 = 1;
|
const ACTION_ANNOUNCE: u32 = 1;
|
||||||
// const ACTION_SCRAPE: u32 = 2;
|
// const ACTION_SCRAPE: u32 = 2;
|
||||||
// const ACTION_ERROR: u32 = 3;
|
const ACTION_ERROR: u32 = 3;
|
||||||
|
|
||||||
pub const EVENT_NONE: u32 = 0;
|
pub const EVENT_NONE: u32 = 0;
|
||||||
pub const EVENT_COMPLETED: u32 = 1;
|
pub const EVENT_COMPLETED: u32 = 1;
|
||||||
|
|
@ -44,31 +51,51 @@ pub enum Request {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Request {
|
impl Request {
|
||||||
pub fn serialize(&self, transaction_id: TransactionId, buf: &mut Vec<u8>) -> usize {
|
pub fn serialize(
|
||||||
let cur_len = buf.len();
|
&self,
|
||||||
match self {
|
transaction_id: TransactionId,
|
||||||
Request::Connect => {
|
buf: &mut [u8],
|
||||||
buf.extend_from_slice(&CONNECTION_ID_MAGIC.to_be_bytes());
|
) -> anyhow::Result<usize> {
|
||||||
buf.extend_from_slice(&ACTION_CONNECT.to_be_bytes());
|
struct W<'a> {
|
||||||
buf.extend_from_slice(&transaction_id.to_be_bytes());
|
buf: &'a mut [u8],
|
||||||
}
|
offset: usize,
|
||||||
Request::Announce(connection_id, fields) => {
|
}
|
||||||
buf.extend_from_slice(&connection_id.to_be_bytes());
|
impl W<'_> {
|
||||||
buf.extend_from_slice(&ACTION_ANNOUNCE.to_be_bytes());
|
fn extend_from_slice(&mut self, s: &[u8]) -> anyhow::Result<()> {
|
||||||
buf.extend_from_slice(&transaction_id.to_be_bytes());
|
if self.buf.len() < self.offset + s.len() {
|
||||||
buf.extend_from_slice(&fields.info_hash.0);
|
bail!("not enough space in buffer")
|
||||||
buf.extend_from_slice(&fields.peer_id.0);
|
}
|
||||||
buf.extend_from_slice(&fields.downloaded.to_be_bytes());
|
self.buf[self.offset..self.offset + s.len()].copy_from_slice(s);
|
||||||
buf.extend_from_slice(&fields.left.to_be_bytes());
|
self.offset += s.len();
|
||||||
buf.extend_from_slice(&fields.uploaded.to_be_bytes());
|
Ok(())
|
||||||
buf.extend_from_slice(&fields.event.to_be_bytes());
|
|
||||||
buf.extend_from_slice(&0u32.to_be_bytes()); // ip address 0
|
|
||||||
buf.extend_from_slice(&fields.key.to_be_bytes());
|
|
||||||
buf.extend_from_slice(&(-1i32).to_be_bytes()); // num want -1
|
|
||||||
buf.extend_from_slice(&fields.port.to_be_bytes());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
buf.len() - cur_len
|
|
||||||
|
let mut w = W { buf, offset: 0 };
|
||||||
|
|
||||||
|
match self {
|
||||||
|
Request::Connect => {
|
||||||
|
w.extend_from_slice(&CONNECTION_ID_MAGIC.to_be_bytes())?;
|
||||||
|
w.extend_from_slice(&ACTION_CONNECT.to_be_bytes())?;
|
||||||
|
w.extend_from_slice(&transaction_id.to_be_bytes())?;
|
||||||
|
}
|
||||||
|
Request::Announce(connection_id, fields) => {
|
||||||
|
w.extend_from_slice(&connection_id.to_be_bytes())?;
|
||||||
|
w.extend_from_slice(&ACTION_ANNOUNCE.to_be_bytes())?;
|
||||||
|
w.extend_from_slice(&transaction_id.to_be_bytes())?;
|
||||||
|
w.extend_from_slice(&fields.info_hash.0)?;
|
||||||
|
w.extend_from_slice(&fields.peer_id.0)?;
|
||||||
|
w.extend_from_slice(&fields.downloaded.to_be_bytes())?;
|
||||||
|
w.extend_from_slice(&fields.left.to_be_bytes())?;
|
||||||
|
w.extend_from_slice(&fields.uploaded.to_be_bytes())?;
|
||||||
|
w.extend_from_slice(&fields.event.to_be_bytes())?;
|
||||||
|
w.extend_from_slice(&0u32.to_be_bytes())?; // ip address 0
|
||||||
|
w.extend_from_slice(&fields.key.to_be_bytes())?;
|
||||||
|
w.extend_from_slice(&(-1i32).to_be_bytes())?; // num want -1
|
||||||
|
w.extend_from_slice(&fields.port.to_be_bytes())?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(w.offset)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -86,6 +113,9 @@ pub struct AnnounceResponse {
|
||||||
pub enum Response {
|
pub enum Response {
|
||||||
Connect(ConnectionId),
|
Connect(ConnectionId),
|
||||||
Announce(AnnounceResponse),
|
Announce(AnnounceResponse),
|
||||||
|
#[allow(dead_code)]
|
||||||
|
Error(String),
|
||||||
|
Unknown,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn split_slice(s: &[u8], first_len: usize) -> Option<(&[u8], &[u8])> {
|
fn split_slice(s: &[u8], first_len: usize) -> Option<(&[u8], &[u8])> {
|
||||||
|
|
@ -128,7 +158,20 @@ parse_impl!(i16, 2);
|
||||||
impl Response {
|
impl Response {
|
||||||
pub fn parse(buf: &[u8]) -> anyhow::Result<(TransactionId, Self)> {
|
pub fn parse(buf: &[u8]) -> anyhow::Result<(TransactionId, Self)> {
|
||||||
let (action, buf) = u32::parse_num(buf).context("can't parse action")?;
|
let (action, buf) = u32::parse_num(buf).context("can't parse action")?;
|
||||||
let (tid, mut buf) = u32::parse_num(buf).context("can't parse transaction id")?;
|
let (tid, buf) = u32::parse_num(buf).context("can't parse transaction id")?;
|
||||||
|
|
||||||
|
let response = match Self::parse_response(action, buf) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
debug!("error parsing: {e:#}");
|
||||||
|
Response::Unknown
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((tid, response))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_response(action: u32, mut buf: &[u8]) -> anyhow::Result<Self> {
|
||||||
let response = match action {
|
let response = match action {
|
||||||
ACTION_CONNECT => {
|
ACTION_CONNECT => {
|
||||||
let (connection_id, b) =
|
let (connection_id, b) =
|
||||||
|
|
@ -158,6 +201,15 @@ impl Response {
|
||||||
addrs,
|
addrs,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
ACTION_ERROR => {
|
||||||
|
let msg = CStr::from_bytes_with_nul(buf)
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.to_str().ok())
|
||||||
|
.or_else(|| std::str::from_utf8(buf).ok())
|
||||||
|
.unwrap_or("<invalid UTF-8>")
|
||||||
|
.to_owned();
|
||||||
|
return Ok(Response::Error(msg));
|
||||||
|
}
|
||||||
_ => bail!("unsupported action {action}"),
|
_ => bail!("unsupported action {action}"),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -168,92 +220,182 @@ impl Response {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((tid, response))
|
Ok(response)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct UdpTrackerRequester {
|
pub type TrackerAddr = (String, u16);
|
||||||
sock: tokio::net::UdpSocket,
|
|
||||||
connection_id: ConnectionId,
|
struct ConnectionIdMeta {
|
||||||
read_buf: Vec<u8>,
|
id: ConnectionId,
|
||||||
write_buf: Vec<u8>,
|
created: Instant,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UdpTrackerRequester {
|
#[derive(Default)]
|
||||||
// Addr is "host:port"
|
struct ClientLocked {
|
||||||
pub async fn new(addr: impl ToSocketAddrs) -> anyhow::Result<Self> {
|
connections: HashMap<TrackerAddr, ConnectionIdMeta>,
|
||||||
|
transactions: HashMap<TransactionId, tokio::sync::oneshot::Sender<Response>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ClientShared {
|
||||||
|
sock: tokio::net::UdpSocket,
|
||||||
|
locked: RwLock<ClientLocked>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct UdpTrackerClient {
|
||||||
|
state: Arc<ClientShared>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TransactionIdGuard<'a> {
|
||||||
|
tid: TransactionId,
|
||||||
|
state: &'a ClientShared,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for TransactionIdGuard<'_> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let mut g = self.state.locked.write();
|
||||||
|
g.transactions.remove(&self.tid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UdpTrackerClient {
|
||||||
|
pub async fn new(cancel_token: CancellationToken) -> anyhow::Result<Self> {
|
||||||
let sock = tokio::net::UdpSocket::bind("0.0.0.0:0")
|
let sock = tokio::net::UdpSocket::bind("0.0.0.0:0")
|
||||||
.await
|
.await
|
||||||
.context("error binding UDP socket")?;
|
.context("error binding UDP for tracker")?;
|
||||||
sock.connect(addr)
|
let client = Self {
|
||||||
.await
|
state: Arc::new(ClientShared {
|
||||||
.context("error connecting UDP socket")?;
|
sock,
|
||||||
|
locked: RwLock::new(Default::default()),
|
||||||
let tid = new_transaction_id();
|
}),
|
||||||
let mut write_buf = Vec::new();
|
|
||||||
let mut read_buf = vec![0u8; 4096];
|
|
||||||
|
|
||||||
trace!("sending connect request");
|
|
||||||
Request::Connect.serialize(tid, &mut write_buf);
|
|
||||||
|
|
||||||
sock.send(&write_buf)
|
|
||||||
.await
|
|
||||||
.context("error sending to socket")?;
|
|
||||||
|
|
||||||
let size = sock
|
|
||||||
.recv(&mut read_buf)
|
|
||||||
.await
|
|
||||||
.context("error receiving from socket")?;
|
|
||||||
|
|
||||||
let (rtid, response) =
|
|
||||||
Response::parse(&read_buf[..size]).context("error parsing response")?;
|
|
||||||
if tid != rtid {
|
|
||||||
bail!("expected transaction id {} == {}", tid, rtid);
|
|
||||||
}
|
|
||||||
trace!(response=?response, "received");
|
|
||||||
|
|
||||||
let connection_id = match response {
|
|
||||||
Response::Connect(connection_id) => connection_id,
|
|
||||||
other => bail!("unexpected response {other:?}"),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
trace!(connection_id);
|
spawn_with_cancel(error_span!("udp_tracker"), cancel_token, {
|
||||||
|
let client = client.clone();
|
||||||
|
async move { client.run().await }
|
||||||
|
});
|
||||||
|
|
||||||
Ok(Self {
|
Ok(client)
|
||||||
sock,
|
|
||||||
connection_id,
|
|
||||||
read_buf,
|
|
||||||
write_buf,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn announce(&mut self, fields: AnnounceFields) -> anyhow::Result<AnnounceResponse> {
|
async fn run(self) -> anyhow::Result<()> {
|
||||||
let request = Request::Announce(self.connection_id, fields);
|
let mut buf = [0u8; 16384];
|
||||||
let response = self.request(request).await?;
|
loop {
|
||||||
|
let (len, addr) = match self.state.sock.recv_from(&mut buf).await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
warn!("error in UdpSocket::recv_from: {e:#}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (tid, response) = match Response::parse(&buf[..len]) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
debug!(?addr, "error parsing UDP response: {e:#}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
trace!(?tid, ?response, ?addr, "received");
|
||||||
|
|
||||||
|
let t = self.state.locked.write().transactions.remove(&tid);
|
||||||
|
match t {
|
||||||
|
Some(tx) => match tx.send(response) {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(_) => {
|
||||||
|
debug!(tid, "reader dead");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
None => {
|
||||||
|
debug!(tid, "nowhere to send response");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_connection_id(&self, addr: &TrackerAddr) -> anyhow::Result<ConnectionId> {
|
||||||
|
if let Some(m) = self.state.locked.read().connections.get(addr) {
|
||||||
|
if m.created.elapsed() < Duration::from_secs(60) {
|
||||||
|
return Ok(m.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self.request(addr, Request::Connect).await?;
|
||||||
|
match response {
|
||||||
|
Response::Connect(connection_id) => {
|
||||||
|
self.state.locked.write().connections.insert(
|
||||||
|
addr.clone(),
|
||||||
|
ConnectionIdMeta {
|
||||||
|
id: connection_id,
|
||||||
|
created: Instant::now(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
Ok(connection_id)
|
||||||
|
}
|
||||||
|
_ => anyhow::bail!("expected connect response"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn request(&self, addr: &TrackerAddr, request: Request) -> anyhow::Result<Response> {
|
||||||
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||||
|
let tid_g = self.reserve_transaction_id(tx)?;
|
||||||
|
|
||||||
|
let mut write_buf = [0u8; 1024];
|
||||||
|
let len = request.serialize(tid_g.tid, &mut write_buf)?;
|
||||||
|
self.state.sock.send_to(&write_buf[..len], addr).await?;
|
||||||
|
|
||||||
|
let response = tokio::time::timeout(Duration::from_secs(10), rx)
|
||||||
|
.await
|
||||||
|
.context("timeout connecting")?
|
||||||
|
.context("sender dead")?;
|
||||||
|
match &response {
|
||||||
|
Response::Error(e) => {
|
||||||
|
anyhow::bail!("remote errored: {e}")
|
||||||
|
}
|
||||||
|
Response::Unknown => {
|
||||||
|
anyhow::bail!("remote replied with something we could not parse")
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reserve_transaction_id(
|
||||||
|
&self,
|
||||||
|
tx: tokio::sync::oneshot::Sender<Response>,
|
||||||
|
) -> anyhow::Result<TransactionIdGuard<'_>> {
|
||||||
|
let mut g = self.state.locked.write();
|
||||||
|
for _ in 0..10 {
|
||||||
|
let t = new_transaction_id();
|
||||||
|
match g.transactions.entry(t) {
|
||||||
|
Entry::Occupied(_) => continue,
|
||||||
|
Entry::Vacant(vac) => {
|
||||||
|
vac.insert(tx);
|
||||||
|
return Ok(TransactionIdGuard {
|
||||||
|
tid: t,
|
||||||
|
state: &self.state,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bail!("cant generate transaction id")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn announce(
|
||||||
|
&self,
|
||||||
|
tracker: &TrackerAddr,
|
||||||
|
fields: AnnounceFields,
|
||||||
|
) -> anyhow::Result<AnnounceResponse> {
|
||||||
|
let connection_id = self.get_connection_id(tracker).await?;
|
||||||
|
let request = Request::Announce(connection_id, fields);
|
||||||
|
let response = self.request(tracker, request).await?;
|
||||||
match response {
|
match response {
|
||||||
Response::Announce(r) => Ok(r),
|
Response::Announce(r) => Ok(r),
|
||||||
other => bail!("unexpected response {other:?}, expected announce"),
|
other => bail!("unexpected response {other:?}, expected announce"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn request(&mut self, request: Request) -> anyhow::Result<Response> {
|
|
||||||
let tid = new_transaction_id();
|
|
||||||
self.write_buf.clear();
|
|
||||||
let size = request.serialize(tid, &mut self.write_buf);
|
|
||||||
trace!(request=?request, tid, "sending");
|
|
||||||
self.sock
|
|
||||||
.send(&self.write_buf[..size])
|
|
||||||
.await
|
|
||||||
.context("error sending")?;
|
|
||||||
let size = self.sock.recv(&mut self.read_buf).await?;
|
|
||||||
|
|
||||||
let (rtid, response) = Response::parse(&self.read_buf[..size])?;
|
|
||||||
trace!("received response");
|
|
||||||
if tid != rtid {
|
|
||||||
bail!("unexpected transaction id");
|
|
||||||
}
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -280,12 +422,12 @@ mod tests {
|
||||||
sock.connect("opentor.net:6969").await.unwrap();
|
sock.connect("opentor.net:6969").await.unwrap();
|
||||||
|
|
||||||
let tid = new_transaction_id();
|
let tid = new_transaction_id();
|
||||||
let mut write_buf = Vec::new();
|
let mut write_buf = [0u8; 16384];
|
||||||
let mut read_buf = vec![0u8; 4096];
|
let mut read_buf = vec![0u8; 4096];
|
||||||
|
|
||||||
Request::Connect.serialize(tid, &mut write_buf);
|
let len = Request::Connect.serialize(tid, &mut write_buf).unwrap();
|
||||||
|
|
||||||
sock.send(&write_buf).await.unwrap();
|
sock.send(&write_buf[..len]).await.unwrap();
|
||||||
|
|
||||||
let size = sock.recv(&mut read_buf).await.unwrap();
|
let size = sock.recv(&mut read_buf).await.unwrap();
|
||||||
|
|
||||||
|
|
@ -314,8 +456,7 @@ mod tests {
|
||||||
port: 24563,
|
port: 24563,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
write_buf.clear();
|
let size = request.serialize(tid, &mut write_buf).unwrap();
|
||||||
let size = request.serialize(tid, &mut write_buf);
|
|
||||||
|
|
||||||
sock.send(&write_buf[..size]).await.unwrap();
|
sock.send(&write_buf[..size]).await.unwrap();
|
||||||
let size = sock.recv(&mut read_buf).await.unwrap();
|
let size = sock.recv(&mut read_buf).await.unwrap();
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue