Session persistence

This commit is contained in:
Igor Katson 2023-11-25 02:36:19 +00:00
parent e467787c38
commit bec5e1be7f
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
7 changed files with 204 additions and 51 deletions

View file

@ -1,5 +1,2 @@
[target.arm-unknown-linux-gnueabihf] [target.arm-unknown-linux-gnueabihf]
rustflags = ["-l", "atomic"] rustflags = ["-l", "atomic"]
[build]
rustflags = ["--cfg", "tokio_unstable"]

View file

@ -1,5 +1,13 @@
use std::{ use std::{
borrow::Cow, collections::HashMap, io::Read, net::SocketAddr, path::PathBuf, time::Duration, borrow::Cow,
collections::{HashMap, HashSet},
fs::{File, OpenOptions},
io::{BufReader, BufWriter, Read},
net::SocketAddr,
path::PathBuf,
str::FromStr,
sync::Arc,
time::Duration,
}; };
use anyhow::{bail, Context}; use anyhow::{bail, Context};
@ -12,13 +20,14 @@ use librqbit_core::{
}; };
use parking_lot::RwLock; use parking_lot::RwLock;
use reqwest::Url; use reqwest::Url;
use serde::{Deserialize, Serialize};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::{debug, error_span, info, warn}; use tracing::{debug, error, error_span, info, warn};
use crate::{ use crate::{
dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult}, dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult},
peer_connection::PeerConnectionOptions, peer_connection::PeerConnectionOptions,
spawn_utils::BlockingSpawner, spawn_utils::{spawn, BlockingSpawner},
torrent_state::{ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState}, torrent_state::{ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState},
}; };
@ -27,26 +36,62 @@ pub const SUPPORTED_SCHEMES: [&str; 3] = ["http:", "https:", "magnet:"];
pub type TorrentId = usize; pub type TorrentId = usize;
#[derive(Default)] #[derive(Default)]
pub struct SessionLocked { pub struct SessionDatabase {
next_id: usize, next_id: usize,
torrents: HashMap<usize, ManagedTorrentHandle>, torrents: HashMap<usize, ManagedTorrentHandle>,
} }
impl SessionLocked { impl SessionDatabase {
fn add_torrent(&mut self, torrent: ManagedTorrentHandle) -> TorrentId { fn add_torrent(&mut self, torrent: ManagedTorrentHandle) -> TorrentId {
let idx = self.next_id; let idx = self.next_id;
self.torrents.insert(idx, torrent); self.torrents.insert(idx, torrent);
self.next_id += 1; self.next_id += 1;
idx idx
} }
fn serialize(&self) -> SerializedSessionDatabase {
SerializedSessionDatabase {
torrents: self
.torrents
.values()
.map(|torrent| SerializedTorrent {
trackers: torrent
.info()
.trackers
.iter()
.map(|u| u.to_string())
.collect(),
info_hash: torrent.info_hash().as_string(),
only_files: torrent.only_files.clone(),
is_paused: torrent.with_state(|s| matches!(s, ManagedTorrentState::Paused(_))),
output_folder: torrent.info().out_dir.clone(),
})
.collect(),
}
}
}
#[derive(Serialize, Deserialize)]
struct SerializedTorrent {
info_hash: String,
trackers: HashSet<String>,
output_folder: PathBuf,
only_files: Option<Vec<usize>>,
is_paused: bool,
}
#[derive(Serialize, Deserialize)]
struct SerializedSessionDatabase {
torrents: Vec<SerializedTorrent>,
} }
pub struct Session { pub struct Session {
peer_id: Id20, peer_id: Id20,
dht: Option<Dht>, dht: Option<Dht>,
persistence_filename: PathBuf,
peer_opts: PeerConnectionOptions, peer_opts: PeerConnectionOptions,
spawner: BlockingSpawner, spawner: BlockingSpawner,
locked: RwLock<SessionLocked>, db: RwLock<SessionDatabase>,
output_folder: PathBuf, output_folder: PathBuf,
} }
@ -86,6 +131,7 @@ fn compute_only_files<ByteBuf: AsRef<[u8]>>(
#[derive(Default, Clone)] #[derive(Default, Clone)]
pub struct AddTorrentOptions { pub struct AddTorrentOptions {
pub paused: bool,
pub only_files_regex: Option<String>, pub only_files_regex: Option<String>,
pub only_files: Option<Vec<usize>>, pub only_files: Option<Vec<usize>>,
pub overwrite: bool, pub overwrite: bool,
@ -164,20 +210,24 @@ impl<'a> AddTorrent<'a> {
pub struct SessionOptions { pub struct SessionOptions {
pub disable_dht: bool, pub disable_dht: bool,
pub disable_dht_persistence: bool, pub disable_dht_persistence: bool,
pub persistence: bool,
pub dht_config: Option<PersistentDhtConfig>, pub dht_config: Option<PersistentDhtConfig>,
pub peer_id: Option<Id20>, pub peer_id: Option<Id20>,
pub peer_opts: Option<PeerConnectionOptions>, pub peer_opts: Option<PeerConnectionOptions>,
} }
impl Session { impl Session {
pub async fn new(output_folder: PathBuf, spawner: BlockingSpawner) -> anyhow::Result<Self> { pub async fn new(
output_folder: PathBuf,
spawner: BlockingSpawner,
) -> anyhow::Result<Arc<Self>> {
Self::new_with_opts(output_folder, spawner, SessionOptions::default()).await Self::new_with_opts(output_folder, spawner, SessionOptions::default()).await
} }
pub async fn new_with_opts( pub async fn new_with_opts(
output_folder: PathBuf, output_folder: PathBuf,
spawner: BlockingSpawner, spawner: BlockingSpawner,
opts: SessionOptions, opts: SessionOptions,
) -> anyhow::Result<Self> { ) -> anyhow::Result<Arc<Self>> {
let peer_id = opts.peer_id.unwrap_or_else(generate_peer_id); let peer_id = opts.peer_id.unwrap_or_else(generate_peer_id);
let dht = if opts.disable_dht { let dht = if opts.disable_dht {
None None
@ -191,25 +241,117 @@ impl Session {
Some(dht) Some(dht)
}; };
let peer_opts = opts.peer_opts.unwrap_or_default(); let peer_opts = opts.peer_opts.unwrap_or_default();
let session_filename = output_folder.join(".rqbit-session.json");
Ok(Self { let session = Arc::new(Self {
persistence_filename: session_filename,
peer_id, peer_id,
dht, dht,
peer_opts, peer_opts,
spawner, spawner,
output_folder, output_folder,
locked: RwLock::new(SessionLocked::default()), db: RwLock::new(Default::default()),
}) });
if opts.persistence {
let session = session.clone();
spawn(
"session persistene",
error_span!("session persistence"),
async move {
// Populate initial from the state filename
if let Err(e) = session.populate_from_stored().await {
error!("could not populate session from stored file: {:?}", e);
}
let session = Arc::downgrade(&session);
loop {
tokio::time::sleep(Duration::from_secs(10)).await;
let session = match session.upgrade() {
Some(s) => s,
None => break,
};
if let Err(e) = session.dump_to_disk() {
error!("error dumping session to disk: {:?}", e);
}
}
Ok(())
},
);
}
Ok(session)
} }
pub fn get_dht(&self) -> Option<&Dht> { pub fn get_dht(&self) -> Option<&Dht> {
self.dht.as_ref() self.dht.as_ref()
} }
async fn populate_from_stored(&self) -> anyhow::Result<()> {
let mut rdr = BufReader::new(
std::fs::File::open(&self.persistence_filename).with_context(|| {
format!("error opening session file {:?}", self.persistence_filename)
})?,
);
let db: SerializedSessionDatabase =
serde_json::from_reader(&mut rdr).context("error deserializing session database")?;
for storrent in db.torrents.into_iter() {
let magnet = Magnet {
info_hash: Id20::from_str(&storrent.info_hash)
.context("error deserializing info_hash")?,
trackers: storrent.trackers.into_iter().collect(),
};
if let Err(e) = self
.add_torrent(
AddTorrent::Url(Cow::Owned(magnet.to_string())),
Some(AddTorrentOptions {
paused: storrent.is_paused,
output_folder: Some(
storrent
.output_folder
.to_str()
.context("broken path")?
.to_owned(),
),
only_files: storrent.only_files,
overwrite: true,
..Default::default()
}),
)
.await
{
error!("error adding torrent from stored session: {:?}", e)
}
}
Ok(())
}
fn dump_to_disk(&self) -> anyhow::Result<()> {
let tmp_filename = format!("{}.tmp", self.persistence_filename.to_str().unwrap());
let mut tmp = BufWriter::new(
std::fs::OpenOptions::new()
.create(true)
.create_new(true)
.truncate(true)
.write(true)
.open(&tmp_filename)
.with_context(|| format!("error opening {:?}", tmp_filename))?,
);
let serialized = self.db.read().serialize();
serde_json::to_writer(&mut tmp, &serialized).context("error serializing")?;
drop(tmp);
std::fs::rename(&tmp_filename, &self.persistence_filename)
.context("error renaming persistence file")?;
debug!("wrote persistence to {:?}", &self.persistence_filename);
Ok(())
}
pub fn with_torrents<R>( pub fn with_torrents<R>(
&self, &self,
callback: impl Fn(&mut dyn Iterator<Item = (TorrentId, &ManagedTorrentHandle)>) -> R, callback: impl Fn(&mut dyn Iterator<Item = (TorrentId, &ManagedTorrentHandle)>) -> R,
) -> R { ) -> R {
callback(&mut self.locked.read().torrents.iter().map(|(id, t)| (*id, t))) callback(&mut self.db.read().torrents.iter().map(|(id, t)| (*id, t)))
} }
pub async fn add_torrent( pub async fn add_torrent(
@ -407,7 +549,7 @@ impl Session {
} }
let (managed_torrent, id) = { let (managed_torrent, id) = {
let mut g = self.locked.write(); let mut g = self.db.write();
if let Some((id, handle)) = g.torrents.iter().find(|(_, t)| t.info_hash() == info_hash) if let Some((id, handle)) = g.torrents.iter().find(|(_, t)| t.info_hash() == info_hash)
{ {
return Ok(AddTorrentResponse::AlreadyManaged(*id, handle.clone())); return Ok(AddTorrentResponse::AlreadyManaged(*id, handle.clone()));
@ -422,7 +564,7 @@ impl Session {
let span = managed_torrent.info.span.clone(); let span = managed_torrent.info.span.clone();
let _ = span.enter(); let _ = span.enter();
managed_torrent managed_torrent
.start(initial_peers, dht_peer_rx) .start(initial_peers, dht_peer_rx, opts.paused)
.context("error starting torrent")?; .context("error starting torrent")?;
} }
@ -430,12 +572,12 @@ impl Session {
} }
pub fn get(&self, id: TorrentId) -> Option<ManagedTorrentHandle> { pub fn get(&self, id: TorrentId) -> Option<ManagedTorrentHandle> {
self.locked.read().torrents.get(&id).cloned() self.db.read().torrents.get(&id).cloned()
} }
pub fn delete(&self, id: TorrentId, delete_files: bool) -> anyhow::Result<()> { pub fn delete(&self, id: TorrentId, delete_files: bool) -> anyhow::Result<()> {
let removed = self let removed = self
.locked .db
.write() .write()
.torrents .torrents
.remove(&id) .remove(&id)
@ -477,7 +619,7 @@ impl Session {
.as_ref() .as_ref()
.map(|dht| dht.get_peers(handle.info_hash())) .map(|dht| dht.get_peers(handle.info_hash()))
.transpose()?; .transpose()?;
handle.start(Default::default(), peer_rx)?; handle.start(Default::default(), peer_rx, false)?;
Ok(()) Ok(())
} }
} }

View file

@ -1,7 +1,7 @@
use tracing::{debug, trace, warn, Instrument}; use tracing::{debug, trace, warn, Instrument};
pub fn spawn( pub fn spawn(
name: &str, _name: &str,
span: tracing::Span, span: tracing::Span,
fut: impl std::future::Future<Output = anyhow::Result<()>> + Send + 'static, fut: impl std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
) -> tokio::task::JoinHandle<()> { ) -> tokio::task::JoinHandle<()> {
@ -17,7 +17,7 @@ pub fn spawn(
} }
} }
.instrument(span.or_current()); .instrument(span.or_current());
tokio::task::Builder::new().name(name).spawn(fut).unwrap() tokio::task::spawn(fut)
} }
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]

View file

@ -85,7 +85,7 @@ pub struct ManagedTorrentInfo {
pub struct ManagedTorrent { pub struct ManagedTorrent {
pub info: Arc<ManagedTorrentInfo>, pub info: Arc<ManagedTorrentInfo>,
only_files: Option<Vec<usize>>, pub(crate) only_files: Option<Vec<usize>>,
locked: RwLock<ManagedTorrentLocked>, locked: RwLock<ManagedTorrentLocked>,
} }
@ -138,6 +138,7 @@ impl ManagedTorrent {
self: &Arc<Self>, self: &Arc<Self>,
initial_peers: Vec<SocketAddr>, initial_peers: Vec<SocketAddr>,
peer_rx: Option<impl StreamExt<Item = SocketAddr> + Unpin + Send + Sync + 'static>, peer_rx: Option<impl StreamExt<Item = SocketAddr> + Unpin + Send + Sync + 'static>,
start_paused: bool,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut g = self.locked.write(); let mut g = self.locked.write();
@ -185,6 +186,11 @@ impl ManagedTorrent {
return Ok(()); return Ok(());
} }
if start_paused {
g.state = ManagedTorrentState::Paused(paused);
return Ok(());
}
let live = TorrentStateLive::new(paused); let live = TorrentStateLive::new(paused);
g.state = ManagedTorrentState::Live(live.clone()); g.state = ManagedTorrentState::Live(live.clone());

View file

@ -41,6 +41,17 @@ impl Magnet {
} }
} }
impl std::fmt::Display for Magnet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"magnet:?xt=urn:btih:{}&tr={}",
self.info_hash.as_string(),
self.trackers.join("&tr=")
)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[test] #[test]

View file

@ -13,7 +13,7 @@ readme = "README.md"
[features] [features]
default = ["sha1-system", "default-tls", "webui"] default = ["sha1-system", "default-tls", "webui"]
tokio-console = ["console-subscriber",] tokio-console = ["console-subscriber", "tokio/tracing"]
webui = ["librqbit/webui"] webui = ["librqbit/webui"]
timed_existence = ["librqbit/timed_existence"] timed_existence = ["librqbit/timed_existence"]
sha1-system = ["librqbit/sha1-system"] sha1-system = ["librqbit/sha1-system"]
@ -25,7 +25,7 @@ rust-tls = ["librqbit/rust-tls"]
[dependencies] [dependencies]
librqbit = {path="../librqbit", default-features=false, version = "3.3.0"} librqbit = {path="../librqbit", default-features=false, version = "3.3.0"}
dht = {path="../dht", package="librqbit-dht", version="3.1.0"} dht = {path="../dht", package="librqbit-dht", version="3.1.0"}
tokio = {version = "1", features = ["macros", "rt-multi-thread", "tracing"]} tokio = {version = "1", features = ["macros", "rt-multi-thread"]}
console-subscriber = {version = "0.2", optional = true} console-subscriber = {version = "0.2", optional = true}
anyhow = "1" anyhow = "1"
clap = {version = "4", features = ["derive", "deprecated"]} clap = {version = "4", features = ["derive", "deprecated"]}

View file

@ -279,6 +279,7 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
disable_dht: opts.disable_dht, disable_dht: opts.disable_dht,
disable_dht_persistence: opts.disable_dht_persistence, disable_dht_persistence: opts.disable_dht_persistence,
dht_config: None, dht_config: None,
persistence: true,
peer_id: None, peer_id: None,
peer_opts: Some(PeerConnectionOptions { peer_opts: Some(PeerConnectionOptions {
connect_timeout: Some(opts.peer_connect_timeout), connect_timeout: Some(opts.peer_connect_timeout),
@ -342,15 +343,13 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
match &opts.subcommand { match &opts.subcommand {
SubCommand::Server(server_opts) => match &server_opts.subcommand { SubCommand::Server(server_opts) => match &server_opts.subcommand {
ServerSubcommand::Start(start_opts) => { ServerSubcommand::Start(start_opts) => {
let session = Arc::new( let session = Session::new_with_opts(
Session::new_with_opts( PathBuf::from(&start_opts.output_folder),
PathBuf::from(&start_opts.output_folder), spawner,
spawner, sopts,
sopts, )
) .await
.await .context("error initializing rqbit session")?;
.context("error initializing rqbit session")?,
);
spawn( spawn(
"stats_printer", "stats_printer",
trace_span!("stats_printer"), trace_span!("stats_printer"),
@ -416,21 +415,19 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
} }
Ok(()) Ok(())
} else { } else {
let session = Arc::new( let session = Session::new_with_opts(
Session::new_with_opts( download_opts
download_opts .output_folder
.output_folder .as_ref()
.as_ref() .map(PathBuf::from)
.map(PathBuf::from) .context(
.context( "output_folder is required if can't connect to an existing server",
"output_folder is required if can't connect to an existing server", )?,
)?, spawner,
spawner, sopts,
sopts, )
) .await
.await .context("error initializing rqbit session")?;
.context("error initializing rqbit session")?,
);
spawn( spawn(
"stats_printer", "stats_printer",
trace_span!("stats_printer"), trace_span!("stats_printer"),