I can now download torrent metainfo from peers!!

This commit is contained in:
Igor Katson 2021-07-02 17:58:53 +01:00
parent 302e95649d
commit 48dcf2d1bd
5 changed files with 258 additions and 143 deletions

Binary file not shown.

View file

@ -2,10 +2,10 @@ use serde::{Deserialize, Deserializer};
use crate::clone_to_owned::CloneToOwned;
#[derive(PartialEq, Eq, Hash, Clone)]
#[derive(Default, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
pub struct ByteString(pub Vec<u8>);
#[derive(Deserialize, PartialEq, Eq, Hash, Clone)]
#[derive(Default, Deserialize, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
#[serde(transparent)]
pub struct ByteBuf<'a>(pub &'a [u8]);

View file

@ -48,7 +48,7 @@ const MSGID_REQUEST: u8 = 6;
const MSGID_PIECE: u8 = 7;
const MSGID_EXTENDED: u8 = 20;
const MY_EXTENDED_UT_METADATA: u8 = 0;
const MY_EXTENDED_UT_METADATA: u8 = 3;
#[derive(Debug)]
pub enum MessageDeserializeError {
@ -221,7 +221,7 @@ where
Message::KeepAlive => Message::KeepAlive,
Message::Have(v) => Message::Have(*v),
Message::NotInterested => Message::NotInterested,
Message::Extended(_) => unimplemented!(),
Message::Extended(e) => Message::Extended(e.clone_to_owned()),
}
}
}
@ -317,9 +317,11 @@ where
Ok(msg_len)
}
Message::Extended(e) => {
e.serialize(out, peer_extended_handshake);
e.serialize(out, peer_extended_handshake)?;
let msg_size = out.len();
BE::write_u32(&mut out[..4], msg_size as u32);
// no fucking idea why +1, but I tweaked that for it all to match up
// with real messages.
BE::write_u32(&mut out[..4], (msg_size - PREAMBLE_LEN + 1) as u32);
Ok(msg_size)
}
}
@ -528,8 +530,8 @@ impl<'a> Handshake<'a> {
))?;
Ok((Self::bopts().deserialize(&hbuf).unwrap(), expected_len))
}
pub fn serialize(&self) -> Vec<u8> {
Self::bopts().serialize(&self).unwrap()
pub fn serialize(&self, buf: &mut Vec<u8>) {
Self::bopts().serialize_into(buf, &self).unwrap()
}
}
@ -553,7 +555,11 @@ impl Request {
#[derive(Debug)]
pub enum UtMetadata<ByteBuf> {
Request(u32),
Data(u32, ByteBuf),
Data {
piece: u32,
total_size: u32,
data: ByteBuf,
},
Reject(u32),
}
@ -563,7 +569,15 @@ impl<ByteBuf: CloneToOwned> CloneToOwned for UtMetadata<ByteBuf> {
fn clone_to_owned(&self) -> Self::Target {
match self {
UtMetadata::Request(req) => UtMetadata::Request(*req),
UtMetadata::Data(piece, data) => UtMetadata::Data(*piece, data.clone_to_owned()),
UtMetadata::Data {
piece,
total_size,
data,
} => UtMetadata::Data {
piece: *piece,
total_size: *total_size,
data: data.clone_to_owned(),
},
UtMetadata::Reject(piece) => UtMetadata::Reject(*piece),
}
}
@ -590,11 +604,15 @@ impl<'a, ByteBuf: 'a> UtMetadata<ByteBuf> {
};
serde_bencode_ser::bencode_serialize_to_writer(message, buf).unwrap()
}
UtMetadata::Data(piece, data) => {
UtMetadata::Data {
piece,
total_size,
data,
} => {
let message = Message {
msg_type: 1,
piece: *piece,
total_size: Some(data.as_ref().len() as u32),
total_size: Some(*total_size),
};
serde_bencode_ser::bencode_serialize_to_writer(message, buf).unwrap();
buf.write_all(data.as_ref()).unwrap();
@ -643,14 +661,11 @@ impl<'a, ByteBuf: 'a> UtMetadata<ByteBuf> {
"expected key total_size to be present in UtMetadata \"data\" message"
))
})?;
if remaining.len() != total_size as usize {
return Err(MessageDeserializeError::Other(anyhow::anyhow!(
"remaining bytes len {} != total_size {}",
remaining.len(),
total_size
)));
}
Ok(UtMetadata::Data(message.piece, ByteBuf::from(remaining)))
Ok(UtMetadata::Data {
piece: message.piece,
total_size,
data: ByteBuf::from(remaining),
})
}
// reject
2 => {
@ -730,16 +745,6 @@ impl<'a, ByteBuf: 'a + std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf
where
ByteBuf: Deserialize<'a> + From<&'a [u8]>,
{
{
use std::io::Write;
let mut f = std::fs::OpenOptions::new()
.create(true)
.write(true)
.open("/tmp/msg")
.unwrap();
f.write_all(buf).unwrap();
}
use crate::serde_bencode_de::from_bytes;
let emsg_id = buf.get(0).copied().ok_or_else(|| {
@ -759,17 +764,8 @@ impl<'a, ByteBuf: 'a + std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf
MY_EXTENDED_UT_METADATA => {
Ok(ExtendedMessage::UtMetadata(UtMetadata::deserialize(&buf)?))
}
other => Ok(ExtendedMessage::Dyn(emsg_id, from_bytes(&buf)?)),
_ => Ok(ExtendedMessage::Dyn(emsg_id, from_bytes(&buf)?)),
}
// match self {
// ExtendedMessage::Dyn(v, msg) => {
// crate::bencode_value::dyn_from_bytes(buf)
// }
// ExtendedMessage::Handshake(h) => {
// crate::serde_bencode_ser::bencode_serialize_to_writer(h, out).unwrap()
// }
// }
}
}
@ -777,11 +773,17 @@ impl<'a, ByteBuf: 'a + std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf
pub struct YourIP(pub IpAddr);
impl Serialize for YourIP {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
todo!()
match self.0 {
IpAddr::V4(ipv4) => {
let buf = ipv4.octets();
serializer.serialize_bytes(&buf)
}
IpAddr::V6(_) => todo!(),
}
}
}
@ -823,7 +825,7 @@ impl<'de> Deserialize<'de> for YourIP {
}
}
#[derive(Deserialize, Serialize, Debug)]
#[derive(Deserialize, Serialize, Debug, Default)]
pub struct ExtendedHandshake<ByteBuf: Eq + std::hash::Hash> {
#[serde(bound(deserialize = "ByteBuf: From<&'de [u8]>"))]
pub m: HashMap<ByteBuf, u8>,
@ -839,7 +841,23 @@ pub struct ExtendedHandshake<ByteBuf: Eq + std::hash::Hash> {
pub ipv4: Option<ByteBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reqq: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata_size: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub complete_ago: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub upload_only: Option<u32>,
}
impl ExtendedHandshake<ByteBuf<'static>> {
pub fn new() -> Self {
let mut features = HashMap::new();
features.insert(ByteBuf(b"ut_metadata"), MY_EXTENDED_UT_METADATA);
Self {
m: features,
..Default::default()
}
}
}
impl<ByteBuf: Eq + std::hash::Hash> ExtendedHandshake<ByteBuf> {
@ -874,21 +892,25 @@ where
ipv4: self.ipv4.clone_to_owned(),
reqq: self.reqq,
metadata_size: self.metadata_size,
complete_ago: self.complete_ago,
upload_only: self.upload_only,
}
}
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddr, str::FromStr};
use std::{fs::File, io::Read, net::SocketAddr, str::FromStr};
use log::info;
use parking_lot::{Mutex, RwLock};
use parking_lot::RwLock;
use tokio::sync::mpsc::UnboundedSender;
use crate::{
lengths::ceil_div_u64,
peer_connection::{PeerConnection, PeerConnectionHandler, WriterRequest},
peer_id::generate_peer_id,
torrent_metainfo::TorrentMetaV1Borrowed,
};
use std::sync::Once;
@ -913,30 +935,46 @@ mod tests {
let peer_id = [
1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
];
let b = dbg!(Handshake::new(info_hash, peer_id).serialize());
assert_eq!(b.len(), 20 + 20 + 8 + 19 + 1);
let mut buf = Vec::new();
Handshake::new(info_hash, peer_id).serialize(&mut buf);
assert_eq!(buf.len(), 20 + 20 + 8 + 19 + 1);
}
#[test]
fn test_extended_serialize() {
let feats = HashMap::new();
let msg =
Message::<ByteBuf<'static>>::Extended(ExtendedMessage::Handshake(ExtendedHandshake {
m: feats,
p: None,
v: None,
yourip: None,
ipv6: None,
ipv4: None,
reqq: None,
metadata_size: None,
}));
let msg = Message::Extended(ExtendedMessage::Handshake(ExtendedHandshake::new()));
let mut out = Vec::new();
msg.serialize(&mut out, None);
msg.serialize(&mut out, None).unwrap();
dbg!(out);
}
#[test]
fn test_deserialize_serialize_extended_is_same() {
use std::fs::File;
use std::io::Read;
let mut buf = Vec::new();
File::open("resources/test/extended-handshake.bin")
.unwrap()
.read_to_end(&mut buf)
.unwrap();
let (msg, size) = MessageBorrowed::deserialize(&buf).unwrap();
assert_eq!(size, buf.len());
let mut write_buf = Vec::new();
msg.serialize(&mut write_buf, None).unwrap();
if buf != write_buf {
{
use std::io::Write;
let mut f = std::fs::OpenOptions::new()
.create(true)
.write(true)
.open("/tmp/test_deserialize_serialize_extended_is_same")
.unwrap();
f.write_all(&write_buf).unwrap();
}
panic!("resources/test/extended-handshake.bin did not serialize exactly the same. Dumped to /tmp/test_deserialize_serialize_extended_is_same, you can compare with resources/test/extended-handshake.bin")
}
}
#[tokio::test]
async fn test_connect_to_local_qbittorrent() {
init_logging();
@ -961,6 +999,34 @@ mod tests {
fn on_received_message(&self, msg: Message<ByteBuf<'_>>) -> anyhow::Result<()> {
info!("received message: {:?}", msg);
if let Message::Extended(ExtendedMessage::UtMetadata(UtMetadata::Data {
piece,
total_size,
data,
})) = msg
{
// this just assumes piece come in the order requested.
let mut f = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open("/tmp/torrent")
.unwrap();
f.write_all(&data).unwrap();
// test if it's the last piece
if data.len() < CHUNK_SIZE as usize {
let mut buf = Vec::new();
let mut f = File::open("/tmp/torrent").unwrap();
f.read_to_end(&mut buf).unwrap();
// let torrent: TorrentMetaV1Borrowed =
// crate::torrent_metainfo::torrent_from_bytes(&buf).unwrap();
let torrent: BencodeValue<ByteBuf> =
crate::bencode_value::dyn_from_bytes(&buf).unwrap();
dbg!(torrent);
}
}
Ok(())
}
@ -975,10 +1041,24 @@ mod tests {
.write()
.replace(extended_handshake.clone_to_owned());
self.tx
.send(WriterRequest::Message(Message::Extended(
ExtendedMessage::UtMetadata(UtMetadata::Request(0)),
)))
.unwrap()
.send(WriterRequest::Message(Message::Unchoke))
.unwrap();
self.tx
.send(WriterRequest::Message(Message::Interested))
.unwrap();
let total_metadata_chunks = ceil_div_u64(
extended_handshake.metadata_size.unwrap() as u64,
CHUNK_SIZE as u64,
);
for i in 0..total_metadata_chunks {
self.tx
.send(WriterRequest::Message(Message::Extended(
ExtendedMessage::UtMetadata(UtMetadata::Request(i as u32)),
)))
.unwrap()
}
}
}

View file

@ -66,6 +66,33 @@ pub struct PeerConnection<H> {
// }
// }
macro_rules! read_one {
($conn:ident, $read_buf:ident, $read_so_far:ident) => {{
let (extended, size) = loop {
match MessageBorrowed::deserialize(&$read_buf[..$read_so_far]) {
Ok((msg, size)) => break (msg, size),
Err(MessageDeserializeError::NotEnoughData(d, _)) => {
if $read_buf.len() < $read_so_far + d {
$read_buf.reserve(d);
$read_buf.resize($read_buf.capacity(), 0);
}
let size = $conn
.read(&mut $read_buf[$read_so_far..])
.await
.context("error reading from peer")?;
if size == 0 {
anyhow::bail!("disconnected while reading, read so far: {}", $read_so_far)
}
$read_so_far += size;
}
Err(e) => return Err(e.into()),
}
};
(extended, size)
}};
}
impl<H: PeerConnectionHandler> PeerConnection<H> {
pub fn new(addr: SocketAddr, info_hash: [u8; 20], peer_id: [u8; 20], handler: H) -> Self {
PeerConnection {
@ -87,10 +114,14 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
let mut conn = tokio::net::TcpStream::connect(self.addr)
.await
.context("error connecting")?;
let mut write_buf = Vec::<u8>::with_capacity(PIECE_MESSAGE_DEFAULT_LEN);
let handshake = Handshake::new(self.info_hash, self.peer_id);
conn.write_all(&handshake.serialize())
handshake.serialize(&mut write_buf);
conn.write_all(&write_buf)
.await
.context("error writing handshake")?;
write_buf.clear();
let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2];
let mut read_so_far = conn
.read(&mut read_buf)
@ -121,34 +152,15 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
read_so_far -= size;
if supports_extended {
// Read extended handshake
// I wasn't able to extract that into a function.
// TODO: extract into a macro.
let (extended, size) = loop {
match MessageBorrowed::deserialize(&read_buf[..read_so_far]) {
Ok((msg, size)) => break (msg, size),
Err(MessageDeserializeError::NotEnoughData(d, _)) => {
if read_buf.len() < read_so_far + d {
read_buf.reserve(d);
read_buf.resize(read_buf.capacity(), 0);
}
let size = conn
.read(&mut read_buf[read_so_far..])
.await
.context("error reading from peer")?;
if size == 0 {
anyhow::bail!(
"disconnected while reading, read so far: {}",
read_so_far
)
}
read_so_far += size;
}
Err(e) => return Err(e.into()),
}
};
let my_extended =
Message::Extended(ExtendedMessage::Handshake(ExtendedHandshake::new()));
my_extended.serialize(&mut write_buf, None).unwrap();
conn.write_all(&write_buf)
.await
.context("error writing extended handshake")?;
write_buf.clear();
let (extended, size) = read_one!(conn, read_buf, read_so_far);
match extended {
Message::Extended(ExtendedMessage::Handshake(h)) => {
trace!("received from {}: {:?}", self.addr, &h);
@ -167,13 +179,15 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
let (mut read_half, mut write_half) = tokio::io::split(conn);
let writer = async move {
let mut buf = Vec::<u8>::with_capacity(PIECE_MESSAGE_DEFAULT_LEN);
let keep_alive_interval = Duration::from_secs(120);
if self.handler.get_have_bytes() > 0 {
if let Some(len) = self.handler.serialize_bitfield_message_to_buf(&mut buf) {
if let Some(len) = self
.handler
.serialize_bitfield_message_to_buf(&mut write_buf)
{
write_half
.write_all(&buf[..len])
.write_all(&write_buf[..len])
.await
.context("error writing bitfield to peer")?;
debug!("sent bitfield to {}", self.addr);
@ -193,16 +207,16 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
let len = match &req {
WriterRequest::Message(msg) => {
msg.serialize(&mut buf, extended_handshake.as_ref())?
msg.serialize(&mut write_buf, extended_handshake.as_ref())?
}
WriterRequest::ReadChunkRequest(chunk) => {
// this whole section is an optimization
buf.resize(PIECE_MESSAGE_DEFAULT_LEN, 0);
let preamble_len = serialize_piece_preamble(&chunk, &mut buf);
write_buf.resize(PIECE_MESSAGE_DEFAULT_LEN, 0);
let preamble_len = serialize_piece_preamble(&chunk, &mut write_buf);
let full_len = preamble_len + chunk.size as usize;
buf.resize(full_len, 0);
write_buf.resize(full_len, 0);
self.handler
.read_chunk(chunk, &mut buf[preamble_len..])
.read_chunk(chunk, &mut write_buf[preamble_len..])
.with_context(|| format!("error reading chunk {:?}", chunk))?;
uploaded_add = Some(chunk.size);
full_len
@ -212,9 +226,10 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
debug!("sending to {}: {:?}, length={}", self.addr, &req, len);
write_half
.write_all(&buf[..len])
.write_all(&write_buf[..len])
.await
.context("error writing the message to peer")?;
write_buf.clear();
if let Some(uploaded_add) = uploaded_add {
self.handler.on_uploaded_bytes(uploaded_add)
@ -228,33 +243,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
let reader = async move {
loop {
let (message, size) = loop {
match MessageBorrowed::deserialize(&read_buf[..read_so_far]) {
Ok((msg, size)) => {
break (msg, size);
}
Err(MessageDeserializeError::NotEnoughData(d, _)) => {
if read_buf.len() < read_so_far + d {
read_buf.reserve(d);
read_buf.resize(read_buf.capacity(), 0);
}
let size = read_half
.read(&mut read_buf[read_so_far..])
.await
.context("error reading from peer")?;
if size == 0 {
anyhow::bail!(
"disconnected while reading, read so far: {}",
read_so_far
)
}
read_so_far += size;
}
Err(e) => return Err(e.into()),
}
};
let (message, size) = read_one!(read_half, read_buf, read_so_far);
trace!("received from {}: {:?}", self.addr, &message);
self.handler

View file

@ -1,5 +1,9 @@
use std::collections::BTreeMap;
use serde::{Serialize, Serializer};
use crate::buffers::ByteString;
#[derive(Debug)]
pub enum SerErrorKind {
Other(anyhow::Error),
@ -56,27 +60,37 @@ impl std::fmt::Display for SerError {
struct BencodeSerializer<W: std::io::Write> {
writer: W,
hack_no_bytestring_prefix: bool,
}
impl<W: std::io::Write> BencodeSerializer<W> {
pub fn new(writer: W) -> Self {
Self {
writer,
hack_no_bytestring_prefix: false,
}
}
fn write_raw(&mut self, buf: &[u8]) -> Result<(), SerError> {
self.writer
.write_all(buf)
.map_err(|e| SerError::from_err_with_ser(e, &self))
}
fn write_fmt(&mut self, fmt: core::fmt::Arguments<'_>) -> Result<(), SerError> {
self.writer
.write_fmt(fmt)
.map_err(|e| SerError::from_err_with_ser(e, &self))
}
fn write_byte(&mut self, byte: u8) -> Result<(), SerError> {
self.writer
.write_all(&[byte])
.map_err(|e| SerError::from_err_with_ser(e, &self))
self.write_raw(&[byte])
}
fn write_number<N: std::fmt::Display>(&mut self, number: N) -> Result<(), SerError> {
self.write_fmt(format_args!("i{}e", number))
}
fn write_bytes(&mut self, bytes: &[u8]) -> Result<(), SerError> {
self.write_fmt(format_args!("{}:", bytes.len()))?;
self.writer
.write_all(bytes)
.map_err(|e| SerError::from_err_with_ser(e, &self))
if !self.hack_no_bytestring_prefix {
self.write_fmt(format_args!("{}:", bytes.len()))?;
}
self.write_raw(bytes)
}
}
@ -162,6 +176,8 @@ impl<'ser, W: std::io::Write> serde::ser::SerializeTupleVariant for SerializeTup
struct SerializeMap<'ser, W: std::io::Write> {
ser: &'ser mut BencodeSerializer<W>,
tmp: BTreeMap<ByteString, ByteString>,
last_key: Option<ByteString>,
}
impl<'ser, W: std::io::Write> serde::ser::SerializeMap for SerializeMap<'ser, W> {
type Ok = ();
@ -172,23 +188,39 @@ impl<'ser, W: std::io::Write> serde::ser::SerializeMap for SerializeMap<'ser, W>
where
T: serde::Serialize,
{
key.serialize(&mut *self.ser)
let mut buf = Vec::new();
let mut ser = BencodeSerializer::new(&mut buf);
ser.hack_no_bytestring_prefix = true;
key.serialize(&mut ser)?;
self.last_key.replace(ByteString::from(buf));
Ok(())
// key.serialize(&mut *self.ser);
}
fn serialize_value<T: ?Sized>(&mut self, value: &T) -> Result<(), Self::Error>
where
T: serde::Serialize,
{
value.serialize(&mut *self.ser)
let mut buf = Vec::new();
let mut ser = BencodeSerializer::new(&mut buf);
value.serialize(&mut ser)?;
self.tmp
.insert(self.last_key.take().unwrap(), ByteString::from(buf));
Ok(())
}
fn end(self) -> Result<Self::Ok, Self::Error> {
for (key, value) in self.tmp {
self.ser.write_bytes(&key)?;
self.ser.write_raw(&value)?;
}
self.ser.write_byte(b'e')
}
}
struct SerializeStruct<'ser, W: std::io::Write> {
ser: &'ser mut BencodeSerializer<W>,
tmp: BTreeMap<&'static str, ByteString>,
}
impl<'ser, W: std::io::Write> serde::ser::SerializeStruct for SerializeStruct<'ser, W> {
type Ok = ();
@ -203,11 +235,18 @@ impl<'ser, W: std::io::Write> serde::ser::SerializeStruct for SerializeStruct<'s
where
T: serde::Serialize,
{
self.ser.write_bytes(key.as_bytes())?;
value.serialize(&mut *self.ser)
let mut buf = Vec::new();
let mut ser = BencodeSerializer::new(&mut buf);
value.serialize(&mut ser)?;
self.tmp.insert(key, ByteString::from(buf));
Ok(())
}
fn end(self) -> Result<Self::Ok, Self::Error> {
for (key, value) in self.tmp {
self.ser.write_bytes(key.as_bytes())?;
self.ser.write_raw(&dbg!(value))?;
}
self.ser.write_byte(b'e')
}
}
@ -412,7 +451,11 @@ impl<'ser, W: std::io::Write> Serializer for &'ser mut BencodeSerializer<W> {
fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
self.write_byte(b'd')?;
Ok(SerializeMap { ser: self })
Ok(SerializeMap {
ser: self,
tmp: Default::default(),
last_key: None,
})
}
fn serialize_struct(
@ -421,7 +464,10 @@ impl<'ser, W: std::io::Write> Serializer for &'ser mut BencodeSerializer<W> {
_len: usize,
) -> Result<Self::SerializeStruct, Self::Error> {
self.write_byte(b'd')?;
Ok(SerializeStruct { ser: self })
Ok(SerializeStruct {
ser: self,
tmp: Default::default(),
})
}
fn serialize_struct_variant(
@ -439,7 +485,7 @@ pub fn bencode_serialize_to_writer<T: Serialize, W: std::io::Write>(
value: T,
writer: &mut W,
) -> Result<(), SerError> {
let mut serializer = BencodeSerializer { writer };
let mut serializer = BencodeSerializer::new(writer);
value.serialize(&mut serializer)?;
Ok(())
}