Abstracting tracker comms

This commit is contained in:
Igor Katson 2024-02-17 10:51:09 +00:00
parent 6f3383050e
commit 8733538d83
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
6 changed files with 539 additions and 312 deletions

View file

@ -36,7 +36,8 @@ mod session;
mod spawn_utils;
mod torrent_state;
pub mod tracing_subscriber_config_utils;
mod tracker_comms;
pub mod tracker_comms;
pub mod tracker_comms_http;
pub mod tracker_comms_udp;
mod type_aliases;

View file

@ -57,7 +57,6 @@ use std::{
use anyhow::{bail, Context};
use backoff::backoff::Backoff;
use bencode::from_bytes;
use buffers::{ByteBuf, ByteString};
use clone_to_owned::CloneToOwned;
use futures::{stream::FuturesUnordered, StreamExt};
@ -83,7 +82,6 @@ use tokio::{
};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, error_span, info, trace, warn};
use url::Url;
use crate::{
chunk_tracker::{ChunkMarkingResult, ChunkTracker},
@ -93,7 +91,6 @@ use crate::{
},
session::CheckedIncomingConnection,
torrent_state::{peer::Peer, utils::atomic_inc},
tracker_comms::{TrackerError, TrackerRequest, TrackerRequestEvent, TrackerResponse},
type_aliases::{PeerHandle, BF},
};
@ -237,13 +234,6 @@ impl TorrentStateLive {
cancellation_token,
});
for tracker in state.meta.trackers.iter() {
state.spawn(
error_span!(parent: state.meta.span.clone(), "tracker_monitor", url = tracker.to_string()),
state.clone().task_single_tracker_monitor(tracker.clone()),
);
}
state.spawn(
error_span!(parent: state.meta.span.clone(), "speed_estimator_updater"),
{
@ -297,74 +287,6 @@ impl TorrentStateLive {
&self.up_speed_estimator
}
async fn tracker_one_request(&self, tracker_url: Url) -> anyhow::Result<u64> {
let response: reqwest::Response = reqwest::get(tracker_url).await?;
if !response.status().is_success() {
anyhow::bail!("tracker responded with {:?}", response.status());
}
let bytes = response.bytes().await?;
if let Ok(error) = from_bytes::<TrackerError>(&bytes) {
anyhow::bail!(
"tracker returned failure. Failure reason: {}",
error.failure_reason
)
};
let response = from_bytes::<TrackerResponse>(&bytes)?;
for peer in response.peers.iter_sockaddrs() {
self.add_peer_if_not_seen(peer)?;
}
Ok(response.interval)
}
async fn task_single_tracker_monitor(
self: Arc<Self>,
mut tracker_url: Url,
) -> anyhow::Result<()> {
let mut event = Some(TrackerRequestEvent::Started);
loop {
let request = TrackerRequest {
info_hash: self.info_hash(),
peer_id: self.peer_id(),
port: 6778,
uploaded: self.get_uploaded_bytes(),
downloaded: self.get_downloaded_bytes(),
left: self.get_left_to_download_bytes(),
compact: true,
no_peer_id: false,
event,
ip: None,
numwant: None,
key: None,
trackerid: None,
};
let request_query = request.as_querystring();
tracker_url.set_query(Some(&request_query));
match self.tracker_one_request(tracker_url.clone()).await {
Ok(interval) => {
event = None;
let interval = self
.meta
.options
.force_tracker_interval
.unwrap_or_else(|| Duration::from_secs(interval));
debug!(
"sleeping for {:?} after calling tracker {}",
interval,
tracker_url.host().unwrap()
);
tokio::time::sleep(interval).await;
}
Err(e) => {
debug!("error calling the tracker {}: {:#}", tracker_url, e);
tokio::time::sleep(Duration::from_secs(60)).await;
}
};
}
}
pub(crate) fn add_incoming_peer(
self: &Arc<Self>,
checked_peer: CheckedIncomingConnection,

View file

@ -1,233 +1,226 @@
use buffers::ByteBuf;
use byteorder::ByteOrder;
use serde::{Deserialize, Deserializer};
use std::{
fmt::Write,
marker::PhantomData,
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4},
str::FromStr,
};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::bail;
use anyhow::Context;
use futures::Stream;
use librqbit_core::spawn_utils::spawn_with_cancel;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use tracing::error_span;
use tracing::info;
use url::Url;
use crate::tracker_comms_http;
use crate::tracker_comms_udp;
use librqbit_core::hash_id::Id20;
#[derive(Clone, Copy)]
pub enum TrackerRequestEvent {
Started,
#[allow(dead_code)]
Stopped,
#[allow(dead_code)]
Completed,
pub struct TrackerComms {
info_hash: Id20,
peer_id: Id20,
stats: Box<dyn TorrentStatsForTracker>,
force_tracker_interval: Option<Duration>,
cancellation_token: CancellationToken,
tx: Sender,
tcp_listen_port: Option<u16>,
}
pub struct TrackerRequest {
pub info_hash: Id20,
pub peer_id: Id20,
pub event: Option<TrackerRequestEvent>,
pub port: u16,
pub uploaded: u64,
pub downloaded: u64,
pub left: u64,
pub compact: bool,
pub no_peer_id: bool,
pub trait TorrentStatsForTracker: Send + Sync {
fn get_uploaded_bytes(&self) -> u64;
fn get_downloaded_bytes(&self) -> u64;
fn get_total_bytes(&self) -> u64;
pub ip: Option<std::net::IpAddr>,
pub numwant: Option<usize>,
pub key: Option<String>,
pub trackerid: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct TrackerError<'a> {
#[serde(rename = "failure reason", borrow)]
pub failure_reason: ByteBuf<'a>,
}
#[derive(Deserialize, Debug)]
pub struct DictPeer<'a> {
#[serde(deserialize_with = "deserialize_ip_string")]
ip: IpAddr,
#[serde(borrow)]
#[allow(dead_code)]
peer_id: Option<ByteBuf<'a>>,
port: u16,
}
impl<'a> DictPeer<'a> {
fn as_sockaddr(&self) -> SocketAddr {
SocketAddr::new(self.ip, self.port)
fn get_left_to_download_bytes(&self) -> u64 {
let total = self.get_total_bytes();
let down = self.get_downloaded_bytes();
if total >= down {
return total - down;
}
0
}
}
#[derive(Debug)]
pub struct Peers {
addrs: Vec<SocketAddr>,
}
type Sender = tokio::sync::mpsc::Sender<SocketAddr>;
impl Peers {
pub fn iter_sockaddrs(&self) -> impl Iterator<Item = std::net::SocketAddr> + '_ {
self.addrs.iter().copied()
}
}
impl<'de> serde::de::Deserialize<'de> for Peers {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor<'de> {
phantom: std::marker::PhantomData<&'de ()>,
}
impl<'de> serde::de::Visitor<'de> for Visitor<'de> {
type Value = Peers;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a list of peers in dict or binary format")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut peers = Vec::new();
while let Some(peer) = seq.next_element::<DictPeer>()? {
peers.push(peer.as_sockaddr())
}
Ok(Peers { addrs: peers })
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Peers {
addrs: parse_compact_peers(v)
.into_iter()
.map(|v| v.into())
.collect(),
})
}
}
deserializer.deserialize_any(Visitor {
phantom: PhantomData,
})
}
}
fn deserialize_ip_string<'de, D>(de: D) -> Result<IpAddr, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = IpAddr;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("expecting an IPv4 address")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
IpAddr::from_str(v).map_err(|e| E::custom(format!("cannot parse ip: {e}")))
}
}
de.deserialize_str(Visitor {})
}
fn parse_compact_peers(b: &[u8]) -> Vec<SocketAddrV4> {
let mut ips = Vec::new();
for chunk in b.chunks_exact(6) {
let ip_chunk = &chunk[..4];
let port_chunk = &chunk[4..6];
let ipaddr = Ipv4Addr::new(ip_chunk[0], ip_chunk[1], ip_chunk[2], ip_chunk[3]);
let port = byteorder::BigEndian::read_u16(port_chunk);
ips.push(SocketAddrV4::new(ipaddr, port));
}
ips
}
#[derive(Deserialize, Debug)]
pub struct TrackerResponse<'a> {
#[serde(rename = "warning message", borrow)]
pub warning_message: Option<ByteBuf<'a>>,
pub complete: u64,
pub interval: u64,
#[serde(rename = "min interval")]
pub min_interval: Option<u64>,
pub tracker_id: Option<ByteBuf<'a>>,
pub incomplete: u64,
pub peers: Peers,
}
impl TrackerRequest {
pub fn as_querystring(&self) -> String {
use urlencoding as u;
let mut s = String::new();
s.push_str("info_hash=");
s.push_str(u::encode_binary(&self.info_hash.0).as_ref());
s.push_str("&peer_id=");
s.push_str(u::encode_binary(&self.peer_id.0).as_ref());
if let Some(event) = self.event {
write!(
s,
"&event={}",
match event {
TrackerRequestEvent::Started => "started",
TrackerRequestEvent::Stopped => "stopped",
TrackerRequestEvent::Completed => "completed",
}
)
.unwrap();
}
write!(s, "&port={}", self.port).unwrap();
write!(s, "&uploaded={}", self.uploaded).unwrap();
write!(s, "&downloaded={}", self.downloaded).unwrap();
write!(s, "&left={}", self.left).unwrap();
write!(s, "&compact={}", if self.compact { 1 } else { 0 }).unwrap();
write!(s, "&no_peer_id={}", if self.no_peer_id { 1 } else { 0 }).unwrap();
if let Some(ip) = &self.ip {
write!(s, "&ip={ip}").unwrap();
}
if let Some(numwant) = &self.numwant {
write!(s, "&numwant={numwant}").unwrap();
}
if let Some(key) = &self.key {
write!(s, "&key={key}").unwrap();
}
if let Some(trackerid) = &self.trackerid {
write!(s, "&trackerid={trackerid}").unwrap();
}
s
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize() {
let info_hash = Id20::new([
1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
]);
let peer_id = Id20::new([
1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
]);
let request = TrackerRequest {
impl TrackerComms {
pub fn start(
info_hash: Id20,
peer_id: Id20,
trackers: Vec<String>,
stats: Box<dyn TorrentStatsForTracker>,
force_interval: Option<Duration>,
cancellation_token: CancellationToken,
tcp_listen_port: Option<u16>,
) -> anyhow::Result<impl Stream<Item = SocketAddr> + Send + Sync + Unpin + 'static> {
let (tx, rx) = tokio::sync::mpsc::channel::<SocketAddr>(16);
let comms = Arc::new(Self {
info_hash,
peer_id,
port: 6881,
uploaded: 0,
downloaded: 0,
left: 1024 * 1024,
compact: true,
no_peer_id: false,
event: Some(TrackerRequestEvent::Started),
ip: Some("127.0.0.1".parse().unwrap()),
numwant: None,
key: None,
trackerid: None,
stats,
force_tracker_interval: force_interval,
cancellation_token,
tx,
tcp_listen_port,
});
for tracker in trackers {
if let Err(e) = comms.clone().add_tracker(&tracker) {
info!(tracker = tracker, "error adding tracker: {:#}", e)
}
}
Ok(tokio_stream::wrappers::ReceiverStream::new(rx))
}
fn add_tracker(self: Arc<Self>, tracker: &str) -> anyhow::Result<()> {
if tracker.starts_with("http://") || tracker.starts_with("https://") {
spawn_with_cancel(
error_span!(
"http_tracker",
tracker = tracker,
info_hash = ?self.info_hash
),
self.cancellation_token.clone(),
{
let comms = self;
let url = Url::parse(tracker).context("can't parse URL")?;
async move { comms.task_single_tracker_monitor_http(url).await }
},
);
} else if tracker.starts_with("udp://") {
spawn_with_cancel(
error_span!("udp_tracker", tracker = tracker),
self.cancellation_token.clone(),
{
let comms = self;
let url = Url::parse(tracker).context("can't parse URL")?;
async move { comms.task_single_tracker_monitor_udp(url).await }
},
);
} else {
bail!("unsupported tracker url {}", tracker)
}
Ok(())
}
async fn task_single_tracker_monitor_http(
self: Arc<Self>,
mut tracker_url: Url,
) -> anyhow::Result<()> {
let mut event = Some(tracker_comms_http::TrackerRequestEvent::Started);
loop {
let request = tracker_comms_http::TrackerRequest {
info_hash: self.info_hash,
peer_id: self.peer_id,
port: 6778,
uploaded: self.stats.get_uploaded_bytes(),
downloaded: self.stats.get_downloaded_bytes(),
left: self.stats.get_left_to_download_bytes(),
compact: true,
no_peer_id: false,
event,
ip: None,
numwant: None,
key: None,
trackerid: None,
};
let request_query = request.as_querystring();
tracker_url.set_query(Some(&request_query));
match self.tracker_one_request_http(tracker_url.clone()).await {
Ok(interval) => {
event = None;
let interval = self
.force_tracker_interval
.unwrap_or_else(|| Duration::from_secs(interval));
debug!(
"sleeping for {:?} after calling tracker {}",
interval,
tracker_url.host().unwrap()
);
tokio::time::sleep(interval).await;
}
Err(e) => {
debug!("error calling the tracker {}: {:#}", tracker_url, e);
tokio::time::sleep(Duration::from_secs(60)).await;
}
};
}
}
async fn tracker_one_request_http(&self, tracker_url: Url) -> anyhow::Result<u64> {
let response: reqwest::Response = reqwest::get(tracker_url).await?;
if !response.status().is_success() {
anyhow::bail!("tracker responded with {:?}", response.status());
}
let bytes = response.bytes().await?;
if let Ok(error) = bencode::from_bytes::<tracker_comms_http::TrackerError>(&bytes) {
anyhow::bail!(
"tracker returned failure. Failure reason: {}",
error.failure_reason
)
};
dbg!(request.as_querystring());
let response = bencode::from_bytes::<tracker_comms_http::TrackerResponse>(&bytes)?;
for peer in response.peers.iter_sockaddrs() {
self.tx.send(peer).await?;
}
Ok(response.interval)
}
async fn task_single_tracker_monitor_udp(&self, url: Url) -> anyhow::Result<()> {
use tracker_comms_udp::*;
if url.scheme() != "udp" {
bail!("expected UDP scheme in {}", url);
}
let hp: (&str, u16) = (
url.host_str().context("missing host")?,
url.port().context("missing port")?,
);
let mut requester = UdpTrackerRequester::new(hp)
.await
.context("error creating UDP tracker requester")?;
let mut sleep_interval: Option<Duration> = None;
loop {
if let Some(i) = sleep_interval {
tokio::time::sleep(i).await;
}
let request = AnnounceFields {
info_hash: self.info_hash,
peer_id: self.peer_id,
downloaded: self.stats.get_downloaded_bytes(),
left: self.stats.get_left_to_download_bytes(),
uploaded: self.stats.get_uploaded_bytes(),
event: EVENT_NONE,
key: 0, // whatever that is?
port: self.tcp_listen_port.unwrap_or(0),
};
match requester.announce(request).await {
Ok(response) => {
for addr in response.addrs {
self.tx
.send(SocketAddr::V4(addr))
.await
.context("rx closed")?;
}
let new_interval = response.interval.max(5);
let new_interval = Duration::from_secs(new_interval as u64);
sleep_interval = Some(self.force_tracker_interval.unwrap_or(new_interval));
}
Err(e) => {
debug!(url = ?url, "error reading announce response: {e:#}");
if sleep_interval.is_none() {
sleep_interval = Some(
self.force_tracker_interval
.unwrap_or(Duration::from_secs(60)),
);
}
}
}
}
}
}

View file

@ -0,0 +1,233 @@
use buffers::ByteBuf;
use byteorder::ByteOrder;
use serde::{Deserialize, Deserializer};
use std::{
fmt::Write,
marker::PhantomData,
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4},
str::FromStr,
};
use librqbit_core::hash_id::Id20;
#[derive(Clone, Copy)]
pub enum TrackerRequestEvent {
Started,
#[allow(dead_code)]
Stopped,
#[allow(dead_code)]
Completed,
}
pub struct TrackerRequest {
pub info_hash: Id20,
pub peer_id: Id20,
pub event: Option<TrackerRequestEvent>,
pub port: u16,
pub uploaded: u64,
pub downloaded: u64,
pub left: u64,
pub compact: bool,
pub no_peer_id: bool,
pub ip: Option<std::net::IpAddr>,
pub numwant: Option<usize>,
pub key: Option<String>,
pub trackerid: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct TrackerError<'a> {
#[serde(rename = "failure reason", borrow)]
pub failure_reason: ByteBuf<'a>,
}
#[derive(Deserialize, Debug)]
pub struct DictPeer<'a> {
#[serde(deserialize_with = "deserialize_ip_string")]
ip: IpAddr,
#[serde(borrow)]
#[allow(dead_code)]
peer_id: Option<ByteBuf<'a>>,
port: u16,
}
impl<'a> DictPeer<'a> {
fn as_sockaddr(&self) -> SocketAddr {
SocketAddr::new(self.ip, self.port)
}
}
#[derive(Debug)]
pub struct Peers {
addrs: Vec<SocketAddr>,
}
impl Peers {
pub fn iter_sockaddrs(&self) -> impl Iterator<Item = std::net::SocketAddr> + '_ {
self.addrs.iter().copied()
}
}
impl<'de> serde::de::Deserialize<'de> for Peers {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor<'de> {
phantom: std::marker::PhantomData<&'de ()>,
}
impl<'de> serde::de::Visitor<'de> for Visitor<'de> {
type Value = Peers;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a list of peers in dict or binary format")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut peers = Vec::new();
while let Some(peer) = seq.next_element::<DictPeer>()? {
peers.push(peer.as_sockaddr())
}
Ok(Peers { addrs: peers })
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Peers {
addrs: parse_compact_peers(v)
.into_iter()
.map(|v| v.into())
.collect(),
})
}
}
deserializer.deserialize_any(Visitor {
phantom: PhantomData,
})
}
}
fn deserialize_ip_string<'de, D>(de: D) -> Result<IpAddr, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = IpAddr;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("expecting an IPv4 address")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
IpAddr::from_str(v).map_err(|e| E::custom(format!("cannot parse ip: {e}")))
}
}
de.deserialize_str(Visitor {})
}
fn parse_compact_peers(b: &[u8]) -> Vec<SocketAddrV4> {
let mut ips = Vec::new();
for chunk in b.chunks_exact(6) {
let ip_chunk = &chunk[..4];
let port_chunk = &chunk[4..6];
let ipaddr = Ipv4Addr::new(ip_chunk[0], ip_chunk[1], ip_chunk[2], ip_chunk[3]);
let port = byteorder::BigEndian::read_u16(port_chunk);
ips.push(SocketAddrV4::new(ipaddr, port));
}
ips
}
#[derive(Deserialize, Debug)]
pub struct TrackerResponse<'a> {
#[serde(rename = "warning message", borrow)]
pub warning_message: Option<ByteBuf<'a>>,
pub complete: u64,
pub interval: u64,
#[serde(rename = "min interval")]
pub min_interval: Option<u64>,
pub tracker_id: Option<ByteBuf<'a>>,
pub incomplete: u64,
pub peers: Peers,
}
impl TrackerRequest {
pub fn as_querystring(&self) -> String {
use urlencoding as u;
let mut s = String::new();
s.push_str("info_hash=");
s.push_str(u::encode_binary(&self.info_hash.0).as_ref());
s.push_str("&peer_id=");
s.push_str(u::encode_binary(&self.peer_id.0).as_ref());
if let Some(event) = self.event {
write!(
s,
"&event={}",
match event {
TrackerRequestEvent::Started => "started",
TrackerRequestEvent::Stopped => "stopped",
TrackerRequestEvent::Completed => "completed",
}
)
.unwrap();
}
write!(s, "&port={}", self.port).unwrap();
write!(s, "&uploaded={}", self.uploaded).unwrap();
write!(s, "&downloaded={}", self.downloaded).unwrap();
write!(s, "&left={}", self.left).unwrap();
write!(s, "&compact={}", if self.compact { 1 } else { 0 }).unwrap();
write!(s, "&no_peer_id={}", if self.no_peer_id { 1 } else { 0 }).unwrap();
if let Some(ip) = &self.ip {
write!(s, "&ip={ip}").unwrap();
}
if let Some(numwant) = &self.numwant {
write!(s, "&numwant={numwant}").unwrap();
}
if let Some(key) = &self.key {
write!(s, "&key={key}").unwrap();
}
if let Some(trackerid) = &self.trackerid {
write!(s, "&trackerid={trackerid}").unwrap();
}
s
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize() {
let info_hash = Id20::new([
1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
]);
let peer_id = Id20::new([
1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
]);
let request = TrackerRequest {
info_hash,
peer_id,
port: 6881,
uploaded: 0,
downloaded: 0,
left: 1024 * 1024,
compact: true,
no_peer_id: false,
event: Some(TrackerRequestEvent::Started),
ip: Some("127.0.0.1".parse().unwrap()),
numwant: None,
key: None,
trackerid: None,
};
dbg!(request.as_querystring());
}
}

View file

@ -3,6 +3,7 @@ use std::net::{Ipv4Addr, SocketAddrV4};
use anyhow::{bail, Context};
use librqbit_core::hash_id::Id20;
use rand::Rng;
use tokio::net::ToSocketAddrs;
const ACTION_CONNECT: u32 = 0;
const ACTION_ANNOUNCE: u32 = 1;
@ -70,15 +71,18 @@ impl Request {
}
}
#[derive(Debug)]
pub struct AnnounceResponse {
pub interval: u32,
pub leechers: u32,
pub seeders: u32,
pub addrs: Vec<SocketAddrV4>,
}
#[derive(Debug)]
pub enum Response {
Connect(ConnectionId),
Announce {
interval: u32,
leechers: u32,
seeders: u32,
addrs: Vec<SocketAddrV4>,
},
Announce(AnnounceResponse),
}
fn split_slice(s: &[u8], first_len: usize) -> Option<(&[u8], &[u8])> {
@ -144,12 +148,12 @@ impl Response {
addrs.push(SocketAddrV4::new(ip, port));
}
buf = b;
Response::Announce {
Response::Announce(AnnounceResponse {
interval,
leechers,
seeders,
addrs,
}
})
}
_ => bail!("unsupported action {action}"),
};
@ -165,6 +169,83 @@ impl Response {
}
}
pub struct UdpTrackerRequester {
sock: tokio::net::UdpSocket,
connection_id: ConnectionId,
read_buf: Vec<u8>,
write_buf: Vec<u8>,
}
impl UdpTrackerRequester {
// Addr is "host:port"
pub async fn new(addr: impl ToSocketAddrs) -> anyhow::Result<Self> {
let sock = tokio::net::UdpSocket::bind("0.0.0.0:0")
.await
.context("error binding UDP socket")?;
sock.connect(addr)
.await
.context("error connecting UDP socket")?;
let tid = new_transaction_id();
let mut write_buf = Vec::new();
let mut read_buf = vec![0u8; 4096];
Request::Connect.serialize(tid, &mut write_buf);
sock.send(&write_buf)
.await
.context("error sending to socket")?;
let size = sock
.recv(&mut read_buf)
.await
.context("error receiving from socket")?;
let (rtid, response) =
Response::parse(&read_buf[..size]).context("error parsing response")?;
if tid != rtid {
bail!("expected transaction id {} == {}", tid, rtid);
}
let connection_id = match response {
Response::Connect(connection_id) => connection_id,
other => bail!("unexpected response {other:?}"),
};
Ok(Self {
sock,
connection_id,
read_buf,
write_buf,
})
}
pub async fn announce(&mut self, fields: AnnounceFields) -> anyhow::Result<AnnounceResponse> {
let request = Request::Announce(self.connection_id, fields);
let response = self.request(request).await?;
match response {
Response::Announce(r) => Ok(r),
other => bail!("unexpected response {other:?}, expected announce"),
}
}
pub async fn request(&mut self, request: Request) -> anyhow::Result<Response> {
let tid = new_transaction_id();
self.write_buf.clear();
let size = request.serialize(tid, &mut self.write_buf);
self.sock
.send(&self.write_buf[..size])
.await
.context("error sending")?;
let size = self.sock.recv(&mut self.read_buf).await.unwrap();
let (rtid, response) = Response::parse(&self.read_buf[..size]).unwrap();
assert_eq!(tid, rtid);
Ok(response)
}
}
#[cfg(test)]
mod tests {
use std::{io::Write, str::FromStr};
@ -244,13 +325,8 @@ mod tests {
let (rtid, response) = Response::parse(&read_buf[..size]).unwrap();
assert_eq!(tid, rtid);
match response {
Response::Announce {
interval,
leechers,
seeders,
addrs,
} => {
dbg!(interval, leechers, seeders, addrs);
Response::Announce(r) => {
dbg!(r);
}
other => panic!("unexpected response {other:?}"),
}

View file

@ -7,3 +7,5 @@ pub mod peer_id;
pub mod spawn_utils;
pub mod speed_estimator;
pub mod torrent_metainfo;
pub use hash_id::Id20;