Use relative trigger keys (#149846)

This commit is contained in:
Artur Pragacz 2025-08-05 00:01:40 +02:00 committed by GitHub
parent d48cc03be7
commit 53c9c42148
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 128 additions and 48 deletions

View File

@ -11,7 +11,7 @@
} }
}, },
"triggers": { "triggers": {
"mqtt": { "_": {
"trigger": "mdi:swap-horizontal" "trigger": "mdi:swap-horizontal"
} }
} }

View File

@ -1285,7 +1285,7 @@
} }
}, },
"triggers": { "triggers": {
"mqtt": { "_": {
"name": "MQTT", "name": "MQTT",
"description": "When a specific message is received on a given MQTT topic.", "description": "When a specific message is received on a given MQTT topic.",
"description_configured": "When an MQTT message has been received", "description_configured": "When an MQTT message has been received",

View File

@ -1,6 +1,6 @@
# Describes the format for MQTT triggers # Describes the format for MQTT triggers
mqtt: _:
fields: fields:
payload: payload:
example: "on" example: "on"

View File

@ -8,8 +8,8 @@ from homeassistant.helpers.trigger import Trigger
from .triggers import event, value_updated from .triggers import event, value_updated
TRIGGERS = { TRIGGERS = {
event.PLATFORM_TYPE: event.EventTrigger, event.RELATIVE_PLATFORM_TYPE: event.EventTrigger,
value_updated.PLATFORM_TYPE: value_updated.ValueUpdatedTrigger, value_updated.RELATIVE_PLATFORM_TYPE: value_updated.ValueUpdatedTrigger,
} }

View File

@ -34,8 +34,11 @@ from ..helpers import (
) )
from .trigger_helpers import async_bypass_dynamic_config_validation from .trigger_helpers import async_bypass_dynamic_config_validation
# Relative platform type should be <SUBMODULE_NAME>
RELATIVE_PLATFORM_TYPE = f"{__name__.rsplit('.', maxsplit=1)[-1]}"
# Platform type should be <DOMAIN>.<SUBMODULE_NAME> # Platform type should be <DOMAIN>.<SUBMODULE_NAME>
PLATFORM_TYPE = f"{DOMAIN}.{__name__.rsplit('.', maxsplit=1)[-1]}" PLATFORM_TYPE = f"{DOMAIN}.{RELATIVE_PLATFORM_TYPE}"
def validate_non_node_event_source(obj: dict) -> dict: def validate_non_node_event_source(obj: dict) -> dict:

View File

@ -37,8 +37,11 @@ from ..const import (
from ..helpers import async_get_nodes_from_targets, get_device_id from ..helpers import async_get_nodes_from_targets, get_device_id
from .trigger_helpers import async_bypass_dynamic_config_validation from .trigger_helpers import async_bypass_dynamic_config_validation
# Relative platform type should be <SUBMODULE_NAME>
RELATIVE_PLATFORM_TYPE = f"{__name__.rsplit('.', maxsplit=1)[-1]}"
# Platform type should be <DOMAIN>.<SUBMODULE_NAME> # Platform type should be <DOMAIN>.<SUBMODULE_NAME>
PLATFORM_TYPE = f"{DOMAIN}.{__name__.rsplit('.', maxsplit=1)[-1]}" PLATFORM_TYPE = f"{DOMAIN}.{RELATIVE_PLATFORM_TYPE}"
ATTR_FROM = "from" ATTR_FROM = "from"
ATTR_TO = "to" ATTR_TO = "to"

View File

@ -0,0 +1,21 @@
"""Helpers for automation."""
def get_absolute_description_key(domain: str, key: str) -> str:
"""Return the absolute description key."""
if not key.startswith("_"):
return f"{domain}.{key}"
key = key[1:] # Remove leading underscore
if not key:
return domain
return key
def get_relative_description_key(domain: str, key: str) -> str:
"""Return the relative description key."""
platform, *subtype = key.split(".", 1)
if platform != domain:
return f"_{key}"
if not subtype:
return "_"
return subtype[0]

View File

@ -644,6 +644,13 @@ def slug(value: Any) -> str:
raise vol.Invalid(f"invalid slug {value} (try {slg})") raise vol.Invalid(f"invalid slug {value} (try {slg})")
def underscore_slug(value: Any) -> str:
"""Validate value is a valid slug, possibly starting with an underscore."""
if value.startswith("_"):
return f"_{slug(value[1:])}"
return slug(value)
def schema_with_slug_keys( def schema_with_slug_keys(
value_schema: dict | Callable, *, slug_validator: Callable[[Any], str] = slug value_schema: dict | Callable, *, slug_validator: Callable[[Any], str] = slug
) -> Callable: ) -> Callable:

View File

@ -40,9 +40,9 @@ from homeassistant.loader import (
from homeassistant.util.async_ import create_eager_task from homeassistant.util.async_ import create_eager_task
from homeassistant.util.hass_dict import HassKey from homeassistant.util.hass_dict import HassKey
from homeassistant.util.yaml import load_yaml_dict from homeassistant.util.yaml import load_yaml_dict
from homeassistant.util.yaml.loader import JSON_TYPE
from . import config_validation as cv, selector from . import config_validation as cv, selector
from .automation import get_absolute_description_key, get_relative_description_key
from .integration_platform import async_process_integration_platforms from .integration_platform import async_process_integration_platforms
from .selector import TargetSelector from .selector import TargetSelector
from .template import Template from .template import Template
@ -100,7 +100,7 @@ def starts_with_dot(key: str) -> str:
_TRIGGERS_SCHEMA = vol.Schema( _TRIGGERS_SCHEMA = vol.Schema(
{ {
vol.Remove(vol.All(str, starts_with_dot)): object, vol.Remove(vol.All(str, starts_with_dot)): object,
cv.slug: vol.Any(None, _TRIGGER_SCHEMA), cv.underscore_slug: vol.Any(None, _TRIGGER_SCHEMA),
} }
) )
@ -139,6 +139,7 @@ async def _register_trigger_platform(
if hasattr(platform, "async_get_triggers"): if hasattr(platform, "async_get_triggers"):
for trigger_key in await platform.async_get_triggers(hass): for trigger_key in await platform.async_get_triggers(hass):
trigger_key = get_absolute_description_key(integration_domain, trigger_key)
hass.data[TRIGGERS][trigger_key] = integration_domain hass.data[TRIGGERS][trigger_key] = integration_domain
new_triggers.add(trigger_key) new_triggers.add(trigger_key)
elif hasattr(platform, "async_validate_trigger_config") or hasattr( elif hasattr(platform, "async_validate_trigger_config") or hasattr(
@ -357,9 +358,8 @@ class PluggableAction:
async def _async_get_trigger_platform( async def _async_get_trigger_platform(
hass: HomeAssistant, config: ConfigType hass: HomeAssistant, trigger_key: str
) -> TriggerProtocol: ) -> tuple[str, TriggerProtocol]:
trigger_key: str = config[CONF_PLATFORM]
platform_and_sub_type = trigger_key.split(".") platform_and_sub_type = trigger_key.split(".")
platform = platform_and_sub_type[0] platform = platform_and_sub_type[0]
platform = _PLATFORM_ALIASES.get(platform, platform) platform = _PLATFORM_ALIASES.get(platform, platform)
@ -368,7 +368,7 @@ async def _async_get_trigger_platform(
except IntegrationNotFound: except IntegrationNotFound:
raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified") from None raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified") from None
try: try:
return await integration.async_get_platform("trigger") return platform, await integration.async_get_platform("trigger")
except ImportError: except ImportError:
raise vol.Invalid( raise vol.Invalid(
f"Integration '{platform}' does not provide trigger support" f"Integration '{platform}' does not provide trigger support"
@ -381,11 +381,14 @@ async def async_validate_trigger_config(
"""Validate triggers.""" """Validate triggers."""
config = [] config = []
for conf in trigger_config: for conf in trigger_config:
platform = await _async_get_trigger_platform(hass, conf) trigger_key: str = conf[CONF_PLATFORM]
platform_domain, platform = await _async_get_trigger_platform(hass, trigger_key)
if hasattr(platform, "async_get_triggers"): if hasattr(platform, "async_get_triggers"):
trigger_descriptors = await platform.async_get_triggers(hass) trigger_descriptors = await platform.async_get_triggers(hass)
trigger_key: str = conf[CONF_PLATFORM] relative_trigger_key = get_relative_description_key(
if not (trigger := trigger_descriptors.get(trigger_key)): platform_domain, trigger_key
)
if not (trigger := trigger_descriptors.get(relative_trigger_key)):
raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified") raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified")
conf = await trigger.async_validate_trigger_config(hass, conf) conf = await trigger.async_validate_trigger_config(hass, conf)
elif hasattr(platform, "async_validate_trigger_config"): elif hasattr(platform, "async_validate_trigger_config"):
@ -471,7 +474,8 @@ async def async_initialize_triggers(
if not enabled: if not enabled:
continue continue
platform = await _async_get_trigger_platform(hass, conf) trigger_key: str = conf[CONF_PLATFORM]
platform_domain, platform = await _async_get_trigger_platform(hass, trigger_key)
trigger_id = conf.get(CONF_ID, f"{idx}") trigger_id = conf.get(CONF_ID, f"{idx}")
trigger_idx = f"{idx}" trigger_idx = f"{idx}"
trigger_alias = conf.get(CONF_ALIAS) trigger_alias = conf.get(CONF_ALIAS)
@ -487,7 +491,10 @@ async def async_initialize_triggers(
action_wrapper = _trigger_action_wrapper(hass, action, conf) action_wrapper = _trigger_action_wrapper(hass, action, conf)
if hasattr(platform, "async_get_triggers"): if hasattr(platform, "async_get_triggers"):
trigger_descriptors = await platform.async_get_triggers(hass) trigger_descriptors = await platform.async_get_triggers(hass)
trigger = trigger_descriptors[conf[CONF_PLATFORM]](hass, conf) relative_trigger_key = get_relative_description_key(
platform_domain, trigger_key
)
trigger = trigger_descriptors[relative_trigger_key](hass, conf)
coro = trigger.async_attach_trigger(action_wrapper, info) coro = trigger.async_attach_trigger(action_wrapper, info)
else: else:
coro = platform.async_attach_trigger(hass, conf, action_wrapper, info) coro = platform.async_attach_trigger(hass, conf, action_wrapper, info)
@ -525,11 +532,11 @@ async def async_initialize_triggers(
return remove_triggers return remove_triggers
def _load_triggers_file(hass: HomeAssistant, integration: Integration) -> JSON_TYPE: def _load_triggers_file(integration: Integration) -> dict[str, Any]:
"""Load triggers file for an integration.""" """Load triggers file for an integration."""
try: try:
return cast( return cast(
JSON_TYPE, dict[str, Any],
_TRIGGERS_SCHEMA( _TRIGGERS_SCHEMA(
load_yaml_dict(str(integration.file_path / "triggers.yaml")) load_yaml_dict(str(integration.file_path / "triggers.yaml"))
), ),
@ -549,11 +556,14 @@ def _load_triggers_file(hass: HomeAssistant, integration: Integration) -> JSON_T
def _load_triggers_files( def _load_triggers_files(
hass: HomeAssistant, integrations: Iterable[Integration] integrations: Iterable[Integration],
) -> dict[str, JSON_TYPE]: ) -> dict[str, dict[str, Any]]:
"""Load trigger files for multiple integrations.""" """Load trigger files for multiple integrations."""
return { return {
integration.domain: _load_triggers_file(hass, integration) integration.domain: {
get_absolute_description_key(integration.domain, key): value
for key, value in _load_triggers_file(integration).items()
}
for integration in integrations for integration in integrations
} }
@ -574,7 +584,7 @@ async def async_get_all_descriptions(
return descriptions_cache return descriptions_cache
# Files we loaded for missing descriptions # Files we loaded for missing descriptions
new_triggers_descriptions: dict[str, JSON_TYPE] = {} new_triggers_descriptions: dict[str, dict[str, Any]] = {}
# We try to avoid making a copy in the event the cache is good, # We try to avoid making a copy in the event the cache is good,
# but now we must make a copy in case new triggers get added # but now we must make a copy in case new triggers get added
# while we are loading the missing ones so we do not # while we are loading the missing ones so we do not
@ -601,7 +611,7 @@ async def async_get_all_descriptions(
if integrations: if integrations:
new_triggers_descriptions = await hass.async_add_executor_job( new_triggers_descriptions = await hass.async_add_executor_job(
_load_triggers_files, hass, integrations _load_triggers_files, integrations
) )
# Make a copy of the old cache and add missing descriptions to it # Make a copy of the old cache and add missing descriptions to it
@ -610,7 +620,7 @@ async def async_get_all_descriptions(
domain = triggers[missing_trigger] domain = triggers[missing_trigger]
if ( if (
yaml_description := new_triggers_descriptions.get(domain, {}).get( # type: ignore[union-attr] yaml_description := new_triggers_descriptions.get(domain, {}).get(
missing_trigger missing_trigger
) )
) is None: ) is None:

View File

@ -136,7 +136,7 @@ TRIGGER_ICONS_SCHEMA = cv.schema_with_slug_keys(
vol.Optional("trigger"): icon_value_validator, vol.Optional("trigger"): icon_value_validator,
} }
), ),
slug_validator=translation_key_validator, slug_validator=cv.underscore_slug,
) )

View File

@ -450,7 +450,7 @@ def gen_strings_schema(config: Config, integration: Integration) -> vol.Schema:
slug_validator=translation_key_validator, slug_validator=translation_key_validator,
), ),
}, },
slug_validator=translation_key_validator, slug_validator=cv.underscore_slug,
), ),
vol.Optional("conversation"): { vol.Optional("conversation"): {
vol.Required("agent"): { vol.Required("agent"): {

View File

@ -50,7 +50,7 @@ TRIGGER_SCHEMA = vol.Any(
TRIGGERS_SCHEMA = vol.Schema( TRIGGERS_SCHEMA = vol.Schema(
{ {
vol.Remove(vol.All(str, trigger.starts_with_dot)): object, vol.Remove(vol.All(str, trigger.starts_with_dot)): object,
cv.slug: TRIGGER_SCHEMA, cv.underscore_slug: TRIGGER_SCHEMA,
} }
) )

View File

@ -806,10 +806,10 @@ async def test_subscribe_triggers(
) -> None: ) -> None:
"""Test trigger_platforms/subscribe command.""" """Test trigger_platforms/subscribe command."""
sun_trigger_descriptions = """ sun_trigger_descriptions = """
sun: {} _: {}
""" """
tag_trigger_descriptions = """ tag_trigger_descriptions = """
tag: {} _: {}
""" """
def _load_yaml(fname, secrets=None): def _load_yaml(fname, secrets=None):

View File

@ -977,7 +977,7 @@ async def test_zwave_js_event_invalid_config_entry_id(
async def test_invalid_trigger_configs(hass: HomeAssistant) -> None: async def test_invalid_trigger_configs(hass: HomeAssistant) -> None:
"""Test invalid trigger configs.""" """Test invalid trigger configs."""
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
await TRIGGERS[f"{DOMAIN}.event"].async_validate_trigger_config( await TRIGGERS["event"].async_validate_trigger_config(
hass, hass,
{ {
"platform": f"{DOMAIN}.event", "platform": f"{DOMAIN}.event",
@ -988,7 +988,7 @@ async def test_invalid_trigger_configs(hass: HomeAssistant) -> None:
) )
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
await TRIGGERS[f"{DOMAIN}.value_updated"].async_validate_trigger_config( await TRIGGERS["value_updated"].async_validate_trigger_config(
hass, hass,
{ {
"platform": f"{DOMAIN}.value_updated", "platform": f"{DOMAIN}.value_updated",
@ -1026,7 +1026,7 @@ async def test_zwave_js_trigger_config_entry_unloaded(
await hass.config_entries.async_unload(integration.entry_id) await hass.config_entries.async_unload(integration.entry_id)
# Test full validation for both events # Test full validation for both events
assert await TRIGGERS[f"{DOMAIN}.value_updated"].async_validate_trigger_config( assert await TRIGGERS["value_updated"].async_validate_trigger_config(
hass, hass,
{ {
"platform": f"{DOMAIN}.value_updated", "platform": f"{DOMAIN}.value_updated",
@ -1036,7 +1036,7 @@ async def test_zwave_js_trigger_config_entry_unloaded(
}, },
) )
assert await TRIGGERS[f"{DOMAIN}.event"].async_validate_trigger_config( assert await TRIGGERS["event"].async_validate_trigger_config(
hass, hass,
{ {
"platform": f"{DOMAIN}.event", "platform": f"{DOMAIN}.event",

View File

@ -0,0 +1,36 @@
"""Test automation helpers."""
import pytest
from homeassistant.helpers.automation import (
get_absolute_description_key,
get_relative_description_key,
)
@pytest.mark.parametrize(
("relative_key", "absolute_key"),
[
("turned_on", "homeassistant.turned_on"),
("_", "homeassistant"),
("_state", "state"),
],
)
def test_absolute_description_key(relative_key: str, absolute_key: str) -> None:
"""Test absolute description key."""
DOMAIN = "homeassistant"
assert get_absolute_description_key(DOMAIN, relative_key) == absolute_key
@pytest.mark.parametrize(
("relative_key", "absolute_key"),
[
("turned_on", "homeassistant.turned_on"),
("_", "homeassistant"),
("_state", "state"),
],
)
def test_relative_description_key(relative_key: str, absolute_key: str) -> None:
"""Test relative description key."""
DOMAIN = "homeassistant"
assert get_relative_description_key(DOMAIN, absolute_key) == relative_key

View File

@ -50,7 +50,7 @@ async def test_trigger_subtype(hass: HomeAssistant) -> None:
"homeassistant.helpers.trigger.async_get_integration", "homeassistant.helpers.trigger.async_get_integration",
return_value=MagicMock(async_get_platform=AsyncMock()), return_value=MagicMock(async_get_platform=AsyncMock()),
) as integration_mock: ) as integration_mock:
await _async_get_trigger_platform(hass, {"platform": "test.subtype"}) await _async_get_trigger_platform(hass, "test.subtype")
assert integration_mock.call_args == call(hass, "test") assert integration_mock.call_args == call(hass, "test")
@ -493,8 +493,8 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
hass: HomeAssistant, hass: HomeAssistant,
) -> dict[str, type[Trigger]]: ) -> dict[str, type[Trigger]]:
return { return {
"test": MockTrigger1, "_": MockTrigger1,
"test.trig_2": MockTrigger2, "trig_2": MockTrigger2,
} }
mock_integration(hass, MockModule("test")) mock_integration(hass, MockModule("test"))
@ -534,7 +534,7 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
"sun_trigger_descriptions", "sun_trigger_descriptions",
[ [
""" """
sun: _:
fields: fields:
event: event:
example: sunrise example: sunrise
@ -551,7 +551,7 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
.anchor: &anchor .anchor: &anchor
- sunrise - sunrise
- sunset - sunset
sun: _:
fields: fields:
event: event:
example: sunrise example: sunrise
@ -569,7 +569,7 @@ async def test_async_get_all_descriptions(
) -> None: ) -> None:
"""Test async_get_all_descriptions.""" """Test async_get_all_descriptions."""
tag_trigger_descriptions = """ tag_trigger_descriptions = """
tag: _:
fields: fields:
entity: entity:
selector: selector:
@ -607,7 +607,7 @@ async def test_async_get_all_descriptions(
# Test we only load triggers.yaml for integrations with triggers, # Test we only load triggers.yaml for integrations with triggers,
# system_health has no triggers # system_health has no triggers
assert proxy_load_triggers_files.mock_calls[0][1][1] == unordered( assert proxy_load_triggers_files.mock_calls[0][1][0] == unordered(
[ [
await async_get_integration(hass, DOMAIN_SUN), await async_get_integration(hass, DOMAIN_SUN),
] ]
@ -615,7 +615,7 @@ async def test_async_get_all_descriptions(
# system_health does not have triggers and should not be in descriptions # system_health does not have triggers and should not be in descriptions
assert descriptions == { assert descriptions == {
DOMAIN_SUN: { "sun": {
"fields": { "fields": {
"event": { "event": {
"example": "sunrise", "example": "sunrise",
@ -650,7 +650,7 @@ async def test_async_get_all_descriptions(
new_descriptions = await trigger.async_get_all_descriptions(hass) new_descriptions = await trigger.async_get_all_descriptions(hass)
assert new_descriptions is not descriptions assert new_descriptions is not descriptions
assert new_descriptions == { assert new_descriptions == {
DOMAIN_SUN: { "sun": {
"fields": { "fields": {
"event": { "event": {
"example": "sunrise", "example": "sunrise",
@ -666,7 +666,7 @@ async def test_async_get_all_descriptions(
"offset": {"selector": {"time": {}}}, "offset": {"selector": {"time": {}}},
} }
}, },
DOMAIN_TAG: { "tag": {
"fields": { "fields": {
"entity": { "entity": {
"selector": { "selector": {
@ -736,7 +736,7 @@ async def test_async_get_all_descriptions_with_bad_description(
) -> None: ) -> None:
"""Test async_get_all_descriptions.""" """Test async_get_all_descriptions."""
sun_service_descriptions = """ sun_service_descriptions = """
sun: _:
fields: not_a_dict fields: not_a_dict
""" """
@ -760,7 +760,7 @@ async def test_async_get_all_descriptions_with_bad_description(
assert ( assert (
"Unable to parse triggers.yaml for the sun integration: " "Unable to parse triggers.yaml for the sun integration: "
"expected a dictionary for dictionary value @ data['sun']['fields']" "expected a dictionary for dictionary value @ data['_']['fields']"
) in caplog.text ) in caplog.text
@ -787,7 +787,7 @@ async def test_subscribe_triggers(
) -> None: ) -> None:
"""Test trigger.async_subscribe_platform_events.""" """Test trigger.async_subscribe_platform_events."""
sun_trigger_descriptions = """ sun_trigger_descriptions = """
sun: {} _: {}
""" """
def _load_yaml(fname, secrets=None): def _load_yaml(fname, secrets=None):