Improve support for trigger platforms with multiple triggers (#144827)

* Improve support for trigger platforms with multiple triggers

* Adjust zwave_js

* Refactor the Trigger class

* Silence mypy

* Adjust

* Revert "Adjust"

This reverts commit 17b3d16a267d54c082b12f07550faa8ac4ac3a49.

* Revert "Silence mypy"

This reverts commit c2a011b16f9b02880fc3dc673b5b12501f7995fc.

* Reapply "Adjust"

This reverts commit c64ba202dd19da9de08c504f8163ec51acbebab0.

* Apply suggestions from code review

* Revert "Apply suggestions from code review"

This reverts commit 0314955c5a15548b8a4ce69aab7b25452fe4b1e0.
This commit is contained in:
Erik Montnemery 2025-06-10 20:48:51 +02:00 committed by GitHub
parent dbfecf99dc
commit 26fe23eb5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 118 additions and 79 deletions

View File

@ -29,7 +29,6 @@ from homeassistant.helpers import (
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from . import trigger
from .config_validation import VALUE_SCHEMA from .config_validation import VALUE_SCHEMA
from .const import ( from .const import (
ATTR_COMMAND_CLASS, ATTR_COMMAND_CLASS,
@ -67,6 +66,8 @@ from .triggers.value_updated import (
ATTR_FROM, ATTR_FROM,
ATTR_TO, ATTR_TO,
PLATFORM_TYPE as VALUE_UPDATED_PLATFORM_TYPE, PLATFORM_TYPE as VALUE_UPDATED_PLATFORM_TYPE,
async_attach_trigger as attach_value_updated_trigger,
async_validate_trigger_config as validate_value_updated_trigger_config,
) )
# Trigger types # Trigger types
@ -448,10 +449,10 @@ async def async_attach_trigger(
ATTR_TO, ATTR_TO,
], ],
) )
zwave_js_config = await trigger.async_validate_trigger_config( zwave_js_config = await validate_value_updated_trigger_config(
hass, zwave_js_config hass, zwave_js_config
) )
return await trigger.async_attach_trigger( return await attach_value_updated_trigger(
hass, zwave_js_config, action, trigger_info hass, zwave_js_config, action, trigger_info
) )

View File

