diff --git a/homeassistant/components/ssdp/__init__.py b/homeassistant/components/ssdp/__init__.py index ce2901d4f1a..af8f5915f57 100644 --- a/homeassistant/components/ssdp/__init__.py +++ b/homeassistant/components/ssdp/__init__.py @@ -56,6 +56,7 @@ ATTR_UPNP_UDN = "UDN" ATTR_UPNP_UPC = "UPC" ATTR_UPNP_PRESENTATION_URL = "presentationURL" +PRIMARY_MATCH_KEYS = [ATTR_UPNP_MANUFACTURER, "st", ATTR_UPNP_DEVICE_TYPE] DISCOVERY_MAPPING = { "usn": ATTR_SSDP_USN, @@ -124,7 +125,10 @@ async def async_get_discovery_info_by_udn( async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the SSDP integration.""" - scanner = hass.data[DOMAIN] = Scanner(hass) + integration_matchers = IntegrationMatchers() + integration_matchers.async_setup(await async_get_ssdp(hass)) + + scanner = hass.data[DOMAIN] = Scanner(hass, integration_matchers) asyncio.create_task(scanner.async_start()) @@ -156,10 +160,57 @@ def _async_headers_match( return True +class IntegrationMatchers: + """Optimized integration matching.""" + + def __init__(self) -> None: + """Init optimized integration matching.""" + self._match_by_key: dict[ + str, dict[str, list[tuple[str, dict[str, str]]]] + ] | None = None + + @core_callback + def async_setup( + self, integration_matchers: dict[str, list[dict[str, str]]] + ) -> None: + """Build matchers by key. + + Here we convert the primary match keys into their own + dicts so we can do lookups of the primary match + key to find the match dict. + """ + self._match_by_key = {} + for key in PRIMARY_MATCH_KEYS: + matchers_by_key = self._match_by_key[key] = {} + for domain, matchers in integration_matchers.items(): + for matcher in matchers: + if match_value := matcher.get(key): + matchers_by_key.setdefault(match_value, []).append( + (domain, matcher) + ) + + @core_callback + def async_matching_domains(self, info_with_desc: CaseInsensitiveDict) -> set[str]: + """Find domains matching the passed CaseInsensitiveDict.""" + assert self._match_by_key is not None + domains = set() + for key, matchers_by_key in self._match_by_key.items(): + if not (match_value := info_with_desc.get(key)): + continue + for domain, matcher in matchers_by_key.get(match_value, []): + if domain in domains: + continue + if all(info_with_desc.get(k) == v for (k, v) in matcher.items()): + domains.add(domain) + return domains + + class Scanner: """Class to manage SSDP searching and SSDP advertisements.""" - def __init__(self, hass: HomeAssistant) -> None: + def __init__( + self, hass: HomeAssistant, integration_matchers: IntegrationMatchers + ) -> None: """Initialize class.""" self.hass = hass self._cancel_scan: Callable[[], None] | None = None @@ -167,7 +218,7 @@ class Scanner: self._callbacks: list[tuple[SsdpCallback, dict[str, str]]] = [] self._flow_dispatcher: FlowDispatcher | None = None self._description_cache: DescriptionCache | None = None - self._integration_matchers: dict[str, list[dict[str, str]]] | None = None + self.integration_matchers = integration_matchers @property def _ssdp_devices(self) -> list[SsdpDevice]: @@ -271,7 +322,6 @@ class Scanner: requester = AiohttpSessionRequester(session, True, 10) self._description_cache = DescriptionCache(requester) self._flow_dispatcher = FlowDispatcher(self.hass) - self._integration_matchers = await async_get_ssdp(self.hass) await self._async_start_ssdp_listeners() @@ -323,16 +373,6 @@ class Scanner: if _async_headers_match(combined_headers, match_dict) ] - @core_callback - def _async_matching_domains(self, info_with_desc: CaseInsensitiveDict) -> set[str]: - assert self._integration_matchers is not None - domains = set() - for domain, matchers in self._integration_matchers.items(): - for matcher in matchers: - if all(info_with_desc.get(k) == v for (k, v) in matcher.items()): - domains.add(domain) - return domains - async def _ssdp_listener_callback( self, ssdp_device: SsdpDevice, dst: DeviceOrServiceType, source: SsdpSource ) -> None: @@ -351,7 +391,7 @@ class Scanner: ssdp_change = SSDP_SOURCE_SSDP_CHANGE_MAPPING[source] await _async_process_callbacks(callbacks, discovery_info, ssdp_change) - for domain in self._async_matching_domains(info_with_desc): + for domain in self.integration_matchers.async_matching_domains(info_with_desc): _LOGGER.debug("Discovered %s at %s", domain, location) flow: SSDPFlow = {