Screwing around with extended messages

This commit is contained in:
Igor Katson 2021-07-02 13:00:46 +01:00
parent d722f0edcb
commit 302e95649d
5 changed files with 476 additions and 96 deletions

View file

@ -1,8 +1,13 @@
use std::collections::HashMap;
use std::{
collections::HashMap,
io::Write,
marker::PhantomData,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
};
use bincode::Options;
use byteorder::{ByteOrder, BE};
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use crate::{
bencode_value::BencodeValue,
@ -10,6 +15,8 @@ use crate::{
clone_to_owned::CloneToOwned,
constants::CHUNK_SIZE,
lengths::ChunkInfo,
serde_bencode_de::BencodeDeserializer,
serde_bencode_ser,
};
const INTEGER_LEN: usize = 4;
@ -41,6 +48,8 @@ const MSGID_REQUEST: u8 = 6;
const MSGID_PIECE: u8 = 7;
const MSGID_EXTENDED: u8 = 20;
const MY_EXTENDED_UT_METADATA: u8 = 0;
#[derive(Debug)]
pub enum MessageDeserializeError {
NotEnoughData(usize, &'static str),
@ -255,7 +264,11 @@ where
Message::Extended(_) => (0, MSGID_EXTENDED),
}
}
pub fn serialize(&self, out: &mut Vec<u8>) -> usize {
pub fn serialize(
&self,
out: &mut Vec<u8>,
peer_extended_handshake: Option<&ExtendedHandshake<ByteString>>,
) -> anyhow::Result<usize> {
let (lp, msg_id) = self.len_prefix_and_msg_id();
out.resize(PREAMBLE_LEN, 0);
@ -272,17 +285,17 @@ where
debug_assert_eq!((&out[PREAMBLE_LEN..]).len(), 12);
ser.serialize_into(&mut out[PREAMBLE_LEN..], request)
.unwrap();
MSG_LEN
Ok(MSG_LEN)
}
Message::Bitfield(b) => {
let block_len = b.as_ref().len();
let msg_len = PREAMBLE_LEN + block_len;
out.resize(msg_len, 0);
(&mut out[PREAMBLE_LEN..PREAMBLE_LEN + block_len]).copy_from_slice(b.as_ref());
msg_len
Ok(msg_len)
}
Message::Choke | Message::Unchoke | Message::Interested | Message::NotInterested => {
PREAMBLE_LEN
Ok(PREAMBLE_LEN)
}
Message::Piece(p) => {
let block_len = p.block.as_ref().len();
@ -291,23 +304,23 @@ where
out.resize(msg_len, 0);
let tmp = &mut out[PREAMBLE_LEN..];
p.serialize(&mut tmp[..payload_len]);
msg_len
Ok(msg_len)
}
Message::KeepAlive => {
// the len prefix was already written out to buf
4
Ok(4)
}
Message::Have(v) => {
let msg_len = PREAMBLE_LEN + 4;
out.resize(msg_len, 0);
BE::write_u32(&mut out[PREAMBLE_LEN..], *v);
msg_len
Ok(msg_len)
}
Message::Extended(e) => {
e.serialize(out);
e.serialize(out, peer_extended_handshake);
let msg_size = out.len();
BE::write_u32(&mut out[..4], msg_size as u32);
msg_size
Ok(msg_size)
}
}
}
@ -496,6 +509,9 @@ impl<'a> Handshake<'a> {
peer_id,
}
}
pub fn supports_extended(&self) -> bool {
self.reserved[5] & 0x10 > 0
}
fn bopts() -> impl bincode::Options {
bincode::DefaultOptions::new()
}
@ -535,28 +551,184 @@ impl Request {
}
#[derive(Debug)]
pub enum ExtendedMessage<ByteBuf: std::hash::Hash + Eq> {
Handshake(ExtendedHandshake<ByteBuf>),
Dyn(u8, BencodeValue<ByteBuf>),
pub enum UtMetadata<ByteBuf> {
Request(u32),
Data(u32, ByteBuf),
Reject(u32),
}
impl<ByteBuf: std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf> {
fn serialize(&self, out: &mut Vec<u8>) {
impl<ByteBuf: CloneToOwned> CloneToOwned for UtMetadata<ByteBuf> {
type Target = UtMetadata<<ByteBuf as CloneToOwned>::Target>;
fn clone_to_owned(&self) -> Self::Target {
match self {
ExtendedMessage::Dyn(msg_id, v) => {
out.push(*msg_id);
crate::serde_bencode_ser::bencode_serialize_to_writer(v, out).unwrap()
UtMetadata::Request(req) => UtMetadata::Request(*req),
UtMetadata::Data(piece, data) => UtMetadata::Data(*piece, data.clone_to_owned()),
UtMetadata::Reject(piece) => UtMetadata::Reject(*piece),
}
}
}
impl<'a, ByteBuf: 'a> UtMetadata<ByteBuf> {
fn serialize(&self, buf: &mut Vec<u8>)
where
ByteBuf: AsRef<[u8]>,
{
#[derive(Serialize)]
struct Message {
msg_type: u32,
piece: u32,
#[serde(skip_serializing_if = "Option::is_none")]
total_size: Option<u32>,
}
match self {
UtMetadata::Request(piece) => {
let message = Message {
msg_type: 0,
piece: *piece,
total_size: None,
};
serde_bencode_ser::bencode_serialize_to_writer(message, buf).unwrap()
}
ExtendedMessage::Handshake(h) => {
out.push(0);
crate::serde_bencode_ser::bencode_serialize_to_writer(h, out).unwrap()
UtMetadata::Data(piece, data) => {
let message = Message {
msg_type: 1,
piece: *piece,
total_size: Some(data.as_ref().len() as u32),
};
serde_bencode_ser::bencode_serialize_to_writer(message, buf).unwrap();
buf.write_all(data.as_ref()).unwrap();
}
UtMetadata::Reject(piece) => {
let message = Message {
msg_type: 2,
piece: *piece,
total_size: None,
};
serde_bencode_ser::bencode_serialize_to_writer(message, buf).unwrap();
}
}
}
fn deserialize<'de>(mut buf: &'de [u8]) -> Result<Self, MessageDeserializeError>
fn deserialize(buf: &'a [u8]) -> Result<Self, MessageDeserializeError>
where
ByteBuf: Deserialize<'de> + From<&'de [u8]>,
ByteBuf: From<&'a [u8]>,
{
let mut de = BencodeDeserializer::new_from_buf(buf);
#[derive(Deserialize)]
struct Message {
msg_type: u32,
piece: u32,
total_size: Option<u32>,
}
let message =
Message::deserialize(&mut de).map_err(|e| MessageDeserializeError::Other(e.into()))?;
let remaining = de.into_remaining();
match message.msg_type {
// request
0 => {
if !remaining.is_empty() {
return Err(MessageDeserializeError::Other(anyhow::anyhow!(
"trailing bytes when decoding UtMetadata"
)));
}
Ok(UtMetadata::Request(message.piece))
}
// data
1 => {
let total_size = message.total_size.ok_or_else(|| {
MessageDeserializeError::Other(anyhow::anyhow!(
"expected key total_size to be present in UtMetadata \"data\" message"
))
})?;
if remaining.len() != total_size as usize {
return Err(MessageDeserializeError::Other(anyhow::anyhow!(
"remaining bytes len {} != total_size {}",
remaining.len(),
total_size
)));
}
Ok(UtMetadata::Data(message.piece, ByteBuf::from(remaining)))
}
// reject
2 => {
if !remaining.is_empty() {
return Err(MessageDeserializeError::Other(anyhow::anyhow!(
"trailing bytes when decoding UtMetadata"
)));
}
Ok(UtMetadata::Reject(message.piece))
}
other => {
return Err(MessageDeserializeError::Other(anyhow::anyhow!(
"unrecognized ut_metadata message type {}",
other
)))
}
}
}
}
#[derive(Debug)]
pub enum ExtendedMessage<ByteBuf: std::hash::Hash + Eq> {
Handshake(ExtendedHandshake<ByteBuf>),
UtMetadata(UtMetadata<ByteBuf>),
Dyn(u8, BencodeValue<ByteBuf>),
}
impl<ByteBuf> CloneToOwned for ExtendedMessage<ByteBuf>
where
ByteBuf: CloneToOwned + std::hash::Hash + Eq,
<ByteBuf as CloneToOwned>::Target: std::hash::Hash + Eq,
{
type Target = ExtendedMessage<<ByteBuf as CloneToOwned>::Target>;
fn clone_to_owned(&self) -> Self::Target {
match self {
ExtendedMessage::Handshake(h) => ExtendedMessage::Handshake(h.clone_to_owned()),
ExtendedMessage::Dyn(u, d) => ExtendedMessage::Dyn(*u, d.clone_to_owned()),
ExtendedMessage::UtMetadata(m) => ExtendedMessage::UtMetadata(m.clone_to_owned()),
}
}
}
impl<'a, ByteBuf: 'a + std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf> {
fn serialize(
&self,
out: &mut Vec<u8>,
extended_handshake: Option<&ExtendedHandshake<ByteString>>,
) -> anyhow::Result<()>
where
ByteBuf: AsRef<[u8]>,
{
match self {
ExtendedMessage::Dyn(msg_id, v) => {
out.push(*msg_id);
crate::serde_bencode_ser::bencode_serialize_to_writer(v, out)?;
}
ExtendedMessage::Handshake(h) => {
out.push(0);
crate::serde_bencode_ser::bencode_serialize_to_writer(h, out)?;
}
ExtendedMessage::UtMetadata(u) => {
let h = extended_handshake.ok_or_else(|| {
anyhow::anyhow!("need peer's handshake to serialize ut_metadata")
})?;
let emsg_id = h
.get_msgid(b"ut_metadata")
.ok_or_else(|| anyhow::anyhow!("peer doesn't support ut_metadata"))?;
out.push(emsg_id);
u.serialize(out);
}
}
Ok(())
}
fn deserialize(mut buf: &'a [u8]) -> Result<Self, MessageDeserializeError>
where
ByteBuf: Deserialize<'a> + From<&'a [u8]>,
{
{
use std::io::Write;
@ -583,10 +755,13 @@ impl<ByteBuf: std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf> {
})?;
match emsg_id {
// handshake
0 => Ok(ExtendedMessage::Handshake(from_bytes(&buf)?)),
other => Ok(ExtendedMessage::Dyn(other, from_bytes(&buf)?)),
MY_EXTENDED_UT_METADATA => {
Ok(ExtendedMessage::UtMetadata(UtMetadata::deserialize(&buf)?))
}
other => Ok(ExtendedMessage::Dyn(emsg_id, from_bytes(&buf)?)),
}
// match self {
// ExtendedMessage::Dyn(v, msg) => {
// crate::bencode_value::dyn_from_bytes(buf)
@ -598,31 +773,130 @@ impl<ByteBuf: std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf> {
}
}
#[derive(Debug, Clone, Copy)]
pub struct YourIP(pub IpAddr);
impl Serialize for YourIP {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
todo!()
}
}
impl<'de> Deserialize<'de> for YourIP {
fn deserialize<D>(de: D) -> Result<YourIP, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor {}
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = YourIP;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "expecting 4 bytes of ipv4 or 16 bytes of ipv6")
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if v.len() == 4 {
return Ok(YourIP(IpAddr::V4(Ipv4Addr::new(v[0], v[1], v[2], v[3]))));
} else if v.len() == 16 {
return Ok(YourIP(IpAddr::V6(Ipv6Addr::new(
BE::read_u16(&v[..2]),
BE::read_u16(&v[2..4]),
BE::read_u16(&v[4..6]),
BE::read_u16(&v[6..8]),
BE::read_u16(&v[8..10]),
BE::read_u16(&v[10..12]),
BE::read_u16(&v[12..14]),
BE::read_u16(&v[14..]),
))));
}
Err(E::custom("expected 4 or 16 byte address"))
}
}
de.deserialize_bytes(Visitor {})
}
}
#[derive(Deserialize, Serialize, Debug)]
pub struct ExtendedHandshake<ByteBuf: Eq + std::hash::Hash> {
#[serde(bound(deserialize = "ByteBuf: From<&'de [u8]>"))]
pub m: HashMap<ByteBuf, BencodeValue<ByteBuf>>,
pub m: HashMap<ByteBuf, u8>,
#[serde(skip_serializing_if = "Option::is_none")]
pub p: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub v: Option<ByteBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub yourip: Option<ByteBuf>,
pub yourip: Option<YourIP>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ipv6: Option<ByteBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ipv4: Option<ByteBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reqq: Option<u32>,
pub metadata_size: Option<u32>,
}
impl<ByteBuf: Eq + std::hash::Hash> ExtendedHandshake<ByteBuf> {
fn get_msgid(&self, msg_type: &[u8]) -> Option<u8>
where
ByteBuf: AsRef<[u8]>,
{
self.m.iter().find_map(|(k, v)| {
if k.as_ref() == msg_type {
Some(*v)
} else {
None
}
})
}
}
impl<ByteBuf> CloneToOwned for ExtendedHandshake<ByteBuf>
where
ByteBuf: CloneToOwned + Eq + std::hash::Hash,
<ByteBuf as CloneToOwned>::Target: Eq + std::hash::Hash,
{
type Target = ExtendedHandshake<<ByteBuf as CloneToOwned>::Target>;
fn clone_to_owned(&self) -> Self::Target {
ExtendedHandshake {
m: self.m.clone_to_owned(),
p: self.p,
v: self.v.clone_to_owned(),
yourip: self.yourip,
ipv6: self.ipv6.clone_to_owned(),
ipv4: self.ipv4.clone_to_owned(),
reqq: self.reqq,
metadata_size: self.metadata_size,
}
}
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddr, str::FromStr};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use log::info;
use parking_lot::{Mutex, RwLock};
use tokio::sync::mpsc::UnboundedSender;
use crate::peer_id::generate_peer_id;
use crate::{
peer_connection::{PeerConnection, PeerConnectionHandler, WriterRequest},
peer_id::generate_peer_id,
};
use std::sync::Once;
static LOG_INIT: Once = std::sync::Once::new();
fn init_logging() {
LOG_INIT.call_once(pretty_env_logger::init)
}
fn decode_info_hash(hash_str: &str) -> [u8; 20] {
let mut hash_arr = [0u8; 20];
@ -645,9 +919,7 @@ mod tests {
#[test]
fn test_extended_serialize() {
let mut feats = HashMap::new();
feats.insert("whatever".as_bytes().into(), BencodeValue::Integer(1));
let feats = HashMap::new();
let msg =
Message::<ByteBuf<'static>>::Extended(ExtendedMessage::Handshake(ExtendedHandshake {
m: feats,
@ -657,66 +929,72 @@ mod tests {
ipv6: None,
ipv4: None,
reqq: None,
metadata_size: None,
}));
let mut out = Vec::new();
msg.serialize(&mut out);
msg.serialize(&mut out, None);
dbg!(out);
}
#[tokio::test]
async fn test_connect_to_local_qbittorrent() {
let mut stream =
tokio::net::TcpStream::connect(SocketAddr::from_str("127.0.0.1:27311").unwrap())
.await
.unwrap();
init_logging();
struct Handler {
ehandshake: RwLock<Option<ExtendedHandshake<ByteString>>>,
tx: UnboundedSender<WriterRequest>,
}
impl PeerConnectionHandler for Handler {
fn get_have_bytes(&self) -> u64 {
0
}
fn serialize_bitfield_message_to_buf(&self, _buf: &mut Vec<u8>) -> Option<usize> {
None
}
fn on_handshake(&self, handshake: Handshake) {
info!("received handshake: {:?}", handshake)
}
fn on_received_message(&self, msg: Message<ByteBuf<'_>>) -> anyhow::Result<()> {
info!("received message: {:?}", msg);
Ok(())
}
fn on_uploaded_bytes(&self, _bytes: u32) {}
fn read_chunk(&self, _chunk: &ChunkInfo, _buf: &mut [u8]) -> anyhow::Result<()> {
panic!("dude, why are you requesting chunks")
}
fn on_extended_handshake(&self, extended_handshake: &ExtendedHandshake<ByteBuf>) {
self.ehandshake
.write()
.replace(extended_handshake.clone_to_owned());
self.tx
.send(WriterRequest::Message(Message::Extended(
ExtendedMessage::UtMetadata(UtMetadata::Request(0)),
)))
.unwrap()
}
}
let addr = SocketAddr::from_str("127.0.0.1:27311").unwrap();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let handler = Handler {
tx,
ehandshake: RwLock::new(None),
};
let peer_id = generate_peer_id();
let info_hash = decode_info_hash("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7");
dbg!(info_hash);
let handshake = dbg!(Handshake::new(info_hash, peer_id));
let mut write_buf = Vec::<u8>::new();
let h = handshake.serialize();
let conn = PeerConnection::new(addr, info_hash, peer_id, handler);
let mut read_buf = vec![0u8; 16384];
// tx.send(WriterRequest::Message(Message::Extended(ExtendedMessage)));
stream.write_all(&h).await.unwrap();
let read_bytes = stream.read(&mut read_buf).await.unwrap();
let (handshake, hlen) = Handshake::deserialize(&read_buf[..read_bytes]).unwrap();
dbg!(handshake);
read_buf.copy_within(hlen..read_bytes, 0);
let mut read_so_far = read_bytes - hlen;
loop {
let (message, size) = loop {
match MessageBorrowed::deserialize(&read_buf[..read_so_far]) {
Ok((msg, size)) => {
break (msg, size);
}
Err(MessageDeserializeError::NotEnoughData(d, _)) => {
if read_buf.len() < read_so_far + d {
read_buf.reserve(d);
read_buf.resize(read_buf.capacity(), 0);
}
let size = stream.read(&mut read_buf[read_so_far..]).await.unwrap();
if size == 0 {
panic!("size == 0, disconnected")
}
read_so_far += size;
}
Err(e) => Err(e).unwrap(),
}
};
dbg!(message, size);
if read_so_far > size {
read_buf.copy_within(size..read_so_far, 0);
}
read_so_far -= size;
}
conn.manage_peer(rx).await.unwrap();
}
}