rqbit/crates/librqbit_core/src/hash_id.rs
2024-07-16 13:49:51 +08:00

217 lines
6 KiB
Rust

use data_encoding::BASE32;
use serde::{Deserialize, Deserializer, Serialize};
use std::{cmp::Ordering, str::FromStr};
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct Id<const N: usize>(pub [u8; N]);
impl<const N: usize> Id<N> {
pub fn new(from: [u8; N]) -> Id<N> {
Id(from)
}
pub fn as_string(&self) -> String {
hex::encode(self.0)
}
pub fn distance(&self, other: &Id<N>) -> Id<N> {
let mut xor = [0u8; N];
for (idx, (s, o)) in self
.0
.iter()
.copied()
.zip(other.0.iter().copied())
.enumerate()
{
xor[idx] = s ^ o;
}
Id(xor)
}
pub fn get_bit(&self, bit: u8) -> bool {
let n = self.0[(bit / 8) as usize];
let mask = 1 << (7 - bit % 8);
n & mask > 0
}
pub fn set_bit(&mut self, bit: u8, value: bool) {
let n = &mut self.0[(bit / 8) as usize];
if value {
*n |= 1 << (7 - bit % 8)
} else {
let mask = !(1 << (7 - bit % 8));
*n &= mask;
}
}
pub fn set_bits_range(&mut self, r: std::ops::Range<u8>, value: bool) {
for bit in r {
self.set_bit(bit, value)
}
}
}
impl<const N: usize> Default for Id<N> {
fn default() -> Self {
Id([0; N])
}
}
impl<const N: usize> std::fmt::Debug for Id<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for byte in self.0 {
write!(f, "{:02x?}", byte)?;
}
Ok(())
}
}
impl<const N: usize> FromStr for Id<N> {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut out = [0u8; N];
let base32_encoded_size = (N as f64 / 5f64).ceil() as usize * 8;
if s.len() == N * 2 {
hex::decode_to_slice(s, &mut out)?;
Ok(Id(out))
// try decode as base32
} else if s.len() == base32_encoded_size {
match BASE32.decode(s.as_bytes()) {
Ok(decoded) => {
out.copy_from_slice(&decoded);
Ok(Id(out))
}
Err(err) => {
anyhow::bail!(
"fail to decode base32 string {}: {}",
s,
err
)
}
}
} else {
anyhow::bail!(
"expected a hex string of length {} or {}",
N * 2,
base32_encoded_size
);
}
}
}
impl<const N: usize> Serialize for Id<N> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_bytes(&self.0)
}
}
impl<'de, const N: usize> Deserialize<'de> for Id<N> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct IdVisitor<const N: usize>;
impl<'de, const N: usize> serde::de::Visitor<'de> for IdVisitor<N> {
type Value = Id<N>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter
.write_str("a byte array of length ")
.and_then(|_| formatter.write_fmt(format_args!("{}", N)))
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if v.len() != N * 2 {
return Err(E::invalid_length(40, &self));
}
let mut out = [0u8; N];
match hex::decode_to_slice(v, &mut out) {
Ok(_) => Ok(Id(out)),
Err(e) => Err(E::custom(e)),
}
}
fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
self.visit_bytes(v)
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if v.len() != N {
return Err(E::invalid_length(N, &self));
}
let mut buf = [0u8; N];
buf.copy_from_slice(v);
Ok(Id(buf))
}
}
deserializer.deserialize_any(IdVisitor {})
}
}
impl<const N: usize> PartialOrd<Id<N>> for Id<N> {
fn partial_cmp(&self, other: &Id<N>) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<const N: usize> Ord for Id<N> {
fn cmp(&self, other: &Id<N>) -> Ordering {
for (s, o) in self.0.iter().copied().zip(other.0.iter().copied()) {
match s.cmp(&o) {
Ordering::Less => return Ordering::Less,
Ordering::Equal => continue,
Ordering::Greater => return Ordering::Greater,
}
}
Ordering::Equal
}
}
/// A 20-byte hash used throughout librqbit, for torrent info hashes, peer ids etc.
pub type Id20 = Id<20>;
/// A 32-byte hash used in Bittorrent V2, for torrent info hashes, piece hashing, etc.
pub type Id32 = Id<32>;
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
#[test]
fn test_set_bit_range() {
let mut id = Id20::default();
id.set_bits_range(9..17, true);
assert_eq!(
id,
Id20::new([0, 127, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
)
}
#[test]
fn test_id32_from_str() {
let str = "06f04cc728bef957a658876ef807f0514e4d715392969998efef584d2c3e435e";
let _ih = Id32::from_str(str).unwrap();
}
#[test]
fn test_id20_base32_encoded_from_str() {
let str = "Z7QRDHYSJCA4U4HXGBXTFYUSDFGIRQMV";
let ih1 = Id20::from_str(str).unwrap();
let s2 = "cfe1119f124881ca70f7306f32e292194c88c195";
let ih2 = Id20::from_str(s2).unwrap();
assert_eq!(ih1, ih2);
}
}