diff --git a/Cargo.lock b/Cargo.lock index 5f02ebe..a38293f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1293,6 +1293,7 @@ version = "5.6.4" dependencies = [ "anyhow", "async-stream", + "async-trait", "axum 0.7.5", "backoff", "base64 0.21.7", diff --git a/crates/librqbit/Cargo.toml b/crates/librqbit/Cargo.toml index ba6a09d..21512a2 100644 --- a/crates/librqbit/Cargo.toml +++ b/crates/librqbit/Cargo.toml @@ -76,6 +76,7 @@ memmap2 = { version = "0.9.4" } rand_distr = { version = "0.4.3", optional = true } lru = { version = "0.12.3", optional = true } +async-trait = "0.1.80" [dev-dependencies] futures = { version = "0.3" } diff --git a/crates/librqbit/src/peer_connection.rs b/crates/librqbit/src/peer_connection.rs index c588135..39b946f 100644 --- a/crates/librqbit/src/peer_connection.rs +++ b/crates/librqbit/src/peer_connection.rs @@ -19,6 +19,7 @@ use tracing::trace; use crate::{read_buf::ReadBuf, spawn_utils::BlockingSpawner}; +#[async_trait::async_trait] pub trait PeerConnectionHandler { fn on_connected(&self, _connection_time: Duration) {} fn get_have_bytes(&self) -> u64; @@ -28,7 +29,7 @@ pub trait PeerConnectionHandler { &self, extended_handshake: &ExtendedHandshake, ) -> anyhow::Result<()>; - fn on_received_message(&self, msg: Message>) -> anyhow::Result<()>; + async fn on_received_message(&self, msg: Message>) -> anyhow::Result<()>; fn on_uploaded_bytes(&self, bytes: u32); fn read_chunk(&self, chunk: &ChunkInfo, buf: &mut [u8]) -> anyhow::Result<()>; } @@ -360,6 +361,7 @@ impl PeerConnection { } else { self.handler .on_received_message(message) + .await .context("error in handler.on_received_message()")?; } } diff --git a/crates/librqbit/src/peer_info_reader/mod.rs b/crates/librqbit/src/peer_info_reader/mod.rs index 2f11a5e..be3bf31 100644 --- a/crates/librqbit/src/peer_info_reader/mod.rs +++ b/crates/librqbit/src/peer_info_reader/mod.rs @@ -141,6 +141,7 @@ struct Handler { locked: RwLock>, } +#[async_trait::async_trait] impl PeerConnectionHandler for Handler { fn get_have_bytes(&self) -> u64 { 0 @@ -157,7 +158,7 @@ impl PeerConnectionHandler for Handler { Ok(()) } - fn on_received_message(&self, msg: Message>) -> anyhow::Result<()> { + async fn on_received_message(&self, msg: Message>) -> anyhow::Result<()> { trace!("{}: received message: {:?}", self.addr, msg); if let Message::Extended(ExtendedMessage::UtMetadata(UtMetadata::Data { diff --git a/crates/librqbit/src/torrent_state/live/mod.rs b/crates/librqbit/src/torrent_state/live/mod.rs index 51b2792..7c7eb41 100644 --- a/crates/librqbit/src/torrent_state/live/mod.rs +++ b/crates/librqbit/src/torrent_state/live/mod.rs @@ -59,6 +59,7 @@ use buffers::{ByteBuf, ByteBufOwned}; use clone_to_owned::CloneToOwned; use futures::{stream::FuturesUnordered, StreamExt}; use librqbit_core::{ + constants::CHUNK_SIZE, hash_id::Id20, lengths::{ChunkInfo, Lengths, ValidPieceIndex}, spawn_utils::spawn_with_cancel, @@ -155,6 +156,10 @@ pub struct TorrentStateOptions { pub peer_read_write_timeout: Option, } +struct DiskWriteWorkItem { + work: Box, +} + pub struct TorrentStateLive { peers: PeerStates, meta: Arc, @@ -179,6 +184,8 @@ pub struct TorrentStateLive { up_speed_estimator: SpeedEstimator, cancellation_token: CancellationToken, + disk_work_tx: tokio::sync::mpsc::Sender, + pub(crate) streams: Arc, } @@ -210,6 +217,10 @@ impl TorrentStateLive { pri }; + // 8MB per torrent of disk buffering. + let (disk_work_tx, mut disk_work_rx) = + tokio::sync::mpsc::channel(8 * 1024 * 1024 / CHUNK_SIZE as usize); + let state = Arc::new(TorrentStateLive { meta: paused.info.clone(), peers: Default::default(), @@ -236,8 +247,19 @@ impl TorrentStateLive { per_piece_locks: (0..lengths.total_pieces()) .map(|_| RwLock::new(())) .collect(), + disk_work_tx, }); + state.spawn( + error_span!(parent: state.meta.span.clone(), "disk_writer"), + async move { + while let Some(work_item) = disk_work_rx.recv().await { + tokio::task::spawn_blocking(work_item.work); + } + Ok(()) + }, + ); + state.spawn( error_span!(parent: state.meta.span.clone(), "speed_estimator_updater"), { @@ -802,6 +824,7 @@ struct PeerHandler { tx: PeerTx, } +#[async_trait::async_trait] impl<'a> PeerConnectionHandler for &'a PeerHandler { fn on_connected(&self, connection_time: Duration) { self.counters @@ -812,7 +835,8 @@ impl<'a> PeerConnectionHandler for &'a PeerHandler { .total_time_connecting_ms .fetch_add(connection_time.as_millis() as u64, Ordering::Relaxed); } - fn on_received_message(&self, message: Message>) -> anyhow::Result<()> { + + async fn on_received_message(&self, message: Message>) -> anyhow::Result<()> { match message { Message::Request(request) => { self.on_download_request(request) @@ -824,7 +848,10 @@ impl<'a> PeerConnectionHandler for &'a PeerHandler { Message::Choke => self.on_i_am_choked(), Message::Unchoke => self.on_i_am_unchoked(), Message::Interested => self.on_peer_interested(), - Message::Piece(piece) => self.on_received_piece(piece).context("on_received_piece")?, + Message::Piece(piece) => self + .on_received_piece(piece) + .await + .context("on_received_piece")?, Message::KeepAlive => { trace!("keepalive received"); } @@ -1302,7 +1329,7 @@ impl PeerHandler { self.requests_sem.add_permits(128); } - fn on_received_piece(&self, piece: Piece) -> anyhow::Result<()> { + async fn on_received_piece(&self, piece: Piece>) -> anyhow::Result<()> { let piece_index = self .state .lengths @@ -1510,7 +1537,12 @@ impl PeerHandler { } }) }; - tokio::task::spawn_blocking(work); + self.state + .disk_work_tx + .send(DiskWriteWorkItem { + work: Box::new(work), + }) + .await?; } else { self.state .meta