Persistent DHT

This commit is contained in:
Igor Katson 2021-07-18 10:53:33 +01:00
parent 52f17a1717
commit 1300faa0b4
8 changed files with 301 additions and 79 deletions

31
Cargo.lock generated
View file

@ -233,6 +233,7 @@ dependencies = [
"anyhow",
"bencode",
"clone_to_owned",
"directories",
"futures",
"hex 0.4.3",
"indexmap",
@ -256,6 +257,26 @@ dependencies = [
"generic-array 0.14.4",
]
[[package]]
name = "directories"
version = "3.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e69600ff1703123957937708eb27f7a564e48885c537782722ed0ba3189ce1d7"
dependencies = [
"dirs-sys",
]
[[package]]
name = "dirs-sys"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03d86534ed367a67548dc68113a0f5db55432fdfbb6e6f9d77704397d95d5780"
dependencies = [
"libc",
"redox_users",
"winapi",
]
[[package]]
name = "encoding_rs"
version = "0.8.28"
@ -1252,6 +1273,16 @@ dependencies = [
"bitflags",
]
[[package]]
name = "redox_users"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "528532f3d801c87aec9def2add9ca802fe569e44a544afe633765267840abe64"
dependencies = [
"getrandom 0.2.3",
"redox_syscall",
]
[[package]]
name = "regex"
version = "1.5.4"

View file

@ -19,6 +19,7 @@ pretty_env_logger = "0.4"
futures = "0.3"
rand = "0.8"
indexmap = "1.7"
directories = "3"
clone_to_owned = {path="../clone_to_owned"}
librqbit_core = {path="../librqbit_core"}

View file

@ -11,7 +11,6 @@ use crate::{
MessageKind, Node,
},
routing_table::{InsertResult, RoutingTable},
DHT_BOOTSTRAP,
};
use anyhow::Context;
use bencode::ByteString;
@ -26,7 +25,7 @@ use tokio::{
net::UdpSocket,
sync::mpsc::{channel, unbounded_channel, Sender, UnboundedReceiver, UnboundedSender},
};
use tokio_stream::wrappers::BroadcastStream;
use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
#[derive(Debug, Serialize)]
pub struct DhtStats {
@ -58,12 +57,17 @@ struct DhtState {
}
impl DhtState {
fn new(id: Id20, sender: UnboundedSender<(Message<ByteString>, SocketAddr)>) -> Self {
fn new(
id: Id20,
sender: UnboundedSender<(Message<ByteString>, SocketAddr)>,
routing_table: Option<RoutingTable>,
) -> Self {
let routing_table = routing_table.unwrap_or_else(|| RoutingTable::new(id));
Self {
id,
next_transaction_id: 0,
outstanding_requests: Default::default(),
routing_table: RoutingTable::new(id),
routing_table,
sender,
seen_peers: Default::default(),
get_peers_subscribers: Default::default(),
@ -569,10 +573,14 @@ impl Stream for PeerStream {
) -> Poll<Option<Self::Item>> {
loop {
if let Some((pos, end)) = self.initial_peers_pos.take() {
let g = self.state.lock();
let seen = g.seen_peers.get(&self.info_hash).unwrap();
let addr = *seen.get_index(pos).unwrap();
drop(g);
let addr = *self
.state
.lock()
.seen_peers
.get(&self.info_hash)
.unwrap()
.get_index(pos)
.unwrap();
if pos < end {
self.initial_peers_pos = Some((pos + 1, end));
}
@ -580,50 +588,52 @@ impl Stream for PeerStream {
return Poll::Ready(Some(addr));
}
let r = match self.broadcast_rx.poll_next_unpin(cx) {
Poll::Ready(r) => match r {
Some(r) => r,
None => return Poll::Ready(None),
},
Poll::Pending => return Poll::Pending,
};
match r {
Ok(v) => {
match self.broadcast_rx.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(v))) => {
self.absolute_stream_pos += 1;
return Poll::Ready(Some(v));
}
Err(e) => match e {
tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(lagged_by) => {
debug!("peer stream is lagged by {}", lagged_by);
let s = self.absolute_stream_pos;
let e = s + lagged_by as usize;
self.initial_peers_pos = Some((s, e));
continue;
}
},
}
Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(lagged_by)))) => {
debug!("peer stream is lagged by {}", lagged_by);
let s = self.absolute_stream_pos;
let e = s + lagged_by as usize;
self.initial_peers_pos = Some((s, e));
continue;
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
};
}
}
}
#[derive(Default)]
pub struct DhtConfig {
pub peer_id: Option<Id20>,
pub bootstrap_addrs: Option<Vec<String>>,
pub routing_table: Option<RoutingTable>,
}
impl Dht {
pub async fn new() -> anyhow::Result<Self> {
Self::with_bootstrap_addrs(DHT_BOOTSTRAP).await
Self::with_config(DhtConfig::default()).await
}
pub async fn with_bootstrap_addrs(bootstrap_addrs: &[&str]) -> anyhow::Result<Self> {
pub async fn with_config(config: DhtConfig) -> anyhow::Result<Self> {
let socket = UdpSocket::bind("0.0.0.0:0")
.await
.context("error binding socket")?;
let peer_id = generate_peer_id();
let peer_id = config.peer_id.unwrap_or_else(generate_peer_id);
info!("starting up DHT with peer id {:?}", peer_id);
let bootstrap_addrs = bootstrap_addrs
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>();
let bootstrap_addrs = config
.bootstrap_addrs
.unwrap_or_else(|| crate::DHT_BOOTSTRAP.iter().map(|v| v.to_string()).collect());
let (in_tx, in_rx) = unbounded_channel();
let state = Arc::new(Mutex::new(DhtState::new(peer_id, in_tx.clone())));
let state = Arc::new(Mutex::new(DhtState::new(
peer_id,
in_tx.clone(),
config.routing_table,
)));
tokio::spawn({
let state = state.clone();

View file

@ -1,10 +1,12 @@
mod bprotocol;
mod dht;
mod persistence;
mod routing_table;
mod utils;
pub use dht::Dht;
pub use dht::DhtStats;
pub use dht::{Dht, DhtConfig};
pub use librqbit_core::id20::Id20;
pub use persistence::{PersistentDht, PersistentDhtConfig};
pub static DHT_BOOTSTRAP: &[&str] = &["dht.transmissionbt.com:6881", "dht.libtorrent.org:25401"];

View file

@ -0,0 +1,119 @@
// TODO: this now stores only the routing table, but we also need AT LEAST the same socket address...
use std::fs::OpenOptions;
use std::path::{Path, PathBuf};
use std::time::Duration;
use anyhow::Context;
use log::{debug, error, info, warn};
use tokio::spawn;
use crate::dht::{Dht, DhtConfig};
use crate::routing_table::RoutingTable;
#[derive(Default, Clone)]
pub struct PersistentDhtConfig {
pub dump_interval: Option<Duration>,
pub config_filename: Option<PathBuf>,
}
pub struct PersistentDht {
// config_filename: PathBuf,
}
fn dump_dht(dht: &Dht, filename: &Path, tempfile_name: &Path) -> anyhow::Result<()> {
let mut file = OpenOptions::new()
.truncate(true)
.create(true)
.write(true)
.open(&tempfile_name)
.with_context(|| format!("error opening {:?}", tempfile_name))?;
match dht.with_routing_table(|r| serde_json::to_writer(&mut file, r)) {
Ok(_) => {
debug!("dumped DHT to {:?}", &tempfile_name);
}
Err(e) => {
return Err(e).with_context(|| {
format!("error serializing DHT routing table to {:?}", tempfile_name)
})
}
}
std::fs::rename(tempfile_name, filename)
.with_context(|| format!("error renaming {:?} to {:?}", tempfile_name, filename))
}
impl PersistentDht {
pub async fn create(config: Option<PersistentDhtConfig>) -> anyhow::Result<Dht> {
let mut config = config.unwrap_or_default();
let config_filename = match config.config_filename.take() {
Some(config_filename) => config_filename,
None => {
let dirs = directories::ProjectDirs::from("com", "rqbit", "dht")
.context("cannot determine project directory for com.rqbit.dht")?;
let path = dirs.cache_dir().join("dht.json");
info!("will store DHT routing table to {:?} periodically", &path);
path
}
};
if let Some(parent) = config_filename.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("error creating dir {:?}", &parent))?;
}
let routing_table = match OpenOptions::new().read(true).open(&config_filename) {
Ok(dht_json) => match serde_json::from_reader::<_, RoutingTable>(&dht_json) {
Ok(r) => {
info!("loaded DHT routing table from {:?}", &config_filename);
Some(r)
}
Err(e) => {
warn!(
"cannot deserialize routing table from file {:?}: {:#}",
&config_filename, e
);
None
}
},
Err(e) => match e.kind() {
std::io::ErrorKind::NotFound => None,
_ => return Err(e).with_context(|| format!("error reading {:?}", config_filename)),
},
};
let peer_id = routing_table.as_ref().map(|r| r.id());
let dht_config = DhtConfig {
peer_id,
routing_table,
..Default::default()
};
let dht = Dht::with_config(dht_config).await?;
spawn({
let dht = dht.clone();
let dump_interval = config
.dump_interval
.unwrap_or_else(|| Duration::from_secs(3));
async move {
let tempfile_name = {
let file_name = format!("dht.json.tmp.{}", std::process::id());
let mut tmp = config_filename.clone();
tmp.set_file_name(file_name);
tmp
};
loop {
tokio::time::sleep(dump_interval).await;
debug!("dumping DHT to {:?}", &config_filename);
match dump_dht(&dht, &config_filename, &tempfile_name) {
Ok(_) => debug!("dumped DHT to {:?}", &config_filename),
Err(e) => error!("error dumping DHT to {:?}: {:#}", &config_filename, e),
}
}
}
});
Ok(dht)
}
}

View file

@ -5,16 +5,16 @@ use std::{
use librqbit_core::id20::Id20;
use log::debug;
use serde::{ser::SerializeMap, Serialize};
use serde::{ser::SerializeMap, Deserialize, Serialize};
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
enum BucketTreeNodeData {
// TODO: maybe replace that with SmallVec<8>?
Leaf(Vec<RoutingTableNode>),
LeftRight(usize, usize),
}
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
struct BucketTreeNode {
bits: u8,
#[serde(serialize_with = "crate::utils::serialize_id20")]
@ -29,6 +29,48 @@ pub struct BucketTree {
data: Vec<BucketTreeNode>,
}
impl<'de> Deserialize<'de> for BucketTree {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = BucketTree;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "a map with key \"flat\"")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut data: Option<Vec<BucketTreeNode>> = None;
loop {
match map.next_key::<String>()?.as_deref() {
Some("flat") => {
let buckets = map.next_value::<Vec<BucketTreeNode>>()?;
data = Some(buckets)
}
Some(_) => {
map.next_value::<serde::de::IgnoredAny>()?;
}
None => {
use serde::de::Error;
match data.take() {
Some(data) => return Ok(BucketTree { data }),
None => return Err(A::Error::missing_field("flat")),
}
}
}
}
}
}
deserializer.deserialize_map(Visitor)
}
}
impl Serialize for BucketTree {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
@ -212,24 +254,13 @@ impl BucketTree {
BucketTreeIterator::new(self)
}
pub fn get_mut(&mut self, id: &Id20) -> Option<&mut RoutingTableNode> {
fn get_leaf(&self, id: &Id20) -> usize {
let mut idx = 0;
loop {
let node = &self.data[idx];
if !(*id >= node.start && *id <= node.end_inclusive) {
return None;
};
match &node.data {
BucketTreeNodeData::Leaf(_) => {
// re-borrow mutably
if let BucketTreeNodeData::Leaf(nodes) = &mut self.data[idx].data {
return nodes.iter_mut().find(|b| b.id == *id);
}
unreachable!()
}
match node.data {
BucketTreeNodeData::Leaf(_) => return idx,
BucketTreeNodeData::LeftRight(left_idx, right_idx) => {
let left_idx = *left_idx;
let right_idx = *right_idx;
let left = &self.data[left_idx];
if *id >= left.start && *id <= left.end_inclusive {
idx = left_idx;
@ -241,26 +272,18 @@ impl BucketTree {
}
}
pub fn add_node(&mut self, self_id: &Id20, id: Id20, addr: SocketAddr) -> InsertResult {
let mut current = 0;
loop {
let node = &self.data[current];
debug_assert!(id >= node.start && id <= node.end_inclusive);
match &node.data {
BucketTreeNodeData::Leaf(_) => {
return self.insert_into_leaf(current, self_id, id, addr);
}
BucketTreeNodeData::LeftRight(left_idx, right_idx) => {
let left = &self.data[*left_idx];
if id <= left.end_inclusive {
current = *left_idx;
continue;
}
current = *right_idx;
}
}
pub fn get_mut(&mut self, id: &Id20) -> Option<&mut RoutingTableNode> {
let idx = self.get_leaf(id);
match &mut self.data[idx].data {
BucketTreeNodeData::Leaf(nodes) => nodes.iter_mut().find(|b| b.id == *id),
BucketTreeNodeData::LeftRight(_, _) => unreachable!(),
}
}
pub fn add_node(&mut self, self_id: &Id20, id: Id20, addr: SocketAddr) -> InsertResult {
let idx = self.get_leaf(&id);
self.insert_into_leaf(idx, self_id, id, addr)
}
fn insert_into_leaf(
&mut self,
mut idx: usize,
@ -367,7 +390,7 @@ impl Default for BucketTree {
}
}
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingTableNode {
#[serde(serialize_with = "crate::utils::serialize_id20")]
id: Id20,
@ -425,7 +448,7 @@ impl RoutingTableNode {
}
}
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingTable {
#[serde(serialize_with = "crate::utils::serialize_id20")]
id: Id20,
@ -441,6 +464,9 @@ impl RoutingTable {
size: 0,
}
}
pub fn id(&self) -> Id20 {
self.id
}
pub fn len(&self) -> usize {
self.size
}
@ -487,7 +513,10 @@ impl RoutingTable {
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::{
io::Cursor,
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
};
use librqbit_core::id20::Id20;
use rand::Rng;
@ -619,4 +648,11 @@ mod tests {
let rtable = generate_table(None);
assert_eq!(rtable.sorted_by_distance_from(id).len(), rtable.size);
}
#[test]
fn serialize_deserialize_routing_table() {
let table = generate_table(Some(1000));
let v = serde_json::to_vec(&table).unwrap();
let detable: RoutingTable = serde_json::from_reader(Cursor::new(v)).unwrap();
}
}

View file

@ -48,7 +48,26 @@ impl<'de> Deserialize<'de> for Id20 {
type Value = Id20;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a 20 byte slice")
write!(formatter, "a 20 byte slice or a 40 byte string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if v.len() != 40 {
return Err(E::invalid_length(40, &self));
}
let mut out = [0u8; 20];
match hex::decode_to_slice(v, &mut out) {
Ok(_) => Ok(Id20(out)),
Err(e) => Err(E::custom(e)),
}
}
fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
self.visit_bytes(v)
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
@ -62,7 +81,7 @@ impl<'de> Deserialize<'de> for Id20 {
Ok(Id20(buf))
}
}
deserializer.deserialize_bytes(Visitor {})
deserializer.deserialize_any(Visitor {})
}
}

View file

@ -2,7 +2,7 @@ use std::{fs::File, io::Read, net::SocketAddr, str::FromStr, time::Duration};
use anyhow::Context;
use clap::Clap;
use dht::{Dht, Id20};
use dht::{Dht, Id20, PersistentDht};
use futures::StreamExt;
use librqbit::{
dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult},
@ -191,7 +191,11 @@ async fn async_main(opts: Opts, spawner: BlockingSpawner) -> anyhow::Result<()>
let dht = if opts.disable_dht {
None
} else {
Some(Dht::new().await.context("error initializing DHT")?)
Some(
PersistentDht::create(None)
.await
.context("error initializing DHT")?,
)
};
let peer_opts = PeerConnectionOptions {