Fix race in loading service descriptions (#109316)

This commit is contained in:
J. Nick Koston 2024-02-01 12:34:23 -06:00 committed by GitHub
parent c61a2b46d4
commit ed726db974
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 84 additions and 0 deletions

View File

@ -608,6 +608,11 @@ async def async_get_all_descriptions(
# Files we loaded for missing descriptions # Files we loaded for missing descriptions
loaded: dict[str, JSON_TYPE] = {} loaded: dict[str, JSON_TYPE] = {}
# We try to avoid making a copy in the event the cache is good,
# but now we must make a copy in case new services get added
# while we are loading the missing ones so we do not
# add the new ones to the cache without their descriptions
services = {domain: service.copy() for domain, service in services.items()}
if domains_with_missing_services: if domains_with_missing_services:
ints_or_excs = await async_get_integrations(hass, domains_with_missing_services) ints_or_excs = await async_get_integrations(hass, domains_with_missing_services)

View File

@ -1,4 +1,5 @@
"""Test service helpers.""" """Test service helpers."""
import asyncio
from collections.abc import Iterable from collections.abc import Iterable
from copy import deepcopy from copy import deepcopy
from typing import Any from typing import Any
@ -782,6 +783,84 @@ async def test_async_get_all_descriptions_dynamically_created_services(
} }
async def test_async_get_all_descriptions_new_service_added_while_loading(
hass: HomeAssistant,
) -> None:
"""Test async_get_all_descriptions when a new service is added while loading translations."""
group = hass.components.group
group_config = {group.DOMAIN: {}}
await async_setup_component(hass, group.DOMAIN, group_config)
descriptions = await service.async_get_all_descriptions(hass)
assert len(descriptions) == 1
assert "description" in descriptions["group"]["reload"]
assert "fields" in descriptions["group"]["reload"]
logger = hass.components.logger
logger_domain = logger.DOMAIN
logger_config = {logger_domain: {}}
translations_called = asyncio.Event()
translations_wait = asyncio.Event()
async def async_get_translations(
hass: HomeAssistant,
language: str,
category: str,
integrations: Iterable[str] | None = None,
config_flow: bool | None = None,
) -> dict[str, Any]:
"""Return all backend translations."""
translations_called.set()
await translations_wait.wait()
translation_key_prefix = f"component.{logger_domain}.services.set_default_level"
return {
f"{translation_key_prefix}.name": "Translated name",
f"{translation_key_prefix}.description": "Translated description",
f"{translation_key_prefix}.fields.level.name": "Field name",
f"{translation_key_prefix}.fields.level.description": "Field description",
f"{translation_key_prefix}.fields.level.example": "Field example",
}
with patch(
"homeassistant.helpers.service.translation.async_get_translations",
side_effect=async_get_translations,
):
await async_setup_component(hass, logger_domain, logger_config)
task = asyncio.create_task(service.async_get_all_descriptions(hass))
await translations_called.wait()
# Now register a new service while translations are being loaded
hass.services.async_register(logger_domain, "new_service", lambda x: None, None)
service.async_set_service_schema(
hass, logger_domain, "new_service", {"description": "new service"}
)
translations_wait.set()
descriptions = await task
# Two domains should be present
assert len(descriptions) == 2
logger_descriptions = descriptions[logger_domain]
# The new service was loaded after the translations were loaded
# so it should not appear until the next time we fetch
assert "new_service" not in logger_descriptions
set_default_level = logger_descriptions["set_default_level"]
assert set_default_level["name"] == "Translated name"
assert set_default_level["description"] == "Translated description"
set_default_level_fields = set_default_level["fields"]
assert set_default_level_fields["level"]["name"] == "Field name"
assert set_default_level_fields["level"]["description"] == "Field description"
assert set_default_level_fields["level"]["example"] == "Field example"
descriptions = await service.async_get_all_descriptions(hass)
assert "description" in descriptions[logger_domain]["new_service"]
assert descriptions[logger_domain]["new_service"]["description"] == "new service"
async def test_register_with_mixed_case(hass: HomeAssistant) -> None: async def test_register_with_mixed_case(hass: HomeAssistant) -> None:
"""Test registering a service with mixed case. """Test registering a service with mixed case.