From b8044f60fca0b7ad96bbec5f97a14b59e24db0fd Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 24 Jun 2025 10:13:44 +0200 Subject: [PATCH] Fix trigger config validation (#147408) --- homeassistant/helpers/trigger.py | 2 +- tests/helpers/test_trigger.py | 95 +++++++++++++++++++++++++++++++- 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 62aebdf6fd7..853b5aaf812 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -271,7 +271,7 @@ async def async_validate_trigger_config( if hasattr(platform, "async_get_triggers"): trigger_descriptors = await platform.async_get_triggers(hass) trigger_key: str = conf[CONF_PLATFORM] - if not (trigger := trigger_descriptors[trigger_key]): + if not (trigger := trigger_descriptors.get(trigger_key)): raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified") conf = await trigger.async_validate_trigger_config(hass, conf) elif hasattr(platform, "async_validate_trigger_config"): diff --git a/tests/helpers/test_trigger.py b/tests/helpers/test_trigger.py index 77f48be170b..f5a2b549f89 100644 --- a/tests/helpers/test_trigger.py +++ b/tests/helpers/test_trigger.py @@ -1,20 +1,32 @@ """The tests for the trigger helper.""" -from unittest.mock import ANY, AsyncMock, MagicMock, call, patch +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, call, patch import pytest import voluptuous as vol -from homeassistant.core import Context, HomeAssistant, ServiceCall, callback +from homeassistant.core import ( + CALLBACK_TYPE, + Context, + HomeAssistant, + ServiceCall, + callback, +) from homeassistant.helpers.trigger import ( DATA_PLUGGABLE_ACTIONS, PluggableAction, + Trigger, + TriggerActionType, + TriggerInfo, _async_get_trigger_platform, async_initialize_triggers, async_validate_trigger_config, ) +from homeassistant.helpers.typing import ConfigType from homeassistant.setup import async_setup_component +from tests.common import MockModule, mock_integration, mock_platform + async def test_bad_trigger_platform(hass: HomeAssistant) -> None: """Test bad trigger platform.""" @@ -428,3 +440,82 @@ async def test_pluggable_action( remove_attach_2() assert not hass.data[DATA_PLUGGABLE_ACTIONS] assert not plug_2 + + +async def test_platform_multiple_triggers(hass: HomeAssistant) -> None: + """Test a trigger platform with multiple trigger.""" + + class MockTrigger(Trigger): + """Mock trigger.""" + + def __init__(self, hass: HomeAssistant, config: ConfigType) -> None: + """Initialize trigger.""" + + @classmethod + async def async_validate_trigger_config( + cls, hass: HomeAssistant, config: ConfigType + ) -> ConfigType: + """Validate config.""" + return config + + class MockTrigger1(MockTrigger): + """Mock trigger 1.""" + + async def async_attach_trigger( + self, + action: TriggerActionType, + trigger_info: TriggerInfo, + ) -> CALLBACK_TYPE: + """Attach a trigger.""" + action({"trigger": "test_trigger_1"}) + + class MockTrigger2(MockTrigger): + """Mock trigger 2.""" + + async def async_attach_trigger( + self, + action: TriggerActionType, + trigger_info: TriggerInfo, + ) -> CALLBACK_TYPE: + """Attach a trigger.""" + action({"trigger": "test_trigger_2"}) + + async def async_get_triggers( + hass: HomeAssistant, + ) -> dict[str, type[Trigger]]: + return { + "test": MockTrigger1, + "test.trig_2": MockTrigger2, + } + + mock_integration(hass, MockModule("test")) + mock_platform(hass, "test.trigger", Mock(async_get_triggers=async_get_triggers)) + + config_1 = [{"platform": "test"}] + config_2 = [{"platform": "test.trig_2"}] + config_3 = [{"platform": "test.unknown_trig"}] + assert await async_validate_trigger_config(hass, config_1) == config_1 + assert await async_validate_trigger_config(hass, config_2) == config_2 + with pytest.raises( + vol.Invalid, match="Invalid trigger 'test.unknown_trig' specified" + ): + await async_validate_trigger_config(hass, config_3) + + log_cb = MagicMock() + + action_calls = [] + + @callback + def cb_action(*args): + action_calls.append([*args]) + + await async_initialize_triggers(hass, config_1, cb_action, "test", "", log_cb) + assert action_calls == [[{"trigger": "test_trigger_1"}]] + action_calls.clear() + + await async_initialize_triggers(hass, config_2, cb_action, "test", "", log_cb) + assert action_calls == [[{"trigger": "test_trigger_2"}]] + action_calls.clear() + + with pytest.raises(KeyError): + await async_initialize_triggers(hass, config_3, cb_action, "test", "", log_cb)