Can decode extended messages now

This commit is contained in:
Igor Katson 2021-07-02 01:38:07 +01:00
parent 5f60f9e1b4
commit e666f063ff
7 changed files with 248 additions and 23 deletions

View file

@ -166,7 +166,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a77162240fd97248d19a564a565eb563a3f592b386e4136fb300909e67dddca"
dependencies = [
"commoncrypto",
"hex",
"hex 0.3.2",
"openssl",
"winapi",
]
@ -439,6 +439,12 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "805026a5d0141ffc30abb3be3173848ad46a1b1664fe632428479619a3644d77"
[[package]]
name = "hex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "http"
version = "0.2.4"
@ -601,6 +607,7 @@ dependencies = [
"byteorder",
"crypto-hash",
"futures",
"hex 0.4.3",
"log",
"openssl",
"parking_lot",

View file

@ -39,6 +39,7 @@ futures = "0.3"
[dev-dependencies]
futures = {version = "0.3"}
pretty_env_logger = "0.4"
hex = "0.4"
[profile.dev]
panic = "abort"

View file

@ -56,6 +56,7 @@ pub enum MessageDeserializeError {
len_prefix: u32,
name: &'static str,
},
Other(anyhow::Error),
}
pub fn serialize_piece_preamble(chunk: &ChunkInfo, mut buf: &mut [u8]) -> usize {
@ -145,6 +146,7 @@ impl std::fmt::Display for MessageDeserializeError {
"error deserializing {} (msg_id={}, len_prefix={}): {:?}",
name, msg_id, len_prefix, error
),
MessageDeserializeError::Other(e) => write!(f, "{}", e),
}
}
}
@ -158,6 +160,12 @@ impl std::error::Error for MessageDeserializeError {
}
}
impl From<anyhow::Error> for MessageDeserializeError {
fn from(e: anyhow::Error) -> Self {
MessageDeserializeError::Other(e)
}
}
#[derive(Debug)]
pub enum Message<ByteBuf: std::hash::Hash + Eq> {
Request(Request),
@ -228,7 +236,7 @@ impl<'a> std::fmt::Debug for Bitfield<'a> {
impl<ByteBuf> Message<ByteBuf>
where
ByteBuf: AsRef<[u8]> + std::hash::Hash + Eq,
ByteBuf: AsRef<[u8]> + std::hash::Hash + Eq + Serialize,
{
pub fn len_prefix_and_msg_id(&self) -> (u32, u8) {
match self {
@ -244,7 +252,7 @@ where
),
Message::KeepAlive => (LEN_PREFIX_KEEPALIVE, 0),
Message::Have(_) => (LEN_PREFIX_HAVE, MSGID_HAVE),
Message::Extended(_) => todo!(),
Message::Extended(_) => (0, MSGID_EXTENDED),
}
}
pub fn serialize(&self, out: &mut Vec<u8>) -> usize {
@ -295,14 +303,19 @@ where
BE::write_u32(&mut out[PREAMBLE_LEN..], *v);
msg_len
}
Message::Extended(_) => todo!(),
Message::Extended(e) => {
e.serialize(out);
let msg_size = out.len();
BE::write_u32(&mut out[..4], msg_size as u32);
msg_size
}
}
}
pub fn deserialize<'a>(
buf: &'a [u8],
) -> Result<(Message<ByteBuf>, usize), MessageDeserializeError>
where
ByteBuf: From<&'a [u8]> + 'a,
ByteBuf: From<&'a [u8]> + 'a + Deserialize<'a>,
{
let len_prefix = match buf.get(0..4) {
Some(bytes) => byteorder::BigEndian::read_u32(bytes),
@ -426,6 +439,27 @@ where
)),
}
}
MSGID_EXTENDED => {
if len_prefix <= 6 {
return Err(MessageDeserializeError::IncorrectLenPrefix {
expected: 6,
received: len_prefix,
msg_id,
});
}
// TODO: NO clue why - 1 here. Empirically figured out.
let expected_len = len_prefix as usize - 1;
match rest.get(..expected_len) {
Some(b) => Ok((
Message::Extended(ExtendedMessage::deserialize(&b)?),
PREAMBLE_LEN + expected_len,
)),
None => Err(MessageDeserializeError::NotEnoughData(
expected_len - rest.len(),
"extended",
)),
}
}
msg_id => Err(MessageDeserializeError::UnsupportedMessageId(msg_id)),
}
}
@ -448,9 +482,16 @@ fn bopts() -> impl bincode::Options {
impl<'a> Handshake<'a> {
pub fn new(info_hash: [u8; 20], peer_id: [u8; 20]) -> Handshake<'static> {
debug_assert_eq!(PSTR_BT1.len(), 19);
let mut reserved: u64 = 0;
// supports extended messaging
reserved |= 1 << 20;
let mut reserved_arr = [0u8; 8];
BE::write_u64(&mut reserved_arr, reserved);
Handshake {
pstr: PSTR_BT1,
reserved: [0; 8],
reserved: reserved_arr,
info_hash,
peer_id,
}
@ -495,19 +536,100 @@ impl Request {
#[derive(Debug)]
pub enum ExtendedMessage<ByteBuf: std::hash::Hash + Eq> {
Dyn(BencodeValue<ByteBuf>),
Unimplemented(PhantomData<ByteBuf>),
Handshake(ExtendedHandshake<ByteBuf>),
Dyn(u8, BencodeValue<ByteBuf>),
}
struct ExtendedHandshake<ByteBuf: Eq + std::hash::Hash> {
m: HashMap<ByteBuf, BencodeValue<ByteBuf>>,
p: Option<u32>,
v: Option<ByteBuf>,
// _phantom: PhantomData<&'a ()>,
impl<ByteBuf: std::hash::Hash + Eq + Serialize> ExtendedMessage<ByteBuf> {
fn serialize(&self, out: &mut Vec<u8>) {
match self {
ExtendedMessage::Dyn(msg_id, v) => {
out.push(*msg_id);
crate::serde_bencode_ser::bencode_serialize_to_writer(v, out).unwrap()
}
ExtendedMessage::Handshake(h) => {
out.push(0);
crate::serde_bencode_ser::bencode_serialize_to_writer(h, out).unwrap()
}
}
}
fn deserialize<'de>(mut buf: &'de [u8]) -> Result<Self, MessageDeserializeError>
where
ByteBuf: Deserialize<'de> + From<&'de [u8]>,
{
{
use std::io::Write;
let mut f = std::fs::OpenOptions::new()
.create(true)
.write(true)
.open("/tmp/msg")
.unwrap();
f.write_all(buf).unwrap();
}
use crate::serde_bencode_de::from_bytes;
let emsg_id = buf.get(0).copied().ok_or_else(|| {
MessageDeserializeError::Other(anyhow::anyhow!(
"cannot deserialize extended message: can't read first byte"
))
})?;
buf = &buf.get(1..).ok_or_else(|| {
MessageDeserializeError::Other(anyhow::anyhow!(
"cannot deserialize extended message: buffer empty"
))
})?;
match emsg_id {
// handshake
0 => Ok(ExtendedMessage::Handshake(from_bytes(&buf)?)),
other => Ok(ExtendedMessage::Dyn(other, from_bytes(&buf)?)),
}
// match self {
// ExtendedMessage::Dyn(v, msg) => {
// crate::bencode_value::dyn_from_bytes(buf)
// }
// ExtendedMessage::Handshake(h) => {
// crate::serde_bencode_ser::bencode_serialize_to_writer(h, out).unwrap()
// }
// }
}
}
#[derive(Deserialize, Serialize, Debug)]
pub struct ExtendedHandshake<ByteBuf: Eq + std::hash::Hash> {
#[serde(bound(deserialize = "ByteBuf: From<&'de [u8]>"))]
pub m: HashMap<ByteBuf, BencodeValue<ByteBuf>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub p: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub v: Option<ByteBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub yourip: Option<ByteBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ipv6: Option<ByteBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ipv4: Option<ByteBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reqq: Option<u32>,
}
#[cfg(test)]
mod tests {
use std::{io::Write, net::SocketAddr, ptr::read, str::FromStr};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::peer_id::generate_peer_id;
fn decode_info_hash(hash_str: &str) -> [u8; 20] {
let mut hash_arr = [0u8; 20];
hex::decode_to_slice(hash_str, &mut hash_arr).unwrap();
hash_arr
}
use super::*;
#[test]
fn test_handshake_serialize() {
@ -520,4 +642,81 @@ mod tests {
let b = dbg!(Handshake::new(info_hash, peer_id).serialize());
assert_eq!(b.len(), 20 + 20 + 8 + 19 + 1);
}
#[test]
fn test_extended_serialize() {
let mut feats = HashMap::new();
feats.insert("whatever".as_bytes().into(), BencodeValue::Integer(1));
let msg =
Message::<ByteBuf<'static>>::Extended(ExtendedMessage::Handshake(ExtendedHandshake {
m: feats,
p: None,
v: None,
yourip: None,
ipv6: None,
ipv4: None,
reqq: None,
}));
let mut out = Vec::new();
msg.serialize(&mut out);
dbg!(out);
}
#[tokio::test]
async fn test_connect_to_local_qbittorrent() {
let mut stream =
tokio::net::TcpStream::connect(SocketAddr::from_str("127.0.0.1:27311").unwrap())
.await
.unwrap();
let peer_id = generate_peer_id();
let info_hash = decode_info_hash("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7");
dbg!(info_hash);
let handshake = dbg!(Handshake::new(info_hash, peer_id));
let mut write_buf = Vec::<u8>::new();
let h = handshake.serialize();
let mut read_buf = vec![0u8; 16384];
stream.write_all(&h).await.unwrap();
let read_bytes = stream.read(&mut read_buf).await.unwrap();
let (handshake, hlen) = Handshake::deserialize(&read_buf[..read_bytes]).unwrap();
dbg!(handshake);
read_buf.copy_within(hlen..read_bytes, 0);
let mut read_so_far = read_bytes - hlen;
loop {
let (message, size) = loop {
match MessageBorrowed::deserialize(&read_buf[..read_so_far]) {
Ok((msg, size)) => {
break (msg, size);
}
Err(MessageDeserializeError::NotEnoughData(d, _)) => {
if read_buf.len() < read_so_far + d {
read_buf.reserve(d);
read_buf.resize(read_buf.capacity(), 0);
}
let size = stream.read(&mut read_buf[read_so_far..]).await.unwrap();
if size == 0 {
panic!("size == 0, disconnected")
}
read_so_far += size;
}
Err(e) => Err(e).unwrap(),
}
};
dbg!(message, size);
if read_so_far > size {
read_buf.copy_within(size..read_so_far, 0);
}
read_so_far -= size;
}
}
}

View file

@ -43,3 +43,14 @@ pub enum PeerId {
pub fn try_decode_peer_id(p: [u8; 20]) -> Option<PeerId> {
Some(PeerId::AzureusStyle(try_decode_azureus_style(&p)?))
}
pub fn generate_peer_id() -> [u8; 20] {
let mut peer_id = [0u8; 20];
let u = uuid::Uuid::new_v4();
(&mut peer_id[4..20]).copy_from_slice(&u.as_bytes()[..]);
(&mut peer_id[..8]).copy_from_slice(b"-rQ0001-");
peer_id
}

View file

@ -40,7 +40,7 @@ pub struct LivePeerState {
impl LivePeerState {
pub fn new(peer_id: [u8; 20]) -> Self {
LivePeerState {
peer_id: peer_id,
peer_id,
i_am_choked: true,
peer_interested: false,
bitfield: None,

View file

@ -55,7 +55,13 @@ impl<'de> BencodeDeserializer<'de> {
.map_err(|e| Error::new_from_err(e).set_context(self))?;
let bytes_start = length_delim + 1;
let bytes_end = bytes_start + length;
let bytes = &self.buf[bytes_start..bytes_end];
let bytes = &self.buf.get(bytes_start..bytes_end).ok_or_else(|| {
Error::custom(format!(
"could not get byte range {}..{}, data in the buffer: {:?}",
bytes_start, bytes_end, &self.buf
))
.set_context(self)
})?;
let rem = self.buf.get(bytes_end..).unwrap_or_default();
self.buf = rem;
Ok(bytes)
@ -86,7 +92,14 @@ where
T: serde::de::Deserialize<'a>,
{
let mut de = BencodeDeserializer::new_from_buf(buf);
Ok(T::deserialize(&mut de)?)
let v = T::deserialize(&mut de)?;
if !de.buf.is_empty() {
anyhow::bail!(
"deserialized successfully, but {} bytes remaining",
de.buf.len()
)
}
Ok(v)
}
#[derive(Debug)]

View file

@ -21,6 +21,7 @@ use crate::{
file_ops::FileOps,
http_api::make_and_run_http_api,
lengths::Lengths,
peer_id::generate_peer_id,
spawn_utils::{spawn, BlockingSpawner},
speed_estimator::SpeedEstimator,
torrent_metainfo::TorrentMetaV1Owned,
@ -104,13 +105,6 @@ struct TorrentManager {
force_tracker_interval: Option<Duration>,
}
fn generate_peer_id() -> [u8; 20] {
let mut peer_id = [0u8; 20];
let u = uuid::Uuid::new_v4();
(&mut peer_id[..16]).copy_from_slice(&u.as_bytes()[..]);
peer_id
}
fn make_lengths(torrent: &TorrentMetaV1Owned) -> anyhow::Result<Lengths> {
let total_length = torrent.info.iter_file_lengths().sum();
Lengths::new(total_length, torrent.info.piece_length, None)