diff --git a/homeassistant/components/zwave_js/trigger.py b/homeassistant/components/zwave_js/trigger.py index 9cb1a3e1d7e..923798d1177 100644 --- a/homeassistant/components/zwave_js/trigger.py +++ b/homeassistant/components/zwave_js/trigger.py @@ -2,45 +2,23 @@ from __future__ import annotations -from homeassistant.const import CONF_PLATFORM -from homeassistant.core import CALLBACK_TYPE, HomeAssistant -from homeassistant.helpers.trigger import ( - TriggerActionType, - TriggerInfo, - TriggerProtocol, -) -from homeassistant.helpers.typing import ConfigType +from homeassistant.core import HomeAssistant +from homeassistant.helpers.trigger import Trigger from .triggers import event, value_updated TRIGGERS = { - "value_updated": value_updated, - "event": event, + event.PLATFORM_TYPE: Trigger( + event.async_validate_trigger_config, + event.async_attach_trigger, + ), + value_updated.PLATFORM_TYPE: Trigger( + value_updated.async_validate_trigger_config, + value_updated.async_attach_trigger, + ), } -def _get_trigger_platform(config: ConfigType) -> TriggerProtocol: - """Return trigger platform.""" - platform_split = config[CONF_PLATFORM].split(".", maxsplit=1) - if len(platform_split) < 2 or platform_split[1] not in TRIGGERS: - raise ValueError(f"Unknown Z-Wave JS trigger platform {config[CONF_PLATFORM]}") - return TRIGGERS[platform_split[1]] - - -async def async_validate_trigger_config( - hass: HomeAssistant, config: ConfigType -) -> ConfigType: - """Validate config.""" - platform = _get_trigger_platform(config) - return await platform.async_validate_trigger_config(hass, config) - - -async def async_attach_trigger( - hass: HomeAssistant, - config: ConfigType, - action: TriggerActionType, - trigger_info: TriggerInfo, -) -> CALLBACK_TYPE: - """Attach trigger of specified platform.""" - platform = _get_trigger_platform(config) - return await platform.async_attach_trigger(hass, config, action, trigger_info) +async def async_get_triggers(hass: HomeAssistant) -> dict[str, Trigger]: + """Return the triggers for Z-Wave JS.""" + return TRIGGERS diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index a27c85a5c58..bb07d4e54f1 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -49,12 +49,29 @@ DATA_PLUGGABLE_ACTIONS: HassKey[defaultdict[tuple, PluggableActionsEntry]] = Has ) +@dataclass +class Trigger: + """Trigger descriptor data class.""" + + async_validate_trigger_config: Callable[ + [HomeAssistant, ConfigType], + Coroutine[Any, Any, ConfigType], + ] + async_attach_trigger: Callable[ + [HomeAssistant, ConfigType, TriggerActionType, TriggerInfo], + Coroutine[Any, Any, CALLBACK_TYPE], + ] + + class TriggerProtocol(Protocol): """Define the format of trigger modules. - Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config. + New implementations should only implement async_get_triggers. """ + async def async_get_triggers(self, hass: HomeAssistant) -> dict[str, Trigger]: + """Return the triggers provided by this integration.""" + TRIGGER_SCHEMA: vol.Schema async def async_validate_trigger_config( @@ -219,13 +236,14 @@ class PluggableAction: async def _async_get_trigger_platform( hass: HomeAssistant, config: ConfigType ) -> TriggerProtocol: - platform_and_sub_type = config[CONF_PLATFORM].split(".") + trigger_key: str = config[CONF_PLATFORM] + platform_and_sub_type = trigger_key.split(".") platform = platform_and_sub_type[0] platform = _PLATFORM_ALIASES.get(platform, platform) try: integration = await async_get_integration(hass, platform) except IntegrationNotFound: - raise vol.Invalid(f"Invalid trigger '{platform}' specified") from None + raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified") from None try: return await integration.async_get_platform("trigger") except ImportError: @@ -241,7 +259,13 @@ async def async_validate_trigger_config( config = [] for conf in trigger_config: platform = await _async_get_trigger_platform(hass, conf) - if hasattr(platform, "async_validate_trigger_config"): + if hasattr(platform, "async_get_triggers"): + trigger_descriptors = await platform.async_get_triggers(hass) + trigger_key: str = conf[CONF_PLATFORM] + if not (trigger := trigger_descriptors[trigger_key]): + raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified") + conf = await trigger.async_validate_trigger_config(hass, conf) + elif hasattr(platform, "async_validate_trigger_config"): conf = await platform.async_validate_trigger_config(hass, conf) else: conf = platform.TRIGGER_SCHEMA(conf) @@ -337,11 +361,15 @@ async def async_initialize_triggers( trigger_data=trigger_data, ) + if hasattr(platform, "async_get_triggers"): + trigger_descriptors = await platform.async_get_triggers(hass) + attach_fn = trigger_descriptors[conf[CONF_PLATFORM]].async_attach_trigger + else: + attach_fn = platform.async_attach_trigger + triggers.append( create_eager_task( - platform.async_attach_trigger( - hass, conf, _trigger_action_wrapper(hass, action, conf), info - ) + attach_fn(hass, conf, _trigger_action_wrapper(hass, action, conf), info) ) ) diff --git a/tests/components/zwave_js/test_trigger.py b/tests/components/zwave_js/test_trigger.py index 8c345619a90..02675544644 100644 --- a/tests/components/zwave_js/test_trigger.py +++ b/tests/components/zwave_js/test_trigger.py @@ -1,6 +1,6 @@ """The tests for Z-Wave JS automation triggers.""" -from unittest.mock import AsyncMock, patch +from unittest.mock import patch import pytest import voluptuous as vol @@ -11,14 +11,11 @@ from zwave_js_server.model.node import Node from homeassistant.components import automation from homeassistant.components.zwave_js import DOMAIN from homeassistant.components.zwave_js.helpers import get_device_id -from homeassistant.components.zwave_js.trigger import ( - _get_trigger_platform, - async_validate_trigger_config, -) +from homeassistant.components.zwave_js.trigger import TRIGGERS from homeassistant.components.zwave_js.triggers.trigger_helpers import ( async_bypass_dynamic_config_validation, ) -from homeassistant.const import CONF_PLATFORM, SERVICE_RELOAD +from homeassistant.const import SERVICE_RELOAD from homeassistant.core import HomeAssistant from homeassistant.helpers import device_registry as dr from homeassistant.setup import async_setup_component @@ -977,22 +974,10 @@ async def test_zwave_js_event_invalid_config_entry_id( caplog.clear() -async def test_async_validate_trigger_config(hass: HomeAssistant) -> None: - """Test async_validate_trigger_config.""" - mock_platform = AsyncMock() - with patch( - "homeassistant.components.zwave_js.trigger._get_trigger_platform", - return_value=mock_platform, - ): - mock_platform.async_validate_trigger_config.return_value = {} - await async_validate_trigger_config(hass, {}) - mock_platform.async_validate_trigger_config.assert_awaited() - - async def test_invalid_trigger_configs(hass: HomeAssistant) -> None: """Test invalid trigger configs.""" with pytest.raises(vol.Invalid): - await async_validate_trigger_config( + await TRIGGERS[f"{DOMAIN}.event"].async_validate_trigger_config( hass, { "platform": f"{DOMAIN}.event", @@ -1003,7 +988,7 @@ async def test_invalid_trigger_configs(hass: HomeAssistant) -> None: ) with pytest.raises(vol.Invalid): - await async_validate_trigger_config( + await TRIGGERS[f"{DOMAIN}.value_updated"].async_validate_trigger_config( hass, { "platform": f"{DOMAIN}.value_updated", @@ -1041,7 +1026,7 @@ async def test_zwave_js_trigger_config_entry_unloaded( await hass.config_entries.async_unload(integration.entry_id) # Test full validation for both events - assert await async_validate_trigger_config( + assert await TRIGGERS[f"{DOMAIN}.value_updated"].async_validate_trigger_config( hass, { "platform": f"{DOMAIN}.value_updated", @@ -1051,7 +1036,7 @@ async def test_zwave_js_trigger_config_entry_unloaded( }, ) - assert await async_validate_trigger_config( + assert await TRIGGERS[f"{DOMAIN}.event"].async_validate_trigger_config( hass, { "platform": f"{DOMAIN}.event", @@ -1115,12 +1100,6 @@ async def test_zwave_js_trigger_config_entry_unloaded( ) -def test_get_trigger_platform_failure() -> None: - """Test _get_trigger_platform.""" - with pytest.raises(ValueError): - _get_trigger_platform({CONF_PLATFORM: "zwave_js.invalid"}) - - async def test_server_reconnect_event( hass: HomeAssistant, client,