102 lines
3.6 KiB
Rust
102 lines
3.6 KiB
Rust
use std::time::Duration;
|
|
|
|
use crate::peer_connection::with_timeout;
|
|
use anyhow::Context;
|
|
use buffers::ByteBuf;
|
|
use peer_binary_protocol::{
|
|
Handshake, MessageBorrowed, MessageDeserializeError, PIECE_MESSAGE_DEFAULT_LEN,
|
|
};
|
|
use tokio::io::AsyncReadExt;
|
|
|
|
pub struct ReadBuf {
|
|
buf: Vec<u8>,
|
|
// How many bytes into the buffer we have read from the connection.
|
|
// New reads should go past this.
|
|
filled: usize,
|
|
// How many bytes have we successfully deserialized.
|
|
processed: usize,
|
|
}
|
|
|
|
impl ReadBuf {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
buf: vec![0; PIECE_MESSAGE_DEFAULT_LEN * 2],
|
|
filled: 0,
|
|
processed: 0,
|
|
}
|
|
}
|
|
|
|
fn prepare_for_read(&mut self, need_additional_bytes: usize) {
|
|
// Ensure the buffer starts from the to-be-deserialized message.
|
|
if self.processed > 0 {
|
|
if self.filled > self.processed {
|
|
self.buf.copy_within(self.processed..self.filled, 0);
|
|
}
|
|
self.filled -= self.processed;
|
|
self.processed = 0;
|
|
}
|
|
|
|
// Ensure we have enough capacity to deserialize the message.
|
|
if self.buf.len() < self.filled + need_additional_bytes {
|
|
self.buf.reserve(need_additional_bytes);
|
|
self.buf.resize(self.buf.capacity(), 0);
|
|
}
|
|
}
|
|
|
|
// Read the BT handshake.
|
|
// This MUST be run as the first operation on the buffer.
|
|
pub async fn read_handshake(
|
|
&mut self,
|
|
mut conn: impl AsyncReadExt + Unpin,
|
|
timeout: Duration,
|
|
) -> anyhow::Result<Handshake<ByteBuf<'_>>> {
|
|
self.filled = with_timeout(timeout, conn.read(&mut self.buf))
|
|
.await
|
|
.context("error reading handshake")?;
|
|
if self.filled == 0 {
|
|
anyhow::bail!("peer disconnected while reading handshake");
|
|
}
|
|
let (h, size) = Handshake::deserialize(&self.buf[..self.filled]).map_err(|e| {
|
|
anyhow::anyhow!(
|
|
"error deserializing handshake: {:?} hadshake data {:?}",
|
|
e,
|
|
&self.buf[..self.filled.min(19)]
|
|
)
|
|
})?;
|
|
self.processed = size;
|
|
Ok(h)
|
|
}
|
|
|
|
// Read a message into the buffer, try to deserialize it and call the callback on it.
|
|
pub async fn read_message(
|
|
&mut self,
|
|
mut conn: impl AsyncReadExt + Unpin,
|
|
timeout: Duration,
|
|
) -> anyhow::Result<MessageBorrowed<'_>> {
|
|
loop {
|
|
let need_additional_bytes =
|
|
match MessageBorrowed::deserialize(&self.buf[self.processed..self.filled]) {
|
|
Err(MessageDeserializeError::NotEnoughData(d, _)) => d,
|
|
Ok((msg, size)) => {
|
|
self.processed += size;
|
|
|
|
// Rust's borrow checker can't do this early return so resort to unsafe.
|
|
// This erases the lifetime so that it's happy.
|
|
let msg: MessageBorrowed<'_> =
|
|
unsafe { std::mem::transmute(msg as MessageBorrowed<'_>) };
|
|
return Ok(msg);
|
|
}
|
|
Err(e) => return Err(e.into()),
|
|
};
|
|
self.prepare_for_read(need_additional_bytes);
|
|
debug_assert!(!self.buf[self.filled..].is_empty());
|
|
let size = with_timeout(timeout, conn.read(&mut self.buf[self.filled..]))
|
|
.await
|
|
.context("error reading from peer")?;
|
|
if size == 0 {
|
|
anyhow::bail!("disconnected while reading, read so far: {}", self.filled)
|
|
}
|
|
self.filled += size;
|
|
}
|
|
}
|
|
}
|