Session persistence
This commit is contained in:
parent
e467787c38
commit
bec5e1be7f
7 changed files with 204 additions and 51 deletions
|
|
@ -1,5 +1,2 @@
|
|||
[target.arm-unknown-linux-gnueabihf]
|
||||
rustflags = ["-l", "atomic"]
|
||||
|
||||
[build]
|
||||
rustflags = ["--cfg", "tokio_unstable"]
|
||||
rustflags = ["-l", "atomic"]
|
||||
|
|
@ -1,5 +1,13 @@
|
|||
use std::{
|
||||
borrow::Cow, collections::HashMap, io::Read, net::SocketAddr, path::PathBuf, time::Duration,
|
||||
borrow::Cow,
|
||||
collections::{HashMap, HashSet},
|
||||
fs::{File, OpenOptions},
|
||||
io::{BufReader, BufWriter, Read},
|
||||
net::SocketAddr,
|
||||
path::PathBuf,
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
|
|
@ -12,13 +20,14 @@ use librqbit_core::{
|
|||
};
|
||||
use parking_lot::RwLock;
|
||||
use reqwest::Url;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{debug, error_span, info, warn};
|
||||
use tracing::{debug, error, error_span, info, warn};
|
||||
|
||||
use crate::{
|
||||
dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult},
|
||||
peer_connection::PeerConnectionOptions,
|
||||
spawn_utils::BlockingSpawner,
|
||||
spawn_utils::{spawn, BlockingSpawner},
|
||||
torrent_state::{ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState},
|
||||
};
|
||||
|
||||
|
|
@ -27,26 +36,62 @@ pub const SUPPORTED_SCHEMES: [&str; 3] = ["http:", "https:", "magnet:"];
|
|||
pub type TorrentId = usize;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SessionLocked {
|
||||
pub struct SessionDatabase {
|
||||
next_id: usize,
|
||||
torrents: HashMap<usize, ManagedTorrentHandle>,
|
||||
}
|
||||
|
||||
impl SessionLocked {
|
||||
impl SessionDatabase {
|
||||
fn add_torrent(&mut self, torrent: ManagedTorrentHandle) -> TorrentId {
|
||||
let idx = self.next_id;
|
||||
self.torrents.insert(idx, torrent);
|
||||
self.next_id += 1;
|
||||
idx
|
||||
}
|
||||
|
||||
fn serialize(&self) -> SerializedSessionDatabase {
|
||||
SerializedSessionDatabase {
|
||||
torrents: self
|
||||
.torrents
|
||||
.values()
|
||||
.map(|torrent| SerializedTorrent {
|
||||
trackers: torrent
|
||||
.info()
|
||||
.trackers
|
||||
.iter()
|
||||
.map(|u| u.to_string())
|
||||
.collect(),
|
||||
info_hash: torrent.info_hash().as_string(),
|
||||
only_files: torrent.only_files.clone(),
|
||||
is_paused: torrent.with_state(|s| matches!(s, ManagedTorrentState::Paused(_))),
|
||||
output_folder: torrent.info().out_dir.clone(),
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SerializedTorrent {
|
||||
info_hash: String,
|
||||
trackers: HashSet<String>,
|
||||
output_folder: PathBuf,
|
||||
only_files: Option<Vec<usize>>,
|
||||
is_paused: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SerializedSessionDatabase {
|
||||
torrents: Vec<SerializedTorrent>,
|
||||
}
|
||||
|
||||
pub struct Session {
|
||||
peer_id: Id20,
|
||||
dht: Option<Dht>,
|
||||
persistence_filename: PathBuf,
|
||||
peer_opts: PeerConnectionOptions,
|
||||
spawner: BlockingSpawner,
|
||||
locked: RwLock<SessionLocked>,
|
||||
db: RwLock<SessionDatabase>,
|
||||
output_folder: PathBuf,
|
||||
}
|
||||
|
||||
|
|
@ -86,6 +131,7 @@ fn compute_only_files<ByteBuf: AsRef<[u8]>>(
|
|||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct AddTorrentOptions {
|
||||
pub paused: bool,
|
||||
pub only_files_regex: Option<String>,
|
||||
pub only_files: Option<Vec<usize>>,
|
||||
pub overwrite: bool,
|
||||
|
|
@ -164,20 +210,24 @@ impl<'a> AddTorrent<'a> {
|
|||
pub struct SessionOptions {
|
||||
pub disable_dht: bool,
|
||||
pub disable_dht_persistence: bool,
|
||||
pub persistence: bool,
|
||||
pub dht_config: Option<PersistentDhtConfig>,
|
||||
pub peer_id: Option<Id20>,
|
||||
pub peer_opts: Option<PeerConnectionOptions>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub async fn new(output_folder: PathBuf, spawner: BlockingSpawner) -> anyhow::Result<Self> {
|
||||
pub async fn new(
|
||||
output_folder: PathBuf,
|
||||
spawner: BlockingSpawner,
|
||||
) -> anyhow::Result<Arc<Self>> {
|
||||
Self::new_with_opts(output_folder, spawner, SessionOptions::default()).await
|
||||
}
|
||||
pub async fn new_with_opts(
|
||||
output_folder: PathBuf,
|
||||
spawner: BlockingSpawner,
|
||||
opts: SessionOptions,
|
||||
) -> anyhow::Result<Self> {
|
||||
) -> anyhow::Result<Arc<Self>> {
|
||||
let peer_id = opts.peer_id.unwrap_or_else(generate_peer_id);
|
||||
let dht = if opts.disable_dht {
|
||||
None
|
||||
|
|
@ -191,25 +241,117 @@ impl Session {
|
|||
Some(dht)
|
||||
};
|
||||
let peer_opts = opts.peer_opts.unwrap_or_default();
|
||||
|
||||
Ok(Self {
|
||||
let session_filename = output_folder.join(".rqbit-session.json");
|
||||
let session = Arc::new(Self {
|
||||
persistence_filename: session_filename,
|
||||
peer_id,
|
||||
dht,
|
||||
peer_opts,
|
||||
spawner,
|
||||
output_folder,
|
||||
locked: RwLock::new(SessionLocked::default()),
|
||||
})
|
||||
db: RwLock::new(Default::default()),
|
||||
});
|
||||
|
||||
if opts.persistence {
|
||||
let session = session.clone();
|
||||
spawn(
|
||||
"session persistene",
|
||||
error_span!("session persistence"),
|
||||
async move {
|
||||
// Populate initial from the state filename
|
||||
if let Err(e) = session.populate_from_stored().await {
|
||||
error!("could not populate session from stored file: {:?}", e);
|
||||
}
|
||||
|
||||
let session = Arc::downgrade(&session);
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
||||
let session = match session.upgrade() {
|
||||
Some(s) => s,
|
||||
None => break,
|
||||
};
|
||||
if let Err(e) = session.dump_to_disk() {
|
||||
error!("error dumping session to disk: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
pub fn get_dht(&self) -> Option<&Dht> {
|
||||
self.dht.as_ref()
|
||||
}
|
||||
|
||||
async fn populate_from_stored(&self) -> anyhow::Result<()> {
|
||||
let mut rdr = BufReader::new(
|
||||
std::fs::File::open(&self.persistence_filename).with_context(|| {
|
||||
format!("error opening session file {:?}", self.persistence_filename)
|
||||
})?,
|
||||
);
|
||||
let db: SerializedSessionDatabase =
|
||||
serde_json::from_reader(&mut rdr).context("error deserializing session database")?;
|
||||
for storrent in db.torrents.into_iter() {
|
||||
let magnet = Magnet {
|
||||
info_hash: Id20::from_str(&storrent.info_hash)
|
||||
.context("error deserializing info_hash")?,
|
||||
trackers: storrent.trackers.into_iter().collect(),
|
||||
};
|
||||
if let Err(e) = self
|
||||
.add_torrent(
|
||||
AddTorrent::Url(Cow::Owned(magnet.to_string())),
|
||||
Some(AddTorrentOptions {
|
||||
paused: storrent.is_paused,
|
||||
output_folder: Some(
|
||||
storrent
|
||||
.output_folder
|
||||
.to_str()
|
||||
.context("broken path")?
|
||||
.to_owned(),
|
||||
),
|
||||
only_files: storrent.only_files,
|
||||
overwrite: true,
|
||||
..Default::default()
|
||||
}),
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("error adding torrent from stored session: {:?}", e)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dump_to_disk(&self) -> anyhow::Result<()> {
|
||||
let tmp_filename = format!("{}.tmp", self.persistence_filename.to_str().unwrap());
|
||||
let mut tmp = BufWriter::new(
|
||||
std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.create_new(true)
|
||||
.truncate(true)
|
||||
.write(true)
|
||||
.open(&tmp_filename)
|
||||
.with_context(|| format!("error opening {:?}", tmp_filename))?,
|
||||
);
|
||||
let serialized = self.db.read().serialize();
|
||||
serde_json::to_writer(&mut tmp, &serialized).context("error serializing")?;
|
||||
drop(tmp);
|
||||
|
||||
std::fs::rename(&tmp_filename, &self.persistence_filename)
|
||||
.context("error renaming persistence file")?;
|
||||
debug!("wrote persistence to {:?}", &self.persistence_filename);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn with_torrents<R>(
|
||||
&self,
|
||||
callback: impl Fn(&mut dyn Iterator<Item = (TorrentId, &ManagedTorrentHandle)>) -> R,
|
||||
) -> R {
|
||||
callback(&mut self.locked.read().torrents.iter().map(|(id, t)| (*id, t)))
|
||||
callback(&mut self.db.read().torrents.iter().map(|(id, t)| (*id, t)))
|
||||
}
|
||||
|
||||
pub async fn add_torrent(
|
||||
|
|
@ -407,7 +549,7 @@ impl Session {
|
|||
}
|
||||
|
||||
let (managed_torrent, id) = {
|
||||
let mut g = self.locked.write();
|
||||
let mut g = self.db.write();
|
||||
if let Some((id, handle)) = g.torrents.iter().find(|(_, t)| t.info_hash() == info_hash)
|
||||
{
|
||||
return Ok(AddTorrentResponse::AlreadyManaged(*id, handle.clone()));
|
||||
|
|
@ -422,7 +564,7 @@ impl Session {
|
|||
let span = managed_torrent.info.span.clone();
|
||||
let _ = span.enter();
|
||||
managed_torrent
|
||||
.start(initial_peers, dht_peer_rx)
|
||||
.start(initial_peers, dht_peer_rx, opts.paused)
|
||||
.context("error starting torrent")?;
|
||||
}
|
||||
|
||||
|
|
@ -430,12 +572,12 @@ impl Session {
|
|||
}
|
||||
|
||||
pub fn get(&self, id: TorrentId) -> Option<ManagedTorrentHandle> {
|
||||
self.locked.read().torrents.get(&id).cloned()
|
||||
self.db.read().torrents.get(&id).cloned()
|
||||
}
|
||||
|
||||
pub fn delete(&self, id: TorrentId, delete_files: bool) -> anyhow::Result<()> {
|
||||
let removed = self
|
||||
.locked
|
||||
.db
|
||||
.write()
|
||||
.torrents
|
||||
.remove(&id)
|
||||
|
|
@ -477,7 +619,7 @@ impl Session {
|
|||
.as_ref()
|
||||
.map(|dht| dht.get_peers(handle.info_hash()))
|
||||
.transpose()?;
|
||||
handle.start(Default::default(), peer_rx)?;
|
||||
handle.start(Default::default(), peer_rx, false)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use tracing::{debug, trace, warn, Instrument};
|
||||
|
||||
pub fn spawn(
|
||||
name: &str,
|
||||
_name: &str,
|
||||
span: tracing::Span,
|
||||
fut: impl std::future::Future<Output = anyhow::Result<()>> + Send + 'static,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
|
|
@ -17,7 +17,7 @@ pub fn spawn(
|
|||
}
|
||||
}
|
||||
.instrument(span.or_current());
|
||||
tokio::task::Builder::new().name(name).spawn(fut).unwrap()
|
||||
tokio::task::spawn(fut)
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ pub struct ManagedTorrentInfo {
|
|||
|
||||
pub struct ManagedTorrent {
|
||||
pub info: Arc<ManagedTorrentInfo>,
|
||||
only_files: Option<Vec<usize>>,
|
||||
pub(crate) only_files: Option<Vec<usize>>,
|
||||
locked: RwLock<ManagedTorrentLocked>,
|
||||
}
|
||||
|
||||
|
|
@ -138,6 +138,7 @@ impl ManagedTorrent {
|
|||
self: &Arc<Self>,
|
||||
initial_peers: Vec<SocketAddr>,
|
||||
peer_rx: Option<impl StreamExt<Item = SocketAddr> + Unpin + Send + Sync + 'static>,
|
||||
start_paused: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut g = self.locked.write();
|
||||
|
||||
|
|
@ -185,6 +186,11 @@ impl ManagedTorrent {
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
if start_paused {
|
||||
g.state = ManagedTorrentState::Paused(paused);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let live = TorrentStateLive::new(paused);
|
||||
g.state = ManagedTorrentState::Live(live.clone());
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,17 @@ impl Magnet {
|
|||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Magnet {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"magnet:?xt=urn:btih:{}&tr={}",
|
||||
self.info_hash.as_string(),
|
||||
self.trackers.join("&tr=")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ readme = "README.md"
|
|||
|
||||
[features]
|
||||
default = ["sha1-system", "default-tls", "webui"]
|
||||
tokio-console = ["console-subscriber",]
|
||||
tokio-console = ["console-subscriber", "tokio/tracing"]
|
||||
webui = ["librqbit/webui"]
|
||||
timed_existence = ["librqbit/timed_existence"]
|
||||
sha1-system = ["librqbit/sha1-system"]
|
||||
|
|
@ -25,7 +25,7 @@ rust-tls = ["librqbit/rust-tls"]
|
|||
[dependencies]
|
||||
librqbit = {path="../librqbit", default-features=false, version = "3.3.0"}
|
||||
dht = {path="../dht", package="librqbit-dht", version="3.1.0"}
|
||||
tokio = {version = "1", features = ["macros", "rt-multi-thread", "tracing"]}
|
||||
tokio = {version = "1", features = ["macros", "rt-multi-thread"]}
|
||||
console-subscriber = {version = "0.2", optional = true}
|
||||
anyhow = "1"
|
||||
clap = {version = "4", features = ["derive", "deprecated"]}
|
||||
|
|
|
|||
|
|
@ -279,6 +279,7 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
|
|||
disable_dht: opts.disable_dht,
|
||||
disable_dht_persistence: opts.disable_dht_persistence,
|
||||
dht_config: None,
|
||||
persistence: true,
|
||||
peer_id: None,
|
||||
peer_opts: Some(PeerConnectionOptions {
|
||||
connect_timeout: Some(opts.peer_connect_timeout),
|
||||
|
|
@ -342,15 +343,13 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
|
|||
match &opts.subcommand {
|
||||
SubCommand::Server(server_opts) => match &server_opts.subcommand {
|
||||
ServerSubcommand::Start(start_opts) => {
|
||||
let session = Arc::new(
|
||||
Session::new_with_opts(
|
||||
PathBuf::from(&start_opts.output_folder),
|
||||
spawner,
|
||||
sopts,
|
||||
)
|
||||
.await
|
||||
.context("error initializing rqbit session")?,
|
||||
);
|
||||
let session = Session::new_with_opts(
|
||||
PathBuf::from(&start_opts.output_folder),
|
||||
spawner,
|
||||
sopts,
|
||||
)
|
||||
.await
|
||||
.context("error initializing rqbit session")?;
|
||||
spawn(
|
||||
"stats_printer",
|
||||
trace_span!("stats_printer"),
|
||||
|
|
@ -416,21 +415,19 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
|
|||
}
|
||||
Ok(())
|
||||
} else {
|
||||
let session = Arc::new(
|
||||
Session::new_with_opts(
|
||||
download_opts
|
||||
.output_folder
|
||||
.as_ref()
|
||||
.map(PathBuf::from)
|
||||
.context(
|
||||
"output_folder is required if can't connect to an existing server",
|
||||
)?,
|
||||
spawner,
|
||||
sopts,
|
||||
)
|
||||
.await
|
||||
.context("error initializing rqbit session")?,
|
||||
);
|
||||
let session = Session::new_with_opts(
|
||||
download_opts
|
||||
.output_folder
|
||||
.as_ref()
|
||||
.map(PathBuf::from)
|
||||
.context(
|
||||
"output_folder is required if can't connect to an existing server",
|
||||
)?,
|
||||
spawner,
|
||||
sopts,
|
||||
)
|
||||
.await
|
||||
.context("error initializing rqbit session")?;
|
||||
spawn(
|
||||
"stats_printer",
|
||||
trace_span!("stats_printer"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue