From 09252c039731b4540357a5218804c4729ac0805d Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Fri, 29 Dec 2023 18:54:08 -0500 Subject: [PATCH 1/3] Remove "read_one" macro --- crates/librqbit/src/peer_connection.rs | 57 ++++++++++++-------------- crates/librqbit/src/session.rs | 14 ++++++- 2 files changed, 39 insertions(+), 32 deletions(-) diff --git a/crates/librqbit/src/peer_connection.rs b/crates/librqbit/src/peer_connection.rs index 9a429bf..1f4bfeb 100644 --- a/crates/librqbit/src/peer_connection.rs +++ b/crates/librqbit/src/peer_connection.rs @@ -76,32 +76,6 @@ where } } -macro_rules! read_one { - ($conn:ident, $read_buf:ident, $read_so_far:ident, $rwtimeout:ident) => {{ - let (extended, size) = loop { - match MessageBorrowed::deserialize(&$read_buf[..$read_so_far]) { - Ok((msg, size)) => break (msg, size), - Err(MessageDeserializeError::NotEnoughData(d, _)) => { - if $read_buf.len() < $read_so_far + d { - $read_buf.reserve(d); - $read_buf.resize($read_buf.capacity(), 0); - } - - let size = with_timeout($rwtimeout, $conn.read(&mut $read_buf[$read_so_far..])) - .await - .context("error reading from peer")?; - if size == 0 { - anyhow::bail!("disconnected while reading, read so far: {}", $read_so_far) - } - $read_so_far += size; - } - Err(e) => return Err(e.into()), - } - }; - (extended, size) - }}; -} - impl PeerConnection { pub fn new( addr: SocketAddr, @@ -354,7 +328,31 @@ impl PeerConnection { let reader = async move { loop { - let (message, size) = read_one!(read_half, read_buf, read_so_far, rwtimeout); + let (message, size) = loop { + match MessageBorrowed::deserialize(&read_buf[..read_so_far]) { + Ok((msg, size)) => break (msg, size), + Err(MessageDeserializeError::NotEnoughData(d, _)) => { + if read_buf.len() < read_so_far + d { + read_buf.reserve(d); + read_buf.resize(read_buf.capacity(), 0); + } + let size = with_timeout( + rwtimeout, + read_half.read(&mut read_buf[read_so_far..]), + ) + .await + .context("error reading from peer")?; + if size == 0 { + anyhow::bail!( + "disconnected while reading, read so far: {}", + read_so_far + ) + } + read_so_far += size; + } + Err(e) => return Err(e.into()), + } + }; trace!("received: {:?}", &message); if let Message::Extended(ExtendedMessage::Handshake(h)) = &message { @@ -378,7 +376,7 @@ impl PeerConnection { Ok::<_, anyhow::Error>(()) }; - let r = tokio::select! { + tokio::select! { r = reader => { trace!("reader is done, exiting"); r @@ -387,7 +385,6 @@ impl PeerConnection { trace!("writer is done, exiting"); r } - }; - r + } } } diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index c825988..46642aa 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -11,7 +11,7 @@ use std::{ use anyhow::{bail, Context}; use bencode::{bencode_serialize_to_writer, BencodeDeserializer}; -use buffers::{ByteBufT, ByteString}; +use buffers::{ByteBuf, ByteBufT, ByteString}; use clone_to_owned::CloneToOwned; use dht::{ Dht, DhtBuilder, DhtConfig, Id20, PersistentDht, PersistentDhtConfig, RequestPeersStream, @@ -22,7 +22,9 @@ use librqbit_core::{ magnet::Magnet, peer_id::generate_peer_id, spawn_utils::spawn_with_cancel, - torrent_metainfo::{torrent_from_bytes, TorrentMetaV1Info, TorrentMetaV1Owned}, + torrent_metainfo::{ + torrent_from_bytes as bencode_torrent_from_bytes, TorrentMetaV1Info, TorrentMetaV1Owned, + }, }; use parking_lot::RwLock; use peer_binary_protocol::{Handshake, PIECE_MESSAGE_DEFAULT_LEN}; @@ -49,6 +51,14 @@ pub const SUPPORTED_SCHEMES: [&str; 3] = ["http:", "https:", "magnet:"]; pub type TorrentId = usize; +fn torrent_from_bytes(bytes: &[u8]) -> anyhow::Result { + debug!( + "all fields in torrent: {:#?}", + bencode::dyn_from_bytes::(bytes) + ); + bencode_torrent_from_bytes(bytes) +} + #[derive(Default)] pub struct SessionDatabase { next_id: TorrentId, From d5d98aff6060aeac690c558f4e93b23d7d831b8e Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Fri, 29 Dec 2023 20:33:37 -0500 Subject: [PATCH 2/3] Extract ReadBuf logic into a separate struct --- crates/librqbit/src/lib.rs | 1 + crates/librqbit/src/peer_connection.rs | 89 +++++-------------- crates/librqbit/src/read_buf.rs | 88 ++++++++++++++++++ crates/librqbit/src/session.rs | 30 ++----- crates/librqbit/src/torrent_state/live/mod.rs | 1 - 5 files changed, 120 insertions(+), 89 deletions(-) create mode 100644 crates/librqbit/src/read_buf.rs diff --git a/crates/librqbit/src/lib.rs b/crates/librqbit/src/lib.rs index 434bf27..38091bf 100644 --- a/crates/librqbit/src/lib.rs +++ b/crates/librqbit/src/lib.rs @@ -31,6 +31,7 @@ pub mod http_api; pub mod http_api_client; mod peer_connection; mod peer_info_reader; +mod read_buf; mod session; mod spawn_utils; mod torrent_state; diff --git a/crates/librqbit/src/peer_connection.rs b/crates/librqbit/src/peer_connection.rs index 1f4bfeb..e565ede 100644 --- a/crates/librqbit/src/peer_connection.rs +++ b/crates/librqbit/src/peer_connection.rs @@ -10,15 +10,14 @@ use librqbit_core::{id20::Id20, lengths::ChunkInfo, peer_id::try_decode_peer_id} use parking_lot::RwLock; use peer_binary_protocol::{ extended::{handshake::ExtendedHandshake, ExtendedMessage}, - serialize_piece_preamble, Handshake, Message, MessageBorrowed, MessageDeserializeError, - MessageOwned, PIECE_MESSAGE_DEFAULT_LEN, + serialize_piece_preamble, Handshake, Message, MessageOwned, PIECE_MESSAGE_DEFAULT_LEN, }; use serde::{Deserialize, Serialize}; use serde_with::serde_as; use tokio::time::timeout; use tracing::trace; -use crate::spawn_utils::BlockingSpawner; +use crate::{read_buf::ReadBuf, spawn_utils::BlockingSpawner}; pub trait PeerConnectionHandler { fn on_connected(&self, _connection_time: Duration) {} @@ -100,9 +99,7 @@ impl PeerConnection { pub async fn manage_peer_incoming( &self, outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, - // How many bytes into read buffer have we read already. - read_so_far: usize, - read_buf: Vec, + read_buf: ReadBuf, handshake: Handshake, mut conn: tokio::net::TcpStream, ) -> anyhow::Result<()> { @@ -140,7 +137,6 @@ impl PeerConnection { self.manage_peer( h_supports_extended, - read_so_far, read_buf, write_buf, conn, @@ -153,7 +149,6 @@ impl PeerConnection { &self, outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, ) -> anyhow::Result<()> { - use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; let rwtimeout = self @@ -180,16 +175,11 @@ impl PeerConnection { .context("error writing handshake")?; write_buf.clear(); - let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2]; - let mut read_so_far = with_timeout(rwtimeout, conn.read(&mut read_buf)) + let mut read_buf = ReadBuf::new(); + let h = read_buf + .read_handshake(&mut conn, rwtimeout) .await .context("error reading handshake")?; - if read_so_far == 0 { - anyhow::bail!("bad handshake"); - } - let (h, size) = Handshake::deserialize(&read_buf[..read_so_far]) - .map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?; - let h_supports_extended = h.supports_extended(); trace!("connected: id={:?}", try_decode_peer_id(Id20(h.peer_id))); if h.info_hash != self.info_hash.0 { @@ -202,14 +192,8 @@ impl PeerConnection { self.handler.on_handshake(h)?; - if read_so_far > size { - read_buf.copy_within(size..read_so_far, 0); - } - read_so_far -= size; - self.manage_peer( h_supports_extended, - read_so_far, read_buf, write_buf, conn, @@ -221,14 +205,11 @@ impl PeerConnection { async fn manage_peer( &self, handshake_supports_extended: bool, - // How many bytes into read_buf is there of peer-sent-data. - mut read_so_far: usize, - mut read_buf: Vec, + mut read_buf: ReadBuf, mut write_buf: Vec, mut conn: tokio::net::TcpStream, mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, ) -> anyhow::Result<()> { - use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; let rwtimeout = self @@ -328,47 +309,23 @@ impl PeerConnection { let reader = async move { loop { - let (message, size) = loop { - match MessageBorrowed::deserialize(&read_buf[..read_so_far]) { - Ok((msg, size)) => break (msg, size), - Err(MessageDeserializeError::NotEnoughData(d, _)) => { - if read_buf.len() < read_so_far + d { - read_buf.reserve(d); - read_buf.resize(read_buf.capacity(), 0); - } - let size = with_timeout( - rwtimeout, - read_half.read(&mut read_buf[read_so_far..]), - ) - .await - .context("error reading from peer")?; - if size == 0 { - anyhow::bail!( - "disconnected while reading, read so far: {}", - read_so_far - ) - } - read_so_far += size; + read_buf + .read_message(&mut read_half, rwtimeout, |message| { + trace!("received: {:?}", &message); + + if let Message::Extended(ExtendedMessage::Handshake(h)) = &message { + *extended_handshake_ref.write() = Some(h.clone_to_owned()); + self.handler.on_extended_handshake(h)?; + trace!("remembered extended handshake for future serializing"); + } else { + self.handler + .on_received_message(message) + .context("error in handler.on_received_message()")?; } - Err(e) => return Err(e.into()), - } - }; - trace!("received: {:?}", &message); - - if let Message::Extended(ExtendedMessage::Handshake(h)) = &message { - *extended_handshake_ref.write() = Some(h.clone_to_owned()); - self.handler.on_extended_handshake(h)?; - trace!("remembered extended handshake for future serializing"); - } else { - self.handler - .on_received_message(message) - .context("error in handler.on_received_message()")?; - } - - if read_so_far > size { - read_buf.copy_within(size..read_so_far, 0); - } - read_so_far -= size; + Ok(()) + }) + .await + .context("error reading message")?; } // For type inference. diff --git a/crates/librqbit/src/read_buf.rs b/crates/librqbit/src/read_buf.rs new file mode 100644 index 0000000..eef884e --- /dev/null +++ b/crates/librqbit/src/read_buf.rs @@ -0,0 +1,88 @@ +use std::time::Duration; + +use crate::peer_connection::with_timeout; +use anyhow::Context; +use buffers::ByteBuf; +use peer_binary_protocol::{ + Handshake, MessageBorrowed, MessageDeserializeError, PIECE_MESSAGE_DEFAULT_LEN, +}; +use tokio::io::AsyncReadExt; + +pub struct ReadBuf { + buf: Vec, + read_so_far: usize, + last_size: usize, +} + +impl ReadBuf { + pub fn new() -> Self { + Self { + buf: vec![0; PIECE_MESSAGE_DEFAULT_LEN * 2], + read_so_far: 0, + last_size: 0, + } + } + + fn prepare_for_read(&mut self) { + if self.read_so_far > self.last_size { + self.buf.copy_within(self.last_size..self.read_so_far, 0); + } + self.read_so_far -= self.last_size; + self.last_size = 0; + } + + // This MUST be run as the first operation on the buffer. + pub async fn read_handshake( + &mut self, + mut conn: impl AsyncReadExt + Unpin, + timeout: Duration, + ) -> anyhow::Result>> { + self.read_so_far = with_timeout(timeout, conn.read(&mut self.buf)) + .await + .context("error reading handshake")?; + if self.read_so_far == 0 { + anyhow::bail!("bad handshake"); + } + let (h, size) = Handshake::deserialize(&self.buf[..self.read_so_far]) + .map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?; + self.last_size = size; + Ok(h) + } + + pub async fn read_message( + &mut self, + mut conn: impl AsyncReadExt + Unpin, + timeout: Duration, + on_message: impl for<'a> FnOnce(MessageBorrowed<'a>) -> anyhow::Result<()>, + ) -> anyhow::Result<()> { + self.prepare_for_read(); + loop { + let need_additional_bytes = + match MessageBorrowed::deserialize(&self.buf[..self.read_so_far]) { + Err(MessageDeserializeError::NotEnoughData(d, _)) => d, + Ok((msg, size)) => { + self.last_size = size; + // Rust's borrow checker can't do this early return. So we are using a callback instead. + // return Ok(msg); + on_message(msg)?; + return Ok(()); + } + Err(e) => return Err(e.into()), + }; + if self.buf.len() < self.read_so_far + need_additional_bytes { + self.buf.reserve(need_additional_bytes); + self.buf.resize(self.buf.capacity(), 0); + } + let size = with_timeout(timeout, conn.read(&mut self.buf[self.read_so_far..])) + .await + .context("error reading from peer")?; + if size == 0 { + anyhow::bail!( + "disconnected while reading, read so far: {}", + self.read_so_far + ) + } + self.read_so_far += size; + } + } +} diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 46642aa..0f1e7c6 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -27,20 +27,18 @@ use librqbit_core::{ }, }; use parking_lot::RwLock; -use peer_binary_protocol::{Handshake, PIECE_MESSAGE_DEFAULT_LEN}; +use peer_binary_protocol::Handshake; use reqwest::Url; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_with::serde_as; -use tokio::{ - io::AsyncReadExt, - net::{TcpListener, TcpStream}, -}; +use tokio::net::{TcpListener, TcpStream}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use crate::{ dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult}, - peer_connection::{with_timeout, PeerConnectionOptions}, + peer_connection::PeerConnectionOptions, + read_buf::ReadBuf, spawn_utils::BlockingSpawner, torrent_state::{ ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive, @@ -371,9 +369,8 @@ async fn create_tcp_listener( pub(crate) struct CheckedIncomingConnection { pub addr: SocketAddr, pub stream: tokio::net::TcpStream, - pub read_buf: Vec, + pub read_buf: ReadBuf, pub handshake: Handshake, - pub read_so_far: usize, } impl Session { @@ -515,16 +512,11 @@ impl Session { .read_write_timeout .unwrap_or_else(|| Duration::from_secs(10)); - let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2]; - let mut read_so_far = with_timeout(rwtimeout, stream.read(&mut read_buf)) + let mut read_buf = ReadBuf::new(); + let h = read_buf + .read_handshake(&mut stream, rwtimeout) .await .context("error reading handshake")?; - if read_so_far == 0 { - anyhow::bail!("bad handshake"); - } - let (h, size) = Handshake::deserialize(&read_buf[..read_so_far]) - .map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?; - trace!("received handshake from {addr}: {:?}", h); if h.peer_id == self.peer_id.0 { @@ -545,11 +537,6 @@ impl Session { let handshake = h.clone_to_owned(); - if read_so_far > size { - read_buf.copy_within(size..read_so_far, 0); - } - read_so_far -= size; - return Ok(( live, CheckedIncomingConnection { @@ -557,7 +544,6 @@ impl Session { stream, handshake, read_buf, - read_so_far, }, )); } diff --git a/crates/librqbit/src/torrent_state/live/mod.rs b/crates/librqbit/src/torrent_state/live/mod.rs index 72a7db5..83c459a 100644 --- a/crates/librqbit/src/torrent_state/live/mod.rs +++ b/crates/librqbit/src/torrent_state/live/mod.rs @@ -457,7 +457,6 @@ impl TorrentStateLive { r = requester => {r} r = peer_connection.manage_peer_incoming( rx, - checked_peer.read_so_far, checked_peer.read_buf, checked_peer.handshake, checked_peer.stream From 8ee98548f54552307f394b5aec253c177756833d Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Fri, 29 Dec 2023 20:56:25 -0500 Subject: [PATCH 3/3] A tiny optimisation to not memcpy the buffer as often if we have many messages already buffered --- crates/librqbit/src/read_buf.rs | 62 +++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/crates/librqbit/src/read_buf.rs b/crates/librqbit/src/read_buf.rs index eef884e..5867db7 100644 --- a/crates/librqbit/src/read_buf.rs +++ b/crates/librqbit/src/read_buf.rs @@ -10,58 +10,72 @@ use tokio::io::AsyncReadExt; pub struct ReadBuf { buf: Vec, - read_so_far: usize, - last_size: usize, + // How many bytes into the buffer we have read from the connection. + // New reads should go past this. + filled: usize, + // How many bytes have we successfully deserialized. + processed: usize, } impl ReadBuf { pub fn new() -> Self { Self { buf: vec![0; PIECE_MESSAGE_DEFAULT_LEN * 2], - read_so_far: 0, - last_size: 0, + filled: 0, + processed: 0, } } - fn prepare_for_read(&mut self) { - if self.read_so_far > self.last_size { - self.buf.copy_within(self.last_size..self.read_so_far, 0); + fn prepare_for_read(&mut self, need_additional_bytes: usize) { + // Ensure the buffer starts from the to-be-deserialized message. + if self.processed > 0 { + if self.filled > self.processed { + self.buf.copy_within(self.processed..self.filled, 0); + } + self.filled -= self.processed; + self.processed = 0; + } + + // Ensure we have enough capacity to deserialize the message. + if self.buf.len() < self.filled + need_additional_bytes { + self.buf.reserve(need_additional_bytes); + self.buf.resize(self.buf.capacity(), 0); } - self.read_so_far -= self.last_size; - self.last_size = 0; } + // Read the BT handshake. // This MUST be run as the first operation on the buffer. pub async fn read_handshake( &mut self, mut conn: impl AsyncReadExt + Unpin, timeout: Duration, ) -> anyhow::Result>> { - self.read_so_far = with_timeout(timeout, conn.read(&mut self.buf)) + self.filled = with_timeout(timeout, conn.read(&mut self.buf)) .await .context("error reading handshake")?; - if self.read_so_far == 0 { - anyhow::bail!("bad handshake"); + if self.filled == 0 { + anyhow::bail!("peer disconnected while reading handshake"); } - let (h, size) = Handshake::deserialize(&self.buf[..self.read_so_far]) + let (h, size) = Handshake::deserialize(&self.buf[..self.filled]) .map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?; - self.last_size = size; + self.processed = size; Ok(h) } + // Read a message into the buffer, try to deserialize it and call the callback on it. + // We can't return the message because of a borrow checker issue. pub async fn read_message( &mut self, mut conn: impl AsyncReadExt + Unpin, timeout: Duration, on_message: impl for<'a> FnOnce(MessageBorrowed<'a>) -> anyhow::Result<()>, ) -> anyhow::Result<()> { - self.prepare_for_read(); loop { let need_additional_bytes = - match MessageBorrowed::deserialize(&self.buf[..self.read_so_far]) { + match MessageBorrowed::deserialize(&self.buf[self.processed..self.filled]) { Err(MessageDeserializeError::NotEnoughData(d, _)) => d, Ok((msg, size)) => { - self.last_size = size; + self.processed += size; // Rust's borrow checker can't do this early return. So we are using a callback instead. // return Ok(msg); on_message(msg)?; @@ -69,20 +83,14 @@ impl ReadBuf { } Err(e) => return Err(e.into()), }; - if self.buf.len() < self.read_so_far + need_additional_bytes { - self.buf.reserve(need_additional_bytes); - self.buf.resize(self.buf.capacity(), 0); - } - let size = with_timeout(timeout, conn.read(&mut self.buf[self.read_so_far..])) + self.prepare_for_read(need_additional_bytes); + let size = with_timeout(timeout, conn.read(&mut self.buf[self.filled..])) .await .context("error reading from peer")?; if size == 0 { - anyhow::bail!( - "disconnected while reading, read so far: {}", - self.read_so_far - ) + anyhow::bail!("disconnected while reading, read so far: {}", self.filled) } - self.read_so_far += size; + self.filled += size; } } }