diff --git a/crates/librqbit/src/peer_connection.rs b/crates/librqbit/src/peer_connection.rs index 3d54cf2..85c48bf 100644 --- a/crates/librqbit/src/peer_connection.rs +++ b/crates/librqbit/src/peer_connection.rs @@ -151,11 +151,14 @@ impl PeerConnection { self.handler.on_handshake(handshake)?; + let (read, write) = conn.into_split(); + self.manage_peer( h_supports_extended, read_buf, write_buf, - conn, + read, + write, outgoing_chan, have_broadcast, ) @@ -179,23 +182,23 @@ impl PeerConnection { .unwrap_or_else(|| Duration::from_secs(10)); let now = Instant::now(); - let conn = self.connector.connect(self.addr); - let mut conn = with_timeout(connect_timeout, conn) - .await - .context("error connecting")?; + let (mut read, mut write) = + with_timeout(connect_timeout, self.connector.connect(self.addr)) + .await + .context("error connecting")?; self.handler.on_connected(now.elapsed()); let mut write_buf = Vec::::with_capacity(PIECE_MESSAGE_DEFAULT_LEN); let handshake = Handshake::new(self.info_hash, self.peer_id); handshake.serialize(&mut write_buf); - with_timeout(rwtimeout, conn.write_all(&write_buf)) + with_timeout(rwtimeout, write.write_all(&write_buf)) .await .context("error writing handshake")?; write_buf.clear(); let mut read_buf = ReadBuf::new(); let h = read_buf - .read_handshake(&mut conn, rwtimeout) + .read_handshake(&mut read, rwtimeout) .await .context("error reading handshake")?; let h_supports_extended = h.supports_extended(); @@ -218,7 +221,8 @@ impl PeerConnection { h_supports_extended, read_buf, write_buf, - conn, + read, + write, outgoing_chan, have_broadcast, ) @@ -230,7 +234,8 @@ impl PeerConnection { handshake_supports_extended: bool, mut read_buf: ReadBuf, mut write_buf: Vec, - mut conn: impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, + mut read: impl tokio::io::AsyncRead + Unpin + Send, + mut write: impl tokio::io::AsyncWrite + Unpin + Send, mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, mut have_broadcast: tokio::sync::broadcast::Receiver, ) -> anyhow::Result<()> { @@ -256,14 +261,12 @@ impl PeerConnection { my_extended .serialize(&mut write_buf, &Default::default) .unwrap(); - with_timeout(rwtimeout, conn.write_all(&write_buf)) + with_timeout(rwtimeout, write.write_all(&write_buf)) .await .context("error writing extended handshake")?; write_buf.clear(); } - let (mut read_half, mut write_half) = tokio::io::split(conn); - let writer = async move { let keep_alive_interval = self .options @@ -274,14 +277,14 @@ impl PeerConnection { let len = self .handler .serialize_bitfield_message_to_buf(&mut write_buf)?; - with_timeout(rwtimeout, write_half.write_all(&write_buf[..len])) + with_timeout(rwtimeout, write.write_all(&write_buf[..len])) .await .context("error writing bitfield to peer")?; trace!("sent bitfield"); } let len = MessageOwned::Unchoke.serialize(&mut write_buf, &Default::default)?; - with_timeout(rwtimeout, write_half.write_all(&write_buf[..len])) + with_timeout(rwtimeout, write.write_all(&write_buf[..len])) .await .context("error writing unchoke")?; trace!("sent unchoke"); @@ -378,7 +381,7 @@ impl PeerConnection { } }; - with_timeout(rwtimeout, write_half.write_all(&write_buf[..len])) + with_timeout(rwtimeout, write.write_all(&write_buf[..len])) .await .context("error writing the message to peer")?; @@ -395,7 +398,7 @@ impl PeerConnection { let reader = async move { loop { let message = read_buf - .read_message(&mut read_half, rwtimeout) + .read_message(&mut read, rwtimeout) .await .context("error reading message")?; trace!("received: {:?}", &message); diff --git a/crates/librqbit/src/stream_connect.rs b/crates/librqbit/src/stream_connect.rs index 11e91e0..fed786d 100644 --- a/crates/librqbit/src/stream_connect.rs +++ b/crates/librqbit/src/stream_connect.rs @@ -30,10 +30,13 @@ impl SocksProxyConfig { async fn connect( &self, addr: SocketAddr, - ) -> anyhow::Result { + ) -> anyhow::Result<( + impl tokio::io::AsyncRead + Unpin, + impl tokio::io::AsyncWrite + Unpin, + )> { let proxy_addr = (self.host.as_str(), self.port); - if let Some((username, password)) = self.username_password.as_ref() { + let stream = if let Some((username, password)) = self.username_password.as_ref() { tokio_socks::tcp::Socks5Stream::connect_with_password( proxy_addr, addr, @@ -41,12 +44,14 @@ impl SocksProxyConfig { password.as_str(), ) .await - .context("error connecting to proxy") + .context("error connecting to proxy")? } else { tokio_socks::tcp::Socks5Stream::connect(proxy_addr, addr) .await - .context("error connecting to proxy") - } + .context("error connecting to proxy")? + }; + + Ok(tokio::io::split(stream)) } } @@ -61,22 +66,23 @@ impl From> for StreamConnector { } } -pub(crate) trait AsyncReadWrite: - tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin -{ -} - -impl AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {} - impl StreamConnector { - pub async fn connect(&self, addr: SocketAddr) -> anyhow::Result> { + pub async fn connect( + &self, + addr: SocketAddr, + ) -> anyhow::Result<( + Box, + Box, + )> { if let Some(proxy) = self.proxy_config.as_ref() { - return Ok(Box::new(proxy.connect(addr).await?)); + let (r, w) = proxy.connect(addr).await?; + return Ok((Box::new(r), Box::new(w))); } - Ok(Box::new( - tokio::net::TcpStream::connect(addr) - .await - .context("error connecting")?, - )) + + let (r, w) = tokio::net::TcpStream::connect(addr) + .await + .context("error connecting")? + .into_split(); + Ok((Box::new(r), Box::new(w))) } }