diff --git a/crates/librqbit/Cargo.lock b/crates/librqbit/Cargo.lock index 34ec560..f9f28ee 100644 --- a/crates/librqbit/Cargo.lock +++ b/crates/librqbit/Cargo.lock @@ -166,7 +166,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a77162240fd97248d19a564a565eb563a3f592b386e4136fb300909e67dddca" dependencies = [ "commoncrypto", - "hex", + "hex 0.3.2", "openssl", "winapi", ] @@ -439,6 +439,12 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "805026a5d0141ffc30abb3be3173848ad46a1b1664fe632428479619a3644d77" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "http" version = "0.2.4" @@ -601,6 +607,7 @@ dependencies = [ "byteorder", "crypto-hash", "futures", + "hex 0.4.3", "log", "openssl", "parking_lot", diff --git a/crates/librqbit/Cargo.toml b/crates/librqbit/Cargo.toml index 106bdca..7eb09ec 100644 --- a/crates/librqbit/Cargo.toml +++ b/crates/librqbit/Cargo.toml @@ -39,6 +39,7 @@ futures = "0.3" [dev-dependencies] futures = {version = "0.3"} pretty_env_logger = "0.4" +hex = "0.4" [profile.dev] panic = "abort" diff --git a/crates/librqbit/src/peer_binary_protocol.rs b/crates/librqbit/src/peer_binary_protocol.rs index 38bd277..1626a42 100644 --- a/crates/librqbit/src/peer_binary_protocol.rs +++ b/crates/librqbit/src/peer_binary_protocol.rs @@ -56,6 +56,7 @@ pub enum MessageDeserializeError { len_prefix: u32, name: &'static str, }, + Other(anyhow::Error), } pub fn serialize_piece_preamble(chunk: &ChunkInfo, mut buf: &mut [u8]) -> usize { @@ -145,6 +146,7 @@ impl std::fmt::Display for MessageDeserializeError { "error deserializing {} (msg_id={}, len_prefix={}): {:?}", name, msg_id, len_prefix, error ), + MessageDeserializeError::Other(e) => write!(f, "{}", e), } } } @@ -158,6 +160,12 @@ impl std::error::Error for MessageDeserializeError { } } +impl From for MessageDeserializeError { + fn from(e: anyhow::Error) -> Self { + MessageDeserializeError::Other(e) + } +} + #[derive(Debug)] pub enum Message { Request(Request), @@ -228,7 +236,7 @@ impl<'a> std::fmt::Debug for Bitfield<'a> { impl Message where - ByteBuf: AsRef<[u8]> + std::hash::Hash + Eq, + ByteBuf: AsRef<[u8]> + std::hash::Hash + Eq + Serialize, { pub fn len_prefix_and_msg_id(&self) -> (u32, u8) { match self { @@ -244,7 +252,7 @@ where ), Message::KeepAlive => (LEN_PREFIX_KEEPALIVE, 0), Message::Have(_) => (LEN_PREFIX_HAVE, MSGID_HAVE), - Message::Extended(_) => todo!(), + Message::Extended(_) => (0, MSGID_EXTENDED), } } pub fn serialize(&self, out: &mut Vec) -> usize { @@ -295,14 +303,19 @@ where BE::write_u32(&mut out[PREAMBLE_LEN..], *v); msg_len } - Message::Extended(_) => todo!(), + Message::Extended(e) => { + e.serialize(out); + let msg_size = out.len(); + BE::write_u32(&mut out[..4], msg_size as u32); + msg_size + } } } pub fn deserialize<'a>( buf: &'a [u8], ) -> Result<(Message, usize), MessageDeserializeError> where - ByteBuf: From<&'a [u8]> + 'a, + ByteBuf: From<&'a [u8]> + 'a + Deserialize<'a>, { let len_prefix = match buf.get(0..4) { Some(bytes) => byteorder::BigEndian::read_u32(bytes), @@ -426,6 +439,27 @@ where )), } } + MSGID_EXTENDED => { + if len_prefix <= 6 { + return Err(MessageDeserializeError::IncorrectLenPrefix { + expected: 6, + received: len_prefix, + msg_id, + }); + } + // TODO: NO clue why - 1 here. Empirically figured out. + let expected_len = len_prefix as usize - 1; + match rest.get(..expected_len) { + Some(b) => Ok(( + Message::Extended(ExtendedMessage::deserialize(&b)?), + PREAMBLE_LEN + expected_len, + )), + None => Err(MessageDeserializeError::NotEnoughData( + expected_len - rest.len(), + "extended", + )), + } + } msg_id => Err(MessageDeserializeError::UnsupportedMessageId(msg_id)), } } @@ -448,9 +482,16 @@ fn bopts() -> impl bincode::Options { impl<'a> Handshake<'a> { pub fn new(info_hash: [u8; 20], peer_id: [u8; 20]) -> Handshake<'static> { debug_assert_eq!(PSTR_BT1.len(), 19); + + let mut reserved: u64 = 0; + // supports extended messaging + reserved |= 1 << 20; + let mut reserved_arr = [0u8; 8]; + BE::write_u64(&mut reserved_arr, reserved); + Handshake { pstr: PSTR_BT1, - reserved: [0; 8], + reserved: reserved_arr, info_hash, peer_id, } @@ -495,19 +536,100 @@ impl Request { #[derive(Debug)] pub enum ExtendedMessage { - Dyn(BencodeValue), - Unimplemented(PhantomData), + Handshake(ExtendedHandshake), + Dyn(u8, BencodeValue), } -struct ExtendedHandshake { - m: HashMap>, - p: Option, - v: Option, - // _phantom: PhantomData<&'a ()>, +impl ExtendedMessage { + fn serialize(&self, out: &mut Vec) { + match self { + ExtendedMessage::Dyn(msg_id, v) => { + out.push(*msg_id); + crate::serde_bencode_ser::bencode_serialize_to_writer(v, out).unwrap() + } + ExtendedMessage::Handshake(h) => { + out.push(0); + crate::serde_bencode_ser::bencode_serialize_to_writer(h, out).unwrap() + } + } + } + + fn deserialize<'de>(mut buf: &'de [u8]) -> Result + where + ByteBuf: Deserialize<'de> + From<&'de [u8]>, + { + { + use std::io::Write; + let mut f = std::fs::OpenOptions::new() + .create(true) + .write(true) + .open("/tmp/msg") + .unwrap(); + f.write_all(buf).unwrap(); + } + + use crate::serde_bencode_de::from_bytes; + + let emsg_id = buf.get(0).copied().ok_or_else(|| { + MessageDeserializeError::Other(anyhow::anyhow!( + "cannot deserialize extended message: can't read first byte" + )) + })?; + + buf = &buf.get(1..).ok_or_else(|| { + MessageDeserializeError::Other(anyhow::anyhow!( + "cannot deserialize extended message: buffer empty" + )) + })?; + + match emsg_id { + // handshake + 0 => Ok(ExtendedMessage::Handshake(from_bytes(&buf)?)), + other => Ok(ExtendedMessage::Dyn(other, from_bytes(&buf)?)), + } + // match self { + // ExtendedMessage::Dyn(v, msg) => { + // crate::bencode_value::dyn_from_bytes(buf) + // } + // ExtendedMessage::Handshake(h) => { + // crate::serde_bencode_ser::bencode_serialize_to_writer(h, out).unwrap() + // } + // } + } +} + +#[derive(Deserialize, Serialize, Debug)] +pub struct ExtendedHandshake { + #[serde(bound(deserialize = "ByteBuf: From<&'de [u8]>"))] + pub m: HashMap>, + #[serde(skip_serializing_if = "Option::is_none")] + pub p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub v: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub yourip: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ipv6: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ipv4: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reqq: Option, } #[cfg(test)] mod tests { + use std::{io::Write, net::SocketAddr, ptr::read, str::FromStr}; + + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + use crate::peer_id::generate_peer_id; + + fn decode_info_hash(hash_str: &str) -> [u8; 20] { + let mut hash_arr = [0u8; 20]; + hex::decode_to_slice(hash_str, &mut hash_arr).unwrap(); + hash_arr + } + use super::*; #[test] fn test_handshake_serialize() { @@ -520,4 +642,81 @@ mod tests { let b = dbg!(Handshake::new(info_hash, peer_id).serialize()); assert_eq!(b.len(), 20 + 20 + 8 + 19 + 1); } + + #[test] + fn test_extended_serialize() { + let mut feats = HashMap::new(); + feats.insert("whatever".as_bytes().into(), BencodeValue::Integer(1)); + + let msg = + Message::>::Extended(ExtendedMessage::Handshake(ExtendedHandshake { + m: feats, + p: None, + v: None, + yourip: None, + ipv6: None, + ipv4: None, + reqq: None, + })); + + let mut out = Vec::new(); + msg.serialize(&mut out); + dbg!(out); + } + + #[tokio::test] + async fn test_connect_to_local_qbittorrent() { + let mut stream = + tokio::net::TcpStream::connect(SocketAddr::from_str("127.0.0.1:27311").unwrap()) + .await + .unwrap(); + let peer_id = generate_peer_id(); + let info_hash = decode_info_hash("9905f844e5d8787ecd5e08fb46b2eb0a42c131d7"); + dbg!(info_hash); + let handshake = dbg!(Handshake::new(info_hash, peer_id)); + + let mut write_buf = Vec::::new(); + let h = handshake.serialize(); + + let mut read_buf = vec![0u8; 16384]; + + stream.write_all(&h).await.unwrap(); + + let read_bytes = stream.read(&mut read_buf).await.unwrap(); + let (handshake, hlen) = Handshake::deserialize(&read_buf[..read_bytes]).unwrap(); + dbg!(handshake); + + read_buf.copy_within(hlen..read_bytes, 0); + let mut read_so_far = read_bytes - hlen; + + loop { + let (message, size) = loop { + match MessageBorrowed::deserialize(&read_buf[..read_so_far]) { + Ok((msg, size)) => { + break (msg, size); + } + Err(MessageDeserializeError::NotEnoughData(d, _)) => { + if read_buf.len() < read_so_far + d { + read_buf.reserve(d); + read_buf.resize(read_buf.capacity(), 0); + } + + let size = stream.read(&mut read_buf[read_so_far..]).await.unwrap(); + if size == 0 { + panic!("size == 0, disconnected") + } + read_so_far += size; + } + Err(e) => Err(e).unwrap(), + } + }; + + dbg!(message, size); + + if read_so_far > size { + read_buf.copy_within(size..read_so_far, 0); + } + read_so_far -= size; + } + } } diff --git a/crates/librqbit/src/peer_id.rs b/crates/librqbit/src/peer_id.rs index a519965..633f1b4 100644 --- a/crates/librqbit/src/peer_id.rs +++ b/crates/librqbit/src/peer_id.rs @@ -43,3 +43,14 @@ pub enum PeerId { pub fn try_decode_peer_id(p: [u8; 20]) -> Option { Some(PeerId::AzureusStyle(try_decode_azureus_style(&p)?)) } + +pub fn generate_peer_id() -> [u8; 20] { + let mut peer_id = [0u8; 20]; + + let u = uuid::Uuid::new_v4(); + (&mut peer_id[4..20]).copy_from_slice(&u.as_bytes()[..]); + + (&mut peer_id[..8]).copy_from_slice(b"-rQ0001-"); + + peer_id +} diff --git a/crates/librqbit/src/peer_state.rs b/crates/librqbit/src/peer_state.rs index 150b147..f6f979f 100644 --- a/crates/librqbit/src/peer_state.rs +++ b/crates/librqbit/src/peer_state.rs @@ -40,7 +40,7 @@ pub struct LivePeerState { impl LivePeerState { pub fn new(peer_id: [u8; 20]) -> Self { LivePeerState { - peer_id: peer_id, + peer_id, i_am_choked: true, peer_interested: false, bitfield: None, diff --git a/crates/librqbit/src/serde_bencode_de.rs b/crates/librqbit/src/serde_bencode_de.rs index 702d50f..70d690a 100644 --- a/crates/librqbit/src/serde_bencode_de.rs +++ b/crates/librqbit/src/serde_bencode_de.rs @@ -55,7 +55,13 @@ impl<'de> BencodeDeserializer<'de> { .map_err(|e| Error::new_from_err(e).set_context(self))?; let bytes_start = length_delim + 1; let bytes_end = bytes_start + length; - let bytes = &self.buf[bytes_start..bytes_end]; + let bytes = &self.buf.get(bytes_start..bytes_end).ok_or_else(|| { + Error::custom(format!( + "could not get byte range {}..{}, data in the buffer: {:?}", + bytes_start, bytes_end, &self.buf + )) + .set_context(self) + })?; let rem = self.buf.get(bytes_end..).unwrap_or_default(); self.buf = rem; Ok(bytes) @@ -86,7 +92,14 @@ where T: serde::de::Deserialize<'a>, { let mut de = BencodeDeserializer::new_from_buf(buf); - Ok(T::deserialize(&mut de)?) + let v = T::deserialize(&mut de)?; + if !de.buf.is_empty() { + anyhow::bail!( + "deserialized successfully, but {} bytes remaining", + de.buf.len() + ) + } + Ok(v) } #[derive(Debug)] diff --git a/crates/librqbit/src/torrent_manager.rs b/crates/librqbit/src/torrent_manager.rs index eb9fe73..d482e07 100644 --- a/crates/librqbit/src/torrent_manager.rs +++ b/crates/librqbit/src/torrent_manager.rs @@ -21,6 +21,7 @@ use crate::{ file_ops::FileOps, http_api::make_and_run_http_api, lengths::Lengths, + peer_id::generate_peer_id, spawn_utils::{spawn, BlockingSpawner}, speed_estimator::SpeedEstimator, torrent_metainfo::TorrentMetaV1Owned, @@ -104,13 +105,6 @@ struct TorrentManager { force_tracker_interval: Option, } -fn generate_peer_id() -> [u8; 20] { - let mut peer_id = [0u8; 20]; - let u = uuid::Uuid::new_v4(); - (&mut peer_id[..16]).copy_from_slice(&u.as_bytes()[..]); - peer_id -} - fn make_lengths(torrent: &TorrentMetaV1Owned) -> anyhow::Result { let total_length = torrent.info.iter_file_lengths().sum(); Lengths::new(total_length, torrent.info.piece_length, None)