Compare commits

...

1 Commits

Author SHA1 Message Date
abmantis
4c63435aaf Add get_triggers_for_target websocket command 2025-11-17 23:22:09 +00:00
4 changed files with 246 additions and 5 deletions

View File

@@ -37,6 +37,7 @@ from homeassistant.exceptions import (
from homeassistant.helpers import (
config_validation as cv,
entity,
entity_registry as er,
target as target_helpers,
template,
)
@@ -107,6 +108,7 @@ def async_register_commands(
async_reg(hass, handle_entity_source)
async_reg(hass, handle_execute_script)
async_reg(hass, handle_extract_from_target)
async_reg(hass, handle_get_triggers_for_target)
async_reg(hass, handle_fire_event)
async_reg(hass, handle_get_config)
async_reg(hass, handle_get_services)
@@ -877,6 +879,55 @@ def handle_extract_from_target(
connection.send_result(msg["id"], extracted_dict)
@decorators.websocket_command(
{
vol.Required("type"): "get_triggers_for_target",
vol.Required("target"): cv.TARGET_FIELDS,
vol.Optional("expand_group", default=False): bool,
}
)
@decorators.async_response
async def handle_get_triggers_for_target(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle get triggers for target command.
This command returns all triggers that can be used on any entity that are currently
part of a target.
"""
selector_data = target_helpers.TargetSelectorData(msg["target"])
extracted = target_helpers.async_extract_referenced_entity_ids(
hass, selector_data, expand_group=msg["expand_group"]
)
_LOGGER.debug("Extracted entities for trigger lookup: %s", extracted)
descriptions = await async_get_all_trigger_descriptions(hass)
target_triggers = [
trigger
for trigger, description in descriptions.items()
if description and "target" in description
]
_LOGGER.debug("Available target triggers: %s", target_triggers)
# Collect both platform domains and integrations from all entities in target
entity_domains: set[str] = set()
entity_registry = er.async_get(hass)
for entity_id in extracted.referenced.union(extracted.indirectly_referenced):
entity_domains.add(entity_id.split(".")[0])
if entity_entry := entity_registry.async_get(entity_id):
entity_domains.add(entity_entry.platform)
_LOGGER.debug("Relevant domains: %s", entity_domains)
triggers_for_target = {
trigger
for trigger in target_triggers
if trigger.split(".")[0] in entity_domains
}
connection.send_result(msg["id"], triggers_for_target)
@decorators.websocket_command(
{
vol.Required("type"): "subscribe_trigger",

View File

@@ -805,6 +805,8 @@ async def async_get_all_descriptions(
continue
description = {"fields": yaml_description.get("fields", {})}
if target := yaml_description.get("target"):
description["target"] = target
new_descriptions_cache[missing_trigger] = description

View File

@@ -1608,12 +1608,16 @@ def mock_integration(
top_level_files: set[str] | None = None,
) -> loader.Integration:
"""Mock an integration."""
integration = loader.Integration(
hass,
path = (
f"{loader.PACKAGE_BUILTIN}.{module.DOMAIN}"
if built_in
else f"{loader.PACKAGE_CUSTOM_COMPONENTS}.{module.DOMAIN}",
pathlib.Path(""),
else f"{loader.PACKAGE_CUSTOM_COMPONENTS}.{module.DOMAIN}"
)
integration = loader.Integration(
hass,
path,
pathlib.Path(path.replace(".", "/")),
module.mock_manifest(),
top_level_files,
)
@@ -2009,7 +2013,6 @@ def get_sensor_display_state(
)
) is None:
return value
with suppress(TypeError, ValueError):
numerical_value = float(value)
value = f"{numerical_value:z.{precision}f}"

View File

@@ -3480,3 +3480,188 @@ async def test_extract_from_target_validation_error(
assert msg["type"] == const.TYPE_RESULT
assert not msg["success"]
assert "error" in msg
@patch("annotatedyaml.loader.load_yaml")
@patch.object(Integration, "has_triggers", return_value=True)
async def test_get_triggers_for_target(
mock_has_triggers: Mock,
mock_load_yaml: Mock,
hass: HomeAssistant,
websocket_client: MockHAClientWebSocket,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
label_registry: lr.LabelRegistry,
) -> None:
"""Test get_triggers_for_target command with mixed target types."""
async def async_get_triggers(hass: HomeAssistant) -> dict[str, type]:
return {
"turned_on": Mock,
"non_target_trigger": Mock,
}
mock_platform(hass, "light.trigger", Mock(async_get_triggers=async_get_triggers))
mock_platform(hass, "switch.trigger", Mock(async_get_triggers=async_get_triggers))
mock_platform(hass, "sensor.trigger", Mock(async_get_triggers=async_get_triggers))
mock_platform(
hass,
"mqtt.trigger",
Mock(async_get_triggers=AsyncMock(return_value={"_": Mock})),
)
common_trigger_descriptions = """
turned_on:
target:
entity:
domain: light
fields:
behavior:
required: true
default: any
selector:
select:
options:
- first
- last
- any
non_target_trigger:
fields:
behavior:
required: true
default: any
selector:
select:
options:
- first
- last
- any
"""
mqtt_trigger_descriptions = """
_:
target:
entity:
"""
def _load_yaml(fname, secrets=None):
if fname.endswith("mqtt/triggers.yaml"):
trigger_descriptions = mqtt_trigger_descriptions
else:
trigger_descriptions = common_trigger_descriptions
with io.StringIO(trigger_descriptions) as file:
return parse_yaml(file)
mock_load_yaml.side_effect = _load_yaml
assert await async_setup_component(hass, "light", {})
assert await async_setup_component(hass, "switch", {})
assert await async_setup_component(hass, "sensor", {})
assert await async_setup_component(hass, "mqtt", {})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
device1 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "device1")},
)
device2 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "device2")},
)
area_device = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "device3")},
)
label2_device = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={("test", "device4")},
)
kitchen_area = area_registry.async_create("Kitchen")
living_room_area = area_registry.async_create("Living Room")
label_area = area_registry.async_create("Bathroom")
label1 = label_registry.async_create("Test Label 1")
label2 = label_registry.async_create("Test Label 2")
# Associate devices with areas and labels
device_registry.async_update_device(area_device.id, area_id=kitchen_area.id)
device_registry.async_update_device(label2_device.id, labels={label2.label_id})
area_registry.async_update(label_area.id, labels={label1.label_id})
# Setup entities with targets
device1_entity1 = entity_registry.async_get_or_create(
"light", "test", "unique1", device_id=device1.id
)
entity_registry.async_get_or_create(
"switch", "test", "unique2", device_id=device1.id
)
entity_registry.async_get_or_create(
"sensor", "test", "unique3", device_id=device2.id
)
entity_registry.async_get_or_create(
"light", "test", "unique4", device_id=area_device.id
)
area_entity = entity_registry.async_get_or_create("switch", "test", "unique5")
entity_registry.async_get_or_create(
"light", "test", "unique6", device_id=label2_device.id
)
label_mqtt_entity = entity_registry.async_get_or_create("switch", "mqtt", "unique7")
# Associate entities with areas and labels
entity_registry.async_update_entity(
area_entity.entity_id, area_id=living_room_area.id
)
entity_registry.async_update_entity(
label_mqtt_entity.entity_id, labels={label1.label_id}
)
async def call_command(target: dict[str, str]) -> Any:
await websocket_client.send_json_auto_id(
{"type": "get_triggers_for_target", "target": target}
)
return await websocket_client.receive_json()
def assert_triggers(msg: dict[str, Any], expected: list[str]) -> None:
"""Assert triggers for target match expected."""
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert sorted(msg["result"]) == sorted(expected)
# Test entity target - unknown entity
msg = await call_command({"entity_id": ["light.unknown_entity"]})
assert_triggers(msg, ["light.turned_on"])
# Test entity target - existing entity from the mqtt integration
msg = await call_command({"entity_id": [label_mqtt_entity.entity_id]})
assert_triggers(msg, ["mqtt", "switch.turned_on"])
# Test device target - multiple devices
msg = await call_command({"device_id": [device1.id, device2.id]})
assert_triggers(msg, ["light.turned_on", "sensor.turned_on", "switch.turned_on"])
# Test area target - multiple areas
msg = await call_command({"area_id": [kitchen_area.id, living_room_area.id]})
assert_triggers(msg, ["light.turned_on", "switch.turned_on"])
# Test label target - multiple labels
msg = await call_command({"label_id": [label1.label_id, label2.label_id]})
assert_triggers(msg, ["light.turned_on", "mqtt", "switch.turned_on"])
# Test mixed target types
msg = await call_command(
{
"entity_id": [device1_entity1.entity_id],
"device_id": [device2.id],
"area_id": [kitchen_area.id],
"label_id": [label1.label_id],
}
)
assert_triggers(
msg, ["light.turned_on", "mqtt", "sensor.turned_on", "switch.turned_on"]
)