Merge pull request #320 from ikatson/stream-traits

[perf] don't use tokio::io::split
This commit is contained in:
Igor Katson 2025-01-30 14:11:36 +00:00 committed by GitHub
commit ed32b899ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 78 additions and 57 deletions

View file

@ -26,7 +26,7 @@ async fn h_api_root(parts: Parts) -> impl IntoResponse {
.headers .headers
.get("Accept") .get("Accept")
.and_then(|h| h.to_str().ok()) .and_then(|h| h.to_str().ok())
.map_or(false, |h| h.contains("text/html")) .is_some_and(|h| h.contains("text/html"))
{ {
return Redirect::temporary("./web/").into_response(); return Redirect::temporary("./web/").into_response();
} }

View file

@ -88,6 +88,16 @@ where
} }
} }
struct ManagePeerArgs<R, W> {
handshake_supports_extended: bool,
read_buf: ReadBuf,
write_buf: Vec<u8>,
read: R,
write: W,
outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
have_broadcast: tokio::sync::broadcast::Receiver<ValidPieceIndex>,
}
impl<H: PeerConnectionHandler> PeerConnection<H> { impl<H: PeerConnectionHandler> PeerConnection<H> {
pub fn new( pub fn new(
addr: SocketAddr, addr: SocketAddr,
@ -147,18 +157,21 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
.context("error writing handshake")?; .context("error writing handshake")?;
write_buf.clear(); write_buf.clear();
let h_supports_extended = handshake.supports_extended(); let handshake_supports_extended = handshake.supports_extended();
self.handler.on_handshake(handshake)?; self.handler.on_handshake(handshake)?;
self.manage_peer( let (read, write) = conn.into_split();
h_supports_extended,
self.manage_peer(ManagePeerArgs {
handshake_supports_extended,
read_buf, read_buf,
write_buf, write_buf,
conn, read,
write,
outgoing_chan, outgoing_chan,
have_broadcast, have_broadcast,
) })
.await .await
} }
@ -179,26 +192,26 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
.unwrap_or_else(|| Duration::from_secs(10)); .unwrap_or_else(|| Duration::from_secs(10));
let now = Instant::now(); let now = Instant::now();
let conn = self.connector.connect(self.addr); let (mut read, mut write) =
let mut conn = with_timeout(connect_timeout, conn) with_timeout(connect_timeout, self.connector.connect(self.addr))
.await .await
.context("error connecting")?; .context("error connecting")?;
self.handler.on_connected(now.elapsed()); self.handler.on_connected(now.elapsed());
let mut write_buf = Vec::<u8>::with_capacity(PIECE_MESSAGE_DEFAULT_LEN); let mut write_buf = Vec::<u8>::with_capacity(PIECE_MESSAGE_DEFAULT_LEN);
let handshake = Handshake::new(self.info_hash, self.peer_id); let handshake = Handshake::new(self.info_hash, self.peer_id);
handshake.serialize(&mut write_buf); handshake.serialize(&mut write_buf);
with_timeout(rwtimeout, conn.write_all(&write_buf)) with_timeout(rwtimeout, write.write_all(&write_buf))
.await .await
.context("error writing handshake")?; .context("error writing handshake")?;
write_buf.clear(); write_buf.clear();
let mut read_buf = ReadBuf::new(); let mut read_buf = ReadBuf::new();
let h = read_buf let h = read_buf
.read_handshake(&mut conn, rwtimeout) .read_handshake(&mut read, rwtimeout)
.await .await
.context("error reading handshake")?; .context("error reading handshake")?;
let h_supports_extended = h.supports_extended(); let handshake_supports_extended = h.supports_extended();
trace!( trace!(
peer_id=?Id20::new(h.peer_id), peer_id=?Id20::new(h.peer_id),
decoded_id=?try_decode_peer_id(Id20::new(h.peer_id)), decoded_id=?try_decode_peer_id(Id20::new(h.peer_id)),
@ -214,26 +227,35 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
self.handler.on_handshake(h)?; self.handler.on_handshake(h)?;
self.manage_peer( self.manage_peer(ManagePeerArgs {
h_supports_extended, handshake_supports_extended,
read_buf, read_buf,
write_buf, write_buf,
conn, read,
write,
outgoing_chan, outgoing_chan,
have_broadcast, have_broadcast,
) })
.await .await
} }
async fn manage_peer( async fn manage_peer(
&self, &self,
handshake_supports_extended: bool, args: ManagePeerArgs<
mut read_buf: ReadBuf, impl tokio::io::AsyncRead + Send + Unpin,
mut write_buf: Vec<u8>, impl tokio::io::AsyncWrite + Send + Unpin,
mut conn: impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, >,
mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver<WriterRequest>,
mut have_broadcast: tokio::sync::broadcast::Receiver<ValidPieceIndex>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let ManagePeerArgs {
handshake_supports_extended,
mut read_buf,
mut write_buf,
mut read,
mut write,
mut outgoing_chan,
mut have_broadcast,
} = args;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
let rwtimeout = self let rwtimeout = self
@ -256,14 +278,12 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
my_extended my_extended
.serialize(&mut write_buf, &Default::default) .serialize(&mut write_buf, &Default::default)
.unwrap(); .unwrap();
with_timeout(rwtimeout, conn.write_all(&write_buf)) with_timeout(rwtimeout, write.write_all(&write_buf))
.await .await
.context("error writing extended handshake")?; .context("error writing extended handshake")?;
write_buf.clear(); write_buf.clear();
} }
let (mut read_half, mut write_half) = tokio::io::split(conn);
let writer = async move { let writer = async move {
let keep_alive_interval = self let keep_alive_interval = self
.options .options
@ -274,14 +294,14 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
let len = self let len = self
.handler .handler
.serialize_bitfield_message_to_buf(&mut write_buf)?; .serialize_bitfield_message_to_buf(&mut write_buf)?;
with_timeout(rwtimeout, write_half.write_all(&write_buf[..len])) with_timeout(rwtimeout, write.write_all(&write_buf[..len]))
.await .await
.context("error writing bitfield to peer")?; .context("error writing bitfield to peer")?;
trace!("sent bitfield"); trace!("sent bitfield");
} }
let len = MessageOwned::Unchoke.serialize(&mut write_buf, &Default::default)?; let len = MessageOwned::Unchoke.serialize(&mut write_buf, &Default::default)?;
with_timeout(rwtimeout, write_half.write_all(&write_buf[..len])) with_timeout(rwtimeout, write.write_all(&write_buf[..len]))
.await .await
.context("error writing unchoke")?; .context("error writing unchoke")?;
trace!("sent unchoke"); trace!("sent unchoke");
@ -378,7 +398,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
} }
}; };
with_timeout(rwtimeout, write_half.write_all(&write_buf[..len])) with_timeout(rwtimeout, write.write_all(&write_buf[..len]))
.await .await
.context("error writing the message to peer")?; .context("error writing the message to peer")?;
@ -395,7 +415,7 @@ impl<H: PeerConnectionHandler> PeerConnection<H> {
let reader = async move { let reader = async move {
loop { loop {
let message = read_buf let message = read_buf
.read_message(&mut read_half, rwtimeout) .read_message(&mut read, rwtimeout)
.await .await
.context("error reading message")?; .context("error reading message")?;
trace!("received: {:?}", &message); trace!("received: {:?}", &message);

View file

@ -996,7 +996,7 @@ impl Session {
name, name,
} = add_res; } = add_res;
let private = metadata.as_ref().map_or(false, |m| m.info.private); let private = metadata.as_ref().is_some_and(|m| m.info.private);
let make_peer_rx = || { let make_peer_rx = || {
self.make_peer_rx( self.make_peer_rx(

View file

@ -30,10 +30,13 @@ impl SocksProxyConfig {
async fn connect( async fn connect(
&self, &self,
addr: SocketAddr, addr: SocketAddr,
) -> anyhow::Result<impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> { ) -> anyhow::Result<(
impl tokio::io::AsyncRead + Unpin,
impl tokio::io::AsyncWrite + Unpin,
)> {
let proxy_addr = (self.host.as_str(), self.port); let proxy_addr = (self.host.as_str(), self.port);
if let Some((username, password)) = self.username_password.as_ref() { let stream = if let Some((username, password)) = self.username_password.as_ref() {
tokio_socks::tcp::Socks5Stream::connect_with_password( tokio_socks::tcp::Socks5Stream::connect_with_password(
proxy_addr, proxy_addr,
addr, addr,
@ -41,12 +44,14 @@ impl SocksProxyConfig {
password.as_str(), password.as_str(),
) )
.await .await
.context("error connecting to proxy") .context("error connecting to proxy")?
} else { } else {
tokio_socks::tcp::Socks5Stream::connect(proxy_addr, addr) tokio_socks::tcp::Socks5Stream::connect(proxy_addr, addr)
.await .await
.context("error connecting to proxy") .context("error connecting to proxy")?
} };
Ok(tokio::io::split(stream))
} }
} }
@ -61,22 +66,23 @@ impl From<Option<SocksProxyConfig>> for StreamConnector {
} }
} }
pub(crate) trait AsyncReadWrite:
tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin
{
}
impl<T> AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin {}
impl StreamConnector { impl StreamConnector {
pub async fn connect(&self, addr: SocketAddr) -> anyhow::Result<Box<dyn AsyncReadWrite>> { pub async fn connect(
&self,
addr: SocketAddr,
) -> anyhow::Result<(
Box<dyn tokio::io::AsyncRead + Send + Unpin>,
Box<dyn tokio::io::AsyncWrite + Send + Unpin>,
)> {
if let Some(proxy) = self.proxy_config.as_ref() { if let Some(proxy) = self.proxy_config.as_ref() {
return Ok(Box::new(proxy.connect(addr).await?)); let (r, w) = proxy.connect(addr).await?;
return Ok((Box::new(r), Box::new(w)));
} }
Ok(Box::new(
tokio::net::TcpStream::connect(addr) let (r, w) = tokio::net::TcpStream::connect(addr)
.await .await
.context("error connecting")?, .context("error connecting")?
)) .into_split();
Ok((Box::new(r), Box::new(w)))
} }
} }

View file

@ -1253,10 +1253,7 @@ impl PeerHandler {
/// ///
/// If this returns, an existing in-flight piece was marked to be ours. /// If this returns, an existing in-flight piece was marked to be ours.
fn try_steal_old_slow_piece(&self, threshold: f64) -> Option<ValidPieceIndex> { fn try_steal_old_slow_piece(&self, threshold: f64) -> Option<ValidPieceIndex> {
let my_avg_time = match self.counters.average_piece_download_time() { let my_avg_time = self.counters.average_piece_download_time()?;
Some(t) => t,
None => return None,
};
let (stolen_idx, from_peer) = { let (stolen_idx, from_peer) = {
let mut g = self.state.lock_write("try_steal_old_slow_piece"); let mut g = self.state.lock_write("try_steal_old_slow_piece");

View file

@ -263,8 +263,6 @@ impl LivePeerState {
} }
pub fn has_full_torrent(&self, total_pieces: usize) -> bool { pub fn has_full_torrent(&self, total_pieces: usize) -> bool {
self.bitfield self.bitfield.get(0..total_pieces).is_some_and(|s| s.all())
.get(0..total_pieces)
.map_or(false, |s| s.all())
} }
} }