Initial commit

This commit is contained in:
Igor Katson 2021-06-25 13:47:51 +01:00
commit 87d6fe27ce
20 changed files with 4780 additions and 0 deletions

View file

@ -0,0 +1,112 @@
use serde::Deserialize;
use crate::clone_to_owned::CloneToOwned;
#[derive(PartialEq, Eq, Hash, Clone)]
pub struct ByteString(pub Vec<u8>);
impl std::fmt::Debug for ByteString {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.0.iter().all(|b| *b == 0) {
return write!(f, "<{} bytes, all zeroes>", self.0.len());
}
match std::str::from_utf8(self.0.as_slice()) {
Ok(bytes) => bytes.fmt(f),
Err(_e) => write!(f, "<{} bytes>", self.0.len()),
}
}
}
#[derive(Deserialize, PartialEq, Eq, Hash, Clone)]
#[serde(transparent)]
pub struct ByteBuf<'a>(pub &'a [u8]);
impl<'a> ByteBuf<'a> {
pub fn as_bytes(&'a self) -> &'a [u8] {
self.0
}
}
fn debug_raw_bytes(b: &[u8], f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<{} bytes>", b.len())
}
impl<'a> std::fmt::Debug for ByteBuf<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.0.iter().all(|b| *b == 0) {
return write!(f, "<{} bytes, all zeroes>", self.0.len());
}
match std::str::from_utf8(self.0) {
Ok(bytes) => bytes.fmt(f),
Err(_e) => debug_raw_bytes(&self.0, f),
}
}
}
impl<'a> std::fmt::Display for ByteBuf<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.0.iter().all(|b| *b == 0) {
return write!(f, "<{} bytes, all zeroes>", self.0.len());
}
match std::str::from_utf8(self.0) {
Ok(bytes) => f.write_str(bytes),
Err(_e) => debug_raw_bytes(&self.0, f),
}
}
}
impl<'a> CloneToOwned for ByteBuf<'a> {
type Target = ByteString;
fn clone_to_owned(&self) -> Self::Target {
ByteString(self.0.into())
}
}
impl CloneToOwned for ByteString {
type Target = ByteString;
fn clone_to_owned(&self) -> Self::Target {
self.clone()
}
}
impl<'a> std::convert::AsRef<[u8]> for ByteBuf<'a> {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl std::convert::AsRef<[u8]> for ByteString {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl<'a> std::ops::Deref for ByteBuf<'a> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::Deref for ByteString {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'a> From<&'a [u8]> for ByteBuf<'a> {
fn from(b: &'a [u8]) -> Self {
Self(b)
}
}
impl<'a> From<&'a [u8]> for ByteString {
fn from(b: &'a [u8]) -> Self {
Self(b.into())
}
}

View file

@ -0,0 +1,71 @@
use log::{debug, info};
use crate::{
buffers::ByteString,
lengths::{Lengths, ValidPieceIndex},
peer_comms::Piece,
type_aliases::BF,
};
pub struct ChunkTracker {
needed_pieces: BF,
chunk_status: BF,
lengths: Lengths,
}
fn compute_chunk_status(lengths: &Lengths, needed_pieces: &BF) -> BF {
let required_bits = lengths.total_chunks();
let required_size = (required_bits as usize + 1) / 8;
let vec = vec![0u8; required_size];
let mut chunk_bf = BF::from_vec(vec);
for bit in needed_pieces.iter_zeros() {
let offset = bit * 8;
for i in 0..8 {
chunk_bf.set(offset + i, true);
}
}
chunk_bf
}
impl ChunkTracker {
pub fn new(needed_pieces: BF, lengths: Lengths) -> Self {
Self {
chunk_status: compute_chunk_status(&lengths, &needed_pieces),
needed_pieces,
lengths,
}
}
pub fn get_needed_pieces(&self) -> &BF {
&self.needed_pieces
}
pub fn reserve_needed_piece(&mut self, index: ValidPieceIndex) {
self.needed_pieces.set(index.get() as usize, false)
}
pub fn mark_piece_needed(&mut self, index: ValidPieceIndex) -> bool {
info!("remarking piece={} as needed", index);
self.needed_pieces.set(index.get() as usize, true);
self.chunk_status
.get_mut(self.lengths.chunk_range(index))
.map(|s| {
s.set_all(false);
true
})
.unwrap_or_default()
}
// return true if the whole piece is marked downloaded
pub fn mark_chunk_downloaded(&mut self, piece: &Piece<ByteString>) -> Option<bool> {
let chunk_info = self.lengths.chunk_info_from_received_piece_data(piece)?;
self.chunk_status
.set(chunk_info.absolute_index as usize, true);
let chunk_range = self.lengths.chunk_range(chunk_info.piece_index);
let chunk_range = self.chunk_status.get(chunk_range).unwrap();
let all = chunk_range.all();
debug!(
"piece={}, chunk_info={:?}, bits={:?}",
piece.index, chunk_info, chunk_range,
);
Some(all)
}
}

View file

@ -0,0 +1,27 @@
pub trait CloneToOwned {
type Target;
fn clone_to_owned(&self) -> Self::Target;
}
impl<T> CloneToOwned for Option<T>
where
T: CloneToOwned,
{
type Target = Option<<T as CloneToOwned>::Target>;
fn clone_to_owned(&self) -> Self::Target {
self.as_ref().map(|i| i.clone_to_owned())
}
}
impl<T> CloneToOwned for Vec<T>
where
T: CloneToOwned,
{
type Target = Vec<<T as CloneToOwned>::Target>;
fn clone_to_owned(&self) -> Self::Target {
self.iter().map(|i| i.clone_to_owned()).collect()
}
}

View file

@ -0,0 +1 @@
pub const CHUNK_SIZE: u32 = 16384;

View file

@ -0,0 +1,268 @@
use crate::{buffers::ByteString, constants::CHUNK_SIZE, peer_comms::Piece};
const fn is_power_of_two(x: u64) -> bool {
(x != 0) && ((x & (x - 1)) == 0)
}
pub const fn ceil_div_u64(a: u64, b: u64) -> u64 {
(a + b - 1) / b
}
pub const fn last_element_size_u64(total: u64, chunk_size: u64) -> u64 {
let rem = total % chunk_size;
if rem == 0 {
return chunk_size;
}
rem
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ChunkInfo {
pub piece_index: ValidPieceIndex,
pub chunk_index: u32,
pub absolute_index: u32,
pub size: u32,
pub offset: u32,
}
#[derive(Debug, Clone, Copy)]
pub struct Lengths {
chunk_length: u32,
total_length: u64,
piece_length: u32,
last_piece_id: u32,
last_piece_length: u32,
chunks_per_piece: u32,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct ValidPieceIndex(u32);
impl std::fmt::Display for ValidPieceIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::fmt::Debug for ValidPieceIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl ValidPieceIndex {
pub fn get(&self) -> u32 {
self.0
}
}
impl Lengths {
pub fn new(
total_length: u64,
piece_length: u32,
chunk_length: Option<u32>,
) -> anyhow::Result<Self> {
let chunk_length = chunk_length.unwrap_or(CHUNK_SIZE);
if !(is_power_of_two(piece_length as u64)) {
anyhow::bail!("piece length {} is not a power of 2", piece_length);
}
if !(is_power_of_two(chunk_length as u64)) {
anyhow::bail!("chunk length {} is not a power of 2", chunk_length);
}
if chunk_length >= piece_length {
anyhow::bail!(
"chunk length {} should be smaller than pice length {}",
chunk_length,
piece_length
);
}
Ok(Self {
chunk_length,
piece_length,
total_length,
chunks_per_piece: piece_length / chunk_length,
last_piece_id: ((total_length + 1) / piece_length as u64) as u32,
last_piece_length: last_element_size_u64(total_length, piece_length as u64) as u32,
})
}
pub const fn piece_bitfield_bytes(&self) -> usize {
ceil_div_u64(self.total_pieces() as u64, 8) as usize
}
pub const fn total_length(&self) -> u64 {
self.total_length
}
pub const fn validate_piece_index(&self, index: u32) -> Option<ValidPieceIndex> {
if index > self.last_piece_id {
return None;
}
Some(ValidPieceIndex(index))
}
pub const fn default_piece_length(&self) -> u32 {
self.piece_length
}
pub const fn default_chunk_length(&self) -> u32 {
self.chunk_length
}
pub const fn total_chunks(&self) -> u32 {
ceil_div_u64(self.total_length, self.chunk_length as u64) as u32
}
pub const fn total_pieces(&self) -> u32 {
self.last_piece_id + 1
}
pub const fn piece_length(&self, index: ValidPieceIndex) -> u32 {
if index.0 == self.last_piece_id {
return self.last_piece_length;
}
self.piece_length
}
pub const fn piece_offset(&self, index: ValidPieceIndex) -> u64 {
index.0 as u64 * self.piece_length as u64
}
pub fn iter_chunk_infos(&self, index: ValidPieceIndex) -> impl Iterator<Item = ChunkInfo> {
let mut remaining = self.piece_length(index);
let chunk_size = self.chunk_length;
let absolute_offset = index.0 * self.chunks_per_piece;
(0u32..).scan(0, move |offset, idx| {
if remaining == 0 {
return None;
}
let s = std::cmp::min(remaining, chunk_size);
let result = ChunkInfo {
piece_index: index,
chunk_index: idx,
absolute_index: absolute_offset + idx,
size: s,
offset: *offset,
};
*offset += s;
remaining -= s;
Some(result)
})
}
pub fn chunk_info_from_received_piece_data(
&self,
piece: &Piece<ByteString>,
) -> Option<ChunkInfo> {
let piece_index = self.validate_piece_index(piece.index)?;
let index = piece.begin / self.chunk_length;
let chunk_size = self.chunk_size(piece_index, index)?;
let offset = self.chunk_offset_in_piece(piece_index, index)?;
if offset != piece.begin {
return None;
}
if chunk_size as usize != piece.block.len() {
return None;
}
let absolute_index = self.chunks_per_piece * piece_index.get() + index;
Some(ChunkInfo {
piece_index,
chunk_index: index,
size: chunk_size,
offset,
absolute_index,
})
}
pub const fn chunk_range(&self, index: ValidPieceIndex) -> std::ops::Range<usize> {
let start = index.0 * self.chunks_per_piece;
let end = start + self.chunks_per_piece(index);
start as usize..end as usize
}
pub const fn chunks_per_piece(&self, index: ValidPieceIndex) -> u32 {
if index.0 == self.last_piece_id {
return (self.last_piece_length + self.chunk_length - 1) / self.chunk_length;
}
self.chunks_per_piece
}
pub const fn chunk_offset_in_piece(
&self,
piece_index: ValidPieceIndex,
chunk_index: u32,
) -> Option<u32> {
if chunk_index >= self.chunks_per_piece(piece_index) {
return None;
}
Some(chunk_index * self.chunk_length)
}
pub fn chunk_size(&self, piece_index: ValidPieceIndex, chunk_index: u32) -> Option<u32> {
let chunks_per_piece = self.chunks_per_piece(piece_index);
let pl = self.piece_length(piece_index);
if chunk_index >= chunks_per_piece {
return None;
}
let offset = chunk_index * self.chunk_length;
Some(std::cmp::min(self.chunk_length, pl - offset))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_lengths() -> Lengths {
Lengths::new(1174243328, 262144, None).unwrap()
}
#[test]
fn test_total_pieces() {
let l = make_lengths();
assert_eq!(l.total_pieces(), 4480);
}
#[test]
fn test_piece_length() {
let l = make_lengths();
let p = l.validate_piece_index(4479).unwrap();
assert_eq!(l.piece_length(l.validate_piece_index(0).unwrap()), 262144);
assert_eq!(l.piece_length(p), 100352);
}
#[test]
fn test_chunks_in_piece() {
let l = make_lengths();
let p = l.validate_piece_index(4479).unwrap();
assert_eq!(l.chunks_per_piece(l.validate_piece_index(0).unwrap()), 16);
assert_eq!(l.chunks_per_piece(p), 7);
}
#[test]
fn test_chunk_size() {
let l = make_lengths();
let p = l.validate_piece_index(4479).unwrap();
assert_eq!(l.chunk_size(p, 0), Some(16384));
assert_eq!(l.chunk_size(p, 6), Some(2048));
}
#[test]
fn test_chunk_infos() {
let l = make_lengths();
let p = l.validate_piece_index(4479).unwrap();
let mut it = l.iter_chunk_infos(p);
let first = it.next().unwrap();
let last = it.last().unwrap();
assert_eq!(
first,
ChunkInfo {
piece_index: p,
chunk_index: 0,
absolute_index: 71664,
size: 16384,
offset: 0,
}
);
assert_eq!(
last,
ChunkInfo {
piece_index: p,
chunk_index: 6,
absolute_index: 71670,
size: 2048,
offset: 98304,
}
);
}
}

View file

@ -0,0 +1,12 @@
pub mod buffers;
pub mod chunk_tracker;
pub mod clone_to_owned;
pub mod constants;
pub mod lengths;
pub mod peer_comms;
pub mod peer_id;
pub mod serde_bencode;
pub mod torrent_manager;
pub mod torrent_metainfo;
pub mod tracker_comms;
pub mod type_aliases;

View file

@ -0,0 +1,454 @@
use bincode::Options;
use byteorder::ByteOrder;
use serde::{Deserialize, Serialize};
use crate::{
buffers::{ByteBuf, ByteString},
clone_to_owned::CloneToOwned,
};
const PREAMBLE_LEN: usize = 5;
const NO_PAYLOAD_MSG_LEN: usize = PREAMBLE_LEN;
const PSTR_BT1: &str = "BitTorrent protocol";
const LEN_PREFIX_KEEPALIVE: u32 = 0;
const LEN_PREFIX_CHOKE: u32 = 1;
const LEN_PREFIX_UNCHOKE: u32 = 1;
const LEN_PREFIX_INTERESTED: u32 = 1;
const LEN_PREFIX_NOT_INTERESTED: u32 = 1;
const LEN_PREFIX_HAVE: u32 = 5;
const LEN_PREFIX_REQUEST: u32 = 13;
const MSGID_CHOKE: u8 = 0;
const MSGID_UNCHOKE: u8 = 1;
const MSGID_INTERESTED: u8 = 2;
const MSGID_NOT_INTERESTED: u8 = 3;
const MSGID_HAVE: u8 = 4;
const MSGID_BITFIELD: u8 = 5;
const MSGID_REQUEST: u8 = 6;
const MSGID_PIECE: u8 = 7;
#[derive(Debug)]
pub enum MessageDeserializeError {
NotEnoughData(usize, &'static str),
UnsupportedMessageId(u8),
IncorrectLenPrefix {
received: u32,
expected: u32,
msg_id: u8,
},
OtherBincode {
error: bincode::Error,
msg_id: u8,
len_prefix: u32,
name: &'static str,
},
}
#[derive(Debug)]
pub struct Piece<ByteBuf> {
pub index: u32,
pub begin: u32,
pub block: ByteBuf,
}
impl<ByteBuf> Piece<ByteBuf>
where
ByteBuf: AsRef<[u8]>,
{
pub fn serialize(&self, buf: &mut [u8]) -> usize {
byteorder::BigEndian::write_u32(&mut buf[0..4], self.index);
byteorder::BigEndian::write_u32(&mut buf[4..8], self.begin);
(&mut buf[8..8 + self.block.as_ref().len()]).copy_from_slice(self.block.as_ref());
self.block.as_ref().len() + 8
}
pub fn deserialize<'a>(buf: &'a [u8]) -> Piece<ByteBuf>
where
ByteBuf: From<&'a [u8]> + 'a,
{
let index = byteorder::BigEndian::read_u32(&buf[0..4]);
let begin = byteorder::BigEndian::read_u32(&buf[4..8]);
let block = ByteBuf::from(&buf[8..]);
Piece {
index,
begin,
block,
}
}
}
impl std::fmt::Display for MessageDeserializeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MessageDeserializeError::NotEnoughData(b, name) => {
write!(
f,
"not enough data to deserialize {}: expected at least {} more bytes",
name, b
)
}
MessageDeserializeError::UnsupportedMessageId(msg_id) => {
write!(f, "unsupported message id {}", msg_id)
}
MessageDeserializeError::IncorrectLenPrefix {
received,
expected,
msg_id,
} => write!(
f,
"incorrect len prefix for message id {}, expected {}, received {}",
msg_id, expected, received
),
MessageDeserializeError::OtherBincode {
error,
msg_id,
name,
len_prefix,
} => write!(
f,
"error deserializing {} (msg_id={}, len_prefix={}): {:?}",
name, msg_id, len_prefix, error
),
}
}
}
impl std::error::Error for MessageDeserializeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
MessageDeserializeError::OtherBincode { error, .. } => Some(error),
_ => None,
}
}
}
#[derive(Debug)]
pub enum Message<ByteBuf> {
Request(Request),
Bitfield(ByteBuf),
KeepAlive,
Have(u32),
Choke,
Unchoke,
Interested,
NotInterested,
Piece(Piece<ByteBuf>),
}
pub type MessageBorrowed<'a> = Message<ByteBuf<'a>>;
pub type MessageOwned = Message<ByteString>;
pub type BitfieldBorrowed<'a> = &'a bitvec::slice::BitSlice<bitvec::order::Lsb0, u8>;
pub type BitfieldOwned = bitvec::vec::BitVec<bitvec::order::Lsb0, u8>;
pub struct Bitfield<'a> {
pub data: BitfieldBorrowed<'a>,
}
impl<ByteBuf: CloneToOwned> CloneToOwned for Message<ByteBuf> {
type Target = Message<<ByteBuf as CloneToOwned>::Target>;
fn clone_to_owned(&self) -> Self::Target {
match self {
Message::Request(req) => Message::Request(*req),
Message::Bitfield(b) => Message::Bitfield(b.clone_to_owned()),
Message::Choke => Message::Choke,
Message::Unchoke => Message::Unchoke,
Message::Interested => Message::Interested,
Message::Piece(piece) => Message::Piece(Piece {
index: piece.index,
begin: piece.begin,
block: piece.block.clone_to_owned(),
}),
Message::KeepAlive => Message::KeepAlive,
Message::Have(v) => Message::Have(*v),
Message::NotInterested => Message::NotInterested,
}
}
}
impl<'a> Bitfield<'a> {
pub fn new_from_slice(buf: &'a [u8]) -> anyhow::Result<Self> {
Ok(Self {
data: bitvec::slice::BitSlice::from_slice(buf)?,
})
}
}
impl<'a> std::fmt::Debug for Bitfield<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Bitfield")
.field("_ones", &self.data.count_ones())
.field("_len", &self.data.len())
.finish()
}
}
impl<ByteBuf> Message<ByteBuf>
where
ByteBuf: AsRef<[u8]>,
{
pub fn len_prefix_and_msg_id(&self) -> (u32, u8) {
match self {
Message::Request(_) => (LEN_PREFIX_REQUEST, MSGID_REQUEST),
Message::Bitfield(b) => (1 + b.as_ref().len() as u32, MSGID_BITFIELD),
Message::Choke => (LEN_PREFIX_CHOKE, MSGID_CHOKE),
Message::Unchoke => (LEN_PREFIX_UNCHOKE, MSGID_UNCHOKE),
Message::Interested => (LEN_PREFIX_INTERESTED, MSGID_INTERESTED),
Message::NotInterested => (LEN_PREFIX_NOT_INTERESTED, MSGID_NOT_INTERESTED),
Message::Piece(p) => (9 + p.block.as_ref().len() as u32, MSGID_PIECE),
Message::KeepAlive => (LEN_PREFIX_KEEPALIVE, 0),
Message::Have(_) => (LEN_PREFIX_HAVE, MSGID_HAVE),
}
}
pub fn serialize(&self, out: &mut Vec<u8>) -> usize {
let (lp, msg_id) = self.len_prefix_and_msg_id();
out.resize(PREAMBLE_LEN, 0);
byteorder::BigEndian::write_u32(&mut out[..4], lp);
out[4] = msg_id;
let ser = bopts();
match self {
Message::Request(request) => {
const MSG_LEN: usize = PREAMBLE_LEN + 12;
out.resize(MSG_LEN, 0);
debug_assert_eq!((&out[PREAMBLE_LEN..]).len(), 12);
ser.serialize_into(&mut out[PREAMBLE_LEN..], request)
.unwrap();
MSG_LEN
}
Message::Bitfield(_) => todo!(),
Message::Choke | Message::Unchoke | Message::Interested => PREAMBLE_LEN,
Message::Piece(p) => {
let msg_len = PREAMBLE_LEN + 8 + p.block.as_ref().len();
out.resize(msg_len, 0);
p.serialize(&mut out[PREAMBLE_LEN..(8 + p.block.as_ref().len())]);
msg_len
}
Message::KeepAlive => 4,
Message::Have(v) => {
let msg_len = PREAMBLE_LEN + 4;
out.resize(msg_len, 0);
byteorder::BE::write_u32(&mut out[PREAMBLE_LEN..], *v);
msg_len
}
Message::NotInterested => todo!(),
}
}
pub fn deserialize<'a>(
buf: &'a [u8],
) -> Result<(Message<ByteBuf>, usize), MessageDeserializeError>
where
ByteBuf: From<&'a [u8]> + 'a,
{
let len_prefix = match buf.get(0..4) {
Some(bytes) => byteorder::BigEndian::read_u32(bytes),
None => return Err(MessageDeserializeError::NotEnoughData(4, "message")),
};
if len_prefix == 0 {
return Ok((Message::KeepAlive, 4));
}
let msg_id = match buf.get(4) {
Some(msg_id) => *msg_id,
None => return Err(MessageDeserializeError::NotEnoughData(1, "message")),
};
let rest = &buf[5..];
let decoder_config = bincode::DefaultOptions::new()
.with_fixint_encoding()
.with_big_endian();
match msg_id {
MSGID_CHOKE => {
if len_prefix != LEN_PREFIX_CHOKE {
return Err(MessageDeserializeError::IncorrectLenPrefix {
received: len_prefix,
expected: LEN_PREFIX_CHOKE,
msg_id,
});
}
Ok((Message::Choke, NO_PAYLOAD_MSG_LEN))
}
MSGID_UNCHOKE => {
if len_prefix != LEN_PREFIX_UNCHOKE {
return Err(MessageDeserializeError::IncorrectLenPrefix {
received: len_prefix,
expected: LEN_PREFIX_UNCHOKE,
msg_id,
});
}
Ok((Message::Unchoke, NO_PAYLOAD_MSG_LEN))
}
MSGID_INTERESTED => {
if len_prefix != LEN_PREFIX_INTERESTED {
return Err(MessageDeserializeError::IncorrectLenPrefix {
received: len_prefix,
expected: LEN_PREFIX_INTERESTED,
msg_id,
});
}
Ok((Message::Interested, NO_PAYLOAD_MSG_LEN))
}
MSGID_NOT_INTERESTED => {
if len_prefix != LEN_PREFIX_NOT_INTERESTED {
return Err(MessageDeserializeError::IncorrectLenPrefix {
received: len_prefix,
expected: LEN_PREFIX_NOT_INTERESTED,
msg_id,
});
}
Ok((Message::NotInterested, NO_PAYLOAD_MSG_LEN))
}
MSGID_HAVE => {
let expected_len = 4;
match rest.get(..expected_len as usize) {
Some(h) => Ok((
Message::Have(byteorder::BE::read_u32(&h)),
PREAMBLE_LEN + expected_len,
)),
None => {
let missing = expected_len - rest.len();
Err(MessageDeserializeError::NotEnoughData(missing, "have"))
}
}
}
MSGID_BITFIELD => {
if len_prefix <= 1 {
return Err(MessageDeserializeError::IncorrectLenPrefix {
expected: 2,
received: len_prefix,
msg_id,
});
}
let expected_len = len_prefix as usize - 1;
match rest.get(..expected_len as usize) {
Some(bitfield) => Ok((
Message::Bitfield(ByteBuf::from(bitfield)),
PREAMBLE_LEN + expected_len,
)),
None => {
let missing = expected_len - rest.len();
Err(MessageDeserializeError::NotEnoughData(missing, "bitfield"))
}
}
}
MSGID_REQUEST => {
let expected_len = 12;
match rest.get(..expected_len as usize) {
Some(b) => {
let request = decoder_config.deserialize::<Request>(&b).unwrap();
Ok((Message::Request(request), PREAMBLE_LEN + expected_len))
}
None => {
let missing = expected_len - rest.len();
Err(MessageDeserializeError::NotEnoughData(missing, "request"))
}
}
}
MSGID_PIECE => {
if len_prefix <= 9 {
return Err(MessageDeserializeError::IncorrectLenPrefix {
expected: 10,
received: len_prefix,
msg_id,
});
}
// <len=0009+X> is for "9", "8" is for 2 integer fields in the piece.
let expected_len = len_prefix as usize - 9 + 8;
match rest.get(..expected_len) {
Some(b) => Ok((
Message::Piece(Piece::deserialize(&b)),
PREAMBLE_LEN + expected_len,
)),
None => Err(MessageDeserializeError::NotEnoughData(
expected_len - rest.len(),
"piece",
)),
}
}
msg_id => Err(MessageDeserializeError::UnsupportedMessageId(msg_id)),
}
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Handshake<'a> {
pub pstr: &'a str,
pub reserved: [u8; 8],
pub info_hash: [u8; 20],
pub peer_id: [u8; 20],
}
fn bopts() -> impl bincode::Options {
bincode::DefaultOptions::new()
.with_fixint_encoding()
.with_big_endian()
}
impl<'a> Handshake<'a> {
pub fn new(info_hash: [u8; 20], peer_id: [u8; 20]) -> Handshake<'static> {
debug_assert_eq!(PSTR_BT1.len(), 19);
Handshake {
pstr: PSTR_BT1,
reserved: [0; 8],
info_hash,
peer_id,
}
}
fn bopts() -> impl bincode::Options {
bincode::DefaultOptions::new()
}
pub fn deserialize(b: &[u8]) -> Result<(Handshake<'_>, usize), MessageDeserializeError> {
let pstr_len = *b
.get(0)
.ok_or(MessageDeserializeError::NotEnoughData(1, "handshake"))?;
let expected_len = 1usize + pstr_len as usize + 48;
let hbuf = b
.get(..expected_len)
.ok_or(MessageDeserializeError::NotEnoughData(
expected_len,
"handshake",
))?;
Ok((Self::bopts().deserialize(&hbuf).unwrap(), expected_len))
}
pub fn serialize(&self) -> Vec<u8> {
Self::bopts().serialize(&self).unwrap()
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
pub struct Request {
pub index: u32,
pub begin: u32,
pub length: u32,
}
impl Request {
pub fn new(index: u32, begin: u32, length: u32) -> Self {
Self {
index,
begin,
length,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handshake_serialize() {
let info_hash = [
1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
];
let peer_id = [
1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
];
let b = dbg!(Handshake::new(info_hash, peer_id).serialize());
assert_eq!(b.len(), 20 + 20 + 8 + 19 + 1);
}
}

View file

@ -0,0 +1,45 @@
#[derive(Debug)]
pub enum AzureusStyleKind {
Deluge,
LibTorrent,
Transmission,
Other([char; 2]),
}
#[derive(Debug)]
pub struct AzureusStyle {
pub kind: AzureusStyleKind,
pub version: [char; 4],
}
impl AzureusStyleKind {
pub const fn from_bytes(b1: u8, b2: u8) -> Self {
match &[b1, b2] {
b"DE" => AzureusStyleKind::Deluge,
b"lt" | b"LT" => AzureusStyleKind::LibTorrent,
b"TR" => AzureusStyleKind::Transmission,
_ => AzureusStyleKind::Other([b1 as char, b2 as char]),
}
}
}
fn try_decode_azureus_style(p: &[u8; 20]) -> Option<AzureusStyle> {
if !(p[0] == b'-' && p[7] == b'-') {
return None;
}
let mut version = ['0'; 4];
for (i, c) in (&p[3..7]).iter().copied().enumerate() {
version[i] = c as char;
}
let kind = AzureusStyleKind::from_bytes(p[1], p[2]);
Some(AzureusStyle { kind, version })
}
#[derive(Debug)]
pub enum PeerId {
AzureusStyle(AzureusStyle),
}
pub fn try_decode_peer_id(p: [u8; 20]) -> Option<PeerId> {
Some(PeerId::AzureusStyle(try_decode_azureus_style(&p)?))
}

View file

@ -0,0 +1,663 @@
use serde::de::Deserializer;
use serde::de::Error as DeError;
use std::collections::HashMap;
use crate::buffers::ByteBuf;
use crate::buffers::ByteString;
pub struct BencodeDeserializer<'de> {
buf: &'de [u8],
field_context: Vec<ByteBuf<'de>>,
parsing_key: bool,
pub(crate) is_torrent_info: bool,
pub(crate) torrent_info_digest: Option<[u8; 20]>,
}
impl<'de> BencodeDeserializer<'de> {
pub fn new_from_buf(buf: &'de [u8]) -> BencodeDeserializer<'de> {
Self {
buf,
field_context: Default::default(),
parsing_key: false,
is_torrent_info: false,
torrent_info_digest: None,
}
}
fn parse_integer(&mut self) -> Result<i64, Error> {
match self.buf.iter().copied().position(|e| e == b'e') {
Some(end) => {
let intbytes = &self.buf[1..end];
let value: i64 = std::str::from_utf8(intbytes)
.map_err(|e| Error::new_from_err(e).set_context(self))?
.parse()
.map_err(|e| Error::new_from_err(e).set_context(self))?;
let rem = self.buf.get(end + 1..).unwrap_or_default();
self.buf = rem;
Ok(value)
}
None => Err(Error::custom("cannot parse integer, unexpected EOF").set_context(self)),
}
}
fn parse_bytes(&mut self) -> Result<&'de [u8], Error> {
match self.buf.iter().copied().position(|e| e == b':') {
Some(length_delim) => {
let lenbytes = &self.buf[..length_delim];
let length: usize = std::str::from_utf8(lenbytes)
.map_err(|e| Error::new_from_err(e).set_context(self))?
.parse()
.map_err(|e| Error::new_from_err(e).set_context(self))?;
let bytes_start = length_delim + 1;
let bytes_end = bytes_start + length;
let bytes = &self.buf[bytes_start..bytes_end];
let rem = self.buf.get(bytes_end..).unwrap_or_default();
self.buf = rem;
Ok(bytes)
}
None => Err(Error::custom("cannot parse bytes, unexpected EOF").set_context(self)),
}
}
fn parse_bytes_checked(&mut self) -> Result<&'de [u8], Error> {
let first = match self.buf.first().copied() {
Some(first) => first,
None => return Err(Error::custom("expected bencode bytes, got EOF").set_context(self)),
};
match first {
b'0'..=b'9' => {}
_ => return Err(Error::custom("expected bencode bytes").set_context(self)),
}
let b = self.parse_bytes()?;
if self.parsing_key {
self.field_context.push(ByteBuf(b));
}
Ok(b)
}
}
pub fn from_bytes<'a, T>(buf: &'a [u8]) -> anyhow::Result<T>
where
T: serde::de::Deserialize<'a>,
{
let mut de = BencodeDeserializer::new_from_buf(buf);
Ok(T::deserialize(&mut de)?)
}
pub fn dyn_from_bytes(buf: &[u8]) -> anyhow::Result<DynBencodeNode<'_>> {
from_bytes(buf)
}
#[derive(Debug)]
enum ErrorKind {
Other(anyhow::Error),
NotSupported(&'static str),
}
#[derive(Debug, Default)]
pub struct ErrorContext {
field_stack: Vec<String>,
}
impl std::fmt::Display for ErrorContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut it = self.field_stack.iter();
if let Some(field) = it.next() {
write!(f, "\"{}\"", field)?;
} else {
return Ok(());
}
for field in self.field_stack.iter().skip(1) {
write!(f, " -> \"{}\"", field)?;
}
f.write_str(": ")
}
}
#[derive(Debug)]
pub struct Error {
kind: ErrorKind,
context: ErrorContext,
}
impl std::fmt::Display for ErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErrorKind::Other(err) => err.fmt(f),
ErrorKind::NotSupported(s) => write!(f, "{} is not supported by bencode", s),
}
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}{}", self.context, self.kind)
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.kind {
ErrorKind::Other(err) => err.source(),
_ => None,
}
}
}
impl Error {
fn new_from_err<E>(e: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Error {
kind: ErrorKind::Other(anyhow::Error::new(e)),
context: Default::default(),
}
}
fn new_from_kind(kind: ErrorKind) -> Self {
Self {
kind,
context: Default::default(),
}
}
fn new_from_anyhow(e: anyhow::Error) -> Self {
Error {
kind: ErrorKind::Other(e),
context: Default::default(),
}
}
fn custom_with_de<M: std::fmt::Display>(msg: M, de: &BencodeDeserializer<'_>) -> Self {
Self::custom(msg).set_context(de)
}
fn set_context(mut self, de: &BencodeDeserializer<'_>) -> Self {
self.context = ErrorContext {
field_stack: de.field_context.iter().map(|s| format!("{}", s)).collect(),
};
self
}
}
impl serde::de::Error for Error {
fn custom<T>(msg: T) -> Self
where
T: std::fmt::Display,
{
Self {
kind: ErrorKind::Other(anyhow::anyhow!("{}", msg)),
context: Default::default(),
}
}
}
impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut BencodeDeserializer<'de> {
type Error = Error;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
match self.buf.first().copied() {
Some(b'd') => self.deserialize_map(visitor),
Some(b'i') => self.deserialize_u64(visitor),
Some(b'l') => self.deserialize_seq(visitor),
Some(_) => self.deserialize_bytes(visitor),
None => Err(Error::custom_with_de("empty input", self)),
}
}
fn deserialize_bool<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(
Error::new_from_kind(ErrorKind::NotSupported("bencode doesn't support booleans"))
.set_context(self),
)
}
fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_i64(visitor)
}
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_i64(visitor)
}
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_i64(visitor)
}
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
if !self.buf.starts_with(b"i") {
return Err(Error::custom_with_de("expected bencode int", self));
}
visitor
.visit_i64(self.parse_integer()?)
.map_err(|e: Self::Error| e.set_context(self))
}
fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_i64(visitor)
}
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_i64(visitor)
}
fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_i64(visitor)
}
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_i64(visitor)
}
fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(
Error::new_from_kind(ErrorKind::NotSupported("bencode doesn't support floats"))
.set_context(self),
)
}
fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(
Error::new_from_kind(ErrorKind::NotSupported("bencode doesn't support floats"))
.set_context(self),
)
}
fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(
Error::new_from_kind(ErrorKind::NotSupported("bencode doesn't support chars"))
.set_context(self),
)
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let first = match self.buf.first().copied() {
Some(first) => first,
None => {
return Err(Error::custom_with_de(
"expected bencode string, got EOF",
self,
))
}
};
match first {
b'0'..=b'9' => {}
_ => return Err(Error::custom_with_de("expected bencode string", self)),
}
let b = self.parse_bytes()?;
let s = std::str::from_utf8(b).map_err(|e| {
Error::new_from_anyhow(anyhow::anyhow!("error reading utf-8: {}", e)).set_context(self)
})?;
visitor
.visit_borrowed_str(s)
.map_err(|e: Self::Error| e.set_context(self))
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_str(visitor)
}
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let b = self.parse_bytes_checked()?;
visitor
.visit_borrowed_bytes(b)
.map_err(|e: Self::Error| e.set_context(self))
}
fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_bytes(visitor)
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
visitor
.visit_some(&mut *self)
.map_err(|e: Self::Error| e.set_context(self))
}
fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(Error::new_from_kind(ErrorKind::NotSupported(
"bencode doesn't support unit types",
))
.set_context(self))
}
fn deserialize_unit_struct<V>(
self,
_name: &'static str,
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(Error::new_from_kind(ErrorKind::NotSupported(
"bencode doesn't support unit structs",
))
.set_context(self))
}
fn deserialize_newtype_struct<V>(
self,
_name: &'static str,
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(
Error::new_from_kind(ErrorKind::NotSupported("bencode doesn't newtype structs"))
.set_context(self),
)
}
fn deserialize_seq<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
if !self.buf.starts_with(b"l") {
return Err(Error::custom(format!(
"expected bencode list, but got {}",
self.buf[0] as char,
)));
}
self.buf = self.buf.get(1..).unwrap_or_default();
visitor
.visit_seq(SeqAccess { de: &mut self })
.map_err(|e: Self::Error| e.set_context(self))
}
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
_len: usize,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_map<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
if !self.buf.starts_with(b"d") {
return Err(Error::custom("expected bencode dict"));
}
self.buf = self.buf.get(1..).unwrap_or_default();
visitor
.visit_map(MapAccess { de: &mut self })
.map_err(|e: Self::Error| e.set_context(self))
}
fn deserialize_struct<V>(
self,
_name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_map(visitor)
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(
Error::new_from_kind(ErrorKind::NotSupported("deserializing enums not supported"))
.set_context(self),
)
}
fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let name = self.parse_bytes_checked()?;
visitor
.visit_borrowed_bytes(name)
.map_err(|e: Self::Error| e.set_context(self))
}
fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_any(visitor)
}
}
struct MapAccess<'a, 'de> {
de: &'a mut BencodeDeserializer<'de>,
}
struct SeqAccess<'a, 'de> {
de: &'a mut BencodeDeserializer<'de>,
}
impl<'a, 'de> serde::de::MapAccess<'de> for MapAccess<'a, 'de> {
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where
K: serde::de::DeserializeSeed<'de>,
{
if self.de.buf.starts_with(b"e") {
self.de.buf = self.de.buf.get(1..).unwrap_or_default();
return Ok(None);
}
self.de.parsing_key = true;
let retval = seed.deserialize(&mut *self.de)?;
self.de.parsing_key = false;
Ok(Some(retval))
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
where
V: serde::de::DeserializeSeed<'de>,
{
let buf_before = self.de.buf;
let value = seed.deserialize(&mut *self.de)?;
if self.de.is_torrent_info && self.de.field_context.as_slice() == [ByteBuf(b"info")] {
let len = self.de.buf.as_ptr() as usize - buf_before.as_ptr() as usize;
let mut hash = sha1::Sha1::new();
hash.update(&buf_before[..len]);
let digest = hash.digest().bytes();
self.de.torrent_info_digest = Some(digest)
}
self.de.field_context.pop();
Ok(value)
}
}
impl<'a, 'de> serde::de::SeqAccess<'de> for SeqAccess<'a, 'de> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: serde::de::DeserializeSeed<'de>,
{
if self.de.buf.starts_with(b"e") {
self.de.buf = self.de.buf.get(1..).unwrap_or_default();
return Ok(None);
}
Ok(Some(seed.deserialize(&mut *self.de)?))
}
}
impl<'de> serde::de::Deserialize<'de> for DynBencodeNode<'de> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = DynBencodeNode<'de>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a bencode value")
}
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(DynBencodeNode::Integer(v))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut v = Vec::new();
while let Some(value) = seq.next_element()? {
v.push(value);
}
Ok(DynBencodeNode::List(v))
}
fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(DynBencodeNode::Bytes(ByteBuf(v)))
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut hmap = HashMap::new();
while let Some(key) = map.next_key()? {
let value = map.next_value()?;
hmap.insert(key, value);
}
Ok(DynBencodeNode::Dict(hmap))
}
}
deserializer.deserialize_any(Visitor {})
}
}
impl<'de> serde::de::Deserialize<'de> for ByteString {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = Vec<u8>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("bencode byte string")
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.to_owned())
}
}
Ok(ByteString(deserializer.deserialize_byte_buf(Visitor {})?))
}
}
#[derive(Debug)]
pub enum DynBencodeNode<'a> {
Bytes(ByteBuf<'a>),
Integer(i64),
List(Vec<DynBencodeNode<'a>>),
Dict(HashMap<ByteBuf<'a>, DynBencodeNode<'a>>),
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Read;
#[test]
fn test_deserialize_torrent_dyn() {
let mut buf = Vec::new();
let filename = "resources/ubuntu-21.04-desktop-amd64.iso.torrent";
std::fs::File::open(filename)
.unwrap()
.read_to_end(&mut buf)
.unwrap();
let torrent: DynBencodeNode = from_bytes(&buf).unwrap();
dbg!(torrent);
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,253 @@
use std::{fs::File, ops::Deref, path::PathBuf};
use serde::Deserialize;
use crate::{
buffers::{ByteBuf, ByteString},
clone_to_owned::CloneToOwned,
serde_bencode::BencodeDeserializer,
};
pub type TorrentMetaV1Borrowed<'a> = TorrentMetaV1<ByteBuf<'a>>;
pub type TorrentMetaV1Owned = TorrentMetaV1<ByteString>;
pub fn torrent_from_bytes(buf: &[u8]) -> anyhow::Result<TorrentMetaV1Borrowed<'_>> {
let mut de = BencodeDeserializer::new_from_buf(buf);
de.is_torrent_info = true;
let mut t = TorrentMetaV1::deserialize(&mut de)?;
t.info_hash = de.torrent_info_digest.unwrap();
Ok(t)
}
pub fn torrent_from_bytes_owned(buf: &[u8]) -> anyhow::Result<TorrentMetaV1Owned> {
let mut de = BencodeDeserializer::new_from_buf(buf);
de.is_torrent_info = true;
let mut t = TorrentMetaV1Owned::deserialize(&mut de)?;
t.info_hash = de.torrent_info_digest.unwrap();
Ok(t)
}
#[derive(Deserialize, Debug, Clone)]
pub struct TorrentMetaV1<BufType: Clone> {
pub announce: BufType,
#[serde(rename = "announce-list")]
pub announce_list: Vec<Vec<BufType>>,
pub info: TorrentMetaV1Info<BufType>,
pub comment: Option<BufType>,
#[serde(rename = "created by")]
pub created_by: Option<BufType>,
pub encoding: Option<BufType>,
pub publisher: Option<BufType>,
#[serde(rename = "publisher-url")]
pub publisher_url: Option<BufType>,
#[serde(rename = "creation date")]
pub creation_date: Option<usize>,
#[serde(skip)]
pub info_hash: [u8; 20],
}
impl<BufType: Clone> TorrentMetaV1<BufType> {
pub fn iter_announce(&self) -> impl Iterator<Item = &BufType> {
std::iter::once(&self.announce).chain(self.announce_list.iter().flatten())
}
}
#[derive(Deserialize, Debug, Clone)]
pub struct TorrentMetaV1Info<BufType: Clone> {
pub name: Option<BufType>,
pub pieces: BufType,
#[serde(rename = "piece length")]
pub piece_length: u32,
// Single-file mode
pub length: Option<u64>,
pub md5sum: Option<BufType>,
// Multi-file mode
pub files: Option<Vec<TorrentMetaV1File<BufType>>>,
}
pub enum FileIteratorName<'a, ByteBuf> {
Single(Option<&'a ByteBuf>),
Tree(&'a [ByteBuf]),
}
impl<'a, ByteBuf> FileIteratorName<'a, ByteBuf> {
pub fn iter_components(&self) -> impl Iterator<Item = Option<&'a ByteBuf>> {
let single_it = std::iter::once(match self {
FileIteratorName::Single(n) => Some(*n),
FileIteratorName::Tree(_) => None,
});
let multi_it = match self {
FileIteratorName::Single(_) => &[],
FileIteratorName::Tree(t) => *t,
}
.iter()
.map(|p| Some(Some(p)));
single_it.chain(multi_it).flatten()
}
}
impl<BufType: Clone + Deref<Target = [u8]>> TorrentMetaV1Info<BufType> {
pub fn compare_hash(&self, piece: u32, hash: &sha1::Sha1) -> Option<bool> {
let start = piece as usize * 20;
let end = start + 20;
let expected_hash = self.pieces.deref().get(start..end)?;
Some(expected_hash == hash.digest().bytes())
}
pub fn iter_filenames_and_lengths(
&self,
) -> impl Iterator<Item = (FileIteratorName<'_, BufType>, u64)> {
let single_it = std::iter::once(match (self.name.as_ref(), self.length) {
(Some(n), Some(l)) => Some((FileIteratorName::Single(Some(n)), l)),
_ => None,
});
let multi_it = self
.files
.as_deref()
.unwrap_or_default()
.iter()
.map(|f| Some((FileIteratorName::Tree(&f.path), f.length)));
single_it.chain(multi_it).flatten()
}
pub fn iter_file_lengths(&self) -> impl Iterator<Item = u64> + '_ {
std::iter::once(self.length)
.chain(
self.files
.as_deref()
.unwrap_or_default()
.iter()
.map(|f| Some(f.length)),
)
.flatten()
}
}
#[derive(Deserialize, Debug, Clone)]
pub struct TorrentMetaV1File<BufType: Clone> {
pub length: u64,
pub path: Vec<BufType>,
}
impl<BufType> TorrentMetaV1File<BufType>
where
BufType: Clone + AsRef<[u8]>,
{
pub fn full_path(&self, parent: &mut PathBuf) -> anyhow::Result<()> {
for p in self.path.iter() {
let bit = std::str::from_utf8(p.as_ref())?;
parent.push(bit);
}
Ok(())
}
}
impl<ByteBuf> CloneToOwned for TorrentMetaV1File<ByteBuf>
where
ByteBuf: CloneToOwned + Clone,
<ByteBuf as CloneToOwned>::Target: Clone,
{
type Target = TorrentMetaV1File<<ByteBuf as CloneToOwned>::Target>;
fn clone_to_owned(&self) -> Self::Target {
TorrentMetaV1File {
length: self.length,
path: self.path.clone_to_owned(),
}
}
}
impl<ByteBuf> CloneToOwned for TorrentMetaV1Info<ByteBuf>
where
ByteBuf: CloneToOwned + Clone,
<ByteBuf as CloneToOwned>::Target: Clone,
{
type Target = TorrentMetaV1Info<<ByteBuf as CloneToOwned>::Target>;
fn clone_to_owned(&self) -> Self::Target {
TorrentMetaV1Info {
name: self.name.clone_to_owned(),
pieces: self.pieces.clone_to_owned(),
piece_length: self.piece_length,
length: self.length,
md5sum: self.md5sum.clone_to_owned(),
files: self.files.clone_to_owned(),
}
}
}
impl<ByteBuf> CloneToOwned for TorrentMetaV1<ByteBuf>
where
ByteBuf: CloneToOwned + Clone,
<ByteBuf as CloneToOwned>::Target: Clone,
{
type Target = TorrentMetaV1<<ByteBuf as CloneToOwned>::Target>;
fn clone_to_owned(&self) -> Self::Target {
TorrentMetaV1 {
announce: self.announce.clone_to_owned(),
announce_list: self.announce_list.clone_to_owned(),
info: self.info.clone_to_owned(),
comment: self.comment.clone_to_owned(),
created_by: self.created_by.clone_to_owned(),
encoding: self.encoding.clone_to_owned(),
publisher: self.publisher.clone_to_owned(),
publisher_url: self.publisher_url.clone_to_owned(),
creation_date: self.creation_date,
info_hash: self.info_hash,
}
}
}
#[cfg(test)]
mod tests {
use std::io::Read;
use crate::serde_bencode::from_bytes;
use super::*;
#[test]
fn test_deserialize_torrent_owned() {
let mut buf = Vec::new();
let filename = "resources/ubuntu-21.04-desktop-amd64.iso.torrent";
std::fs::File::open(filename)
.unwrap()
.read_to_end(&mut buf)
.unwrap();
let torrent: TorrentMetaV1Owned = from_bytes(&buf).unwrap();
dbg!(torrent);
}
#[test]
fn test_deserialize_torrent_borrowed() {
let mut buf = Vec::new();
let filename = "resources/ubuntu-21.04-desktop-amd64.iso.torrent";
std::fs::File::open(filename)
.unwrap()
.read_to_end(&mut buf)
.unwrap();
let torrent: TorrentMetaV1Borrowed = from_bytes(&buf).unwrap();
dbg!(torrent);
}
#[test]
fn test_deserialize_torrent_with_info_hash() {
let mut buf = Vec::new();
let filename = "resources/ubuntu-21.04-desktop-amd64.iso.torrent";
std::fs::File::open(filename)
.unwrap()
.read_to_end(&mut buf)
.unwrap();
let torrent = torrent_from_bytes(&buf).unwrap();
assert_eq!(
torrent.info_hash,
*b"\x64\xa9\x80\xab\xe6\xe4\x48\x22\x6b\xb9\x30\xba\x06\x15\x92\xe4\x4c\x37\x81\xa1"
);
}
}

View file

@ -0,0 +1,228 @@
use byteorder::ByteOrder;
use serde::{Deserialize, Deserializer};
use std::{
fmt::Write,
marker::PhantomData,
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4},
str::FromStr,
};
use crate::buffers::ByteBuf;
#[derive(Clone, Copy)]
pub enum TrackerRequestEvent {
Started,
Stopped,
Completed,
}
pub struct TrackerRequest {
pub info_hash: [u8; 20],
pub peer_id: [u8; 20],
pub event: Option<TrackerRequestEvent>,
pub port: u16,
pub uploaded: u64,
pub downloaded: u64,
pub left: u64,
pub compact: bool,
pub no_peer_id: bool,
pub ip: Option<std::net::IpAddr>,
pub numwant: Option<usize>,
pub key: Option<String>,
pub trackerid: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct TrackerError<'a> {
#[serde(rename = "failure reason", borrow)]
failure_reason: ByteBuf<'a>,
}
#[derive(Deserialize, Debug)]
pub struct DictPeer<'a> {
#[serde(deserialize_with = "deserialize_ip_string")]
ip: IpAddr,
#[serde(borrow)]
peer_id: Option<ByteBuf<'a>>,
port: u16,
}
impl<'a> DictPeer<'a> {
fn as_sockaddr(&self) -> SocketAddr {
SocketAddr::new(self.ip, self.port)
}
}
#[derive(Debug)]
pub enum Peers<'a> {
Full(Vec<DictPeer<'a>>),
Compact(Vec<SocketAddrV4>),
}
impl<'a> Peers<'a> {
pub fn iter_sockaddrs(&self) -> Box<dyn Iterator<Item = std::net::SocketAddr> + '_> {
match self {
Peers::Full(d) => Box::new(d.iter().map(DictPeer::as_sockaddr)),
Peers::Compact(c) => Box::new(c.iter().copied().map(SocketAddr::V4)),
}
}
}
impl<'de: 'a, 'a> serde::de::Deserialize<'de> for Peers<'a> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor<'de> {
phantom: std::marker::PhantomData<&'de ()>,
}
impl<'de> serde::de::Visitor<'de> for Visitor<'de> {
type Value = Peers<'de>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a list of peers in dict or binary format")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut peers = Vec::new();
while let Some(peer) = seq.next_element::<DictPeer>()? {
peers.push(peer)
}
Ok(Peers::Full(peers))
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Peers::Compact(parse_compact_peers(v)))
}
}
deserializer.deserialize_any(Visitor {
phantom: PhantomData,
})
}
}
fn deserialize_ip_string<'de, D>(de: D) -> Result<IpAddr, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = IpAddr;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("expecting an IPv4 address")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
IpAddr::from_str(v).map_err(|e| E::custom(format!("cannot parse ip: {}", e)))
}
}
de.deserialize_str(Visitor {})
}
fn parse_compact_peers(b: &[u8]) -> Vec<SocketAddrV4> {
let mut ips = Vec::new();
for chunk in b.chunks_exact(6) {
let ip_chunk = &chunk[..4];
let port_chunk = &chunk[4..6];
let ipaddr = Ipv4Addr::new(ip_chunk[0], ip_chunk[1], ip_chunk[2], ip_chunk[3]);
let port = byteorder::BigEndian::read_u16(port_chunk);
ips.push(SocketAddrV4::new(ipaddr, port));
}
ips
}
#[derive(Deserialize, Debug)]
pub struct CompactTrackerResponse<'a> {
#[serde(rename = "warning message", borrow)]
pub warning_message: Option<ByteBuf<'a>>,
pub complete: u64,
pub interval: u64,
#[serde(rename = "min interval")]
pub min_interval: Option<u64>,
pub tracker_id: Option<ByteBuf<'a>>,
pub incomplete: u64,
pub peers: Peers<'a>,
}
impl TrackerRequest {
pub fn as_querystring(&self) -> String {
use urlencoding as u;
let mut s = String::new();
s.push_str("info_hash=");
s.push_str(u::encode_binary(&self.info_hash).as_ref());
s.push_str("&peer_id=");
s.push_str(u::encode_binary(&self.peer_id).as_ref());
if let Some(event) = self.event {
write!(
s,
"&event={}",
match event {
TrackerRequestEvent::Started => "started",
TrackerRequestEvent::Stopped => "stopped",
TrackerRequestEvent::Completed => "completed",
}
)
.unwrap();
}
write!(s, "&port={}", self.port).unwrap();
write!(s, "&uploaded={}", self.uploaded).unwrap();
write!(s, "&downloaded={}", self.downloaded).unwrap();
write!(s, "&left={}", self.left).unwrap();
write!(s, "&compact={}", if self.compact { 1 } else { 0 }).unwrap();
write!(s, "&no_peer_id={}", if self.no_peer_id { 1 } else { 0 }).unwrap();
if let Some(ip) = &self.ip {
write!(s, "&ip={}", ip).unwrap();
}
if let Some(numwant) = &self.numwant {
write!(s, "&numwant={}", numwant).unwrap();
}
if let Some(key) = &self.key {
write!(s, "&key={}", key).unwrap();
}
if let Some(trackerid) = &self.trackerid {
write!(s, "&trackerid={}", trackerid).unwrap();
}
s
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize() {
let info_hash = [
1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
];
let peer_id = [
1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
];
let request = TrackerRequest {
info_hash,
peer_id,
port: 6881,
uploaded: 0,
downloaded: 0,
left: 1024 * 1024,
compact: true,
no_peer_id: false,
event: Some(TrackerRequestEvent::Started),
ip: Some("127.0.0.1".parse().unwrap()),
numwant: None,
key: None,
trackerid: None,
};
dbg!(request.as_querystring());
}
}

View file

@ -0,0 +1 @@
pub type BF = bitvec::vec::BitVec<bitvec::order::Msb0, u8>;