Extract ReadBuf logic into a separate struct

This commit is contained in:
Igor Katson 2023-12-29 20:33:37 -05:00
parent 09252c0397
commit d5d98aff60
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
5 changed files with 120 additions and 89 deletions

View file

@ -31,6 +31,7 @@ pub mod http_api;
pub mod http_api_client; pub mod http_api_client;
mod peer_connection; mod peer_connection;
mod peer_info_reader; mod peer_info_reader;
mod read_buf;
mod session; mod session;
mod spawn_utils; mod spawn_utils;
mod torrent_state; mod torrent_state;

View file

@ -10,15 +10,14 @@ use librqbit_core::{id20::Id20, lengths::ChunkInfo, peer_id::try_decode_peer_id}
use parking_lot::RwLock; use parking_lot::RwLock;
use peer_binary_protocol::{ use peer_binary_protocol::{
extended::{handshake::ExtendedHandshake, ExtendedMessage}, extended::{handshake::ExtendedHandshake, ExtendedMessage},
serialize_piece_preamble, Handshake, Message, MessageBorrowed, MessageDeserializeError, serialize_piece_preamble, Handshake, Message, MessageOwned, PIECE_MESSAGE_DEFAULT_LEN,
MessageOwned, PIECE_MESSAGE_DEFAULT_LEN,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::serde_as; use serde_with::serde_as;
use tokio::time::timeout; use tokio::time::timeout;
use tracing::trace; use tracing::trace;
use crate::spawn_utils::BlockingSpawner; use crate::{read_buf::ReadBuf, spawn_utils::BlockingSpawner};
pub trait PeerConnectionHandler { pub trait PeerConnectionHandler {
fn on_connected(&self, _connection_time: Duration) {} fn on_connected(&self, _connection_time: Duration) {}
@ -100,9 +99,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
pub async fn manage_peer_incoming( pub async fn manage_peer_incoming(
&self, &self,
outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>, outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
// How many bytes into read buffer have we read already. read_buf: ReadBuf,
read_so_far: usize,
read_buf: Vec<u8>,
handshake: Handshake<ByteString>, handshake: Handshake<ByteString>,
mut conn: tokio::net::TcpStream, mut conn: tokio::net::TcpStream,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
@ -140,7 +137,6 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
self.manage_peer( self.manage_peer(
h_supports_extended, h_supports_extended,
read_so_far,
read_buf, read_buf,
write_buf, write_buf,
conn, conn,
@ -153,7 +149,6 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
&self, &self,
outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>, outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
let rwtimeout = self let rwtimeout = self
@ -180,16 +175,11 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
.context("error writing handshake")?; .context("error writing handshake")?;
write_buf.clear(); write_buf.clear();
let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2]; let mut read_buf = ReadBuf::new();
let mut read_so_far = with_timeout(rwtimeout, conn.read(&mut read_buf)) let h = read_buf
.read_handshake(&mut conn, rwtimeout)
.await .await
.context("error reading handshake")?; .context("error reading handshake")?;
if read_so_far == 0 {
anyhow::bail!("bad handshake");
}
let (h, size) = Handshake::deserialize(&read_buf[..read_so_far])
.map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?;
let h_supports_extended = h.supports_extended(); let h_supports_extended = h.supports_extended();
trace!("connected: id={:?}", try_decode_peer_id(Id20(h.peer_id))); trace!("connected: id={:?}", try_decode_peer_id(Id20(h.peer_id)));
if h.info_hash != self.info_hash.0 { if h.info_hash != self.info_hash.0 {
@ -202,14 +192,8 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
self.handler.on_handshake(h)?; self.handler.on_handshake(h)?;
if read_so_far > size {
read_buf.copy_within(size..read_so_far, 0);
}
read_so_far -= size;
self.manage_peer( self.manage_peer(
h_supports_extended, h_supports_extended,
read_so_far,
read_buf, read_buf,
write_buf, write_buf,
conn, conn,
@ -221,14 +205,11 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
async fn manage_peer( async fn manage_peer(
&self, &self,
handshake_supports_extended: bool, handshake_supports_extended: bool,
// How many bytes into read_buf is there of peer-sent-data. mut read_buf: ReadBuf,
mut read_so_far: usize,
mut read_buf: Vec<u8>,
mut write_buf: Vec<u8>, mut write_buf: Vec<u8>,
mut conn: tokio::net::TcpStream, mut conn: tokio::net::TcpStream,
mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>, mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
let rwtimeout = self let rwtimeout = self
@ -328,47 +309,23 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
let reader = async move { let reader = async move {
loop { loop {
let (message, size) = loop { read_buf
match MessageBorrowed::deserialize(&read_buf[..read_so_far]) { .read_message(&mut read_half, rwtimeout, |message| {
Ok((msg, size)) => break (msg, size), trace!("received: {:?}", &message);
Err(MessageDeserializeError::NotEnoughData(d, _)) => {
if read_buf.len() < read_so_far + d { if let Message::Extended(ExtendedMessage::Handshake(h)) = &message {
read_buf.reserve(d); *extended_handshake_ref.write() = Some(h.clone_to_owned());
read_buf.resize(read_buf.capacity(), 0); self.handler.on_extended_handshake(h)?;
} trace!("remembered extended handshake for future serializing");
let size = with_timeout( } else {
rwtimeout, self.handler
read_half.read(&mut read_buf[read_so_far..]), .on_received_message(message)
) .context("error in handler.on_received_message()")?;
.await
.context("error reading from peer")?;
if size == 0 {
anyhow::bail!(
"disconnected while reading, read so far: {}",
read_so_far
)
}
read_so_far += size;
} }
Err(e) => return Err(e.into()), Ok(())
} })
}; .await
trace!("received: {:?}", &message); .context("error reading message")?;
if let Message::Extended(ExtendedMessage::Handshake(h)) = &message {
*extended_handshake_ref.write() = Some(h.clone_to_owned());
self.handler.on_extended_handshake(h)?;
trace!("remembered extended handshake for future serializing");
} else {
self.handler
.on_received_message(message)
.context("error in handler.on_received_message()")?;
}
if read_so_far > size {
read_buf.copy_within(size..read_so_far, 0);
}
read_so_far -= size;
} }
// For type inference. // For type inference.

View file

@ -0,0 +1,88 @@
use std::time::Duration;
use crate::peer_connection::with_timeout;
use anyhow::Context;
use buffers::ByteBuf;
use peer_binary_protocol::{
Handshake, MessageBorrowed, MessageDeserializeError, PIECE_MESSAGE_DEFAULT_LEN,
};
use tokio::io::AsyncReadExt;
pub struct ReadBuf {
buf: Vec<u8>,
read_so_far: usize,
last_size: usize,
}
impl ReadBuf {
pub fn new() -> Self {
Self {
buf: vec![0; PIECE_MESSAGE_DEFAULT_LEN * 2],
read_so_far: 0,
last_size: 0,
}
}
fn prepare_for_read(&mut self) {
if self.read_so_far > self.last_size {
self.buf.copy_within(self.last_size..self.read_so_far, 0);
}
self.read_so_far -= self.last_size;
self.last_size = 0;
}
// This MUST be run as the first operation on the buffer.
pub async fn read_handshake(
&mut self,
mut conn: impl AsyncReadExt + Unpin,
timeout: Duration,
) -> anyhow::Result<Handshake<ByteBuf<'_>>> {
self.read_so_far = with_timeout(timeout, conn.read(&mut self.buf))
.await
.context("error reading handshake")?;
if self.read_so_far == 0 {
anyhow::bail!("bad handshake");
}
let (h, size) = Handshake::deserialize(&self.buf[..self.read_so_far])
.map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?;
self.last_size = size;
Ok(h)
}
pub async fn read_message(
&mut self,
mut conn: impl AsyncReadExt + Unpin,
timeout: Duration,
on_message: impl for<'a> FnOnce(MessageBorrowed<'a>) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
self.prepare_for_read();
loop {
let need_additional_bytes =
match MessageBorrowed::deserialize(&self.buf[..self.read_so_far]) {
Err(MessageDeserializeError::NotEnoughData(d, _)) => d,
Ok((msg, size)) => {
self.last_size = size;
// Rust's borrow checker can't do this early return. So we are using a callback instead.
// return Ok(msg);
on_message(msg)?;
return Ok(());
}
Err(e) => return Err(e.into()),
};
if self.buf.len() < self.read_so_far + need_additional_bytes {
self.buf.reserve(need_additional_bytes);
self.buf.resize(self.buf.capacity(), 0);
}
let size = with_timeout(timeout, conn.read(&mut self.buf[self.read_so_far..]))
.await
.context("error reading from peer")?;
if size == 0 {
anyhow::bail!(
"disconnected while reading, read so far: {}",
self.read_so_far
)
}
self.read_so_far += size;
}
}
}

View file

@ -27,20 +27,18 @@ use librqbit_core::{
}, },
}; };
use parking_lot::RwLock; use parking_lot::RwLock;
use peer_binary_protocol::{Handshake, PIECE_MESSAGE_DEFAULT_LEN}; use peer_binary_protocol::Handshake;
use reqwest::Url; use reqwest::Url;
use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_with::serde_as; use serde_with::serde_as;
use tokio::{ use tokio::net::{TcpListener, TcpStream};
io::AsyncReadExt,
net::{TcpListener, TcpStream},
};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use tracing::{debug, error, error_span, info, trace, warn, Instrument};
use crate::{ use crate::{
dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult}, dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult},
peer_connection::{with_timeout, PeerConnectionOptions}, peer_connection::PeerConnectionOptions,
read_buf::ReadBuf,
spawn_utils::BlockingSpawner, spawn_utils::BlockingSpawner,
torrent_state::{ torrent_state::{
ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive, ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive,
@ -371,9 +369,8 @@ async fn create_tcp_listener(
pub(crate) struct CheckedIncomingConnection { pub(crate) struct CheckedIncomingConnection {
pub addr: SocketAddr, pub addr: SocketAddr,
pub stream: tokio::net::TcpStream, pub stream: tokio::net::TcpStream,
pub read_buf: Vec<u8>, pub read_buf: ReadBuf,
pub handshake: Handshake<ByteString>, pub handshake: Handshake<ByteString>,
pub read_so_far: usize,
} }
impl Session { impl Session {
@ -515,16 +512,11 @@ impl Session {
.read_write_timeout .read_write_timeout
.unwrap_or_else(|| Duration::from_secs(10)); .unwrap_or_else(|| Duration::from_secs(10));
let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2]; let mut read_buf = ReadBuf::new();
let mut read_so_far = with_timeout(rwtimeout, stream.read(&mut read_buf)) let h = read_buf
.read_handshake(&mut stream, rwtimeout)
.await .await
.context("error reading handshake")?; .context("error reading handshake")?;
if read_so_far == 0 {
anyhow::bail!("bad handshake");
}
let (h, size) = Handshake::deserialize(&read_buf[..read_so_far])
.map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?;
trace!("received handshake from {addr}: {:?}", h); trace!("received handshake from {addr}: {:?}", h);
if h.peer_id == self.peer_id.0 { if h.peer_id == self.peer_id.0 {
@ -545,11 +537,6 @@ impl Session {
let handshake = h.clone_to_owned(); let handshake = h.clone_to_owned();
if read_so_far > size {
read_buf.copy_within(size..read_so_far, 0);
}
read_so_far -= size;
return Ok(( return Ok((
live, live,
CheckedIncomingConnection { CheckedIncomingConnection {
@ -557,7 +544,6 @@ impl Session {
stream, stream,
handshake, handshake,
read_buf, read_buf,
read_so_far,
}, },
)); ));
} }

View file

@ -457,7 +457,6 @@ impl TorrentStateLive {
r = requester => {r} r = requester => {r}
r = peer_connection.manage_peer_incoming( r = peer_connection.manage_peer_incoming(
rx, rx,
checked_peer.read_so_far,
checked_peer.read_buf, checked_peer.read_buf,
checked_peer.handshake, checked_peer.handshake,
checked_peer.stream checked_peer.stream