diff --git a/crates/dht/examples/dht.rs b/crates/dht/examples/dht.rs index 586a707..108122d 100644 --- a/crates/dht/examples/dht.rs +++ b/crates/dht/examples/dht.rs @@ -12,7 +12,9 @@ async fn main() -> anyhow::Result<()> { .nth(1) .expect("first argument should be a magnet link"); let magnet = Magnet::parse(&magnet).unwrap(); - let info_hash = magnet.as_id20().context("Supplied magnet link didn't contain a BTv1 infohash")?; + let info_hash = magnet + .as_id20() + .context("Supplied magnet link didn't contain a BTv1 infohash")?; tracing_subscriber::fmt::init(); diff --git a/crates/librqbit/src/storage/example.rs b/crates/librqbit/src/storage/example.rs new file mode 100644 index 0000000..ad353b1 --- /dev/null +++ b/crates/librqbit/src/storage/example.rs @@ -0,0 +1,97 @@ +use std::{collections::HashMap, path::Path}; + +use anyhow::Context; +use librqbit_core::lengths::{Lengths, ValidPieceIndex}; +use parking_lot::RwLock; + +use crate::type_aliases::FileInfos; + +use super::TorrentStorage; + +struct InMemoryPiece { + bytes: Box<[u8]>, +} + +impl InMemoryPiece { + fn new(l: &Lengths) -> Self { + let v = vec![0; l.default_piece_length() as usize].into_boxed_slice(); + Self { bytes: v } + } +} + +pub struct InMemoryExampleStorage { + lengths: Lengths, + file_infos: FileInfos, + map: RwLock>, + // TODO: chunk tracker - rename to PieceTracker and extract chunks out of it (only keep pieces) + // this sucker here would track chunks, and the storage above too. +} + +impl InMemoryExampleStorage { + pub fn new(lengths: Lengths, file_infos: FileInfos) -> anyhow::Result { + // Max memory 128MiB. Make it tunable + let max_pieces = 128 * 1024 * 1024 / lengths.default_piece_length(); + if max_pieces == 0 { + anyhow::bail!("pieces too large"); + } + + Ok(Self { + lengths, + file_infos, + map: RwLock::new(HashMap::new()), + }) + } +} + +impl TorrentStorage for InMemoryExampleStorage { + fn pread_exact(&self, file_id: usize, offset: u64, buf: &mut [u8]) -> anyhow::Result<()> { + let fi = &self.file_infos[file_id]; + let abs_offset = fi.offset_in_torrent + offset; + let piece_id: u32 = (abs_offset / self.lengths.default_piece_length() as u64).try_into()?; + let piece_offset: usize = + (abs_offset % self.lengths.default_piece_length() as u64).try_into()?; + let piece_id = self.lengths.validate_piece_index(piece_id).context("bug")?; + + let g = self.map.read(); + let inmp = g.get(&piece_id).context("piece expired")?; + buf.copy_from_slice(&inmp.bytes[piece_offset..(piece_offset + buf.len())]); + Ok(()) + } + + fn pwrite_all(&self, file_id: usize, offset: u64, buf: &[u8]) -> anyhow::Result<()> { + let fi = &self.file_infos[file_id]; + let abs_offset = fi.offset_in_torrent + offset; + let piece_id: u32 = (abs_offset / self.lengths.default_piece_length() as u64).try_into()?; + let piece_offset: usize = + (abs_offset % self.lengths.default_piece_length() as u64).try_into()?; + let piece_id = self.lengths.validate_piece_index(piece_id).context("bug")?; + let mut g = self.map.write(); + let inmp = g + .entry(piece_id) + .or_insert_with(|| InMemoryPiece::new(&self.lengths)); + inmp.bytes[piece_offset..(piece_offset + buf.len())].copy_from_slice(buf); + Ok(()) + } + + fn remove_file(&self, _file_id: usize, _filename: &Path) -> anyhow::Result<()> { + Ok(()) + } + + fn ensure_file_length(&self, _file_id: usize, _length: u64) -> anyhow::Result<()> { + Ok(()) + } + + fn take(&self) -> anyhow::Result> { + let map = { + let mut g = self.map.write(); + let mut repl = HashMap::new(); + std::mem::swap(&mut *g, &mut repl); + repl + }; + Ok(Box::new(Self { + lengths: self.lengths, + map: RwLock::new(map), + file_infos: self.file_infos.clone(), + })) + } +} diff --git a/crates/librqbit/src/storage/filesystem.rs b/crates/librqbit/src/storage/filesystem.rs new file mode 100644 index 0000000..836922e --- /dev/null +++ b/crates/librqbit/src/storage/filesystem.rs @@ -0,0 +1,106 @@ +use std::{ + fs::OpenOptions, + io::{Read, Seek, SeekFrom, Write}, + path::{Path, PathBuf}, +}; + +use anyhow::Context; + +use crate::{opened_file::OpenedFile, torrent_state::ManagedTorrentInfo}; + +use super::{StorageFactory, TorrentStorage}; + +pub struct FilesystemStorageFactory { + pub output_folder: PathBuf, + pub allow_overwrite: bool, +} + +impl StorageFactory for FilesystemStorageFactory { + fn init_storage(&self, meta: &ManagedTorrentInfo) -> anyhow::Result> { + let mut files = Vec::::new(); + for file_details in meta.info.iter_file_details(&meta.lengths)? { + let mut full_path = self.output_folder.clone(); + let relative_path = file_details + .filename + .to_pathbuf() + .context("error converting file to path")?; + full_path.push(relative_path); + + std::fs::create_dir_all(full_path.parent().context("bug: no parent")?)?; + let file = if self.allow_overwrite { + OpenOptions::new() + .create(true) + .truncate(false) + .read(true) + .write(true) + .open(&full_path) + .with_context(|| format!("error opening {full_path:?} in read/write mode"))? + } else { + // create_new does not seem to work with read(true), so calling this twice. + OpenOptions::new() + .create_new(true) + .write(true) + .open(&full_path) + .with_context(|| format!("error creating {:?}", &full_path))?; + OpenOptions::new().read(true).write(true).open(&full_path)? + }; + files.push(OpenedFile::new(file)); + } + Ok(Box::new(FilesystemStorage { + output_folder: self.output_folder.clone(), + opened_files: files, + })) + } + + fn output_folder(&self) -> Option<&Path> { + Some(&self.output_folder) + } +} + +pub struct FilesystemStorage { + output_folder: PathBuf, + opened_files: Vec, +} + +impl TorrentStorage for FilesystemStorage { + fn pread_exact(&self, file_id: usize, offset: u64, buf: &mut [u8]) -> anyhow::Result<()> { + let mut g = self + .opened_files + .get(file_id) + .context("no such file")? + .file + .lock(); + g.seek(SeekFrom::Start(offset))?; + Ok(g.read_exact(buf)?) + } + + fn pwrite_all(&self, file_id: usize, offset: u64, buf: &[u8]) -> anyhow::Result<()> { + let mut g = self + .opened_files + .get(file_id) + .context("no such file")? + .file + .lock(); + g.seek(SeekFrom::Start(offset))?; + Ok(g.write_all(buf)?) + } + + fn remove_file(&self, _file_id: usize, filename: &Path) -> anyhow::Result<()> { + Ok(std::fs::remove_file(self.output_folder.join(filename))?) + } + + fn ensure_file_length(&self, file_id: usize, len: u64) -> anyhow::Result<()> { + Ok(self.opened_files[file_id].file.lock().set_len(len)?) + } + + fn take(&self) -> anyhow::Result> { + Ok(Box::new(Self { + opened_files: self + .opened_files + .iter() + .map(|f| f.take_clone()) + .collect::>>()?, + output_folder: self.output_folder.clone(), + })) + } +} diff --git a/crates/librqbit/src/storage/mod.rs b/crates/librqbit/src/storage/mod.rs new file mode 100644 index 0000000..738d376 --- /dev/null +++ b/crates/librqbit/src/storage/mod.rs @@ -0,0 +1,52 @@ +pub mod example; +pub mod filesystem; + +use std::path::Path; + +use crate::torrent_state::ManagedTorrentInfo; + +pub trait StorageFactory: Send + Sync { + fn init_storage(&self, info: &ManagedTorrentInfo) -> anyhow::Result>; + + fn output_folder(&self) -> Option<&Path> { + None + } +} + +pub trait TorrentStorage: Send + Sync { + fn pread_exact(&self, file_id: usize, offset: u64, buf: &mut [u8]) -> anyhow::Result<()>; + + fn pwrite_all(&self, file_id: usize, offset: u64, buf: &[u8]) -> anyhow::Result<()>; + + fn remove_file(&self, file_id: usize, filename: &Path) -> anyhow::Result<()>; + + fn ensure_file_length(&self, file_id: usize, length: u64) -> anyhow::Result<()>; + + fn take(&self) -> anyhow::Result>; + + fn output_folder(&self) -> Option<&Path> { + None + } +} + +impl TorrentStorage for Box { + fn pread_exact(&self, file_id: usize, offset: u64, buf: &mut [u8]) -> anyhow::Result<()> { + (**self).pread_exact(file_id, offset, buf) + } + + fn pwrite_all(&self, file_id: usize, offset: u64, buf: &[u8]) -> anyhow::Result<()> { + (**self).pwrite_all(file_id, offset, buf) + } + + fn remove_file(&self, file_id: usize, filename: &Path) -> anyhow::Result<()> { + (**self).remove_file(file_id, filename) + } + + fn ensure_file_length(&self, file_id: usize, length: u64) -> anyhow::Result<()> { + (**self).ensure_file_length(file_id, length) + } + + fn take(&self) -> anyhow::Result> { + (**self).take() + } +} diff --git a/crates/librqbit/src/torrent_state/mod.rs b/crates/librqbit/src/torrent_state/mod.rs index 6c71db6..54310a0 100644 --- a/crates/librqbit/src/torrent_state/mod.rs +++ b/crates/librqbit/src/torrent_state/mod.rs @@ -36,7 +36,7 @@ use tracing::warn; use crate::chunk_tracker::ChunkTracker; use crate::file_info::FileInfo; use crate::spawn_utils::BlockingSpawner; -use crate::storage::FilesystemStorageFactory; +use crate::storage::filesystem::FilesystemStorageFactory; use crate::storage::StorageFactory; use crate::torrent_state::stats::LiveStats; use crate::type_aliases::FileInfos; @@ -547,6 +547,11 @@ impl ManagedTorrentBuilder { self } + pub fn storage_factory(&mut self, factory: Box) -> &mut Self { + self.storage = Some(ManagedTorrentBuilderStorage::Custom(factory)); + self + } + pub fn force_tracker_interval(&mut self, force_tracker_interval: Duration) -> &mut Self { self.force_tracker_interval = Some(force_tracker_interval); self diff --git a/crates/librqbit_core/src/hash_id.rs b/crates/librqbit_core/src/hash_id.rs index 1a27f97..d8bd20b 100644 --- a/crates/librqbit_core/src/hash_id.rs +++ b/crates/librqbit_core/src/hash_id.rs @@ -69,8 +69,8 @@ impl FromStr for Id { fn from_str(s: &str) -> Result { let mut out = [0u8; N]; - if s.len() != N*2 { - anyhow::bail!("expected a hex string of length {}", N*2) + if s.len() != N * 2 { + anyhow::bail!("expected a hex string of length {}", N * 2) }; hex::decode_to_slice(s, &mut out)?; Ok(Id(out)) @@ -97,8 +97,9 @@ impl<'de, const N: usize> Deserialize<'de> for Id { type Value = Id; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a byte array of length ") - .and_then(|_| formatter.write_fmt(format_args!("{}", N))) + formatter + .write_str("a byte array of length ") + .and_then(|_| formatter.write_fmt(format_args!("{}", N))) } fn visit_str(self, v: &str) -> Result @@ -135,7 +136,7 @@ impl<'de, const N: usize> Deserialize<'de> for Id { } } - deserializer.deserialize_any(IdVisitor{}) + deserializer.deserialize_any(IdVisitor {}) } } @@ -165,8 +166,8 @@ pub type Id32 = Id<32>; #[cfg(test)] mod tests { - use std::str::FromStr; use super::*; + use std::str::FromStr; #[test] fn test_set_bit_range() { @@ -183,5 +184,4 @@ mod tests { let str = "06f04cc728bef957a658876ef807f0514e4d715392969998efef584d2c3e435e"; let _ih = Id32::from_str(str).unwrap(); } - -} \ No newline at end of file +}