Extract ReadBuf logic into a separate struct
This commit is contained in:
parent
09252c0397
commit
d5d98aff60
5 changed files with 120 additions and 89 deletions
|
|
@ -10,15 +10,14 @@ use librqbit_core::{id20::Id20, lengths::ChunkInfo, peer_id::try_decode_peer_id}
|
|||
use parking_lot::RwLock;
|
||||
use peer_binary_protocol::{
|
||||
extended::{handshake::ExtendedHandshake, ExtendedMessage},
|
||||
serialize_piece_preamble, Handshake, Message, MessageBorrowed, MessageDeserializeError,
|
||||
MessageOwned, PIECE_MESSAGE_DEFAULT_LEN,
|
||||
serialize_piece_preamble, Handshake, Message, MessageOwned, PIECE_MESSAGE_DEFAULT_LEN,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::serde_as;
|
||||
use tokio::time::timeout;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::spawn_utils::BlockingSpawner;
|
||||
use crate::{read_buf::ReadBuf, spawn_utils::BlockingSpawner};
|
||||
|
||||
pub trait PeerConnectionHandler {
|
||||
fn on_connected(&self, _connection_time: Duration) {}
|
||||
|
|
@ -100,9 +99,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
pub async fn manage_peer_incoming(
|
||||
&self,
|
||||
outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
|
||||
// How many bytes into read buffer have we read already.
|
||||
read_so_far: usize,
|
||||
read_buf: Vec<u8>,
|
||||
read_buf: ReadBuf,
|
||||
handshake: Handshake<ByteString>,
|
||||
mut conn: tokio::net::TcpStream,
|
||||
) -> anyhow::Result<()> {
|
||||
|
|
@ -140,7 +137,6 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
|
||||
self.manage_peer(
|
||||
h_supports_extended,
|
||||
read_so_far,
|
||||
read_buf,
|
||||
write_buf,
|
||||
conn,
|
||||
|
|
@ -153,7 +149,6 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
&self,
|
||||
outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
|
||||
) -> anyhow::Result<()> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
let rwtimeout = self
|
||||
|
|
@ -180,16 +175,11 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
.context("error writing handshake")?;
|
||||
write_buf.clear();
|
||||
|
||||
let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2];
|
||||
let mut read_so_far = with_timeout(rwtimeout, conn.read(&mut read_buf))
|
||||
let mut read_buf = ReadBuf::new();
|
||||
let h = read_buf
|
||||
.read_handshake(&mut conn, rwtimeout)
|
||||
.await
|
||||
.context("error reading handshake")?;
|
||||
if read_so_far == 0 {
|
||||
anyhow::bail!("bad handshake");
|
||||
}
|
||||
let (h, size) = Handshake::deserialize(&read_buf[..read_so_far])
|
||||
.map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?;
|
||||
|
||||
let h_supports_extended = h.supports_extended();
|
||||
trace!("connected: id={:?}", try_decode_peer_id(Id20(h.peer_id)));
|
||||
if h.info_hash != self.info_hash.0 {
|
||||
|
|
@ -202,14 +192,8 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
|
||||
self.handler.on_handshake(h)?;
|
||||
|
||||
if read_so_far > size {
|
||||
read_buf.copy_within(size..read_so_far, 0);
|
||||
}
|
||||
read_so_far -= size;
|
||||
|
||||
self.manage_peer(
|
||||
h_supports_extended,
|
||||
read_so_far,
|
||||
read_buf,
|
||||
write_buf,
|
||||
conn,
|
||||
|
|
@ -221,14 +205,11 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
async fn manage_peer(
|
||||
&self,
|
||||
handshake_supports_extended: bool,
|
||||
// How many bytes into read_buf is there of peer-sent-data.
|
||||
mut read_so_far: usize,
|
||||
mut read_buf: Vec<u8>,
|
||||
mut read_buf: ReadBuf,
|
||||
mut write_buf: Vec<u8>,
|
||||
mut conn: tokio::net::TcpStream,
|
||||
mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
|
||||
) -> anyhow::Result<()> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
let rwtimeout = self
|
||||
|
|
@ -328,47 +309,23 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
|
||||
let reader = async move {
|
||||
loop {
|
||||
let (message, size) = loop {
|
||||
match MessageBorrowed::deserialize(&read_buf[..read_so_far]) {
|
||||
Ok((msg, size)) => break (msg, size),
|
||||
Err(MessageDeserializeError::NotEnoughData(d, _)) => {
|
||||
if read_buf.len() < read_so_far + d {
|
||||
read_buf.reserve(d);
|
||||
read_buf.resize(read_buf.capacity(), 0);
|
||||
}
|
||||
let size = with_timeout(
|
||||
rwtimeout,
|
||||
read_half.read(&mut read_buf[read_so_far..]),
|
||||
)
|
||||
.await
|
||||
.context("error reading from peer")?;
|
||||
if size == 0 {
|
||||
anyhow::bail!(
|
||||
"disconnected while reading, read so far: {}",
|
||||
read_so_far
|
||||
)
|
||||
}
|
||||
read_so_far += size;
|
||||
read_buf
|
||||
.read_message(&mut read_half, rwtimeout, |message| {
|
||||
trace!("received: {:?}", &message);
|
||||
|
||||
if let Message::Extended(ExtendedMessage::Handshake(h)) = &message {
|
||||
*extended_handshake_ref.write() = Some(h.clone_to_owned());
|
||||
self.handler.on_extended_handshake(h)?;
|
||||
trace!("remembered extended handshake for future serializing");
|
||||
} else {
|
||||
self.handler
|
||||
.on_received_message(message)
|
||||
.context("error in handler.on_received_message()")?;
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
};
|
||||
trace!("received: {:?}", &message);
|
||||
|
||||
if let Message::Extended(ExtendedMessage::Handshake(h)) = &message {
|
||||
*extended_handshake_ref.write() = Some(h.clone_to_owned());
|
||||
self.handler.on_extended_handshake(h)?;
|
||||
trace!("remembered extended handshake for future serializing");
|
||||
} else {
|
||||
self.handler
|
||||
.on_received_message(message)
|
||||
.context("error in handler.on_received_message()")?;
|
||||
}
|
||||
|
||||
if read_so_far > size {
|
||||
read_buf.copy_within(size..read_so_far, 0);
|
||||
}
|
||||
read_so_far -= size;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.context("error reading message")?;
|
||||
}
|
||||
|
||||
// For type inference.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue