use std::time::Duration; use crate::peer_connection::with_timeout; use anyhow::Context; use buffers::ByteBuf; use peer_binary_protocol::{ Handshake, MessageBorrowed, MessageDeserializeError, PIECE_MESSAGE_DEFAULT_LEN, }; use tokio::io::AsyncReadExt; pub struct ReadBuf { buf: Vec, // How many bytes into the buffer we have read from the connection. // New reads should go past this. filled: usize, // How many bytes have we successfully deserialized. processed: usize, } impl ReadBuf { pub fn new() -> Self { Self { buf: vec![0; PIECE_MESSAGE_DEFAULT_LEN * 2], filled: 0, processed: 0, } } fn prepare_for_read(&mut self, need_additional_bytes: usize) { // Ensure the buffer starts from the to-be-deserialized message. if self.processed > 0 { if self.filled > self.processed { self.buf.copy_within(self.processed..self.filled, 0); } self.filled -= self.processed; self.processed = 0; } // Ensure we have enough capacity to deserialize the message. if self.buf.len() < self.filled + need_additional_bytes { self.buf.reserve(need_additional_bytes); self.buf.resize(self.buf.capacity(), 0); } } // Read the BT handshake. // This MUST be run as the first operation on the buffer. pub async fn read_handshake( &mut self, mut conn: impl AsyncReadExt + Unpin, timeout: Duration, ) -> anyhow::Result>> { self.filled = with_timeout(timeout, conn.read(&mut self.buf)) .await .context("error reading handshake")?; if self.filled == 0 { anyhow::bail!("peer disconnected while reading handshake"); } let (h, size) = Handshake::deserialize(&self.buf[..self.filled]).map_err(|e| { anyhow::anyhow!( "error deserializing handshake: {:?} hadshake data {:?}", e, &self.buf[..self.filled.min(19)] ) })?; self.processed = size; Ok(h) } // Read a message into the buffer, try to deserialize it and call the callback on it. pub async fn read_message( &mut self, mut conn: impl AsyncReadExt + Unpin, timeout: Duration, ) -> anyhow::Result> { loop { let need_additional_bytes = match MessageBorrowed::deserialize(&self.buf[self.processed..self.filled]) { Err(MessageDeserializeError::NotEnoughData(d, _)) => d, Ok((msg, size)) => { self.processed += size; // Rust's borrow checker can't do this early return so resort to unsafe. // This erases the lifetime so that it's happy. let msg: MessageBorrowed<'_> = unsafe { std::mem::transmute(msg as MessageBorrowed<'_>) }; return Ok(msg); } Err(e) => return Err(e.into()), }; self.prepare_for_read(need_additional_bytes); debug_assert!(!self.buf[self.filled..].is_empty()); let size = with_timeout(timeout, conn.read(&mut self.buf[self.filled..])) .await .context("error reading from peer")?; if size == 0 { anyhow::bail!("disconnected while reading, read so far: {}", self.filled) } self.filled += size; } } }