A tiny optimisation to not memcpy the buffer as often if we have many messages already buffered
This commit is contained in:
parent
d5d98aff60
commit
8ee98548f5
1 changed files with 35 additions and 27 deletions
|
|
@ -10,58 +10,72 @@ use tokio::io::AsyncReadExt;
|
|||
|
||||
pub struct ReadBuf {
|
||||
buf: Vec<u8>,
|
||||
read_so_far: usize,
|
||||
last_size: usize,
|
||||
// How many bytes into the buffer we have read from the connection.
|
||||
// New reads should go past this.
|
||||
filled: usize,
|
||||
// How many bytes have we successfully deserialized.
|
||||
processed: usize,
|
||||
}
|
||||
|
||||
impl ReadBuf {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buf: vec![0; PIECE_MESSAGE_DEFAULT_LEN * 2],
|
||||
read_so_far: 0,
|
||||
last_size: 0,
|
||||
filled: 0,
|
||||
processed: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_for_read(&mut self) {
|
||||
if self.read_so_far > self.last_size {
|
||||
self.buf.copy_within(self.last_size..self.read_so_far, 0);
|
||||
fn prepare_for_read(&mut self, need_additional_bytes: usize) {
|
||||
// Ensure the buffer starts from the to-be-deserialized message.
|
||||
if self.processed > 0 {
|
||||
if self.filled > self.processed {
|
||||
self.buf.copy_within(self.processed..self.filled, 0);
|
||||
}
|
||||
self.filled -= self.processed;
|
||||
self.processed = 0;
|
||||
}
|
||||
|
||||
// Ensure we have enough capacity to deserialize the message.
|
||||
if self.buf.len() < self.filled + need_additional_bytes {
|
||||
self.buf.reserve(need_additional_bytes);
|
||||
self.buf.resize(self.buf.capacity(), 0);
|
||||
}
|
||||
self.read_so_far -= self.last_size;
|
||||
self.last_size = 0;
|
||||
}
|
||||
|
||||
// Read the BT handshake.
|
||||
// This MUST be run as the first operation on the buffer.
|
||||
pub async fn read_handshake(
|
||||
&mut self,
|
||||
mut conn: impl AsyncReadExt + Unpin,
|
||||
timeout: Duration,
|
||||
) -> anyhow::Result<Handshake<ByteBuf<'_>>> {
|
||||
self.read_so_far = with_timeout(timeout, conn.read(&mut self.buf))
|
||||
self.filled = with_timeout(timeout, conn.read(&mut self.buf))
|
||||
.await
|
||||
.context("error reading handshake")?;
|
||||
if self.read_so_far == 0 {
|
||||
anyhow::bail!("bad handshake");
|
||||
if self.filled == 0 {
|
||||
anyhow::bail!("peer disconnected while reading handshake");
|
||||
}
|
||||
let (h, size) = Handshake::deserialize(&self.buf[..self.read_so_far])
|
||||
let (h, size) = Handshake::deserialize(&self.buf[..self.filled])
|
||||
.map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?;
|
||||
self.last_size = size;
|
||||
self.processed = size;
|
||||
Ok(h)
|
||||
}
|
||||
|
||||
// Read a message into the buffer, try to deserialize it and call the callback on it.
|
||||
// We can't return the message because of a borrow checker issue.
|
||||
pub async fn read_message(
|
||||
&mut self,
|
||||
mut conn: impl AsyncReadExt + Unpin,
|
||||
timeout: Duration,
|
||||
on_message: impl for<'a> FnOnce(MessageBorrowed<'a>) -> anyhow::Result<()>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.prepare_for_read();
|
||||
loop {
|
||||
let need_additional_bytes =
|
||||
match MessageBorrowed::deserialize(&self.buf[..self.read_so_far]) {
|
||||
match MessageBorrowed::deserialize(&self.buf[self.processed..self.filled]) {
|
||||
Err(MessageDeserializeError::NotEnoughData(d, _)) => d,
|
||||
Ok((msg, size)) => {
|
||||
self.last_size = size;
|
||||
self.processed += size;
|
||||
// Rust's borrow checker can't do this early return. So we are using a callback instead.
|
||||
// return Ok(msg);
|
||||
on_message(msg)?;
|
||||
|
|
@ -69,20 +83,14 @@ impl ReadBuf {
|
|||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
if self.buf.len() < self.read_so_far + need_additional_bytes {
|
||||
self.buf.reserve(need_additional_bytes);
|
||||
self.buf.resize(self.buf.capacity(), 0);
|
||||
}
|
||||
let size = with_timeout(timeout, conn.read(&mut self.buf[self.read_so_far..]))
|
||||
self.prepare_for_read(need_additional_bytes);
|
||||
let size = with_timeout(timeout, conn.read(&mut self.buf[self.filled..]))
|
||||
.await
|
||||
.context("error reading from peer")?;
|
||||
if size == 0 {
|
||||
anyhow::bail!(
|
||||
"disconnected while reading, read so far: {}",
|
||||
self.read_so_far
|
||||
)
|
||||
anyhow::bail!("disconnected while reading, read so far: {}", self.filled)
|
||||
}
|
||||
self.read_so_far += size;
|
||||
self.filled += size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue