Avoid the same task ending up multiple times in the ssdp asyncio.gather (#42815)

This commit is contained in:
J. Nick Koston 2020-11-04 06:25:37 -10:00 committed by GitHub
parent c094f4b907
commit 3661d8380a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -79,7 +79,8 @@ class Scanner:
async def _process_entries(self, entries): async def _process_entries(self, entries):
"""Process SSDP entries.""" """Process SSDP entries."""
tasks = [] entries_to_process = []
unseen_locations = set()
for entry in entries: for entry in entries:
key = (entry.st, entry.location) key = (entry.st, entry.location)
@ -89,21 +90,21 @@ class Scanner:
self.seen.add(key) self.seen.add(key)
tasks.append(self._process_entry(entry)) entries_to_process.append(entry)
if not tasks: if entry.location not in self._description_cache:
unseen_locations.add(entry.location)
if not entries_to_process:
return return
to_load = [ if unseen_locations:
result for result in await asyncio.gather(*tasks) if result is not None await self._fetch_descriptions(list(unseen_locations))
]
if not to_load:
return
tasks = [] tasks = []
for entry, info, domains in to_load: for entry in entries_to_process:
info, domains = self._process_entry(entry)
for domain in domains: for domain in domains:
_LOGGER.debug("Discovered %s at %s", domain, entry.location) _LOGGER.debug("Discovered %s at %s", domain, entry.location)
tasks.append( tasks.append(
@ -112,9 +113,29 @@ class Scanner:
) )
) )
await asyncio.wait(tasks) if tasks:
await asyncio.gather(*tasks)
async def _process_entry(self, entry): async def _fetch_descriptions(self, locations):
"""Fetch descriptions from locations."""
for idx, result in enumerate(
await asyncio.gather(
*[self._fetch_description(location) for location in locations],
return_exceptions=True,
)
):
location = locations[idx]
if isinstance(result, Exception):
_LOGGER.exception(
"Failed to fetch ssdp data from: %s", location, exc_info=result
)
continue
self._description_cache[location] = result
def _process_entry(self, entry):
"""Process a single entry.""" """Process a single entry."""
info = {"st": entry.st} info = {"st": entry.st}
@ -123,17 +144,13 @@ class Scanner:
info[key] = entry.values[key] info[key] = entry.values[key]
if entry.location: if entry.location:
# Multiple entries usually share same location. Make sure # Multiple entries usually share same location. Make sure
# we fetch it only once. # we fetch it only once.
info_req = self._description_cache.get(entry.location) info_req = self._description_cache.get(entry.location)
if info_req is None: if info_req is None:
info_req = self._description_cache[ return (None, [])
entry.location
] = self.hass.async_create_task(self._fetch_description(entry.location))
info.update(await info_req) info.update(info_req)
domains = set() domains = set()
for domain, matchers in self._integration_matchers.items(): for domain, matchers in self._integration_matchers.items():
@ -142,9 +159,9 @@ class Scanner:
domains.add(domain) domains.add(domain)
if domains: if domains:
return (entry, info_from_entry(entry, info), domains) return (info_from_entry(entry, info), domains)
return None return (None, [])
async def _fetch_description(self, xml_location): async def _fetch_description(self, xml_location):
"""Fetch an XML description.""" """Fetch an XML description."""