diff --git a/crates/tracker_comms/src/tracker_comms_udp.rs b/crates/tracker_comms/src/tracker_comms_udp.rs index baef767..30fc65d 100644 --- a/crates/tracker_comms/src/tracker_comms_udp.rs +++ b/crates/tracker_comms/src/tracker_comms_udp.rs @@ -1,5 +1,6 @@ use std::{ collections::{hash_map::Entry, HashMap}, + ffi::CStr, net::{Ipv4Addr, SocketAddrV4}, sync::Arc, time::{Duration, Instant}, @@ -15,7 +16,7 @@ use tracing::{debug, error_span, trace, warn}; const ACTION_CONNECT: u32 = 0; const ACTION_ANNOUNCE: u32 = 1; // const ACTION_SCRAPE: u32 = 2; -// const ACTION_ERROR: u32 = 3; +const ACTION_ERROR: u32 = 3; pub const EVENT_NONE: u32 = 0; pub const EVENT_COMPLETED: u32 = 1; @@ -50,31 +51,51 @@ pub enum Request { } impl Request { - pub fn serialize(&self, transaction_id: TransactionId, buf: &mut Vec) -> usize { - let cur_len = buf.len(); - match self { - Request::Connect => { - buf.extend_from_slice(&CONNECTION_ID_MAGIC.to_be_bytes()); - buf.extend_from_slice(&ACTION_CONNECT.to_be_bytes()); - buf.extend_from_slice(&transaction_id.to_be_bytes()); - } - Request::Announce(connection_id, fields) => { - buf.extend_from_slice(&connection_id.to_be_bytes()); - buf.extend_from_slice(&ACTION_ANNOUNCE.to_be_bytes()); - buf.extend_from_slice(&transaction_id.to_be_bytes()); - buf.extend_from_slice(&fields.info_hash.0); - buf.extend_from_slice(&fields.peer_id.0); - buf.extend_from_slice(&fields.downloaded.to_be_bytes()); - buf.extend_from_slice(&fields.left.to_be_bytes()); - buf.extend_from_slice(&fields.uploaded.to_be_bytes()); - buf.extend_from_slice(&fields.event.to_be_bytes()); - buf.extend_from_slice(&0u32.to_be_bytes()); // ip address 0 - buf.extend_from_slice(&fields.key.to_be_bytes()); - buf.extend_from_slice(&(-1i32).to_be_bytes()); // num want -1 - buf.extend_from_slice(&fields.port.to_be_bytes()); + pub fn serialize( + &self, + transaction_id: TransactionId, + buf: &mut [u8], + ) -> anyhow::Result { + struct W<'a> { + buf: &'a mut [u8], + offset: usize, + } + impl W<'_> { + fn extend_from_slice(&mut self, s: &[u8]) -> anyhow::Result<()> { + if self.buf.len() < self.offset + s.len() { + bail!("not enough space in buffer") + } + self.buf[self.offset..self.offset + s.len()].copy_from_slice(s); + self.offset += s.len(); + Ok(()) } } - buf.len() - cur_len + + let mut w = W { buf, offset: 0 }; + + match self { + Request::Connect => { + w.extend_from_slice(&CONNECTION_ID_MAGIC.to_be_bytes())?; + w.extend_from_slice(&ACTION_CONNECT.to_be_bytes())?; + w.extend_from_slice(&transaction_id.to_be_bytes())?; + } + Request::Announce(connection_id, fields) => { + w.extend_from_slice(&connection_id.to_be_bytes())?; + w.extend_from_slice(&ACTION_ANNOUNCE.to_be_bytes())?; + w.extend_from_slice(&transaction_id.to_be_bytes())?; + w.extend_from_slice(&fields.info_hash.0)?; + w.extend_from_slice(&fields.peer_id.0)?; + w.extend_from_slice(&fields.downloaded.to_be_bytes())?; + w.extend_from_slice(&fields.left.to_be_bytes())?; + w.extend_from_slice(&fields.uploaded.to_be_bytes())?; + w.extend_from_slice(&fields.event.to_be_bytes())?; + w.extend_from_slice(&0u32.to_be_bytes())?; // ip address 0 + w.extend_from_slice(&fields.key.to_be_bytes())?; + w.extend_from_slice(&(-1i32).to_be_bytes())?; // num want -1 + w.extend_from_slice(&fields.port.to_be_bytes())?; + } + } + Ok(w.offset) } } @@ -92,6 +113,9 @@ pub struct AnnounceResponse { pub enum Response { Connect(ConnectionId), Announce(AnnounceResponse), + #[allow(dead_code)] + Error(String), + Unknown, } fn split_slice(s: &[u8], first_len: usize) -> Option<(&[u8], &[u8])> { @@ -134,7 +158,20 @@ parse_impl!(i16, 2); impl Response { pub fn parse(buf: &[u8]) -> anyhow::Result<(TransactionId, Self)> { let (action, buf) = u32::parse_num(buf).context("can't parse action")?; - let (tid, mut buf) = u32::parse_num(buf).context("can't parse transaction id")?; + let (tid, buf) = u32::parse_num(buf).context("can't parse transaction id")?; + + let response = match Self::parse_response(action, buf) { + Ok(r) => r, + Err(e) => { + debug!("error parsing: {e:#}"); + Response::Unknown + } + }; + + Ok((tid, response)) + } + + fn parse_response(action: u32, mut buf: &[u8]) -> anyhow::Result { let response = match action { ACTION_CONNECT => { let (connection_id, b) = @@ -164,6 +201,15 @@ impl Response { addrs, }) } + ACTION_ERROR => { + let msg = CStr::from_bytes_with_nul(buf) + .ok() + .and_then(|s| s.to_str().ok()) + .or_else(|| std::str::from_utf8(buf).ok()) + .unwrap_or("") + .to_owned(); + return Ok(Response::Error(msg)); + } _ => bail!("unsupported action {action}"), }; @@ -174,7 +220,7 @@ impl Response { ); } - Ok((tid, response)) + Ok(response) } } @@ -296,15 +342,23 @@ impl UdpTrackerClient { let (tx, rx) = tokio::sync::oneshot::channel(); let tid_g = self.reserve_transaction_id(tx)?; - // TODO: no allocs - let mut write_buf = Vec::new(); - request.serialize(tid_g.tid, &mut write_buf); - self.state.sock.send_to(&write_buf, addr).await?; + let mut write_buf = [0u8; 1024]; + let len = request.serialize(tid_g.tid, &mut write_buf)?; + self.state.sock.send_to(&write_buf[..len], addr).await?; let response = tokio::time::timeout(Duration::from_secs(10), rx) .await .context("timeout connecting")? .context("sender dead")?; + match &response { + Response::Error(e) => { + anyhow::bail!("remote errored: {e}") + } + Response::Unknown => { + anyhow::bail!("remote replied with something we could not parse") + } + _ => {} + } Ok(response) } @@ -368,12 +422,12 @@ mod tests { sock.connect("opentor.net:6969").await.unwrap(); let tid = new_transaction_id(); - let mut write_buf = Vec::new(); + let mut write_buf = [0u8; 16384]; let mut read_buf = vec![0u8; 4096]; - Request::Connect.serialize(tid, &mut write_buf); + let len = Request::Connect.serialize(tid, &mut write_buf).unwrap(); - sock.send(&write_buf).await.unwrap(); + sock.send(&write_buf[..len]).await.unwrap(); let size = sock.recv(&mut read_buf).await.unwrap(); @@ -402,8 +456,7 @@ mod tests { port: 24563, }, ); - write_buf.clear(); - let size = request.serialize(tid, &mut write_buf); + let size = request.serialize(tid, &mut write_buf).unwrap(); sock.send(&write_buf[..size]).await.unwrap(); let size = sock.recv(&mut read_buf).await.unwrap();