Fixing up initialization to allow passing in custom storages

This commit is contained in:
Igor Katson 2024-04-30 08:55:00 +01:00
parent 1b49257019
commit 42bbf84ea5
8 changed files with 128 additions and 90 deletions

View file

@ -36,6 +36,8 @@ use tracing::warn;
use crate::chunk_tracker::ChunkTracker;
use crate::file_info::FileInfo;
use crate::spawn_utils::BlockingSpawner;
use crate::storage::FilesystemStorageFactory;
use crate::storage::StorageFactory;
use crate::torrent_state::stats::LiveStats;
use crate::type_aliases::FileInfos;
use crate::type_aliases::PeerStream;
@ -89,13 +91,11 @@ pub(crate) struct ManagedTorrentOptions {
pub force_tracker_interval: Option<Duration>,
pub peer_connect_timeout: Option<Duration>,
pub peer_read_write_timeout: Option<Duration>,
pub overwrite: bool,
}
pub struct ManagedTorrentInfo {
pub info: TorrentMetaV1Info<ByteBufOwned>,
pub info_hash: Id20,
pub out_dir: PathBuf,
pub(crate) spawner: BlockingSpawner,
pub trackers: HashSet<String>,
pub peer_id: Id20,
@ -107,6 +107,7 @@ pub struct ManagedTorrentInfo {
pub struct ManagedTorrent {
pub info: Arc<ManagedTorrentInfo>,
storage_factory: Box<dyn StorageFactory>,
locked: RwLock<ManagedTorrentLocked>,
}
@ -267,7 +268,7 @@ impl ManagedTorrent {
error_span!(parent: span.clone(), "initialize_and_start"),
token.clone(),
async move {
match init.check().await {
match init.check(&*self.storage_factory).await {
Ok(paused) => {
let mut g = t.locked.write();
if let ManagedTorrentState::Initializing(_) = &g.state {
@ -461,18 +462,42 @@ impl ManagedTorrent {
}
}
enum ManagedTorrentBuilderStorage {
Filesystem {
overwrite: bool,
output_folder: PathBuf,
},
Custom(Box<dyn StorageFactory>),
}
impl ManagedTorrentBuilderStorage {
fn build(self) -> anyhow::Result<Box<dyn StorageFactory>> {
let s = match self {
ManagedTorrentBuilderStorage::Filesystem {
overwrite,
output_folder,
} => Box::new(FilesystemStorageFactory {
output_folder,
allow_overwrite: overwrite,
}),
ManagedTorrentBuilderStorage::Custom(s) => s,
};
Ok(s)
}
}
pub struct ManagedTorrentBuilder {
info: TorrentMetaV1Info<ByteBufOwned>,
info_hash: Id20,
output_folder: PathBuf,
force_tracker_interval: Option<Duration>,
peer_connect_timeout: Option<Duration>,
peer_read_write_timeout: Option<Duration>,
only_files: Option<Vec<usize>>,
trackers: Vec<String>,
peer_id: Option<Id20>,
overwrite: bool,
spawner: Option<BlockingSpawner>,
deferred_build_errors: Vec<String>,
storage: Option<ManagedTorrentBuilderStorage>,
}
impl ManagedTorrentBuilder {
@ -484,15 +509,19 @@ impl ManagedTorrentBuilder {
Self {
info,
info_hash,
output_folder: output_folder.as_ref().into(),
spawner: None,
force_tracker_interval: None,
peer_connect_timeout: None,
peer_read_write_timeout: None,
only_files: None,
deferred_build_errors: Default::default(),
trackers: Default::default(),
peer_id: None,
overwrite: false,
// default is filesystem to keep the old API unchanged for now
storage: Some(ManagedTorrentBuilderStorage::Filesystem {
overwrite: false,
output_folder: output_folder.as_ref().to_owned(),
}),
}
}
@ -506,8 +535,15 @@ impl ManagedTorrentBuilder {
self
}
pub fn overwrite(&mut self, overwrite: bool) -> &mut Self {
self.overwrite = overwrite;
pub fn overwrite(&mut self, new_overwrite: bool) -> &mut Self {
match self.storage.as_mut() {
Some(ManagedTorrentBuilderStorage::Filesystem { overwrite, .. }) => {
*overwrite = new_overwrite
}
_ => self
.deferred_build_errors
.push("overwrite() called when storage factory was not filesystem".to_owned()),
}
self
}
@ -537,25 +573,33 @@ impl ManagedTorrentBuilder {
}
pub(crate) fn build(self, span: tracing::Span) -> anyhow::Result<ManagedTorrentHandle> {
if !self.deferred_build_errors.is_empty() {
anyhow::bail!("Errors: {}", self.deferred_build_errors.join(";"))
}
let lengths = Lengths::from_torrent(&self.info)?;
let file_infos = self
.info
.iter_file_details(&lengths)?
.map(|fd| {
Ok::<_, anyhow::Error>(FileInfo {
filename: self.output_folder.join(fd.filename.to_pathbuf()?),
relative_filename: fd.filename.to_pathbuf()?,
offset_in_torrent: fd.offset,
piece_range: fd.pieces,
len: fd.len,
})
})
.collect::<anyhow::Result<Vec<FileInfo>>>()?;
let storage_factory = self
.storage
.context("by the time build() is called you must set storage factory")?
.build()?;
let info = Arc::new(ManagedTorrentInfo {
span,
file_infos,
info: self.info,
info_hash: self.info_hash,
out_dir: self.output_folder,
trackers: self.trackers.into_iter().collect(),
spawner: self.spawner.unwrap_or_default(),
peer_id: self.peer_id.unwrap_or_else(generate_peer_id),
@ -564,9 +608,9 @@ impl ManagedTorrentBuilder {
force_tracker_interval: self.force_tracker_interval,
peer_connect_timeout: self.peer_connect_timeout,
peer_read_write_timeout: self.peer_read_write_timeout,
overwrite: self.overwrite,
},
});
let initializing = Arc::new(TorrentStateInitializing::new(
info.clone(),
self.only_files.clone(),
@ -576,6 +620,7 @@ impl ManagedTorrentBuilder {
state: ManagedTorrentState::Initializing(initializing),
only_files: self.only_files,
}),
storage_factory,
info,
}))
}