diff --git a/Cargo.lock b/Cargo.lock index 45b64ba..0f54b89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -233,6 +233,7 @@ dependencies = [ "anyhow", "bencode", "clone_to_owned", + "directories", "futures", "hex 0.4.3", "indexmap", @@ -256,6 +257,26 @@ dependencies = [ "generic-array 0.14.4", ] +[[package]] +name = "directories" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e69600ff1703123957937708eb27f7a564e48885c537782722ed0ba3189ce1d7" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03d86534ed367a67548dc68113a0f5db55432fdfbb6e6f9d77704397d95d5780" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + [[package]] name = "encoding_rs" version = "0.8.28" @@ -1252,6 +1273,16 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redox_users" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "528532f3d801c87aec9def2add9ca802fe569e44a544afe633765267840abe64" +dependencies = [ + "getrandom 0.2.3", + "redox_syscall", +] + [[package]] name = "regex" version = "1.5.4" diff --git a/crates/dht/Cargo.toml b/crates/dht/Cargo.toml index 7ea32e3..54cd8a4 100644 --- a/crates/dht/Cargo.toml +++ b/crates/dht/Cargo.toml @@ -19,6 +19,7 @@ pretty_env_logger = "0.4" futures = "0.3" rand = "0.8" indexmap = "1.7" +directories = "3" clone_to_owned = {path="../clone_to_owned"} librqbit_core = {path="../librqbit_core"} diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index 2eea931..aa1058e 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -11,7 +11,6 @@ use crate::{ MessageKind, Node, }, routing_table::{InsertResult, RoutingTable}, - DHT_BOOTSTRAP, }; use anyhow::Context; use bencode::ByteString; @@ -26,7 +25,7 @@ use tokio::{ net::UdpSocket, sync::mpsc::{channel, unbounded_channel, Sender, UnboundedReceiver, UnboundedSender}, }; -use tokio_stream::wrappers::BroadcastStream; +use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; #[derive(Debug, Serialize)] pub struct DhtStats { @@ -58,12 +57,17 @@ struct DhtState { } impl DhtState { - fn new(id: Id20, sender: UnboundedSender<(Message, SocketAddr)>) -> Self { + fn new( + id: Id20, + sender: UnboundedSender<(Message, SocketAddr)>, + routing_table: Option, + ) -> Self { + let routing_table = routing_table.unwrap_or_else(|| RoutingTable::new(id)); Self { id, next_transaction_id: 0, outstanding_requests: Default::default(), - routing_table: RoutingTable::new(id), + routing_table, sender, seen_peers: Default::default(), get_peers_subscribers: Default::default(), @@ -569,10 +573,14 @@ impl Stream for PeerStream { ) -> Poll> { loop { if let Some((pos, end)) = self.initial_peers_pos.take() { - let g = self.state.lock(); - let seen = g.seen_peers.get(&self.info_hash).unwrap(); - let addr = *seen.get_index(pos).unwrap(); - drop(g); + let addr = *self + .state + .lock() + .seen_peers + .get(&self.info_hash) + .unwrap() + .get_index(pos) + .unwrap(); if pos < end { self.initial_peers_pos = Some((pos + 1, end)); } @@ -580,50 +588,52 @@ impl Stream for PeerStream { return Poll::Ready(Some(addr)); } - let r = match self.broadcast_rx.poll_next_unpin(cx) { - Poll::Ready(r) => match r { - Some(r) => r, - None => return Poll::Ready(None), - }, - Poll::Pending => return Poll::Pending, - }; - - match r { - Ok(v) => { + match self.broadcast_rx.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(v))) => { self.absolute_stream_pos += 1; return Poll::Ready(Some(v)); } - Err(e) => match e { - tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(lagged_by) => { - debug!("peer stream is lagged by {}", lagged_by); - let s = self.absolute_stream_pos; - let e = s + lagged_by as usize; - self.initial_peers_pos = Some((s, e)); - continue; - } - }, - } + Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(lagged_by)))) => { + debug!("peer stream is lagged by {}", lagged_by); + let s = self.absolute_stream_pos; + let e = s + lagged_by as usize; + self.initial_peers_pos = Some((s, e)); + continue; + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + }; } } } +#[derive(Default)] +pub struct DhtConfig { + pub peer_id: Option, + pub bootstrap_addrs: Option>, + pub routing_table: Option, +} + impl Dht { pub async fn new() -> anyhow::Result { - Self::with_bootstrap_addrs(DHT_BOOTSTRAP).await + Self::with_config(DhtConfig::default()).await } - pub async fn with_bootstrap_addrs(bootstrap_addrs: &[&str]) -> anyhow::Result { + pub async fn with_config(config: DhtConfig) -> anyhow::Result { let socket = UdpSocket::bind("0.0.0.0:0") .await .context("error binding socket")?; - let peer_id = generate_peer_id(); + let peer_id = config.peer_id.unwrap_or_else(generate_peer_id); info!("starting up DHT with peer id {:?}", peer_id); - let bootstrap_addrs = bootstrap_addrs - .iter() - .map(|s| s.to_string()) - .collect::>(); + let bootstrap_addrs = config + .bootstrap_addrs + .unwrap_or_else(|| crate::DHT_BOOTSTRAP.iter().map(|v| v.to_string()).collect()); let (in_tx, in_rx) = unbounded_channel(); - let state = Arc::new(Mutex::new(DhtState::new(peer_id, in_tx.clone()))); + let state = Arc::new(Mutex::new(DhtState::new( + peer_id, + in_tx.clone(), + config.routing_table, + ))); tokio::spawn({ let state = state.clone(); diff --git a/crates/dht/src/lib.rs b/crates/dht/src/lib.rs index ff5fee1..7dba05c 100644 --- a/crates/dht/src/lib.rs +++ b/crates/dht/src/lib.rs @@ -1,10 +1,12 @@ mod bprotocol; mod dht; +mod persistence; mod routing_table; mod utils; -pub use dht::Dht; pub use dht::DhtStats; +pub use dht::{Dht, DhtConfig}; pub use librqbit_core::id20::Id20; +pub use persistence::{PersistentDht, PersistentDhtConfig}; pub static DHT_BOOTSTRAP: &[&str] = &["dht.transmissionbt.com:6881", "dht.libtorrent.org:25401"]; diff --git a/crates/dht/src/persistence.rs b/crates/dht/src/persistence.rs new file mode 100644 index 0000000..1cdbff0 --- /dev/null +++ b/crates/dht/src/persistence.rs @@ -0,0 +1,119 @@ +// TODO: this now stores only the routing table, but we also need AT LEAST the same socket address... + +use std::fs::OpenOptions; +use std::path::{Path, PathBuf}; +use std::time::Duration; + +use anyhow::Context; +use log::{debug, error, info, warn}; +use tokio::spawn; + +use crate::dht::{Dht, DhtConfig}; +use crate::routing_table::RoutingTable; + +#[derive(Default, Clone)] +pub struct PersistentDhtConfig { + pub dump_interval: Option, + pub config_filename: Option, +} + +pub struct PersistentDht { + // config_filename: PathBuf, +} + +fn dump_dht(dht: &Dht, filename: &Path, tempfile_name: &Path) -> anyhow::Result<()> { + let mut file = OpenOptions::new() + .truncate(true) + .create(true) + .write(true) + .open(&tempfile_name) + .with_context(|| format!("error opening {:?}", tempfile_name))?; + + match dht.with_routing_table(|r| serde_json::to_writer(&mut file, r)) { + Ok(_) => { + debug!("dumped DHT to {:?}", &tempfile_name); + } + Err(e) => { + return Err(e).with_context(|| { + format!("error serializing DHT routing table to {:?}", tempfile_name) + }) + } + } + + std::fs::rename(tempfile_name, filename) + .with_context(|| format!("error renaming {:?} to {:?}", tempfile_name, filename)) +} + +impl PersistentDht { + pub async fn create(config: Option) -> anyhow::Result { + let mut config = config.unwrap_or_default(); + let config_filename = match config.config_filename.take() { + Some(config_filename) => config_filename, + None => { + let dirs = directories::ProjectDirs::from("com", "rqbit", "dht") + .context("cannot determine project directory for com.rqbit.dht")?; + let path = dirs.cache_dir().join("dht.json"); + info!("will store DHT routing table to {:?} periodically", &path); + path + } + }; + + if let Some(parent) = config_filename.parent() { + std::fs::create_dir_all(parent) + .with_context(|| format!("error creating dir {:?}", &parent))?; + } + + let routing_table = match OpenOptions::new().read(true).open(&config_filename) { + Ok(dht_json) => match serde_json::from_reader::<_, RoutingTable>(&dht_json) { + Ok(r) => { + info!("loaded DHT routing table from {:?}", &config_filename); + Some(r) + } + Err(e) => { + warn!( + "cannot deserialize routing table from file {:?}: {:#}", + &config_filename, e + ); + None + } + }, + Err(e) => match e.kind() { + std::io::ErrorKind::NotFound => None, + _ => return Err(e).with_context(|| format!("error reading {:?}", config_filename)), + }, + }; + let peer_id = routing_table.as_ref().map(|r| r.id()); + let dht_config = DhtConfig { + peer_id, + routing_table, + ..Default::default() + }; + let dht = Dht::with_config(dht_config).await?; + + spawn({ + let dht = dht.clone(); + let dump_interval = config + .dump_interval + .unwrap_or_else(|| Duration::from_secs(3)); + async move { + let tempfile_name = { + let file_name = format!("dht.json.tmp.{}", std::process::id()); + let mut tmp = config_filename.clone(); + tmp.set_file_name(file_name); + tmp + }; + + loop { + tokio::time::sleep(dump_interval).await; + debug!("dumping DHT to {:?}", &config_filename); + + match dump_dht(&dht, &config_filename, &tempfile_name) { + Ok(_) => debug!("dumped DHT to {:?}", &config_filename), + Err(e) => error!("error dumping DHT to {:?}: {:#}", &config_filename, e), + } + } + } + }); + Ok(dht) + } +} diff --git a/crates/dht/src/routing_table.rs b/crates/dht/src/routing_table.rs index 096cf48..049262e 100644 --- a/crates/dht/src/routing_table.rs +++ b/crates/dht/src/routing_table.rs @@ -5,16 +5,16 @@ use std::{ use librqbit_core::id20::Id20; use log::debug; -use serde::{ser::SerializeMap, Serialize}; +use serde::{ser::SerializeMap, Deserialize, Serialize}; -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] enum BucketTreeNodeData { // TODO: maybe replace that with SmallVec<8>? Leaf(Vec), LeftRight(usize, usize), } -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct BucketTreeNode { bits: u8, #[serde(serialize_with = "crate::utils::serialize_id20")] @@ -29,6 +29,48 @@ pub struct BucketTree { data: Vec, } +impl<'de> Deserialize<'de> for BucketTree { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct Visitor; + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = BucketTree; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "a map with key \"flat\"") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut data: Option> = None; + loop { + match map.next_key::()?.as_deref() { + Some("flat") => { + let buckets = map.next_value::>()?; + data = Some(buckets) + } + Some(_) => { + map.next_value::()?; + } + None => { + use serde::de::Error; + match data.take() { + Some(data) => return Ok(BucketTree { data }), + None => return Err(A::Error::missing_field("flat")), + } + } + } + } + } + } + deserializer.deserialize_map(Visitor) + } +} + impl Serialize for BucketTree { fn serialize(&self, serializer: S) -> Result where @@ -212,24 +254,13 @@ impl BucketTree { BucketTreeIterator::new(self) } - pub fn get_mut(&mut self, id: &Id20) -> Option<&mut RoutingTableNode> { + fn get_leaf(&self, id: &Id20) -> usize { let mut idx = 0; loop { let node = &self.data[idx]; - if !(*id >= node.start && *id <= node.end_inclusive) { - return None; - }; - match &node.data { - BucketTreeNodeData::Leaf(_) => { - // re-borrow mutably - if let BucketTreeNodeData::Leaf(nodes) = &mut self.data[idx].data { - return nodes.iter_mut().find(|b| b.id == *id); - } - unreachable!() - } + match node.data { + BucketTreeNodeData::Leaf(_) => return idx, BucketTreeNodeData::LeftRight(left_idx, right_idx) => { - let left_idx = *left_idx; - let right_idx = *right_idx; let left = &self.data[left_idx]; if *id >= left.start && *id <= left.end_inclusive { idx = left_idx; @@ -241,26 +272,18 @@ impl BucketTree { } } - pub fn add_node(&mut self, self_id: &Id20, id: Id20, addr: SocketAddr) -> InsertResult { - let mut current = 0; - loop { - let node = &self.data[current]; - debug_assert!(id >= node.start && id <= node.end_inclusive); - match &node.data { - BucketTreeNodeData::Leaf(_) => { - return self.insert_into_leaf(current, self_id, id, addr); - } - BucketTreeNodeData::LeftRight(left_idx, right_idx) => { - let left = &self.data[*left_idx]; - if id <= left.end_inclusive { - current = *left_idx; - continue; - } - current = *right_idx; - } - } + pub fn get_mut(&mut self, id: &Id20) -> Option<&mut RoutingTableNode> { + let idx = self.get_leaf(id); + match &mut self.data[idx].data { + BucketTreeNodeData::Leaf(nodes) => nodes.iter_mut().find(|b| b.id == *id), + BucketTreeNodeData::LeftRight(_, _) => unreachable!(), } } + + pub fn add_node(&mut self, self_id: &Id20, id: Id20, addr: SocketAddr) -> InsertResult { + let idx = self.get_leaf(&id); + self.insert_into_leaf(idx, self_id, id, addr) + } fn insert_into_leaf( &mut self, mut idx: usize, @@ -367,7 +390,7 @@ impl Default for BucketTree { } } -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct RoutingTableNode { #[serde(serialize_with = "crate::utils::serialize_id20")] id: Id20, @@ -425,7 +448,7 @@ impl RoutingTableNode { } } -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct RoutingTable { #[serde(serialize_with = "crate::utils::serialize_id20")] id: Id20, @@ -441,6 +464,9 @@ impl RoutingTable { size: 0, } } + pub fn id(&self) -> Id20 { + self.id + } pub fn len(&self) -> usize { self.size } @@ -487,7 +513,10 @@ impl RoutingTable { #[cfg(test)] mod tests { - use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + use std::{ + io::Cursor, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + }; use librqbit_core::id20::Id20; use rand::Rng; @@ -619,4 +648,11 @@ mod tests { let rtable = generate_table(None); assert_eq!(rtable.sorted_by_distance_from(id).len(), rtable.size); } + + #[test] + fn serialize_deserialize_routing_table() { + let table = generate_table(Some(1000)); + let v = serde_json::to_vec(&table).unwrap(); + let detable: RoutingTable = serde_json::from_reader(Cursor::new(v)).unwrap(); + } } diff --git a/crates/librqbit_core/src/id20.rs b/crates/librqbit_core/src/id20.rs index b234087..5833fc3 100644 --- a/crates/librqbit_core/src/id20.rs +++ b/crates/librqbit_core/src/id20.rs @@ -48,7 +48,26 @@ impl<'de> Deserialize<'de> for Id20 { type Value = Id20; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(formatter, "a 20 byte slice") + write!(formatter, "a 20 byte slice or a 40 byte string") + } + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + if v.len() != 40 { + return Err(E::invalid_length(40, &self)); + } + let mut out = [0u8; 20]; + match hex::decode_to_slice(v, &mut out) { + Ok(_) => Ok(Id20(out)), + Err(e) => Err(E::custom(e)), + } + } + fn visit_borrowed_bytes(self, v: &'de [u8]) -> Result + where + E: serde::de::Error, + { + self.visit_bytes(v) } fn visit_bytes(self, v: &[u8]) -> Result where @@ -62,7 +81,7 @@ impl<'de> Deserialize<'de> for Id20 { Ok(Id20(buf)) } } - deserializer.deserialize_bytes(Visitor {}) + deserializer.deserialize_any(Visitor {}) } } diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index ba85aff..965ae09 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -2,7 +2,7 @@ use std::{fs::File, io::Read, net::SocketAddr, str::FromStr, time::Duration}; use anyhow::Context; use clap::Clap; -use dht::{Dht, Id20}; +use dht::{Dht, Id20, PersistentDht}; use futures::StreamExt; use librqbit::{ dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult}, @@ -191,7 +191,11 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()> let dht = if opts.disable_dht { None } else { - Some(Dht::new().await.context("error initializing DHT")?) + Some( + PersistentDht::create(None) + .await + .context("error initializing DHT")?, + ) }; let peer_opts = PeerConnectionOptions {