diff --git a/crates/librqbit/src/torrent_state/live/mod.rs b/crates/librqbit/src/torrent_state/live/mod.rs index e93801e..51b2792 100644 --- a/crates/librqbit/src/torrent_state/live/mod.rs +++ b/crates/librqbit/src/torrent_state/live/mod.rs @@ -109,6 +109,7 @@ use super::{ ManagedTorrentInfo, }; +#[derive(Debug)] struct InflightPiece { peer: PeerHandle, started: Instant, @@ -214,7 +215,7 @@ impl TorrentStateLive { peers: Default::default(), locked: RwLock::new(TorrentStateLocked { chunks: Some(paused.chunk_tracker), - // TODO: move under per_piece_locks + // TODO: move under per_piece_locks? inflight_pieces: Default::default(), file_priorities, fatal_errors_tx: Some(fatal_errors_tx), @@ -1051,14 +1052,20 @@ impl PeerHandler { // heuristic for "too slow peer" if elapsed.as_secs_f64() > my_avg_time.as_secs_f64() * threshold { - debug!( - "will steal piece {} from {}: elapsed time {:?}, my avg piece time: {:?}", - idx, piece_req.peer, elapsed, my_avg_time - ); - let old = piece_req.peer; - piece_req.peer = self.addr; - piece_req.started = Instant::now(); - (*idx, old) + // If the piece is locked and someone is active writing to disk, don't steal it. + if let Some(_g) = self.state.per_piece_locks[idx.get_usize()].try_write() { + debug!( + "will steal piece {} from {}: elapsed time {:?}, my avg piece time: {:?}", + idx, piece_req.peer, elapsed, my_avg_time + ); + let old = piece_req.peer; + piece_req.peer = self.addr; + piece_req.started = Instant::now(); + (*idx, old) + } else { + warn!(?idx, ?piece_req, "attempted to steal but peer was writing"); + return None; + } } else { return None; } @@ -1349,8 +1356,10 @@ impl PeerHandler { ) -> anyhow::Result<()> { let index = piece.index; - let full_piece_download_time = { - let mut g = state.lock_write("mark_chunk_downloaded"); + let ppl_guard = { + let g = state.lock_read("check_steal"); + + let ppl = state.per_piece_locks[piece.index as usize].read(); match g.inflight_pieces.get(&chunk_info.piece_index) { Some(InflightPiece { peer, .. }) if *peer == addr => {} @@ -1370,10 +1379,26 @@ impl PeerHandler { } }; + ppl + }; + + // While we hold per piece lock, noone can steal it. + // So we can proceed writing knowing that the piece is ours now and will still be by the time + // the write is finished. + match state.file_ops().write_chunk(addr, piece, chunk_info) { + Ok(()) => {} + Err(e) => { + error!("FATAL: error writing chunk to disk: {:?}", e); + return state.on_fatal_error(e); + } + }; + + let full_piece_download_time = { + let mut g = state.lock_write("mark_chunk_downloaded"); let chunk_marking_result = g.get_chunks_mut()?.mark_chunk_downloaded(piece); trace!(?piece, chunk_marking_result=?chunk_marking_result); - let full_piece_download_time = match chunk_marking_result { + match chunk_marking_result { Some(ChunkMarkingResult::Completed) => { trace!("piece={} done, will write and checksum", piece.index); // This will prevent others from stealing it. @@ -1395,24 +1420,13 @@ impl PeerHandler { piece ); } - }; - - // By this time we reach here, no other peer can request this piece. All others, even if they steal pieces would - // have fallen off above in one of the defensive checks. - - // Not being able to write to storage is a fatal error. You need to unpause the - // torrent to recover from it. - match state.file_ops().write_chunk(addr, piece, chunk_info) { - Ok(()) => {} - Err(e) => { - error!("FATAL: error writing chunk to disk: {:?}", e); - return state.on_fatal_error(e); - } - }; - - full_piece_download_time + } }; + // We don't care about per piece lock anymore, as it's removed from inflight pieces. + // It shouldn't impact perf anyway, but dropping just in case. + drop(ppl_guard); + let full_piece_download_time = match full_piece_download_time { Some(t) => t, None => return Ok(()), diff --git a/crates/librqbit/src/torrent_state/utils.rs b/crates/librqbit/src/torrent_state/utils.rs index 3323cba..72e5118 100644 --- a/crates/librqbit/src/torrent_state/utils.rs +++ b/crates/librqbit/src/torrent_state/utils.rs @@ -74,6 +74,7 @@ mod timed_existence { fn drop(&mut self) { let elapsed = self.started.elapsed(); let reason = self.reason; + tracing::trace!(name=%self.reason, ?elapsed, "dropping guard"); if elapsed > MAX { warn!("elapsed on lock {reason:?}: {elapsed:?}") } @@ -96,10 +97,12 @@ mod timed_existence { pub fn timeit(name: impl std::fmt::Display, f: impl FnOnce() -> R) -> R { let now = Instant::now(); + tracing::trace!(%name, "starting"); let r = f(); + tracing::trace!(%name, "done"); let elapsed = now.elapsed(); if elapsed > MAX { - warn!("elapsed on \"{name:}\": {elapsed:?}") + warn!(%name, ?elapsed, max = ?MAX, "elapsed > MAX"); } r }