Just messing around with Rust typing

This commit is contained in:
Igor Katson 2021-06-30 18:42:16 +01:00
parent 32f2ea4953
commit c1f34a6599
9 changed files with 181 additions and 105 deletions

View file

@ -1,73 +1,79 @@
use serde::Deserialize;
use serde::{Deserialize, Deserializer};
use crate::clone_to_owned::CloneToOwned;
#[derive(PartialEq, Eq, Hash, Clone)]
pub struct ByteString(pub Vec<u8>);
impl std::fmt::Debug for ByteString {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.0.iter().all(|b| *b == 0) {
return write!(f, "<{} bytes, all zeroes>", self.0.len());
}
match std::str::from_utf8(self.0.as_slice()) {
Ok(bytes) => bytes.fmt(f),
Err(_e) => write!(f, "<{} bytes>", self.0.len()),
}
}
}
#[derive(Deserialize, PartialEq, Eq, Hash, Clone)]
#[serde(transparent)]
pub struct ByteBuf<'a>(pub &'a [u8]);
impl<'a> ByteBuf<'a> {
pub fn as_bytes(&'a self) -> &'a [u8] {
self.0
pub trait ByteBufT {
fn as_slice(&self) -> &[u8];
}
impl ByteBufT for ByteString {
fn as_slice(&self) -> &[u8] {
self.as_ref()
}
}
fn debug_raw_bytes(b: &[u8], f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<{} bytes>", b.len())
impl<'a> ByteBufT for ByteBuf<'a> {
fn as_slice(&self) -> &[u8] {
self.as_ref()
}
}
fn debug_bytes(b: &[u8], f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if b.iter().all(|b| *b == 0) {
return write!(f, "<{} bytes, all zeroes>", b.len());
}
match std::str::from_utf8(b) {
Ok(s) => write!(f, "{:?}", s),
Err(_e) => write!(f, "<{} bytes>", b.len()),
}
}
fn display_bytes(b: &[u8], f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if b.iter().all(|b| *b == 0) {
return write!(f, "<{} bytes, all zeroes>", b.len());
}
match std::str::from_utf8(b) {
Ok(s) => write!(f, "{}", s),
Err(_e) => write!(f, "<{} bytes>", b.len()),
}
}
impl<'a> std::fmt::Debug for ByteBuf<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.0.iter().all(|b| *b == 0) {
return write!(f, "<{} bytes, all zeroes>", self.0.len());
}
match std::str::from_utf8(self.0) {
Ok(bytes) => bytes.fmt(f),
Err(_e) => debug_raw_bytes(&self.0, f),
}
debug_bytes(self.0, f)
}
}
impl<'a> std::fmt::Display for ByteBuf<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.0.iter().all(|b| *b == 0) {
return write!(f, "<{} bytes, all zeroes>", self.0.len());
}
match std::str::from_utf8(self.0) {
Ok(bytes) => f.write_str(bytes),
Err(_e) => debug_raw_bytes(&self.0, f),
}
display_bytes(self.0, f)
}
}
impl<'a> CloneToOwned for ByteBuf<'a> {
type Target = ByteString;
fn clone_to_owned(&self) -> Self::Target {
ByteString(self.0.into())
impl std::fmt::Debug for ByteString {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
debug_bytes(&self.0, f)
}
}
impl CloneToOwned for ByteString {
impl std::fmt::Display for ByteString {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
display_bytes(&self.0, f)
}
}
impl<B: ByteBufT> CloneToOwned for B {
type Target = ByteString;
fn clone_to_owned(&self) -> Self::Target {
self.clone()
ByteString(self.as_slice().to_owned())
}
}
@ -116,3 +122,27 @@ impl From<Vec<u8>> for ByteString {
Self(b)
}
}
impl<'de> serde::de::Deserialize<'de> for ByteString {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = Vec<u8>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("byte string")
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.to_owned())
}
}
Ok(ByteString(deserializer.deserialize_byte_buf(Visitor {})?))
}
}

View file

