Graceful shutdown

This commit is contained in:
Igor Katson 2024-08-26 18:25:22 +01:00
parent 4ae22f2a3d
commit bf9d75e748
No known key found for this signature in database
GPG key ID: B4EC22B66D61A3F5
6 changed files with 81 additions and 27 deletions

12
Cargo.lock generated
View file

@ -2530,8 +2530,10 @@ dependencies = [
"regex", "regex",
"serde", "serde",
"serde_json", "serde_json",
"signal-hook",
"size_format", "size_format",
"tokio", "tokio",
"tokio-util",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"upnp-serve", "upnp-serve",
@ -2818,6 +2820,16 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801"
dependencies = [
"libc",
"signal-hook-registry",
]
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.2" version = "1.4.2"

View file

@ -403,6 +403,8 @@ pub struct SessionOptions {
// socks5://[username:password@]host:port // socks5://[username:password@]host:port
pub socks_proxy_url: Option<String>, pub socks_proxy_url: Option<String>,
pub cancellation_token: Option<CancellationToken>,
// how many concurrent torrent initializations can happen // how many concurrent torrent initializations can happen
pub concurrent_init_limit: Option<usize>, pub concurrent_init_limit: Option<usize>,
@ -481,7 +483,7 @@ impl Session {
) -> BoxFuture<'static, anyhow::Result<Arc<Self>>> { ) -> BoxFuture<'static, anyhow::Result<Arc<Self>>> {
async move { async move {
let peer_id = opts.peer_id.unwrap_or_else(generate_peer_id); let peer_id = opts.peer_id.unwrap_or_else(generate_peer_id);
let token = CancellationToken::new(); let token = opts.cancellation_token.take().unwrap_or_default();
let (tcp_listener, tcp_listen_port) = let (tcp_listener, tcp_listen_port) =
if let Some(port_range) = opts.listen_port_range.clone() { if let Some(port_range) = opts.listen_port_range.clone() {

View file

@ -46,6 +46,8 @@ bytes = "1.5.0"
openssl = { version = "0.10", features = ["vendored"], optional = true } openssl = { version = "0.10", features = ["vendored"], optional = true }
upnp-serve = { path = "../upnp-serve" } upnp-serve = { path = "../upnp-serve" }
libc = "0.2.158" libc = "0.2.158"
signal-hook = "0.3.17"
tokio-util = "0.7.11"
[dev-dependencies] [dev-dependencies]
futures = { version = "0.3" } futures = { version = "0.3" }

View file

@ -1,4 +1,4 @@
use std::{io, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; use std::{io, net::SocketAddr, path::PathBuf, sync::Arc, thread, time::Duration};
use anyhow::{bail, Context}; use anyhow::{bail, Context};
use clap::{CommandFactory, Parser, ValueEnum}; use clap::{CommandFactory, Parser, ValueEnum};
@ -17,7 +17,8 @@ use librqbit::{
}; };
use size_format::SizeFormatterBinary as SF; use size_format::SizeFormatterBinary as SF;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tracing::{error, error_span, info, trace_span, warn}; use tokio_util::sync::CancellationToken;
use tracing::{debug, error, error_span, info, trace_span, warn};
#[derive(Debug, Clone, Copy, ValueEnum)] #[derive(Debug, Clone, Copy, ValueEnum)]
enum LogLevel { enum LogLevel {
@ -357,10 +358,30 @@ fn main() -> anyhow::Result<()> {
.max_blocking_threads(opts.max_blocking_threads as usize) .max_blocking_threads(opts.max_blocking_threads as usize)
.build()?; .build()?;
rt.block_on(async_main(opts)) let token = tokio_util::sync::CancellationToken::new();
{
let token = token.clone();
use signal_hook::{consts::SIGINT, consts::SIGTERM, iterator::Signals};
let mut signals = Signals::new([SIGINT, SIGTERM])?;
thread::spawn(move || {
if let Some(sig) = signals.forever().next() {
warn!("Received signal {:?}", sig);
token.cancel();
}
});
}
rt.block_on(async move {
let res = async_main(opts, token.clone()).await;
if let Err(e) = res {
error!("error running rqbit: {e:?}");
std::process::exit(1);
}
std::process::exit(0);
})
} }
async fn async_main(opts: Opts) -> anyhow::Result<()> { async fn async_main(opts: Opts, cancel: CancellationToken) -> anyhow::Result<()> {
let log_config = init_logging(InitLoggingOptions { let log_config = init_logging(InitLoggingOptions {
default_rust_log_value: Some(match opts.log_level.unwrap_or(LogLevel::Info) { default_rust_log_value: Some(match opts.log_level.unwrap_or(LogLevel::Info) {
LogLevel::Trace => "trace", LogLevel::Trace => "trace",
@ -420,6 +441,7 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> {
concurrent_init_limit: Some(opts.concurrent_init_limit), concurrent_init_limit: Some(opts.concurrent_init_limit),
root_span: None, root_span: None,
fastresume: false, fastresume: false,
cancellation_token: Some(cancel.clone()),
}; };
let stats_printer = |session: Arc<Session>| async move { let stats_printer = |session: Arc<Session>| async move {
@ -553,15 +575,15 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> {
Some(srv) => { Some(srv) => {
let upnp_fut = srv.run_ssdp_forever(); let upnp_fut = srv.run_ssdp_forever();
tokio::pin!(http_api_fut);
tokio::pin!(upnp_fut);
tokio::select! { tokio::select! {
r = &mut http_api_fut => r, r = http_api_fut => r,
r = &mut upnp_fut => r r = upnp_fut => r
} }
} }
None => http_api_fut.await, None => tokio::select! {
_ = cancel.cancelled() => bail!("cancelled"),
r = http_api_fut => r,
},
}; };
res.context("error running rqbit server") res.context("error running rqbit server")
@ -726,8 +748,13 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> {
if download_opts.exit_on_finish { if download_opts.exit_on_finish {
let results = futures::future::join_all( let results = futures::future::join_all(
handles.iter().map(|h| h.wait_until_completed()), handles.iter().map(|h| h.wait_until_completed()),
) );
.await; let results = tokio::select! {
_ = cancel.cancelled() => {
bail!("cancelled");
},
r = results => r
};
if results.iter().any(|r| r.is_err()) { if results.iter().any(|r| r.is_err()) {
anyhow::bail!("some downloads failed") anyhow::bail!("some downloads failed")
} }
@ -735,9 +762,8 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> {
Ok(()) Ok(())
} else { } else {
// Sleep forever. // Sleep forever.
loop { cancel.cancelled().await;
tokio::time::sleep(Duration::from_secs(60)).await; bail!("cancelled");
}
} }
} else { } else {
anyhow::bail!("no torrents were added") anyhow::bail!("no torrents were added")

View file

@ -72,6 +72,7 @@ impl UpnpServer {
description_http_location, description_http_location,
server_string: "Linux/3.4 UPnP/1.0 rqbit/1".to_owned(), server_string: "Linux/3.4 UPnP/1.0 rqbit/1".to_owned(),
notify_interval: Duration::from_secs(60), notify_interval: Duration::from_secs(60),
shutdown: opts.cancellation_token.clone(),
}) })
.await .await
.context("error initializing SsdpRunner")?; .context("error initializing SsdpRunner")?;

View file

@ -6,6 +6,7 @@ use std::{
use anyhow::{bail, Context}; use anyhow::{bail, Context};
use bstr::BStr; use bstr::BStr;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
use crate::constants::{UPNP_KIND_MEDIASERVER, UPNP_KIND_ROOT_DEVICE}; use crate::constants::{UPNP_KIND_MEDIASERVER, UPNP_KIND_ROOT_DEVICE};
@ -14,6 +15,9 @@ const UPNP_PORT: u16 = 1900;
const UPNP_BROADCAST_IP: Ipv4Addr = Ipv4Addr::new(239, 255, 255, 250); const UPNP_BROADCAST_IP: Ipv4Addr = Ipv4Addr::new(239, 255, 255, 250);
const UPNP_BROADCAST_ADDR: SocketAddrV4 = SocketAddrV4::new(UPNP_BROADCAST_IP, UPNP_PORT); const UPNP_BROADCAST_ADDR: SocketAddrV4 = SocketAddrV4::new(UPNP_BROADCAST_IP, UPNP_PORT);
const NTS_ALIVE: &str = "ssdp:alive";
const NTS_BYEBYE: &str = "ssdp:byebye";
#[derive(Debug)] #[derive(Debug)]
pub enum SsdpMessage<'a, 'h> { pub enum SsdpMessage<'a, 'h> {
MSearch(SsdpMSearchRequest<'a>), MSearch(SsdpMSearchRequest<'a>),
@ -93,6 +97,7 @@ pub struct SsdpRunnerOptions {
pub description_http_location: String, pub description_http_location: String,
pub server_string: String, pub server_string: String,
pub notify_interval: Duration, pub notify_interval: Duration,
pub shutdown: CancellationToken,
} }
pub struct SsdpRunner { pub struct SsdpRunner {
@ -104,10 +109,10 @@ impl SsdpRunner {
pub async fn new(opts: SsdpRunnerOptions) -> anyhow::Result<Self> { pub async fn new(opts: SsdpRunnerOptions) -> anyhow::Result<Self> {
let bind_addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, UPNP_PORT); let bind_addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, UPNP_PORT);
trace!(addr=?bind_addr, "binding UDP"); trace!(addr=?bind_addr, "binding UDP");
let socket = let socket = tokio::net::UdpSocket::bind(bind_addr)
tokio::net::UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, UPNP_PORT)) .await
.await .context(bind_addr)
.context("error binding")?; .context("error binding")?;
trace!(multiaddr=?UPNP_BROADCAST_IP, interface=?Ipv4Addr::UNSPECIFIED, "joining multicast v4 group"); trace!(multiaddr=?UPNP_BROADCAST_IP, interface=?Ipv4Addr::UNSPECIFIED, "joining multicast v4 group");
socket socket
@ -117,7 +122,7 @@ impl SsdpRunner {
Ok(Self { opts, socket }) Ok(Self { opts, socket })
} }
fn generate_notify_message(&self, kind: &str) -> String { fn generate_notify_message(&self, kind: &str, nts: &str) -> String {
let usn: &str = &self.opts.usn; let usn: &str = &self.opts.usn;
let description_http_location = &self.opts.description_http_location; let description_http_location = &self.opts.description_http_location;
let server: &str = &self.opts.server_string; let server: &str = &self.opts.server_string;
@ -128,7 +133,7 @@ Host: {bcast_addr}\r
Cache-Control: max-age=75\r Cache-Control: max-age=75\r
Location: {description_http_location}\r Location: {description_http_location}\r
NT: {kind}\r NT: {kind}\r
NTS: ssdp:alive\r NTS: {nts}\r
Server: {server}\r Server: {server}\r
USN: {usn}::{kind}\r USN: {usn}::{kind}\r
\r \r
@ -153,9 +158,9 @@ Content-Length: 0\r\n\r\n"
) )
} }
async fn try_send_notifies(&self) { async fn try_send_notifies(&self, nts: &str) {
for kind in [UPNP_KIND_ROOT_DEVICE, UPNP_KIND_MEDIASERVER] { for kind in [UPNP_KIND_ROOT_DEVICE, UPNP_KIND_MEDIASERVER] {
let msg = self.generate_notify_message(kind); let msg = self.generate_notify_message(kind, nts);
trace!(content=?msg, addr=?UPNP_BROADCAST_ADDR, "sending SSDP NOTIFY"); trace!(content=?msg, addr=?UPNP_BROADCAST_ADDR, "sending SSDP NOTIFY");
if let Err(e) = self if let Err(e) = self
.socket .socket
@ -163,15 +168,17 @@ Content-Length: 0\r\n\r\n"
.await .await
{ {
warn!(error=?e, "error sending SSDP NOTIFY") warn!(error=?e, "error sending SSDP NOTIFY")
} else {
debug!(kind, nts, "sent SSDP NOTIFY")
} }
} }
} }
async fn task_send_notifies_periodically(&self) -> anyhow::Result<()> { async fn task_send_alive_notifies_periodically(&self) -> anyhow::Result<()> {
let mut interval = tokio::time::interval(self.opts.notify_interval); let mut interval = tokio::time::interval(self.opts.notify_interval);
loop { loop {
interval.tick().await; interval.tick().await;
self.try_send_notifies().await; self.try_send_notifies(NTS_ALIVE).await;
} }
} }
@ -242,7 +249,7 @@ MX: 2\r\n\r\n";
self.send_msearch().await?; self.send_msearch().await?;
let t1 = self.task_respond_on_msearches(); let t1 = self.task_respond_on_msearches();
let t2 = self.task_send_notifies_periodically(); let t2 = self.task_send_alive_notifies_periodically();
tokio::pin!(t1); tokio::pin!(t1);
tokio::pin!(t2); tokio::pin!(t2);
@ -250,6 +257,10 @@ MX: 2\r\n\r\n";
tokio::select! { tokio::select! {
r = &mut t1 => r, r = &mut t1 => r,
r = &mut t2 => r, r = &mut t2 => r,
_ = self.opts.shutdown.cancelled() => {
self.try_send_notifies(NTS_BYEBYE).await;
bail!("canceled");
}
} }
} }
} }