use anyhow::{bail, Context}; use bstr::BStr; use futures::{stream::FuturesUnordered, StreamExt, TryFutureExt}; use network_interface::NetworkInterfaceConfig; use reqwest::Client; use serde::Deserialize; use std::{ collections::HashSet, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}, time::Duration, }; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tracing::{debug, error_span, trace, warn, Instrument, Span}; use url::Url; const SERVICE_TYPE_WAN_IP_CONNECTION: &str = "urn:schemas-upnp-org:service:WANIPConnection:1"; const SSDP_MULTICAST_IP: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(239, 255, 255, 250), 1900)); pub const SSDP_SEARCH_WAN_IPCONNECTION_ST: &str = "urn:schemas-upnp-org:service:WANIPConnection:1"; pub const SSDP_SEARCH_ROOT_ST: &str = "upnp:rootdevice"; pub fn make_ssdp_search_request(kind: &str) -> String { format!( "M-SEARCH * HTTP/1.1\r\n\ Host: 239.255.255.250:1900\r\n\ Man: \"ssdp:discover\"\r\n\ MX: 3\r\n\ ST: {kind}\r\n\ \r\n" ) } // .to_bits() isn't yet available on min rust version we support (1.75 at the time of writing this) const fn ip_bits_v6(addr: Ipv6Addr) -> u128 { u128::from_be_bytes(addr.octets()) } pub fn ipv6_is_link_local(ip: Ipv6Addr) -> bool { const LL: Ipv6Addr = Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 0); const MASK: Ipv6Addr = Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0); ip_bits_v6(ip) & ip_bits_v6(MASK) == ip_bits_v6(LL) & ip_bits_v6(MASK) } pub fn get_local_ip_relative_to(local_dest: SocketAddr) -> anyhow::Result { fn ip_bits_v4(addr: Ipv4Addr) -> u32 { u32::from_be_bytes(addr.octets()) } fn masked_v4(ip: Ipv4Addr, mask: Ipv4Addr) -> u32 { ip_bits_v4(ip) & ip_bits_v4(mask) } fn masked_v6(ip: Ipv6Addr, mask: Ipv6Addr) -> u128 { ip_bits_v6(ip) & ip_bits_v6(mask) } let interfaces = network_interface::NetworkInterface::show().context("error listing network interfaces")?; for i in interfaces { for addr in i.addr { trace!(%local_dest, nic=i.index, ip=?addr.ip(), nm=?addr.netmask(), "dbg"); match (local_dest, addr.ip(), addr.netmask()) { // We are connecting to ourselves, return itself. (l, a, _) if l.ip() == a => return Ok(addr.ip()), // IPv4 masks match. (SocketAddr::V4(l), IpAddr::V4(a), Some(IpAddr::V4(m))) if masked_v4(*l.ip(), m) == masked_v4(a, m) => { return Ok(addr.ip()) } // Return IPv6 link-local addresses when source is link-local address and there's a scope_id set. (SocketAddr::V6(l), IpAddr::V6(a), _) if ipv6_is_link_local(*l.ip()) && l.scope_id() > 0 => { if ipv6_is_link_local(a) && l.scope_id() == i.index { return Ok(addr.ip()); } } // If V6 masks match, return. (SocketAddr::V6(l), IpAddr::V6(a), Some(IpAddr::V6(m))) if masked_v6(*l.ip(), m) == masked_v6(a, m) => { return Ok(addr.ip()) } // For IPv6 fallback to returning a random (first encountered) IPv6 address. (SocketAddr::V6(_), IpAddr::V6(_), None) => return Ok(addr.ip()), _ => continue, } } } bail!("couldn't find a local ip address") } async fn forward_port( control_url: Url, local_ip: IpAddr, port: u16, lease_duration: Duration, ) -> anyhow::Result<()> { let request_body = format!( r#" {port} TCP {port} {local_ip} 1 rust UPnP {} "#, lease_duration.as_secs() ); let url = control_url; let client = reqwest::Client::new(); let response = client .post(url.clone()) .header("Content-Type", "text/xml") .header( "SOAPAction", format!("\"{}#AddPortMapping\"", SERVICE_TYPE_WAN_IP_CONNECTION), ) .body(request_body) .send() .await .context("error sending")?; let status = response.status(); let response_text = response .text() .await .context("error reading response text")?; trace!(status = %status, text=response_text, "AddPortMapping response"); if !status.is_success() { bail!("failed port forwarding: {}", status); } else { debug!(%local_ip, port, "successfully port forwarded"); } Ok(()) } #[derive(Clone, Debug, Deserialize, PartialEq, Eq)] pub struct RootDesc { #[serde(rename = "device")] pub devices: Vec, } #[derive(Default, Clone, Debug, Deserialize, PartialEq, Eq)] pub struct DeviceList { #[serde(rename = "device")] pub devices: Vec, } #[derive(Debug, Clone, Deserialize, PartialEq, Eq)] pub struct Device { #[serde(rename = "deviceType")] pub device_type: String, #[serde(rename = "friendlyName", default)] pub friendly_name: String, #[serde(rename = "serviceList", default)] pub service_list: ServiceList, #[serde(rename = "deviceList", default)] pub device_list: DeviceList, } impl Device { pub fn iter_services( &self, parent: Span, ) -> Box + '_> { let self_span = self.span(parent); let services = self.service_list.services.iter().map({ let self_span = self_span.clone(); move |s| (s.span(self_span.clone()), s) }); Box::new(services.chain(self.device_list.devices.iter().flat_map({ let self_span = self_span.clone(); move |d| d.iter_services(self_span.clone()) }))) } pub fn span(&self, parent: tracing::Span) -> tracing::Span { error_span!(parent: parent, "device", device = self.name()) } } impl Device { pub fn name(&self) -> &str { if self.friendly_name.is_empty() { return &self.device_type; } &self.friendly_name } } #[derive(Clone, Debug, Deserialize, Default, PartialEq, Eq)] pub struct ServiceList { #[serde(rename = "service", default)] pub services: Vec, } #[derive(Clone, Debug, Deserialize, PartialEq, Eq)] pub struct Service { #[serde(rename = "serviceType")] pub service_type: String, #[serde(rename = "controlURL")] pub control_url: String, #[serde(rename = "SCPDURL")] pub scpd_url: String, #[serde(rename = "eventSubURL", default)] pub event_sub_url: Option, } impl Service { pub fn span(&self, parent: tracing::Span) -> tracing::Span { error_span!(parent: parent, "service", url = self.control_url) } } #[derive(Debug)] struct UpnpEndpoint { discover_response: UpnpDiscoverResponse, data: RootDesc, } impl UpnpEndpoint { fn location(&self) -> &Url { &self.discover_response.location } fn span(&self) -> tracing::Span { error_span!("upnp_endpoint", location = %self.location()) } fn iter_services(&self) -> impl Iterator + '_ { let self_span = self.span(); self.data .devices .iter() .flat_map(move |d| d.iter_services(self_span.clone())) } fn my_local_ip(&self) -> anyhow::Result { let received_from = self.discover_response.received_from; let local_ip = get_local_ip_relative_to(received_from) .with_context(|| format!("can't determine local IP relative to {received_from}"))?; Ok(local_ip) } fn get_wan_ip_control_urls(&self) -> impl Iterator + '_ { self.iter_services() .filter(|(_, s)| s.service_type == SERVICE_TYPE_WAN_IP_CONNECTION) .map(|(span, s)| (span, self.discover_response.location.join(&s.control_url))) .filter_map(|(span, url)| match url { Ok(url) => Some((span, url)), Err(e) => { debug!("bad control url: {e:#}"); None } }) } } #[derive(Debug)] pub struct UpnpDiscoverResponse { pub received_from: SocketAddr, pub location: Url, } pub async fn discover_services(location: Url) -> anyhow::Result { let response = Client::new() .get(location.clone()) .send() .await .context("failed to send GET request")? .text() .await .context("failed to read response body")?; trace!("received from {location}: {response}"); let root_desc: RootDesc = quick_xml::de::from_str(&response) .context("failed to parse response body as xml") .map_err(|e| { debug!("failed to parse this XML: {response}"); e })?; Ok(root_desc) } pub fn parse_upnp_discover_response( buf: &[u8], received_from: SocketAddr, ) -> anyhow::Result { let mut headers = [httparse::EMPTY_HEADER; 16]; let mut resp = httparse::Response::new(&mut headers); resp.parse(buf).context("error parsing response")?; trace!(?resp, "parsed SSDP response"); match resp.code { Some(200) => {} other => anyhow::bail!("bad response code {other:?}, expected 200"), } let mut location = None; for header in resp.headers { match header.name { "location" | "LOCATION" | "Location" => { location = Some( std::str::from_utf8(header.value).context("bad utf-8 in location header")?, ) } _ => continue, } } let location = location.context("missing location header")?; let location = Url::parse(location).with_context(|| format!("failed parsing location {location}"))?; Ok(UpnpDiscoverResponse { location, received_from, }) } pub async fn discover_once( tx: &UnboundedSender, kind: &str, timeout: Duration, ) -> anyhow::Result<()> { let socket = tokio::net::UdpSocket::bind("0.0.0.0:0") .await .context("failed to bind UDP socket")?; let message = make_ssdp_search_request(kind); socket .send_to(message.as_bytes(), SSDP_MULTICAST_IP) .await .context("failed to send SSDP search request")?; let mut buffer = [0; 2048]; let timeout = tokio::time::sleep(timeout); let mut timed_out = false; tokio::pin!(timeout); let mut discovered = 0; while !timed_out { tokio::select! { _ = &mut timeout, if !timed_out => { timed_out = true; } Ok((len, addr)) = socket.recv_from(&mut buffer), if !timed_out => { let response = &buffer[..len]; match parse_upnp_discover_response(response, addr) { Ok(r) => { tx.send(r)?; discovered += 1; }, Err(e) => warn!(error=?e, response=?BStr::new(response), "failed to parse SSDP response"), }; }, } } debug!("discovered {discovered} endpoints"); Ok(()) } pub struct UpnpPortForwarderOptions { pub lease_duration: Duration, pub discover_interval: Duration, pub discover_timeout: Duration, } impl Default for UpnpPortForwarderOptions { fn default() -> Self { Self { discover_interval: Duration::from_secs(60), discover_timeout: Duration::from_secs(10), lease_duration: Duration::from_secs(60), } } } pub struct UpnpPortForwarder { ports: Vec, opts: UpnpPortForwarderOptions, } impl UpnpPortForwarder { pub fn new(ports: Vec, opts: Option) -> anyhow::Result { if ports.is_empty() { bail!("empty ports") } Ok(Self { ports, opts: opts.unwrap_or_default(), }) } async fn parse_endpoint( &self, discover_response: UpnpDiscoverResponse, ) -> anyhow::Result { let services = discover_services(discover_response.location.clone()).await?; Ok(UpnpEndpoint { discover_response, data: services, }) } async fn discover_once( &self, tx: &UnboundedSender, ) -> anyhow::Result<()> { discover_once( tx, SSDP_SEARCH_WAN_IPCONNECTION_ST, self.opts.discover_timeout, ) .await } async fn discovery(&self, tx: UnboundedSender) -> anyhow::Result<()> { let mut discover_interval = tokio::time::interval(self.opts.discover_interval); discover_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); loop { discover_interval.tick().await; if let Err(e) = self.discover_once(&tx).await { warn!("failed to run discovery: {e:#}"); } } } async fn manage_port(&self, control_url: Url, local_ip: IpAddr, port: u16) -> ! { let lease_duration = self.opts.lease_duration; let mut interval = tokio::time::interval(lease_duration / 2); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); loop { interval.tick().await; if let Err(e) = forward_port(control_url.clone(), local_ip, port, lease_duration).await { warn!("failed to forward port: {e:#}"); } } } async fn manage_service(&self, control_url: Url, local_ip: IpAddr) -> anyhow::Result<()> { futures::future::join_all(self.ports.iter().cloned().map(|port| { self.manage_port(control_url.clone(), local_ip, port) .instrument(error_span!("manage_port", port = port)) })) .await; Ok(()) } pub async fn run_forever(self) -> ! { let (discover_tx, mut discover_rx) = unbounded_channel(); let discovery = self.discovery(discover_tx); let mut spawned_tasks = HashSet::::new(); let mut endpoints = FuturesUnordered::new(); let mut service_managers = FuturesUnordered::new(); tokio::pin!(discovery); loop { tokio::select! { _ = &mut discovery => {}, r = discover_rx.recv() => { let r = r.unwrap(); let location = r.location.clone(); endpoints.push(self.parse_endpoint(r).map_err(|e| { debug!("error parsing endpoint: {e:#}"); e }).instrument(error_span!("parse endpoint", location=location.to_string()))); }, Some(Ok(endpoint)) = endpoints.next(), if !endpoints.is_empty() => { let mut local_ip = None; for (span, control_url) in endpoint.get_wan_ip_control_urls() { if spawned_tasks.contains(&control_url) { debug!("already spawned for {}", control_url); continue; } let ip = match local_ip { Some(ip) => ip, None => { match endpoint.my_local_ip() { Ok(ip) => { local_ip = Some(ip); ip }, Err(e) => { warn!("failed to determine local IP for endpoint at {}: {:#}", endpoint.location(), e); break; } } } }; spawned_tasks.insert(control_url.clone()); service_managers.push(self.manage_service(control_url, ip).instrument(span)) } }, _ = service_managers.next(), if !service_managers.is_empty() => { }, } } } } #[cfg(test)] mod tests { use quick_xml::de::from_str; use crate::{Device, DeviceList, RootDesc, Service, ServiceList}; #[test] fn test_parse_root_desc() { let actual = from_str::(include_str!("resources/test/devices-0.xml")).unwrap(); let expected = RootDesc { devices: vec![Device { device_type: "urn:schemas-upnp-org:device:InternetGatewayDevice:1".into(), friendly_name: "ARRIS TG3492LG".into(), service_list: ServiceList { services: vec![Service { service_type: "urn:schemas-upnp-org:service:Layer3Forwarding:1".into(), control_url: "/upnp/control/Layer3Forwarding".into(), scpd_url: "/Layer3ForwardingSCPD.xml".into(), event_sub_url: Some("/upnp/event/Layer3Forwarding".into()), }], }, device_list: DeviceList { devices: vec![Device { device_type: "urn:schemas-upnp-org:device:WANDevice:1".into(), friendly_name: "WANDevice:1".into(), service_list: ServiceList { services: vec![Service { service_type: "urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1".into(), control_url: "/upnp/control/WANCommonInterfaceConfig0".into(), scpd_url: "/WANCommonInterfaceConfigSCPD.xml".into(), event_sub_url: Some("/upnp/event/WANCommonInterfaceConfig0".into()), }], }, device_list: DeviceList { devices: vec![Device { device_type: "urn:schemas-upnp-org:device:WANConnectionDevice:1" .into(), friendly_name: "WANConnectionDevice:1".into(), service_list: ServiceList { services: vec![Service { service_type: "urn:schemas-upnp-org:service:WANIPConnection:1".into(), control_url: "/upnp/control/WANIPConnection0".into(), scpd_url: "/WANIPConnectionServiceSCPD.xml".into(), event_sub_url: Some("/upnp/event/WANIPConnection0".into()), }], }, device_list: DeviceList { devices: vec![] }, }], }, }], }, }], }; assert_eq!(actual, expected); } }