diff --git a/crates/librqbit/src/limits.rs b/crates/librqbit/src/limits.rs index 5df402c..4fd9be8 100644 --- a/crates/librqbit/src/limits.rs +++ b/crates/librqbit/src/limits.rs @@ -1,6 +1,8 @@ +use std::sync::Arc; use std::time::Duration; use leaky_bucket::RateLimiter; +use parking_lot::RwLock; use peer_binary_protocol::PIECE_MESSAGE_DEFAULT_LEN; use serde::Deserialize; use serde::Serialize; @@ -11,38 +13,68 @@ pub struct LimitsConfig { pub download_bps: Option, } -#[derive(Default)] -pub struct Limits { - down: Option, - up: Option, -} +struct Limit(RwLock>>); -impl Limits { - pub fn new(config: LimitsConfig) -> Self { - let new = |bps: usize| -> RateLimiter { - let b_per_100_ms = bps.div_ceil(10); +impl Limit { + fn new_inner(bps: Option) -> Arc> { + let bps = match bps { + Some(bps) => bps, + None => return Arc::new(None), + }; + let b_per_100_ms = bps.div_ceil(10); + Arc::new(Some( RateLimiter::builder() .interval(Duration::from_millis(100)) .refill(b_per_100_ms) // whatever the limit is, we need to be able to download / upload a chunk .max(PIECE_MESSAGE_DEFAULT_LEN.max(bps)) - .build() - }; + .build(), + )) + } + + fn new(bps: Option) -> Self { + Self(RwLock::new(Self::new_inner(bps))) + } + + async fn acquire(&self, size: usize) { + let lim = self.0.read().clone(); + if let Some(rl) = lim.as_ref() { + rl.acquire(size).await + } + } + + fn set(&self, limit: Option) { + let new = Self::new_inner(limit); + *self.0.write() = new; + } +} + +pub struct Limits { + down: Limit, + up: Limit, +} + +impl Limits { + pub fn new(config: LimitsConfig) -> Self { Self { - down: config.download_bps.map(new), - up: config.upload_bps.map(new), + down: Limit::new(config.download_bps), + up: Limit::new(config.upload_bps), } } pub async fn prepare_for_upload(&self, len: usize) { - if let Some(rl) = self.up.as_ref() { - rl.acquire(len).await; - } + self.up.acquire(len).await } pub async fn prepare_for_download(&self, len: usize) { - if let Some(rl) = self.down.as_ref() { - rl.acquire(len).await; - } + self.down.acquire(len).await + } + + pub fn set_upload_bps(&self, bps: Option) { + self.up.set(bps); + } + + pub fn set_download_bps(&self, bps: Option) { + self.down.set(bps); } }