Session persistence rewritten completely

This commit is contained in:
Igor Katson 2024-08-15 10:40:48 +01:00
parent c4fc107c4e
commit 83592ca866
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
12 changed files with 431 additions and 299 deletions

View file

@ -1,11 +1,9 @@
use std::{
any::TypeId,
borrow::Cow,
collections::{HashMap, HashSet},
io::{BufReader, BufWriter, Read},
io::Read,
net::SocketAddr,
path::{Path, PathBuf},
str::FromStr,
sync::Arc,
time::Duration,
};
@ -15,6 +13,9 @@ use crate::{
merge_streams::merge_streams,
peer_connection::PeerConnectionOptions,
read_buf::ReadBuf,
session_persistence::{
json::JsonSessionPersistenceStore, BoxSessionPersistenceStore, SessionPersistenceStore,
},
spawn_utils::BlockingSpawner,
storage::{
filesystem::FilesystemStorageFactory, BoxStorageFactory, StorageFactoryExt, TorrentStorage,
@ -27,7 +28,7 @@ use crate::{
ManagedTorrentInfo,
};
use anyhow::{bail, Context};
use bencode::{bencode_serialize_to_writer, BencodeDeserializer};
use bencode::bencode_serialize_to_writer;
use buffers::{ByteBuf, ByteBufOwned, ByteBufT};
use bytes::Bytes;
use clone_to_owned::CloneToOwned;
@ -48,7 +49,7 @@ use librqbit_core::{
};
use parking_lot::RwLock;
use peer_binary_protocol::Handshake;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use tokio::net::{TcpListener, TcpStream};
use tokio_stream::StreamExt;
use tokio_util::sync::{CancellationToken, DropGuard};
@ -103,130 +104,15 @@ impl SessionDatabase {
}
let idx = self.next_id;
self.torrents.insert(idx, torrent);
self.next_id += 1;
self.next_id = self.next_id.max(idx) + 1;
idx
}
fn serialize(&self) -> SerializedSessionDatabase {
SerializedSessionDatabase {
torrents: self
.torrents
.iter()
// We don't support serializing / deserializing of other storage types.
.filter(|(_, torrent)| {
torrent
.storage_factory
.is_type_id(TypeId::of::<FilesystemStorageFactory>())
})
.map(|(id, torrent)| {
(
*id,
SerializedTorrent {
trackers: torrent
.info()
.trackers
.iter()
.map(|u| u.to_string())
.collect(),
info_hash: torrent.info_hash().as_string(),
// TODO: this could take up too much space / time / resources to write on interval.
// Store this outside the JSON file
//
// torrent_bytes: torrent.info.torrent_bytes.clone(),
torrent_bytes: Bytes::new(),
info: torrent.info().info.clone(),
only_files: torrent.only_files().clone(),
is_paused: torrent
.with_state(|s| matches!(s, ManagedTorrentState::Paused(_))),
output_folder: torrent.info().options.output_folder.clone(),
},
)
})
.collect(),
}
}
}
#[derive(Serialize, Deserialize)]
struct SerializedTorrent {
info_hash: String,
#[serde(
serialize_with = "serialize_torrent",
deserialize_with = "deserialize_torrent"
)]
info: TorrentMetaV1Info<ByteBufOwned>,
#[serde(
serialize_with = "serialize_torrent_bytes",
deserialize_with = "deserialize_torrent_bytes",
default
)]
torrent_bytes: Bytes,
trackers: HashSet<String>,
output_folder: PathBuf,
only_files: Option<Vec<usize>>,
is_paused: bool,
}
fn serialize_torrent<S>(
t: &TorrentMetaV1Info<ByteBufOwned>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use base64::{engine::general_purpose, Engine as _};
use serde::ser::Error;
let mut writer = Vec::new();
bencode_serialize_to_writer(t, &mut writer).map_err(S::Error::custom)?;
let s = general_purpose::STANDARD_NO_PAD.encode(&writer);
s.serialize(serializer)
}
fn deserialize_torrent<'de, D>(deserializer: D) -> Result<TorrentMetaV1Info<ByteBufOwned>, D::Error>
where
D: Deserializer<'de>,
{
use base64::{engine::general_purpose, Engine as _};
use serde::de::Error;
let s = String::deserialize(deserializer)?;
let b = general_purpose::STANDARD_NO_PAD
.decode(s)
.map_err(D::Error::custom)?;
TorrentMetaV1Info::<ByteBufOwned>::deserialize(&mut BencodeDeserializer::new_from_buf(&b))
.map_err(D::Error::custom)
}
fn serialize_torrent_bytes<S>(t: &Bytes, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use base64::{engine::general_purpose, Engine as _};
let s = general_purpose::STANDARD_NO_PAD.encode(t);
s.serialize(serializer)
}
fn deserialize_torrent_bytes<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
where
D: Deserializer<'de>,
{
use base64::{engine::general_purpose, Engine as _};
use serde::de::Error;
let s = String::deserialize(deserializer)?;
let b = general_purpose::STANDARD_NO_PAD
.decode(s)
.map_err(D::Error::custom)?;
Ok(b.into())
}
#[derive(Serialize, Deserialize)]
struct SerializedSessionDatabase {
torrents: HashMap<usize, SerializedTorrent>,
}
pub struct Session {
peer_id: Id20,
dht: Option<Dht>,
persistence_filename: PathBuf,
persistence: Option<Box<dyn SessionPersistenceStore>>,
peer_opts: PeerConnectionOptions,
spawner: BlockingSpawner,
db: RwLock<SessionDatabase>,
@ -463,6 +349,11 @@ impl<'a> AddTorrent<'a> {
}
}
pub enum SessionPersistenceConfig {
/// The filename for persistence. By default uses an OS-specific folder.
Json { folder: Option<PathBuf> },
}
#[derive(Default)]
pub struct SessionOptions {
/// Turn on to disable DHT.
@ -476,9 +367,7 @@ pub struct SessionOptions {
/// Turn on to dump session contents into a file periodically, so that on next start
/// all remembered torrents will continue where they left off.
pub persistence: bool,
/// The filename for persistence. By default uses an OS-specific folder.
pub persistence_filename: Option<PathBuf>,
pub persistence: Option<SessionPersistenceConfig>,
/// The peer ID to use. If not specified, a random one will be generated.
pub peer_id: Option<Id20>,
@ -557,11 +446,6 @@ impl Session {
Self::new_with_opts(default_output_folder, SessionOptions::default())
}
pub fn default_persistence_filename() -> anyhow::Result<PathBuf> {
let dir = get_configuration_directory("session")?;
Ok(dir.data_dir().join("session.json"))
}
pub fn cancellation_token(&self) -> &CancellationToken {
&self.cancellation_token
}
@ -576,15 +460,16 @@ impl Session {
let peer_id = opts.peer_id.unwrap_or_else(generate_peer_id);
let token = CancellationToken::new();
let (tcp_listener, tcp_listen_port) = if let Some(port_range) = opts.listen_port_range {
let (l, p) = create_tcp_listener(port_range)
.await
.context("error listening on TCP")?;
info!("Listening on 0.0.0.0:{p} for incoming peer connections");
(Some(l), Some(p))
} else {
(None, None)
};
let (tcp_listener, tcp_listen_port) =
if let Some(port_range) = opts.listen_port_range.clone() {
let (l, p) = create_tcp_listener(port_range)
.await
.context("error listening on TCP")?;
info!("Listening on 0.0.0.0:{p} for incoming peer connections");
(Some(l), Some(p))
} else {
(None, None)
};
let dht = if opts.disable_dht {
None
@ -606,11 +491,36 @@ impl Session {
Some(dht)
};
let peer_opts = opts.peer_opts.unwrap_or_default();
let persistence_filename = match opts.persistence_filename {
Some(filename) => filename,
None if !opts.persistence => PathBuf::new(),
None => Self::default_persistence_filename()?,
};
async fn persistence_factory(
opts: &SessionOptions,
) -> anyhow::Result<Option<BoxSessionPersistenceStore>> {
pub fn default_persistence_folder() -> anyhow::Result<PathBuf> {
let dir = get_configuration_directory("session")?;
Ok(dir.data_dir().to_owned())
}
match &opts.persistence {
Some(SessionPersistenceConfig::Json { folder }) => {
let folder = match folder.as_ref() {
Some(f) => f.clone(),
None => default_persistence_folder()?,
};
Ok(Some(Box::new(
JsonSessionPersistenceStore::new(folder)
.await
.context("error initializing JsonSessionPersistenceStore")?,
)))
}
None => Ok(None),
}
}
let persistence = persistence_factory(&opts)
.await
.context("error initializing session persistence store")?;
let spawner = BlockingSpawner::default();
let (disk_write_tx, disk_write_rx) = opts
@ -646,7 +556,7 @@ impl Session {
let stream_connector = Arc::new(StreamConnector::from(proxy_config));
let session = Arc::new(Self {
persistence_filename,
persistence,
peer_id,
dht,
peer_opts,
@ -688,18 +598,34 @@ impl Session {
}
}
if opts.persistence {
info!(
"will use {:?} for session persistence",
session.persistence_filename
);
if let Some(parent) = session.persistence_filename.parent() {
std::fs::create_dir_all(parent).with_context(|| {
format!("couldn't create directory {:?} for session storage", parent)
})?;
if let Some(persistence) = session.persistence.as_ref() {
info!("will use {persistence:?} for session persistence");
let mut ps = persistence.stream_all().await?;
let mut added_all = false;
let mut futs = FuturesUnordered::new();
while !added_all || !futs.is_empty() {
tokio::select! {
Some(res) = futs.next(), if !futs.is_empty() => {
if let Err(e) = res {
error!("error adding torrent to session: {e:?}");
}
},
st = ps.next(), if !added_all => {
if let Some(st) = st {
let (id, st) = st?;
let span = error_span!("add_torrent", info_hash=?st.info_hash());
let (add_torrent, mut opts) = st.into_add_torrent()?;
opts.preferred_id = Some(id);
let fut = session.add_torrent(add_torrent, Some(opts)).instrument(span);
futs.push(fut);
} else {
added_all = true;
}
},
}
}
let persistence_task = session.clone().task_persistence();
session.spawn(error_span!("session_persistence"), persistence_task);
}
Ok(session)
@ -707,29 +633,6 @@ impl Session {
.boxed()
}
async fn task_persistence(self: Arc<Self>) -> anyhow::Result<()> {
// Populate initial from the state filename
if let Err(e) = self.populate_from_stored().await {
error!("could not populate session from stored file: {:?}", e);
}
let session = Arc::downgrade(&self);
drop(self);
loop {
tokio::time::sleep(Duration::from_secs(10)).await;
let session = match session.upgrade() {
Some(s) => s,
None => break,
};
if let Err(e) = session.dump_to_disk() {
error!("error dumping session to disk: {:?}", e);
}
}
Ok(())
}
async fn check_incoming_connection(
&self,
addr: SocketAddr,
@ -868,102 +771,6 @@ impl Session {
tokio::time::sleep(Duration::from_secs(1)).await;
}
async fn populate_from_stored(self: &Arc<Self>) -> anyhow::Result<()> {
let mut rdr = match std::fs::File::open(&self.persistence_filename) {
Ok(f) => BufReader::new(f),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
Err(e) => {
return Err(e).context(format!(
"error opening session file {:?}",
self.persistence_filename
))
}
};
let db: SerializedSessionDatabase =
serde_json::from_reader(&mut rdr).context("error deserializing session database")?;
let mut futures = Vec::new();
for (id, storrent) in db.torrents.into_iter() {
let trackers: Vec<ByteBufOwned> = storrent
.trackers
.into_iter()
.map(|t| ByteBufOwned::from(t.into_bytes()))
.collect();
let torrent_bytes = storrent.torrent_bytes;
let add_torrent = if !torrent_bytes.is_empty() {
AddTorrent::TorrentFileBytes(torrent_bytes)
} else {
let info_hash = Id20::from_str(&storrent.info_hash)?;
debug!(?info_hash, "torrent added before 6.1.0, need to readd");
let info = TorrentMetaV1Owned {
announce: trackers.first().cloned(),
announce_list: vec![trackers],
info: storrent.info,
comment: None,
created_by: None,
encoding: None,
publisher: None,
publisher_url: None,
creation_date: None,
info_hash,
};
AddTorrent::TorrentInfo(Box::new(info))
};
futures.push({
let session = self.clone();
async move {
session
.add_torrent(
add_torrent,
Some(AddTorrentOptions {
paused: storrent.is_paused,
output_folder: Some(
storrent
.output_folder
.to_str()
.context("broken path")?
.to_owned(),
),
only_files: storrent.only_files,
overwrite: true,
preferred_id: Some(id),
..Default::default()
}),
)
.await
.map_err(|e| {
error!("error adding torrent from stored session: {:?}", e);
e
})
}
});
}
futures::future::join_all(futures).await;
Ok(())
}
fn dump_to_disk(&self) -> anyhow::Result<()> {
let tmp_filename = format!("{}.tmp", self.persistence_filename.to_str().unwrap());
let mut tmp = BufWriter::new(
std::fs::OpenOptions::new()
.create(true)
.truncate(true)
.write(true)
.open(&tmp_filename)
.with_context(|| format!("error opening {:?}", tmp_filename))?,
);
let serialized = self.db.read().serialize();
serde_json::to_writer(&mut tmp, &serialized).context("error serializing")?;
drop(tmp);
std::fs::rename(&tmp_filename, &self.persistence_filename)
.context("error renaming persistence file")?;
trace!(filename=?self.persistence_filename, "wrote persistence");
Ok(())
}
/// Run a callback given the currently managed torrents.
pub fn with_torrents<R>(
&self,
@ -1256,13 +1063,17 @@ impl Session {
{
return Ok(AddTorrentResponse::AlreadyManaged(*id, handle.clone()));
}
let next_id = g.torrents.len();
let next_id = g.next_id;
let managed_torrent =
builder.build(error_span!(parent: None, "torrent", id = next_id))?;
let id = g.add_torrent(managed_torrent.clone(), opts.preferred_id);
(managed_torrent, id)
};
if let Some(p) = self.persistence.as_ref() {
p.store(id, &managed_torrent).await?;
}
// Merge "initial_peers" and "peer_rx" into one stream.
let peer_rx = merge_two_optional_streams(
if !initial_peers.is_empty() {
@ -1289,7 +1100,7 @@ impl Session {
self.db.read().torrents.get(&id).cloned()
}
pub fn delete(&self, id: TorrentId, delete_files: bool) -> anyhow::Result<()> {
pub async fn delete(&self, id: TorrentId, delete_files: bool) -> anyhow::Result<()> {
let removed = self
.db
.write()
@ -1301,6 +1112,12 @@ impl Session {
debug!("error pausing torrent before deletion: {e:?}")
}
if let Some(p) = self.persistence.as_ref() {
if let Err(e) = p.delete(id).await {
error!(error=?e, "error deleting torrent from database");
}
}
let storage = removed
.with_state_mut(|s| match s.take() {
ManagedTorrentState::Initializing(p) => p.files.take().ok(),