Optimize SSDP matching (#56622)

* Optimize SSDP matching

* tweak

* remove

* remove dupe
This commit is contained in:
J. Nick Koston 2021-09-26 16:30:39 -05:00 committed by GitHub
parent f268227d64
commit 6399730d2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -56,6 +56,7 @@ ATTR_UPNP_UDN = "UDN"
ATTR_UPNP_UPC = "UPC"
ATTR_UPNP_PRESENTATION_URL = "presentationURL"
PRIMARY_MATCH_KEYS = [ATTR_UPNP_MANUFACTURER, "st", ATTR_UPNP_DEVICE_TYPE]
DISCOVERY_MAPPING = {
"usn": ATTR_SSDP_USN,
@ -124,7 +125,10 @@ async def async_get_discovery_info_by_udn(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the SSDP integration."""
scanner = hass.data[DOMAIN] = Scanner(hass)
integration_matchers = IntegrationMatchers()
integration_matchers.async_setup(await async_get_ssdp(hass))
scanner = hass.data[DOMAIN] = Scanner(hass, integration_matchers)
asyncio.create_task(scanner.async_start())
@ -156,10 +160,57 @@ def _async_headers_match(
return True
class IntegrationMatchers:
"""Optimized integration matching."""
def __init__(self) -> None:
"""Init optimized integration matching."""
self._match_by_key: dict[
str, dict[str, list[tuple[str, dict[str, str]]]]
] | None = None
@core_callback
def async_setup(
self, integration_matchers: dict[str, list[dict[str, str]]]
) -> None:
"""Build matchers by key.
Here we convert the primary match keys into their own
dicts so we can do lookups of the primary match
key to find the match dict.
"""
self._match_by_key = {}
for key in PRIMARY_MATCH_KEYS:
matchers_by_key = self._match_by_key[key] = {}
for domain, matchers in integration_matchers.items():
for matcher in matchers:
if match_value := matcher.get(key):
matchers_by_key.setdefault(match_value, []).append(
(domain, matcher)
)
@core_callback
def async_matching_domains(self, info_with_desc: CaseInsensitiveDict) -> set[str]:
"""Find domains matching the passed CaseInsensitiveDict."""
assert self._match_by_key is not None
domains = set()
for key, matchers_by_key in self._match_by_key.items():
if not (match_value := info_with_desc.get(key)):
continue
for domain, matcher in matchers_by_key.get(match_value, []):
if domain in domains:
continue
if all(info_with_desc.get(k) == v for (k, v) in matcher.items()):
domains.add(domain)
return domains
class Scanner:
"""Class to manage SSDP searching and SSDP advertisements."""
def __init__(self, hass: HomeAssistant) -> None:
def __init__(
self, hass: HomeAssistant, integration_matchers: IntegrationMatchers
) -> None:
"""Initialize class."""
self.hass = hass
self._cancel_scan: Callable[[], None] | None = None
@ -167,7 +218,7 @@ class Scanner:
self._callbacks: list[tuple[SsdpCallback, dict[str, str]]] = []
self._flow_dispatcher: FlowDispatcher | None = None
self._description_cache: DescriptionCache | None = None
self._integration_matchers: dict[str, list[dict[str, str]]] | None = None
self.integration_matchers = integration_matchers
@property
def _ssdp_devices(self) -> list[SsdpDevice]:
@ -271,7 +322,6 @@ class Scanner:
requester = AiohttpSessionRequester(session, True, 10)
self._description_cache = DescriptionCache(requester)
self._flow_dispatcher = FlowDispatcher(self.hass)
self._integration_matchers = await async_get_ssdp(self.hass)
await self._async_start_ssdp_listeners()
@ -323,16 +373,6 @@ class Scanner:
if _async_headers_match(combined_headers, match_dict)
]
@core_callback
def _async_matching_domains(self, info_with_desc: CaseInsensitiveDict) -> set[str]:
assert self._integration_matchers is not None
domains = set()
for domain, matchers in self._integration_matchers.items():
for matcher in matchers:
if all(info_with_desc.get(k) == v for (k, v) in matcher.items()):
domains.add(domain)
return domains
async def _ssdp_listener_callback(
self, ssdp_device: SsdpDevice, dst: DeviceOrServiceType, source: SsdpSource
) -> None:
@ -351,7 +391,7 @@ class Scanner:
ssdp_change = SSDP_SOURCE_SSDP_CHANGE_MAPPING[source]
await _async_process_callbacks(callbacks, discovery_info, ssdp_change)
for domain in self._async_matching_domains(info_with_desc):
for domain in self.integration_matchers.async_matching_domains(info_with_desc):
_LOGGER.debug("Discovered %s at %s", domain, location)
flow: SSDPFlow = {