From c1f34a6599a4736f38f011e3b5c6429b632b2d21 Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Wed, 30 Jun 2021 18:42:16 +0100 Subject: [PATCH] Just messing around with Rust typing --- crates/librqbit/src/buffers.rs | 108 +++++++++++++++--------- crates/librqbit/src/clone_to_owned.rs | 19 +++++ crates/librqbit/src/http_api.rs | 2 +- crates/librqbit/src/serde_bencode.rs | 108 +++++++++++++++--------- crates/librqbit/src/torrent_manager.rs | 16 +++- crates/librqbit/src/torrent_metainfo.rs | 18 ++-- crates/librqbit/src/torrent_state.rs | 2 +- crates/librqbit/src/tracker_comms.rs | 4 +- src/main.rs | 9 +- 9 files changed, 181 insertions(+), 105 deletions(-) diff --git a/crates/librqbit/src/buffers.rs b/crates/librqbit/src/buffers.rs index 2027738..4d58bb5 100644 --- a/crates/librqbit/src/buffers.rs +++ b/crates/librqbit/src/buffers.rs @@ -1,73 +1,79 @@ -use serde::Deserialize; +use serde::{Deserialize, Deserializer}; use crate::clone_to_owned::CloneToOwned; #[derive(PartialEq, Eq, Hash, Clone)] pub struct ByteString(pub Vec); -impl std::fmt::Debug for ByteString { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.0.iter().all(|b| *b == 0) { - return write!(f, "<{} bytes, all zeroes>", self.0.len()); - } - match std::str::from_utf8(self.0.as_slice()) { - Ok(bytes) => bytes.fmt(f), - Err(_e) => write!(f, "<{} bytes>", self.0.len()), - } - } -} - #[derive(Deserialize, PartialEq, Eq, Hash, Clone)] #[serde(transparent)] pub struct ByteBuf<'a>(pub &'a [u8]); -impl<'a> ByteBuf<'a> { - pub fn as_bytes(&'a self) -> &'a [u8] { - self.0 +pub trait ByteBufT { + fn as_slice(&self) -> &[u8]; +} + +impl ByteBufT for ByteString { + fn as_slice(&self) -> &[u8] { + self.as_ref() } } -fn debug_raw_bytes(b: &[u8], f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "<{} bytes>", b.len()) +impl<'a> ByteBufT for ByteBuf<'a> { + fn as_slice(&self) -> &[u8] { + self.as_ref() + } +} + +fn debug_bytes(b: &[u8], f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if b.iter().all(|b| *b == 0) { + return write!(f, "<{} bytes, all zeroes>", b.len()); + } + match std::str::from_utf8(b) { + Ok(s) => write!(f, "{:?}", s), + Err(_e) => write!(f, "<{} bytes>", b.len()), + } +} + +fn display_bytes(b: &[u8], f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if b.iter().all(|b| *b == 0) { + return write!(f, "<{} bytes, all zeroes>", b.len()); + } + match std::str::from_utf8(b) { + Ok(s) => write!(f, "{}", s), + Err(_e) => write!(f, "<{} bytes>", b.len()), + } } impl<'a> std::fmt::Debug for ByteBuf<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.0.iter().all(|b| *b == 0) { - return write!(f, "<{} bytes, all zeroes>", self.0.len()); - } - match std::str::from_utf8(self.0) { - Ok(bytes) => bytes.fmt(f), - Err(_e) => debug_raw_bytes(&self.0, f), - } + debug_bytes(self.0, f) } } impl<'a> std::fmt::Display for ByteBuf<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.0.iter().all(|b| *b == 0) { - return write!(f, "<{} bytes, all zeroes>", self.0.len()); - } - match std::str::from_utf8(self.0) { - Ok(bytes) => f.write_str(bytes), - Err(_e) => debug_raw_bytes(&self.0, f), - } + display_bytes(self.0, f) } } -impl<'a> CloneToOwned for ByteBuf<'a> { - type Target = ByteString; - - fn clone_to_owned(&self) -> Self::Target { - ByteString(self.0.into()) +impl std::fmt::Debug for ByteString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + debug_bytes(&self.0, f) } } -impl CloneToOwned for ByteString { +impl std::fmt::Display for ByteString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + display_bytes(&self.0, f) + } +} + +impl CloneToOwned for B { type Target = ByteString; fn clone_to_owned(&self) -> Self::Target { - self.clone() + ByteString(self.as_slice().to_owned()) } } @@ -116,3 +122,27 @@ impl From> for ByteString { Self(b) } } + +impl<'de> serde::de::Deserialize<'de> for ByteString { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct Visitor; + + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("byte string") + } + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + Ok(v.to_owned()) + } + } + Ok(ByteString(deserializer.deserialize_byte_buf(Visitor {})?)) + } +} diff --git a/crates/librqbit/src/clone_to_owned.rs b/crates/librqbit/src/clone_to_owned.rs index 36b76c2..0cd7002 100644 --- a/crates/librqbit/src/clone_to_owned.rs +++ b/crates/librqbit/src/clone_to_owned.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + pub trait CloneToOwned { type Target; @@ -25,3 +27,20 @@ where self.iter().map(|i| i.clone_to_owned()).collect() } } + +impl CloneToOwned for HashMap +where + K: CloneToOwned, + ::Target: std::hash::Hash + Eq, + V: CloneToOwned, +{ + type Target = HashMap<::Target, ::Target>; + + fn clone_to_owned(&self) -> Self::Target { + let mut result = HashMap::with_capacity(self.capacity()); + for (k, v) in self { + result.insert(k.clone_to_owned(), v.clone_to_owned()); + } + result + } +} diff --git a/crates/librqbit/src/http_api.rs b/crates/librqbit/src/http_api.rs index 65a08be..49e4303 100644 --- a/crates/librqbit/src/http_api.rs +++ b/crates/librqbit/src/http_api.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use std::io::Write; use std::sync::atomic::Ordering; -use std::time::{Duration, Instant}; +use std::time::Instant; use warp::Filter; use crate::torrent_state::TorrentState; diff --git a/crates/librqbit/src/serde_bencode.rs b/crates/librqbit/src/serde_bencode.rs index 48f01e1..1f854d3 100644 --- a/crates/librqbit/src/serde_bencode.rs +++ b/crates/librqbit/src/serde_bencode.rs @@ -1,9 +1,12 @@ use serde::de::Deserializer; use serde::de::Error as DeError; +use serde::Deserialize; use std::collections::HashMap; +use std::marker::PhantomData; use crate::buffers::ByteBuf; use crate::buffers::ByteString; +use crate::clone_to_owned::CloneToOwned; pub struct BencodeDeserializer<'de> { buf: &'de [u8], @@ -83,7 +86,10 @@ where Ok(T::deserialize(&mut de)?) } -pub fn dyn_from_bytes(buf: &[u8]) -> anyhow::Result> { +pub fn dyn_from_bytes<'de, ByteBuf>(buf: &'de [u8]) -> anyhow::Result> +where + ByteBuf: From<&'de [u8]> + Deserialize<'de> + std::hash::Hash + Eq, +{ from_bytes(buf) } @@ -555,15 +561,23 @@ impl<'a, 'de> serde::de::SeqAccess<'de> for SeqAccess<'a, 'de> { } } -impl<'de> serde::de::Deserialize<'de> for DynBencodeNode<'de> { +impl<'de, ByteBuf> serde::de::Deserialize<'de> for BencodeValue +where + ByteBuf: From<&'de [u8]> + Deserialize<'de> + std::hash::Hash + Eq, +{ fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { - struct Visitor; + struct Visitor { + buftype: PhantomData, + } - impl<'de> serde::de::Visitor<'de> for Visitor { - type Value = DynBencodeNode<'de>; + impl<'de, ByteBuf> serde::de::Visitor<'de> for Visitor + where + ByteBuf: From<&'de [u8]> + Deserialize<'de> + std::hash::Hash + Eq, + { + type Value = BencodeValue; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { write!(formatter, "a bencode value") @@ -573,7 +587,7 @@ impl<'de> serde::de::Deserialize<'de> for DynBencodeNode<'de> { where E: serde::de::Error, { - Ok(DynBencodeNode::Integer(v)) + Ok(BencodeValue::Integer(v)) } fn visit_seq(self, mut seq: A) -> Result @@ -584,14 +598,14 @@ impl<'de> serde::de::Deserialize<'de> for DynBencodeNode<'de> { while let Some(value) = seq.next_element()? { v.push(value); } - Ok(DynBencodeNode::List(v)) + Ok(BencodeValue::List(v)) } fn visit_borrowed_bytes(self, v: &'de [u8]) -> Result where E: serde::de::Error, { - Ok(DynBencodeNode::Bytes(ByteBuf(v))) + Ok(BencodeValue::Bytes(ByteBuf::from(v))) } fn visit_map(self, mut map: A) -> Result @@ -603,46 +617,56 @@ impl<'de> serde::de::Deserialize<'de> for DynBencodeNode<'de> { let value = map.next_value()?; hmap.insert(key, value); } - Ok(DynBencodeNode::Dict(hmap)) + Ok(BencodeValue::Dict(hmap)) } } - deserializer.deserialize_any(Visitor {}) + deserializer.deserialize_any(Visitor { + buftype: PhantomData, + }) } } -impl<'de> serde::de::Deserialize<'de> for ByteString { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct Visitor; - - impl<'de> serde::de::Visitor<'de> for Visitor { - type Value = Vec; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("bencode byte string") - } - fn visit_bytes(self, v: &[u8]) -> Result - where - E: serde::de::Error, - { - Ok(v.to_owned()) - } - } - Ok(ByteString(deserializer.deserialize_byte_buf(Visitor {})?)) - } -} - -#[derive(Debug)] -pub enum DynBencodeNode<'a> { - Bytes(ByteBuf<'a>), +// A dynamic value when we don't know exactly what we are deserializing. +// Useful for debugging. +pub enum BencodeValue { + Bytes(ByteBuf), Integer(i64), - List(Vec>), - Dict(HashMap, DynBencodeNode<'a>>), + List(Vec>), + Dict(HashMap>), } +impl std::fmt::Debug for BencodeValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BencodeValue::Bytes(b) => std::fmt::Debug::fmt(b, f), + BencodeValue::Integer(i) => std::fmt::Debug::fmt(i, f), + BencodeValue::List(l) => std::fmt::Debug::fmt(l, f), + BencodeValue::Dict(d) => std::fmt::Debug::fmt(d, f), + } + } +} + +impl CloneToOwned for BencodeValue +where + ByteBuf: CloneToOwned, + ::Target: Eq + std::hash::Hash, +{ + type Target = BencodeValue<::Target>; + + fn clone_to_owned(&self) -> Self::Target { + match self { + BencodeValue::Bytes(b) => BencodeValue::Bytes(b.clone_to_owned()), + BencodeValue::Integer(i) => BencodeValue::Integer(*i), + BencodeValue::List(l) => BencodeValue::List(l.clone_to_owned()), + BencodeValue::Dict(d) => BencodeValue::Dict(d.clone_to_owned()), + } + } +} + +pub type DynBencodeNodeBorrowed<'a> = BencodeValue>; +pub type DynBencodeNodeOwned = BencodeValue; + #[cfg(test)] mod tests { use super::*; @@ -657,7 +681,9 @@ mod tests { .read_to_end(&mut buf) .unwrap(); - let torrent: DynBencodeNode = from_bytes(&buf).unwrap(); - dbg!(torrent); + let torrent_borrowed: DynBencodeNodeBorrowed = from_bytes(&buf).unwrap(); + let torrent_owned: DynBencodeNodeOwned = from_bytes(&buf).unwrap(); + dbg!(torrent_borrowed); + dbg!(torrent_owned); } } diff --git a/crates/librqbit/src/torrent_manager.rs b/crates/librqbit/src/torrent_manager.rs index 38a4258..69780a8 100644 --- a/crates/librqbit/src/torrent_manager.rs +++ b/crates/librqbit/src/torrent_manager.rs @@ -24,7 +24,7 @@ use crate::{ spawn_utils::spawn, torrent_metainfo::TorrentMetaV1Owned, torrent_state::{AtomicStats, TorrentState, TorrentStateLocked}, - tracker_comms::{CompactTrackerResponse, TrackerRequest, TrackerRequestEvent}, + tracker_comms::{TrackerError, TrackerRequest, TrackerRequestEvent, TrackerResponse}, }; pub struct TorrentManagerBuilder { torrent: TorrentMetaV1Owned, @@ -263,8 +263,20 @@ impl TorrentManager { async fn tracker_one_request(&self, tracker_url: Url) -> anyhow::Result { let response: reqwest::Response = reqwest::get(tracker_url).await?; + if !response.status().is_success() { + anyhow::bail!("tracker responded with {:?}", response.status()); + } let bytes = response.bytes().await?; - let response = crate::serde_bencode::from_bytes::(&bytes)?; + match crate::serde_bencode::from_bytes::(&bytes) { + Ok(error) => anyhow::bail!( + "tracker returned failure. Failure reason: {}", + error.failure_reason + ), + Err(_) => { + // ignore, assume ok response + } + }; + let response = crate::serde_bencode::from_bytes::(&bytes)?; for peer in response.peers.iter_sockaddrs() { self.state.add_peer_if_not_seen(peer); diff --git a/crates/librqbit/src/torrent_metainfo.rs b/crates/librqbit/src/torrent_metainfo.rs index 69a8fec..fe30999 100644 --- a/crates/librqbit/src/torrent_metainfo.rs +++ b/crates/librqbit/src/torrent_metainfo.rs @@ -11,7 +11,9 @@ use crate::{ pub type TorrentMetaV1Borrowed<'a> = TorrentMetaV1>; pub type TorrentMetaV1Owned = TorrentMetaV1; -pub fn torrent_from_bytes(buf: &[u8]) -> anyhow::Result> { +pub fn torrent_from_bytes<'de, ByteBuf: Clone + Deserialize<'de>>( + buf: &'de [u8], +) -> anyhow::Result> { let mut de = BencodeDeserializer::new_from_buf(buf); de.is_torrent_info = true; let mut t = TorrentMetaV1::deserialize(&mut de)?; @@ -19,14 +21,6 @@ pub fn torrent_from_bytes(buf: &[u8]) -> anyhow::Result anyhow::Result { - let mut de = BencodeDeserializer::new_from_buf(buf); - de.is_torrent_info = true; - let mut t = TorrentMetaV1Owned::deserialize(&mut de)?; - t.info_hash = de.torrent_info_digest.unwrap(); - Ok(t) -} - #[derive(Deserialize, Debug, Clone)] pub struct TorrentMetaV1 { pub announce: BufType, @@ -259,7 +253,7 @@ mod tests { .read_to_end(&mut buf) .unwrap(); - let torrent: TorrentMetaV1Owned = from_bytes(&buf).unwrap(); + let torrent: TorrentMetaV1Owned = torrent_from_bytes(&buf).unwrap(); dbg!(torrent); } @@ -272,7 +266,7 @@ mod tests { .read_to_end(&mut buf) .unwrap(); - let torrent: TorrentMetaV1Borrowed = from_bytes(&buf).unwrap(); + let torrent: TorrentMetaV1Borrowed = torrent_from_bytes(&buf).unwrap(); dbg!(torrent); } @@ -285,7 +279,7 @@ mod tests { .read_to_end(&mut buf) .unwrap(); - let torrent = torrent_from_bytes(&buf).unwrap(); + let torrent: TorrentMetaV1Borrowed = torrent_from_bytes(&buf).unwrap(); assert_eq!( torrent.info_hash, *b"\x64\xa9\x80\xab\xe6\xe4\x48\x22\x6b\xb9\x30\xba\x06\x15\x92\xe4\x4c\x37\x81\xa1" diff --git a/crates/librqbit/src/torrent_state.rs b/crates/librqbit/src/torrent_state.rs index 0e8f544..1746822 100644 --- a/crates/librqbit/src/torrent_state.rs +++ b/crates/librqbit/src/torrent_state.rs @@ -10,7 +10,7 @@ use std::{ }; use futures::{stream::FuturesUnordered, StreamExt}; -use log::{debug, info, trace, warn}; +use log::{debug, trace, warn}; use parking_lot::{Mutex, RwLock}; use tokio::sync::mpsc::{channel, Sender}; diff --git a/crates/librqbit/src/tracker_comms.rs b/crates/librqbit/src/tracker_comms.rs index 18de7a9..2c4412b 100644 --- a/crates/librqbit/src/tracker_comms.rs +++ b/crates/librqbit/src/tracker_comms.rs @@ -36,7 +36,7 @@ pub struct TrackerRequest { #[derive(Deserialize, Debug)] pub struct TrackerError<'a> { #[serde(rename = "failure reason", borrow)] - failure_reason: ByteBuf<'a>, + pub failure_reason: ByteBuf<'a>, } #[derive(Deserialize, Debug)] @@ -143,7 +143,7 @@ fn parse_compact_peers(b: &[u8]) -> Vec { } #[derive(Deserialize, Debug)] -pub struct CompactTrackerResponse<'a> { +pub struct TrackerResponse<'a> { #[serde(rename = "warning message", borrow)] pub warning_message: Option>, pub complete: u64, diff --git a/src/main.rs b/src/main.rs index 2efcdbb..31a356d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,6 @@ use std::{fs::File, io::Read, time::Duration}; use anyhow::Context; use clap::Clap; use librqbit::{ - clone_to_owned::CloneToOwned, torrent_manager::TorrentManagerBuilder, torrent_metainfo::{torrent_from_bytes, TorrentMetaV1Owned}, }; @@ -20,9 +19,7 @@ async fn torrent_from_url(url: &str) -> anyhow::Result { .bytes() .await .with_context(|| format!("error reading repsonse body from {}", url))?; - Ok(torrent_from_bytes(&b) - .context("error decoding torrent")? - .clone_to_owned()) + torrent_from_bytes(&b).context("error decoding torrent") } fn torrent_from_file(filename: &str) -> anyhow::Result { @@ -37,9 +34,7 @@ fn torrent_from_file(filename: &str) -> anyhow::Result { .read_to_end(&mut buf) .with_context(|| format!("error reading {}", filename))?; } - Ok(torrent_from_bytes(&buf) - .context("error decoding torrent")? - .clone_to_owned()) + torrent_from_bytes(&buf).context("error decoding torrent") } #[derive(Debug, Clap)]