diff --git a/Cargo.lock b/Cargo.lock index b2794d8..5367cf1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1695,7 +1695,9 @@ version = "0.1.1" dependencies = [ "anyhow", "async-recursion", + "bstr", "futures", + "httparse", "network-interface", "reqwest", "serde", diff --git a/crates/upnp/Cargo.toml b/crates/upnp/Cargo.toml index f242e49..16d0929 100644 --- a/crates/upnp/Cargo.toml +++ b/crates/upnp/Cargo.toml @@ -22,6 +22,8 @@ futures = "0.3" url = "2" async-recursion = "1" network-interface = { git = 'https://github.com/ikatson/network-interface', branch = "compile-on-freebsd" } +httparse = "1.9.4" +bstr = "1.10.0" [dev-dependencies] tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/crates/upnp/src/lib.rs b/crates/upnp/src/lib.rs index 20ea7bb..5c0d6e1 100644 --- a/crates/upnp/src/lib.rs +++ b/crates/upnp/src/lib.rs @@ -1,11 +1,12 @@ use anyhow::{bail, Context}; +use bstr::BStr; use futures::{stream::FuturesUnordered, StreamExt, TryFutureExt}; use network_interface::NetworkInterfaceConfig; use reqwest::Client; use serde::Deserialize; use serde_xml_rs::from_str; use std::{ - collections::{HashMap, HashSet}, + collections::HashSet, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, time::Duration, }; @@ -20,7 +21,7 @@ const SSDP_SEARCH_REQUEST: &str = "M-SEARCH * HTTP/1.1\r\n\ Host: 239.255.255.250:1900\r\n\ Man: \"ssdp:discover\"\r\n\ MX: 3\r\n\ - ST: upnp:rootdevice\r\n\ + ST: urn:schemas-upnp-org:service:WANIPConnection:1\r\n\ \r\n"; fn get_local_ip_relative_to(local_dest: Ipv4Addr) -> anyhow::Result { @@ -265,16 +266,30 @@ async fn discover_services(location: Url) -> anyhow::Result { } fn parse_upnp_discover_response( - response: &str, + buf: &[u8], received_from: SocketAddr, ) -> anyhow::Result { - let mut headers = HashMap::new(); - for line in response.lines() { - if let Some((key, value)) = line.split_once(": ") { - headers.insert(key.to_lowercase(), value.trim_end().to_string()); + let mut headers = [httparse::EMPTY_HEADER; 16]; + let mut resp = httparse::Response::new(&mut headers); + resp.parse(buf).context("error parsing response")?; + + trace!(?resp, "parsed SSDP response"); + match resp.code { + Some(200) => {} + other => anyhow::bail!("bad response code {other:?}, expected 200"), + } + let mut location = None; + for header in resp.headers { + match header.name { + "location" | "LOCATION" | "Location" => { + location = Some( + std::str::from_utf8(header.value).context("bad utf-8 in location header")?, + ) + } + _ => continue, } } - let location = headers.get("location").context("missing location header")?; + let location = location.context("missing location header")?; let location = Url::parse(location).with_context(|| format!("failed parsing location {location}"))?; Ok(UpnpDiscoverResponse { @@ -352,20 +367,13 @@ impl UpnpPortForwarder { timed_out = true; } Ok((len, addr)) = socket.recv_from(&mut buffer), if !timed_out => { - let response = match std::str::from_utf8(&buffer[..len]) { - Ok(response) => response, - Err(_) => { - warn!(%addr, "received invalid utf-8"); - continue; - }, - }; - trace!(%addr, response, "response"); + let response = &buffer[..len]; match parse_upnp_discover_response(response, addr) { Ok(r) => { tx.send(r)?; discovered += 1; }, - Err(e) => warn!("failed to parse response: {e:#}"), + Err(e) => warn!(error=?e, response=?BStr::new(response), "failed to parse SSDP response"), }; }, }