@ -1,3 +1,5 @@
use std::collections::HashMap;
pub trait CloneToOwned {
type Target;
@ -25,3 +27,20 @@ where
self.iter().map(|i| i.clone_to_owned()).collect()
}
}
impl<K, V> CloneToOwned for HashMap<K, V>
where
K: CloneToOwned,
<K as CloneToOwned>::Target: std::hash::Hash + Eq,
V: CloneToOwned,
{
type Target = HashMap<<K as CloneToOwned>::Target, <V as CloneToOwned>::Target>;
fn clone_to_owned(&self) -> Self::Target {
let mut result = HashMap::with_capacity(self.capacity());
for (k, v) in self {
result.insert(k.clone_to_owned(), v.clone_to_owned());
}
result
}
}

View file

@ -2,7 +2,7 @@ use std::sync::Arc;
use std::io::Write;
use std::sync::atomic::Ordering;
use std::time::{Duration, Instant};
use std::time::Instant;
use warp::Filter;
use crate::torrent_state::TorrentState;

View file

@ -1,9 +1,12 @@
use serde::de::Deserializer;
use serde::de::Error as DeError;
use serde::Deserialize;
use std::collections::HashMap;
use std::marker::PhantomData;
use crate::buffers::ByteBuf;
use crate::buffers::ByteString;
use crate::clone_to_owned::CloneToOwned;
pub struct BencodeDeserializer<'de> {
buf: &'de [u8],
@ -83,7 +86,10 @@ where
Ok(T::deserialize(&mut de)?)
}
pub fn dyn_from_bytes(buf: &[u8]) -> anyhow::Result<DynBencodeNode<'_>> {
pub fn dyn_from_bytes<'de, ByteBuf>(buf: &'de [u8]) -> anyhow::Result<BencodeValue<ByteBuf>>
where
ByteBuf: From<&'de [u8]> + Deserialize<'de> + std::hash::Hash + Eq,
{
from_bytes(buf)
}
@ -555,15 +561,23 @@ impl<'a, 'de> serde::de::SeqAccess<'de> for SeqAccess<'a, 'de> {
}
}
impl<'de> serde::de::Deserialize<'de> for DynBencodeNode<'de> {
impl<'de, ByteBuf> serde::de::Deserialize<'de> for BencodeValue<ByteBuf>
where
ByteBuf: From<&'de [u8]> + Deserialize<'de> + std::hash::Hash + Eq,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
struct Visitor<ByteBuf> {
buftype: PhantomData<ByteBuf>,
}
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = DynBencodeNode<'de>;
impl<'de, ByteBuf> serde::de::Visitor<'de> for Visitor<ByteBuf>
where
ByteBuf: From<&'de [u8]> + Deserialize<'de> + std::hash::Hash + Eq,
{
type Value = BencodeValue<ByteBuf>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a bencode value")
@ -573,7 +587,7 @@ impl<'de> serde::de::Deserialize<'de> for DynBencodeNode<'de> {
where
E: serde::de::Error,
{
Ok(DynBencodeNode::Integer(v))
Ok(BencodeValue::Integer(v))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
@ -584,14 +598,14 @@ impl<'de> serde::de::Deserialize<'de> for DynBencodeNode<'de> {
while let Some(value) = seq.next_element()? {
v.push(value);
}
Ok(DynBencodeNode::List(v))
Ok(BencodeValue::List(v))
}
fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(DynBencodeNode::Bytes(ByteBuf(v)))
Ok(BencodeValue::Bytes(ByteBuf::from(v)))
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
@ -603,46 +617,56 @@ impl<'de> serde::de::Deserialize<'de> for DynBencodeNode<'de> {
let value = map.next_value()?;
hmap.insert(key, value);
}
Ok(DynBencodeNode::Dict(hmap))
Ok(BencodeValue::Dict(hmap))
}
}
deserializer.deserialize_any(Visitor {})
deserializer.deserialize_any(Visitor {
buftype: PhantomData,
})
}
}
impl<'de> serde::de::Deserialize<'de> for ByteString {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = Vec<u8>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("bencode byte string")
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.to_owned())
}
}
Ok(ByteString(deserializer.deserialize_byte_buf(Visitor {})?))
}
}
#[derive(Debug)]
pub enum DynBencodeNode<'a> {
Bytes(ByteBuf<'a>),
// A dynamic value when we don't know exactly what we are deserializing.
// Useful for debugging.
pub enum BencodeValue<ByteBuf> {
Bytes(ByteBuf),
Integer(i64),
List(Vec<DynBencodeNode<'a>>),
Dict(HashMap<ByteBuf<'a>, DynBencodeNode<'a>>),
List(Vec<BencodeValue<ByteBuf>>),
Dict(HashMap<ByteBuf, BencodeValue<ByteBuf>>),
}
impl<ByteBuf: std::fmt::Debug> std::fmt::Debug for BencodeValue<ByteBuf> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BencodeValue::Bytes(b) => std::fmt::Debug::fmt(b, f),
BencodeValue::Integer(i) => std::fmt::Debug::fmt(i, f),
BencodeValue::List(l) => std::fmt::Debug::fmt(l, f),
BencodeValue::Dict(d) => std::fmt::Debug::fmt(d, f),
}
}
}
impl<ByteBuf> CloneToOwned for BencodeValue<ByteBuf>
where
ByteBuf: CloneToOwned,
<ByteBuf as CloneToOwned>::Target: Eq + std::hash::Hash,
{
type Target = BencodeValue<<ByteBuf as CloneToOwned>::Target>;
fn clone_to_owned(&self) -> Self::Target {
match self {
BencodeValue::Bytes(b) => BencodeValue::Bytes(b.clone_to_owned()),
BencodeValue::Integer(i) => BencodeValue::Integer(*i),
BencodeValue::List(l) => BencodeValue::List(l.clone_to_owned()),
BencodeValue::Dict(d) => BencodeValue::Dict(d.clone_to_owned()),
}
}
}
pub type DynBencodeNodeBorrowed<'a> = BencodeValue<ByteBuf<'a>>;
pub type DynBencodeNodeOwned = BencodeValue<ByteString>;
#[cfg(test)]
mod tests {
use super::*;
@ -657,7 +681,9 @@ mod tests {
.read_to_end(&mut buf)
.unwrap();
let torrent: DynBencodeNode = from_bytes(&buf).unwrap();
dbg!(torrent);
let torrent_borrowed: DynBencodeNodeBorrowed = from_bytes(&buf).unwrap();
let torrent_owned: DynBencodeNodeOwned = from_bytes(&buf).unwrap();
dbg!(torrent_borrowed);
dbg!(torrent_owned);
}
}

