Compare commits

...

8 Commits

2 changed files with 161 additions and 22 deletions

View File

@@ -33,6 +33,10 @@ FLATTENED_SERVICE_DESCRIPTIONS_CACHE: HassKey[
tuple[dict[str, dict[str, Any]], dict[str, dict[str, Any]]]
] = HassKey("websocket_automation_flat_service_description_cache")
AUTOMATION_COMPONENT_LOOKUP_CACHE: HassKey[
list[tuple[Mapping[str, Any], _AutomationComponentLookupTable]]
] = HassKey("websocket_automation_component_lookup_cache")
@dataclass(slots=True, kw_only=True)
class _EntityFilter:
@@ -107,6 +111,14 @@ class _AutomationComponentLookupData:
)
@dataclass(slots=True, kw_only=True)
class _AutomationComponentLookupTable:
"""Helper class for looking up automation components."""
domain_components: dict[str | None, list[_AutomationComponentLookupData]]
component_count: int
def _get_automation_component_domains(
target_description: dict[str, Any],
) -> set[str | None]:
@@ -138,6 +150,42 @@ def _get_automation_component_domains(
return domains
def _get_automation_component_lookup_table(
hass: HomeAssistant, component_descriptions: Mapping[str, Mapping[str, Any] | None]
) -> _AutomationComponentLookupTable:
"""Get a dict of automation components keyed by domain, along with the total number of components."""
if AUTOMATION_COMPONENT_LOOKUP_CACHE not in hass.data:
hass.data[AUTOMATION_COMPONENT_LOOKUP_CACHE] = []
cache = hass.data[AUTOMATION_COMPONENT_LOOKUP_CACHE]
for cached_descriptions, cached_lookup in cache:
if cached_descriptions is component_descriptions:
_LOGGER.debug("Using cached automation component lookup data")
return cached_lookup
lookup_table = _AutomationComponentLookupTable(
domain_components={}, component_count=0
)
for component, description in component_descriptions.items():
if description is None or CONF_TARGET not in description:
_LOGGER.debug("Skipping component %s without target description", component)
continue
domains = _get_automation_component_domains(description[CONF_TARGET])
lookup_data = _AutomationComponentLookupData.create(
component, description[CONF_TARGET]
)
for domain in domains:
lookup_table.domain_components.setdefault(domain, []).append(lookup_data)
lookup_table.component_count += 1
cache.append((component_descriptions, lookup_table))
if len(cache) > 3: # Should have a max of 3: triggers, conditions, services
cache.pop(0)
return lookup_table
def _async_get_automation_components_for_target(
hass: HomeAssistant,
target_selection: ConfigType,
@@ -155,27 +203,15 @@ def _async_get_automation_components_for_target(
)
_LOGGER.debug("Extracted entities for lookup: %s", extracted)
# Build lookup structure: domain -> list of trigger/condition/service lookup data
domain_components: dict[str | None, list[_AutomationComponentLookupData]] = {}
component_count = 0
for component, description in component_descriptions.items():
if description is None or CONF_TARGET not in description:
_LOGGER.debug("Skipping component %s without target description", component)
continue
domains = _get_automation_component_domains(description[CONF_TARGET])
lookup_data = _AutomationComponentLookupData.create(
component, description[CONF_TARGET]
)
for domain in domains:
domain_components.setdefault(domain, []).append(lookup_data)
component_count += 1
_LOGGER.debug("Automation components per domain: %s", domain_components)
lookup_table = _get_automation_component_lookup_table(hass, component_descriptions)
_LOGGER.debug(
"Automation components per domain: %s", lookup_table.domain_components
)
entity_infos = entity_sources(hass)
matched_components: set[str] = set()
for entity_id in extracted.referenced | extracted.indirectly_referenced:
if component_count == len(matched_components):
if lookup_table.component_count == len(matched_components):
# All automation components matched already, so we don't need to iterate further
break
@@ -187,7 +223,7 @@ def _async_get_automation_components_for_target(
entity_domain = entity_id.split(".")[0]
entity_integration = entity_info["domain"]
for domain in (entity_domain, entity_integration, None):
for component_data in domain_components.get(domain, []):
for component_data in lookup_table.domain_components.get(domain, []):
if component_data.component in matched_components:
continue
if component_data.matches(

View File

@@ -24,6 +24,10 @@ from homeassistant.components.websocket_api.auth import (
TYPE_AUTH_OK,
TYPE_AUTH_REQUIRED,
)
from homeassistant.components.websocket_api.automation import (
AUTOMATION_COMPONENT_LOOKUP_CACHE,
_get_automation_component_lookup_table,
)
from homeassistant.components.websocket_api.commands import (
ALL_CONDITION_DESCRIPTIONS_JSON_CACHE,
ALL_SERVICE_DESCRIPTIONS_JSON_CACHE,
@@ -3665,6 +3669,7 @@ async def test_get_triggers_conditions_for_target(
hass: HomeAssistant,
websocket_client: MockHAClientWebSocket,
automation_component: str,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test get_triggers_for_target/get_conditions_for_target command with mixed target types."""
@@ -3803,7 +3808,9 @@ async def test_get_triggers_conditions_for_target(
await hass.async_block_till_done()
async def assert_command(
target: dict[str, list[str]], expected: list[str]
target: dict[str, list[str]],
expected: list[str],
expect_lookup_cache: bool = True,
) -> Any:
"""Call the command and assert expected triggers/conditions."""
await websocket_client.send_json_auto_id(
@@ -3815,8 +3822,15 @@ async def test_get_triggers_conditions_for_target(
assert msg["success"]
assert sorted(msg["result"]) == sorted(expected)
assert (
"Using cached automation component lookup data" in caplog.text
) == expect_lookup_cache
caplog.clear()
# Test entity target - unknown entity
await assert_command({"entity_id": ["light.unknown_entity"]}, [])
await assert_command(
{"entity_id": ["light.unknown_entity"]}, [], expect_lookup_cache=False
)
# Test entity target - entity not in registry
await assert_command(
@@ -3936,6 +3950,7 @@ async def test_get_services_for_target(
mock_load_yaml: Mock,
hass: HomeAssistant,
websocket_client: MockHAClientWebSocket,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test get_services_for_target command with mixed target types."""
@@ -4047,7 +4062,11 @@ async def test_get_services_for_target(
)
await hass.async_block_till_done()
async def assert_services(target: dict[str, list[str]], expected: list[str]) -> Any:
async def assert_services(
target: dict[str, list[str]],
expected: list[str],
expect_lookup_cache: bool = True,
) -> Any:
"""Call the command and assert expected services."""
await websocket_client.send_json_auto_id(
{"type": "get_services_for_target", "target": target}
@@ -4058,8 +4077,15 @@ async def test_get_services_for_target(
assert msg["success"]
assert sorted(msg["result"]) == sorted(expected)
assert (
"Using cached automation component lookup data" in caplog.text
) == expect_lookup_cache
caplog.clear()
# Test entity target - unknown entity
await assert_services({"entity_id": ["light.unknown_entity"]}, [])
await assert_services(
{"entity_id": ["light.unknown_entity"]}, [], expect_lookup_cache=False
)
# Test entity target - entity not in registry
await assert_services(
@@ -4240,3 +4266,80 @@ async def test_get_services_for_target_caching(
third_flat_descriptions = mock_get_components.call_args_list[2][0][3]
assert "new_domain.new_service" in third_flat_descriptions
assert third_flat_descriptions is not first_flat_descriptions
async def test_get_automation_component_lookup_table_cache(
hass: HomeAssistant,
) -> None:
"""Test that _get_automation_component_lookup_table caches and rotates properly."""
trigger_descriptions: dict[str, dict[str, Any] | None] = {
"light.turned_on": {"target": {"entity": [{"domain": ["light"]}]}},
"switch.turned_on": {"target": {"entity": [{"domain": ["switch"]}]}},
}
condition_descriptions: dict[str, dict[str, Any] | None] = {
"light.is_on": {"target": {"entity": [{"domain": ["light"]}]}},
"sensor.is_above": {"target": {"entity": [{"domain": ["sensor"]}]}},
}
service_descriptions: dict[str, dict[str, Any] | None] = {
"light.turn_on": {"target": {"entity": [{"domain": ["light"]}]}},
"climate.set_temperature": {"target": {"entity": [{"domain": ["climate"]}]}},
}
# First call with triggers - cache should be created with 1 entry
trigger_result1 = _get_automation_component_lookup_table(hass, trigger_descriptions)
assert AUTOMATION_COMPONENT_LOOKUP_CACHE in hass.data
cache = hass.data[AUTOMATION_COMPONENT_LOOKUP_CACHE]
assert len(cache) == 1
# Second call with same triggers - should return cached result
trigger_result2 = _get_automation_component_lookup_table(hass, trigger_descriptions)
assert trigger_result1 is trigger_result2
assert len(cache) == 1
# Call with conditions - cache should have 2 entries
condition_result1 = _get_automation_component_lookup_table(
hass, condition_descriptions
)
assert len(cache) == 2
assert condition_result1 is not trigger_result1
# Call with services - cache should have 3 entries
service_result1 = _get_automation_component_lookup_table(hass, service_descriptions)
assert len(cache) == 3
assert service_result1 is not trigger_result1
assert service_result1 is not condition_result1
# Verify all 3 return cached results
assert (
_get_automation_component_lookup_table(hass, trigger_descriptions)
is trigger_result1
)
assert (
_get_automation_component_lookup_table(hass, condition_descriptions)
is condition_result1
)
assert (
_get_automation_component_lookup_table(hass, service_descriptions)
is service_result1
)
assert len(cache) == 3
# Add a 4th description dict - oldest cache entry should be evicted
extra_descriptions: dict[str, dict[str, Any] | None] = {
"fan.turn_on": {"target": {"entity": [{"domain": ["fan"]}]}},
}
extra_result = _get_automation_component_lookup_table(hass, extra_descriptions)
assert len(cache) == 3
# Trigger cache entry should have been evicted (it was oldest)
trigger_result3 = _get_automation_component_lookup_table(hass, trigger_descriptions)
assert trigger_result3 is not trigger_result1
assert len(cache) == 3
assert (
_get_automation_component_lookup_table(hass, extra_descriptions) is extra_result
)
assert (
_get_automation_component_lookup_table(hass, trigger_descriptions)
is trigger_result3
)