From 3661d8380af5035d0c2d0cbce4b89a5de099b902 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 4 Nov 2020 06:25:37 -1000 Subject: [PATCH] Avoid the same task ending up multiple times in the ssdp asyncio.gather (#42815) --- homeassistant/components/ssdp/__init__.py | 57 +++++++++++++++-------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/homeassistant/components/ssdp/__init__.py b/homeassistant/components/ssdp/__init__.py index 7108034cfb8..b3d8a7f2898 100644 --- a/homeassistant/components/ssdp/__init__.py +++ b/homeassistant/components/ssdp/__init__.py @@ -79,7 +79,8 @@ class Scanner: async def _process_entries(self, entries): """Process SSDP entries.""" - tasks = [] + entries_to_process = [] + unseen_locations = set() for entry in entries: key = (entry.st, entry.location) @@ -89,21 +90,21 @@ class Scanner: self.seen.add(key) - tasks.append(self._process_entry(entry)) + entries_to_process.append(entry) - if not tasks: + if entry.location not in self._description_cache: + unseen_locations.add(entry.location) + + if not entries_to_process: return - to_load = [ - result for result in await asyncio.gather(*tasks) if result is not None - ] - - if not to_load: - return + if unseen_locations: + await self._fetch_descriptions(list(unseen_locations)) tasks = [] - for entry, info, domains in to_load: + for entry in entries_to_process: + info, domains = self._process_entry(entry) for domain in domains: _LOGGER.debug("Discovered %s at %s", domain, entry.location) tasks.append( @@ -112,9 +113,29 @@ class Scanner: ) ) - await asyncio.wait(tasks) + if tasks: + await asyncio.gather(*tasks) - async def _process_entry(self, entry): + async def _fetch_descriptions(self, locations): + """Fetch descriptions from locations.""" + + for idx, result in enumerate( + await asyncio.gather( + *[self._fetch_description(location) for location in locations], + return_exceptions=True, + ) + ): + location = locations[idx] + + if isinstance(result, Exception): + _LOGGER.exception( + "Failed to fetch ssdp data from: %s", location, exc_info=result + ) + continue + + self._description_cache[location] = result + + def _process_entry(self, entry): """Process a single entry.""" info = {"st": entry.st} @@ -123,17 +144,13 @@ class Scanner: info[key] = entry.values[key] if entry.location: - # Multiple entries usually share same location. Make sure # we fetch it only once. info_req = self._description_cache.get(entry.location) - if info_req is None: - info_req = self._description_cache[ - entry.location - ] = self.hass.async_create_task(self._fetch_description(entry.location)) + return (None, []) - info.update(await info_req) + info.update(info_req) domains = set() for domain, matchers in self._integration_matchers.items(): @@ -142,9 +159,9 @@ class Scanner: domains.add(domain) if domains: - return (entry, info_from_entry(entry, info), domains) + return (info_from_entry(entry, info), domains) - return None + return (None, []) async def _fetch_description(self, xml_location): """Fetch an XML description."""