use data_encoding::BASE32; use serde::{Deserialize, Deserializer, Serialize}; use std::{cmp::Ordering, str::FromStr}; #[derive(Clone, Copy, PartialEq, Eq, Hash)] pub struct Id(pub [u8; N]); impl Id { pub fn new(from: [u8; N]) -> Id { Id(from) } pub fn as_string(&self) -> String { hex::encode(self.0) } pub fn distance(&self, other: &Id) -> Id { let mut xor = [0u8; N]; for (idx, (s, o)) in self .0 .iter() .copied() .zip(other.0.iter().copied()) .enumerate() { xor[idx] = s ^ o; } Id(xor) } pub fn get_bit(&self, bit: u8) -> bool { let n = self.0[(bit / 8) as usize]; let mask = 1 << (7 - bit % 8); n & mask > 0 } pub fn set_bit(&mut self, bit: u8, value: bool) { let n = &mut self.0[(bit / 8) as usize]; if value { *n |= 1 << (7 - bit % 8) } else { let mask = !(1 << (7 - bit % 8)); *n &= mask; } } pub fn set_bits_range(&mut self, r: std::ops::Range, value: bool) { for bit in r { self.set_bit(bit, value) } } } impl Default for Id { fn default() -> Self { Id([0; N]) } } impl std::fmt::Debug for Id { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for byte in self.0 { write!(f, "{:02x?}", byte)?; } Ok(()) } } impl FromStr for Id { type Err = anyhow::Error; fn from_str(s: &str) -> Result { let mut out = [0u8; N]; let base32_encoded_size = (N as f64 / 5f64).ceil() as usize * 8; if s.len() == N * 2 { hex::decode_to_slice(s, &mut out)?; Ok(Id(out)) // try decode as base32 } else if s.len() == base32_encoded_size { match BASE32.decode(s.as_bytes()) { Ok(decoded) => { out.copy_from_slice(&decoded); Ok(Id(out)) } Err(err) => { anyhow::bail!( "fail to decode base32 string {}: {}", s, err ) } } } else { anyhow::bail!( "expected a hex string of length {} or {}", N * 2, base32_encoded_size ); } } } impl Serialize for Id { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { serializer.serialize_bytes(&self.0) } } impl<'de, const N: usize> Deserialize<'de> for Id { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { struct IdVisitor; impl<'de, const N: usize> serde::de::Visitor<'de> for IdVisitor { type Value = Id; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter .write_str("a byte array of length ") .and_then(|_| formatter.write_fmt(format_args!("{}", N))) } fn visit_str(self, v: &str) -> Result where E: serde::de::Error, { if v.len() != N * 2 { return Err(E::invalid_length(40, &self)); } let mut out = [0u8; N]; match hex::decode_to_slice(v, &mut out) { Ok(_) => Ok(Id(out)), Err(e) => Err(E::custom(e)), } } fn visit_borrowed_bytes(self, v: &'de [u8]) -> Result where E: serde::de::Error, { self.visit_bytes(v) } fn visit_bytes(self, v: &[u8]) -> Result where E: serde::de::Error, { if v.len() != N { return Err(E::invalid_length(N, &self)); } let mut buf = [0u8; N]; buf.copy_from_slice(v); Ok(Id(buf)) } } deserializer.deserialize_any(IdVisitor {}) } } impl PartialOrd> for Id { fn partial_cmp(&self, other: &Id) -> Option { Some(self.cmp(other)) } } impl Ord for Id { fn cmp(&self, other: &Id) -> Ordering { for (s, o) in self.0.iter().copied().zip(other.0.iter().copied()) { match s.cmp(&o) { Ordering::Less => return Ordering::Less, Ordering::Equal => continue, Ordering::Greater => return Ordering::Greater, } } Ordering::Equal } } /// A 20-byte hash used throughout librqbit, for torrent info hashes, peer ids etc. pub type Id20 = Id<20>; /// A 32-byte hash used in Bittorrent V2, for torrent info hashes, piece hashing, etc. pub type Id32 = Id<32>; #[cfg(test)] mod tests { use super::*; use std::str::FromStr; #[test] fn test_set_bit_range() { let mut id = Id20::default(); id.set_bits_range(9..17, true); assert_eq!( id, Id20::new([0, 127, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) ) } #[test] fn test_id32_from_str() { let str = "06f04cc728bef957a658876ef807f0514e4d715392969998efef584d2c3e435e"; let _ih = Id32::from_str(str).unwrap(); } #[test] fn test_id20_base32_encoded_from_str() { let str = "Z7QRDHYSJCA4U4HXGBXTFYUSDFGIRQMV"; let ih1 = Id20::from_str(str).unwrap(); let s2 = "cfe1119f124881ca70f7306f32e292194c88c195"; let ih2 = Id20::from_str(s2).unwrap(); assert_eq!(ih1, ih2); } }