Screwing around with extended messages
This commit is contained in:
parent
d722f0edcb
commit
302e95649d
5 changed files with 476 additions and 96 deletions
|
|
@ -28,6 +28,22 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl CloneToOwned for u8 {
|
||||||
|
type Target = u8;
|
||||||
|
|
||||||
|
fn clone_to_owned(&self) -> Self::Target {
|
||||||
|
*self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CloneToOwned for u32 {
|
||||||
|
type Target = u32;
|
||||||
|
|
||||||
|
fn clone_to_owned(&self) -> Self::Target {
|
||||||
|
*self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<K, V> CloneToOwned for HashMap<K, V>
|
impl<K, V> CloneToOwned for HashMap<K, V>
|
||||||
where
|
where
|
||||||
K: CloneToOwned,
|
K: CloneToOwned,
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,13 @@
|
||||||
use std::collections::HashMap;
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
io::Write,
|
||||||
|
marker::PhantomData,
|
||||||
|
net::{IpAddr, Ipv4Addr, Ipv6Addr},
|
||||||
|
};
|
||||||
|
|
||||||
use bincode::Options;
|
use bincode::Options;
|
||||||
use byteorder::{ByteOrder, BE};
|
use byteorder::{ByteOrder, BE};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
bencode_value::BencodeValue,
|
bencode_value::BencodeValue,
|
||||||
|
|
@ -10,6 +15,8 @@ use crate::{
|
||||||
clone_to_owned::CloneToOwned,
|
clone_to_owned::CloneToOwned,
|
||||||
constants::CHUNK_SIZE,
|
constants::CHUNK_SIZE,
|
||||||
lengths::ChunkInfo,
|
lengths::ChunkInfo,
|
||||||
|
serde_bencode_de::BencodeDeserializer,
|
||||||
|
serde_bencode_ser,
|
||||||
};
|
};
|
||||||
|
|
||||||
const INTEGER_LEN: usize = 4;
|
const INTEGER_LEN: usize = 4;
|
||||||
|
|
@ -41,6 +48,8 @@ const MSGID_REQUEST: u8 = 6;
|
||||||
const MSGID_PIECE: u8 = 7;
|
const MSGID_PIECE: u8 = 7;
|
||||||
const MSGID_EXTENDED: u8 = 20;
|
const MSGID_EXTENDED: u8 = 20;
|
||||||
|
|
||||||
|
const MY_EXTENDED_UT_METADATA: u8 = 0;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum MessageDeserializeError {
|
pub enum MessageDeserializeError {
|
||||||
NotEnoughData(usize, &'static str),
|
NotEnoughData(usize, &'static str),
|
||||||
|
|
@ -255,7 +264,11 @@ where
|
||||||
Message::Extended(_) => (0, MSGID_EXTENDED),
|
Message::Extended(_) => (0, MSGID_EXTENDED),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub fn serialize(&self, out: &mut Vec<u8>) -> usize {
|
pub fn serialize(
|
||||||
|
&self,
|
||||||
|
out: &mut Vec<u8>,
|
||||||
|
peer_extended_handshake: Option<&ExtendedHandshake<ByteString>>,
|
||||||
|
) -> anyhow::Result<usize> {
|
||||||
let (lp, msg_id) = self.len_prefix_and_msg_id();
|
let (lp, msg_id) = self.len_prefix_and_msg_id();
|
||||||
|
|
||||||
out.resize(PREAMBLE_LEN, 0);
|
out.resize(PREAMBLE_LEN, 0);
|
||||||
|
|
@ -272,17 +285,17 @@ where
|
||||||
debug_assert_eq!((&out[PREAMBLE_LEN..]).len(), 12);
|
debug_assert_eq!((&out[PREAMBLE_LEN..]).len(), 12);
|
||||||
ser.serialize_into(&mut out[PREAMBLE_LEN..], request)
|
ser.serialize_into(&mut out[PREAMBLE_LEN..], request)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
MSG_LEN
|
Ok(MSG_LEN)
|
||||||
}
|
}
|
||||||
Message::Bitfield(b) => {
|
Message::Bitfield(b) => {
|
||||||
let block_len = b.as_ref().len();
|
let block_len = b.as_ref().len();
|
||||||
let msg_len = PREAMBLE_LEN + block_len;
|
let msg_len = PREAMBLE_LEN + block_len;
|
||||||
out.resize(msg_len, 0);
|
out.resize(msg_len, 0);
|
||||||
(&mut out[PREAMBLE_LEN..PREAMBLE_LEN + block_len]).copy_from_slice(b.as_ref());
|
(&mut out[PREAMBLE_LEN..PREAMBLE_LEN + block_len]).copy_from_slice(b.as_ref());
|
||||||
msg_len
|
Ok(msg_len)
|
||||||
}
|
}
|
||||||
Message::Choke | Message::Unchoke | Message::Interested | Message::NotInterested => {
|
Message::Choke | Message::Unchoke | Message::Interested | Message::NotInterested => {
|
||||||
PREAMBLE_LEN
|
Ok(PREAMBLE_LEN)
|
||||||
}
|
}
|
||||||
Message::Piece(p) => {
|
Message::Piece(p) => {
|
||||||
let block_len = p.block.as_ref().len();
|
let block_len = p.block.as_ref().len();
|
||||||
|
|
@ -291,23 +304,23 @@ where
|
||||||
out.resize(msg_len, 0);
|
out.resize(msg_len, 0);
|
||||||
let tmp = &mut out[PREAMBLE_LEN..];
|
let tmp = &mut out[PREAMBLE_LEN..];
|
||||||
p.serialize(&mut tmp[..payload_len]);
|
p.serialize(&mut tmp[..payload_len]);
|
||||||
msg_len
|
Ok(msg_len)
|
||||||
}
|
}
|
||||||
Message::KeepAlive => {
|
Message::KeepAlive => {
|
||||||
// the len prefix was already written out to buf
|
// the len prefix was already written out to buf
|
||||||
4
|
Ok(4)
|
||||||
}
|
}
|
||||||
Message::Have(v) => {
|
Message::Have(v) => {
|
||||||
let msg_len = PREAMBLE_LEN + 4;
|
let msg_len = PREAMBLE_LEN + 4;
|
||||||
out.resize(msg_len, 0);
|
out.resize(msg_len, 0);
|
||||||
BE::write_u32(&mut out[PREAMBLE_LEN..], *v);
|
BE::write_u32(&mut out[PREAMBLE_LEN..], *v);
|
||||||
msg_len
|
Ok(msg_len)
|
||||||
}
|
}
|
||||||
Message::Extended(e) => {
|
Message::Extended(e) => {
|
||||||
e.serialize(out);
|
e.serialize(out, peer_extended_handshake);
|
||||||
let msg_size = out.len();
|
let msg_size = out.len();
|
||||||
BE::write_u32(&mut out[..4], msg_size as u32);
|
BE::write_u32(&mut out[..4], msg_size as u32);
|
||||||
msg_size
|
Ok(msg_size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -496,6 +509,9 @@ impl<'a> Handshake<'a> {
|
||||||
peer_id,
|
peer_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pub fn supports_extended(&self) -> bool {
|
||||||
|
self.reserved[5] & 0x10 > 0
|
||||||
|
}
|
||||||
fn bopts() -> impl bincode::Options {
|
fn bopts() -> impl bincode::Options {
|
||||||
bincode::DefaultOptions::new()
|
bincode::DefaultOptions::new()
|
||||||
}
|
}
|
||||||
|
|
@ -535,28 +551,184 @@ impl Request {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum ExtendedMessage<ByteBuf: std::hash::Hash + Eq> {
|
pub enum UtMetadata<ByteBuf> {
|
||||||
Handshake(ExtendedHandshake<ByteBuf>),
|
Request(u32),
|
||||||
Dyn(u8, BencodeValue<ByteBuf>),
|
Data(u32, ByteBuf),
|
||||||
|
Reject(u32),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<ByteBuf: std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf> {
|
impl<ByteBuf: CloneToOwned> CloneToOwned for UtMetadata<ByteBuf> {
|
||||||
fn serialize(&self, out: &mut Vec<u8>) {
|
type Target = UtMetadata<<ByteBuf as CloneToOwned>::Target>;
|
||||||
|
|
||||||
|
fn clone_to_owned(&self) -> Self::Target {
|
||||||
match self {
|
match self {
|
||||||
ExtendedMessage::Dyn(msg_id, v) => {
|
UtMetadata::Request(req) => UtMetadata::Request(*req),
|
||||||
out.push(*msg_id);
|
UtMetadata::Data(piece, data) => UtMetadata::Data(*piece, data.clone_to_owned()),
|
||||||
crate::serde_bencode_ser::bencode_serialize_to_writer(v, out).unwrap()
|
UtMetadata::Reject(piece) => UtMetadata::Reject(*piece),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, ByteBuf: 'a> UtMetadata<ByteBuf> {
|
||||||
|
fn serialize(&self, buf: &mut Vec<u8>)
|
||||||
|
where
|
||||||
|
ByteBuf: AsRef<[u8]>,
|
||||||
|
{
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct Message {
|
||||||
|
msg_type: u32,
|
||||||
|
piece: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
total_size: Option<u32>,
|
||||||
|
}
|
||||||
|
match self {
|
||||||
|
UtMetadata::Request(piece) => {
|
||||||
|
let message = Message {
|
||||||
|
msg_type: 0,
|
||||||
|
piece: *piece,
|
||||||
|
total_size: None,
|
||||||
|
};
|
||||||
|
serde_bencode_ser::bencode_serialize_to_writer(message, buf).unwrap()
|
||||||
}
|
}
|
||||||
ExtendedMessage::Handshake(h) => {
|
UtMetadata::Data(piece, data) => {
|
||||||
out.push(0);
|
let message = Message {
|
||||||
crate::serde_bencode_ser::bencode_serialize_to_writer(h, out).unwrap()
|
msg_type: 1,
|
||||||
|
piece: *piece,
|
||||||
|
total_size: Some(data.as_ref().len() as u32),
|
||||||
|
};
|
||||||
|
serde_bencode_ser::bencode_serialize_to_writer(message, buf).unwrap();
|
||||||
|
buf.write_all(data.as_ref()).unwrap();
|
||||||
|
}
|
||||||
|
UtMetadata::Reject(piece) => {
|
||||||
|
let message = Message {
|
||||||
|
msg_type: 2,
|
||||||
|
piece: *piece,
|
||||||
|
total_size: None,
|
||||||
|
};
|
||||||
|
serde_bencode_ser::bencode_serialize_to_writer(message, buf).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fn deserialize(buf: &'a [u8]) -> Result<Self, MessageDeserializeError>
|
||||||
fn deserialize<'de>(mut buf: &'de [u8]) -> Result<Self, MessageDeserializeError>
|
|
||||||
where
|
where
|
||||||
ByteBuf: Deserialize<'de> + From<&'de [u8]>,
|
ByteBuf: From<&'a [u8]>,
|
||||||
|
{
|
||||||
|
let mut de = BencodeDeserializer::new_from_buf(buf);
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Message {
|
||||||
|
msg_type: u32,
|
||||||
|
piece: u32,
|
||||||
|
total_size: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
let message =
|
||||||
|
Message::deserialize(&mut de).map_err(|e| MessageDeserializeError::Other(e.into()))?;
|
||||||
|
let remaining = de.into_remaining();
|
||||||
|
|
||||||
|
match message.msg_type {
|
||||||
|
// request
|
||||||
|
0 => {
|
||||||
|
if !remaining.is_empty() {
|
||||||
|
return Err(MessageDeserializeError::Other(anyhow::anyhow!(
|
||||||
|
"trailing bytes when decoding UtMetadata"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Ok(UtMetadata::Request(message.piece))
|
||||||
|
}
|
||||||
|
// data
|
||||||
|
1 => {
|
||||||
|
let total_size = message.total_size.ok_or_else(|| {
|
||||||
|
MessageDeserializeError::Other(anyhow::anyhow!(
|
||||||
|
"expected key total_size to be present in UtMetadata \"data\" message"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
if remaining.len() != total_size as usize {
|
||||||
|
return Err(MessageDeserializeError::Other(anyhow::anyhow!(
|
||||||
|
"remaining bytes len {} != total_size {}",
|
||||||
|
remaining.len(),
|
||||||
|
total_size
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Ok(UtMetadata::Data(message.piece, ByteBuf::from(remaining)))
|
||||||
|
}
|
||||||
|
// reject
|
||||||
|
2 => {
|
||||||
|
if !remaining.is_empty() {
|
||||||
|
return Err(MessageDeserializeError::Other(anyhow::anyhow!(
|
||||||
|
"trailing bytes when decoding UtMetadata"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Ok(UtMetadata::Reject(message.piece))
|
||||||
|
}
|
||||||
|
other => {
|
||||||
|
return Err(MessageDeserializeError::Other(anyhow::anyhow!(
|
||||||
|
"unrecognized ut_metadata message type {}",
|
||||||
|
other
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ExtendedMessage<ByteBuf: std::hash::Hash + Eq> {
|
||||||
|
Handshake(ExtendedHandshake<ByteBuf>),
|
||||||
|
UtMetadata(UtMetadata<ByteBuf>),
|
||||||
|
Dyn(u8, BencodeValue<ByteBuf>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<ByteBuf> CloneToOwned for ExtendedMessage<ByteBuf>
|
||||||
|
where
|
||||||
|
ByteBuf: CloneToOwned + std::hash::Hash + Eq,
|
||||||
|
<ByteBuf as CloneToOwned>::Target: std::hash::Hash + Eq,
|
||||||
|
{
|
||||||
|
type Target = ExtendedMessage<<ByteBuf as CloneToOwned>::Target>;
|
||||||
|
|
||||||
|
fn clone_to_owned(&self) -> Self::Target {
|
||||||
|
match self {
|
||||||
|
ExtendedMessage::Handshake(h) => ExtendedMessage::Handshake(h.clone_to_owned()),
|
||||||
|
ExtendedMessage::Dyn(u, d) => ExtendedMessage::Dyn(*u, d.clone_to_owned()),
|
||||||
|
ExtendedMessage::UtMetadata(m) => ExtendedMessage::UtMetadata(m.clone_to_owned()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, ByteBuf: 'a + std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf> {
|
||||||
|
fn serialize(
|
||||||
|
&self,
|
||||||
|
out: &mut Vec<u8>,
|
||||||
|
extended_handshake: Option<&ExtendedHandshake<ByteString>>,
|
||||||
|
) -> anyhow::Result<()>
|
||||||
|
where
|
||||||
|
ByteBuf: AsRef<[u8]>,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
ExtendedMessage::Dyn(msg_id, v) => {
|
||||||
|
out.push(*msg_id);
|
||||||
|
crate::serde_bencode_ser::bencode_serialize_to_writer(v, out)?;
|
||||||
|
}
|
||||||
|
ExtendedMessage::Handshake(h) => {
|
||||||
|
out.push(0);
|
||||||
|
crate::serde_bencode_ser::bencode_serialize_to_writer(h, out)?;
|
||||||
|
}
|
||||||
|
ExtendedMessage::UtMetadata(u) => {
|
||||||
|
let h = extended_handshake.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("need peer's handshake to serialize ut_metadata")
|
||||||
|
})?;
|
||||||
|
let emsg_id = h
|
||||||
|
.get_msgid(b"ut_metadata")
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("peer doesn't support ut_metadata"))?;
|
||||||
|
out.push(emsg_id);
|
||||||
|
u.serialize(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deserialize(mut buf: &'a [u8]) -> Result<Self, MessageDeserializeError>
|
||||||
|
where
|
||||||
|
ByteBuf: Deserialize<'a> + From<&'a [u8]>,
|
||||||
{
|
{
|
||||||
{
|
{
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
@ -583,10 +755,13 @@ impl<ByteBuf: std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf> {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
match emsg_id {
|
match emsg_id {
|
||||||
// handshake
|
|
||||||
0 => Ok(ExtendedMessage::Handshake(from_bytes(&buf)?)),
|
0 => Ok(ExtendedMessage::Handshake(from_bytes(&buf)?)),
|
||||||
other => Ok(ExtendedMessage::Dyn(other, from_bytes(&buf)?)),
|
MY_EXTENDED_UT_METADATA => {
|
||||||
|
Ok(ExtendedMessage::UtMetadata(UtMetadata::deserialize(&buf)?))
|
||||||
|
}
|
||||||
|
other => Ok(ExtendedMessage::Dyn(emsg_id, from_bytes(&buf)?)),
|
||||||
}
|
}
|
||||||
|
|
||||||
// match self {
|
// match self {
|
||||||
// ExtendedMessage::Dyn(v, msg) => {
|
// ExtendedMessage::Dyn(v, msg) => {
|
||||||
// crate::bencode_value::dyn_from_bytes(buf)
|
// crate::bencode_value::dyn_from_bytes(buf)
|
||||||
|
|
@ -598,31 +773,130 @@ impl<ByteBuf: std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct YourIP(pub IpAddr);
|
||||||
|
|
||||||
|
impl Serialize for YourIP {
|
||||||
|
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for YourIP {
|
||||||
|
fn deserialize<D>(de: D) -> Result<YourIP, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
struct Visitor {}
|
||||||
|
impl<'de> serde::de::Visitor<'de> for Visitor {
|
||||||
|
type Value = YourIP;
|
||||||
|
|
||||||
|
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
write!(f, "expecting 4 bytes of ipv4 or 16 bytes of ipv6")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
if v.len() == 4 {
|
||||||
|
return Ok(YourIP(IpAddr::V4(Ipv4Addr::new(v[0], v[1], v[2], v[3]))));
|
||||||
|
} else if v.len() == 16 {
|
||||||
|
return Ok(YourIP(IpAddr::V6(Ipv6Addr::new(
|
||||||
|
BE::read_u16(&v[..2]),
|
||||||
|
BE::read_u16(&v[2..4]),
|
||||||
|
BE::read_u16(&v[4..6]),
|
||||||
|
BE::read_u16(&v[6..8]),
|
||||||
|
BE::read_u16(&v[8..10]),
|
||||||
|
BE::read_u16(&v[10..12]),
|
||||||
|
BE::read_u16(&v[12..14]),
|
||||||
|
BE::read_u16(&v[14..]),
|
||||||
|
))));
|
||||||
|
}
|
||||||
|
Err(E::custom("expected 4 or 16 byte address"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
de.deserialize_bytes(Visitor {})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, Debug)]
|
#[derive(Deserialize, Serialize, Debug)]
|
||||||
pub struct ExtendedHandshake<ByteBuf: Eq + std::hash::Hash> {
|
pub struct ExtendedHandshake<ByteBuf: Eq + std::hash::Hash> {
|
||||||
#[serde(bound(deserialize = "ByteBuf: From<&'de [u8]>"))]
|
#[serde(bound(deserialize = "ByteBuf: From<&'de [u8]>"))]
|
||||||
pub m: HashMap<ByteBuf, BencodeValue<ByteBuf>>,
|
pub m: HashMap<ByteBuf, u8>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub p: Option<u32>,
|
pub p: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub v: Option<ByteBuf>,
|
pub v: Option<ByteBuf>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub yourip: Option<ByteBuf>,
|
pub yourip: Option<YourIP>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub ipv6: Option<ByteBuf>,
|
pub ipv6: Option<ByteBuf>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub ipv4: Option<ByteBuf>,
|
pub ipv4: Option<ByteBuf>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub reqq: Option<u32>,
|
pub reqq: Option<u32>,
|
||||||
|
pub metadata_size: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<ByteBuf: Eq + std::hash::Hash> ExtendedHandshake<ByteBuf> {
|
||||||
|
fn get_msgid(&self, msg_type: &[u8]) -> Option<u8>
|
||||||
|
where
|
||||||
|
ByteBuf: AsRef<[u8]>,
|
||||||
|
{
|
||||||
|
self.m.iter().find_map(|(k, v)| {
|
||||||
|
if k.as_ref() == msg_type {
|
||||||
|
Some(*v)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<ByteBuf> CloneToOwned for ExtendedHandshake<ByteBuf>
|
||||||
|
where
|
||||||
|
ByteBuf: CloneToOwned + Eq + std::hash::Hash,
|
||||||
|
<ByteBuf as CloneToOwned>::Target: Eq + std::hash::Hash,
|
||||||
|
{
|
||||||
|
type Target = ExtendedHandshake<<ByteBuf as CloneToOwned>::Target>;
|
||||||
|
|
||||||
|
fn clone_to_owned(&self) -> Self::Target {
|
||||||
|
ExtendedHandshake {
|
||||||
|
m: self.m.clone_to_owned(),
|
||||||
|
p: self.p,
|
||||||
|
v: self.v.clone_to_owned(),
|
||||||
|
yourip: self.yourip,
|
||||||
|
ipv6: self.ipv6.clone_to_owned(),
|
||||||
|
ipv4: self.ipv4.clone_to_owned(),
|
||||||
|
reqq: self.reqq,
|
||||||
|
metadata_size: self.metadata_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::{net::SocketAddr, str::FromStr};
|
use std::{net::SocketAddr, str::FromStr};
|
||||||
|
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
use log::info;
|
||||||
|
use parking_lot::{Mutex, RwLock};
|
||||||
|
use tokio::sync::mpsc::UnboundedSender;
|
||||||
|
|
||||||
use crate::peer_id::generate_peer_id;
|
use crate::{
|
||||||
|
peer_connection::{PeerConnection, PeerConnectionHandler, WriterRequest},
|
||||||
|
peer_id::generate_peer_id,
|
||||||
|
};
|
||||||
|
use std::sync::Once;
|
||||||
|
|
||||||
|
static LOG_INIT: Once = std::sync::Once::new();
|
||||||
|
|
||||||
|
fn init_logging() {
|
||||||
|
LOG_INIT.call_once(pretty_env_logger::init)
|
||||||
|
}
|
||||||
|
|
||||||
fn decode_info_hash(hash_str: &str) -> [u8; 20] {
|
fn decode_info_hash(hash_str: &str) -> [u8; 20] {
|
||||||
let mut hash_arr = [0u8; 20];
|
let mut hash_arr = [0u8; 20];
|
||||||
|
|
@ -645,9 +919,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_extended_serialize() {
|
fn test_extended_serialize() {
|
||||||
let mut feats = HashMap::new();
|
let feats = HashMap::new();
|
||||||
feats.insert("whatever".as_bytes().into(), BencodeValue::Integer(1));
|
|
||||||
|
|
||||||
let msg =
|
let msg =
|
||||||
Message::<ByteBuf<'static>>::Extended(ExtendedMessage::Handshake(ExtendedHandshake {
|
Message::<ByteBuf<'static>>::Extended(ExtendedMessage::Handshake(ExtendedHandshake {
|
||||||
m: feats,
|
m: feats,
|
||||||
|
|
@ -657,66 +929,72 @@ mod tests {
|
||||||
ipv6: None,
|
ipv6: None,
|
||||||
ipv4: None,
|
ipv4: None,
|
||||||
reqq: None,
|
reqq: None,
|
||||||
|
metadata_size: None,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
let mut out = Vec::new();
|
let mut out = Vec::new();
|
||||||
msg.serialize(&mut out);
|
msg.serialize(&mut out, None);
|
||||||
dbg!(out);
|
dbg!(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_connect_to_local_qbittorrent() {
|
async fn test_connect_to_local_qbittorrent() {
|
||||||
let mut stream =
|
init_logging();
|
||||||
tokio::net::TcpStream::connect(SocketAddr::from_str("127.0.0.1:27311").unwrap())
|
|
||||||
.await
|
struct Handler {
|
||||||
.unwrap();
|
ehandshake: RwLock<Option<ExtendedHandshake<ByteString>>>,
|
||||||
|
tx: UnboundedSender<WriterRequest>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PeerConnectionHandler for Handler {
|
||||||
|
fn get_have_bytes(&self) -> u64 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn serialize_bitfield_message_to_buf(&self, _buf: &mut Vec<u8>) -> Option<usize> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_handshake(&self, handshake: Handshake) {
|
||||||
|
info!("received handshake: {:?}", handshake)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_received_message(&self, msg: Message<ByteBuf<'_>>) -> anyhow::Result<()> {
|
||||||
|
info!("received message: {:?}", msg);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_uploaded_bytes(&self, _bytes: u32) {}
|
||||||
|
|
||||||
|
fn read_chunk(&self, _chunk: &ChunkInfo, _buf: &mut [u8]) -> anyhow::Result<()> {
|
||||||
|
panic!("dude, why are you requesting chunks")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_extended_handshake(&self, extended_handshake: &ExtendedHandshake<ByteBuf>) {
|
||||||
|
self.ehandshake
|
||||||
|
.write()
|
||||||
|
.replace(extended_handshake.clone_to_owned());
|
||||||
|
self.tx
|
||||||
|
.send(WriterRequest::Message(Message::Extended(
|
||||||
|
ExtendedMessage::UtMetadata(UtMetadata::Request(0)),
|
||||||
|
)))
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let addr = SocketAddr::from_str("127.0.0.1:27311").unwrap();
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
let handler = Handler {
|
||||||
|
tx,
|
||||||
|
ehandshake: RwLock::new(None),
|
||||||
|
};
|
||||||
let peer_id = generate_peer_id();
|
let peer_id = generate_peer_id();
|
||||||
let info_hash = decode_info_hash("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7");
|
let info_hash = decode_info_hash("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7");
|
||||||
dbg!(info_hash);
|
|
||||||
let handshake = dbg!(Handshake::new(info_hash, peer_id));
|
|
||||||
|
|
||||||
let mut write_buf = Vec::<u8>::new();
|
let conn = PeerConnection::new(addr, info_hash, peer_id, handler);
|
||||||
let h = handshake.serialize();
|
|
||||||
|
|
||||||
let mut read_buf = vec![0u8; 16384];
|
// tx.send(WriterRequest::Message(Message::Extended(ExtendedMessage)));
|
||||||
|
|
||||||
stream.write_all(&h).await.unwrap();
|
conn.manage_peer(rx).await.unwrap();
|
||||||
|
|
||||||
let read_bytes = stream.read(&mut read_buf).await.unwrap();
|
|
||||||
let (handshake, hlen) = Handshake::deserialize(&read_buf[..read_bytes]).unwrap();
|
|
||||||
dbg!(handshake);
|
|
||||||
|
|
||||||
read_buf.copy_within(hlen..read_bytes, 0);
|
|
||||||
let mut read_so_far = read_bytes - hlen;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
let (message, size) = loop {
|
|
||||||
match MessageBorrowed::deserialize(&read_buf[..read_so_far]) {
|
|
||||||
Ok((msg, size)) => {
|
|
||||||
break (msg, size);
|
|
||||||
}
|
|
||||||
Err(MessageDeserializeError::NotEnoughData(d, _)) => {
|
|
||||||
if read_buf.len() < read_so_far + d {
|
|
||||||
read_buf.reserve(d);
|
|
||||||
read_buf.resize(read_buf.capacity(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
let size = stream.read(&mut read_buf[read_so_far..]).await.unwrap();
|
|
||||||
if size == 0 {
|
|
||||||
panic!("size == 0, disconnected")
|
|
||||||
}
|
|
||||||
read_so_far += size;
|
|
||||||
}
|
|
||||||
Err(e) => Err(e).unwrap(),
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
dbg!(message, size);
|
|
||||||
|
|
||||||
if read_so_far > size {
|
|
||||||
read_buf.copy_within(size..read_so_far, 0);
|
|
||||||
}
|
|
||||||
read_so_far -= size;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,15 @@ use std::{net::SocketAddr, time::Duration};
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use log::{debug, trace};
|
use log::{debug, trace};
|
||||||
use tokio::time::timeout;
|
use tokio::{io::AsyncReadExt, time::timeout};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
buffers::ByteBuf,
|
buffers::{ByteBuf, ByteString},
|
||||||
|
clone_to_owned::CloneToOwned,
|
||||||
lengths::ChunkInfo,
|
lengths::ChunkInfo,
|
||||||
peer_binary_protocol::{
|
peer_binary_protocol::{
|
||||||
serialize_piece_preamble, Handshake, Message, MessageBorrowed, MessageDeserializeError,
|
serialize_piece_preamble, ExtendedHandshake, ExtendedMessage, Handshake, Message,
|
||||||
MessageOwned, PIECE_MESSAGE_DEFAULT_LEN,
|
MessageBorrowed, MessageDeserializeError, MessageOwned, PIECE_MESSAGE_DEFAULT_LEN,
|
||||||
},
|
},
|
||||||
peer_id::try_decode_peer_id,
|
peer_id::try_decode_peer_id,
|
||||||
};
|
};
|
||||||
|
|
@ -18,6 +19,7 @@ pub trait PeerConnectionHandler {
|
||||||
fn get_have_bytes(&self) -> u64;
|
fn get_have_bytes(&self) -> u64;
|
||||||
fn serialize_bitfield_message_to_buf(&self, buf: &mut Vec<u8>) -> Option<usize>;
|
fn serialize_bitfield_message_to_buf(&self, buf: &mut Vec<u8>) -> Option<usize>;
|
||||||
fn on_handshake(&self, handshake: Handshake);
|
fn on_handshake(&self, handshake: Handshake);
|
||||||
|
fn on_extended_handshake(&self, extended_handshake: &ExtendedHandshake<ByteBuf>);
|
||||||
fn on_received_message(&self, msg: Message<ByteBuf<'_>>) -> anyhow::Result<()>;
|
fn on_received_message(&self, msg: Message<ByteBuf<'_>>) -> anyhow::Result<()>;
|
||||||
fn on_uploaded_bytes(&self, bytes: u32);
|
fn on_uploaded_bytes(&self, bytes: u32);
|
||||||
fn read_chunk(&self, chunk: &ChunkInfo, buf: &mut [u8]) -> anyhow::Result<()>;
|
fn read_chunk(&self, chunk: &ChunkInfo, buf: &mut [u8]) -> anyhow::Result<()>;
|
||||||
|
|
@ -36,6 +38,34 @@ pub struct PeerConnection<H> {
|
||||||
peer_id: [u8; 20],
|
peer_id: [u8; 20],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// async fn read_one<'a, R: AsyncReadExt + Unpin>(
|
||||||
|
// mut reader: R,
|
||||||
|
// read_buf: &'a mut Vec<u8>,
|
||||||
|
// read_so_far: &mut usize,
|
||||||
|
// ) -> anyhow::Result<(MessageBorrowed<'a>, usize)> {
|
||||||
|
// loop {
|
||||||
|
// match MessageBorrowed::deserialize(&read_buf[..*read_so_far]) {
|
||||||
|
// Ok((msg, size)) => return Ok((msg, size)),
|
||||||
|
// Err(MessageDeserializeError::NotEnoughData(d, _)) => {
|
||||||
|
// if read_buf.len() < *read_so_far + d {
|
||||||
|
// read_buf.reserve(d);
|
||||||
|
// read_buf.resize(read_buf.capacity(), 0);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// let size = reader
|
||||||
|
// .read(&mut read_buf[*read_so_far..])
|
||||||
|
// .await
|
||||||
|
// .context("error reading from peer")?;
|
||||||
|
// if size == 0 {
|
||||||
|
// anyhow::bail!("disconnected while reading, read so far: {}", *read_so_far)
|
||||||
|
// }
|
||||||
|
// *read_so_far += size;
|
||||||
|
// }
|
||||||
|
// Err(e) => return Err(e.into()),
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
impl<H: PeerConnectionHandler> PeerConnection<H> {
|
impl<H: PeerConnectionHandler> PeerConnection<H> {
|
||||||
pub fn new(addr: SocketAddr, info_hash: [u8; 20], peer_id: [u8; 20], handler: H) -> Self {
|
pub fn new(addr: SocketAddr, info_hash: [u8; 20], peer_id: [u8; 20], handler: H) -> Self {
|
||||||
PeerConnection {
|
PeerConnection {
|
||||||
|
|
@ -62,17 +92,16 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
||||||
.await
|
.await
|
||||||
.context("error writing handshake")?;
|
.context("error writing handshake")?;
|
||||||
let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2];
|
let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2];
|
||||||
let read_bytes = conn
|
let mut read_so_far = conn
|
||||||
.read(&mut read_buf)
|
.read(&mut read_buf)
|
||||||
.await
|
.await
|
||||||
.context("error reading handshake")?;
|
.context("error reading handshake")?;
|
||||||
if read_bytes == 0 {
|
if read_so_far == 0 {
|
||||||
anyhow::bail!("bad handshake");
|
anyhow::bail!("bad handshake");
|
||||||
}
|
}
|
||||||
let (h, hlen) = Handshake::deserialize(&read_buf[..read_bytes])
|
let (h, size) = Handshake::deserialize(&read_buf[..read_so_far])
|
||||||
.map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?;
|
.map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?;
|
||||||
|
|
||||||
let mut read_so_far = 0usize;
|
|
||||||
debug!(
|
debug!(
|
||||||
"connected peer {}: {:?}",
|
"connected peer {}: {:?}",
|
||||||
self.addr,
|
self.addr,
|
||||||
|
|
@ -82,11 +111,57 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
||||||
anyhow::bail!("info hash does not match");
|
anyhow::bail!("info hash does not match");
|
||||||
}
|
}
|
||||||
|
|
||||||
self.handler.on_handshake(h);
|
let mut extended_handshake: Option<ExtendedHandshake<ByteString>> = None;
|
||||||
|
let supports_extended = h.supports_extended();
|
||||||
|
|
||||||
if read_bytes > hlen {
|
self.handler.on_handshake(h);
|
||||||
read_buf.copy_within(hlen..read_bytes, 0);
|
if read_so_far > size {
|
||||||
read_so_far = read_bytes - hlen;
|
read_buf.copy_within(size..read_so_far, 0);
|
||||||
|
}
|
||||||
|
read_so_far -= size;
|
||||||
|
|
||||||
|
if supports_extended {
|
||||||
|
// Read extended handshake
|
||||||
|
// I wasn't able to extract that into a function.
|
||||||
|
// TODO: extract into a macro.
|
||||||
|
let (extended, size) = loop {
|
||||||
|
match MessageBorrowed::deserialize(&read_buf[..read_so_far]) {
|
||||||
|
Ok((msg, size)) => break (msg, size),
|
||||||
|
Err(MessageDeserializeError::NotEnoughData(d, _)) => {
|
||||||
|
if read_buf.len() < read_so_far + d {
|
||||||
|
read_buf.reserve(d);
|
||||||
|
read_buf.resize(read_buf.capacity(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let size = conn
|
||||||
|
.read(&mut read_buf[read_so_far..])
|
||||||
|
.await
|
||||||
|
.context("error reading from peer")?;
|
||||||
|
if size == 0 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"disconnected while reading, read so far: {}",
|
||||||
|
read_so_far
|
||||||
|
)
|
||||||
|
}
|
||||||
|
read_so_far += size;
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e.into()),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match extended {
|
||||||
|
Message::Extended(ExtendedMessage::Handshake(h)) => {
|
||||||
|
trace!("received from {}: {:?}", self.addr, &h);
|
||||||
|
self.handler.on_extended_handshake(&h);
|
||||||
|
extended_handshake = Some(h.clone_to_owned())
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("expected extended handshake, but got {:?}", other),
|
||||||
|
};
|
||||||
|
|
||||||
|
if read_so_far > size {
|
||||||
|
read_buf.copy_within(size..read_so_far, 0);
|
||||||
|
}
|
||||||
|
read_so_far -= size;
|
||||||
}
|
}
|
||||||
|
|
||||||
let (mut read_half, mut write_half) = tokio::io::split(conn);
|
let (mut read_half, mut write_half) = tokio::io::split(conn);
|
||||||
|
|
@ -117,7 +192,9 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
||||||
let mut uploaded_add = None;
|
let mut uploaded_add = None;
|
||||||
|
|
||||||
let len = match &req {
|
let len = match &req {
|
||||||
WriterRequest::Message(msg) => msg.serialize(&mut buf),
|
WriterRequest::Message(msg) => {
|
||||||
|
msg.serialize(&mut buf, extended_handshake.as_ref())?
|
||||||
|
}
|
||||||
WriterRequest::ReadChunkRequest(chunk) => {
|
WriterRequest::ReadChunkRequest(chunk) => {
|
||||||
// this whole section is an optimization
|
// this whole section is an optimization
|
||||||
buf.resize(PIECE_MESSAGE_DEFAULT_LEN, 0);
|
buf.resize(PIECE_MESSAGE_DEFAULT_LEN, 0);
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,9 @@ impl<'de> BencodeDeserializer<'de> {
|
||||||
torrent_info_digest: None,
|
torrent_info_digest: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pub fn into_remaining(self) -> &'de [u8] {
|
||||||
|
self.buf
|
||||||
|
}
|
||||||
fn parse_integer(&mut self) -> Result<i64, Error> {
|
fn parse_integer(&mut self) -> Result<i64, Error> {
|
||||||
match self.buf.iter().copied().position(|e| e == b'e') {
|
match self.buf.iter().copied().position(|e| e == b'e') {
|
||||||
Some(end) => {
|
Some(end) => {
|
||||||
|
|
|
||||||
|
|
@ -498,7 +498,7 @@ impl PeerConnectionHandler for PeerHandler {
|
||||||
fn serialize_bitfield_message_to_buf(&self, buf: &mut Vec<u8>) -> Option<usize> {
|
fn serialize_bitfield_message_to_buf(&self, buf: &mut Vec<u8>) -> Option<usize> {
|
||||||
let g = self.state.locked.read();
|
let g = self.state.locked.read();
|
||||||
let msg = Message::Bitfield(ByteBuf(g.chunks.get_have_pieces().as_raw_slice()));
|
let msg = Message::Bitfield(ByteBuf(g.chunks.get_have_pieces().as_raw_slice()));
|
||||||
let len = msg.serialize(buf);
|
let len = msg.serialize(buf, None).unwrap();
|
||||||
debug!("sending to {}: {:?}, length={}", self.addr, &msg, len);
|
debug!("sending to {}: {:?}, length={}", self.addr, &msg, len);
|
||||||
Some(len)
|
Some(len)
|
||||||
}
|
}
|
||||||
|
|
@ -517,6 +517,12 @@ impl PeerConnectionHandler for PeerHandler {
|
||||||
fn read_chunk(&self, chunk: &crate::lengths::ChunkInfo, buf: &mut [u8]) -> anyhow::Result<()> {
|
fn read_chunk(&self, chunk: &crate::lengths::ChunkInfo, buf: &mut [u8]) -> anyhow::Result<()> {
|
||||||
self.state.file_ops().read_chunk(self.addr, chunk, buf)
|
self.state.file_ops().read_chunk(self.addr, chunk, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn on_extended_handshake(
|
||||||
|
&self,
|
||||||
|
extended_handshake: &crate::peer_binary_protocol::ExtendedHandshake<ByteBuf>,
|
||||||
|
) {
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PeerHandler {
|
impl PeerHandler {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue