diff --git a/TODO.md b/TODO.md index 1282f7a..c9c555d 100644 --- a/TODO.md +++ b/TODO.md @@ -34,6 +34,7 @@ incoming peers: - [ ] error managing peer: expected extended handshake, but got Bitfield(<94 bytes>) +- [ ] do not announce when merely listing the torrent someday: - [x] cancellation from the client-side for the lib (i.e. stop the torrent manager) diff --git a/crates/librqbit/src/peer_connection.rs b/crates/librqbit/src/peer_connection.rs index 5605efe..4450d63 100644 --- a/crates/librqbit/src/peer_connection.rs +++ b/crates/librqbit/src/peer_connection.rs @@ -7,6 +7,7 @@ use anyhow::{bail, Context}; use buffers::{ByteBuf, ByteString}; use clone_to_owned::CloneToOwned; use librqbit_core::{id20::Id20, lengths::ChunkInfo, peer_id::try_decode_peer_id}; +use parking_lot::RwLock; use peer_binary_protocol::{ extended::{handshake::ExtendedHandshake, ExtendedMessage}, serialize_piece_preamble, Handshake, Message, MessageBorrowed, MessageDeserializeError, @@ -261,33 +262,19 @@ impl PeerConnection { .read_write_timeout .unwrap_or_else(|| Duration::from_secs(10)); - let mut extended_handshake: Option> = None; + let extended_handshake: RwLock>> = RwLock::new(None); + let extended_handshake_ref = &extended_handshake; let supports_extended = handshake_supports_extended; if supports_extended { let my_extended = Message::Extended(ExtendedMessage::Handshake(ExtendedHandshake::new())); trace!("sending extended handshake: {:?}", &my_extended); - my_extended.serialize(&mut write_buf, None).unwrap(); + my_extended.serialize(&mut write_buf, &|| None).unwrap(); with_timeout(rwtimeout, conn.write_all(&write_buf)) .await .context("error writing extended handshake")?; write_buf.clear(); - - let (extended, size) = read_one!(conn, read_buf, read_so_far, rwtimeout); - match extended { - Message::Extended(ExtendedMessage::Handshake(h)) => { - trace!("received: {:?}", &h); - self.handler.on_extended_handshake(&h)?; - extended_handshake = Some(h.clone_to_owned()) - } - other => anyhow::bail!("expected extended handshake, but got {:?}", other), - }; - - if read_so_far > size { - read_buf.copy_within(size..read_so_far, 0); - } - read_so_far -= size; } let (mut read_half, mut write_half) = tokio::io::split(conn); @@ -320,9 +307,12 @@ impl PeerConnection { let mut uploaded_add = None; let len = match &req { - WriterRequest::Message(msg) => { - msg.serialize(&mut write_buf, extended_handshake.as_ref())? - } + WriterRequest::Message(msg) => msg.serialize(&mut write_buf, &|| { + extended_handshake_ref + .read() + .as_ref() + .and_then(|e| e.ut_metadata()) + })?, WriterRequest::ReadChunkRequest(chunk) => { // this whole section is an optimization write_buf.resize(PIECE_MESSAGE_DEFAULT_LEN, 0); @@ -366,9 +356,15 @@ impl PeerConnection { let (message, size) = read_one!(read_half, read_buf, read_so_far, rwtimeout); trace!("received: {:?}", &message); - self.handler - .on_received_message(message) - .context("error in handler.on_received_message()")?; + if let Message::Extended(ExtendedMessage::Handshake(h)) = &message { + *extended_handshake_ref.write() = Some(h.clone_to_owned()); + self.handler.on_extended_handshake(h)?; + trace!("remembered extended handshake for future serializing"); + } else { + self.handler + .on_received_message(message) + .context("error in handler.on_received_message()")?; + } if read_so_far > size { read_buf.copy_within(size..read_so_far, 0); diff --git a/crates/librqbit/src/torrent_state/live/mod.rs b/crates/librqbit/src/torrent_state/live/mod.rs index 6b19825..306661c 100644 --- a/crates/librqbit/src/torrent_state/live/mod.rs +++ b/crates/librqbit/src/torrent_state/live/mod.rs @@ -838,7 +838,7 @@ impl<'a> PeerConnectionHandler for &'a PeerHandler { fn serialize_bitfield_message_to_buf(&self, buf: &mut Vec) -> anyhow::Result { let g = self.state.lock_read("serialize_bitfield_message_to_buf"); let msg = Message::Bitfield(ByteBuf(g.get_chunks()?.get_have_pieces().as_raw_slice())); - let len = msg.serialize(buf, None)?; + let len = msg.serialize(buf, &|| None)?; trace!("sending: {:?}, length={}", &msg, len); Ok(len) } diff --git a/crates/peer_binary_protocol/src/extended/handshake.rs b/crates/peer_binary_protocol/src/extended/handshake.rs index db41803..62f0955 100644 --- a/crates/peer_binary_protocol/src/extended/handshake.rs +++ b/crates/peer_binary_protocol/src/extended/handshake.rs @@ -59,6 +59,13 @@ impl ExtendedHandshake { } }) } + + pub fn ut_metadata(&self) -> Option + where + ByteBuf: AsRef<[u8]>, + { + self.get_msgid(b"ut_metadata") + } } impl CloneToOwned for ExtendedHandshake diff --git a/crates/peer_binary_protocol/src/extended/mod.rs b/crates/peer_binary_protocol/src/extended/mod.rs index 47bb530..9ad8f32 100644 --- a/crates/peer_binary_protocol/src/extended/mod.rs +++ b/crates/peer_binary_protocol/src/extended/mod.rs @@ -1,7 +1,6 @@ use bencode::bencode_serialize_to_writer; use bencode::from_bytes; use bencode::BencodeValue; -use buffers::ByteString; use clone_to_owned::CloneToOwned; use serde::{Deserialize, Serialize}; @@ -41,7 +40,7 @@ impl<'a, ByteBuf: 'a + std::hash::Hash + Eq + Serialize> ExtendedMessage, - extended_handshake: Option<&ExtendedHandshake>, + extended_handshake_ut_metadata: &dyn Fn() -> Option, ) -> anyhow::Result<()> where ByteBuf: AsRef<[u8]>, @@ -56,12 +55,9 @@ impl<'a, ByteBuf: 'a + std::hash::Hash + Eq + Serialize> ExtendedMessage { - let h = extended_handshake.ok_or_else(|| { + let emsg_id = extended_handshake_ut_metadata().ok_or_else(|| { anyhow::anyhow!("need peer's handshake to serialize ut_metadata") })?; - let emsg_id = h - .get_msgid(b"ut_metadata") - .ok_or_else(|| anyhow::anyhow!("peer doesn't support ut_metadata"))?; out.push(emsg_id); u.serialize(out); } diff --git a/crates/peer_binary_protocol/src/lib.rs b/crates/peer_binary_protocol/src/lib.rs index 11171f7..b99edc5 100644 --- a/crates/peer_binary_protocol/src/lib.rs +++ b/crates/peer_binary_protocol/src/lib.rs @@ -11,7 +11,7 @@ use clone_to_owned::CloneToOwned; use librqbit_core::{constants::CHUNK_SIZE, id20::Id20, lengths::ChunkInfo}; use serde::{Deserialize, Serialize}; -use self::extended::{handshake::ExtendedHandshake, ExtendedMessage}; +use self::extended::ExtendedMessage; const INTEGER_LEN: usize = 4; const MSGID_LEN: usize = 1; @@ -258,7 +258,7 @@ where pub fn serialize( &self, out: &mut Vec, - peer_extended_handshake: Option<&ExtendedHandshake>, + extended_handshake_ut_metadata: &dyn Fn() -> Option, ) -> anyhow::Result { let (lp, msg_id) = self.len_prefix_and_msg_id(); @@ -308,7 +308,7 @@ where Ok(msg_len) } Message::Extended(e) => { - e.serialize(out, peer_extended_handshake)?; + e.serialize(out, extended_handshake_ut_metadata)?; let msg_size = out.len(); // no fucking idea why +1, but I tweaked that for it all to match up // with real messages. @@ -576,6 +576,8 @@ impl Request { #[cfg(test)] mod tests { + use crate::extended::handshake::ExtendedHandshake; + use super::*; #[test] fn test_handshake_serialize() { @@ -594,7 +596,7 @@ mod tests { fn test_extended_serialize() { let msg = Message::Extended(ExtendedMessage::Handshake(ExtendedHandshake::new())); let mut out = Vec::new(); - msg.serialize(&mut out, None).unwrap(); + msg.serialize(&mut out, &|| None).unwrap(); dbg!(out); } @@ -610,7 +612,7 @@ mod tests { let (msg, size) = MessageBorrowed::deserialize(&buf).unwrap(); assert_eq!(size, buf.len()); let mut write_buf = Vec::new(); - msg.serialize(&mut write_buf, None).unwrap(); + msg.serialize(&mut write_buf, &|| None).unwrap(); if buf != write_buf { { use std::io::Write;