Avoid the same task ending up multiple times in the ssdp asyncio.gather (#42815)

This commit is contained in:
J. Nick Koston 2020-11-04 06:25:37 -10:00 committed by GitHub
parent c094f4b907
commit 3661d8380a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -79,7 +79,8 @@ class Scanner:
async def _process_entries(self, entries):
"""Process SSDP entries."""
tasks = []
entries_to_process = []
unseen_locations = set()
for entry in entries:
key = (entry.st, entry.location)
@ -89,21 +90,21 @@ class Scanner:
self.seen.add(key)
tasks.append(self._process_entry(entry))
entries_to_process.append(entry)
if not tasks:
if entry.location not in self._description_cache:
unseen_locations.add(entry.location)
if not entries_to_process:
return
to_load = [
result for result in await asyncio.gather(*tasks) if result is not None
]
if not to_load:
return
if unseen_locations:
await self._fetch_descriptions(list(unseen_locations))
tasks = []
for entry, info, domains in to_load:
for entry in entries_to_process:
info, domains = self._process_entry(entry)
for domain in domains:
_LOGGER.debug("Discovered %s at %s", domain, entry.location)
tasks.append(
@ -112,9 +113,29 @@ class Scanner:
)
)
await asyncio.wait(tasks)
if tasks:
await asyncio.gather(*tasks)
async def _process_entry(self, entry):
async def _fetch_descriptions(self, locations):
"""Fetch descriptions from locations."""
for idx, result in enumerate(
await asyncio.gather(
*[self._fetch_description(location) for location in locations],
return_exceptions=True,
)
):
location = locations[idx]
if isinstance(result, Exception):
_LOGGER.exception(
"Failed to fetch ssdp data from: %s", location, exc_info=result
)
continue
self._description_cache[location] = result
def _process_entry(self, entry):
"""Process a single entry."""
info = {"st": entry.st}
@ -123,17 +144,13 @@ class Scanner:
info[key] = entry.values[key]
if entry.location:
# Multiple entries usually share same location. Make sure
# we fetch it only once.
info_req = self._description_cache.get(entry.location)
if info_req is None:
info_req = self._description_cache[
entry.location
] = self.hass.async_create_task(self._fetch_description(entry.location))
return (None, [])
info.update(await info_req)
info.update(info_req)
domains = set()
for domain, matchers in self._integration_matchers.items():
@ -142,9 +159,9 @@ class Scanner:
domains.add(domain)
if domains:
return (entry, info_from_entry(entry, info), domains)
return (info_from_entry(entry, info), domains)
return None
return (None, [])
async def _fetch_description(self, xml_location):
"""Fetch an XML description."""