mirror of
https://github.com/home-assistant/core.git
synced 2025-07-08 05:47:10 +00:00
Improve support for trigger platforms with multiple triggers (#144827)
* Improve support for trigger platforms with multiple triggers * Adjust zwave_js * Refactor the Trigger class * Silence mypy * Adjust * Revert "Adjust" This reverts commit 17b3d16a267d54c082b12f07550faa8ac4ac3a49. * Revert "Silence mypy" This reverts commit c2a011b16f9b02880fc3dc673b5b12501f7995fc. * Reapply "Adjust" This reverts commit c64ba202dd19da9de08c504f8163ec51acbebab0. * Apply suggestions from code review * Revert "Apply suggestions from code review" This reverts commit 0314955c5a15548b8a4ce69aab7b25452fe4b1e0.
This commit is contained in:
parent
dbfecf99dc
commit
26fe23eb5c
@ -29,7 +29,6 @@ from homeassistant.helpers import (
|
|||||||
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
|
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from . import trigger
|
|
||||||
from .config_validation import VALUE_SCHEMA
|
from .config_validation import VALUE_SCHEMA
|
||||||
from .const import (
|
from .const import (
|
||||||
ATTR_COMMAND_CLASS,
|
ATTR_COMMAND_CLASS,
|
||||||
@ -67,6 +66,8 @@ from .triggers.value_updated import (
|
|||||||
ATTR_FROM,
|
ATTR_FROM,
|
||||||
ATTR_TO,
|
ATTR_TO,
|
||||||
PLATFORM_TYPE as VALUE_UPDATED_PLATFORM_TYPE,
|
PLATFORM_TYPE as VALUE_UPDATED_PLATFORM_TYPE,
|
||||||
|
async_attach_trigger as attach_value_updated_trigger,
|
||||||
|
async_validate_trigger_config as validate_value_updated_trigger_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trigger types
|
# Trigger types
|
||||||
@ -448,10 +449,10 @@ async def async_attach_trigger(
|
|||||||
ATTR_TO,
|
ATTR_TO,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
zwave_js_config = await trigger.async_validate_trigger_config(
|
zwave_js_config = await validate_value_updated_trigger_config(
|
||||||
hass, zwave_js_config
|
hass, zwave_js_config
|
||||||
)
|
)
|
||||||
return await trigger.async_attach_trigger(
|
return await attach_value_updated_trigger(
|
||||||
hass, zwave_js_config, action, trigger_info
|
hass, zwave_js_config, action, trigger_info
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2,45 +2,17 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from homeassistant.const import CONF_PLATFORM
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
from homeassistant.helpers.trigger import Trigger
|
||||||
from homeassistant.helpers.trigger import (
|
|
||||||
TriggerActionType,
|
|
||||||
TriggerInfo,
|
|
||||||
TriggerProtocol,
|
|
||||||
)
|
|
||||||
from homeassistant.helpers.typing import ConfigType
|
|
||||||
|
|
||||||
from .triggers import event, value_updated
|
from .triggers import event, value_updated
|
||||||
|
|
||||||
TRIGGERS = {
|
TRIGGERS = {
|
||||||
"value_updated": value_updated,
|
event.PLATFORM_TYPE: event.EventTrigger,
|
||||||
"event": event,
|
value_updated.PLATFORM_TYPE: value_updated.ValueUpdatedTrigger,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _get_trigger_platform(config: ConfigType) -> TriggerProtocol:
|
async def async_get_triggers(hass: HomeAssistant) -> dict[str, type[Trigger]]:
|
||||||
"""Return trigger platform."""
|
"""Return the triggers for Z-Wave JS."""
|
||||||
platform_split = config[CONF_PLATFORM].split(".", maxsplit=1)
|
return TRIGGERS
|
||||||
if len(platform_split) < 2 or platform_split[1] not in TRIGGERS:
|
|
||||||
raise ValueError(f"Unknown Z-Wave JS trigger platform {config[CONF_PLATFORM]}")
|
|
||||||
return TRIGGERS[platform_split[1]]
|
|
||||||
|
|
||||||
|
|
||||||
async def async_validate_trigger_config(
|
|
||||||
hass: HomeAssistant, config: ConfigType
|
|
||||||
) -> ConfigType:
|
|
||||||
"""Validate config."""
|
|
||||||
platform = _get_trigger_platform(config)
|
|
||||||
return await platform.async_validate_trigger_config(hass, config)
|
|
||||||
|
|
||||||
|
|
||||||
async def async_attach_trigger(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
config: ConfigType,
|
|
||||||
action: TriggerActionType,
|
|
||||||
trigger_info: TriggerInfo,
|
|
||||||
) -> CALLBACK_TYPE:
|
|
||||||
"""Attach trigger of specified platform."""
|
|
||||||
platform = _get_trigger_platform(config)
|
|
||||||
return await platform.async_attach_trigger(hass, config, action, trigger_info)
|
|
||||||
|
@ -16,7 +16,7 @@ from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM
|
|||||||
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
||||||
from homeassistant.helpers import config_validation as cv, device_registry as dr
|
from homeassistant.helpers import config_validation as cv, device_registry as dr
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||||
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
|
from homeassistant.helpers.trigger import Trigger, TriggerActionType, TriggerInfo
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from ..const import (
|
from ..const import (
|
||||||
@ -251,3 +251,29 @@ async def async_attach_trigger(
|
|||||||
_create_zwave_listeners()
|
_create_zwave_listeners()
|
||||||
|
|
||||||
return async_remove
|
return async_remove
|
||||||
|
|
||||||
|
|
||||||
|
class EventTrigger(Trigger):
|
||||||
|
"""Z-Wave JS event trigger."""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
|
||||||
|
"""Initialize trigger."""
|
||||||
|
self._config = config
|
||||||
|
self._hass = hass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def async_validate_trigger_config(
|
||||||
|
cls, hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> ConfigType:
|
||||||
|
"""Validate config."""
|
||||||
|
return await async_validate_trigger_config(hass, config)
|
||||||
|
|
||||||
|
async def async_attach_trigger(
|
||||||
|
self,
|
||||||
|
action: TriggerActionType,
|
||||||
|
trigger_info: TriggerInfo,
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
|
"""Attach a trigger."""
|
||||||
|
return await async_attach_trigger(
|
||||||
|
self._hass, self._config, action, trigger_info
|
||||||
|
)
|
||||||
|
@ -14,7 +14,7 @@ from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM, M
|
|||||||
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
||||||
from homeassistant.helpers import config_validation as cv, device_registry as dr
|
from homeassistant.helpers import config_validation as cv, device_registry as dr
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||||
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
|
from homeassistant.helpers.trigger import Trigger, TriggerActionType, TriggerInfo
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from ..config_validation import VALUE_SCHEMA
|
from ..config_validation import VALUE_SCHEMA
|
||||||
@ -202,3 +202,29 @@ async def async_attach_trigger(
|
|||||||
_create_zwave_listeners()
|
_create_zwave_listeners()
|
||||||
|
|
||||||
return async_remove
|
return async_remove
|
||||||
|
|
||||||
|
|
||||||
|
class ValueUpdatedTrigger(Trigger):
|
||||||
|
"""Z-Wave JS value updated trigger."""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
|
||||||
|
"""Initialize trigger."""
|
||||||
|
self._config = config
|
||||||
|
self._hass = hass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def async_validate_trigger_config(
|
||||||
|
cls, hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> ConfigType:
|
||||||
|
"""Validate config."""
|
||||||
|
return await async_validate_trigger_config(hass, config)
|
||||||
|
|
||||||
|
async def async_attach_trigger(
|
||||||
|
self,
|
||||||
|
action: TriggerActionType,
|
||||||
|
trigger_info: TriggerInfo,
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
|
"""Attach a trigger."""
|
||||||
|
return await async_attach_trigger(
|
||||||
|
self._hass, self._config, action, trigger_info
|
||||||
|
)
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import Callable, Coroutine
|
||||||
@ -49,12 +50,37 @@ DATA_PLUGGABLE_ACTIONS: HassKey[defaultdict[tuple, PluggableActionsEntry]] = Has
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Trigger(abc.ABC):
|
||||||
|
"""Trigger class."""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
|
||||||
|
"""Initialize trigger."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def async_validate_trigger_config(
|
||||||
|
cls, hass: HomeAssistant, config: ConfigType
|
||||||
|
) -> ConfigType:
|
||||||
|
"""Validate config."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def async_attach_trigger(
|
||||||
|
self,
|
||||||
|
action: TriggerActionType,
|
||||||
|
trigger_info: TriggerInfo,
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
|
"""Attach a trigger."""
|
||||||
|
|
||||||
|
|
||||||
class TriggerProtocol(Protocol):
|
class TriggerProtocol(Protocol):
|
||||||
"""Define the format of trigger modules.
|
"""Define the format of trigger modules.
|
||||||
|
|
||||||
Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config.
|
New implementations should only implement async_get_triggers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async def async_get_triggers(self, hass: HomeAssistant) -> dict[str, type[Trigger]]:
|
||||||
|
"""Return the triggers provided by this integration."""
|
||||||
|
|
||||||
TRIGGER_SCHEMA: vol.Schema
|
TRIGGER_SCHEMA: vol.Schema
|
||||||
|
|
||||||
async def async_validate_trigger_config(
|
async def async_validate_trigger_config(
|
||||||
@ -219,13 +245,14 @@ class PluggableAction:
|
|||||||
async def _async_get_trigger_platform(
|
async def _async_get_trigger_platform(
|
||||||
hass: HomeAssistant, config: ConfigType
|
hass: HomeAssistant, config: ConfigType
|
||||||
) -> TriggerProtocol:
|
) -> TriggerProtocol:
|
||||||
platform_and_sub_type = config[CONF_PLATFORM].split(".")
|
trigger_key: str = config[CONF_PLATFORM]
|
||||||
|
platform_and_sub_type = trigger_key.split(".")
|
||||||
platform = platform_and_sub_type[0]
|
platform = platform_and_sub_type[0]
|
||||||
platform = _PLATFORM_ALIASES.get(platform, platform)
|
platform = _PLATFORM_ALIASES.get(platform, platform)
|
||||||
try:
|
try:
|
||||||
integration = await async_get_integration(hass, platform)
|
integration = await async_get_integration(hass, platform)
|
||||||
except IntegrationNotFound:
|
except IntegrationNotFound:
|
||||||
raise vol.Invalid(f"Invalid trigger '{platform}' specified") from None
|
raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified") from None
|
||||||
try:
|
try:
|
||||||
return await integration.async_get_platform("trigger")
|
return await integration.async_get_platform("trigger")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -241,7 +268,13 @@ async def async_validate_trigger_config(
|
|||||||
config = []
|
config = []
|
||||||
for conf in trigger_config:
|
for conf in trigger_config:
|
||||||
platform = await _async_get_trigger_platform(hass, conf)
|
platform = await _async_get_trigger_platform(hass, conf)
|
||||||
if hasattr(platform, "async_validate_trigger_config"):
|
if hasattr(platform, "async_get_triggers"):
|
||||||
|
trigger_descriptors = await platform.async_get_triggers(hass)
|
||||||
|
trigger_key: str = conf[CONF_PLATFORM]
|
||||||
|
if not (trigger := trigger_descriptors[trigger_key]):
|
||||||
|
raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified")
|
||||||
|
conf = await trigger.async_validate_trigger_config(hass, conf)
|
||||||
|
elif hasattr(platform, "async_validate_trigger_config"):
|
||||||
conf = await platform.async_validate_trigger_config(hass, conf)
|
conf = await platform.async_validate_trigger_config(hass, conf)
|
||||||
else:
|
else:
|
||||||
conf = platform.TRIGGER_SCHEMA(conf)
|
conf = platform.TRIGGER_SCHEMA(conf)
|
||||||
@ -337,13 +370,15 @@ async def async_initialize_triggers(
|
|||||||
trigger_data=trigger_data,
|
trigger_data=trigger_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
triggers.append(
|
action_wrapper = _trigger_action_wrapper(hass, action, conf)
|
||||||
create_eager_task(
|
if hasattr(platform, "async_get_triggers"):
|
||||||
platform.async_attach_trigger(
|
trigger_descriptors = await platform.async_get_triggers(hass)
|
||||||
hass, conf, _trigger_action_wrapper(hass, action, conf), info
|
trigger = trigger_descriptors[conf[CONF_PLATFORM]](hass, conf)
|
||||||
)
|
coro = trigger.async_attach_trigger(action_wrapper, info)
|
||||||
)
|
else:
|
||||||
)
|
coro = platform.async_attach_trigger(hass, conf, action_wrapper, info)
|
||||||
|
|
||||||
|
triggers.append(create_eager_task(coro))
|
||||||
|
|
||||||
attach_results = await asyncio.gather(*triggers, return_exceptions=True)
|
attach_results = await asyncio.gather(*triggers, return_exceptions=True)
|
||||||
removes: list[Callable[[], None]] = []
|
removes: list[Callable[[], None]] = []
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""The tests for Z-Wave JS automation triggers."""
|
"""The tests for Z-Wave JS automation triggers."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@ -11,14 +11,11 @@ from zwave_js_server.model.node import Node
|
|||||||
from homeassistant.components import automation
|
from homeassistant.components import automation
|
||||||
from homeassistant.components.zwave_js import DOMAIN
|
from homeassistant.components.zwave_js import DOMAIN
|
||||||
from homeassistant.components.zwave_js.helpers import get_device_id
|
from homeassistant.components.zwave_js.helpers import get_device_id
|
||||||
from homeassistant.components.zwave_js.trigger import (
|
from homeassistant.components.zwave_js.trigger import TRIGGERS
|
||||||
_get_trigger_platform,
|
|
||||||
async_validate_trigger_config,
|
|
||||||
)
|
|
||||||
from homeassistant.components.zwave_js.triggers.trigger_helpers import (
|
from homeassistant.components.zwave_js.triggers.trigger_helpers import (
|
||||||
async_bypass_dynamic_config_validation,
|
async_bypass_dynamic_config_validation,
|
||||||
)
|
)
|
||||||
from homeassistant.const import CONF_PLATFORM, SERVICE_RELOAD
|
from homeassistant.const import SERVICE_RELOAD
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import device_registry as dr
|
from homeassistant.helpers import device_registry as dr
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
@ -977,22 +974,10 @@ async def test_zwave_js_event_invalid_config_entry_id(
|
|||||||
caplog.clear()
|
caplog.clear()
|
||||||
|
|
||||||
|
|
||||||
async def test_async_validate_trigger_config(hass: HomeAssistant) -> None:
|
|
||||||
"""Test async_validate_trigger_config."""
|
|
||||||
mock_platform = AsyncMock()
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.zwave_js.trigger._get_trigger_platform",
|
|
||||||
return_value=mock_platform,
|
|
||||||
):
|
|
||||||
mock_platform.async_validate_trigger_config.return_value = {}
|
|
||||||
await async_validate_trigger_config(hass, {})
|
|
||||||
mock_platform.async_validate_trigger_config.assert_awaited()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_invalid_trigger_configs(hass: HomeAssistant) -> None:
|
async def test_invalid_trigger_configs(hass: HomeAssistant) -> None:
|
||||||
"""Test invalid trigger configs."""
|
"""Test invalid trigger configs."""
|
||||||
with pytest.raises(vol.Invalid):
|
with pytest.raises(vol.Invalid):
|
||||||
await async_validate_trigger_config(
|
await TRIGGERS[f"{DOMAIN}.event"].async_validate_trigger_config(
|
||||||
hass,
|
hass,
|
||||||
{
|
{
|
||||||
"platform": f"{DOMAIN}.event",
|
"platform": f"{DOMAIN}.event",
|
||||||
@ -1003,7 +988,7 @@ async def test_invalid_trigger_configs(hass: HomeAssistant) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(vol.Invalid):
|
with pytest.raises(vol.Invalid):
|
||||||
await async_validate_trigger_config(
|
await TRIGGERS[f"{DOMAIN}.value_updated"].async_validate_trigger_config(
|
||||||
hass,
|
hass,
|
||||||
{
|
{
|
||||||
"platform": f"{DOMAIN}.value_updated",
|
"platform": f"{DOMAIN}.value_updated",
|
||||||
@ -1041,7 +1026,7 @@ async def test_zwave_js_trigger_config_entry_unloaded(
|
|||||||
await hass.config_entries.async_unload(integration.entry_id)
|
await hass.config_entries.async_unload(integration.entry_id)
|
||||||
|
|
||||||
# Test full validation for both events
|
# Test full validation for both events
|
||||||
assert await async_validate_trigger_config(
|
assert await TRIGGERS[f"{DOMAIN}.value_updated"].async_validate_trigger_config(
|
||||||
hass,
|
hass,
|
||||||
{
|
{
|
||||||
"platform": f"{DOMAIN}.value_updated",
|
"platform": f"{DOMAIN}.value_updated",
|
||||||
@ -1051,7 +1036,7 @@ async def test_zwave_js_trigger_config_entry_unloaded(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert await async_validate_trigger_config(
|
assert await TRIGGERS[f"{DOMAIN}.event"].async_validate_trigger_config(
|
||||||
hass,
|
hass,
|
||||||
{
|
{
|
||||||
"platform": f"{DOMAIN}.event",
|
"platform": f"{DOMAIN}.event",
|
||||||
@ -1115,12 +1100,6 @@ async def test_zwave_js_trigger_config_entry_unloaded(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_trigger_platform_failure() -> None:
|
|
||||||
"""Test _get_trigger_platform."""
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
_get_trigger_platform({CONF_PLATFORM: "zwave_js.invalid"})
|
|
||||||
|
|
||||||
|
|
||||||
async def test_server_reconnect_event(
|
async def test_server_reconnect_event(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
client,
|
client,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user