View file

@ -24,7 +24,7 @@ use crate::{
spawn_utils::spawn,
torrent_metainfo::TorrentMetaV1Owned,
torrent_state::{AtomicStats, TorrentState, TorrentStateLocked},
tracker_comms::{CompactTrackerResponse, TrackerRequest, TrackerRequestEvent},
tracker_comms::{TrackerError, TrackerRequest, TrackerRequestEvent, TrackerResponse},
};
pub struct TorrentManagerBuilder {
torrent: TorrentMetaV1Owned,
@ -263,8 +263,20 @@ impl TorrentManager {
async fn tracker_one_request(&self, tracker_url: Url) -> anyhow::Result<u64> {
let response: reqwest::Response = reqwest::get(tracker_url).await?;
if !response.status().is_success() {
anyhow::bail!("tracker responded with {:?}", response.status());
}
let bytes = response.bytes().await?;
let response = crate::serde_bencode::from_bytes::<CompactTrackerResponse>(&bytes)?;
match crate::serde_bencode::from_bytes::<TrackerError>(&bytes) {
Ok(error) => anyhow::bail!(
"tracker returned failure. Failure reason: {}",
error.failure_reason
),
Err(_) => {
// ignore, assume ok response
}
};
let response = crate::serde_bencode::from_bytes::<TrackerResponse>(&bytes)?;
for peer in response.peers.iter_sockaddrs() {
self.state.add_peer_if_not_seen(peer);

View file

@ -11,7 +11,9 @@ use crate::{
pub type TorrentMetaV1Borrowed<'a> = TorrentMetaV1<ByteBuf<'a>>;
pub type TorrentMetaV1Owned = TorrentMetaV1<ByteString>;
pub fn torrent_from_bytes(buf: &[u8]) -> anyhow::Result<TorrentMetaV1Borrowed<'_>> {
pub fn torrent_from_bytes<'de, ByteBuf: Clone + Deserialize<'de>>(
buf: &'de [u8],
) -> anyhow::Result<TorrentMetaV1<ByteBuf>> {
let mut de = BencodeDeserializer::new_from_buf(buf);
de.is_torrent_info = true;
let mut t = TorrentMetaV1::deserialize(&mut de)?;
@ -19,14 +21,6 @@ pub fn torrent_from_bytes(buf: &[u8]) -> anyhow::Result<TorrentMetaV1Borrowed<'_
Ok(t)
}
pub fn torrent_from_bytes_owned(buf: &[u8]) -> anyhow::Result<TorrentMetaV1Owned> {
let mut de = BencodeDeserializer::new_from_buf(buf);
de.is_torrent_info = true;
let mut t = TorrentMetaV1Owned::deserialize(&mut de)?;
t.info_hash = de.torrent_info_digest.unwrap();
Ok(t)
}
#[derive(Deserialize, Debug, Clone)]
pub struct TorrentMetaV1<BufType: Clone> {
pub announce: BufType,
@ -259,7 +253,7 @@ mod tests {
.read_to_end(&mut buf)
.unwrap();
let torrent: TorrentMetaV1Owned = from_bytes(&buf).unwrap();
let torrent: TorrentMetaV1Owned = torrent_from_bytes(&buf).unwrap();
dbg!(torrent);
}
@ -272,7 +266,7 @@ mod tests {
.read_to_end(&mut buf)
.unwrap();
let torrent: TorrentMetaV1Borrowed = from_bytes(&buf).unwrap();
let torrent: TorrentMetaV1Borrowed = torrent_from_bytes(&buf).unwrap();
dbg!(torrent);
}
@ -285,7 +279,7 @@ mod tests {
.read_to_end(&mut buf)
.unwrap();
let torrent = torrent_from_bytes(&buf).unwrap();
let torrent: TorrentMetaV1Borrowed = torrent_from_bytes(&buf).unwrap();
assert_eq!(
torrent.info_hash,
*b"\x64\xa9\x80\xab\xe6\xe4\x48\x22\x6b\xb9\x30\xba\x06\x15\x92\xe4\x4c\x37\x81\xa1"

View file

@ -10,7 +10,7 @@ use std::{
};
use futures::{stream::FuturesUnordered, StreamExt};
use log::{debug, info, trace, warn};
use log::{debug, trace, warn};
use parking_lot::{Mutex, RwLock};
use tokio::sync::mpsc::{channel, Sender};

View file

@ -36,7 +36,7 @@ pub struct TrackerRequest {
#[derive(Deserialize, Debug)]
pub struct TrackerError<'a> {
#[serde(rename = "failure reason", borrow)]
failure_reason: ByteBuf<'a>,
pub failure_reason: ByteBuf<'a>,
}
#[derive(Deserialize, Debug)]
@ -143,7 +143,7 @@ fn parse_compact_peers(b: &[u8]) -> Vec<SocketAddrV4> {
}
#[derive(Deserialize, Debug)]
pub struct CompactTrackerResponse<'a> {
pub struct TrackerResponse<'a> {
#[serde(rename = "warning message", borrow)]
pub warning_message: Option<ByteBuf<'a>>,
pub complete: u64,

View file

@ -3,7 +3,6 @@ use std::{fs::File, io::Read, time::Duration};
use anyhow::Context;
use clap::Clap;
use librqbit::{
clone_to_owned::CloneToOwned,
torrent_manager::TorrentManagerBuilder,
torrent_metainfo::{torrent_from_bytes, TorrentMetaV1Owned},
};
@ -20,9 +19,7 @@ async fn torrent_from_url(url: &str) -> anyhow::Result<TorrentMetaV1Owned> {
.bytes()
.await
.with_context(|| format!("error reading repsonse body from {}", url))?;
Ok(torrent_from_bytes(&b)
.context("error decoding torrent")?
.clone_to_owned())
torrent_from_bytes(&b).context("error decoding torrent")
}
fn torrent_from_file(filename: &str) -> anyhow::Result<TorrentMetaV1Owned> {
@ -37,9 +34,7 @@ fn torrent_from_file(filename: &str) -> anyhow::Result<TorrentMetaV1Owned> {
.read_to_end(&mut buf)
.with_context(|| format!("error reading {}", filename))?;
}
Ok(torrent_from_bytes(&buf)
.context("error decoding torrent")?
.clone_to_owned())
torrent_from_bytes(&buf).context("error decoding torrent")
}
#[derive(Debug, Clap)]