diff --git a/crates/librqbit/src/peer_connection.rs b/crates/librqbit/src/peer_connection.rs index f5f3c31..af35a02 100644 --- a/crates/librqbit/src/peer_connection.rs +++ b/crates/librqbit/src/peer_connection.rs @@ -36,6 +36,7 @@ pub enum WriterRequest { #[derive(Default, Copy, Clone)] pub struct PeerConnectionOptions { pub connect_timeout: Option, + pub read_write_timeout: Option, pub keep_alive_interval: Option, } @@ -48,36 +49,21 @@ pub struct PeerConnection { spawner: BlockingSpawner, } -// async fn read_one<'a, R: AsyncReadExt + Unpin>( -// mut reader: R, -// read_buf: &'a mut Vec, -// read_so_far: &mut usize, -// ) -> anyhow::Result<(MessageBorrowed<'a>, usize)> { -// loop { -// match MessageBorrowed::deserialize(&read_buf[..*read_so_far]) { -// Ok((msg, size)) => return Ok((msg, size)), -// Err(MessageDeserializeError::NotEnoughData(d, _)) => { -// if read_buf.len() < *read_so_far + d { -// read_buf.reserve(d); -// read_buf.resize(read_buf.capacity(), 0); -// } - -// let size = reader -// .read(&mut read_buf[*read_so_far..]) -// .await -// .context("error reading from peer")?; -// if size == 0 { -// anyhow::bail!("disconnected while reading, read so far: {}", *read_so_far) -// } -// *read_so_far += size; -// } -// Err(e) => return Err(e.into()), -// } -// } -// } +async fn with_timeout( + timeout_value: Duration, + fut: impl std::future::Future>, +) -> anyhow::Result +where + E: Into, +{ + timeout(timeout_value, fut) + .await + .with_context(|| format!("timeout at {timeout_value:?}"))? + .map_err(|e| e.into()) +} macro_rules! read_one { - ($conn:ident, $read_buf:ident, $read_so_far:ident) => {{ + ($conn:ident, $read_buf:ident, $read_so_far:ident, $rwtimeout:ident) => {{ let (extended, size) = loop { match MessageBorrowed::deserialize(&$read_buf[..$read_so_far]) { Ok((msg, size)) => break (msg, size), @@ -87,8 +73,7 @@ macro_rules! read_one { $read_buf.resize($read_buf.capacity(), 0); } - let size = $conn - .read(&mut $read_buf[$read_so_far..]) + let size = with_timeout($rwtimeout, $conn.read(&mut $read_buf[$read_so_far..])) .await .context("error reading from peer")?; if size == 0 { @@ -130,29 +115,31 @@ impl PeerConnection { ) -> anyhow::Result<()> { use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; - let mut conn = match timeout( - self.options - .connect_timeout - .unwrap_or_else(|| Duration::from_secs(10)), - tokio::net::TcpStream::connect(self.addr), - ) - .await - { - Ok(conn) => conn.context("error connecting")?, - Err(_) => anyhow::bail!("timeout connecting to {}", self.addr), - }; + + let rwtimeout = self + .options + .read_write_timeout + .unwrap_or_else(|| Duration::from_secs(10)); + + let connect_timeout = self + .options + .connect_timeout + .unwrap_or_else(|| Duration::from_secs(10)); + + let mut conn = with_timeout(connect_timeout, tokio::net::TcpStream::connect(self.addr)) + .await + .context("error connecting")?; let mut write_buf = Vec::::with_capacity(PIECE_MESSAGE_DEFAULT_LEN); let handshake = Handshake::new(self.info_hash, self.peer_id); handshake.serialize(&mut write_buf); - conn.write_all(&write_buf) + with_timeout(rwtimeout, conn.write_all(&write_buf)) .await .context("error writing handshake")?; write_buf.clear(); let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2]; - let mut read_so_far = conn - .read(&mut read_buf) + let mut read_so_far = with_timeout(rwtimeout, conn.read(&mut read_buf)) .await .context("error reading handshake")?; if read_so_far == 0 { @@ -188,12 +175,12 @@ impl PeerConnection { &my_extended ); my_extended.serialize(&mut write_buf, None).unwrap(); - conn.write_all(&write_buf) + with_timeout(rwtimeout, conn.write_all(&write_buf)) .await .context("error writing extended handshake")?; write_buf.clear(); - let (extended, size) = read_one!(conn, read_buf, read_so_far); + let (extended, size) = read_one!(conn, read_buf, read_so_far, rwtimeout); match extended { Message::Extended(ExtendedMessage::Handshake(h)) => { trace!("received from {}: {:?}", self.addr, &h); @@ -222,8 +209,7 @@ impl PeerConnection { .handler .serialize_bitfield_message_to_buf(&mut write_buf) { - write_half - .write_all(&write_buf[..len]) + with_timeout(rwtimeout, write_half.write_all(&write_buf[..len])) .await .context("error writing bitfield to peer")?; debug!("sent bitfield to {}", self.addr); @@ -256,7 +242,7 @@ impl PeerConnection { self.handler .read_chunk(chunk, &mut write_buf[preamble_len..]) }) - .with_context(|| format!("error reading chunk {:?}", chunk))?; + .with_context(|| format!("error reading chunk {chunk:?}"))?; uploaded_add = Some(chunk.size); full_len @@ -265,8 +251,7 @@ impl PeerConnection { debug!("sending to {}: {:?}, length={}", self.addr, &req, len); - write_half - .write_all(&write_buf[..len]) + with_timeout(rwtimeout, write_half.write_all(&write_buf[..len])) .await .context("error writing the message to peer")?; write_buf.clear(); @@ -283,7 +268,7 @@ impl PeerConnection { let reader = async move { loop { - let (message, size) = read_one!(read_half, read_buf, read_so_far); + let (message, size) = read_one!(read_half, read_buf, read_so_far, rwtimeout); trace!("received from {}: {:?}", self.addr, &message); self.handler diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 11916ce..b6a43b7 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -355,6 +355,10 @@ impl Session { builder.peer_connect_timeout(t); } + if let Some(t) = opts.peer_opts.unwrap_or(self.peer_opts).read_write_timeout { + builder.peer_read_write_timeout(t); + } + let handle = match builder .start_manager() .context("error starting torrent manager") diff --git a/crates/librqbit/src/torrent_manager.rs b/crates/librqbit/src/torrent_manager.rs index 6bb8dc7..0e9b7f8 100644 --- a/crates/librqbit/src/torrent_manager.rs +++ b/crates/librqbit/src/torrent_manager.rs @@ -32,6 +32,7 @@ use crate::{ struct TorrentManagerOptions { force_tracker_interval: Option, peer_connect_timeout: Option, + peer_read_write_timeout: Option, only_files: Option>, peer_id: Option, overwrite: bool, @@ -90,6 +91,11 @@ impl TorrentManagerBuilder { self } + pub fn peer_read_write_timeout(&mut self, timeout: Duration) -> &mut Self { + self.options.peer_read_write_timeout = Some(timeout); + self + } + pub fn start_manager(self) -> anyhow::Result { TorrentManager::start( self.info, @@ -256,6 +262,7 @@ impl TorrentManager { #[allow(clippy::needless_update)] let state_options = TorrentStateOptions { peer_connect_timeout: options.peer_connect_timeout, + peer_read_write_timeout: options.peer_read_write_timeout, ..Default::default() }; diff --git a/crates/librqbit/src/torrent_state.rs b/crates/librqbit/src/torrent_state.rs index a3271ee..aa52918 100644 --- a/crates/librqbit/src/torrent_state.rs +++ b/crates/librqbit/src/torrent_state.rs @@ -225,6 +225,7 @@ impl StatsSnapshot { #[derive(Default)] pub struct TorrentStateOptions { pub peer_connect_timeout: Option, + pub peer_read_write_timeout: Option, } pub struct TorrentState { @@ -286,6 +287,7 @@ impl TorrentState { loop { let (addr, out_rx) = peer_queue_rx.recv().await.unwrap(); + let permit = state.peer_semaphore.acquire().await.unwrap(); match state.locked.write().peers.states.get_mut(&addr) { Some(s @ PeerState::Queued) => *s = PeerState::Connecting, s => { @@ -294,8 +296,6 @@ impl TorrentState { } }; - state.peer_semaphore.acquire().await.unwrap().forget(); - let handler = PeerHandler { addr, state: state.clone(), @@ -303,6 +303,7 @@ impl TorrentState { }; let options = PeerConnectionOptions { connect_timeout: state.options.peer_connect_timeout, + read_write_timeout: state.options.peer_read_write_timeout, ..Default::default() }; let peer_connection = PeerConnection::new( @@ -313,7 +314,9 @@ impl TorrentState { Some(options), spawner, ); - spawn(format!("manage_peer({})", addr), async move { + + permit.forget(); + spawn(format!("manage_peer({addr})"), async move { if let Err(e) = peer_connection.manage_peer(out_rx).await { debug!("error managing peer {}: {:#}", addr, e) }; diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index e656750..03aebdb 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -70,6 +70,10 @@ struct Opts { #[clap(long = "peer-connect-timeout")] peer_connect_timeout: Option, + /// The connect timeout, e.g. 1s, 1.5s, 100ms etc. + #[clap(long = "peer-read-write-timeout")] + peer_read_write_timeout: Option, + /// How many threads to spawn for the executor. #[clap(short = 't', long)] worker_threads: Option, @@ -200,6 +204,7 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()> peer_id: None, peer_opts: Some(PeerConnectionOptions { connect_timeout: opts.peer_connect_timeout.map(|d| d.0), + read_write_timeout: opts.peer_read_write_timeout.map(|d| d.0), ..Default::default() }), };