Use futures instead of asyncio.Event for async_get_integrations (#93060)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
J. Nick Koston 2023-05-14 14:42:04 -05:00 committed by GitHub
parent b95405a7e9
commit d5a0824924
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -889,39 +889,41 @@ async def async_get_integrations(
cache = hass.data[DATA_INTEGRATIONS] = {} cache = hass.data[DATA_INTEGRATIONS] = {}
results: dict[str, Integration | Exception] = {} results: dict[str, Integration | Exception] = {}
needed: dict[str, asyncio.Event] = {} needed: dict[str, asyncio.Future[None]] = {}
in_progress: dict[str, asyncio.Event] = {} in_progress: dict[str, asyncio.Future[None]] = {}
for domain in domains: for domain in domains:
int_or_evt: Integration | asyncio.Event | None = cache.get(domain, _UNDEF) int_or_fut: Integration | asyncio.Future[None] | None = cache.get(
if isinstance(int_or_evt, asyncio.Event): domain, _UNDEF
in_progress[domain] = int_or_evt )
elif int_or_evt is not _UNDEF: if isinstance(int_or_fut, asyncio.Future):
results[domain] = cast(Integration, int_or_evt) in_progress[domain] = int_or_fut
elif int_or_fut is not _UNDEF:
results[domain] = cast(Integration, int_or_fut)
elif "." in domain: elif "." in domain:
results[domain] = ValueError(f"Invalid domain {domain}") results[domain] = ValueError(f"Invalid domain {domain}")
else: else:
needed[domain] = cache[domain] = asyncio.Event() needed[domain] = cache[domain] = hass.loop.create_future()
if in_progress: if in_progress:
await asyncio.gather(*[event.wait() for event in in_progress.values()]) await asyncio.gather(*in_progress.values())
for domain in in_progress: for domain in in_progress:
# When we have waited and it's _UNDEF, it doesn't exist # When we have waited and it's _UNDEF, it doesn't exist
# We don't cache that it doesn't exist, or else people can't fix it # We don't cache that it doesn't exist, or else people can't fix it
# and then restart, because their config will never be valid. # and then restart, because their config will never be valid.
if (int_or_evt := cache.get(domain, _UNDEF)) is _UNDEF: if (int_or_fut := cache.get(domain, _UNDEF)) is _UNDEF:
results[domain] = IntegrationNotFound(domain) results[domain] = IntegrationNotFound(domain)
else: else:
results[domain] = cast(Integration, int_or_evt) results[domain] = cast(Integration, int_or_fut)
# First we look for custom components # First we look for custom components
if needed: if needed:
# Instead of using resolve_from_root we use the cache of custom # Instead of using resolve_from_root we use the cache of custom
# components to find the integration. # components to find the integration.
custom = await async_get_custom_components(hass) custom = await async_get_custom_components(hass)
for domain, event in needed.items(): for domain, future in needed.items():
if integration := custom.get(domain): if integration := custom.get(domain):
results[domain] = cache[domain] = integration results[domain] = cache[domain] = integration
event.set() future.set_result(None)
for domain in results: for domain in results:
if domain in needed: if domain in needed:
@ -934,7 +936,7 @@ async def async_get_integrations(
integrations = await hass.async_add_executor_job( integrations = await hass.async_add_executor_job(
_resolve_integrations_from_root, hass, components, list(needed) _resolve_integrations_from_root, hass, components, list(needed)
) )
for domain, event in needed.items(): for domain, future in needed.items():
int_or_exc = integrations.get(domain) int_or_exc = integrations.get(domain)
if not int_or_exc: if not int_or_exc:
cache.pop(domain) cache.pop(domain)
@ -946,7 +948,7 @@ async def async_get_integrations(
results[domain] = exc results[domain] = exc
else: else:
results[domain] = cache[domain] = int_or_exc results[domain] = cache[domain] = int_or_exc
event.set() future.set_result(None)
return results return results