From d5d98aff6060aeac690c558f4e93b23d7d831b8e Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Fri, 29 Dec 2023 20:33:37 -0500 Subject: [PATCH] Extract ReadBuf logic into a separate struct --- crates/librqbit/src/lib.rs | 1 + crates/librqbit/src/peer_connection.rs | 89 +++++-------------- crates/librqbit/src/read_buf.rs | 88 ++++++++++++++++++ crates/librqbit/src/session.rs | 30 ++----- crates/librqbit/src/torrent_state/live/mod.rs | 1 - 5 files changed, 120 insertions(+), 89 deletions(-) create mode 100644 crates/librqbit/src/read_buf.rs diff --git a/crates/librqbit/src/lib.rs b/crates/librqbit/src/lib.rs index 434bf27..38091bf 100644 --- a/crates/librqbit/src/lib.rs +++ b/crates/librqbit/src/lib.rs @@ -31,6 +31,7 @@ pub mod http_api; pub mod http_api_client; mod peer_connection; mod peer_info_reader; +mod read_buf; mod session; mod spawn_utils; mod torrent_state; diff --git a/crates/librqbit/src/peer_connection.rs b/crates/librqbit/src/peer_connection.rs index 1f4bfeb..e565ede 100644 --- a/crates/librqbit/src/peer_connection.rs +++ b/crates/librqbit/src/peer_connection.rs @@ -10,15 +10,14 @@ use librqbit_core::{id20::Id20, lengths::ChunkInfo, peer_id::try_decode_peer_id} use parking_lot::RwLock; use peer_binary_protocol::{ extended::{handshake::ExtendedHandshake, ExtendedMessage}, - serialize_piece_preamble, Handshake, Message, MessageBorrowed, MessageDeserializeError, - MessageOwned, PIECE_MESSAGE_DEFAULT_LEN, + serialize_piece_preamble, Handshake, Message, MessageOwned, PIECE_MESSAGE_DEFAULT_LEN, }; use serde::{Deserialize, Serialize}; use serde_with::serde_as; use tokio::time::timeout; use tracing::trace; -use crate::spawn_utils::BlockingSpawner; +use crate::{read_buf::ReadBuf, spawn_utils::BlockingSpawner}; pub trait PeerConnectionHandler { fn on_connected(&self, _connection_time: Duration) {} @@ -100,9 +99,7 @@ impl PeerConnection { pub async fn manage_peer_incoming( &self, outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, - // How many bytes into read buffer have we read already. - read_so_far: usize, - read_buf: Vec, + read_buf: ReadBuf, handshake: Handshake, mut conn: tokio::net::TcpStream, ) -> anyhow::Result<()> { @@ -140,7 +137,6 @@ impl PeerConnection { self.manage_peer( h_supports_extended, - read_so_far, read_buf, write_buf, conn, @@ -153,7 +149,6 @@ impl PeerConnection { &self, outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, ) -> anyhow::Result<()> { - use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; let rwtimeout = self @@ -180,16 +175,11 @@ impl PeerConnection { .context("error writing handshake")?; write_buf.clear(); - let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2]; - let mut read_so_far = with_timeout(rwtimeout, conn.read(&mut read_buf)) + let mut read_buf = ReadBuf::new(); + let h = read_buf + .read_handshake(&mut conn, rwtimeout) .await .context("error reading handshake")?; - if read_so_far == 0 { - anyhow::bail!("bad handshake"); - } - let (h, size) = Handshake::deserialize(&read_buf[..read_so_far]) - .map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?; - let h_supports_extended = h.supports_extended(); trace!("connected: id={:?}", try_decode_peer_id(Id20(h.peer_id))); if h.info_hash != self.info_hash.0 { @@ -202,14 +192,8 @@ impl PeerConnection { self.handler.on_handshake(h)?; - if read_so_far > size { - read_buf.copy_within(size..read_so_far, 0); - } - read_so_far -= size; - self.manage_peer( h_supports_extended, - read_so_far, read_buf, write_buf, conn, @@ -221,14 +205,11 @@ impl PeerConnection { async fn manage_peer( &self, handshake_supports_extended: bool, - // How many bytes into read_buf is there of peer-sent-data. - mut read_so_far: usize, - mut read_buf: Vec, + mut read_buf: ReadBuf, mut write_buf: Vec, mut conn: tokio::net::TcpStream, mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, ) -> anyhow::Result<()> { - use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; let rwtimeout = self @@ -328,47 +309,23 @@ impl PeerConnection { let reader = async move { loop { - let (message, size) = loop { - match MessageBorrowed::deserialize(&read_buf[..read_so_far]) { - Ok((msg, size)) => break (msg, size), - Err(MessageDeserializeError::NotEnoughData(d, _)) => { - if read_buf.len() < read_so_far + d { - read_buf.reserve(d); - read_buf.resize(read_buf.capacity(), 0); - } - let size = with_timeout( - rwtimeout, - read_half.read(&mut read_buf[read_so_far..]), - ) - .await - .context("error reading from peer")?; - if size == 0 { - anyhow::bail!( - "disconnected while reading, read so far: {}", - read_so_far - ) - } - read_so_far += size; + read_buf + .read_message(&mut read_half, rwtimeout, |message| { + trace!("received: {:?}", &message); + + if let Message::Extended(ExtendedMessage::Handshake(h)) = &message { + *extended_handshake_ref.write() = Some(h.clone_to_owned()); + self.handler.on_extended_handshake(h)?; + trace!("remembered extended handshake for future serializing"); + } else { + self.handler + .on_received_message(message) + .context("error in handler.on_received_message()")?; } - Err(e) => return Err(e.into()), - } - }; - trace!("received: {:?}", &message); - - if let Message::Extended(ExtendedMessage::Handshake(h)) = &message { - *extended_handshake_ref.write() = Some(h.clone_to_owned()); - self.handler.on_extended_handshake(h)?; - trace!("remembered extended handshake for future serializing"); - } else { - self.handler - .on_received_message(message) - .context("error in handler.on_received_message()")?; - } - - if read_so_far > size { - read_buf.copy_within(size..read_so_far, 0); - } - read_so_far -= size; + Ok(()) + }) + .await + .context("error reading message")?; } // For type inference. diff --git a/crates/librqbit/src/read_buf.rs b/crates/librqbit/src/read_buf.rs new file mode 100644 index 0000000..eef884e --- /dev/null +++ b/crates/librqbit/src/read_buf.rs @@ -0,0 +1,88 @@ +use std::time::Duration; + +use crate::peer_connection::with_timeout; +use anyhow::Context; +use buffers::ByteBuf; +use peer_binary_protocol::{ + Handshake, MessageBorrowed, MessageDeserializeError, PIECE_MESSAGE_DEFAULT_LEN, +}; +use tokio::io::AsyncReadExt; + +pub struct ReadBuf { + buf: Vec, + read_so_far: usize, + last_size: usize, +} + +impl ReadBuf { + pub fn new() -> Self { + Self { + buf: vec![0; PIECE_MESSAGE_DEFAULT_LEN * 2], + read_so_far: 0, + last_size: 0, + } + } + + fn prepare_for_read(&mut self) { + if self.read_so_far > self.last_size { + self.buf.copy_within(self.last_size..self.read_so_far, 0); + } + self.read_so_far -= self.last_size; + self.last_size = 0; + } + + // This MUST be run as the first operation on the buffer. + pub async fn read_handshake( + &mut self, + mut conn: impl AsyncReadExt + Unpin, + timeout: Duration, + ) -> anyhow::Result>> { + self.read_so_far = with_timeout(timeout, conn.read(&mut self.buf)) + .await + .context("error reading handshake")?; + if self.read_so_far == 0 { + anyhow::bail!("bad handshake"); + } + let (h, size) = Handshake::deserialize(&self.buf[..self.read_so_far]) + .map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?; + self.last_size = size; + Ok(h) + } + + pub async fn read_message( + &mut self, + mut conn: impl AsyncReadExt + Unpin, + timeout: Duration, + on_message: impl for<'a> FnOnce(MessageBorrowed<'a>) -> anyhow::Result<()>, + ) -> anyhow::Result<()> { + self.prepare_for_read(); + loop { + let need_additional_bytes = + match MessageBorrowed::deserialize(&self.buf[..self.read_so_far]) { + Err(MessageDeserializeError::NotEnoughData(d, _)) => d, + Ok((msg, size)) => { + self.last_size = size; + // Rust's borrow checker can't do this early return. So we are using a callback instead. + // return Ok(msg); + on_message(msg)?; + return Ok(()); + } + Err(e) => return Err(e.into()), + }; + if self.buf.len() < self.read_so_far + need_additional_bytes { + self.buf.reserve(need_additional_bytes); + self.buf.resize(self.buf.capacity(), 0); + } + let size = with_timeout(timeout, conn.read(&mut self.buf[self.read_so_far..])) + .await + .context("error reading from peer")?; + if size == 0 { + anyhow::bail!( + "disconnected while reading, read so far: {}", + self.read_so_far + ) + } + self.read_so_far += size; + } + } +} diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 46642aa..0f1e7c6 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -27,20 +27,18 @@ use librqbit_core::{ }, }; use parking_lot::RwLock; -use peer_binary_protocol::{Handshake, PIECE_MESSAGE_DEFAULT_LEN}; +use peer_binary_protocol::Handshake; use reqwest::Url; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_with::serde_as; -use tokio::{ - io::AsyncReadExt, - net::{TcpListener, TcpStream}, -}; +use tokio::net::{TcpListener, TcpStream}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use crate::{ dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult}, - peer_connection::{with_timeout, PeerConnectionOptions}, + peer_connection::PeerConnectionOptions, + read_buf::ReadBuf, spawn_utils::BlockingSpawner, torrent_state::{ ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive, @@ -371,9 +369,8 @@ async fn create_tcp_listener( pub(crate) struct CheckedIncomingConnection { pub addr: SocketAddr, pub stream: tokio::net::TcpStream, - pub read_buf: Vec, + pub read_buf: ReadBuf, pub handshake: Handshake, - pub read_so_far: usize, } impl Session { @@ -515,16 +512,11 @@ impl Session { .read_write_timeout .unwrap_or_else(|| Duration::from_secs(10)); - let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2]; - let mut read_so_far = with_timeout(rwtimeout, stream.read(&mut read_buf)) + let mut read_buf = ReadBuf::new(); + let h = read_buf + .read_handshake(&mut stream, rwtimeout) .await .context("error reading handshake")?; - if read_so_far == 0 { - anyhow::bail!("bad handshake"); - } - let (h, size) = Handshake::deserialize(&read_buf[..read_so_far]) - .map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?; - trace!("received handshake from {addr}: {:?}", h); if h.peer_id == self.peer_id.0 { @@ -545,11 +537,6 @@ impl Session { let handshake = h.clone_to_owned(); - if read_so_far > size { - read_buf.copy_within(size..read_so_far, 0); - } - read_so_far -= size; - return Ok(( live, CheckedIncomingConnection { @@ -557,7 +544,6 @@ impl Session { stream, handshake, read_buf, - read_so_far, }, )); } diff --git a/crates/librqbit/src/torrent_state/live/mod.rs b/crates/librqbit/src/torrent_state/live/mod.rs index 72a7db5..83c459a 100644 --- a/crates/librqbit/src/torrent_state/live/mod.rs +++ b/crates/librqbit/src/torrent_state/live/mod.rs @@ -457,7 +457,6 @@ impl TorrentStateLive { r = requester => {r} r = peer_connection.manage_peer_incoming( rx, - checked_peer.read_so_far, checked_peer.read_buf, checked_peer.handshake, checked_peer.stream