diff --git a/crates/librqbit/src/lib.rs b/crates/librqbit/src/lib.rs index a1897cb..817c086 100644 --- a/crates/librqbit/src/lib.rs +++ b/crates/librqbit/src/lib.rs @@ -36,7 +36,8 @@ mod session; mod spawn_utils; mod torrent_state; pub mod tracing_subscriber_config_utils; -mod tracker_comms; +pub mod tracker_comms; +pub mod tracker_comms_http; pub mod tracker_comms_udp; mod type_aliases; diff --git a/crates/librqbit/src/torrent_state/live/mod.rs b/crates/librqbit/src/torrent_state/live/mod.rs index 859d7f2..3de7f0c 100644 --- a/crates/librqbit/src/torrent_state/live/mod.rs +++ b/crates/librqbit/src/torrent_state/live/mod.rs @@ -57,7 +57,6 @@ use std::{ use anyhow::{bail, Context}; use backoff::backoff::Backoff; -use bencode::from_bytes; use buffers::{ByteBuf, ByteString}; use clone_to_owned::CloneToOwned; use futures::{stream::FuturesUnordered, StreamExt}; @@ -83,7 +82,6 @@ use tokio::{ }; use tokio_util::sync::CancellationToken; use tracing::{debug, error, error_span, info, trace, warn}; -use url::Url; use crate::{ chunk_tracker::{ChunkMarkingResult, ChunkTracker}, @@ -93,7 +91,6 @@ use crate::{ }, session::CheckedIncomingConnection, torrent_state::{peer::Peer, utils::atomic_inc}, - tracker_comms::{TrackerError, TrackerRequest, TrackerRequestEvent, TrackerResponse}, type_aliases::{PeerHandle, BF}, }; @@ -237,13 +234,6 @@ impl TorrentStateLive { cancellation_token, }); - for tracker in state.meta.trackers.iter() { - state.spawn( - error_span!(parent: state.meta.span.clone(), "tracker_monitor", url = tracker.to_string()), - state.clone().task_single_tracker_monitor(tracker.clone()), - ); - } - state.spawn( error_span!(parent: state.meta.span.clone(), "speed_estimator_updater"), { @@ -297,74 +287,6 @@ impl TorrentStateLive { &self.up_speed_estimator } - async fn tracker_one_request(&self, tracker_url: Url) -> anyhow::Result { - let response: reqwest::Response = reqwest::get(tracker_url).await?; - if !response.status().is_success() { - anyhow::bail!("tracker responded with {:?}", response.status()); - } - let bytes = response.bytes().await?; - if let Ok(error) = from_bytes::(&bytes) { - anyhow::bail!( - "tracker returned failure. Failure reason: {}", - error.failure_reason - ) - }; - let response = from_bytes::(&bytes)?; - - for peer in response.peers.iter_sockaddrs() { - self.add_peer_if_not_seen(peer)?; - } - Ok(response.interval) - } - - async fn task_single_tracker_monitor( - self: Arc, - mut tracker_url: Url, - ) -> anyhow::Result<()> { - let mut event = Some(TrackerRequestEvent::Started); - loop { - let request = TrackerRequest { - info_hash: self.info_hash(), - peer_id: self.peer_id(), - port: 6778, - uploaded: self.get_uploaded_bytes(), - downloaded: self.get_downloaded_bytes(), - left: self.get_left_to_download_bytes(), - compact: true, - no_peer_id: false, - event, - ip: None, - numwant: None, - key: None, - trackerid: None, - }; - - let request_query = request.as_querystring(); - tracker_url.set_query(Some(&request_query)); - - match self.tracker_one_request(tracker_url.clone()).await { - Ok(interval) => { - event = None; - let interval = self - .meta - .options - .force_tracker_interval - .unwrap_or_else(|| Duration::from_secs(interval)); - debug!( - "sleeping for {:?} after calling tracker {}", - interval, - tracker_url.host().unwrap() - ); - tokio::time::sleep(interval).await; - } - Err(e) => { - debug!("error calling the tracker {}: {:#}", tracker_url, e); - tokio::time::sleep(Duration::from_secs(60)).await; - } - }; - } - } - pub(crate) fn add_incoming_peer( self: &Arc, checked_peer: CheckedIncomingConnection, diff --git a/crates/librqbit/src/tracker_comms.rs b/crates/librqbit/src/tracker_comms.rs index e263be7..8482552 100644 --- a/crates/librqbit/src/tracker_comms.rs +++ b/crates/librqbit/src/tracker_comms.rs @@ -1,233 +1,226 @@ -use buffers::ByteBuf; -use byteorder::ByteOrder; -use serde::{Deserialize, Deserializer}; -use std::{ - fmt::Write, - marker::PhantomData, - net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}, - str::FromStr, -}; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; +use anyhow::bail; +use anyhow::Context; +use futures::Stream; +use librqbit_core::spawn_utils::spawn_with_cancel; +use tokio_util::sync::CancellationToken; +use tracing::debug; +use tracing::error_span; +use tracing::info; +use url::Url; + +use crate::tracker_comms_http; +use crate::tracker_comms_udp; use librqbit_core::hash_id::Id20; -#[derive(Clone, Copy)] -pub enum TrackerRequestEvent { - Started, - #[allow(dead_code)] - Stopped, - #[allow(dead_code)] - Completed, +pub struct TrackerComms { + info_hash: Id20, + peer_id: Id20, + stats: Box, + force_tracker_interval: Option, + cancellation_token: CancellationToken, + tx: Sender, + tcp_listen_port: Option, } -pub struct TrackerRequest { - pub info_hash: Id20, - pub peer_id: Id20, - pub event: Option, - pub port: u16, - pub uploaded: u64, - pub downloaded: u64, - pub left: u64, - pub compact: bool, - pub no_peer_id: bool, +pub trait TorrentStatsForTracker: Send + Sync { + fn get_uploaded_bytes(&self) -> u64; + fn get_downloaded_bytes(&self) -> u64; + fn get_total_bytes(&self) -> u64; - pub ip: Option, - pub numwant: Option, - pub key: Option, - pub trackerid: Option, -} - -#[derive(Deserialize, Debug)] -pub struct TrackerError<'a> { - #[serde(rename = "failure reason", borrow)] - pub failure_reason: ByteBuf<'a>, -} - -#[derive(Deserialize, Debug)] -pub struct DictPeer<'a> { - #[serde(deserialize_with = "deserialize_ip_string")] - ip: IpAddr, - #[serde(borrow)] - #[allow(dead_code)] - peer_id: Option>, - port: u16, -} - -impl<'a> DictPeer<'a> { - fn as_sockaddr(&self) -> SocketAddr { - SocketAddr::new(self.ip, self.port) + fn get_left_to_download_bytes(&self) -> u64 { + let total = self.get_total_bytes(); + let down = self.get_downloaded_bytes(); + if total >= down { + return total - down; + } + 0 } } -#[derive(Debug)] -pub struct Peers { - addrs: Vec, -} +type Sender = tokio::sync::mpsc::Sender; -impl Peers { - pub fn iter_sockaddrs(&self) -> impl Iterator + '_ { - self.addrs.iter().copied() - } -} - -impl<'de> serde::de::Deserialize<'de> for Peers { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct Visitor<'de> { - phantom: std::marker::PhantomData<&'de ()>, - } - impl<'de> serde::de::Visitor<'de> for Visitor<'de> { - type Value = Peers; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a list of peers in dict or binary format") - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let mut peers = Vec::new(); - while let Some(peer) = seq.next_element::()? { - peers.push(peer.as_sockaddr()) - } - Ok(Peers { addrs: peers }) - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: serde::de::Error, - { - Ok(Peers { - addrs: parse_compact_peers(v) - .into_iter() - .map(|v| v.into()) - .collect(), - }) - } - } - deserializer.deserialize_any(Visitor { - phantom: PhantomData, - }) - } -} - -fn deserialize_ip_string<'de, D>(de: D) -> Result -where - D: Deserializer<'de>, -{ - struct Visitor; - impl<'de> serde::de::Visitor<'de> for Visitor { - type Value = IpAddr; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("expecting an IPv4 address") - } - - fn visit_str(self, v: &str) -> Result - where - E: serde::de::Error, - { - IpAddr::from_str(v).map_err(|e| E::custom(format!("cannot parse ip: {e}"))) - } - } - de.deserialize_str(Visitor {}) -} - -fn parse_compact_peers(b: &[u8]) -> Vec { - let mut ips = Vec::new(); - for chunk in b.chunks_exact(6) { - let ip_chunk = &chunk[..4]; - let port_chunk = &chunk[4..6]; - let ipaddr = Ipv4Addr::new(ip_chunk[0], ip_chunk[1], ip_chunk[2], ip_chunk[3]); - let port = byteorder::BigEndian::read_u16(port_chunk); - ips.push(SocketAddrV4::new(ipaddr, port)); - } - ips -} - -#[derive(Deserialize, Debug)] -pub struct TrackerResponse<'a> { - #[serde(rename = "warning message", borrow)] - pub warning_message: Option>, - pub complete: u64, - pub interval: u64, - #[serde(rename = "min interval")] - pub min_interval: Option, - pub tracker_id: Option>, - pub incomplete: u64, - pub peers: Peers, -} - -impl TrackerRequest { - pub fn as_querystring(&self) -> String { - use urlencoding as u; - let mut s = String::new(); - s.push_str("info_hash="); - s.push_str(u::encode_binary(&self.info_hash.0).as_ref()); - s.push_str("&peer_id="); - s.push_str(u::encode_binary(&self.peer_id.0).as_ref()); - if let Some(event) = self.event { - write!( - s, - "&event={}", - match event { - TrackerRequestEvent::Started => "started", - TrackerRequestEvent::Stopped => "stopped", - TrackerRequestEvent::Completed => "completed", - } - ) - .unwrap(); - } - write!(s, "&port={}", self.port).unwrap(); - write!(s, "&uploaded={}", self.uploaded).unwrap(); - write!(s, "&downloaded={}", self.downloaded).unwrap(); - write!(s, "&left={}", self.left).unwrap(); - write!(s, "&compact={}", if self.compact { 1 } else { 0 }).unwrap(); - write!(s, "&no_peer_id={}", if self.no_peer_id { 1 } else { 0 }).unwrap(); - if let Some(ip) = &self.ip { - write!(s, "&ip={ip}").unwrap(); - } - if let Some(numwant) = &self.numwant { - write!(s, "&numwant={numwant}").unwrap(); - } - if let Some(key) = &self.key { - write!(s, "&key={key}").unwrap(); - } - if let Some(trackerid) = &self.trackerid { - write!(s, "&trackerid={trackerid}").unwrap(); - } - s - } -} - -#[cfg(test)] -mod tests { - use super::*; - #[test] - fn test_serialize() { - let info_hash = Id20::new([ - 1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, - ]); - let peer_id = Id20::new([ - 1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, - ]); - let request = TrackerRequest { +impl TrackerComms { + pub fn start( + info_hash: Id20, + peer_id: Id20, + trackers: Vec, + stats: Box, + force_interval: Option, + cancellation_token: CancellationToken, + tcp_listen_port: Option, + ) -> anyhow::Result + Send + Sync + Unpin + 'static> { + let (tx, rx) = tokio::sync::mpsc::channel::(16); + let comms = Arc::new(Self { info_hash, peer_id, - port: 6881, - uploaded: 0, - downloaded: 0, - left: 1024 * 1024, - compact: true, - no_peer_id: false, - event: Some(TrackerRequestEvent::Started), - ip: Some("127.0.0.1".parse().unwrap()), - numwant: None, - key: None, - trackerid: None, + stats, + force_tracker_interval: force_interval, + cancellation_token, + tx, + tcp_listen_port, + }); + for tracker in trackers { + if let Err(e) = comms.clone().add_tracker(&tracker) { + info!(tracker = tracker, "error adding tracker: {:#}", e) + } + } + Ok(tokio_stream::wrappers::ReceiverStream::new(rx)) + } + + fn add_tracker(self: Arc, tracker: &str) -> anyhow::Result<()> { + if tracker.starts_with("http://") || tracker.starts_with("https://") { + spawn_with_cancel( + error_span!( + "http_tracker", + tracker = tracker, + info_hash = ?self.info_hash + ), + self.cancellation_token.clone(), + { + let comms = self; + let url = Url::parse(tracker).context("can't parse URL")?; + async move { comms.task_single_tracker_monitor_http(url).await } + }, + ); + } else if tracker.starts_with("udp://") { + spawn_with_cancel( + error_span!("udp_tracker", tracker = tracker), + self.cancellation_token.clone(), + { + let comms = self; + let url = Url::parse(tracker).context("can't parse URL")?; + async move { comms.task_single_tracker_monitor_udp(url).await } + }, + ); + } else { + bail!("unsupported tracker url {}", tracker) + } + Ok(()) + } + + async fn task_single_tracker_monitor_http( + self: Arc, + mut tracker_url: Url, + ) -> anyhow::Result<()> { + let mut event = Some(tracker_comms_http::TrackerRequestEvent::Started); + loop { + let request = tracker_comms_http::TrackerRequest { + info_hash: self.info_hash, + peer_id: self.peer_id, + port: 6778, + uploaded: self.stats.get_uploaded_bytes(), + downloaded: self.stats.get_downloaded_bytes(), + left: self.stats.get_left_to_download_bytes(), + compact: true, + no_peer_id: false, + event, + ip: None, + numwant: None, + key: None, + trackerid: None, + }; + + let request_query = request.as_querystring(); + tracker_url.set_query(Some(&request_query)); + + match self.tracker_one_request_http(tracker_url.clone()).await { + Ok(interval) => { + event = None; + let interval = self + .force_tracker_interval + .unwrap_or_else(|| Duration::from_secs(interval)); + debug!( + "sleeping for {:?} after calling tracker {}", + interval, + tracker_url.host().unwrap() + ); + tokio::time::sleep(interval).await; + } + Err(e) => { + debug!("error calling the tracker {}: {:#}", tracker_url, e); + tokio::time::sleep(Duration::from_secs(60)).await; + } + }; + } + } + + async fn tracker_one_request_http(&self, tracker_url: Url) -> anyhow::Result { + let response: reqwest::Response = reqwest::get(tracker_url).await?; + if !response.status().is_success() { + anyhow::bail!("tracker responded with {:?}", response.status()); + } + let bytes = response.bytes().await?; + if let Ok(error) = bencode::from_bytes::(&bytes) { + anyhow::bail!( + "tracker returned failure. Failure reason: {}", + error.failure_reason + ) }; - dbg!(request.as_querystring()); + let response = bencode::from_bytes::(&bytes)?; + + for peer in response.peers.iter_sockaddrs() { + self.tx.send(peer).await?; + } + Ok(response.interval) + } + + async fn task_single_tracker_monitor_udp(&self, url: Url) -> anyhow::Result<()> { + use tracker_comms_udp::*; + + if url.scheme() != "udp" { + bail!("expected UDP scheme in {}", url); + } + let hp: (&str, u16) = ( + url.host_str().context("missing host")?, + url.port().context("missing port")?, + ); + let mut requester = UdpTrackerRequester::new(hp) + .await + .context("error creating UDP tracker requester")?; + + let mut sleep_interval: Option = None; + loop { + if let Some(i) = sleep_interval { + tokio::time::sleep(i).await; + } + + let request = AnnounceFields { + info_hash: self.info_hash, + peer_id: self.peer_id, + downloaded: self.stats.get_downloaded_bytes(), + left: self.stats.get_left_to_download_bytes(), + uploaded: self.stats.get_uploaded_bytes(), + event: EVENT_NONE, + key: 0, // whatever that is? + port: self.tcp_listen_port.unwrap_or(0), + }; + + match requester.announce(request).await { + Ok(response) => { + for addr in response.addrs { + self.tx + .send(SocketAddr::V4(addr)) + .await + .context("rx closed")?; + } + let new_interval = response.interval.max(5); + let new_interval = Duration::from_secs(new_interval as u64); + sleep_interval = Some(self.force_tracker_interval.unwrap_or(new_interval)); + } + Err(e) => { + debug!(url = ?url, "error reading announce response: {e:#}"); + if sleep_interval.is_none() { + sleep_interval = Some( + self.force_tracker_interval + .unwrap_or(Duration::from_secs(60)), + ); + } + } + } + } } } diff --git a/crates/librqbit/src/tracker_comms_http.rs b/crates/librqbit/src/tracker_comms_http.rs new file mode 100644 index 0000000..e263be7 --- /dev/null +++ b/crates/librqbit/src/tracker_comms_http.rs @@ -0,0 +1,233 @@ +use buffers::ByteBuf; +use byteorder::ByteOrder; +use serde::{Deserialize, Deserializer}; +use std::{ + fmt::Write, + marker::PhantomData, + net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}, + str::FromStr, +}; + +use librqbit_core::hash_id::Id20; + +#[derive(Clone, Copy)] +pub enum TrackerRequestEvent { + Started, + #[allow(dead_code)] + Stopped, + #[allow(dead_code)] + Completed, +} + +pub struct TrackerRequest { + pub info_hash: Id20, + pub peer_id: Id20, + pub event: Option, + pub port: u16, + pub uploaded: u64, + pub downloaded: u64, + pub left: u64, + pub compact: bool, + pub no_peer_id: bool, + + pub ip: Option, + pub numwant: Option, + pub key: Option, + pub trackerid: Option, +} + +#[derive(Deserialize, Debug)] +pub struct TrackerError<'a> { + #[serde(rename = "failure reason", borrow)] + pub failure_reason: ByteBuf<'a>, +} + +#[derive(Deserialize, Debug)] +pub struct DictPeer<'a> { + #[serde(deserialize_with = "deserialize_ip_string")] + ip: IpAddr, + #[serde(borrow)] + #[allow(dead_code)] + peer_id: Option>, + port: u16, +} + +impl<'a> DictPeer<'a> { + fn as_sockaddr(&self) -> SocketAddr { + SocketAddr::new(self.ip, self.port) + } +} + +#[derive(Debug)] +pub struct Peers { + addrs: Vec, +} + +impl Peers { + pub fn iter_sockaddrs(&self) -> impl Iterator + '_ { + self.addrs.iter().copied() + } +} + +impl<'de> serde::de::Deserialize<'de> for Peers { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct Visitor<'de> { + phantom: std::marker::PhantomData<&'de ()>, + } + impl<'de> serde::de::Visitor<'de> for Visitor<'de> { + type Value = Peers; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a list of peers in dict or binary format") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut peers = Vec::new(); + while let Some(peer) = seq.next_element::()? { + peers.push(peer.as_sockaddr()) + } + Ok(Peers { addrs: peers }) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + Ok(Peers { + addrs: parse_compact_peers(v) + .into_iter() + .map(|v| v.into()) + .collect(), + }) + } + } + deserializer.deserialize_any(Visitor { + phantom: PhantomData, + }) + } +} + +fn deserialize_ip_string<'de, D>(de: D) -> Result +where + D: Deserializer<'de>, +{ + struct Visitor; + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = IpAddr; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("expecting an IPv4 address") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + IpAddr::from_str(v).map_err(|e| E::custom(format!("cannot parse ip: {e}"))) + } + } + de.deserialize_str(Visitor {}) +} + +fn parse_compact_peers(b: &[u8]) -> Vec { + let mut ips = Vec::new(); + for chunk in b.chunks_exact(6) { + let ip_chunk = &chunk[..4]; + let port_chunk = &chunk[4..6]; + let ipaddr = Ipv4Addr::new(ip_chunk[0], ip_chunk[1], ip_chunk[2], ip_chunk[3]); + let port = byteorder::BigEndian::read_u16(port_chunk); + ips.push(SocketAddrV4::new(ipaddr, port)); + } + ips +} + +#[derive(Deserialize, Debug)] +pub struct TrackerResponse<'a> { + #[serde(rename = "warning message", borrow)] + pub warning_message: Option>, + pub complete: u64, + pub interval: u64, + #[serde(rename = "min interval")] + pub min_interval: Option, + pub tracker_id: Option>, + pub incomplete: u64, + pub peers: Peers, +} + +impl TrackerRequest { + pub fn as_querystring(&self) -> String { + use urlencoding as u; + let mut s = String::new(); + s.push_str("info_hash="); + s.push_str(u::encode_binary(&self.info_hash.0).as_ref()); + s.push_str("&peer_id="); + s.push_str(u::encode_binary(&self.peer_id.0).as_ref()); + if let Some(event) = self.event { + write!( + s, + "&event={}", + match event { + TrackerRequestEvent::Started => "started", + TrackerRequestEvent::Stopped => "stopped", + TrackerRequestEvent::Completed => "completed", + } + ) + .unwrap(); + } + write!(s, "&port={}", self.port).unwrap(); + write!(s, "&uploaded={}", self.uploaded).unwrap(); + write!(s, "&downloaded={}", self.downloaded).unwrap(); + write!(s, "&left={}", self.left).unwrap(); + write!(s, "&compact={}", if self.compact { 1 } else { 0 }).unwrap(); + write!(s, "&no_peer_id={}", if self.no_peer_id { 1 } else { 0 }).unwrap(); + if let Some(ip) = &self.ip { + write!(s, "&ip={ip}").unwrap(); + } + if let Some(numwant) = &self.numwant { + write!(s, "&numwant={numwant}").unwrap(); + } + if let Some(key) = &self.key { + write!(s, "&key={key}").unwrap(); + } + if let Some(trackerid) = &self.trackerid { + write!(s, "&trackerid={trackerid}").unwrap(); + } + s + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_serialize() { + let info_hash = Id20::new([ + 1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + ]); + let peer_id = Id20::new([ + 1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + ]); + let request = TrackerRequest { + info_hash, + peer_id, + port: 6881, + uploaded: 0, + downloaded: 0, + left: 1024 * 1024, + compact: true, + no_peer_id: false, + event: Some(TrackerRequestEvent::Started), + ip: Some("127.0.0.1".parse().unwrap()), + numwant: None, + key: None, + trackerid: None, + }; + dbg!(request.as_querystring()); + } +} diff --git a/crates/librqbit/src/tracker_comms_udp.rs b/crates/librqbit/src/tracker_comms_udp.rs index 306076b..c32a143 100644 --- a/crates/librqbit/src/tracker_comms_udp.rs +++ b/crates/librqbit/src/tracker_comms_udp.rs @@ -3,6 +3,7 @@ use std::net::{Ipv4Addr, SocketAddrV4}; use anyhow::{bail, Context}; use librqbit_core::hash_id::Id20; use rand::Rng; +use tokio::net::ToSocketAddrs; const ACTION_CONNECT: u32 = 0; const ACTION_ANNOUNCE: u32 = 1; @@ -70,15 +71,18 @@ impl Request { } } +#[derive(Debug)] +pub struct AnnounceResponse { + pub interval: u32, + pub leechers: u32, + pub seeders: u32, + pub addrs: Vec, +} + #[derive(Debug)] pub enum Response { Connect(ConnectionId), - Announce { - interval: u32, - leechers: u32, - seeders: u32, - addrs: Vec, - }, + Announce(AnnounceResponse), } fn split_slice(s: &[u8], first_len: usize) -> Option<(&[u8], &[u8])> { @@ -144,12 +148,12 @@ impl Response { addrs.push(SocketAddrV4::new(ip, port)); } buf = b; - Response::Announce { + Response::Announce(AnnounceResponse { interval, leechers, seeders, addrs, - } + }) } _ => bail!("unsupported action {action}"), }; @@ -165,6 +169,83 @@ impl Response { } } +pub struct UdpTrackerRequester { + sock: tokio::net::UdpSocket, + connection_id: ConnectionId, + read_buf: Vec, + write_buf: Vec, +} + +impl UdpTrackerRequester { + // Addr is "host:port" + pub async fn new(addr: impl ToSocketAddrs) -> anyhow::Result { + let sock = tokio::net::UdpSocket::bind("0.0.0.0:0") + .await + .context("error binding UDP socket")?; + sock.connect(addr) + .await + .context("error connecting UDP socket")?; + + let tid = new_transaction_id(); + let mut write_buf = Vec::new(); + let mut read_buf = vec![0u8; 4096]; + + Request::Connect.serialize(tid, &mut write_buf); + + sock.send(&write_buf) + .await + .context("error sending to socket")?; + + let size = sock + .recv(&mut read_buf) + .await + .context("error receiving from socket")?; + + let (rtid, response) = + Response::parse(&read_buf[..size]).context("error parsing response")?; + if tid != rtid { + bail!("expected transaction id {} == {}", tid, rtid); + } + + let connection_id = match response { + Response::Connect(connection_id) => connection_id, + other => bail!("unexpected response {other:?}"), + }; + + Ok(Self { + sock, + connection_id, + read_buf, + write_buf, + }) + } + + pub async fn announce(&mut self, fields: AnnounceFields) -> anyhow::Result { + let request = Request::Announce(self.connection_id, fields); + let response = self.request(request).await?; + match response { + Response::Announce(r) => Ok(r), + other => bail!("unexpected response {other:?}, expected announce"), + } + } + + pub async fn request(&mut self, request: Request) -> anyhow::Result { + let tid = new_transaction_id(); + self.write_buf.clear(); + let size = request.serialize(tid, &mut self.write_buf); + + self.sock + .send(&self.write_buf[..size]) + .await + .context("error sending")?; + let size = self.sock.recv(&mut self.read_buf).await.unwrap(); + + let (rtid, response) = Response::parse(&self.read_buf[..size]).unwrap(); + assert_eq!(tid, rtid); + Ok(response) + } +} + #[cfg(test)] mod tests { use std::{io::Write, str::FromStr}; @@ -244,13 +325,8 @@ mod tests { let (rtid, response) = Response::parse(&read_buf[..size]).unwrap(); assert_eq!(tid, rtid); match response { - Response::Announce { - interval, - leechers, - seeders, - addrs, - } => { - dbg!(interval, leechers, seeders, addrs); + Response::Announce(r) => { + dbg!(r); } other => panic!("unexpected response {other:?}"), } diff --git a/crates/librqbit_core/src/lib.rs b/crates/librqbit_core/src/lib.rs index 6086598..63577d6 100644 --- a/crates/librqbit_core/src/lib.rs +++ b/crates/librqbit_core/src/lib.rs @@ -7,3 +7,5 @@ pub mod peer_id; pub mod spawn_utils; pub mod speed_estimator; pub mod torrent_metainfo; + +pub use hash_id::Id20;