Add read/write timeouts and avoid stuck peers

This commit is contained in:
Igor Katson 2022-12-04 14:34:21 +00:00
parent ae847ce99c
commit 9e8f235cb4
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
5 changed files with 59 additions and 55 deletions

View file

@ -36,6 +36,7 @@ pub enum WriterRequest {
#[derive(Default, Copy, Clone)] #[derive(Default, Copy, Clone)]
pub struct PeerConnectionOptions { pub struct PeerConnectionOptions {
pub connect_timeout: Option<Duration>, pub connect_timeout: Option<Duration>,
pub read_write_timeout: Option<Duration>,
pub keep_alive_interval: Option<Duration>, pub keep_alive_interval: Option<Duration>,
} }
@ -48,36 +49,21 @@ pub struct PeerConnection<H> {
spawner: BlockingSpawner, spawner: BlockingSpawner,
} }
// async fn read_one<'a, R: AsyncReadExt + Unpin>( async fn with_timeout<T, E>(
// mut reader: R, timeout_value: Duration,
// read_buf: &'a mut Vec<u8>, fut: impl std::future::Future<Output = Result<T, E>>,
// read_so_far: &mut usize, ) -> anyhow::Result<T>
// ) -> anyhow::Result<(MessageBorrowed<'a>, usize)> { where
// loop { E: Into<anyhow::Error>,
// match MessageBorrowed::deserialize(&read_buf[..*read_so_far]) { {
// Ok((msg, size)) => return Ok((msg, size)), timeout(timeout_value, fut)
// Err(MessageDeserializeError::NotEnoughData(d, _)) => { .await
// if read_buf.len() < *read_so_far + d { .with_context(|| format!("timeout at {timeout_value:?}"))?
// read_buf.reserve(d); .map_err(|e| e.into())
// read_buf.resize(read_buf.capacity(), 0); }
// }
// let size = reader
// .read(&mut read_buf[*read_so_far..])
// .await
// .context("error reading from peer")?;
// if size == 0 {
// anyhow::bail!("disconnected while reading, read so far: {}", *read_so_far)
// }
// *read_so_far += size;
// }
// Err(e) => return Err(e.into()),
// }
// }
// }
macro_rules! read_one { macro_rules! read_one {
($conn:ident, $read_buf:ident, $read_so_far:ident) => {{ ($conn:ident, $read_buf:ident, $read_so_far:ident, $rwtimeout:ident) => {{
let (extended, size) = loop { let (extended, size) = loop {
match MessageBorrowed::deserialize(&$read_buf[..$read_so_far]) { match MessageBorrowed::deserialize(&$read_buf[..$read_so_far]) {
Ok((msg, size)) => break (msg, size), Ok((msg, size)) => break (msg, size),
@ -87,8 +73,7 @@ macro_rules! read_one {
$read_buf.resize($read_buf.capacity(), 0); $read_buf.resize($read_buf.capacity(), 0);
} }
let size = $conn let size = with_timeout($rwtimeout, $conn.read(&mut $read_buf[$read_so_far..]))
.read(&mut $read_buf[$read_so_far..])
.await .await
.context("error reading from peer")?; .context("error reading from peer")?;
if size == 0 { if size == 0 {
@ -130,29 +115,31 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
let mut conn = match timeout(
self.options let rwtimeout = self
.connect_timeout .options
.unwrap_or_else(|| Duration::from_secs(10)), .read_write_timeout
tokio::net::TcpStream::connect(self.addr), .unwrap_or_else(|| Duration::from_secs(10));
)
.await let connect_timeout = self
{ .options
Ok(conn) => conn.context("error connecting")?, .connect_timeout
Err(_) => anyhow::bail!("timeout connecting to {}", self.addr), .unwrap_or_else(|| Duration::from_secs(10));
};
let mut conn = with_timeout(connect_timeout, tokio::net::TcpStream::connect(self.addr))
.await
.context("error connecting")?;
let mut write_buf = Vec::<u8>::with_capacity(PIECE_MESSAGE_DEFAULT_LEN); let mut write_buf = Vec::<u8>::with_capacity(PIECE_MESSAGE_DEFAULT_LEN);
let handshake = Handshake::new(self.info_hash, self.peer_id); let handshake = Handshake::new(self.info_hash, self.peer_id);
handshake.serialize(&mut write_buf); handshake.serialize(&mut write_buf);
conn.write_all(&write_buf) with_timeout(rwtimeout, conn.write_all(&write_buf))
.await .await
.context("error writing handshake")?; .context("error writing handshake")?;
write_buf.clear(); write_buf.clear();
let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2]; let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2];
let mut read_so_far = conn let mut read_so_far = with_timeout(rwtimeout, conn.read(&mut read_buf))
.read(&mut read_buf)
.await .await
.context("error reading handshake")?; .context("error reading handshake")?;
if read_so_far == 0 { if read_so_far == 0 {
@ -188,12 +175,12 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
&my_extended &my_extended
); );
my_extended.serialize(&mut write_buf, None).unwrap(); my_extended.serialize(&mut write_buf, None).unwrap();
conn.write_all(&write_buf) with_timeout(rwtimeout, conn.write_all(&write_buf))
.await .await
.context("error writing extended handshake")?; .context("error writing extended handshake")?;
write_buf.clear(); write_buf.clear();
let (extended, size) = read_one!(conn, read_buf, read_so_far); let (extended, size) = read_one!(conn, read_buf, read_so_far, rwtimeout);
match extended { match extended {
Message::Extended(ExtendedMessage::Handshake(h)) => { Message::Extended(ExtendedMessage::Handshake(h)) => {
trace!("received from {}: {:?}", self.addr, &h); trace!("received from {}: {:?}", self.addr, &h);
@ -222,8 +209,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
.handler .handler
.serialize_bitfield_message_to_buf(&mut write_buf) .serialize_bitfield_message_to_buf(&mut write_buf)
{ {
write_half with_timeout(rwtimeout, write_half.write_all(&write_buf[..len]))
.write_all(&write_buf[..len])
.await .await
.context("error writing bitfield to peer")?; .context("error writing bitfield to peer")?;
debug!("sent bitfield to {}", self.addr); debug!("sent bitfield to {}", self.addr);
@ -256,7 +242,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
self.handler self.handler
.read_chunk(chunk, &mut write_buf[preamble_len..]) .read_chunk(chunk, &mut write_buf[preamble_len..])
}) })
.with_context(|| format!("error reading chunk {:?}", chunk))?; .with_context(|| format!("error reading chunk {chunk:?}"))?;
uploaded_add = Some(chunk.size); uploaded_add = Some(chunk.size);
full_len full_len
@ -265,8 +251,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
debug!("sending to {}: {:?}, length={}", self.addr, &req, len); debug!("sending to {}: {:?}, length={}", self.addr, &req, len);
write_half with_timeout(rwtimeout, write_half.write_all(&write_buf[..len]))
.write_all(&write_buf[..len])
.await .await
.context("error writing the message to peer")?; .context("error writing the message to peer")?;
write_buf.clear(); write_buf.clear();
@ -283,7 +268,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
let reader = async move { let reader = async move {
loop { loop {
let (message, size) = read_one!(read_half, read_buf, read_so_far); let (message, size) = read_one!(read_half, read_buf, read_so_far, rwtimeout);
trace!("received from {}: {:?}", self.addr, &message); trace!("received from {}: {:?}", self.addr, &message);
self.handler self.handler

View file

@ -355,6 +355,10 @@ impl Session {
builder.peer_connect_timeout(t); builder.peer_connect_timeout(t);
} }
if let Some(t) = opts.peer_opts.unwrap_or(self.peer_opts).read_write_timeout {
builder.peer_read_write_timeout(t);
}
let handle = match builder let handle = match builder
.start_manager() .start_manager()
.context("error starting torrent manager") .context("error starting torrent manager")

View file

@ -32,6 +32,7 @@ use crate::{
struct TorrentManagerOptions { struct TorrentManagerOptions {
force_tracker_interval: Option<Duration>, force_tracker_interval: Option<Duration>,
peer_connect_timeout: Option<Duration>, peer_connect_timeout: Option<Duration>,
peer_read_write_timeout: Option<Duration>,
only_files: Option<Vec<usize>>, only_files: Option<Vec<usize>>,
peer_id: Option<Id20>, peer_id: Option<Id20>,
overwrite: bool, overwrite: bool,
@ -90,6 +91,11 @@ impl TorrentManagerBuilder {
self self
} }
pub fn peer_read_write_timeout(&mut self, timeout: Duration) -> &mut Self {
self.options.peer_read_write_timeout = Some(timeout);
self
}
pub fn start_manager(self) -> anyhow::Result<TorrentManagerHandle> { pub fn start_manager(self) -> anyhow::Result<TorrentManagerHandle> {
TorrentManager::start( TorrentManager::start(
self.info, self.info,
@ -256,6 +262,7 @@ impl TorrentManager {
#[allow(clippy::needless_update)] #[allow(clippy::needless_update)]
let state_options = TorrentStateOptions { let state_options = TorrentStateOptions {
peer_connect_timeout: options.peer_connect_timeout, peer_connect_timeout: options.peer_connect_timeout,
peer_read_write_timeout: options.peer_read_write_timeout,
..Default::default() ..Default::default()
}; };

View file

@ -225,6 +225,7 @@ impl StatsSnapshot {
#[derive(Default)] #[derive(Default)]
pub struct TorrentStateOptions { pub struct TorrentStateOptions {
pub peer_connect_timeout: Option<Duration>, pub peer_connect_timeout: Option<Duration>,
pub peer_read_write_timeout: Option<Duration>,
} }
pub struct TorrentState { pub struct TorrentState {
@ -286,6 +287,7 @@ impl TorrentState {
loop { loop {
let (addr, out_rx) = peer_queue_rx.recv().await.unwrap(); let (addr, out_rx) = peer_queue_rx.recv().await.unwrap();
let permit = state.peer_semaphore.acquire().await.unwrap();
match state.locked.write().peers.states.get_mut(&addr) { match state.locked.write().peers.states.get_mut(&addr) {
Some(s @ PeerState::Queued) => *s = PeerState::Connecting, Some(s @ PeerState::Queued) => *s = PeerState::Connecting,
s => { s => {
@ -294,8 +296,6 @@ impl TorrentState {
} }
}; };
state.peer_semaphore.acquire().await.unwrap().forget();
let handler = PeerHandler { let handler = PeerHandler {
addr, addr,
state: state.clone(), state: state.clone(),
@ -303,6 +303,7 @@ impl TorrentState {
}; };
let options = PeerConnectionOptions { let options = PeerConnectionOptions {
connect_timeout: state.options.peer_connect_timeout, connect_timeout: state.options.peer_connect_timeout,
read_write_timeout: state.options.peer_read_write_timeout,
..Default::default() ..Default::default()
}; };
let peer_connection = PeerConnection::new( let peer_connection = PeerConnection::new(
@ -313,7 +314,9 @@ impl TorrentState {
Some(options), Some(options),
spawner, spawner,
); );
spawn(format!("manage_peer({})", addr), async move {
permit.forget();
spawn(format!("manage_peer({addr})"), async move {
if let Err(e) = peer_connection.manage_peer(out_rx).await { if let Err(e) = peer_connection.manage_peer(out_rx).await {
debug!("error managing peer {}: {:#}", addr, e) debug!("error managing peer {}: {:#}", addr, e)
}; };

View file

@ -70,6 +70,10 @@ struct Opts {
#[clap(long = "peer-connect-timeout")] #[clap(long = "peer-connect-timeout")]
peer_connect_timeout: Option<ParsedDuration>, peer_connect_timeout: Option<ParsedDuration>,
/// The connect timeout, e.g. 1s, 1.5s, 100ms etc.
#[clap(long = "peer-read-write-timeout")]
peer_read_write_timeout: Option<ParsedDuration>,
/// How many threads to spawn for the executor. /// How many threads to spawn for the executor.
#[clap(short = 't', long)] #[clap(short = 't', long)]
worker_threads: Option<usize>, worker_threads: Option<usize>,
@ -200,6 +204,7 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
peer_id: None, peer_id: None,
peer_opts: Some(PeerConnectionOptions { peer_opts: Some(PeerConnectionOptions {
connect_timeout: opts.peer_connect_timeout.map(|d| d.0), connect_timeout: opts.peer_connect_timeout.map(|d| d.0),
read_write_timeout: opts.peer_read_write_timeout.map(|d| d.0),
..Default::default() ..Default::default()
}), }),
}; };