Add read/write timeouts and avoid stuck peers
This commit is contained in:
parent
ae847ce99c
commit
9e8f235cb4
5 changed files with 59 additions and 55 deletions
|
|
@ -36,6 +36,7 @@ pub enum WriterRequest {
|
|||
#[derive(Default, Copy, Clone)]
|
||||
pub struct PeerConnectionOptions {
|
||||
pub connect_timeout: Option<Duration>,
|
||||
pub read_write_timeout: Option<Duration>,
|
||||
pub keep_alive_interval: Option<Duration>,
|
||||
}
|
||||
|
||||
|
|
@ -48,36 +49,21 @@ pub struct PeerConnection<H> {
|
|||
spawner: BlockingSpawner,
|
||||
}
|
||||
|
||||
// async fn read_one<'a, R: AsyncReadExt + Unpin>(
|
||||
// mut reader: R,
|
||||
// read_buf: &'a mut Vec<u8>,
|
||||
// read_so_far: &mut usize,
|
||||
// ) -> anyhow::Result<(MessageBorrowed<'a>, usize)> {
|
||||
// loop {
|
||||
// match MessageBorrowed::deserialize(&read_buf[..*read_so_far]) {
|
||||
// Ok((msg, size)) => return Ok((msg, size)),
|
||||
// Err(MessageDeserializeError::NotEnoughData(d, _)) => {
|
||||
// if read_buf.len() < *read_so_far + d {
|
||||
// read_buf.reserve(d);
|
||||
// read_buf.resize(read_buf.capacity(), 0);
|
||||
// }
|
||||
|
||||
// let size = reader
|
||||
// .read(&mut read_buf[*read_so_far..])
|
||||
// .await
|
||||
// .context("error reading from peer")?;
|
||||
// if size == 0 {
|
||||
// anyhow::bail!("disconnected while reading, read so far: {}", *read_so_far)
|
||||
// }
|
||||
// *read_so_far += size;
|
||||
// }
|
||||
// Err(e) => return Err(e.into()),
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
async fn with_timeout<T, E>(
|
||||
timeout_value: Duration,
|
||||
fut: impl std::future::Future<Output = Result<T, E>>,
|
||||
) -> anyhow::Result<T>
|
||||
where
|
||||
E: Into<anyhow::Error>,
|
||||
{
|
||||
timeout(timeout_value, fut)
|
||||
.await
|
||||
.with_context(|| format!("timeout at {timeout_value:?}"))?
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
macro_rules! read_one {
|
||||
($conn:ident, $read_buf:ident, $read_so_far:ident) => {{
|
||||
($conn:ident, $read_buf:ident, $read_so_far:ident, $rwtimeout:ident) => {{
|
||||
let (extended, size) = loop {
|
||||
match MessageBorrowed::deserialize(&$read_buf[..$read_so_far]) {
|
||||
Ok((msg, size)) => break (msg, size),
|
||||
|
|
@ -87,8 +73,7 @@ macro_rules! read_one {
|
|||
$read_buf.resize($read_buf.capacity(), 0);
|
||||
}
|
||||
|
||||
let size = $conn
|
||||
.read(&mut $read_buf[$read_so_far..])
|
||||
let size = with_timeout($rwtimeout, $conn.read(&mut $read_buf[$read_so_far..]))
|
||||
.await
|
||||
.context("error reading from peer")?;
|
||||
if size == 0 {
|
||||
|
|
@ -130,29 +115,31 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
) -> anyhow::Result<()> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
let mut conn = match timeout(
|
||||
self.options
|
||||
.connect_timeout
|
||||
.unwrap_or_else(|| Duration::from_secs(10)),
|
||||
tokio::net::TcpStream::connect(self.addr),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(conn) => conn.context("error connecting")?,
|
||||
Err(_) => anyhow::bail!("timeout connecting to {}", self.addr),
|
||||
};
|
||||
|
||||
let rwtimeout = self
|
||||
.options
|
||||
.read_write_timeout
|
||||
.unwrap_or_else(|| Duration::from_secs(10));
|
||||
|
||||
let connect_timeout = self
|
||||
.options
|
||||
.connect_timeout
|
||||
.unwrap_or_else(|| Duration::from_secs(10));
|
||||
|
||||
let mut conn = with_timeout(connect_timeout, tokio::net::TcpStream::connect(self.addr))
|
||||
.await
|
||||
.context("error connecting")?;
|
||||
|
||||
let mut write_buf = Vec::<u8>::with_capacity(PIECE_MESSAGE_DEFAULT_LEN);
|
||||
let handshake = Handshake::new(self.info_hash, self.peer_id);
|
||||
handshake.serialize(&mut write_buf);
|
||||
conn.write_all(&write_buf)
|
||||
with_timeout(rwtimeout, conn.write_all(&write_buf))
|
||||
.await
|
||||
.context("error writing handshake")?;
|
||||
write_buf.clear();
|
||||
|
||||
let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2];
|
||||
let mut read_so_far = conn
|
||||
.read(&mut read_buf)
|
||||
let mut read_so_far = with_timeout(rwtimeout, conn.read(&mut read_buf))
|
||||
.await
|
||||
.context("error reading handshake")?;
|
||||
if read_so_far == 0 {
|
||||
|
|
@ -188,12 +175,12 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
&my_extended
|
||||
);
|
||||
my_extended.serialize(&mut write_buf, None).unwrap();
|
||||
conn.write_all(&write_buf)
|
||||
with_timeout(rwtimeout, conn.write_all(&write_buf))
|
||||
.await
|
||||
.context("error writing extended handshake")?;
|
||||
write_buf.clear();
|
||||
|
||||
let (extended, size) = read_one!(conn, read_buf, read_so_far);
|
||||
let (extended, size) = read_one!(conn, read_buf, read_so_far, rwtimeout);
|
||||
match extended {
|
||||
Message::Extended(ExtendedMessage::Handshake(h)) => {
|
||||
trace!("received from {}: {:?}", self.addr, &h);
|
||||
|
|
@ -222,8 +209,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
.handler
|
||||
.serialize_bitfield_message_to_buf(&mut write_buf)
|
||||
{
|
||||
write_half
|
||||
.write_all(&write_buf[..len])
|
||||
with_timeout(rwtimeout, write_half.write_all(&write_buf[..len]))
|
||||
.await
|
||||
.context("error writing bitfield to peer")?;
|
||||
debug!("sent bitfield to {}", self.addr);
|
||||
|
|
@ -256,7 +242,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
self.handler
|
||||
.read_chunk(chunk, &mut write_buf[preamble_len..])
|
||||
})
|
||||
.with_context(|| format!("error reading chunk {:?}", chunk))?;
|
||||
.with_context(|| format!("error reading chunk {chunk:?}"))?;
|
||||
|
||||
uploaded_add = Some(chunk.size);
|
||||
full_len
|
||||
|
|
@ -265,8 +251,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
|
||||
debug!("sending to {}: {:?}, length={}", self.addr, &req, len);
|
||||
|
||||
write_half
|
||||
.write_all(&write_buf[..len])
|
||||
with_timeout(rwtimeout, write_half.write_all(&write_buf[..len]))
|
||||
.await
|
||||
.context("error writing the message to peer")?;
|
||||
write_buf.clear();
|
||||
|
|
@ -283,7 +268,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
|
||||
let reader = async move {
|
||||
loop {
|
||||
let (message, size) = read_one!(read_half, read_buf, read_so_far);
|
||||
let (message, size) = read_one!(read_half, read_buf, read_so_far, rwtimeout);
|
||||
trace!("received from {}: {:?}", self.addr, &message);
|
||||
|
||||
self.handler
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue