diff --git a/crates/librqbit/src/http_api/handlers/mod.rs b/crates/librqbit/src/http_api/handlers/mod.rs index 5094974..ff3b43d 100644 --- a/crates/librqbit/src/http_api/handlers/mod.rs +++ b/crates/librqbit/src/http_api/handlers/mod.rs @@ -26,7 +26,7 @@ async fn h_api_root(parts: Parts) -> impl IntoResponse { .headers .get("Accept") .and_then(|h| h.to_str().ok()) - .map_or(false, |h| h.contains("text/html")) + .is_some_and(|h| h.contains("text/html")) { return Redirect::temporary("./web/").into_response(); } diff --git a/crates/librqbit/src/peer_connection.rs b/crates/librqbit/src/peer_connection.rs index 3d54cf2..523da17 100644 --- a/crates/librqbit/src/peer_connection.rs +++ b/crates/librqbit/src/peer_connection.rs @@ -88,6 +88,16 @@ where } } +struct ManagePeerArgs { + handshake_supports_extended: bool, + read_buf: ReadBuf, + write_buf: Vec, + read: R, + write: W, + outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, + have_broadcast: tokio::sync::broadcast::Receiver, +} + impl PeerConnection { pub fn new( addr: SocketAddr, @@ -147,18 +157,21 @@ impl PeerConnection { .context("error writing handshake")?; write_buf.clear(); - let h_supports_extended = handshake.supports_extended(); + let handshake_supports_extended = handshake.supports_extended(); self.handler.on_handshake(handshake)?; - self.manage_peer( - h_supports_extended, + let (read, write) = conn.into_split(); + + self.manage_peer(ManagePeerArgs { + handshake_supports_extended, read_buf, write_buf, - conn, + read, + write, outgoing_chan, have_broadcast, - ) + }) .await } @@ -179,26 +192,26 @@ impl PeerConnection { .unwrap_or_else(|| Duration::from_secs(10)); let now = Instant::now(); - let conn = self.connector.connect(self.addr); - let mut conn = with_timeout(connect_timeout, conn) - .await - .context("error connecting")?; + let (mut read, mut write) = + with_timeout(connect_timeout, self.connector.connect(self.addr)) + .await + .context("error connecting")?; self.handler.on_connected(now.elapsed()); let mut write_buf = Vec::::with_capacity(PIECE_MESSAGE_DEFAULT_LEN); let handshake = Handshake::new(self.info_hash, self.peer_id); handshake.serialize(&mut write_buf); - with_timeout(rwtimeout, conn.write_all(&write_buf)) + with_timeout(rwtimeout, write.write_all(&write_buf)) .await .context("error writing handshake")?; write_buf.clear(); let mut read_buf = ReadBuf::new(); let h = read_buf - .read_handshake(&mut conn, rwtimeout) + .read_handshake(&mut read, rwtimeout) .await .context("error reading handshake")?; - let h_supports_extended = h.supports_extended(); + let handshake_supports_extended = h.supports_extended(); trace!( peer_id=?Id20::new(h.peer_id), decoded_id=?try_decode_peer_id(Id20::new(h.peer_id)), @@ -214,26 +227,35 @@ impl PeerConnection { self.handler.on_handshake(h)?; - self.manage_peer( - h_supports_extended, + self.manage_peer(ManagePeerArgs { + handshake_supports_extended, read_buf, write_buf, - conn, + read, + write, outgoing_chan, have_broadcast, - ) + }) .await } async fn manage_peer( &self, - handshake_supports_extended: bool, - mut read_buf: ReadBuf, - mut write_buf: Vec, - mut conn: impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, - mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, - mut have_broadcast: tokio::sync::broadcast::Receiver, + args: ManagePeerArgs< + impl tokio::io::AsyncRead + Send + Unpin, + impl tokio::io::AsyncWrite + Send + Unpin, + >, ) -> anyhow::Result<()> { + let ManagePeerArgs { + handshake_supports_extended, + mut read_buf, + mut write_buf, + mut read, + mut write, + mut outgoing_chan, + mut have_broadcast, + } = args; + use tokio::io::AsyncWriteExt; let rwtimeout = self @@ -256,14 +278,12 @@ impl PeerConnection { my_extended .serialize(&mut write_buf, &Default::default) .unwrap(); - with_timeout(rwtimeout, conn.write_all(&write_buf)) + with_timeout(rwtimeout, write.write_all(&write_buf)) .await .context("error writing extended handshake")?; write_buf.clear(); } - let (mut read_half, mut write_half) = tokio::io::split(conn); - let writer = async move { let keep_alive_interval = self .options @@ -274,14 +294,14 @@ impl PeerConnection { let len = self .handler .serialize_bitfield_message_to_buf(&mut write_buf)?; - with_timeout(rwtimeout, write_half.write_all(&write_buf[..len])) + with_timeout(rwtimeout, write.write_all(&write_buf[..len])) .await .context("error writing bitfield to peer")?; trace!("sent bitfield"); } let len = MessageOwned::Unchoke.serialize(&mut write_buf, &Default::default)?; - with_timeout(rwtimeout, write_half.write_all(&write_buf[..len])) + with_timeout(rwtimeout, write.write_all(&write_buf[..len])) .await .context("error writing unchoke")?; trace!("sent unchoke"); @@ -378,7 +398,7 @@ impl PeerConnection { } }; - with_timeout(rwtimeout, write_half.write_all(&write_buf[..len])) + with_timeout(rwtimeout, write.write_all(&write_buf[..len])) .await .context("error writing the message to peer")?; @@ -395,7 +415,7 @@ impl PeerConnection { let reader = async move { loop { let message = read_buf - .read_message(&mut read_half, rwtimeout) + .read_message(&mut read, rwtimeout) .await .context("error reading message")?; trace!("received: {:?}", &message); diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 6606f7a..0f6d2ab 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -996,7 +996,7 @@ impl Session { name, } = add_res; - let private = metadata.as_ref().map_or(false, |m| m.info.private); + let private = metadata.as_ref().is_some_and(|m| m.info.private); let make_peer_rx = || { self.make_peer_rx( diff --git a/crates/librqbit/src/stream_connect.rs b/crates/librqbit/src/stream_connect.rs index 11e91e0..fed786d 100644 --- a/crates/librqbit/src/stream_connect.rs +++ b/crates/librqbit/src/stream_connect.rs @@ -30,10 +30,13 @@ impl SocksProxyConfig { async fn connect( &self, addr: SocketAddr, - ) -> anyhow::Result { + ) -> anyhow::Result<( + impl tokio::io::AsyncRead + Unpin, + impl tokio::io::AsyncWrite + Unpin, + )> { let proxy_addr = (self.host.as_str(), self.port); - if let Some((username, password)) = self.username_password.as_ref() { + let stream = if let Some((username, password)) = self.username_password.as_ref() { tokio_socks::tcp::Socks5Stream::connect_with_password( proxy_addr, addr, @@ -41,12 +44,14 @@ impl SocksProxyConfig { password.as_str(), ) .await - .context("error connecting to proxy") + .context("error connecting to proxy")? } else { tokio_socks::tcp::Socks5Stream::connect(proxy_addr, addr) .await - .context("error connecting to proxy") - } + .context("error connecting to proxy")? + }; + + Ok(tokio::io::split(stream)) } } @@ -61,22 +66,23 @@ impl From> for StreamConnector { } } -pub(crate) trait AsyncReadWrite: - tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin -{ -} - -impl AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {} - impl StreamConnector { - pub async fn connect(&self, addr: SocketAddr) -> anyhow::Result> { + pub async fn connect( + &self, + addr: SocketAddr, + ) -> anyhow::Result<( + Box, + Box, + )> { if let Some(proxy) = self.proxy_config.as_ref() { - return Ok(Box::new(proxy.connect(addr).await?)); + let (r, w) = proxy.connect(addr).await?; + return Ok((Box::new(r), Box::new(w))); } - Ok(Box::new( - tokio::net::TcpStream::connect(addr) - .await - .context("error connecting")?, - )) + + let (r, w) = tokio::net::TcpStream::connect(addr) + .await + .context("error connecting")? + .into_split(); + Ok((Box::new(r), Box::new(w))) } } diff --git a/crates/librqbit/src/torrent_state/live/mod.rs b/crates/librqbit/src/torrent_state/live/mod.rs index 0db812a..d7f9748 100644 --- a/crates/librqbit/src/torrent_state/live/mod.rs +++ b/crates/librqbit/src/torrent_state/live/mod.rs @@ -1253,10 +1253,7 @@ impl PeerHandler { /// /// If this returns, an existing in-flight piece was marked to be ours. fn try_steal_old_slow_piece(&self, threshold: f64) -> Option { - let my_avg_time = match self.counters.average_piece_download_time() { - Some(t) => t, - None => return None, - }; + let my_avg_time = self.counters.average_piece_download_time()?; let (stolen_idx, from_peer) = { let mut g = self.state.lock_write("try_steal_old_slow_piece"); diff --git a/crates/librqbit/src/torrent_state/live/peer/mod.rs b/crates/librqbit/src/torrent_state/live/peer/mod.rs index fc1d04f..5d80650 100644 --- a/crates/librqbit/src/torrent_state/live/peer/mod.rs +++ b/crates/librqbit/src/torrent_state/live/peer/mod.rs @@ -263,8 +263,6 @@ impl LivePeerState { } pub fn has_full_torrent(&self, total_pieces: usize) -> bool { - self.bitfield - .get(0..total_pieces) - .map_or(false, |s| s.all()) + self.bitfield.get(0..total_pieces).is_some_and(|s| s.all()) } }