@ -2,45 +2,17 @@
from __future__ import annotations from __future__ import annotations
from homeassistant.const import CONF_PLATFORM from homeassistant.core import HomeAssistant
from homeassistant.core import CALLBACK_TYPE, HomeAssistant from homeassistant.helpers.trigger import Trigger
from homeassistant.helpers.trigger import (
TriggerActionType,
TriggerInfo,
TriggerProtocol,
)
from homeassistant.helpers.typing import ConfigType
from .triggers import event, value_updated from .triggers import event, value_updated
TRIGGERS = { TRIGGERS = {
"value_updated": value_updated, event.PLATFORM_TYPE: event.EventTrigger,
"event": event, value_updated.PLATFORM_TYPE: value_updated.ValueUpdatedTrigger,
} }
def _get_trigger_platform(config: ConfigType) -> TriggerProtocol: async def async_get_triggers(hass: HomeAssistant) -> dict[str, type[Trigger]]:
"""Return trigger platform.""" """Return the triggers for Z-Wave JS."""
platform_split = config[CONF_PLATFORM].split(".", maxsplit=1) return TRIGGERS
if len(platform_split) < 2 or platform_split[1] not in TRIGGERS:
raise ValueError(f"Unknown Z-Wave JS trigger platform {config[CONF_PLATFORM]}")
return TRIGGERS[platform_split[1]]
async def async_validate_trigger_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
platform = _get_trigger_platform(config)
return await platform.async_validate_trigger_config(hass, config)
async def async_attach_trigger(
hass: HomeAssistant,
config: ConfigType,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach trigger of specified platform."""
platform = _get_trigger_platform(config)
return await platform.async_attach_trigger(hass, config, action, trigger_info)

View File

@ -16,7 +16,7 @@ from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, device_registry as dr from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.trigger import Trigger, TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from ..const import ( from ..const import (
@ -251,3 +251,29 @@ async def async_attach_trigger(
_create_zwave_listeners() _create_zwave_listeners()
return async_remove return async_remove
class EventTrigger(Trigger):
"""Z-Wave JS event trigger."""
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
"""Initialize trigger."""
self._config = config
self._hass = hass
@classmethod
async def async_validate_trigger_config(
cls, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
return await async_validate_trigger_config(hass, config)
async def async_attach_trigger(
self,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
return await async_attach_trigger(
self._hass, self._config, action, trigger_info
)

View File

@ -14,7 +14,7 @@ from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM, M
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, device_registry as dr from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.trigger import Trigger, TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from ..config_validation import VALUE_SCHEMA from ..config_validation import VALUE_SCHEMA
@ -202,3 +202,29 @@ async def async_attach_trigger(
_create_zwave_listeners() _create_zwave_listeners()
return async_remove return async_remove
class ValueUpdatedTrigger(Trigger):
"""Z-Wave JS value updated trigger."""
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
"""Initialize trigger."""
self._config = config
self._hass = hass
@classmethod
async def async_validate_trigger_config(
cls, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
return await async_validate_trigger_config(hass, config)
async def async_attach_trigger(
self,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
return await async_attach_trigger(
self._hass, self._config, action, trigger_info
)

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import abc
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
@ -49,12 +50,37 @@ DATA_PLUGGABLE_ACTIONS: HassKey[defaultdict[tuple, PluggableActionsEntry]] = Has
) )
class Trigger(abc.ABC):
"""Trigger class."""
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
"""Initialize trigger."""
@classmethod
@abc.abstractmethod
async def async_validate_trigger_config(
cls, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
@abc.abstractmethod
async def async_attach_trigger(
self,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
class TriggerProtocol(Protocol): class TriggerProtocol(Protocol):
"""Define the format of trigger modules. """Define the format of trigger modules.
Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config. New implementations should only implement async_get_triggers.
""" """
async def async_get_triggers(self, hass: HomeAssistant) -> dict[str, type[Trigger]]:
"""Return the triggers provided by this integration."""
TRIGGER_SCHEMA: vol.Schema TRIGGER_SCHEMA: vol.Schema
async def async_validate_trigger_config( async def async_validate_trigger_config(
@ -219,13 +245,14 @@ class PluggableAction:
async def _async_get_trigger_platform( async def _async_get_trigger_platform(
hass: HomeAssistant, config: ConfigType hass: HomeAssistant, config: ConfigType
) -> TriggerProtocol: ) -> TriggerProtocol:
platform_and_sub_type = config[CONF_PLATFORM].split(".") trigger_key: str = config[CONF_PLATFORM]
platform_and_sub_type = trigger_key.split(".")
platform = platform_and_sub_type[0] platform = platform_and_sub_type[0]
platform = _PLATFORM_ALIASES.get(platform, platform) platform = _PLATFORM_ALIASES.get(platform, platform)
try: try:
integration = await async_get_integration(hass, platform) integration = await async_get_integration(hass, platform)
except IntegrationNotFound: except IntegrationNotFound:
raise vol.Invalid(f"Invalid trigger '{platform}' specified") from None raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified") from None
try: try:
return await integration.async_get_platform("trigger") return await integration.async_get_platform("trigger")
except ImportError: except ImportError:
@ -241,7 +268,13 @@ async def async_validate_trigger_config(
config = [] config = []
for conf in trigger_config: for conf in trigger_config:
platform = await _async_get_trigger_platform(hass, conf) platform = await _async_get_trigger_platform(hass, conf)
if hasattr(platform, "async_validate_trigger_config"): if hasattr(platform, "async_get_triggers"):
trigger_descriptors = await platform.async_get_triggers(hass)
trigger_key: str = conf[CONF_PLATFORM]
if not (trigger := trigger_descriptors[trigger_key]):
raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified")
conf = await trigger.async_validate_trigger_config(hass, conf)
elif hasattr(platform, "async_validate_trigger_config"):
conf = await platform.async_validate_trigger_config(hass, conf) conf = await platform.async_validate_trigger_config(hass, conf)
else: else:
conf = platform.TRIGGER_SCHEMA(conf) conf = platform.TRIGGER_SCHEMA(conf)
@ -337,13 +370,15 @@ async def async_initialize_triggers(
trigger_data=trigger_data, trigger_data=trigger_data,
) )
triggers.append( action_wrapper = _trigger_action_wrapper(hass, action, conf)
create_eager_task( if hasattr(platform, "async_get_triggers"):
platform.async_attach_trigger( trigger_descriptors = await platform.async_get_triggers(hass)
hass, conf, _trigger_action_wrapper(hass, action, conf), info trigger = trigger_descriptors[conf[CONF_PLATFORM]](hass, conf)
) coro = trigger.async_attach_trigger(action_wrapper, info)
) else:
) coro = platform.async_attach_trigger(hass, conf, action_wrapper, info)
triggers.append(create_eager_task(coro))
attach_results = await asyncio.gather(*triggers, return_exceptions=True) attach_results = await asyncio.gather(*triggers, return_exceptions=True)
removes: list[Callable[[], None]] = [] removes: list[Callable[[], None]] = []

View File

@ -1,6 +1,6 @@
"""The tests for Z-Wave JS automation triggers.""" """The tests for Z-Wave JS automation triggers."""
from unittest.mock import AsyncMock, patch from unittest.mock import patch
import pytest import pytest
import voluptuous as vol import voluptuous as vol
@ -11,14 +11,11 @@ from zwave_js_server.model.node import Node
from homeassistant.components import automation from homeassistant.components import automation
from homeassistant.components.zwave_js import DOMAIN from homeassistant.components.zwave_js import DOMAIN
from homeassistant.components.zwave_js.helpers import get_device_id from homeassistant.components.zwave_js.helpers import get_device_id
from homeassistant.components.zwave_js.trigger import ( from homeassistant.components.zwave_js.trigger import TRIGGERS
_get_trigger_platform,
async_validate_trigger_config,
)
from homeassistant.components.zwave_js.triggers.trigger_helpers import ( from homeassistant.components.zwave_js.triggers.trigger_helpers import (
async_bypass_dynamic_config_validation, async_bypass_dynamic_config_validation,
) )
from homeassistant.const import CONF_PLATFORM, SERVICE_RELOAD from homeassistant.const import SERVICE_RELOAD
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -977,22 +974,10 @@ async def test_zwave_js_event_invalid_config_entry_id(
caplog.clear() caplog.clear()
async def test_async_validate_trigger_config(hass: HomeAssistant) -> None:
"""Test async_validate_trigger_config."""
mock_platform = AsyncMock()
with patch(
"homeassistant.components.zwave_js.trigger._get_trigger_platform",
return_value=mock_platform,
):
mock_platform.async_validate_trigger_config.return_value = {}
await async_validate_trigger_config(hass, {})
mock_platform.async_validate_trigger_config.assert_awaited()
async def test_invalid_trigger_configs(hass: HomeAssistant) -> None: async def test_invalid_trigger_configs(hass: HomeAssistant) -> None:
"""Test invalid trigger configs.""" """Test invalid trigger configs."""
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
await async_validate_trigger_config( await TRIGGERS[f"{DOMAIN}.event"].async_validate_trigger_config(
hass, hass,
{ {
"platform": f"{DOMAIN}.event", "platform": f"{DOMAIN}.event",
@ -1003,7 +988,7 @@ async def test_invalid_trigger_configs(hass: HomeAssistant) -> None:
) )
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
await async_validate_trigger_config( await TRIGGERS[f"{DOMAIN}.value_updated"].async_validate_trigger_config(
hass, hass,
{ {
"platform": f"{DOMAIN}.value_updated", "platform": f"{DOMAIN}.value_updated",
@ -1041,7 +1026,7 @@ async def test_zwave_js_trigger_config_entry_unloaded(
await hass.config_entries.async_unload(integration.entry_id) await hass.config_entries.async_unload(integration.entry_id)
# Test full validation for both events # Test full validation for both events
assert await async_validate_trigger_config( assert await TRIGGERS[f"{DOMAIN}.value_updated"].async_validate_trigger_config(
hass, hass,
{ {
"platform": f"{DOMAIN}.value_updated", "platform": f"{DOMAIN}.value_updated",
@ -1051,7 +1036,7 @@ async def test_zwave_js_trigger_config_entry_unloaded(
}, },
) )
assert await async_validate_trigger_config( assert await TRIGGERS[f"{DOMAIN}.event"].async_validate_trigger_config(
hass, hass,
{ {
"platform": f"{DOMAIN}.event", "platform": f"{DOMAIN}.event",
@ -1115,12 +1100,6 @@ async def test_zwave_js_trigger_config_entry_unloaded(
) )
def test_get_trigger_platform_failure() -> None:
"""Test _get_trigger_platform."""
with pytest.raises(ValueError):
_get_trigger_platform({CONF_PLATFORM: "zwave_js.invalid"})
async def test_server_reconnect_event( async def test_server_reconnect_event(
hass: HomeAssistant, hass: HomeAssistant,
client, client,