Compare commits

...

1 Commits

Author SHA1 Message Date
abmantis
8bd5eed295 Cache flattened service descriptions in websocket api 2025-11-29 00:10:38 +00:00
2 changed files with 107 additions and 7 deletions

View File

@@ -24,9 +24,14 @@ from homeassistant.helpers.trigger import (
async_get_all_descriptions as async_get_all_trigger_descriptions,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.util.hass_dict import HassKey
_LOGGER = logging.getLogger(__name__)
FLAT_SERVICE_DESCRIPTIONS_CACHE: HassKey[
tuple[dict[str, dict[str, Any]], dict[str, dict[str, Any] | None]]
] = HassKey("websocket_automation_flat_service_description_cache")
@dataclass(slots=True, kw_only=True)
class _EntityFilter:
@@ -217,12 +222,29 @@ async def async_get_services_for_target(
) -> set[str]:
"""Get services for a target."""
descriptions = await async_get_all_service_descriptions(hass)
# Flatten dicts to be keyed by domain.name to match trigger/condition format
descriptions_flatten = {
f"{domain}.{service_name}": desc
for domain, services in descriptions.items()
for service_name, desc in services.items()
}
def get_flattened_service_descriptions() -> dict[str, dict[str, Any] | None]:
"""Get flattened service descriptions, with caching."""
if FLAT_SERVICE_DESCRIPTIONS_CACHE in hass.data:
cached_descriptions, cached_flat_descriptions = hass.data[
FLAT_SERVICE_DESCRIPTIONS_CACHE
]
# If the descriptions are the same, return the cached flattened version
if cached_descriptions is descriptions:
return cached_flat_descriptions
# Flatten dicts to be keyed by domain.name to match trigger/condition format
flat_descriptions = {
f"{domain}.{service_name}": desc
for domain, services in descriptions.items()
for service_name, desc in services.items()
}
hass.data[FLAT_SERVICE_DESCRIPTIONS_CACHE] = (
descriptions,
flat_descriptions,
)
return flat_descriptions
return _async_get_automation_components_for_target(
hass, target_selector, expand_group, descriptions_flatten
hass, target_selector, expand_group, get_flattened_service_descriptions()
)

View File

@@ -4162,3 +4162,81 @@ async def test_get_services_for_target(
"switch.turn_on",
],
)
@patch("annotatedyaml.loader.load_yaml")
@patch.object(Integration, "has_services", return_value=True)
async def test_get_services_for_target_caching(
mock_has_services: Mock,
mock_load_yaml: Mock,
hass: HomeAssistant,
websocket_client: MockHAClientWebSocket,
) -> None:
"""Test that flattened service descriptions are cached and reused."""
def get_common_service_descriptions(domain: str):
return f"""
turn_on:
target:
entity:
domain: {domain}
"""
def _load_yaml(fname, secrets=None):
domain = fname.split("/")[-2]
with io.StringIO(get_common_service_descriptions(domain)) as file:
return parse_yaml(file)
mock_load_yaml.side_effect = _load_yaml
await hass.async_block_till_done()
hass.services.async_register("light", "turn_on", lambda call: None)
hass.services.async_register("switch", "turn_on", lambda call: None)
await hass.async_block_till_done()
async def call_command():
await websocket_client.send_json_auto_id(
{
"type": "get_services_for_target",
"target": {"entity_id": ["light.test1"]},
}
)
msg = await websocket_client.receive_json()
assert msg["success"]
with patch(
"homeassistant.components.websocket_api.automation._async_get_automation_components_for_target",
return_value=set(),
) as mock_get_components:
# First call: should create and cache flat descriptions
await call_command()
assert mock_get_components.call_count == 1
first_flat_descriptions = mock_get_components.call_args_list[0][0][3]
assert first_flat_descriptions == {
"light.turn_on": {
"fields": {},
"target": {"entity": [{"domain": ["light"]}]},
},
"switch.turn_on": {
"fields": {},
"target": {"entity": [{"domain": ["switch"]}]},
},
}
# Second call: should reuse cached flat descriptions
await call_command()
assert mock_get_components.call_count == 2
second_flat_descriptions = mock_get_components.call_args_list[1][0][3]
assert first_flat_descriptions is second_flat_descriptions
# Register a new service to invalidate cache
hass.services.async_register("new_domain", "new_service", lambda call: None)
await hass.async_block_till_done()
# Third call: cache should be rebuilt
await call_command()
assert mock_get_components.call_count == 3
third_flat_descriptions = mock_get_components.call_args_list[2][0][3]
assert "new_domain.new_service" in third_flat_descriptions
assert third_flat_descriptions is not first_flat_descriptions