Fix race in async_get_integrations with multiple calls when an integration is not found (#139270)

* Fix race in async_get_integrations with multiple calls when an integration is not found

* Fix race in async_get_integrations with multiple calls when an integration is not found

* Fix race in async_get_integrations with multiple calls when an integration is not found

* tweaks

* tweaks

* tweaks

* restore lost comment

* tweak test

* comment cache

* improve test

* improve comment
This commit is contained in:
J. Nick Koston 2025-02-25 18:08:53 +00:00 committed by GitHub
parent a910fb879c
commit a1d1f6ec97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 93 additions and 31 deletions

View File

@ -40,7 +40,6 @@ from .generated.ssdp import SSDP
from .generated.usb import USB from .generated.usb import USB
from .generated.zeroconf import HOMEKIT, ZEROCONF from .generated.zeroconf import HOMEKIT, ZEROCONF
from .helpers.json import json_bytes, json_fragment from .helpers.json import json_bytes, json_fragment
from .helpers.typing import UNDEFINED
from .util.hass_dict import HassKey from .util.hass_dict import HassKey
from .util.json import JSON_DECODE_EXCEPTIONS, json_loads from .util.json import JSON_DECODE_EXCEPTIONS, json_loads
@ -125,9 +124,9 @@ BLOCKED_CUSTOM_INTEGRATIONS: dict[str, BlockedIntegration] = {
DATA_COMPONENTS: HassKey[dict[str, ModuleType | ComponentProtocol]] = HassKey( DATA_COMPONENTS: HassKey[dict[str, ModuleType | ComponentProtocol]] = HassKey(
"components" "components"
) )
DATA_INTEGRATIONS: HassKey[dict[str, Integration | asyncio.Future[None]]] = HassKey( DATA_INTEGRATIONS: HassKey[
"integrations" dict[str, Integration | asyncio.Future[Integration | IntegrationNotFound]]
) ] = HassKey("integrations")
DATA_MISSING_PLATFORMS: HassKey[dict[str, bool]] = HassKey("missing_platforms") DATA_MISSING_PLATFORMS: HassKey[dict[str, bool]] = HassKey("missing_platforms")
DATA_CUSTOM_COMPONENTS: HassKey[ DATA_CUSTOM_COMPONENTS: HassKey[
dict[str, Integration] | asyncio.Future[dict[str, Integration]] dict[str, Integration] | asyncio.Future[dict[str, Integration]]
@ -1345,7 +1344,7 @@ def async_get_loaded_integration(hass: HomeAssistant, domain: str) -> Integratio
Raises IntegrationNotLoaded if the integration is not loaded. Raises IntegrationNotLoaded if the integration is not loaded.
""" """
cache = hass.data[DATA_INTEGRATIONS] cache = hass.data[DATA_INTEGRATIONS]
int_or_fut = cache.get(domain, UNDEFINED) int_or_fut = cache.get(domain)
# Integration is never subclassed, so we can check for type # Integration is never subclassed, so we can check for type
if type(int_or_fut) is Integration: if type(int_or_fut) is Integration:
return int_or_fut return int_or_fut
@ -1355,7 +1354,7 @@ def async_get_loaded_integration(hass: HomeAssistant, domain: str) -> Integratio
async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration: async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration:
"""Get integration.""" """Get integration."""
cache = hass.data[DATA_INTEGRATIONS] cache = hass.data[DATA_INTEGRATIONS]
if type(int_or_fut := cache.get(domain, UNDEFINED)) is Integration: if type(int_or_fut := cache.get(domain)) is Integration:
return int_or_fut return int_or_fut
integrations_or_excs = await async_get_integrations(hass, [domain]) integrations_or_excs = await async_get_integrations(hass, [domain])
int_or_exc = integrations_or_excs[domain] int_or_exc = integrations_or_excs[domain]
@ -1370,15 +1369,17 @@ async def async_get_integrations(
"""Get integrations.""" """Get integrations."""
cache = hass.data[DATA_INTEGRATIONS] cache = hass.data[DATA_INTEGRATIONS]
results: dict[str, Integration | Exception] = {} results: dict[str, Integration | Exception] = {}
needed: dict[str, asyncio.Future[None]] = {} needed: dict[str, asyncio.Future[Integration | IntegrationNotFound]] = {}
in_progress: dict[str, asyncio.Future[None]] = {} in_progress: dict[str, asyncio.Future[Integration | IntegrationNotFound]] = {}
for domain in domains: for domain in domains:
int_or_fut = cache.get(domain, UNDEFINED) int_or_fut = cache.get(domain)
# Integration is never subclassed, so we can check for type # Integration is never subclassed, so we can check for type
if type(int_or_fut) is Integration: if type(int_or_fut) is Integration:
results[domain] = int_or_fut results[domain] = int_or_fut
elif int_or_fut is not UNDEFINED: elif int_or_fut:
in_progress[domain] = cast(asyncio.Future[None], int_or_fut) if TYPE_CHECKING:
assert isinstance(int_or_fut, asyncio.Future)
in_progress[domain] = int_or_fut
elif "." in domain: elif "." in domain:
results[domain] = ValueError(f"Invalid domain {domain}") results[domain] = ValueError(f"Invalid domain {domain}")
else: else:
@ -1386,14 +1387,13 @@ async def async_get_integrations(
if in_progress: if in_progress:
await asyncio.wait(in_progress.values()) await asyncio.wait(in_progress.values())
for domain in in_progress: # Here we retrieve the results we waited for
# When we have waited and it's UNDEFINED, it doesn't exist # instead of reading them from the cache since
# We don't cache that it doesn't exist, or else people can't fix it # reading from the cache will have a race if
# and then restart, because their config will never be valid. # the integration gets removed from the cache
if (int_or_fut := cache.get(domain, UNDEFINED)) is UNDEFINED: # because it was not found.
results[domain] = IntegrationNotFound(domain) for domain, future in in_progress.items():
else: results[domain] = future.result()
results[domain] = cast(Integration, int_or_fut)
if not needed: if not needed:
return results return results
@ -1405,7 +1405,7 @@ async def async_get_integrations(
for domain, future in needed.items(): for domain, future in needed.items():
if integration := custom.get(domain): if integration := custom.get(domain):
results[domain] = cache[domain] = integration results[domain] = cache[domain] = integration
future.set_result(None) future.set_result(integration)
for domain in results: for domain in results:
if domain in needed: if domain in needed:
@ -1419,18 +1419,24 @@ async def async_get_integrations(
_resolve_integrations_from_root, hass, components, needed _resolve_integrations_from_root, hass, components, needed
) )
for domain, future in needed.items(): for domain, future in needed.items():
int_or_exc = integrations.get(domain) if integration := integrations.get(domain):
if not int_or_exc: results[domain] = cache[domain] = integration
cache.pop(domain) future.set_result(integration)
results[domain] = IntegrationNotFound(domain)
elif isinstance(int_or_exc, Exception):
cache.pop(domain)
exc = IntegrationNotFound(domain)
exc.__cause__ = int_or_exc
results[domain] = exc
else: else:
results[domain] = cache[domain] = int_or_exc # We don't cache that it doesn't exist as configuration
future.set_result(None) # validation that relies on integrations being loaded
# would be unfixable. For example if a custom integration
# was temporarily removed.
# This allows restoring a missing integration to fix the
# validation error so the config validations checks do not
# block restarting.
del cache[domain]
exc = IntegrationNotFound(domain)
results[domain] = exc
# We don't use set_exception because
# we expect there will be cases where
# the a future exception is never retrieved
future.set_result(exc)
return results return results

View File

@ -2039,3 +2039,59 @@ async def test_manifest_json_fragment_round_trip(hass: HomeAssistant) -> None:
json_loads(json_dumps(integration.manifest_json_fragment)) json_loads(json_dumps(integration.manifest_json_fragment))
== integration.manifest == integration.manifest
) )
async def test_async_get_integrations_multiple_non_existent(
hass: HomeAssistant,
) -> None:
"""Test async_get_integrations with multiple non-existent integrations."""
integrations = await loader.async_get_integrations(hass, ["does_not_exist"])
assert isinstance(integrations["does_not_exist"], loader.IntegrationNotFound)
async def slow_load_failure(
*args: Any, **kwargs: Any
) -> dict[str, loader.Integration]:
await asyncio.sleep(0.1)
return {}
with patch.object(hass, "async_add_executor_job", slow_load_failure):
task1 = hass.async_create_task(
loader.async_get_integrations(hass, ["does_not_exist", "does_not_exist2"])
)
# Task one should now be waiting for executor job
task2 = hass.async_create_task(
loader.async_get_integrations(hass, ["does_not_exist"])
)
# Task two should be waiting for the futures created in task one
task3 = hass.async_create_task(
loader.async_get_integrations(hass, ["does_not_exist2", "does_not_exist"])
)
# Task three should be waiting for the futures created in task one
integrations_1 = await task1
assert isinstance(integrations_1["does_not_exist"], loader.IntegrationNotFound)
assert isinstance(integrations_1["does_not_exist2"], loader.IntegrationNotFound)
integrations_2 = await task2
assert isinstance(integrations_2["does_not_exist"], loader.IntegrationNotFound)
integrations_3 = await task3
assert isinstance(integrations_3["does_not_exist2"], loader.IntegrationNotFound)
assert isinstance(integrations_3["does_not_exist"], loader.IntegrationNotFound)
# Make sure IntegrationNotFound is not cached
# so configuration errors can be fixed as to
# not prevent Home Assistant from being restarted
integration = loader.Integration(
hass,
"custom_components.does_not_exist",
None,
{
"name": "Does not exist",
"domain": "does_not_exist",
},
)
with patch.object(
loader,
"_resolve_integrations_from_root",
return_value={"does_not_exist": integration},
):
integrations = await loader.async_get_integrations(hass, ["does_not_exist"])
assert integrations["does_not_exist"] is integration