Add read/write timeouts and avoid stuck peers
This commit is contained in:
parent
ae847ce99c
commit
9e8f235cb4
5 changed files with 59 additions and 55 deletions
|
|
@ -36,6 +36,7 @@ pub enum WriterRequest {
|
|||
#[derive(Default, Copy, Clone)]
|
||||
pub struct PeerConnectionOptions {
|
||||
pub connect_timeout: Option<Duration>,
|
||||
pub read_write_timeout: Option<Duration>,
|
||||
pub keep_alive_interval: Option<Duration>,
|
||||
}
|
||||
|
||||
|
|
@ -48,36 +49,21 @@ pub struct PeerConnection<H> {
|
|||
spawner: BlockingSpawner,
|
||||
}
|
||||
|
||||
// async fn read_one<'a, R: AsyncReadExt + Unpin>(
|
||||
// mut reader: R,
|
||||
// read_buf: &'a mut Vec<u8>,
|
||||
// read_so_far: &mut usize,
|
||||
// ) -> anyhow::Result<(MessageBorrowed<'a>, usize)> {
|
||||
// loop {
|
||||
// match MessageBorrowed::deserialize(&read_buf[..*read_so_far]) {
|
||||
// Ok((msg, size)) => return Ok((msg, size)),
|
||||
// Err(MessageDeserializeError::NotEnoughData(d, _)) => {
|
||||
// if read_buf.len() < *read_so_far + d {
|
||||
// read_buf.reserve(d);
|
||||
// read_buf.resize(read_buf.capacity(), 0);
|
||||
// }
|
||||
|
||||
// let size = reader
|
||||
// .read(&mut read_buf[*read_so_far..])
|
||||
// .await
|
||||
// .context("error reading from peer")?;
|
||||
// if size == 0 {
|
||||
// anyhow::bail!("disconnected while reading, read so far: {}", *read_so_far)
|
||||
// }
|
||||
// *read_so_far += size;
|
||||
// }
|
||||
// Err(e) => return Err(e.into()),
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
async fn with_timeout<T, E>(
|
||||
timeout_value: Duration,
|
||||
fut: impl std::future::Future<Output = Result<T, E>>,
|
||||
) -> anyhow::Result<T>
|
||||
where
|
||||
E: Into<anyhow::Error>,
|
||||
{
|
||||
timeout(timeout_value, fut)
|
||||
.await
|
||||
.with_context(|| format!("timeout at {timeout_value:?}"))?
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
macro_rules! read_one {
|
||||
($conn:ident, $read_buf:ident, $read_so_far:ident) => {{
|
||||
($conn:ident, $read_buf:ident, $read_so_far:ident, $rwtimeout:ident) => {{
|
||||
let (extended, size) = loop {
|
||||
match MessageBorrowed::deserialize(&$read_buf[..$read_so_far]) {
|
||||
Ok((msg, size)) => break (msg, size),
|
||||
|
|
@ -87,8 +73,7 @@ macro_rules! read_one {
|
|||
$read_buf.resize($read_buf.capacity(), 0);
|
||||
}
|
||||
|
||||
let size = $conn
|
||||
.read(&mut $read_buf[$read_so_far..])
|
||||
let size = with_timeout($rwtimeout, $conn.read(&mut $read_buf[$read_so_far..]))
|
||||
.await
|
||||
.context("error reading from peer")?;
|
||||
if size == 0 {
|
||||
|
|
@ -130,29 +115,31 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
) -> anyhow::Result<()> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
let mut conn = match timeout(
|
||||
self.options
|
||||
.connect_timeout
|
||||
.unwrap_or_else(|| Duration::from_secs(10)),
|
||||
tokio::net::TcpStream::connect(self.addr),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(conn) => conn.context("error connecting")?,
|
||||
Err(_) => anyhow::bail!("timeout connecting to {}", self.addr),
|
||||
};
|
||||
|
||||
let rwtimeout = self
|
||||
.options
|
||||
.read_write_timeout
|
||||
.unwrap_or_else(|| Duration::from_secs(10));
|
||||
|
||||
let connect_timeout = self
|
||||
.options
|
||||
.connect_timeout
|
||||
.unwrap_or_else(|| Duration::from_secs(10));
|
||||
|
||||
let mut conn = with_timeout(connect_timeout, tokio::net::TcpStream::connect(self.addr))
|
||||
.await
|
||||
.context("error connecting")?;
|
||||
|
||||
let mut write_buf = Vec::<u8>::with_capacity(PIECE_MESSAGE_DEFAULT_LEN);
|
||||
let handshake = Handshake::new(self.info_hash, self.peer_id);
|
||||
handshake.serialize(&mut write_buf);
|
||||
conn.write_all(&write_buf)
|
||||
with_timeout(rwtimeout, conn.write_all(&write_buf))
|
||||
.await
|
||||
.context("error writing handshake")?;
|
||||
write_buf.clear();
|
||||
|
||||
let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2];
|
||||
let mut read_so_far = conn
|
||||
.read(&mut read_buf)
|
||||
let mut read_so_far = with_timeout(rwtimeout, conn.read(&mut read_buf))
|
||||
.await
|
||||
.context("error reading handshake")?;
|
||||
if read_so_far == 0 {
|
||||
|
|
@ -188,12 +175,12 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
&my_extended
|
||||
);
|
||||
my_extended.serialize(&mut write_buf, None).unwrap();
|
||||
conn.write_all(&write_buf)
|
||||
with_timeout(rwtimeout, conn.write_all(&write_buf))
|
||||
.await
|
||||
.context("error writing extended handshake")?;
|
||||
write_buf.clear();
|
||||
|
||||
let (extended, size) = read_one!(conn, read_buf, read_so_far);
|
||||
let (extended, size) = read_one!(conn, read_buf, read_so_far, rwtimeout);
|
||||
match extended {
|
||||
Message::Extended(ExtendedMessage::Handshake(h)) => {
|
||||
trace!("received from {}: {:?}", self.addr, &h);
|
||||
|
|
@ -222,8 +209,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
.handler
|
||||
.serialize_bitfield_message_to_buf(&mut write_buf)
|
||||
{
|
||||
write_half
|
||||
.write_all(&write_buf[..len])
|
||||
with_timeout(rwtimeout, write_half.write_all(&write_buf[..len]))
|
||||
.await
|
||||
.context("error writing bitfield to peer")?;
|
||||
debug!("sent bitfield to {}", self.addr);
|
||||
|
|
@ -256,7 +242,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
self.handler
|
||||
.read_chunk(chunk, &mut write_buf[preamble_len..])
|
||||
})
|
||||
.with_context(|| format!("error reading chunk {:?}", chunk))?;
|
||||
.with_context(|| format!("error reading chunk {chunk:?}"))?;
|
||||
|
||||
uploaded_add = Some(chunk.size);
|
||||
full_len
|
||||
|
|
@ -265,8 +251,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
|
||||
debug!("sending to {}: {:?}, length={}", self.addr, &req, len);
|
||||
|
||||
write_half
|
||||
.write_all(&write_buf[..len])
|
||||
with_timeout(rwtimeout, write_half.write_all(&write_buf[..len]))
|
||||
.await
|
||||
.context("error writing the message to peer")?;
|
||||
write_buf.clear();
|
||||
|
|
@ -283,7 +268,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
|
|||
|
||||
let reader = async move {
|
||||
loop {
|
||||
let (message, size) = read_one!(read_half, read_buf, read_so_far);
|
||||
let (message, size) = read_one!(read_half, read_buf, read_so_far, rwtimeout);
|
||||
trace!("received from {}: {:?}", self.addr, &message);
|
||||
|
||||
self.handler
|
||||
|
|
|
|||
|
|
@ -355,6 +355,10 @@ impl Session {
|
|||
builder.peer_connect_timeout(t);
|
||||
}
|
||||
|
||||
if let Some(t) = opts.peer_opts.unwrap_or(self.peer_opts).read_write_timeout {
|
||||
builder.peer_read_write_timeout(t);
|
||||
}
|
||||
|
||||
let handle = match builder
|
||||
.start_manager()
|
||||
.context("error starting torrent manager")
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ use crate::{
|
|||
struct TorrentManagerOptions {
|
||||
force_tracker_interval: Option<Duration>,
|
||||
peer_connect_timeout: Option<Duration>,
|
||||
peer_read_write_timeout: Option<Duration>,
|
||||
only_files: Option<Vec<usize>>,
|
||||
peer_id: Option<Id20>,
|
||||
overwrite: bool,
|
||||
|
|
@ -90,6 +91,11 @@ impl TorrentManagerBuilder {
|
|||
self
|
||||
}
|
||||
|
||||
pub fn peer_read_write_timeout(&mut self, timeout: Duration) -> &mut Self {
|
||||
self.options.peer_read_write_timeout = Some(timeout);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn start_manager(self) -> anyhow::Result<TorrentManagerHandle> {
|
||||
TorrentManager::start(
|
||||
self.info,
|
||||
|
|
@ -256,6 +262,7 @@ impl TorrentManager {
|
|||
#[allow(clippy::needless_update)]
|
||||
let state_options = TorrentStateOptions {
|
||||
peer_connect_timeout: options.peer_connect_timeout,
|
||||
peer_read_write_timeout: options.peer_read_write_timeout,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -225,6 +225,7 @@ impl StatsSnapshot {
|
|||
#[derive(Default)]
|
||||
pub struct TorrentStateOptions {
|
||||
pub peer_connect_timeout: Option<Duration>,
|
||||
pub peer_read_write_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
pub struct TorrentState {
|
||||
|
|
@ -286,6 +287,7 @@ impl TorrentState {
|
|||
loop {
|
||||
let (addr, out_rx) = peer_queue_rx.recv().await.unwrap();
|
||||
|
||||
let permit = state.peer_semaphore.acquire().await.unwrap();
|
||||
match state.locked.write().peers.states.get_mut(&addr) {
|
||||
Some(s @ PeerState::Queued) => *s = PeerState::Connecting,
|
||||
s => {
|
||||
|
|
@ -294,8 +296,6 @@ impl TorrentState {
|
|||
}
|
||||
};
|
||||
|
||||
state.peer_semaphore.acquire().await.unwrap().forget();
|
||||
|
||||
let handler = PeerHandler {
|
||||
addr,
|
||||
state: state.clone(),
|
||||
|
|
@ -303,6 +303,7 @@ impl TorrentState {
|
|||
};
|
||||
let options = PeerConnectionOptions {
|
||||
connect_timeout: state.options.peer_connect_timeout,
|
||||
read_write_timeout: state.options.peer_read_write_timeout,
|
||||
..Default::default()
|
||||
};
|
||||
let peer_connection = PeerConnection::new(
|
||||
|
|
@ -313,7 +314,9 @@ impl TorrentState {
|
|||
Some(options),
|
||||
spawner,
|
||||
);
|
||||
spawn(format!("manage_peer({})", addr), async move {
|
||||
|
||||
permit.forget();
|
||||
spawn(format!("manage_peer({addr})"), async move {
|
||||
if let Err(e) = peer_connection.manage_peer(out_rx).await {
|
||||
debug!("error managing peer {}: {:#}", addr, e)
|
||||
};
|
||||
|
|
|
|||
|
|
@ -70,6 +70,10 @@ struct Opts {
|
|||
#[clap(long = "peer-connect-timeout")]
|
||||
peer_connect_timeout: Option<ParsedDuration>,
|
||||
|
||||
/// The connect timeout, e.g. 1s, 1.5s, 100ms etc.
|
||||
#[clap(long = "peer-read-write-timeout")]
|
||||
peer_read_write_timeout: Option<ParsedDuration>,
|
||||
|
||||
/// How many threads to spawn for the executor.
|
||||
#[clap(short = 't', long)]
|
||||
worker_threads: Option<usize>,
|
||||
|
|
@ -200,6 +204,7 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
|
|||
peer_id: None,
|
||||
peer_opts: Some(PeerConnectionOptions {
|
||||
connect_timeout: opts.peer_connect_timeout.map(|d| d.0),
|
||||
read_write_timeout: opts.peer_read_write_timeout.map(|d| d.0),
|
||||
..Default::default()
|
||||
}),
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue