diff --git a/.cargo/config b/.cargo/config index 1b76467..0aca9d9 100644 --- a/.cargo/config +++ b/.cargo/config @@ -1,5 +1,2 @@ [target.arm-unknown-linux-gnueabihf] -rustflags = ["-l", "atomic"] - -[build] -rustflags = ["--cfg", "tokio_unstable"] \ No newline at end of file +rustflags = ["-l", "atomic"] \ No newline at end of file diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 6733eb9..ddf82dc 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -1,5 +1,13 @@ use std::{ - borrow::Cow, collections::HashMap, io::Read, net::SocketAddr, path::PathBuf, time::Duration, + borrow::Cow, + collections::{HashMap, HashSet}, + fs::{File, OpenOptions}, + io::{BufReader, BufWriter, Read}, + net::SocketAddr, + path::PathBuf, + str::FromStr, + sync::Arc, + time::Duration, }; use anyhow::{bail, Context}; @@ -12,13 +20,14 @@ use librqbit_core::{ }; use parking_lot::RwLock; use reqwest::Url; +use serde::{Deserialize, Serialize}; use tokio_stream::StreamExt; -use tracing::{debug, error_span, info, warn}; +use tracing::{debug, error, error_span, info, warn}; use crate::{ dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult}, peer_connection::PeerConnectionOptions, - spawn_utils::BlockingSpawner, + spawn_utils::{spawn, BlockingSpawner}, torrent_state::{ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState}, }; @@ -27,26 +36,62 @@ pub const SUPPORTED_SCHEMES: [&str; 3] = ["http:", "https:", "magnet:"]; pub type TorrentId = usize; #[derive(Default)] -pub struct SessionLocked { +pub struct SessionDatabase { next_id: usize, torrents: HashMap, } -impl SessionLocked { +impl SessionDatabase { fn add_torrent(&mut self, torrent: ManagedTorrentHandle) -> TorrentId { let idx = self.next_id; self.torrents.insert(idx, torrent); self.next_id += 1; idx } + + fn serialize(&self) -> SerializedSessionDatabase { + SerializedSessionDatabase { + torrents: self + .torrents + .values() + .map(|torrent| SerializedTorrent { + trackers: torrent + .info() + .trackers + .iter() + .map(|u| u.to_string()) + .collect(), + info_hash: torrent.info_hash().as_string(), + only_files: torrent.only_files.clone(), + is_paused: torrent.with_state(|s| matches!(s, ManagedTorrentState::Paused(_))), + output_folder: torrent.info().out_dir.clone(), + }) + .collect(), + } + } +} + +#[derive(Serialize, Deserialize)] +struct SerializedTorrent { + info_hash: String, + trackers: HashSet, + output_folder: PathBuf, + only_files: Option>, + is_paused: bool, +} + +#[derive(Serialize, Deserialize)] +struct SerializedSessionDatabase { + torrents: Vec, } pub struct Session { peer_id: Id20, dht: Option, + persistence_filename: PathBuf, peer_opts: PeerConnectionOptions, spawner: BlockingSpawner, - locked: RwLock, + db: RwLock, output_folder: PathBuf, } @@ -86,6 +131,7 @@ fn compute_only_files>( #[derive(Default, Clone)] pub struct AddTorrentOptions { + pub paused: bool, pub only_files_regex: Option, pub only_files: Option>, pub overwrite: bool, @@ -164,20 +210,24 @@ impl<'a> AddTorrent<'a> { pub struct SessionOptions { pub disable_dht: bool, pub disable_dht_persistence: bool, + pub persistence: bool, pub dht_config: Option, pub peer_id: Option, pub peer_opts: Option, } impl Session { - pub async fn new(output_folder: PathBuf, spawner: BlockingSpawner) -> anyhow::Result { + pub async fn new( + output_folder: PathBuf, + spawner: BlockingSpawner, + ) -> anyhow::Result> { Self::new_with_opts(output_folder, spawner, SessionOptions::default()).await } pub async fn new_with_opts( output_folder: PathBuf, spawner: BlockingSpawner, opts: SessionOptions, - ) -> anyhow::Result { + ) -> anyhow::Result> { let peer_id = opts.peer_id.unwrap_or_else(generate_peer_id); let dht = if opts.disable_dht { None @@ -191,25 +241,117 @@ impl Session { Some(dht) }; let peer_opts = opts.peer_opts.unwrap_or_default(); - - Ok(Self { + let session_filename = output_folder.join(".rqbit-session.json"); + let session = Arc::new(Self { + persistence_filename: session_filename, peer_id, dht, peer_opts, spawner, output_folder, - locked: RwLock::new(SessionLocked::default()), - }) + db: RwLock::new(Default::default()), + }); + + if opts.persistence { + let session = session.clone(); + spawn( + "session persistene", + error_span!("session persistence"), + async move { + // Populate initial from the state filename + if let Err(e) = session.populate_from_stored().await { + error!("could not populate session from stored file: {:?}", e); + } + + let session = Arc::downgrade(&session); + + loop { + tokio::time::sleep(Duration::from_secs(10)).await; + let session = match session.upgrade() { + Some(s) => s, + None => break, + }; + if let Err(e) = session.dump_to_disk() { + error!("error dumping session to disk: {:?}", e); + } + } + + Ok(()) + }, + ); + } + + Ok(session) } pub fn get_dht(&self) -> Option<&Dht> { self.dht.as_ref() } + async fn populate_from_stored(&self) -> anyhow::Result<()> { + let mut rdr = BufReader::new( + std::fs::File::open(&self.persistence_filename).with_context(|| { + format!("error opening session file {:?}", self.persistence_filename) + })?, + ); + let db: SerializedSessionDatabase = + serde_json::from_reader(&mut rdr).context("error deserializing session database")?; + for storrent in db.torrents.into_iter() { + let magnet = Magnet { + info_hash: Id20::from_str(&storrent.info_hash) + .context("error deserializing info_hash")?, + trackers: storrent.trackers.into_iter().collect(), + }; + if let Err(e) = self + .add_torrent( + AddTorrent::Url(Cow::Owned(magnet.to_string())), + Some(AddTorrentOptions { + paused: storrent.is_paused, + output_folder: Some( + storrent + .output_folder + .to_str() + .context("broken path")? + .to_owned(), + ), + only_files: storrent.only_files, + overwrite: true, + ..Default::default() + }), + ) + .await + { + error!("error adding torrent from stored session: {:?}", e) + } + } + Ok(()) + } + + fn dump_to_disk(&self) -> anyhow::Result<()> { + let tmp_filename = format!("{}.tmp", self.persistence_filename.to_str().unwrap()); + let mut tmp = BufWriter::new( + std::fs::OpenOptions::new() + .create(true) + .create_new(true) + .truncate(true) + .write(true) + .open(&tmp_filename) + .with_context(|| format!("error opening {:?}", tmp_filename))?, + ); + let serialized = self.db.read().serialize(); + serde_json::to_writer(&mut tmp, &serialized).context("error serializing")?; + drop(tmp); + + std::fs::rename(&tmp_filename, &self.persistence_filename) + .context("error renaming persistence file")?; + debug!("wrote persistence to {:?}", &self.persistence_filename); + Ok(()) + } + pub fn with_torrents( &self, callback: impl Fn(&mut dyn Iterator) -> R, ) -> R { - callback(&mut self.locked.read().torrents.iter().map(|(id, t)| (*id, t))) + callback(&mut self.db.read().torrents.iter().map(|(id, t)| (*id, t))) } pub async fn add_torrent( @@ -407,7 +549,7 @@ impl Session { } let (managed_torrent, id) = { - let mut g = self.locked.write(); + let mut g = self.db.write(); if let Some((id, handle)) = g.torrents.iter().find(|(_, t)| t.info_hash() == info_hash) { return Ok(AddTorrentResponse::AlreadyManaged(*id, handle.clone())); @@ -422,7 +564,7 @@ impl Session { let span = managed_torrent.info.span.clone(); let _ = span.enter(); managed_torrent - .start(initial_peers, dht_peer_rx) + .start(initial_peers, dht_peer_rx, opts.paused) .context("error starting torrent")?; } @@ -430,12 +572,12 @@ impl Session { } pub fn get(&self, id: TorrentId) -> Option { - self.locked.read().torrents.get(&id).cloned() + self.db.read().torrents.get(&id).cloned() } pub fn delete(&self, id: TorrentId, delete_files: bool) -> anyhow::Result<()> { let removed = self - .locked + .db .write() .torrents .remove(&id) @@ -477,7 +619,7 @@ impl Session { .as_ref() .map(|dht| dht.get_peers(handle.info_hash())) .transpose()?; - handle.start(Default::default(), peer_rx)?; + handle.start(Default::default(), peer_rx, false)?; Ok(()) } } diff --git a/crates/librqbit/src/spawn_utils.rs b/crates/librqbit/src/spawn_utils.rs index 957a837..1e404bd 100644 --- a/crates/librqbit/src/spawn_utils.rs +++ b/crates/librqbit/src/spawn_utils.rs @@ -1,7 +1,7 @@ use tracing::{debug, trace, warn, Instrument}; pub fn spawn( - name: &str, + _name: &str, span: tracing::Span, fut: impl std::future::Future> + Send + 'static, ) -> tokio::task::JoinHandle<()> { @@ -17,7 +17,7 @@ pub fn spawn( } } .instrument(span.or_current()); - tokio::task::Builder::new().name(name).spawn(fut).unwrap() + tokio::task::spawn(fut) } #[derive(Clone, Copy, Debug)] diff --git a/crates/librqbit/src/torrent_state/mod.rs b/crates/librqbit/src/torrent_state/mod.rs index 502a2dd..e95c5b3 100644 --- a/crates/librqbit/src/torrent_state/mod.rs +++ b/crates/librqbit/src/torrent_state/mod.rs @@ -85,7 +85,7 @@ pub struct ManagedTorrentInfo { pub struct ManagedTorrent { pub info: Arc, - only_files: Option>, + pub(crate) only_files: Option>, locked: RwLock, } @@ -138,6 +138,7 @@ impl ManagedTorrent { self: &Arc, initial_peers: Vec, peer_rx: Option + Unpin + Send + Sync + 'static>, + start_paused: bool, ) -> anyhow::Result<()> { let mut g = self.locked.write(); @@ -185,6 +186,11 @@ impl ManagedTorrent { return Ok(()); } + if start_paused { + g.state = ManagedTorrentState::Paused(paused); + return Ok(()); + } + let live = TorrentStateLive::new(paused); g.state = ManagedTorrentState::Live(live.clone()); diff --git a/crates/librqbit_core/src/magnet.rs b/crates/librqbit_core/src/magnet.rs index 5b5739e..12e09d9 100644 --- a/crates/librqbit_core/src/magnet.rs +++ b/crates/librqbit_core/src/magnet.rs @@ -41,6 +41,17 @@ impl Magnet { } } +impl std::fmt::Display for Magnet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "magnet:?xt=urn:btih:{}&tr={}", + self.info_hash.as_string(), + self.trackers.join("&tr=") + ) + } +} + #[cfg(test)] mod tests { #[test] diff --git a/crates/rqbit/Cargo.toml b/crates/rqbit/Cargo.toml index a0c5410..e2a1d97 100644 --- a/crates/rqbit/Cargo.toml +++ b/crates/rqbit/Cargo.toml @@ -13,7 +13,7 @@ readme = "README.md" [features] default = ["sha1-system", "default-tls", "webui"] -tokio-console = ["console-subscriber",] +tokio-console = ["console-subscriber", "tokio/tracing"] webui = ["librqbit/webui"] timed_existence = ["librqbit/timed_existence"] sha1-system = ["librqbit/sha1-system"] @@ -25,7 +25,7 @@ rust-tls = ["librqbit/rust-tls"] [dependencies] librqbit = {path="../librqbit", default-features=false, version = "3.3.0"} dht = {path="../dht", package="librqbit-dht", version="3.1.0"} -tokio = {version = "1", features = ["macros", "rt-multi-thread", "tracing"]} +tokio = {version = "1", features = ["macros", "rt-multi-thread"]} console-subscriber = {version = "0.2", optional = true} anyhow = "1" clap = {version = "4", features = ["derive", "deprecated"]} diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index 4e79124..dfdab9a 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -279,6 +279,7 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()> disable_dht: opts.disable_dht, disable_dht_persistence: opts.disable_dht_persistence, dht_config: None, + persistence: true, peer_id: None, peer_opts: Some(PeerConnectionOptions { connect_timeout: Some(opts.peer_connect_timeout), @@ -342,15 +343,13 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()> match &opts.subcommand { SubCommand::Server(server_opts) => match &server_opts.subcommand { ServerSubcommand::Start(start_opts) => { - let session = Arc::new( - Session::new_with_opts( - PathBuf::from(&start_opts.output_folder), - spawner, - sopts, - ) - .await - .context("error initializing rqbit session")?, - ); + let session = Session::new_with_opts( + PathBuf::from(&start_opts.output_folder), + spawner, + sopts, + ) + .await + .context("error initializing rqbit session")?; spawn( "stats_printer", trace_span!("stats_printer"), @@ -416,21 +415,19 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()> } Ok(()) } else { - let session = Arc::new( - Session::new_with_opts( - download_opts - .output_folder - .as_ref() - .map(PathBuf::from) - .context( - "output_folder is required if can't connect to an existing server", - )?, - spawner, - sopts, - ) - .await - .context("error initializing rqbit session")?, - ); + let session = Session::new_with_opts( + download_opts + .output_folder + .as_ref() + .map(PathBuf::from) + .context( + "output_folder is required if can't connect to an existing server", + )?, + spawner, + sopts, + ) + .await + .context("error initializing rqbit session")?; spawn( "stats_printer", trace_span!("stats_printer"),