diff --git a/Cargo.lock b/Cargo.lock index 918d812..6930d26 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2530,8 +2530,10 @@ dependencies = [ "regex", "serde", "serde_json", + "signal-hook", "size_format", "tokio", + "tokio-util", "tracing", "tracing-subscriber", "upnp-serve", @@ -2818,6 +2820,16 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" +dependencies = [ + "libc", + "signal-hook-registry", +] + [[package]] name = "signal-hook-registry" version = "1.4.2" diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 9d9ae02..d6bad63 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -403,6 +403,8 @@ pub struct SessionOptions { // socks5://[username:password@]host:port pub socks_proxy_url: Option, + pub cancellation_token: Option, + // how many concurrent torrent initializations can happen pub concurrent_init_limit: Option, @@ -481,7 +483,7 @@ impl Session { ) -> BoxFuture<'static, anyhow::Result>> { async move { let peer_id = opts.peer_id.unwrap_or_else(generate_peer_id); - let token = CancellationToken::new(); + let token = opts.cancellation_token.take().unwrap_or_default(); let (tcp_listener, tcp_listen_port) = if let Some(port_range) = opts.listen_port_range.clone() { diff --git a/crates/rqbit/Cargo.toml b/crates/rqbit/Cargo.toml index 2adb816..bc0b094 100644 --- a/crates/rqbit/Cargo.toml +++ b/crates/rqbit/Cargo.toml @@ -46,6 +46,8 @@ bytes = "1.5.0" openssl = { version = "0.10", features = ["vendored"], optional = true } upnp-serve = { path = "../upnp-serve" } libc = "0.2.158" +signal-hook = "0.3.17" +tokio-util = "0.7.11" [dev-dependencies] futures = { version = "0.3" } diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index a456ec9..c3da5ad 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -1,4 +1,4 @@ -use std::{io, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; +use std::{io, net::SocketAddr, path::PathBuf, sync::Arc, thread, time::Duration}; use anyhow::{bail, Context}; use clap::{CommandFactory, Parser, ValueEnum}; @@ -17,7 +17,8 @@ use librqbit::{ }; use size_format::SizeFormatterBinary as SF; use tokio::net::TcpListener; -use tracing::{error, error_span, info, trace_span, warn}; +use tokio_util::sync::CancellationToken; +use tracing::{debug, error, error_span, info, trace_span, warn}; #[derive(Debug, Clone, Copy, ValueEnum)] enum LogLevel { @@ -357,10 +358,30 @@ fn main() -> anyhow::Result<()> { .max_blocking_threads(opts.max_blocking_threads as usize) .build()?; - rt.block_on(async_main(opts)) + let token = tokio_util::sync::CancellationToken::new(); + { + let token = token.clone(); + use signal_hook::{consts::SIGINT, consts::SIGTERM, iterator::Signals}; + let mut signals = Signals::new([SIGINT, SIGTERM])?; + thread::spawn(move || { + if let Some(sig) = signals.forever().next() { + warn!("Received signal {:?}", sig); + token.cancel(); + } + }); + } + + rt.block_on(async move { + let res = async_main(opts, token.clone()).await; + if let Err(e) = res { + error!("error running rqbit: {e:?}"); + std::process::exit(1); + } + std::process::exit(0); + }) } -async fn async_main(opts: Opts) -> anyhow::Result<()> { +async fn async_main(opts: Opts, cancel: CancellationToken) -> anyhow::Result<()> { let log_config = init_logging(InitLoggingOptions { default_rust_log_value: Some(match opts.log_level.unwrap_or(LogLevel::Info) { LogLevel::Trace => "trace", @@ -420,6 +441,7 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> { concurrent_init_limit: Some(opts.concurrent_init_limit), root_span: None, fastresume: false, + cancellation_token: Some(cancel.clone()), }; let stats_printer = |session: Arc| async move { @@ -553,15 +575,15 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> { Some(srv) => { let upnp_fut = srv.run_ssdp_forever(); - tokio::pin!(http_api_fut); - tokio::pin!(upnp_fut); - tokio::select! { - r = &mut http_api_fut => r, - r = &mut upnp_fut => r + r = http_api_fut => r, + r = upnp_fut => r } } - None => http_api_fut.await, + None => tokio::select! { + _ = cancel.cancelled() => bail!("cancelled"), + r = http_api_fut => r, + }, }; res.context("error running rqbit server") @@ -726,8 +748,13 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> { if download_opts.exit_on_finish { let results = futures::future::join_all( handles.iter().map(|h| h.wait_until_completed()), - ) - .await; + ); + let results = tokio::select! { + _ = cancel.cancelled() => { + bail!("cancelled"); + }, + r = results => r + }; if results.iter().any(|r| r.is_err()) { anyhow::bail!("some downloads failed") } @@ -735,9 +762,8 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> { Ok(()) } else { // Sleep forever. - loop { - tokio::time::sleep(Duration::from_secs(60)).await; - } + cancel.cancelled().await; + bail!("cancelled"); } } else { anyhow::bail!("no torrents were added") diff --git a/crates/upnp-serve/src/lib.rs b/crates/upnp-serve/src/lib.rs index 2c70c20..eb781b6 100644 --- a/crates/upnp-serve/src/lib.rs +++ b/crates/upnp-serve/src/lib.rs @@ -72,6 +72,7 @@ impl UpnpServer { description_http_location, server_string: "Linux/3.4 UPnP/1.0 rqbit/1".to_owned(), notify_interval: Duration::from_secs(60), + shutdown: opts.cancellation_token.clone(), }) .await .context("error initializing SsdpRunner")?; diff --git a/crates/upnp-serve/src/ssdp.rs b/crates/upnp-serve/src/ssdp.rs index 30facd9..8d6416b 100644 --- a/crates/upnp-serve/src/ssdp.rs +++ b/crates/upnp-serve/src/ssdp.rs @@ -6,6 +6,7 @@ use std::{ use anyhow::{bail, Context}; use bstr::BStr; use tokio::net::UdpSocket; +use tokio_util::sync::CancellationToken; use tracing::{debug, trace, warn}; use crate::constants::{UPNP_KIND_MEDIASERVER, UPNP_KIND_ROOT_DEVICE}; @@ -14,6 +15,9 @@ const UPNP_PORT: u16 = 1900; const UPNP_BROADCAST_IP: Ipv4Addr = Ipv4Addr::new(239, 255, 255, 250); const UPNP_BROADCAST_ADDR: SocketAddrV4 = SocketAddrV4::new(UPNP_BROADCAST_IP, UPNP_PORT); +const NTS_ALIVE: &str = "ssdp:alive"; +const NTS_BYEBYE: &str = "ssdp:byebye"; + #[derive(Debug)] pub enum SsdpMessage<'a, 'h> { MSearch(SsdpMSearchRequest<'a>), @@ -93,6 +97,7 @@ pub struct SsdpRunnerOptions { pub description_http_location: String, pub server_string: String, pub notify_interval: Duration, + pub shutdown: CancellationToken, } pub struct SsdpRunner { @@ -104,10 +109,10 @@ impl SsdpRunner { pub async fn new(opts: SsdpRunnerOptions) -> anyhow::Result { let bind_addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, UPNP_PORT); trace!(addr=?bind_addr, "binding UDP"); - let socket = - tokio::net::UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, UPNP_PORT)) - .await - .context("error binding")?; + let socket = tokio::net::UdpSocket::bind(bind_addr) + .await + .context(bind_addr) + .context("error binding")?; trace!(multiaddr=?UPNP_BROADCAST_IP, interface=?Ipv4Addr::UNSPECIFIED, "joining multicast v4 group"); socket @@ -117,7 +122,7 @@ impl SsdpRunner { Ok(Self { opts, socket }) } - fn generate_notify_message(&self, kind: &str) -> String { + fn generate_notify_message(&self, kind: &str, nts: &str) -> String { let usn: &str = &self.opts.usn; let description_http_location = &self.opts.description_http_location; let server: &str = &self.opts.server_string; @@ -128,7 +133,7 @@ Host: {bcast_addr}\r Cache-Control: max-age=75\r Location: {description_http_location}\r NT: {kind}\r -NTS: ssdp:alive\r +NTS: {nts}\r Server: {server}\r USN: {usn}::{kind}\r \r @@ -153,9 +158,9 @@ Content-Length: 0\r\n\r\n" ) } - async fn try_send_notifies(&self) { + async fn try_send_notifies(&self, nts: &str) { for kind in [UPNP_KIND_ROOT_DEVICE, UPNP_KIND_MEDIASERVER] { - let msg = self.generate_notify_message(kind); + let msg = self.generate_notify_message(kind, nts); trace!(content=?msg, addr=?UPNP_BROADCAST_ADDR, "sending SSDP NOTIFY"); if let Err(e) = self .socket @@ -163,15 +168,17 @@ Content-Length: 0\r\n\r\n" .await { warn!(error=?e, "error sending SSDP NOTIFY") + } else { + debug!(kind, nts, "sent SSDP NOTIFY") } } } - async fn task_send_notifies_periodically(&self) -> anyhow::Result<()> { + async fn task_send_alive_notifies_periodically(&self) -> anyhow::Result<()> { let mut interval = tokio::time::interval(self.opts.notify_interval); loop { interval.tick().await; - self.try_send_notifies().await; + self.try_send_notifies(NTS_ALIVE).await; } } @@ -242,7 +249,7 @@ MX: 2\r\n\r\n"; self.send_msearch().await?; let t1 = self.task_respond_on_msearches(); - let t2 = self.task_send_notifies_periodically(); + let t2 = self.task_send_alive_notifies_periodically(); tokio::pin!(t1); tokio::pin!(t2); @@ -250,6 +257,10 @@ MX: 2\r\n\r\n"; tokio::select! { r = &mut t1 => r, r = &mut t2 => r, + _ = self.opts.shutdown.cancelled() => { + self.try_send_notifies(NTS_BYEBYE).await; + bail!("canceled"); + } } } }