diff --git a/crates/librqbit/src/http_api.rs b/crates/librqbit/src/http_api.rs index 96c3e20..22e3ba4 100644 --- a/crates/librqbit/src/http_api.rs +++ b/crates/librqbit/src/http_api.rs @@ -1,7 +1,7 @@ use anyhow::Context; use dht::{Dht, DhtStats}; use parking_lot::RwLock; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -9,7 +9,7 @@ use warp::hyper::body::Bytes; use warp::hyper::Body; use warp::Filter; -use crate::session::Session; +use crate::session::{AddTorrentOptions, Session}; use crate::torrent_manager::TorrentManagerHandle; use crate::torrent_state::StatsSnapshot; @@ -127,10 +127,14 @@ impl ApiInternal { Some(TorrentDetailsResponse { info_hash, files }) } - async fn api_add_torrent(&self, url: String) -> anyhow::Result { + async fn api_add_torrent( + &self, + url: String, + opts: Option, + ) -> anyhow::Result { let handle = self .session - .add_torrent(url, None) + .add_torrent(url, opts) .await .context("error adding torrent")? .context("expected session.add_torrent() to return a handle")?; @@ -258,16 +262,22 @@ impl HttpApi { move || json_response(inner.api_torrent_list()) }); + #[derive(Deserialize)] + struct TorrentAddQueryParams { + overwrite: Option, + } + let torrent_add = warp::post() .and(warp::path("torrents")) .and(warp::body::bytes()) + .and(warp::query()) .and_then({ let inner = inner.clone(); use warp::http::Response; fn make_response(status: u16, body: String) -> Response { Response::builder().status(status).body(body).unwrap() } - move |body: Bytes| { + move |body: Bytes, params: TorrentAddQueryParams| { let inner = inner.clone(); async move { let url = match String::from_utf8(body.to_vec()) { @@ -279,8 +289,12 @@ impl HttpApi { )) } }; + let opts = AddTorrentOptions { + overwrite: params.overwrite.unwrap_or(false), + ..Default::default() + }; let idx = inner - .api_add_torrent(url) + .api_add_torrent(url, Some(opts)) .await .context("error calling HttpApi::api_add_torrent"); match idx { diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index f34a14f..001917e 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -9,6 +9,7 @@ use librqbit_core::{ torrent_metainfo::{torrent_from_bytes, TorrentMetaV1Info, TorrentMetaV1Owned}, }; use log::{info, warn}; +use parking_lot::RwLock; use reqwest::Url; use tokio_stream::StreamExt; @@ -21,11 +22,45 @@ use crate::{ torrent_manager::{TorrentManagerBuilder, TorrentManagerHandle}, }; +pub enum ManagedTorrentState { + Initializing, + Running(TorrentManagerHandle), +} + +pub struct ManagedTorrent { + info_hash: Id20, + output_folder: PathBuf, + state: ManagedTorrentState, +} + +impl PartialEq for ManagedTorrent { + fn eq(&self, other: &Self) -> bool { + self.info_hash == other.info_hash && self.output_folder == other.output_folder + } +} + +#[derive(Default)] +pub struct SessionLocked { + torrents: Vec, +} + +impl SessionLocked { + fn add_torrent(&mut self, torrent: ManagedTorrent) -> Option { + if self.torrents.contains(&torrent) { + return None; + } + let idx = self.torrents.len(); + self.torrents.push(torrent); + Some(idx) + } +} + pub struct Session { peer_id: Id20, dht: Option, peer_opts: PeerConnectionOptions, spawner: BlockingSpawner, + locked: RwLock, output_folder: PathBuf, } @@ -126,6 +161,7 @@ impl Session { peer_opts, spawner, output_folder, + locked: RwLock::new(SessionLocked::default()), }) } pub fn get_dht(&self) -> Option { @@ -271,7 +307,21 @@ impl Session { .map(PathBuf::from) .unwrap_or_else(|| self.output_folder.clone()); - let mut builder = TorrentManagerBuilder::new(info, info_hash, output_folder); + let managed_torrent = ManagedTorrent { + info_hash, + output_folder: output_folder.clone(), + state: ManagedTorrentState::Initializing, + }; + + if self.locked.write().add_torrent(managed_torrent).is_none() { + anyhow::bail!( + "torrent with info_hash {:?} that is downloaded to {:?} is already managed", + info_hash, + &output_folder + ); + }; + + let mut builder = TorrentManagerBuilder::new(info, info_hash, output_folder.clone()); builder .overwrite(opts.overwrite) .spawner(self.spawner) @@ -287,7 +337,40 @@ impl Session { builder.peer_connect_timeout(t); } - let handle = builder.start_manager()?; + let handle = match builder + .start_manager() + .context("error starting torrent manager") + { + Ok(handle) => { + let mut g = self.locked.write(); + let m = g + .torrents + .iter_mut() + .find(|t| t.info_hash == info_hash && t.output_folder == output_folder) + .unwrap(); + m.state = ManagedTorrentState::Running(handle.clone()); + handle + } + Err(error) => { + let mut g = self.locked.write(); + let idx = g + .torrents + .iter() + .position(|t| t.info_hash == info_hash && t.output_folder == output_folder) + .unwrap(); + g.torrents.remove(idx); + return Err(error); + } + }; + { + let mut g = self.locked.write(); + let m = g + .torrents + .iter_mut() + .find(|t| t.info_hash == info_hash && t.output_folder == output_folder) + .unwrap(); + m.state = ManagedTorrentState::Running(handle.clone()); + } for url in trackers { handle.add_tracker(url);