From f7e083545250e3468dd10d7c20f852d92a091c43 Mon Sep 17 00:00:00 2001 From: Igor Katson Date: Mon, 16 Sep 2024 11:16:46 +0100 Subject: [PATCH] Remove custom libc/winapi code in favor of duplicating the socket and using both socket2 and tokio --- Cargo.lock | 2 - crates/upnp-serve/Cargo.toml | 4 -- crates/upnp-serve/src/ssdp.rs | 119 +++++++--------------------------- 3 files changed, 24 insertions(+), 101 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 05b39cd..00f23c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2692,7 +2692,6 @@ dependencies = [ "gethostname", "http 1.1.0", "httparse", - "libc", "librqbit-core", "librqbit-sha1-wrapper", "librqbit-upnp", @@ -2710,7 +2709,6 @@ dependencies = [ "tracing-subscriber", "url", "uuid", - "winapi", ] [[package]] diff --git a/crates/upnp-serve/Cargo.toml b/crates/upnp-serve/Cargo.toml index 2ff0b22..176b116 100644 --- a/crates/upnp-serve/Cargo.toml +++ b/crates/upnp-serve/Cargo.toml @@ -39,10 +39,6 @@ socket2 = "0.5.7" quick-xml = { version = "0.36.1", features = ["serialize"] } network-interface = "2.0.0" futures = "0.3.30" -libc = "0.2.158" - -[target.'cfg(windows)'.dependencies] -winapi = { version = "0.3.9", features = ["winsock2"] } [dev-dependencies] tracing-subscriber = "0.3.18" diff --git a/crates/upnp-serve/src/ssdp.rs b/crates/upnp-serve/src/ssdp.rs index ceb71f2..a279ab5 100644 --- a/crates/upnp-serve/src/ssdp.rs +++ b/crates/upnp-serve/src/ssdp.rs @@ -8,7 +8,6 @@ use anyhow::{bail, Context}; use bstr::BStr; use network_interface::NetworkInterfaceConfig; use parking_lot::Mutex; -use tokio::net::UdpSocket; use tokio_util::sync::CancellationToken; use tracing::{debug, trace, warn}; @@ -109,13 +108,18 @@ pub struct SsdpRunnerOptions { pub shutdown: CancellationToken, } +struct UdpSocket { + sock2: socket2::Socket, + tokio: tokio::net::UdpSocket, +} + pub struct SsdpRunner { opts: SsdpRunnerOptions, socket_v4: Option, socket_v6: Option, } -fn socket_presetup(bind_addr: SocketAddr) -> anyhow::Result { +fn socket_presetup(bind_addr: SocketAddr) -> anyhow::Result { let domain = if bind_addr.is_ipv4() { socket2::Domain::IPV4 } else { @@ -136,10 +140,16 @@ fn socket_presetup(bind_addr: SocketAddr) -> anyhow::Result anyhow::Result { @@ -163,7 +173,7 @@ async fn bind_v4_socket() -> anyhow::Result { for ifaddr in default_multiast_membership_ip.chain(all_multicast_membership_ips) { trace!(multiaddr=?SSDM_MCAST_IPV4, interface=?ifaddr, "joining multicast v4 group"); - if let Err(e) = socket.join_multicast_v4(SSDM_MCAST_IPV4, ifaddr) { + if let Err(e) = socket.tokio.join_multicast_v4(SSDM_MCAST_IPV4, ifaddr) { debug!(multiaddr=?SSDM_MCAST_IPV4, interface=?ifaddr, "error joining multicast v4 group: {e:#}"); } } @@ -203,7 +213,7 @@ async fn bind_v6_socket() -> anyhow::Result { continue; } trace!(multiaddr=?multiaddr, interface=?nic.index, "joining multicast v6 group"); - if let Err(e) = socket.join_multicast_v6(&multiaddr, nic.index) { + if let Err(e) = socket.tokio.join_multicast_v6(&multiaddr, nic.index) { debug!(multiaddr=?multiaddr, interface=?nic.index, "error joining multicast v6 group: {e:#}"); } } @@ -219,88 +229,6 @@ struct MulticastOpts { mcast_addr: SocketAddr, } -fn set_mcast_if_v4(sock: &UdpSocket, local_ip: Ipv4Addr) -> anyhow::Result<()> { - // in_addr is the same on unix and windows and contains just the 4 bytes of IPv4 in network - // byte order. - let addr = u32::from_ne_bytes(local_ip.octets()); - let sz: usize = std::mem::size_of_val(&addr); - - trace!(addr = %local_ip, "setting IP_MULTICAST_IF"); - - let ret: i32; - #[cfg(target_os = "windows")] - { - use std::os::windows::io::AsRawSocket; - ret = unsafe { - winapi::um::winsock2::setsockopt( - sock.as_raw_socket().try_into()?, - winapi::shared::ws2def::IPPROTO_IP, - winapi::shared::ws2ipdef::IP_MULTICAST_IF, - &addr as *const _ as _, - sz.try_into()?, - ) - }; - } - #[cfg(not(target_os = "windows"))] - { - use std::os::fd::{AsFd, AsRawFd}; - ret = unsafe { - libc::setsockopt( - sock.as_fd().as_raw_fd(), - libc::IPPROTO_IP, - libc::IP_MULTICAST_IF, - &addr as *const _ as _, - sz.try_into()?, - ) - }; - } - if ret < 0 { - return Err(std::io::Error::last_os_error().into()); - } - Ok(()) -} - -fn set_mcast_if_v6(sock: &UdpSocket, dev_idx: u32) -> anyhow::Result<()> { - // in_addr is the same on unix and windows and contains just the 4 bytes of IPv4 in network - // byte order. - trace!(dev_idx, "setting IP_MULTICAST_IF"); - - let ret: i32; - #[cfg(target_os = "windows")] - { - use std::os::windows::io::AsRawSocket; - let sz: usize = std::mem::size_of_val(&dev_idx); - ret = unsafe { - winapi::um::winsock2::setsockopt( - sock.as_raw_socket().try_into()?, - winapi::shared::ws2def::IPPROTO_IPV6, - winapi::shared::ws2ipdef::IPV6_MULTICAST_IF, - &dev_idx as *const _ as _, - sz.try_into()?, - ) - }; - } - #[cfg(not(target_os = "windows"))] - { - use std::os::fd::{AsFd, AsRawFd}; - let dev_idx = dev_idx as i32; - let sz: usize = std::mem::size_of_val(&dev_idx); - ret = unsafe { - libc::setsockopt( - sock.as_fd().as_raw_fd(), - libc::IPPROTO_IPV6, - libc::IPV6_MULTICAST_IF, - &dev_idx as *const _ as _, - sz.try_into()?, - ) - }; - } - if ret < 0 { - return Err(std::io::Error::last_os_error().into()); - } - Ok(()) -} - impl MulticastOpts { fn addr_no_scope(&self) -> SocketAddr { let mut addr = self.mcast_addr; @@ -443,14 +371,14 @@ Content-Length: 0\r\n\r\n" // gets sent out of the interface we want (otherwise it'll get sent through // default one). (IpAddr::V4(ip), Some(sock_v4), _) => { - if let Err(e) = set_mcast_if_v4(sock_v4, ip) { - debug!(addr=%ip, "error calling set_mcast_if_v4: {e:#}"); + if let Err(e) = sock_v4.sock2.set_multicast_if_v4(&ip) { + debug!(addr=%ip, "error calling set_multicast_if_v4: {e:#}"); } sock_v4 } (IpAddr::V6(_), _, Some(sock_v6)) => { - if let Err(e) = set_mcast_if_v6(sock_v6, opts.interface_id) { - debug!(oif_id=opts.interface_id, "error calling set_mcast_if_v6: {e:#}"); + if let Err(e) = sock_v6.sock2.set_multicast_if_v6(opts.interface_id) { + debug!(oif_id=opts.interface_id, "error calling set_multicast_if_v6: {e:#}"); } sock_v6 }, @@ -460,7 +388,7 @@ Content-Length: 0\r\n\r\n" }, }; - match sock.send_to(payload.as_slice(), opts.mcast_addr).await { + match sock.tokio.send_to(payload.as_slice(), opts.mcast_addr).await { Ok(sz) => trace!(addr=%opts.mcast_addr, oif_id=opts.interface_id, oif_addr=%opts.interface_addr, size=sz, payload=?payload, "sent"), Err(e) => { debug!(addr=%opts.mcast_addr, oif_id=opts.interface_id, oif_addr=%opts.interface_addr, payload=?payload, "error sending: {e:#}") @@ -515,7 +443,8 @@ Content-Length: 0\r\n\r\n" if let Ok(st) = std::str::from_utf8(msg.st) { let response = self.generate_ssdp_discover_response(st, addr)?; trace!(content = response, ?addr, "sending SSDP discover response"); - sock.send_to(response.as_bytes(), addr) + sock.tokio + .send_to(response.as_bytes(), addr) .await .context("error sending")?; } @@ -531,7 +460,7 @@ Content-Length: 0\r\n\r\n" }; loop { - let (sz, addr) = match sock.recv_from(&mut buf).await { + let (sz, addr) = match sock.tokio.recv_from(&mut buf).await { Ok((sz, addr)) => (sz, addr), Err(e) => { warn!(error=?e, "error receving");