diff --git a/crates/librqbit/src/api.rs b/crates/librqbit/src/api.rs index b23b895..f795cb7 100644 --- a/crates/librqbit/src/api.rs +++ b/crates/librqbit/src/api.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, net::SocketAddr, sync::Arc}; +use std::{collections::HashSet, marker::PhantomData, net::SocketAddr, str::FromStr, sync::Arc}; use anyhow::Context; use buffers::ByteBufOwned; @@ -36,6 +36,93 @@ pub struct Api { line_broadcast: Option, } +#[derive(Debug, Clone, Copy)] +pub enum TorrentIdOrHash { + Id(TorrentId), + Hash(Id20), +} + +impl Serialize for TorrentIdOrHash { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + match self { + TorrentIdOrHash::Id(id) => id.serialize(serializer), + TorrentIdOrHash::Hash(h) => h.as_string().serialize(serializer), + } + } +} + +impl<'de> Deserialize<'de> for TorrentIdOrHash { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + #[derive(Default)] + struct V<'de> { + p: PhantomData<&'de ()>, + } + impl<'de> serde::de::Visitor<'de> for V<'de> { + type Value = TorrentIdOrHash; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("integer or 40 byte info hash") + } + + fn visit_str(self, v: &str) -> std::result::Result + where + E: serde::de::Error, + { + TorrentIdOrHash::parse(v) + .map_err(|_| E::custom("expected integer or 40 byte info hash")) + } + } + + deserializer.deserialize_str(V::default()) + } +} + +impl std::fmt::Display for TorrentIdOrHash { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TorrentIdOrHash::Id(id) => write!(f, "{}", id), + TorrentIdOrHash::Hash(h) => write!(f, "{:?}", h), + } + } +} + +impl From for TorrentIdOrHash { + fn from(value: TorrentId) -> Self { + TorrentIdOrHash::Id(value) + } +} + +impl From for TorrentIdOrHash { + fn from(value: Id20) -> Self { + TorrentIdOrHash::Hash(value) + } +} + +impl<'a> TryFrom<&'a str> for TorrentIdOrHash { + type Error = anyhow::Error; + + fn try_from(value: &'a str) -> std::result::Result { + Self::parse(value) + } +} + +impl TorrentIdOrHash { + pub fn parse(s: &str) -> anyhow::Result { + if s.len() == 40 { + let id = Id20::from_str(s)?; + return Ok(id.into()); + } + let id: TorrentId = s.parse()?; + Ok(id.into()) + } +} + impl Api { pub fn new( session: Arc, @@ -53,7 +140,7 @@ impl Api { &self.session } - pub fn mgr_handle(&self, idx: TorrentId) -> Result { + pub fn mgr_handle(&self, idx: TorrentIdOrHash) -> Result { self.session .get(idx) .ok_or(ApiError::torrent_not_found(idx)) @@ -71,14 +158,18 @@ impl Api { TorrentListResponse { torrents: items } } - pub fn api_torrent_details(&self, idx: TorrentId) -> Result { + pub fn api_torrent_details(&self, idx: TorrentIdOrHash) -> Result { let handle = self.mgr_handle(idx)?; let info_hash = handle.info().info_hash; let only_files = handle.only_files(); make_torrent_details(&info_hash, &handle.info().info, only_files.as_deref()) } - pub fn torrent_file_mime_type(&self, idx: TorrentId, file_idx: usize) -> Result<&'static str> { + pub fn torrent_file_mime_type( + &self, + idx: TorrentIdOrHash, + file_idx: usize, + ) -> Result<&'static str> { let handle = self.mgr_handle(idx)?; let info = &handle.info().info; torrent_file_mime_type(info, file_idx) @@ -86,7 +177,7 @@ impl Api { pub fn api_peer_stats( &self, - idx: TorrentId, + idx: TorrentIdOrHash, filter: PeerStatsFilter, ) -> Result { let handle = self.mgr_handle(idx)?; @@ -96,7 +187,10 @@ impl Api { .per_peer_stats_snapshot(filter)) } - pub async fn api_torrent_action_pause(&self, idx: TorrentId) -> Result { + pub async fn api_torrent_action_pause( + &self, + idx: TorrentIdOrHash, + ) -> Result { let handle = self.mgr_handle(idx)?; self.session() .pause(&handle) @@ -106,7 +200,10 @@ impl Api { Ok(Default::default()) } - pub async fn api_torrent_action_start(&self, idx: TorrentId) -> Result { + pub async fn api_torrent_action_start( + &self, + idx: TorrentIdOrHash, + ) -> Result { let handle = self.mgr_handle(idx)?; self.session .unpause(&handle) @@ -116,7 +213,10 @@ impl Api { Ok(Default::default()) } - pub async fn api_torrent_action_forget(&self, idx: TorrentId) -> Result { + pub async fn api_torrent_action_forget( + &self, + idx: TorrentIdOrHash, + ) -> Result { self.session .delete(idx, false) .await @@ -124,7 +224,10 @@ impl Api { Ok(Default::default()) } - pub async fn api_torrent_action_delete(&self, idx: TorrentId) -> Result { + pub async fn api_torrent_action_delete( + &self, + idx: TorrentIdOrHash, + ) -> Result { self.session .delete(idx, true) .await @@ -134,7 +237,7 @@ impl Api { pub async fn api_torrent_action_update_only_files( &self, - idx: TorrentId, + idx: TorrentIdOrHash, only_files: &HashSet, ) -> Result { let handle = self.mgr_handle(idx)?; @@ -240,23 +343,23 @@ impl Api { Ok(dht.with_routing_table(|r| r.clone())) } - pub fn api_stats_v0(&self, idx: TorrentId) -> Result { + pub fn api_stats_v0(&self, idx: TorrentIdOrHash) -> Result { let mgr = self.mgr_handle(idx)?; let live = mgr.live().context("torrent not live")?; Ok(LiveStats::from(&*live)) } - pub fn api_stats_v1(&self, idx: TorrentId) -> Result { + pub fn api_stats_v1(&self, idx: TorrentIdOrHash) -> Result { let mgr = self.mgr_handle(idx)?; Ok(mgr.stats()) } - pub fn api_dump_haves(&self, idx: usize) -> Result { + pub fn api_dump_haves(&self, idx: TorrentIdOrHash) -> Result { let mgr = self.mgr_handle(idx)?; Ok(mgr.with_chunk_tracker(|chunks| format!("{:?}", chunks.get_have_pieces()))?) } - pub fn api_stream(&self, idx: TorrentId, file_id: usize) -> Result { + pub fn api_stream(&self, idx: TorrentIdOrHash, file_id: usize) -> Result { let mgr = self.mgr_handle(idx)?; Ok(mgr.stream(file_id)?) } diff --git a/crates/librqbit/src/api_error.rs b/crates/librqbit/src/api_error.rs index 6b082f6..5710817 100644 --- a/crates/librqbit/src/api_error.rs +++ b/crates/librqbit/src/api_error.rs @@ -2,6 +2,8 @@ use axum::response::{IntoResponse, Response}; use http::StatusCode; use serde::{Serialize, Serializer}; +use crate::api::TorrentIdOrHash; + // Convenience error type. #[derive(Debug)] pub struct ApiError { @@ -19,7 +21,7 @@ impl ApiError { } } - pub const fn torrent_not_found(torrent_id: usize) -> Self { + pub const fn torrent_not_found(torrent_id: TorrentIdOrHash) -> Self { Self { status: Some(StatusCode::NOT_FOUND), kind: ApiErrorKind::TorrentNotFound(torrent_id), @@ -75,7 +77,7 @@ impl ApiError { #[derive(Debug)] enum ApiErrorKind { - TorrentNotFound(usize), + TorrentNotFound(TorrentIdOrHash), DhtDisabled, Text(&'static str), Other(anyhow::Error), @@ -93,7 +95,7 @@ impl Serialize for ApiError { status: u16, status_text: String, #[serde(skip_serializing_if = "Option::is_none")] - id: Option, + id: Option, } let mut serr: SerializedError = SerializedError { error_kind: match self.kind { diff --git a/crates/librqbit/src/http_api.rs b/crates/librqbit/src/http_api.rs index a300d24..5e67f9a 100644 --- a/crates/librqbit/src/http_api.rs +++ b/crates/librqbit/src/http_api.rs @@ -18,7 +18,7 @@ use tracing::{debug, info, trace}; use axum::Router; -use crate::api::Api; +use crate::api::{Api, TorrentIdOrHash}; use crate::peer_connection::PeerConnectionOptions; use crate::session::{AddTorrent, AddTorrentOptions, SUPPORTED_SCHEMES}; use crate::torrent_state::peer::stats::snapshot::PeerStatsFilter; @@ -124,7 +124,7 @@ impl HttpApi { async fn torrent_details( State(state): State, - Path(idx): Path, + Path(idx): Path, ) -> Result { state.api_torrent_details(idx).map(axum::Json) } @@ -168,7 +168,7 @@ impl HttpApi { fn build_playlist_content( host: &str, - it: impl IntoIterator, + it: impl IntoIterator, ) -> impl IntoResponse { let body = it .into_iter() @@ -240,7 +240,7 @@ impl HttpApi { async fn torrent_playlist( State(state): State, headers: HeaderMap, - Path(idx): Path, + Path(idx): Path, ) -> Result { let host = get_host(&headers)?; let playlist_items = torrent_playlist_items(&*state.mgr_handle(idx)?)?; @@ -263,7 +263,7 @@ impl HttpApi { torrent_playlist_items(handle) .map(move |items| { items.into_iter().map(move |(file_idx, filename)| { - (torrent_idx, file_idx, filename) + (torrent_idx.into(), file_idx, filename) }) }) .ok() @@ -276,28 +276,28 @@ impl HttpApi { async fn torrent_haves( State(state): State, - Path(idx): Path, + Path(idx): Path, ) -> Result { state.api_dump_haves(idx) } async fn torrent_stats_v0( State(state): State, - Path(idx): Path, + Path(idx): Path, ) -> Result { state.api_stats_v0(idx).map(axum::Json) } async fn torrent_stats_v1( State(state): State, - Path(idx): Path, + Path(idx): Path, ) -> Result { state.api_stats_v1(idx).map(axum::Json) } async fn peer_stats( State(state): State, - Path(idx): Path, + Path(idx): Path, Query(filter): Query, ) -> Result { state.api_peer_stats(idx, filter).map(axum::Json) @@ -305,7 +305,7 @@ impl HttpApi { async fn torrent_stream_file( State(state): State, - Path((idx, file_id)): Path<(usize, usize)>, + Path((idx, file_id)): Path<(TorrentIdOrHash, usize)>, headers: http::HeaderMap, ) -> Result { let mut stream = state.api_stream(idx, file_id)?; @@ -321,7 +321,7 @@ impl HttpApi { } let range_header = headers.get(http::header::RANGE); - trace!(torrent_id=idx, file_id=file_id, range=?range_header, "request for HTTP stream"); + trace!(torrent_id=%idx, file_id=file_id, range=?range_header, "request for HTTP stream"); if let Some(range) = range_header { let offset: Option = range @@ -366,28 +366,28 @@ impl HttpApi { async fn torrent_action_pause( State(state): State, - Path(idx): Path, + Path(idx): Path, ) -> Result { state.api_torrent_action_pause(idx).await.map(axum::Json) } async fn torrent_action_start( State(state): State, - Path(idx): Path, + Path(idx): Path, ) -> Result { state.api_torrent_action_start(idx).await.map(axum::Json) } async fn torrent_action_forget( State(state): State, - Path(idx): Path, + Path(idx): Path, ) -> Result { state.api_torrent_action_forget(idx).await.map(axum::Json) } async fn torrent_action_delete( State(state): State, - Path(idx): Path, + Path(idx): Path, ) -> Result { state.api_torrent_action_delete(idx).await.map(axum::Json) } @@ -399,7 +399,7 @@ impl HttpApi { async fn torrent_action_update_only_files( State(state): State, - Path(idx): Path, + Path(idx): Path, axum::Json(req): axum::Json, ) -> Result { state diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 848e4f4..f645156 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -9,6 +9,7 @@ use std::{ }; use crate::{ + api::TorrentIdOrHash, dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult}, merge_streams::merge_streams, peer_connection::PeerConnectionOptions, @@ -1091,11 +1092,36 @@ impl Session { Ok(AddTorrentResponse::Added(id, managed_torrent)) } - pub fn get(&self, id: TorrentId) -> Option { - self.db.read().torrents.get(&id).cloned() + pub fn get(&self, id: TorrentIdOrHash) -> Option { + match id { + TorrentIdOrHash::Id(id) => self.db.read().torrents.get(&id).cloned(), + TorrentIdOrHash::Hash(id) => self.db.read().torrents.iter().find_map(|(_, v)| { + if v.info_hash() == id { + Some(v.clone()) + } else { + None + } + }), + } } - pub async fn delete(&self, id: TorrentId, delete_files: bool) -> anyhow::Result<()> { + pub async fn delete(&self, id: TorrentIdOrHash, delete_files: bool) -> anyhow::Result<()> { + let id = match id { + TorrentIdOrHash::Id(id) => id, + TorrentIdOrHash::Hash(h) => self + .db + .read() + .torrents + .values() + .find_map(|v| { + if v.info_hash() == h { + Some(v.id()) + } else { + None + } + }) + .context("no such torrent in db")?, + }; let removed = self .db .write() diff --git a/crates/librqbit/src/tests/e2e.rs b/crates/librqbit/src/tests/e2e.rs index 6485a55..3370ecd 100644 --- a/crates/librqbit/src/tests/e2e.rs +++ b/crates/librqbit/src/tests/e2e.rs @@ -228,7 +228,7 @@ async fn test_e2e_download() { } info!("handle is completed"); - session.delete(id, false).await.unwrap(); + session.delete(id.into(), false).await.unwrap(); info!("deleted handle");