Add read/write timeouts and avoid stuck peers

This commit is contained in:
Igor Katson 2022-12-04 14:34:21 +00:00
parent ae847ce99c
commit 9e8f235cb4
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
5 changed files with 59 additions and 55 deletions

View file

@ -36,6 +36,7 @@ pub enum WriterRequest {
#[derive(Default, Copy, Clone)]
pub struct PeerConnectionOptions {
pub connect_timeout: Option<Duration>,
pub read_write_timeout: Option<Duration>,
pub keep_alive_interval: Option<Duration>,
}
@ -48,36 +49,21 @@ pub struct PeerConnection<H> {
spawner: BlockingSpawner,
}
// async fn read_one<'a, R: AsyncReadExt + Unpin>(
// mut reader: R,
// read_buf: &'a mut Vec<u8>,
// read_so_far: &mut usize,
// ) -> anyhow::Result<(MessageBorrowed<'a>, usize)> {
// loop {
// match MessageBorrowed::deserialize(&read_buf[..*read_so_far]) {
// Ok((msg, size)) => return Ok((msg, size)),
// Err(MessageDeserializeError::NotEnoughData(d, _)) => {
// if read_buf.len() < *read_so_far + d {
// read_buf.reserve(d);
// read_buf.resize(read_buf.capacity(), 0);
// }
// let size = reader
// .read(&mut read_buf[*read_so_far..])
// .await
// .context("error reading from peer")?;
// if size == 0 {
// anyhow::bail!("disconnected while reading, read so far: {}", *read_so_far)
// }
// *read_so_far += size;
// }
// Err(e) => return Err(e.into()),
// }
// }
// }
async fn with_timeout<T, E>(
timeout_value: Duration,
fut: impl std::future::Future<Output = Result<T, E>>,
) -> anyhow::Result<T>
where
E: Into<anyhow::Error>,
{
timeout(timeout_value, fut)
.await
.with_context(|| format!("timeout at {timeout_value:?}"))?
.map_err(|e| e.into())
}
macro_rules! read_one {
($conn:ident, $read_buf:ident, $read_so_far:ident) => {{
($conn:ident, $read_buf:ident, $read_so_far:ident, $rwtimeout:ident) => {{
let (extended, size) = loop {
match MessageBorrowed::deserialize(&$read_buf[..$read_so_far]) {
Ok((msg, size)) => break (msg, size),
@ -87,8 +73,7 @@ macro_rules! read_one {
$read_buf.resize($read_buf.capacity(), 0);
}
let size = $conn
.read(&mut $read_buf[$read_so_far..])
let size = with_timeout($rwtimeout, $conn.read(&mut $read_buf[$read_so_far..]))
.await
.context("error reading from peer")?;
if size == 0 {
@ -130,29 +115,31 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
) -> anyhow::Result<()> {
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
let mut conn = match timeout(
self.options
.connect_timeout
.unwrap_or_else(|| Duration::from_secs(10)),
tokio::net::TcpStream::connect(self.addr),
)
.await
{
Ok(conn) => conn.context("error connecting")?,
Err(_) => anyhow::bail!("timeout connecting to {}", self.addr),
};
let rwtimeout = self
.options
.read_write_timeout
.unwrap_or_else(|| Duration::from_secs(10));
let connect_timeout = self
.options
.connect_timeout
.unwrap_or_else(|| Duration::from_secs(10));
let mut conn = with_timeout(connect_timeout, tokio::net::TcpStream::connect(self.addr))
.await
.context("error connecting")?;
let mut write_buf = Vec::<u8>::with_capacity(PIECE_MESSAGE_DEFAULT_LEN);
let handshake = Handshake::new(self.info_hash, self.peer_id);
handshake.serialize(&mut write_buf);
conn.write_all(&write_buf)
with_timeout(rwtimeout, conn.write_all(&write_buf))
.await
.context("error writing handshake")?;
write_buf.clear();
let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2];
let mut read_so_far = conn
.read(&mut read_buf)
let mut read_so_far = with_timeout(rwtimeout, conn.read(&mut read_buf))
.await
.context("error reading handshake")?;
if read_so_far == 0 {
@ -188,12 +175,12 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
&my_extended
);
my_extended.serialize(&mut write_buf, None).unwrap();
conn.write_all(&write_buf)
with_timeout(rwtimeout, conn.write_all(&write_buf))
.await
.context("error writing extended handshake")?;
write_buf.clear();
let (extended, size) = read_one!(conn, read_buf, read_so_far);
let (extended, size) = read_one!(conn, read_buf, read_so_far, rwtimeout);
match extended {
Message::Extended(ExtendedMessage::Handshake(h)) => {
trace!("received from {}: {:?}", self.addr, &h);
@ -222,8 +209,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
.handler
.serialize_bitfield_message_to_buf(&mut write_buf)
{
write_half
.write_all(&write_buf[..len])
with_timeout(rwtimeout, write_half.write_all(&write_buf[..len]))
.await
.context("error writing bitfield to peer")?;
debug!("sent bitfield to {}", self.addr);
@ -256,7 +242,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
self.handler
.read_chunk(chunk, &mut write_buf[preamble_len..])
})
.with_context(|| format!("error reading chunk {:?}", chunk))?;
.with_context(|| format!("error reading chunk {chunk:?}"))?;
uploaded_add = Some(chunk.size);
full_len
@ -265,8 +251,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
debug!("sending to {}: {:?}, length={}", self.addr, &req, len);
write_half
.write_all(&write_buf[..len])
with_timeout(rwtimeout, write_half.write_all(&write_buf[..len]))
.await
.context("error writing the message to peer")?;
write_buf.clear();
@ -283,7 +268,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
let reader = async move {
loop {
let (message, size) = read_one!(read_half, read_buf, read_so_far);
let (message, size) = read_one!(read_half, read_buf, read_so_far, rwtimeout);
trace!("received from {}: {:?}", self.addr, &message);
self.handler

View file

@ -355,6 +355,10 @@ impl Session {
builder.peer_connect_timeout(t);
}
if let Some(t) = opts.peer_opts.unwrap_or(self.peer_opts).read_write_timeout {
builder.peer_read_write_timeout(t);
}
let handle = match builder
.start_manager()
.context("error starting torrent manager")

View file

@ -32,6 +32,7 @@ use crate::{
struct TorrentManagerOptions {
force_tracker_interval: Option<Duration>,
peer_connect_timeout: Option<Duration>,
peer_read_write_timeout: Option<Duration>,
only_files: Option<Vec<usize>>,
peer_id: Option<Id20>,
overwrite: bool,
@ -90,6 +91,11 @@ impl TorrentManagerBuilder {
self
}
pub fn peer_read_write_timeout(&mut self, timeout: Duration) -> &mut Self {
self.options.peer_read_write_timeout = Some(timeout);
self
}
pub fn start_manager(self) -> anyhow::Result<TorrentManagerHandle> {
TorrentManager::start(
self.info,
@ -256,6 +262,7 @@ impl TorrentManager {
#[allow(clippy::needless_update)]
let state_options = TorrentStateOptions {
peer_connect_timeout: options.peer_connect_timeout,
peer_read_write_timeout: options.peer_read_write_timeout,
..Default::default()
};

View file

@ -225,6 +225,7 @@ impl StatsSnapshot {
#[derive(Default)]
pub struct TorrentStateOptions {
pub peer_connect_timeout: Option<Duration>,
pub peer_read_write_timeout: Option<Duration>,
}
pub struct TorrentState {
@ -286,6 +287,7 @@ impl TorrentState {
loop {
let (addr, out_rx) = peer_queue_rx.recv().await.unwrap();
let permit = state.peer_semaphore.acquire().await.unwrap();
match state.locked.write().peers.states.get_mut(&addr) {
Some(s @ PeerState::Queued) => *s = PeerState::Connecting,
s => {
@ -294,8 +296,6 @@ impl TorrentState {
}
};
state.peer_semaphore.acquire().await.unwrap().forget();
let handler = PeerHandler {
addr,
state: state.clone(),
@ -303,6 +303,7 @@ impl TorrentState {
};
let options = PeerConnectionOptions {
connect_timeout: state.options.peer_connect_timeout,
read_write_timeout: state.options.peer_read_write_timeout,
..Default::default()
};
let peer_connection = PeerConnection::new(
@ -313,7 +314,9 @@ impl TorrentState {
Some(options),
spawner,
);
spawn(format!("manage_peer({})", addr), async move {
permit.forget();
spawn(format!("manage_peer({addr})"), async move {
if let Err(e) = peer_connection.manage_peer(out_rx).await {
debug!("error managing peer {}: {:#}", addr, e)
};

View file

@ -70,6 +70,10 @@ struct Opts {
#[clap(long = "peer-connect-timeout")]
peer_connect_timeout: Option<ParsedDuration>,
/// The connect timeout, e.g. 1s, 1.5s, 100ms etc.
#[clap(long = "peer-read-write-timeout")]
peer_read_write_timeout: Option<ParsedDuration>,
/// How many threads to spawn for the executor.
#[clap(short = 't', long)]
worker_threads: Option<usize>,
@ -200,6 +204,7 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
peer_id: None,
peer_opts: Some(PeerConnectionOptions {
connect_timeout: opts.peer_connect_timeout.map(|d| d.0),
read_write_timeout: opts.peer_read_write_timeout.map(|d| d.0),
..Default::default()
}),
};