diff --git a/crates/librqbit/src/peer_state.rs b/crates/librqbit/src/peer_state.rs index 27bfa63..cc9d50d 100644 --- a/crates/librqbit/src/peer_state.rs +++ b/crates/librqbit/src/peer_state.rs @@ -2,8 +2,10 @@ use std::{collections::HashSet, sync::Arc}; use librqbit_core::id20::Id20; use librqbit_core::lengths::{ChunkInfo, ValidPieceIndex}; +use tokio::sync::mpsc::UnboundedSender; use tokio::sync::{Notify, Semaphore}; +use crate::peer_connection::WriterRequest; use crate::type_aliases::BF; #[derive(Debug, Hash, PartialEq, Eq)] @@ -21,10 +23,14 @@ impl From<&ChunkInfo> for InflightRequest { } } -#[derive(Debug)] +// TODO: Arc can be removed probably, as UnboundedSender should be clone + it can be downgraded to weak. +pub type PeerTx = Arc>; + +#[derive(Debug, Default)] pub enum PeerState { + #[default] Queued, - Connecting, + Connecting(PeerTx), Live(LivePeerState), } @@ -37,10 +43,11 @@ pub struct LivePeerState { pub have_notify: Arc, pub bitfield: Option, pub inflight_requests: HashSet, + pub tx: PeerTx, } impl LivePeerState { - pub fn new(peer_id: Id20) -> Self { + pub fn new(peer_id: Id20, tx: PeerTx) -> Self { LivePeerState { peer_id, i_am_choked: true, @@ -49,6 +56,7 @@ impl LivePeerState { have_notify: Arc::new(Notify::new()), requests_sem: Arc::new(Semaphore::new(0)), inflight_requests: Default::default(), + tx, } } } diff --git a/crates/librqbit/src/torrent_state.rs b/crates/librqbit/src/torrent_state.rs index 2bb39a6..b9c0a86 100644 --- a/crates/librqbit/src/torrent_state.rs +++ b/crates/librqbit/src/torrent_state.rs @@ -10,7 +10,7 @@ use std::{ time::{Duration, Instant}, }; -use anyhow::Context; +use anyhow::{bail, Context}; use buffers::{ByteBuf, ByteString}; use clone_to_owned::CloneToOwned; use futures::{stream::FuturesUnordered, StreamExt}; @@ -40,7 +40,7 @@ use crate::{ peer_connection::{ PeerConnection, PeerConnectionHandler, PeerConnectionOptions, WriterRequest, }, - peer_state::{InflightRequest, LivePeerState, PeerState}, + peer_state::{InflightRequest, LivePeerState, PeerState, PeerTx}, spawn_utils::{spawn, BlockingSpawner}, type_aliases::{PeerHandle, BF}, }; @@ -55,7 +55,6 @@ pub struct PeerStates { states: HashMap, seen: HashSet, inflight_pieces: HashMap, - tx: HashMap>>, } #[derive(Debug, Default)] @@ -73,7 +72,7 @@ impl PeerStates { .values() .fold(AggregatePeerStats::default(), |mut s, p| { match p { - PeerState::Connecting => s.connecting += 1, + PeerState::Connecting(_) => s.connecting += 1, PeerState::Live(_) => s.live += 1, PeerState::Queued => s.queued += 1, }; @@ -82,15 +81,11 @@ impl PeerStates { stats.seen = self.seen.len(); stats } - pub fn add_if_not_seen( - &mut self, - addr: SocketAddr, - tx: UnboundedSender, - ) -> Option { + pub fn add_if_not_seen(&mut self, addr: SocketAddr) -> Option { if self.seen.contains(&addr) { return None; } - let handle = self.add(addr, tx)?; + let handle = self.add(addr)?; self.seen.insert(addr); Some(handle) } @@ -113,23 +108,16 @@ impl PeerStates { self.get_live_mut(handle) .ok_or_else(|| anyhow::anyhow!("peer dropped")) } - pub fn add( - &mut self, - addr: SocketAddr, - tx: UnboundedSender, - ) -> Option { + pub fn add(&mut self, addr: SocketAddr) -> Option { let handle = addr; if self.states.contains_key(&addr) { return None; } self.states.insert(handle, PeerState::Queued); - self.tx.insert(handle, Arc::new(tx)); Some(handle) } pub fn drop_peer(&mut self, handle: PeerHandle) -> Option { - let result = self.states.remove(&handle); - self.tx.remove(&handle); - result + self.states.remove(&handle) } pub fn mark_i_am_choked(&mut self, handle: PeerHandle, is_choked: bool) -> Option { let live = self.get_live_mut(handle)?; @@ -158,8 +146,8 @@ impl PeerStates { live.bitfield = Some(bitfield); Some(prev) } - pub fn clone_tx(&self, handle: PeerHandle) -> Option>> { - Some(self.tx.get(&handle)?.clone()) + pub fn clone_tx(&self, handle: PeerHandle) -> Option { + Some(self.get_live(handle)?.tx.clone()) } pub fn remove_inflight_piece(&mut self, piece: ValidPieceIndex) -> Option { self.inflight_pieces.remove(&piece) @@ -242,8 +230,11 @@ pub struct TorrentState { stats: AtomicStats, options: TorrentStateOptions, + // Limits how many active (occupying network resources) peers there are at a moment in time. peer_semaphore: Semaphore, - peer_queue_tx: UnboundedSender<(SocketAddr, UnboundedReceiver)>, + + // The queue for peer manager to connect to them. + peer_queue_tx: UnboundedSender, finished_notify: Notify, } @@ -292,45 +283,51 @@ impl TorrentState { let state = state.clone(); async move { loop { - let (addr, out_rx) = peer_queue_rx.recv().await.unwrap(); + let addr = peer_queue_rx.recv().await.unwrap(); let permit = state.peer_semaphore.acquire().await.unwrap(); - match state.locked.write().peers.states.get_mut(&addr) { - Some(s @ PeerState::Queued) => *s = PeerState::Connecting, - s => { - warn!("did not expect to see the peer in state {:?}", s); - continue; - } - }; - - let handler = PeerHandler { - addr, - state: state.clone(), - spawner, - }; - let options = PeerConnectionOptions { - connect_timeout: state.options.peer_connect_timeout, - read_write_timeout: state.options.peer_read_write_timeout, - ..Default::default() - }; - let peer_connection = PeerConnection::new( - addr, - state.info_hash, - state.peer_id, - handler, - Some(options), - spawner, - ); - permit.forget(); - spawn(format!("manage_peer({addr})"), async move { - if let Err(e) = peer_connection.manage_peer(out_rx).await { - debug!("error managing peer {}: {:#}", addr, e) - }; - let state = peer_connection.into_handler().state; - state.drop_peer(addr); - state.peer_semaphore.add_permits(1); - Ok::<_, anyhow::Error>(()) + spawn(format!("manage_peer({addr})"), { + let state = state.clone(); + async move { + let rx = match state.locked.write().peers.states.get_mut(&addr) { + Some(s @ PeerState::Queued) => { + let (tx, rx) = unbounded_channel(); + *s = PeerState::Connecting(Arc::new(tx)); + rx + } + s => { + bail!("did not expect to see the peer in state {:?}", s); + } + }; + + let handler = PeerHandler { + addr, + state: state.clone(), + spawner, + }; + let options = PeerConnectionOptions { + connect_timeout: state.options.peer_connect_timeout, + read_write_timeout: state.options.peer_read_write_timeout, + ..Default::default() + }; + let peer_connection = PeerConnection::new( + addr, + state.info_hash, + state.peer_id, + handler, + Some(options), + spawner, + ); + + if let Err(e) = peer_connection.manage_peer(rx).await { + debug!("error managing peer {}: {:#}", addr, e) + }; + let state = peer_connection.into_handler().state; + state.drop_peer(addr); + state.peer_semaphore.add_permits(1); + Ok::<_, anyhow::Error>(()) + } }); } } @@ -456,14 +453,18 @@ impl TorrentState { fn set_peer_live(&self, handle: PeerHandle, h: Handshake) { let mut g = self.locked.write(); - match g.peers.states.get_mut(&handle) { - Some(s @ &mut PeerState::Connecting) => { - *s = PeerState::Live(LivePeerState::new(Id20(h.peer_id))); - } + let s = match g.peers.states.get_mut(&handle) { + Some(s @ PeerState::Connecting(_)) => s, _ => { - warn!("peer {} was in wrong state", handle); + warn!("peer {} was in a wrong state", handle); + return; } - } + }; + let tx = match std::mem::take(s) { + PeerState::Connecting(tx) => tx, + _ => unreachable!(), + }; + *s = PeerState::Live(LivePeerState::new(Id20(h.peer_id), tx)); } fn drop_peer(&self, handle: PeerHandle) -> bool { @@ -511,11 +512,7 @@ impl TorrentState { continue; } - let tx = match g.peers.tx.get(handle) { - Some(tx) => tx, - None => continue, - }; - let tx = Arc::downgrade(tx); + let tx = Arc::downgrade(&live.tx); futures.push(async move { if let Some(tx) = tx.upgrade() { if tx @@ -547,13 +544,13 @@ impl TorrentState { } pub fn add_peer_if_not_seen(self: &Arc, addr: SocketAddr) -> bool { - let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::(); - match self.locked.write().peers.add_if_not_seen(addr, out_tx) { + // let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::(); + match self.locked.write().peers.add_if_not_seen(addr) { Some(handle) => handle, None => return false, }; - match self.peer_queue_tx.send((addr, out_rx)) { + match self.peer_queue_tx.send(addr) { Ok(_) => {} Err(_) => { warn!("peer adder died, can't add peer")