Extract ReadBuf logic into a separate struct
This commit is contained in:
parent
09252c0397
commit
d5d98aff60
5 changed files with 120 additions and 89 deletions
|
|
@ -31,6 +31,7 @@ pub mod http_api;
|
||||||
pub mod http_api_client;
|
pub mod http_api_client;
|
||||||
mod peer_connection;
|
mod peer_connection;
|
||||||
mod peer_info_reader;
|
mod peer_info_reader;
|
||||||
|
mod read_buf;
|
||||||
mod session;
|
mod session;
|
||||||
mod spawn_utils;
|
mod spawn_utils;
|
||||||
mod torrent_state;
|
mod torrent_state;
|
||||||
|
|
|
||||||
|
|
@ -10,15 +10,14 @@ use librqbit_core::{id20::Id20, lengths::ChunkInfo, peer_id::try_decode_peer_id}
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
use peer_binary_protocol::{
|
use peer_binary_protocol::{
|
||||||
extended::{handshake::ExtendedHandshake, ExtendedMessage},
|
extended::{handshake::ExtendedHandshake, ExtendedMessage},
|
||||||
serialize_piece_preamble, Handshake, Message, MessageBorrowed, MessageDeserializeError,
|
serialize_piece_preamble, Handshake, Message, MessageOwned, PIECE_MESSAGE_DEFAULT_LEN,
|
||||||
MessageOwned, PIECE_MESSAGE_DEFAULT_LEN,
|
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_with::serde_as;
|
use serde_with::serde_as;
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
use tracing::trace;
|
use tracing::trace;
|
||||||
|
|
||||||
use crate::spawn_utils::BlockingSpawner;
|
use crate::{read_buf::ReadBuf, spawn_utils::BlockingSpawner};
|
||||||
|
|
||||||
pub trait PeerConnectionHandler {
|
pub trait PeerConnectionHandler {
|
||||||
fn on_connected(&self, _connection_time: Duration) {}
|
fn on_connected(&self, _connection_time: Duration) {}
|
||||||
|
|
@ -100,9 +99,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
||||||
pub async fn manage_peer_incoming(
|
pub async fn manage_peer_incoming(
|
||||||
&self,
|
&self,
|
||||||
outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
|
outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
|
||||||
// How many bytes into read buffer have we read already.
|
read_buf: ReadBuf,
|
||||||
read_so_far: usize,
|
|
||||||
read_buf: Vec<u8>,
|
|
||||||
handshake: Handshake<ByteString>,
|
handshake: Handshake<ByteString>,
|
||||||
mut conn: tokio::net::TcpStream,
|
mut conn: tokio::net::TcpStream,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
|
|
@ -140,7 +137,6 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
||||||
|
|
||||||
self.manage_peer(
|
self.manage_peer(
|
||||||
h_supports_extended,
|
h_supports_extended,
|
||||||
read_so_far,
|
|
||||||
read_buf,
|
read_buf,
|
||||||
write_buf,
|
write_buf,
|
||||||
conn,
|
conn,
|
||||||
|
|
@ -153,7 +149,6 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
||||||
&self,
|
&self,
|
||||||
outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
|
outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
use tokio::io::AsyncReadExt;
|
|
||||||
use tokio::io::AsyncWriteExt;
|
use tokio::io::AsyncWriteExt;
|
||||||
|
|
||||||
let rwtimeout = self
|
let rwtimeout = self
|
||||||
|
|
@ -180,16 +175,11 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
||||||
.context("error writing handshake")?;
|
.context("error writing handshake")?;
|
||||||
write_buf.clear();
|
write_buf.clear();
|
||||||
|
|
||||||
let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2];
|
let mut read_buf = ReadBuf::new();
|
||||||
let mut read_so_far = with_timeout(rwtimeout, conn.read(&mut read_buf))
|
let h = read_buf
|
||||||
|
.read_handshake(&mut conn, rwtimeout)
|
||||||
.await
|
.await
|
||||||
.context("error reading handshake")?;
|
.context("error reading handshake")?;
|
||||||
if read_so_far == 0 {
|
|
||||||
anyhow::bail!("bad handshake");
|
|
||||||
}
|
|
||||||
let (h, size) = Handshake::deserialize(&read_buf[..read_so_far])
|
|
||||||
.map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?;
|
|
||||||
|
|
||||||
let h_supports_extended = h.supports_extended();
|
let h_supports_extended = h.supports_extended();
|
||||||
trace!("connected: id={:?}", try_decode_peer_id(Id20(h.peer_id)));
|
trace!("connected: id={:?}", try_decode_peer_id(Id20(h.peer_id)));
|
||||||
if h.info_hash != self.info_hash.0 {
|
if h.info_hash != self.info_hash.0 {
|
||||||
|
|
@ -202,14 +192,8 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
||||||
|
|
||||||
self.handler.on_handshake(h)?;
|
self.handler.on_handshake(h)?;
|
||||||
|
|
||||||
if read_so_far > size {
|
|
||||||
read_buf.copy_within(size..read_so_far, 0);
|
|
||||||
}
|
|
||||||
read_so_far -= size;
|
|
||||||
|
|
||||||
self.manage_peer(
|
self.manage_peer(
|
||||||
h_supports_extended,
|
h_supports_extended,
|
||||||
read_so_far,
|
|
||||||
read_buf,
|
read_buf,
|
||||||
write_buf,
|
write_buf,
|
||||||
conn,
|
conn,
|
||||||
|
|
@ -221,14 +205,11 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
||||||
async fn manage_peer(
|
async fn manage_peer(
|
||||||
&self,
|
&self,
|
||||||
handshake_supports_extended: bool,
|
handshake_supports_extended: bool,
|
||||||
// How many bytes into read_buf is there of peer-sent-data.
|
mut read_buf: ReadBuf,
|
||||||
mut read_so_far: usize,
|
|
||||||
mut read_buf: Vec<u8>,
|
|
||||||
mut write_buf: Vec<u8>,
|
mut write_buf: Vec<u8>,
|
||||||
mut conn: tokio::net::TcpStream,
|
mut conn: tokio::net::TcpStream,
|
||||||
mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
|
mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
use tokio::io::AsyncReadExt;
|
|
||||||
use tokio::io::AsyncWriteExt;
|
use tokio::io::AsyncWriteExt;
|
||||||
|
|
||||||
let rwtimeout = self
|
let rwtimeout = self
|
||||||
|
|
@ -328,47 +309,23 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
||||||
|
|
||||||
let reader = async move {
|
let reader = async move {
|
||||||
loop {
|
loop {
|
||||||
let (message, size) = loop {
|
read_buf
|
||||||
match MessageBorrowed::deserialize(&read_buf[..read_so_far]) {
|
.read_message(&mut read_half, rwtimeout, |message| {
|
||||||
Ok((msg, size)) => break (msg, size),
|
trace!("received: {:?}", &message);
|
||||||
Err(MessageDeserializeError::NotEnoughData(d, _)) => {
|
|
||||||
if read_buf.len() < read_so_far + d {
|
if let Message::Extended(ExtendedMessage::Handshake(h)) = &message {
|
||||||
read_buf.reserve(d);
|
*extended_handshake_ref.write() = Some(h.clone_to_owned());
|
||||||
read_buf.resize(read_buf.capacity(), 0);
|
self.handler.on_extended_handshake(h)?;
|
||||||
}
|
trace!("remembered extended handshake for future serializing");
|
||||||
let size = with_timeout(
|
} else {
|
||||||
rwtimeout,
|
self.handler
|
||||||
read_half.read(&mut read_buf[read_so_far..]),
|
.on_received_message(message)
|
||||||
)
|
.context("error in handler.on_received_message()")?;
|
||||||
.await
|
|
||||||
.context("error reading from peer")?;
|
|
||||||
if size == 0 {
|
|
||||||
anyhow::bail!(
|
|
||||||
"disconnected while reading, read so far: {}",
|
|
||||||
read_so_far
|
|
||||||
)
|
|
||||||
}
|
|
||||||
read_so_far += size;
|
|
||||||
}
|
}
|
||||||
Err(e) => return Err(e.into()),
|
Ok(())
|
||||||
}
|
})
|
||||||
};
|
.await
|
||||||
trace!("received: {:?}", &message);
|
.context("error reading message")?;
|
||||||
|
|
||||||
if let Message::Extended(ExtendedMessage::Handshake(h)) = &message {
|
|
||||||
*extended_handshake_ref.write() = Some(h.clone_to_owned());
|
|
||||||
self.handler.on_extended_handshake(h)?;
|
|
||||||
trace!("remembered extended handshake for future serializing");
|
|
||||||
} else {
|
|
||||||
self.handler
|
|
||||||
.on_received_message(message)
|
|
||||||
.context("error in handler.on_received_message()")?;
|
|
||||||
}
|
|
||||||
|
|
||||||
if read_so_far > size {
|
|
||||||
read_buf.copy_within(size..read_so_far, 0);
|
|
||||||
}
|
|
||||||
read_so_far -= size;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For type inference.
|
// For type inference.
|
||||||
|
|
|
||||||
88
crates/librqbit/src/read_buf.rs
Normal file
88
crates/librqbit/src/read_buf.rs
Normal file
|
|
@ -0,0 +1,88 @@
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use crate::peer_connection::with_timeout;
|
||||||
|
use anyhow::Context;
|
||||||
|
use buffers::ByteBuf;
|
||||||
|
use peer_binary_protocol::{
|
||||||
|
Handshake, MessageBorrowed, MessageDeserializeError, PIECE_MESSAGE_DEFAULT_LEN,
|
||||||
|
};
|
||||||
|
use tokio::io::AsyncReadExt;
|
||||||
|
|
||||||
|
pub struct ReadBuf {
|
||||||
|
buf: Vec<u8>,
|
||||||
|
read_so_far: usize,
|
||||||
|
last_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReadBuf {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
buf: vec![0; PIECE_MESSAGE_DEFAULT_LEN * 2],
|
||||||
|
read_so_far: 0,
|
||||||
|
last_size: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prepare_for_read(&mut self) {
|
||||||
|
if self.read_so_far > self.last_size {
|
||||||
|
self.buf.copy_within(self.last_size..self.read_so_far, 0);
|
||||||
|
}
|
||||||
|
self.read_so_far -= self.last_size;
|
||||||
|
self.last_size = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This MUST be run as the first operation on the buffer.
|
||||||
|
pub async fn read_handshake(
|
||||||
|
&mut self,
|
||||||
|
mut conn: impl AsyncReadExt + Unpin,
|
||||||
|
timeout: Duration,
|
||||||
|
) -> anyhow::Result<Handshake<ByteBuf<'_>>> {
|
||||||
|
self.read_so_far = with_timeout(timeout, conn.read(&mut self.buf))
|
||||||
|
.await
|
||||||
|
.context("error reading handshake")?;
|
||||||
|
if self.read_so_far == 0 {
|
||||||
|
anyhow::bail!("bad handshake");
|
||||||
|
}
|
||||||
|
let (h, size) = Handshake::deserialize(&self.buf[..self.read_so_far])
|
||||||
|
.map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?;
|
||||||
|
self.last_size = size;
|
||||||
|
Ok(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn read_message(
|
||||||
|
&mut self,
|
||||||
|
mut conn: impl AsyncReadExt + Unpin,
|
||||||
|
timeout: Duration,
|
||||||
|
on_message: impl for<'a> FnOnce(MessageBorrowed<'a>) -> anyhow::Result<()>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
self.prepare_for_read();
|
||||||
|
loop {
|
||||||
|
let need_additional_bytes =
|
||||||
|
match MessageBorrowed::deserialize(&self.buf[..self.read_so_far]) {
|
||||||
|
Err(MessageDeserializeError::NotEnoughData(d, _)) => d,
|
||||||
|
Ok((msg, size)) => {
|
||||||
|
self.last_size = size;
|
||||||
|
// Rust's borrow checker can't do this early return. So we are using a callback instead.
|
||||||
|
// return Ok(msg);
|
||||||
|
on_message(msg)?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e.into()),
|
||||||
|
};
|
||||||
|
if self.buf.len() < self.read_so_far + need_additional_bytes {
|
||||||
|
self.buf.reserve(need_additional_bytes);
|
||||||
|
self.buf.resize(self.buf.capacity(), 0);
|
||||||
|
}
|
||||||
|
let size = with_timeout(timeout, conn.read(&mut self.buf[self.read_so_far..]))
|
||||||
|
.await
|
||||||
|
.context("error reading from peer")?;
|
||||||
|
if size == 0 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"disconnected while reading, read so far: {}",
|
||||||
|
self.read_so_far
|
||||||
|
)
|
||||||
|
}
|
||||||
|
self.read_so_far += size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -27,20 +27,18 @@ use librqbit_core::{
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
use peer_binary_protocol::{Handshake, PIECE_MESSAGE_DEFAULT_LEN};
|
use peer_binary_protocol::Handshake;
|
||||||
use reqwest::Url;
|
use reqwest::Url;
|
||||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
use serde_with::serde_as;
|
use serde_with::serde_as;
|
||||||
use tokio::{
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
io::AsyncReadExt,
|
|
||||||
net::{TcpListener, TcpStream},
|
|
||||||
};
|
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::{debug, error, error_span, info, trace, warn, Instrument};
|
use tracing::{debug, error, error_span, info, trace, warn, Instrument};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult},
|
dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult},
|
||||||
peer_connection::{with_timeout, PeerConnectionOptions},
|
peer_connection::PeerConnectionOptions,
|
||||||
|
read_buf::ReadBuf,
|
||||||
spawn_utils::BlockingSpawner,
|
spawn_utils::BlockingSpawner,
|
||||||
torrent_state::{
|
torrent_state::{
|
||||||
ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive,
|
ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive,
|
||||||
|
|
@ -371,9 +369,8 @@ async fn create_tcp_listener(
|
||||||
pub(crate) struct CheckedIncomingConnection {
|
pub(crate) struct CheckedIncomingConnection {
|
||||||
pub addr: SocketAddr,
|
pub addr: SocketAddr,
|
||||||
pub stream: tokio::net::TcpStream,
|
pub stream: tokio::net::TcpStream,
|
||||||
pub read_buf: Vec<u8>,
|
pub read_buf: ReadBuf,
|
||||||
pub handshake: Handshake<ByteString>,
|
pub handshake: Handshake<ByteString>,
|
||||||
pub read_so_far: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
|
|
@ -515,16 +512,11 @@ impl Session {
|
||||||
.read_write_timeout
|
.read_write_timeout
|
||||||
.unwrap_or_else(|| Duration::from_secs(10));
|
.unwrap_or_else(|| Duration::from_secs(10));
|
||||||
|
|
||||||
let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2];
|
let mut read_buf = ReadBuf::new();
|
||||||
let mut read_so_far = with_timeout(rwtimeout, stream.read(&mut read_buf))
|
let h = read_buf
|
||||||
|
.read_handshake(&mut stream, rwtimeout)
|
||||||
.await
|
.await
|
||||||
.context("error reading handshake")?;
|
.context("error reading handshake")?;
|
||||||
if read_so_far == 0 {
|
|
||||||
anyhow::bail!("bad handshake");
|
|
||||||
}
|
|
||||||
let (h, size) = Handshake::deserialize(&read_buf[..read_so_far])
|
|
||||||
.map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?;
|
|
||||||
|
|
||||||
trace!("received handshake from {addr}: {:?}", h);
|
trace!("received handshake from {addr}: {:?}", h);
|
||||||
|
|
||||||
if h.peer_id == self.peer_id.0 {
|
if h.peer_id == self.peer_id.0 {
|
||||||
|
|
@ -545,11 +537,6 @@ impl Session {
|
||||||
|
|
||||||
let handshake = h.clone_to_owned();
|
let handshake = h.clone_to_owned();
|
||||||
|
|
||||||
if read_so_far > size {
|
|
||||||
read_buf.copy_within(size..read_so_far, 0);
|
|
||||||
}
|
|
||||||
read_so_far -= size;
|
|
||||||
|
|
||||||
return Ok((
|
return Ok((
|
||||||
live,
|
live,
|
||||||
CheckedIncomingConnection {
|
CheckedIncomingConnection {
|
||||||
|
|
@ -557,7 +544,6 @@ impl Session {
|
||||||
stream,
|
stream,
|
||||||
handshake,
|
handshake,
|
||||||
read_buf,
|
read_buf,
|
||||||
read_so_far,
|
|
||||||
},
|
},
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -457,7 +457,6 @@ impl TorrentStateLive {
|
||||||
r = requester => {r}
|
r = requester => {r}
|
||||||
r = peer_connection.manage_peer_incoming(
|
r = peer_connection.manage_peer_incoming(
|
||||||
rx,
|
rx,
|
||||||
checked_peer.read_so_far,
|
|
||||||
checked_peer.read_buf,
|
checked_peer.read_buf,
|
||||||
checked_peer.handshake,
|
checked_peer.handshake,
|
||||||
checked_peer.stream
|
checked_peer.stream
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue