diff --git a/Cargo.lock b/Cargo.lock index 09eec36..67b00e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1307,6 +1307,7 @@ version = "6.0.0" dependencies = [ "anyhow", "async-stream", + "async-trait", "axum 0.7.5", "backoff", "base64 0.21.7", diff --git a/crates/librqbit/Cargo.toml b/crates/librqbit/Cargo.toml index 03ec357..955e8aa 100644 --- a/crates/librqbit/Cargo.toml +++ b/crates/librqbit/Cargo.toml @@ -31,7 +31,12 @@ sha1w = { path = "../sha1w", default-features = false, package = "librqbit-sha1- dht = { path = "../dht", package = "librqbit-dht", version = "5.0.4" } librqbit-upnp = { path = "../upnp", version = "0.1.0" } -tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1", features = [ + "macros", + "rt-multi-thread", + "fs", + "io-util", +] } axum = { version = "0.7.4" } tower-http = { version = "0.5", features = ["cors", "trace"] } tokio-stream = "0.1" @@ -79,6 +84,7 @@ memmap2 = { version = "0.9.4" } lru = { version = "0.12.3", optional = true } mime_guess = { version = "2.0.5", default-features = false } tokio-socks = "0.5.2" +async-trait = "0.1.81" [build-dependencies] anyhow = "1" diff --git a/crates/librqbit/examples/custom_storage.rs b/crates/librqbit/examples/custom_storage.rs index f85fd37..2be41c9 100644 --- a/crates/librqbit/examples/custom_storage.rs +++ b/crates/librqbit/examples/custom_storage.rs @@ -71,7 +71,7 @@ async fn main() -> anyhow::Result<()> { Default::default(), SessionOptions { disable_dht_persistence: true, - persistence: false, + persistence: None, listen_port_range: None, enable_upnp_port_forwarding: false, ..Default::default() diff --git a/crates/librqbit/src/api.rs b/crates/librqbit/src/api.rs index 4a253ae..9ed239f 100644 --- a/crates/librqbit/src/api.rs +++ b/crates/librqbit/src/api.rs @@ -114,16 +114,18 @@ impl Api { Ok(Default::default()) } - pub fn api_torrent_action_forget(&self, idx: TorrentId) -> Result { + pub async fn api_torrent_action_forget(&self, idx: TorrentId) -> Result { self.session .delete(idx, false) + .await .context("error forgetting torrent")?; Ok(Default::default()) } - pub fn api_torrent_action_delete(&self, idx: TorrentId) -> Result { + pub async fn api_torrent_action_delete(&self, idx: TorrentId) -> Result { self.session .delete(idx, true) + .await .context("error deleting torrent with files")?; Ok(Default::default()) } diff --git a/crates/librqbit/src/http_api.rs b/crates/librqbit/src/http_api.rs index bb69635..699ce5d 100644 --- a/crates/librqbit/src/http_api.rs +++ b/crates/librqbit/src/http_api.rs @@ -382,14 +382,14 @@ impl HttpApi { State(state): State, Path(idx): Path, ) -> Result { - state.api_torrent_action_forget(idx).map(axum::Json) + state.api_torrent_action_forget(idx).await.map(axum::Json) } async fn torrent_action_delete( State(state): State, Path(idx): Path, ) -> Result { - state.api_torrent_action_delete(idx).map(axum::Json) + state.api_torrent_action_delete(idx).await.map(axum::Json) } #[derive(Deserialize)] diff --git a/crates/librqbit/src/lib.rs b/crates/librqbit/src/lib.rs index 8990e0e..1f66a3b 100644 --- a/crates/librqbit/src/lib.rs +++ b/crates/librqbit/src/lib.rs @@ -39,6 +39,7 @@ mod peer_connection; mod peer_info_reader; mod read_buf; mod session; +mod session_persistence; mod spawn_utils; pub mod storage; mod stream_connect; @@ -53,7 +54,7 @@ pub use dht; pub use peer_connection::PeerConnectionOptions; pub use session::{ AddTorrent, AddTorrentOptions, AddTorrentResponse, ListOnlyResponse, Session, SessionOptions, - SUPPORTED_SCHEMES, + SessionPersistenceConfig, SUPPORTED_SCHEMES, }; pub use spawn_utils::spawn as librqbit_spawn; pub use torrent_state::{ diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 7526002..7d5b0a4 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -1,11 +1,9 @@ use std::{ - any::TypeId, borrow::Cow, collections::{HashMap, HashSet}, - io::{BufReader, BufWriter, Read}, + io::Read, net::SocketAddr, path::{Path, PathBuf}, - str::FromStr, sync::Arc, time::Duration, }; @@ -15,6 +13,9 @@ use crate::{ merge_streams::merge_streams, peer_connection::PeerConnectionOptions, read_buf::ReadBuf, + session_persistence::{ + json::JsonSessionPersistenceStore, BoxSessionPersistenceStore, SessionPersistenceStore, + }, spawn_utils::BlockingSpawner, storage::{ filesystem::FilesystemStorageFactory, BoxStorageFactory, StorageFactoryExt, TorrentStorage, @@ -27,7 +28,7 @@ use crate::{ ManagedTorrentInfo, }; use anyhow::{bail, Context}; -use bencode::{bencode_serialize_to_writer, BencodeDeserializer}; +use bencode::bencode_serialize_to_writer; use buffers::{ByteBuf, ByteBufOwned, ByteBufT}; use bytes::Bytes; use clone_to_owned::CloneToOwned; @@ -48,7 +49,7 @@ use librqbit_core::{ }; use parking_lot::RwLock; use peer_binary_protocol::Handshake; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::{Deserialize, Serialize}; use tokio::net::{TcpListener, TcpStream}; use tokio_stream::StreamExt; use tokio_util::sync::{CancellationToken, DropGuard}; @@ -103,130 +104,15 @@ impl SessionDatabase { } let idx = self.next_id; self.torrents.insert(idx, torrent); - self.next_id += 1; + self.next_id = self.next_id.max(idx) + 1; idx } - - fn serialize(&self) -> SerializedSessionDatabase { - SerializedSessionDatabase { - torrents: self - .torrents - .iter() - // We don't support serializing / deserializing of other storage types. - .filter(|(_, torrent)| { - torrent - .storage_factory - .is_type_id(TypeId::of::()) - }) - .map(|(id, torrent)| { - ( - *id, - SerializedTorrent { - trackers: torrent - .info() - .trackers - .iter() - .map(|u| u.to_string()) - .collect(), - info_hash: torrent.info_hash().as_string(), - // TODO: this could take up too much space / time / resources to write on interval. - // Store this outside the JSON file - // - // torrent_bytes: torrent.info.torrent_bytes.clone(), - torrent_bytes: Bytes::new(), - info: torrent.info().info.clone(), - only_files: torrent.only_files().clone(), - is_paused: torrent - .with_state(|s| matches!(s, ManagedTorrentState::Paused(_))), - output_folder: torrent.info().options.output_folder.clone(), - }, - ) - }) - .collect(), - } - } -} - -#[derive(Serialize, Deserialize)] -struct SerializedTorrent { - info_hash: String, - #[serde( - serialize_with = "serialize_torrent", - deserialize_with = "deserialize_torrent" - )] - info: TorrentMetaV1Info, - #[serde( - serialize_with = "serialize_torrent_bytes", - deserialize_with = "deserialize_torrent_bytes", - default - )] - torrent_bytes: Bytes, - trackers: HashSet, - output_folder: PathBuf, - only_files: Option>, - is_paused: bool, -} - -fn serialize_torrent( - t: &TorrentMetaV1Info, - serializer: S, -) -> Result -where - S: Serializer, -{ - use base64::{engine::general_purpose, Engine as _}; - use serde::ser::Error; - let mut writer = Vec::new(); - bencode_serialize_to_writer(t, &mut writer).map_err(S::Error::custom)?; - let s = general_purpose::STANDARD_NO_PAD.encode(&writer); - s.serialize(serializer) -} - -fn deserialize_torrent<'de, D>(deserializer: D) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - use base64::{engine::general_purpose, Engine as _}; - use serde::de::Error; - let s = String::deserialize(deserializer)?; - let b = general_purpose::STANDARD_NO_PAD - .decode(s) - .map_err(D::Error::custom)?; - TorrentMetaV1Info::::deserialize(&mut BencodeDeserializer::new_from_buf(&b)) - .map_err(D::Error::custom) -} - -fn serialize_torrent_bytes(t: &Bytes, serializer: S) -> Result -where - S: Serializer, -{ - use base64::{engine::general_purpose, Engine as _}; - let s = general_purpose::STANDARD_NO_PAD.encode(t); - s.serialize(serializer) -} - -fn deserialize_torrent_bytes<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, -{ - use base64::{engine::general_purpose, Engine as _}; - use serde::de::Error; - let s = String::deserialize(deserializer)?; - let b = general_purpose::STANDARD_NO_PAD - .decode(s) - .map_err(D::Error::custom)?; - Ok(b.into()) -} - -#[derive(Serialize, Deserialize)] -struct SerializedSessionDatabase { - torrents: HashMap, } pub struct Session { peer_id: Id20, dht: Option, - persistence_filename: PathBuf, + persistence: Option>, peer_opts: PeerConnectionOptions, spawner: BlockingSpawner, db: RwLock, @@ -463,6 +349,11 @@ impl<'a> AddTorrent<'a> { } } +pub enum SessionPersistenceConfig { + /// The filename for persistence. By default uses an OS-specific folder. + Json { folder: Option }, +} + #[derive(Default)] pub struct SessionOptions { /// Turn on to disable DHT. @@ -476,9 +367,7 @@ pub struct SessionOptions { /// Turn on to dump session contents into a file periodically, so that on next start /// all remembered torrents will continue where they left off. - pub persistence: bool, - /// The filename for persistence. By default uses an OS-specific folder. - pub persistence_filename: Option, + pub persistence: Option, /// The peer ID to use. If not specified, a random one will be generated. pub peer_id: Option, @@ -557,11 +446,6 @@ impl Session { Self::new_with_opts(default_output_folder, SessionOptions::default()) } - pub fn default_persistence_filename() -> anyhow::Result { - let dir = get_configuration_directory("session")?; - Ok(dir.data_dir().join("session.json")) - } - pub fn cancellation_token(&self) -> &CancellationToken { &self.cancellation_token } @@ -576,15 +460,16 @@ impl Session { let peer_id = opts.peer_id.unwrap_or_else(generate_peer_id); let token = CancellationToken::new(); - let (tcp_listener, tcp_listen_port) = if let Some(port_range) = opts.listen_port_range { - let (l, p) = create_tcp_listener(port_range) - .await - .context("error listening on TCP")?; - info!("Listening on 0.0.0.0:{p} for incoming peer connections"); - (Some(l), Some(p)) - } else { - (None, None) - }; + let (tcp_listener, tcp_listen_port) = + if let Some(port_range) = opts.listen_port_range.clone() { + let (l, p) = create_tcp_listener(port_range) + .await + .context("error listening on TCP")?; + info!("Listening on 0.0.0.0:{p} for incoming peer connections"); + (Some(l), Some(p)) + } else { + (None, None) + }; let dht = if opts.disable_dht { None @@ -606,11 +491,36 @@ impl Session { Some(dht) }; let peer_opts = opts.peer_opts.unwrap_or_default(); - let persistence_filename = match opts.persistence_filename { - Some(filename) => filename, - None if !opts.persistence => PathBuf::new(), - None => Self::default_persistence_filename()?, - }; + + async fn persistence_factory( + opts: &SessionOptions, + ) -> anyhow::Result> { + pub fn default_persistence_folder() -> anyhow::Result { + let dir = get_configuration_directory("session")?; + Ok(dir.data_dir().to_owned()) + } + + match &opts.persistence { + Some(SessionPersistenceConfig::Json { folder }) => { + let folder = match folder.as_ref() { + Some(f) => f.clone(), + None => default_persistence_folder()?, + }; + + Ok(Some(Box::new( + JsonSessionPersistenceStore::new(folder) + .await + .context("error initializing JsonSessionPersistenceStore")?, + ))) + } + None => Ok(None), + } + } + + let persistence = persistence_factory(&opts) + .await + .context("error initializing session persistence store")?; + let spawner = BlockingSpawner::default(); let (disk_write_tx, disk_write_rx) = opts @@ -646,7 +556,7 @@ impl Session { let stream_connector = Arc::new(StreamConnector::from(proxy_config)); let session = Arc::new(Self { - persistence_filename, + persistence, peer_id, dht, peer_opts, @@ -688,18 +598,34 @@ impl Session { } } - if opts.persistence { - info!( - "will use {:?} for session persistence", - session.persistence_filename - ); - if let Some(parent) = session.persistence_filename.parent() { - std::fs::create_dir_all(parent).with_context(|| { - format!("couldn't create directory {:?} for session storage", parent) - })?; + if let Some(persistence) = session.persistence.as_ref() { + info!("will use {persistence:?} for session persistence"); + + let mut ps = persistence.stream_all().await?; + let mut added_all = false; + let mut futs = FuturesUnordered::new(); + + while !added_all || !futs.is_empty() { + tokio::select! { + Some(res) = futs.next(), if !futs.is_empty() => { + if let Err(e) = res { + error!("error adding torrent to session: {e:?}"); + } + }, + st = ps.next(), if !added_all => { + if let Some(st) = st { + let (id, st) = st?; + let span = error_span!("add_torrent", info_hash=?st.info_hash()); + let (add_torrent, mut opts) = st.into_add_torrent()?; + opts.preferred_id = Some(id); + let fut = session.add_torrent(add_torrent, Some(opts)).instrument(span); + futs.push(fut); + } else { + added_all = true; + } + }, + } } - let persistence_task = session.clone().task_persistence(); - session.spawn(error_span!("session_persistence"), persistence_task); } Ok(session) @@ -707,29 +633,6 @@ impl Session { .boxed() } - async fn task_persistence(self: Arc) -> anyhow::Result<()> { - // Populate initial from the state filename - if let Err(e) = self.populate_from_stored().await { - error!("could not populate session from stored file: {:?}", e); - } - - let session = Arc::downgrade(&self); - drop(self); - - loop { - tokio::time::sleep(Duration::from_secs(10)).await; - let session = match session.upgrade() { - Some(s) => s, - None => break, - }; - if let Err(e) = session.dump_to_disk() { - error!("error dumping session to disk: {:?}", e); - } - } - - Ok(()) - } - async fn check_incoming_connection( &self, addr: SocketAddr, @@ -868,102 +771,6 @@ impl Session { tokio::time::sleep(Duration::from_secs(1)).await; } - async fn populate_from_stored(self: &Arc) -> anyhow::Result<()> { - let mut rdr = match std::fs::File::open(&self.persistence_filename) { - Ok(f) => BufReader::new(f), - Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()), - Err(e) => { - return Err(e).context(format!( - "error opening session file {:?}", - self.persistence_filename - )) - } - }; - let db: SerializedSessionDatabase = - serde_json::from_reader(&mut rdr).context("error deserializing session database")?; - let mut futures = Vec::new(); - for (id, storrent) in db.torrents.into_iter() { - let trackers: Vec = storrent - .trackers - .into_iter() - .map(|t| ByteBufOwned::from(t.into_bytes())) - .collect(); - - let torrent_bytes = storrent.torrent_bytes; - - let add_torrent = if !torrent_bytes.is_empty() { - AddTorrent::TorrentFileBytes(torrent_bytes) - } else { - let info_hash = Id20::from_str(&storrent.info_hash)?; - debug!(?info_hash, "torrent added before 6.1.0, need to readd"); - let info = TorrentMetaV1Owned { - announce: trackers.first().cloned(), - announce_list: vec![trackers], - info: storrent.info, - comment: None, - created_by: None, - encoding: None, - publisher: None, - publisher_url: None, - creation_date: None, - info_hash, - }; - AddTorrent::TorrentInfo(Box::new(info)) - }; - - futures.push({ - let session = self.clone(); - async move { - session - .add_torrent( - add_torrent, - Some(AddTorrentOptions { - paused: storrent.is_paused, - output_folder: Some( - storrent - .output_folder - .to_str() - .context("broken path")? - .to_owned(), - ), - only_files: storrent.only_files, - overwrite: true, - preferred_id: Some(id), - ..Default::default() - }), - ) - .await - .map_err(|e| { - error!("error adding torrent from stored session: {:?}", e); - e - }) - } - }); - } - futures::future::join_all(futures).await; - Ok(()) - } - - fn dump_to_disk(&self) -> anyhow::Result<()> { - let tmp_filename = format!("{}.tmp", self.persistence_filename.to_str().unwrap()); - let mut tmp = BufWriter::new( - std::fs::OpenOptions::new() - .create(true) - .truncate(true) - .write(true) - .open(&tmp_filename) - .with_context(|| format!("error opening {:?}", tmp_filename))?, - ); - let serialized = self.db.read().serialize(); - serde_json::to_writer(&mut tmp, &serialized).context("error serializing")?; - drop(tmp); - - std::fs::rename(&tmp_filename, &self.persistence_filename) - .context("error renaming persistence file")?; - trace!(filename=?self.persistence_filename, "wrote persistence"); - Ok(()) - } - /// Run a callback given the currently managed torrents. pub fn with_torrents( &self, @@ -1256,13 +1063,17 @@ impl Session { { return Ok(AddTorrentResponse::AlreadyManaged(*id, handle.clone())); } - let next_id = g.torrents.len(); + let next_id = g.next_id; let managed_torrent = builder.build(error_span!(parent: None, "torrent", id = next_id))?; let id = g.add_torrent(managed_torrent.clone(), opts.preferred_id); (managed_torrent, id) }; + if let Some(p) = self.persistence.as_ref() { + p.store(id, &managed_torrent).await?; + } + // Merge "initial_peers" and "peer_rx" into one stream. let peer_rx = merge_two_optional_streams( if !initial_peers.is_empty() { @@ -1289,7 +1100,7 @@ impl Session { self.db.read().torrents.get(&id).cloned() } - pub fn delete(&self, id: TorrentId, delete_files: bool) -> anyhow::Result<()> { + pub async fn delete(&self, id: TorrentId, delete_files: bool) -> anyhow::Result<()> { let removed = self .db .write() @@ -1301,6 +1112,12 @@ impl Session { debug!("error pausing torrent before deletion: {e:?}") } + if let Some(p) = self.persistence.as_ref() { + if let Err(e) = p.delete(id).await { + error!(error=?e, "error deleting torrent from database"); + } + } + let storage = removed .with_state_mut(|s| match s.take() { ManagedTorrentState::Initializing(p) => p.files.take().ok(), diff --git a/crates/librqbit/src/session_persistence/json.rs b/crates/librqbit/src/session_persistence/json.rs new file mode 100644 index 0000000..b95c927 --- /dev/null +++ b/crates/librqbit/src/session_persistence/json.rs @@ -0,0 +1,215 @@ +use std::{any::TypeId, collections::HashMap, path::PathBuf}; + +use crate::{ + session::TorrentId, storage::filesystem::FilesystemStorageFactory, + torrent_state::ManagedTorrentHandle, ManagedTorrentState, +}; +use anyhow::{bail, Context}; +use async_trait::async_trait; +use futures::{stream::BoxStream, StreamExt}; +use itertools::Itertools; +use librqbit_core::Id20; +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tracing::{info, trace, warn}; + +use super::{SerializedTorrent, SessionPersistenceStore}; + +#[derive(Serialize, Deserialize, Default)] +struct SerializedSessionDatabase { + torrents: HashMap, +} + +pub struct JsonSessionPersistenceStore { + output_folder: PathBuf, + db_filename: PathBuf, + db_content: tokio::sync::RwLock, +} + +impl std::fmt::Debug for JsonSessionPersistenceStore { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JSON database: {:?}", self.db_filename) + } +} + +impl JsonSessionPersistenceStore { + pub async fn new(output_folder: PathBuf) -> anyhow::Result { + let db_filename = output_folder.join("session.json"); + info!("will use {:?} for session persistence", db_filename); + tokio::fs::create_dir_all(&output_folder) + .await + .with_context(|| { + format!( + "couldn't create directory {:?} for session storage", + output_folder + ) + })?; + + let db = match tokio::fs::File::open(&db_filename).await { + Ok(f) => { + let mut buf = Vec::new(); + let mut rdr = tokio::io::BufReader::new(f); + rdr.read_to_end(&mut buf).await?; + + serde_json::from_reader(&buf[..]).context("error deserializing session database")? + } + Err(e) if e.kind() == std::io::ErrorKind::NotFound => Default::default(), + Err(e) => { + return Err(e).context(format!("error opening session file {:?}", db_filename)) + } + }; + + Ok(Self { + db_filename, + output_folder, + db_content: tokio::sync::RwLock::new(db), + }) + } + + async fn flush(&self) -> anyhow::Result<()> { + let tmp_filename = format!("{}.tmp", self.db_filename.to_str().unwrap()); + let mut tmp = tokio::fs::OpenOptions::new() + .create(true) + .truncate(true) + .write(true) + .open(&tmp_filename) + .await + .with_context(|| format!("error opening {:?}", tmp_filename))?; + + let mut buf = Vec::new(); + serde_json::to_writer(&mut buf, &*self.db_content.read().await) + .context("error serializing")?; + tmp.write_all(&buf) + .await + .with_context(|| format!("error writing {tmp_filename:?}"))?; + + tokio::fs::rename(&tmp_filename, &self.db_filename) + .await + .context("error renaming persistence file")?; + trace!(filename=?self.db_filename, "wrote persistence"); + Ok(()) + } + + fn torrent_bytes_filename(&self, info_hash: &Id20) -> PathBuf { + self.output_folder.join(format!("{:?}.torrent", info_hash)) + } +} + +#[async_trait] +impl SessionPersistenceStore for JsonSessionPersistenceStore { + async fn next_id(&self) -> anyhow::Result { + Ok(self + .db_content + .read() + .await + .torrents + .keys() + .copied() + .max() + .map(|max| max + 1) + .unwrap_or(0)) + } + + async fn store(&self, id: TorrentId, torrent: &ManagedTorrentHandle) -> anyhow::Result<()> { + if !torrent + .storage_factory + .is_type_id(TypeId::of::()) + { + bail!("storages other than FilesystemStorageFactory are not supported"); + } + + let st = SerializedTorrent { + trackers: torrent + .info() + .trackers + .iter() + .map(|u| u.to_string()) + .collect(), + info_hash: torrent.info_hash(), + // we don't serialize this here, but to a file instead. + torrent_bytes: Default::default(), + only_files: torrent.only_files().clone(), + is_paused: torrent.with_state(|s| matches!(s, ManagedTorrentState::Paused(_))), + output_folder: torrent.info().options.output_folder.clone(), + }; + + if !torrent.info().torrent_bytes.is_empty() { + let torrent_bytes_file = self.torrent_bytes_filename(&torrent.info_hash()); + match tokio::fs::OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(&torrent_bytes_file) + .await + { + Ok(mut f) => { + if let Err(e) = f.write_all(&torrent.info().torrent_bytes).await { + warn!(error=?e, file=?torrent_bytes_file, "error writing torrent bytes") + } + } + Err(e) => { + warn!(error=?e, file=?torrent_bytes_file, "error opening torrent bytes file") + } + } + } + + self.db_content.write().await.torrents.insert(id, st); + self.flush().await?; + + Ok(()) + } + + async fn delete(&self, id: TorrentId) -> anyhow::Result<()> { + if let Some(t) = self.db_content.write().await.torrents.remove(&id) { + self.flush().await?; + let tf = self.torrent_bytes_filename(&t.info_hash); + if let Err(e) = tokio::fs::remove_file(&tf).await { + warn!(error=?e, filename=?tf, "error removing torrent file"); + } + } + + Ok(()) + } + + async fn get(&self, id: TorrentId) -> anyhow::Result { + let mut st = self + .db_content + .read() + .await + .torrents + .get(&id) + .cloned() + .context("no torrent found")?; + let mut buf = Vec::new(); + let torrent_bytes_filename = self.torrent_bytes_filename(&st.info_hash); + let mut torrent_bytes_file = match tokio::fs::File::open(&torrent_bytes_filename).await { + Ok(f) => f, + Err(e) => { + warn!(error=?e, filename=?torrent_bytes_filename, "error opening torrent bytes file"); + return Ok(st); + } + }; + if let Err(e) = torrent_bytes_file.read_to_end(&mut buf).await { + warn!(error=?e, filename=?torrent_bytes_filename, "error reading torrent bytes file"); + } else { + st.torrent_bytes = buf.into(); + } + return Ok(st); + } + + async fn stream_all( + &self, + ) -> anyhow::Result>> { + let all_ids = self + .db_content + .read() + .await + .torrents + .keys() + .copied() + .collect_vec(); + Ok(futures::stream::iter(all_ids) + .then(move |id| async move { self.get(id).await.map(move |st| (id, st)) }) + .boxed()) + } +} diff --git a/crates/librqbit/src/session_persistence/mod.rs b/crates/librqbit/src/session_persistence/mod.rs new file mode 100644 index 0000000..c3ee7a2 --- /dev/null +++ b/crates/librqbit/src/session_persistence/mod.rs @@ -0,0 +1,87 @@ +pub mod json; + +use std::{collections::HashSet, path::PathBuf}; + +use anyhow::Context; +use async_trait::async_trait; +use bytes::Bytes; +use futures::stream::BoxStream; +use librqbit_core::magnet::Magnet; +use librqbit_core::Id20; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::{ + session::TorrentId, torrent_state::ManagedTorrentHandle, AddTorrent, AddTorrentOptions, +}; + +#[derive(Serialize, Deserialize, Clone)] +pub struct SerializedTorrent { + #[serde( + serialize_with = "serialize_info_hash", + deserialize_with = "deserialize_info_hash" + )] + info_hash: Id20, + #[serde(skip)] + torrent_bytes: Bytes, + trackers: HashSet, + output_folder: PathBuf, + only_files: Option>, + is_paused: bool, +} + +impl SerializedTorrent { + pub fn info_hash(&self) -> &Id20 { + &self.info_hash + } + pub fn into_add_torrent(self) -> anyhow::Result<(AddTorrent<'static>, AddTorrentOptions)> { + let add_torrent = if !self.torrent_bytes.is_empty() { + AddTorrent::TorrentFileBytes(self.torrent_bytes) + } else { + let magnet = + Magnet::from_id20(self.info_hash, self.trackers.into_iter().collect()).to_string(); + AddTorrent::from_url(magnet) + }; + + let opts = AddTorrentOptions { + paused: self.is_paused, + output_folder: Some( + self.output_folder + .to_str() + .context("broken path")? + .to_owned(), + ), + only_files: self.only_files, + overwrite: true, + ..Default::default() + }; + + Ok((add_torrent, opts)) + } +} + +#[async_trait] +pub trait SessionPersistenceStore: core::fmt::Debug + Send + Sync { + async fn next_id(&self) -> anyhow::Result; + async fn store(&self, id: TorrentId, torrent: &ManagedTorrentHandle) -> anyhow::Result<()>; + async fn delete(&self, id: TorrentId) -> anyhow::Result<()>; + async fn get(&self, id: TorrentId) -> anyhow::Result; + async fn stream_all( + &self, + ) -> anyhow::Result>>; +} + +pub type BoxSessionPersistenceStore = Box; + +fn serialize_info_hash(id: &Id20, serializer: S) -> Result +where + S: Serializer, +{ + id.as_string().serialize(serializer) +} + +fn deserialize_info_hash<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + Id20::deserialize(deserializer) +} diff --git a/crates/librqbit/src/tests/e2e.rs b/crates/librqbit/src/tests/e2e.rs index dedc2e8..6485a55 100644 --- a/crates/librqbit/src/tests/e2e.rs +++ b/crates/librqbit/src/tests/e2e.rs @@ -66,8 +66,7 @@ async fn test_e2e_download() { disable_dht: true, disable_dht_persistence: true, dht_config: None, - persistence: false, - persistence_filename: None, + persistence: None, peer_id: Some(peer_id), peer_opts: None, listen_port_range: Some(15100..17000), @@ -150,8 +149,7 @@ async fn test_e2e_download() { disable_dht: true, disable_dht_persistence: true, dht_config: None, - persistence: false, - persistence_filename: None, + persistence: None, listen_port_range: None, enable_upnp_port_forwarding: false, ..Default::default() @@ -230,7 +228,7 @@ async fn test_e2e_download() { } info!("handle is completed"); - session.delete(id, false).unwrap(); + session.delete(id, false).await.unwrap(); info!("deleted handle"); diff --git a/crates/librqbit/src/tests/e2e_stream.rs b/crates/librqbit/src/tests/e2e_stream.rs index a2ab08e..61ec741 100644 --- a/crates/librqbit/src/tests/e2e_stream.rs +++ b/crates/librqbit/src/tests/e2e_stream.rs @@ -28,7 +28,7 @@ async fn e2e_stream() -> anyhow::Result<()> { crate::SessionOptions { disable_dht: true, peer_id: Some(TestPeerMetadata::good().as_peer_id()), - persistence: false, + persistence: None, listen_port_range: Some(16001..16100), enable_upnp_port_forwarding: false, ..Default::default() @@ -72,7 +72,7 @@ async fn e2e_stream() -> anyhow::Result<()> { client_dir.path().into(), crate::SessionOptions { disable_dht: true, - persistence: false, + persistence: None, peer_id: Some(TestPeerMetadata::good().as_peer_id()), listen_port_range: None, enable_upnp_port_forwarding: false, diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index 782c370..e5799b3 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -13,7 +13,7 @@ use librqbit::{ }, tracing_subscriber_config_utils::{init_logging, InitLoggingOptions}, AddTorrent, AddTorrentOptions, AddTorrentResponse, Api, ListOnlyResponse, - PeerConnectionOptions, Session, SessionOptions, TorrentStatsState, + PeerConnectionOptions, Session, SessionOptions, SessionPersistenceConfig, TorrentStatsState, }; use size_format::SizeFormatterBinary as SF; use tracing::{error, error_span, info, trace_span, warn}; @@ -132,9 +132,13 @@ struct ServerStartOptions { long = "disable-persistence", help = "Disable server persistence. It will not read or write its state to disk." )] + + /// Disable session persistence. disable_persistence: bool, - #[arg(long = "persistence-filename")] - persistence_filename: Option, + + /// The folder to store session data in. By default uses OS specific folder. + #[arg(long = "persistence-folder")] + persistence_folder: Option, } #[derive(Parser)] @@ -297,8 +301,7 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> { disable_dht_persistence: opts.disable_dht_persistence, dht_config: None, // This will be overriden by "server start" below if needed. - persistence: false, - persistence_filename: None, + persistence: None, peer_id: None, peer_opts: Some(PeerConnectionOptions { connect_timeout: Some(opts.peer_connect_timeout), @@ -389,9 +392,11 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> { match &opts.subcommand { SubCommand::Server(server_opts) => match &server_opts.subcommand { ServerSubcommand::Start(start_opts) => { - sopts.persistence = !start_opts.disable_persistence; - sopts.persistence_filename = - start_opts.persistence_filename.clone().map(PathBuf::from); + if !start_opts.disable_persistence { + sopts.persistence = Some(SessionPersistenceConfig::Json { + folder: start_opts.persistence_folder.clone().map(PathBuf::from), + }) + } let session = Session::new_with_opts(PathBuf::from(&start_opts.output_folder), sopts)