Screwing around with extended messages

This commit is contained in:
Igor Katson 2021-07-02 13:00:46 +01:00
parent d722f0edcb
commit 302e95649d
5 changed files with 476 additions and 96 deletions

View file

@ -28,6 +28,22 @@ where
} }
} }
impl CloneToOwned for u8 {
type Target = u8;
fn clone_to_owned(&self) -> Self::Target {
*self
}
}
impl CloneToOwned for u32 {
type Target = u32;
fn clone_to_owned(&self) -> Self::Target {
*self
}
}
impl<K, V> CloneToOwned for HashMap<K, V> impl<K, V> CloneToOwned for HashMap<K, V>
where where
K: CloneToOwned, K: CloneToOwned,

View file

@ -1,8 +1,13 @@
use std::collections::HashMap; use std::{
collections::HashMap,
io::Write,
marker::PhantomData,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
};
use bincode::Options; use bincode::Options;
use byteorder::{ByteOrder, BE}; use byteorder::{ByteOrder, BE};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Deserializer, Serialize};
use crate::{ use crate::{
bencode_value::BencodeValue, bencode_value::BencodeValue,
@ -10,6 +15,8 @@ use crate::{
clone_to_owned::CloneToOwned, clone_to_owned::CloneToOwned,
constants::CHUNK_SIZE, constants::CHUNK_SIZE,
lengths::ChunkInfo, lengths::ChunkInfo,
serde_bencode_de::BencodeDeserializer,
serde_bencode_ser,
}; };
const INTEGER_LEN: usize = 4; const INTEGER_LEN: usize = 4;
@ -41,6 +48,8 @@ const MSGID_REQUEST: u8 = 6;
const MSGID_PIECE: u8 = 7; const MSGID_PIECE: u8 = 7;
const MSGID_EXTENDED: u8 = 20; const MSGID_EXTENDED: u8 = 20;
const MY_EXTENDED_UT_METADATA: u8 = 0;
#[derive(Debug)] #[derive(Debug)]
pub enum MessageDeserializeError { pub enum MessageDeserializeError {
NotEnoughData(usize, &'static str), NotEnoughData(usize, &'static str),
@ -255,7 +264,11 @@ where
Message::Extended(_) => (0, MSGID_EXTENDED), Message::Extended(_) => (0, MSGID_EXTENDED),
} }
} }
pub fn serialize(&self, out: &mut Vec<u8>) -> usize { pub fn serialize(
&self,
out: &mut Vec<u8>,
peer_extended_handshake: Option<&ExtendedHandshake<ByteString>>,
) -> anyhow::Result<usize> {
let (lp, msg_id) = self.len_prefix_and_msg_id(); let (lp, msg_id) = self.len_prefix_and_msg_id();
out.resize(PREAMBLE_LEN, 0); out.resize(PREAMBLE_LEN, 0);
@ -272,17 +285,17 @@ where
debug_assert_eq!((&out[PREAMBLE_LEN..]).len(), 12); debug_assert_eq!((&out[PREAMBLE_LEN..]).len(), 12);
ser.serialize_into(&mut out[PREAMBLE_LEN..], request) ser.serialize_into(&mut out[PREAMBLE_LEN..], request)
.unwrap(); .unwrap();
MSG_LEN Ok(MSG_LEN)
} }
Message::Bitfield(b) => { Message::Bitfield(b) => {
let block_len = b.as_ref().len(); let block_len = b.as_ref().len();
let msg_len = PREAMBLE_LEN + block_len; let msg_len = PREAMBLE_LEN + block_len;
out.resize(msg_len, 0); out.resize(msg_len, 0);
(&mut out[PREAMBLE_LEN..PREAMBLE_LEN + block_len]).copy_from_slice(b.as_ref()); (&mut out[PREAMBLE_LEN..PREAMBLE_LEN + block_len]).copy_from_slice(b.as_ref());
msg_len Ok(msg_len)
} }
Message::Choke | Message::Unchoke | Message::Interested | Message::NotInterested => { Message::Choke | Message::Unchoke | Message::Interested | Message::NotInterested => {
PREAMBLE_LEN Ok(PREAMBLE_LEN)
} }
Message::Piece(p) => { Message::Piece(p) => {
let block_len = p.block.as_ref().len(); let block_len = p.block.as_ref().len();
@ -291,23 +304,23 @@ where
out.resize(msg_len, 0); out.resize(msg_len, 0);
let tmp = &mut out[PREAMBLE_LEN..]; let tmp = &mut out[PREAMBLE_LEN..];
p.serialize(&mut tmp[..payload_len]); p.serialize(&mut tmp[..payload_len]);
msg_len Ok(msg_len)
} }
Message::KeepAlive => { Message::KeepAlive => {
// the len prefix was already written out to buf // the len prefix was already written out to buf
4 Ok(4)
} }
Message::Have(v) => { Message::Have(v) => {
let msg_len = PREAMBLE_LEN + 4; let msg_len = PREAMBLE_LEN + 4;
out.resize(msg_len, 0); out.resize(msg_len, 0);
BE::write_u32(&mut out[PREAMBLE_LEN..], *v); BE::write_u32(&mut out[PREAMBLE_LEN..], *v);
msg_len Ok(msg_len)
} }
Message::Extended(e) => { Message::Extended(e) => {
e.serialize(out); e.serialize(out, peer_extended_handshake);
let msg_size = out.len(); let msg_size = out.len();
BE::write_u32(&mut out[..4], msg_size as u32); BE::write_u32(&mut out[..4], msg_size as u32);
msg_size Ok(msg_size)
} }
} }
} }
@ -496,6 +509,9 @@ impl<'a> Handshake<'a> {
peer_id, peer_id,
} }
} }
pub fn supports_extended(&self) -> bool {
self.reserved[5] & 0x10 > 0
}
fn bopts() -> impl bincode::Options { fn bopts() -> impl bincode::Options {
bincode::DefaultOptions::new() bincode::DefaultOptions::new()
} }
@ -535,28 +551,184 @@ impl Request {
} }
#[derive(Debug)] #[derive(Debug)]
pub enum ExtendedMessage<ByteBuf: std::hash::Hash + Eq> { pub enum UtMetadata<ByteBuf> {
Handshake(ExtendedHandshake<ByteBuf>), Request(u32),
Dyn(u8, BencodeValue<ByteBuf>), Data(u32, ByteBuf),
Reject(u32),
} }
impl<ByteBuf: std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf> { impl<ByteBuf: CloneToOwned> CloneToOwned for UtMetadata<ByteBuf> {
fn serialize(&self, out: &mut Vec<u8>) { type Target = UtMetadata<<ByteBuf as CloneToOwned>::Target>;
fn clone_to_owned(&self) -> Self::Target {
match self { match self {
ExtendedMessage::Dyn(msg_id, v) => { UtMetadata::Request(req) => UtMetadata::Request(*req),
out.push(*msg_id); UtMetadata::Data(piece, data) => UtMetadata::Data(*piece, data.clone_to_owned()),
crate::serde_bencode_ser::bencode_serialize_to_writer(v, out).unwrap() UtMetadata::Reject(piece) => UtMetadata::Reject(*piece),
}
}
}
impl<'a, ByteBuf: 'a> UtMetadata<ByteBuf> {
fn serialize(&self, buf: &mut Vec<u8>)
where
ByteBuf: AsRef<[u8]>,
{
#[derive(Serialize)]
struct Message {
msg_type: u32,
piece: u32,
#[serde(skip_serializing_if = "Option::is_none")]
total_size: Option<u32>,
}
match self {
UtMetadata::Request(piece) => {
let message = Message {
msg_type: 0,
piece: *piece,
total_size: None,
};
serde_bencode_ser::bencode_serialize_to_writer(message, buf).unwrap()
} }
ExtendedMessage::Handshake(h) => { UtMetadata::Data(piece, data) => {
out.push(0); let message = Message {
crate::serde_bencode_ser::bencode_serialize_to_writer(h, out).unwrap() msg_type: 1,
piece: *piece,
total_size: Some(data.as_ref().len() as u32),
};
serde_bencode_ser::bencode_serialize_to_writer(message, buf).unwrap();
buf.write_all(data.as_ref()).unwrap();
}
UtMetadata::Reject(piece) => {
let message = Message {
msg_type: 2,
piece: *piece,
total_size: None,
};
serde_bencode_ser::bencode_serialize_to_writer(message, buf).unwrap();
} }
} }
} }
fn deserialize(buf: &'a [u8]) -> Result<Self, MessageDeserializeError>
fn deserialize<'de>(mut buf: &'de [u8]) -> Result<Self, MessageDeserializeError>
where where
ByteBuf: Deserialize<'de> + From<&'de [u8]>, ByteBuf: From<&'a [u8]>,
{
let mut de = BencodeDeserializer::new_from_buf(buf);
#[derive(Deserialize)]
struct Message {
msg_type: u32,
piece: u32,
total_size: Option<u32>,
}
let message =
Message::deserialize(&mut de).map_err(|e| MessageDeserializeError::Other(e.into()))?;
let remaining = de.into_remaining();
match message.msg_type {
// request
0 => {
if !remaining.is_empty() {
return Err(MessageDeserializeError::Other(anyhow::anyhow!(
"trailing bytes when decoding UtMetadata"
)));
}
Ok(UtMetadata::Request(message.piece))
}
// data
1 => {
let total_size = message.total_size.ok_or_else(|| {
MessageDeserializeError::Other(anyhow::anyhow!(
"expected key total_size to be present in UtMetadata \"data\" message"
))
})?;
if remaining.len() != total_size as usize {
return Err(MessageDeserializeError::Other(anyhow::anyhow!(
"remaining bytes len {} != total_size {}",
remaining.len(),
total_size
)));
}
Ok(UtMetadata::Data(message.piece, ByteBuf::from(remaining)))
}
// reject
2 => {
if !remaining.is_empty() {
return Err(MessageDeserializeError::Other(anyhow::anyhow!(
"trailing bytes when decoding UtMetadata"
)));
}
Ok(UtMetadata::Reject(message.piece))
}
other => {
return Err(MessageDeserializeError::Other(anyhow::anyhow!(
"unrecognized ut_metadata message type {}",
other
)))
}
}
}
}
#[derive(Debug)]
pub enum ExtendedMessage<ByteBuf: std::hash::Hash + Eq> {
Handshake(ExtendedHandshake<ByteBuf>),
UtMetadata(UtMetadata<ByteBuf>),
Dyn(u8, BencodeValue<ByteBuf>),
}
impl<ByteBuf> CloneToOwned for ExtendedMessage<ByteBuf>
where
ByteBuf: CloneToOwned + std::hash::Hash + Eq,
<ByteBuf as CloneToOwned>::Target: std::hash::Hash + Eq,
{
type Target = ExtendedMessage<<ByteBuf as CloneToOwned>::Target>;
fn clone_to_owned(&self) -> Self::Target {
match self {
ExtendedMessage::Handshake(h) => ExtendedMessage::Handshake(h.clone_to_owned()),
ExtendedMessage::Dyn(u, d) => ExtendedMessage::Dyn(*u, d.clone_to_owned()),
ExtendedMessage::UtMetadata(m) => ExtendedMessage::UtMetadata(m.clone_to_owned()),
}
}
}
impl<'a, ByteBuf: 'a + std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf> {
fn serialize(
&self,
out: &mut Vec<u8>,
extended_handshake: Option<&ExtendedHandshake<ByteString>>,
) -> anyhow::Result<()>
where
ByteBuf: AsRef<[u8]>,
{
match self {
ExtendedMessage::Dyn(msg_id, v) => {
out.push(*msg_id);
crate::serde_bencode_ser::bencode_serialize_to_writer(v, out)?;
}
ExtendedMessage::Handshake(h) => {
out.push(0);
crate::serde_bencode_ser::bencode_serialize_to_writer(h, out)?;
}
ExtendedMessage::UtMetadata(u) => {
let h = extended_handshake.ok_or_else(|| {
anyhow::anyhow!("need peer's handshake to serialize ut_metadata")
})?;
let emsg_id = h
.get_msgid(b"ut_metadata")
.ok_or_else(|| anyhow::anyhow!("peer doesn't support ut_metadata"))?;
out.push(emsg_id);
u.serialize(out);
}
}
Ok(())
}
fn deserialize(mut buf: &'a [u8]) -> Result<Self, MessageDeserializeError>
where
ByteBuf: Deserialize<'a> + From<&'a [u8]>,
{ {
{ {
use std::io::Write; use std::io::Write;
@ -583,10 +755,13 @@ impl<ByteBuf: std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf> {
})?; })?;
match emsg_id { match emsg_id {
// handshake
0 => Ok(ExtendedMessage::Handshake(from_bytes(&buf)?)), 0 => Ok(ExtendedMessage::Handshake(from_bytes(&buf)?)),
other => Ok(ExtendedMessage::Dyn(other, from_bytes(&buf)?)), MY_EXTENDED_UT_METADATA => {
Ok(ExtendedMessage::UtMetadata(UtMetadata::deserialize(&buf)?))
}
other => Ok(ExtendedMessage::Dyn(emsg_id, from_bytes(&buf)?)),
} }
// match self { // match self {
// ExtendedMessage::Dyn(v, msg) => { // ExtendedMessage::Dyn(v, msg) => {
// crate::bencode_value::dyn_from_bytes(buf) // crate::bencode_value::dyn_from_bytes(buf)
@ -598,31 +773,130 @@ impl<ByteBuf: std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf> {
} }
} }
#[derive(Debug, Clone, Copy)]
pub struct YourIP(pub IpAddr);
impl Serialize for YourIP {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
todo!()
}
}
impl<'de> Deserialize<'de> for YourIP {
fn deserialize<D>(de: D) -> Result<YourIP, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor {}
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = YourIP;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "expecting 4 bytes of ipv4 or 16 bytes of ipv6")
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if v.len() == 4 {
return Ok(YourIP(IpAddr::V4(Ipv4Addr::new(v[0], v[1], v[2], v[3]))));
} else if v.len() == 16 {
return Ok(YourIP(IpAddr::V6(Ipv6Addr::new(
BE::read_u16(&v[..2]),
BE::read_u16(&v[2..4]),
BE::read_u16(&v[4..6]),
BE::read_u16(&v[6..8]),
BE::read_u16(&v[8..10]),
BE::read_u16(&v[10..12]),
BE::read_u16(&v[12..14]),
BE::read_u16(&v[14..]),
))));
}
Err(E::custom("expected 4 or 16 byte address"))
}
}
de.deserialize_bytes(Visitor {})
}
}
#[derive(Deserialize, Serialize, Debug)] #[derive(Deserialize, Serialize, Debug)]
pub struct ExtendedHandshake<ByteBuf: Eq + std::hash::Hash> { pub struct ExtendedHandshake<ByteBuf: Eq + std::hash::Hash> {
#[serde(bound(deserialize = "ByteBuf: From<&'de [u8]>"))] #[serde(bound(deserialize = "ByteBuf: From<&'de [u8]>"))]
pub m: HashMap<ByteBuf, BencodeValue<ByteBuf>>, pub m: HashMap<ByteBuf, u8>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub p: Option<u32>, pub p: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub v: Option<ByteBuf>, pub v: Option<ByteBuf>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub yourip: Option<ByteBuf>, pub yourip: Option<YourIP>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub ipv6: Option<ByteBuf>, pub ipv6: Option<ByteBuf>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub ipv4: Option<ByteBuf>, pub ipv4: Option<ByteBuf>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub reqq: Option<u32>, pub reqq: Option<u32>,
pub metadata_size: Option<u32>,
}
impl<ByteBuf: Eq + std::hash::Hash> ExtendedHandshake<ByteBuf> {
fn get_msgid(&self, msg_type: &[u8]) -> Option<u8>
where
ByteBuf: AsRef<[u8]>,
{
self.m.iter().find_map(|(k, v)| {
if k.as_ref() == msg_type {
Some(*v)
} else {
None
}
})
}
}
impl<ByteBuf> CloneToOwned for ExtendedHandshake<ByteBuf>
where
ByteBuf: CloneToOwned + Eq + std::hash::Hash,
<ByteBuf as CloneToOwned>::Target: Eq + std::hash::Hash,
{
type Target = ExtendedHandshake<<ByteBuf as CloneToOwned>::Target>;
fn clone_to_owned(&self) -> Self::Target {
ExtendedHandshake {
m: self.m.clone_to_owned(),
p: self.p,
v: self.v.clone_to_owned(),
yourip: self.yourip,
ipv6: self.ipv6.clone_to_owned(),
ipv4: self.ipv4.clone_to_owned(),
reqq: self.reqq,
metadata_size: self.metadata_size,
}
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::{net::SocketAddr, str::FromStr}; use std::{net::SocketAddr, str::FromStr};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use log::info;
use parking_lot::{Mutex, RwLock};
use tokio::sync::mpsc::UnboundedSender;
use crate::peer_id::generate_peer_id; use crate::{
peer_connection::{PeerConnection, PeerConnectionHandler, WriterRequest},
peer_id::generate_peer_id,
};
use std::sync::Once;
static LOG_INIT: Once = std::sync::Once::new();
fn init_logging() {
LOG_INIT.call_once(pretty_env_logger::init)
}
fn decode_info_hash(hash_str: &str) -> [u8; 20] { fn decode_info_hash(hash_str: &str) -> [u8; 20] {
let mut hash_arr = [0u8; 20]; let mut hash_arr = [0u8; 20];
@ -645,9 +919,7 @@ mod tests {
#[test] #[test]
fn test_extended_serialize() { fn test_extended_serialize() {
let mut feats = HashMap::new(); let feats = HashMap::new();
feats.insert("whatever".as_bytes().into(), BencodeValue::Integer(1));
let msg = let msg =
Message::<ByteBuf<'static>>::Extended(ExtendedMessage::Handshake(ExtendedHandshake { Message::<ByteBuf<'static>>::Extended(ExtendedMessage::Handshake(ExtendedHandshake {
m: feats, m: feats,
@ -657,66 +929,72 @@ mod tests {
ipv6: None, ipv6: None,
ipv4: None, ipv4: None,
reqq: None, reqq: None,
metadata_size: None,
})); }));
let mut out = Vec::new(); let mut out = Vec::new();
msg.serialize(&mut out); msg.serialize(&mut out, None);
dbg!(out); dbg!(out);
} }
#[tokio::test] #[tokio::test]
async fn test_connect_to_local_qbittorrent() { async fn test_connect_to_local_qbittorrent() {
let mut stream = init_logging();
tokio::net::TcpStream::connect(SocketAddr::from_str("127.0.0.1:27311").unwrap())
.await struct Handler {
.unwrap(); ehandshake: RwLock<Option<ExtendedHandshake<ByteString>>>,
tx: UnboundedSender<WriterRequest>,
}
impl PeerConnectionHandler for Handler {
fn get_have_bytes(&self) -> u64 {
0
}
fn serialize_bitfield_message_to_buf(&self, _buf: &mut Vec<u8>) -> Option<usize> {
None
}
fn on_handshake(&self, handshake: Handshake) {
info!("received handshake: {:?}", handshake)
}
fn on_received_message(&self, msg: Message<ByteBuf<'_>>) -> anyhow::Result<()> {
info!("received message: {:?}", msg);
Ok(())
}
fn on_uploaded_bytes(&self, _bytes: u32) {}
fn read_chunk(&self, _chunk: &ChunkInfo, _buf: &mut [u8]) -> anyhow::Result<()> {
panic!("dude, why are you requesting chunks")
}
fn on_extended_handshake(&self, extended_handshake: &ExtendedHandshake<ByteBuf>) {
self.ehandshake
.write()
.replace(extended_handshake.clone_to_owned());
self.tx
.send(WriterRequest::Message(Message::Extended(
ExtendedMessage::UtMetadata(UtMetadata::Request(0)),
)))
.unwrap()
}
}
let addr = SocketAddr::from_str("127.0.0.1:27311").unwrap();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let handler = Handler {
tx,
ehandshake: RwLock::new(None),
};
let peer_id = generate_peer_id(); let peer_id = generate_peer_id();
let info_hash = decode_info_hash("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7"); let info_hash = decode_info_hash("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7");
dbg!(info_hash);
let handshake = dbg!(Handshake::new(info_hash, peer_id));
let mut write_buf = Vec::<u8>::new(); let conn = PeerConnection::new(addr, info_hash, peer_id, handler);
let h = handshake.serialize();
let mut read_buf = vec![0u8; 16384]; // tx.send(WriterRequest::Message(Message::Extended(ExtendedMessage)));
stream.write_all(&h).await.unwrap(); conn.manage_peer(rx).await.unwrap();
let read_bytes = stream.read(&mut read_buf).await.unwrap();
let (handshake, hlen) = Handshake::deserialize(&read_buf[..read_bytes]).unwrap();
dbg!(handshake);
read_buf.copy_within(hlen..read_bytes, 0);
let mut read_so_far = read_bytes - hlen;
loop {
let (message, size) = loop {
match MessageBorrowed::deserialize(&read_buf[..read_so_far]) {
Ok((msg, size)) => {
break (msg, size);
}
Err(MessageDeserializeError::NotEnoughData(d, _)) => {
if read_buf.len() < read_so_far + d {
read_buf.reserve(d);
read_buf.resize(read_buf.capacity(), 0);
}
let size = stream.read(&mut read_buf[read_so_far..]).await.unwrap();
if size == 0 {
panic!("size == 0, disconnected")
}
read_so_far += size;
}
Err(e) => Err(e).unwrap(),
}
};
dbg!(message, size);
if read_so_far > size {
read_buf.copy_within(size..read_so_far, 0);
}
read_so_far -= size;
}
} }
} }

View file

@ -2,14 +2,15 @@ use std::{net::SocketAddr, time::Duration};
use anyhow::Context; use anyhow::Context;
use log::{debug, trace}; use log::{debug, trace};
use tokio::time::timeout; use tokio::{io::AsyncReadExt, time::timeout};
use crate::{ use crate::{
buffers::ByteBuf, buffers::{ByteBuf, ByteString},
clone_to_owned::CloneToOwned,
lengths::ChunkInfo, lengths::ChunkInfo,
peer_binary_protocol::{ peer_binary_protocol::{
serialize_piece_preamble, Handshake, Message, MessageBorrowed, MessageDeserializeError, serialize_piece_preamble, ExtendedHandshake, ExtendedMessage, Handshake, Message,
MessageOwned, PIECE_MESSAGE_DEFAULT_LEN, MessageBorrowed, MessageDeserializeError, MessageOwned, PIECE_MESSAGE_DEFAULT_LEN,
}, },
peer_id::try_decode_peer_id, peer_id::try_decode_peer_id,
}; };
@ -18,6 +19,7 @@ pub trait PeerConnectionHandler {
fn get_have_bytes(&self) -> u64; fn get_have_bytes(&self) -> u64;
fn serialize_bitfield_message_to_buf(&self, buf: &mut Vec<u8>) -> Option<usize>; fn serialize_bitfield_message_to_buf(&self, buf: &mut Vec<u8>) -> Option<usize>;
fn on_handshake(&self, handshake: Handshake); fn on_handshake(&self, handshake: Handshake);
fn on_extended_handshake(&self, extended_handshake: &ExtendedHandshake<ByteBuf>);
fn on_received_message(&self, msg: Message<ByteBuf<'_>>) -> anyhow::Result<()>; fn on_received_message(&self, msg: Message<ByteBuf<'_>>) -> anyhow::Result<()>;
fn on_uploaded_bytes(&self, bytes: u32); fn on_uploaded_bytes(&self, bytes: u32);
fn read_chunk(&self, chunk: &ChunkInfo, buf: &mut [u8]) -> anyhow::Result<()>; fn read_chunk(&self, chunk: &ChunkInfo, buf: &mut [u8]) -> anyhow::Result<()>;
@ -36,6 +38,34 @@ pub struct PeerConnection<H> {
peer_id: [u8; 20], peer_id: [u8; 20],
} }
// async fn read_one<'a, R: AsyncReadExt + Unpin>(
// mut reader: R,
// read_buf: &'a mut Vec<u8>,
// read_so_far: &mut usize,
// ) -> anyhow::Result<(MessageBorrowed<'a>, usize)> {
// loop {
// match MessageBorrowed::deserialize(&read_buf[..*read_so_far]) {
// Ok((msg, size)) => return Ok((msg, size)),
// Err(MessageDeserializeError::NotEnoughData(d, _)) => {
// if read_buf.len() < *read_so_far + d {
// read_buf.reserve(d);
// read_buf.resize(read_buf.capacity(), 0);
// }
// let size = reader
// .read(&mut read_buf[*read_so_far..])
// .await
// .context("error reading from peer")?;
// if size == 0 {
// anyhow::bail!("disconnected while reading, read so far: {}", *read_so_far)
// }
// *read_so_far += size;
// }
// Err(e) => return Err(e.into()),
// }
// }
// }
impl<H: PeerConnectionHandler> PeerConnection<H> { impl<H: PeerConnectionHandler> PeerConnection<H> {
pub fn new(addr: SocketAddr, info_hash: [u8; 20], peer_id: [u8; 20], handler: H) -> Self { pub fn new(addr: SocketAddr, info_hash: [u8; 20], peer_id: [u8; 20], handler: H) -> Self {
PeerConnection { PeerConnection {
@ -62,17 +92,16 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
.await .await
.context("error writing handshake")?; .context("error writing handshake")?;
let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2]; let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2];
let read_bytes = conn let mut read_so_far = conn
.read(&mut read_buf) .read(&mut read_buf)
.await .await
.context("error reading handshake")?; .context("error reading handshake")?;
if read_bytes == 0 { if read_so_far == 0 {
anyhow::bail!("bad handshake"); anyhow::bail!("bad handshake");
} }
let (h, hlen) = Handshake::deserialize(&read_buf[..read_bytes]) let (h, size) = Handshake::deserialize(&read_buf[..read_so_far])
.map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?; .map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?;
let mut read_so_far = 0usize;
debug!( debug!(
"connected peer {}: {:?}", "connected peer {}: {:?}",
self.addr, self.addr,
@ -82,11 +111,57 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
anyhow::bail!("info hash does not match"); anyhow::bail!("info hash does not match");
} }
self.handler.on_handshake(h); let mut extended_handshake: Option<ExtendedHandshake<ByteString>> = None;
let supports_extended = h.supports_extended();
if read_bytes > hlen { self.handler.on_handshake(h);
read_buf.copy_within(hlen..read_bytes, 0); if read_so_far > size {
read_so_far = read_bytes - hlen; read_buf.copy_within(size..read_so_far, 0);
}
read_so_far -= size;
if supports_extended {
// Read extended handshake
// I wasn't able to extract that into a function.
// TODO: extract into a macro.
let (extended, size) = loop {
match MessageBorrowed::deserialize(&read_buf[..read_so_far]) {
Ok((msg, size)) => break (msg, size),
Err(MessageDeserializeError::NotEnoughData(d, _)) => {
if read_buf.len() < read_so_far + d {
read_buf.reserve(d);
read_buf.resize(read_buf.capacity(), 0);
}
let size = conn
.read(&mut read_buf[read_so_far..])
.await
.context("error reading from peer")?;
if size == 0 {
anyhow::bail!(
"disconnected while reading, read so far: {}",
read_so_far
)
}
read_so_far += size;
}
Err(e) => return Err(e.into()),
}
};
match extended {
Message::Extended(ExtendedMessage::Handshake(h)) => {
trace!("received from {}: {:?}", self.addr, &h);
self.handler.on_extended_handshake(&h);
extended_handshake = Some(h.clone_to_owned())
}
other => anyhow::bail!("expected extended handshake, but got {:?}", other),
};
if read_so_far > size {
read_buf.copy_within(size..read_so_far, 0);
}
read_so_far -= size;
} }
let (mut read_half, mut write_half) = tokio::io::split(conn); let (mut read_half, mut write_half) = tokio::io::split(conn);
@ -117,7 +192,9 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
let mut uploaded_add = None; let mut uploaded_add = None;
let len = match &req { let len = match &req {
WriterRequest::Message(msg) => msg.serialize(&mut buf), WriterRequest::Message(msg) => {
msg.serialize(&mut buf, extended_handshake.as_ref())?
}
WriterRequest::ReadChunkRequest(chunk) => { WriterRequest::ReadChunkRequest(chunk) => {
// this whole section is an optimization // this whole section is an optimization
buf.resize(PIECE_MESSAGE_DEFAULT_LEN, 0); buf.resize(PIECE_MESSAGE_DEFAULT_LEN, 0);

View file

@ -21,6 +21,9 @@ impl<'de> BencodeDeserializer<'de> {
torrent_info_digest: None, torrent_info_digest: None,
} }
} }
pub fn into_remaining(self) -> &'de [u8] {
self.buf
}
fn parse_integer(&mut self) -> Result<i64, Error> { fn parse_integer(&mut self) -> Result<i64, Error> {
match self.buf.iter().copied().position(|e| e == b'e') { match self.buf.iter().copied().position(|e| e == b'e') {
Some(end) => { Some(end) => {

View file

@ -498,7 +498,7 @@ impl PeerConnectionHandler for PeerHandler {
fn serialize_bitfield_message_to_buf(&self, buf: &mut Vec<u8>) -> Option<usize> { fn serialize_bitfield_message_to_buf(&self, buf: &mut Vec<u8>) -> Option<usize> {
let g = self.state.locked.read(); let g = self.state.locked.read();
let msg = Message::Bitfield(ByteBuf(g.chunks.get_have_pieces().as_raw_slice())); let msg = Message::Bitfield(ByteBuf(g.chunks.get_have_pieces().as_raw_slice()));
let len = msg.serialize(buf); let len = msg.serialize(buf, None).unwrap();
debug!("sending to {}: {:?}, length={}", self.addr, &msg, len); debug!("sending to {}: {:?}, length={}", self.addr, &msg, len);
Some(len) Some(len)
} }
@ -517,6 +517,12 @@ impl PeerConnectionHandler for PeerHandler {
fn read_chunk(&self, chunk: &crate::lengths::ChunkInfo, buf: &mut [u8]) -> anyhow::Result<()> { fn read_chunk(&self, chunk: &crate::lengths::ChunkInfo, buf: &mut [u8]) -> anyhow::Result<()> {
self.state.file_ops().read_chunk(self.addr, chunk, buf) self.state.file_ops().read_chunk(self.addr, chunk, buf)
} }
fn on_extended_handshake(
&self,
extended_handshake: &crate::peer_binary_protocol::ExtendedHandshake<ByteBuf>,
) {
}
} }
impl PeerHandler { impl PeerHandler {