mirror of
https://github.com/home-assistant/core.git
synced 2025-08-02 10:08:23 +00:00
Improve support for trigger platforms with multiple triggers
This commit is contained in:
parent
b15c9ad130
commit
ed43b3a4ec
@ -2,45 +2,23 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from homeassistant.const import CONF_PLATFORM
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
||||
from homeassistant.helpers.trigger import (
|
||||
TriggerActionType,
|
||||
TriggerInfo,
|
||||
TriggerProtocol,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.trigger import Trigger
|
||||
|
||||
from .triggers import event, value_updated
|
||||
|
||||
TRIGGERS = {
|
||||
"value_updated": value_updated,
|
||||
"event": event,
|
||||
event.PLATFORM_TYPE: Trigger(
|
||||
event.async_validate_trigger_config,
|
||||
event.async_attach_trigger,
|
||||
),
|
||||
value_updated.PLATFORM_TYPE: Trigger(
|
||||
value_updated.async_validate_trigger_config,
|
||||
value_updated.async_attach_trigger,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _get_trigger_platform(config: ConfigType) -> TriggerProtocol:
|
||||
"""Return trigger platform."""
|
||||
platform_split = config[CONF_PLATFORM].split(".", maxsplit=1)
|
||||
if len(platform_split) < 2 or platform_split[1] not in TRIGGERS:
|
||||
raise ValueError(f"Unknown Z-Wave JS trigger platform {config[CONF_PLATFORM]}")
|
||||
return TRIGGERS[platform_split[1]]
|
||||
|
||||
|
||||
async def async_validate_trigger_config(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> ConfigType:
|
||||
"""Validate config."""
|
||||
platform = _get_trigger_platform(config)
|
||||
return await platform.async_validate_trigger_config(hass, config)
|
||||
|
||||
|
||||
async def async_attach_trigger(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
action: TriggerActionType,
|
||||
trigger_info: TriggerInfo,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Attach trigger of specified platform."""
|
||||
platform = _get_trigger_platform(config)
|
||||
return await platform.async_attach_trigger(hass, config, action, trigger_info)
|
||||
async def async_get_triggers(hass: HomeAssistant) -> dict[str, Trigger]:
|
||||
"""Return the triggers for Z-Wave JS."""
|
||||
return TRIGGERS
|
||||
|
@ -49,12 +49,29 @@ DATA_PLUGGABLE_ACTIONS: HassKey[defaultdict[tuple, PluggableActionsEntry]] = Has
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Trigger:
|
||||
"""Trigger descriptor data class."""
|
||||
|
||||
async_validate_trigger_config: Callable[
|
||||
[HomeAssistant, ConfigType],
|
||||
Coroutine[Any, Any, ConfigType],
|
||||
]
|
||||
async_attach_trigger: Callable[
|
||||
[HomeAssistant, ConfigType, TriggerActionType, TriggerInfo],
|
||||
Coroutine[Any, Any, CALLBACK_TYPE],
|
||||
]
|
||||
|
||||
|
||||
class TriggerProtocol(Protocol):
|
||||
"""Define the format of trigger modules.
|
||||
|
||||
Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config.
|
||||
New implementations should only implement async_get_triggers.
|
||||
"""
|
||||
|
||||
async def async_get_triggers(self, hass: HomeAssistant) -> dict[str, Trigger]:
|
||||
"""Return the triggers provided by this integration."""
|
||||
|
||||
TRIGGER_SCHEMA: vol.Schema
|
||||
|
||||
async def async_validate_trigger_config(
|
||||
@ -219,13 +236,14 @@ class PluggableAction:
|
||||
async def _async_get_trigger_platform(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> TriggerProtocol:
|
||||
platform_and_sub_type = config[CONF_PLATFORM].split(".")
|
||||
trigger_key: str = config[CONF_PLATFORM]
|
||||
platform_and_sub_type = trigger_key.split(".")
|
||||
platform = platform_and_sub_type[0]
|
||||
platform = _PLATFORM_ALIASES.get(platform, platform)
|
||||
try:
|
||||
integration = await async_get_integration(hass, platform)
|
||||
except IntegrationNotFound:
|
||||
raise vol.Invalid(f"Invalid trigger '{platform}' specified") from None
|
||||
raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified") from None
|
||||
try:
|
||||
return await integration.async_get_platform("trigger")
|
||||
except ImportError:
|
||||
@ -241,7 +259,13 @@ async def async_validate_trigger_config(
|
||||
config = []
|
||||
for conf in trigger_config:
|
||||
platform = await _async_get_trigger_platform(hass, conf)
|
||||
if hasattr(platform, "async_validate_trigger_config"):
|
||||
if hasattr(platform, "async_get_triggers"):
|
||||
trigger_descriptors = await platform.async_get_triggers(hass)
|
||||
trigger_key: str = conf[CONF_PLATFORM]
|
||||
if not (trigger := trigger_descriptors[trigger_key]):
|
||||
raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified")
|
||||
conf = await trigger.async_validate_trigger_config(hass, conf)
|
||||
elif hasattr(platform, "async_validate_trigger_config"):
|
||||
conf = await platform.async_validate_trigger_config(hass, conf)
|
||||
else:
|
||||
conf = platform.TRIGGER_SCHEMA(conf)
|
||||
@ -337,11 +361,15 @@ async def async_initialize_triggers(
|
||||
trigger_data=trigger_data,
|
||||
)
|
||||
|
||||
if hasattr(platform, "async_get_triggers"):
|
||||
trigger_descriptors = await platform.async_get_triggers(hass)
|
||||
attach_fn = trigger_descriptors[conf[CONF_PLATFORM]].async_attach_trigger
|
||||
else:
|
||||
attach_fn = platform.async_attach_trigger
|
||||
|
||||
triggers.append(
|
||||
create_eager_task(
|
||||
platform.async_attach_trigger(
|
||||
hass, conf, _trigger_action_wrapper(hass, action, conf), info
|
||||
)
|
||||
attach_fn(hass, conf, _trigger_action_wrapper(hass, action, conf), info)
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""The tests for Z-Wave JS automation triggers."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
@ -11,14 +11,11 @@ from zwave_js_server.model.node import Node
|
||||
from homeassistant.components import automation
|
||||
from homeassistant.components.zwave_js import DOMAIN
|
||||
from homeassistant.components.zwave_js.helpers import get_device_id
|
||||
from homeassistant.components.zwave_js.trigger import (
|
||||
_get_trigger_platform,
|
||||
async_validate_trigger_config,
|
||||
)
|
||||
from homeassistant.components.zwave_js.trigger import TRIGGERS
|
||||
from homeassistant.components.zwave_js.triggers.trigger_helpers import (
|
||||
async_bypass_dynamic_config_validation,
|
||||
)
|
||||
from homeassistant.const import CONF_PLATFORM, SERVICE_RELOAD
|
||||
from homeassistant.const import SERVICE_RELOAD
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.setup import async_setup_component
|
||||
@ -977,22 +974,10 @@ async def test_zwave_js_event_invalid_config_entry_id(
|
||||
caplog.clear()
|
||||
|
||||
|
||||
async def test_async_validate_trigger_config(hass: HomeAssistant) -> None:
|
||||
"""Test async_validate_trigger_config."""
|
||||
mock_platform = AsyncMock()
|
||||
with patch(
|
||||
"homeassistant.components.zwave_js.trigger._get_trigger_platform",
|
||||
return_value=mock_platform,
|
||||
):
|
||||
mock_platform.async_validate_trigger_config.return_value = {}
|
||||
await async_validate_trigger_config(hass, {})
|
||||
mock_platform.async_validate_trigger_config.assert_awaited()
|
||||
|
||||
|
||||
async def test_invalid_trigger_configs(hass: HomeAssistant) -> None:
|
||||
"""Test invalid trigger configs."""
|
||||
with pytest.raises(vol.Invalid):
|
||||
await async_validate_trigger_config(
|
||||
await TRIGGERS[f"{DOMAIN}.event"].async_validate_trigger_config(
|
||||
hass,
|
||||
{
|
||||
"platform": f"{DOMAIN}.event",
|
||||
@ -1003,7 +988,7 @@ async def test_invalid_trigger_configs(hass: HomeAssistant) -> None:
|
||||
)
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
await async_validate_trigger_config(
|
||||
await TRIGGERS[f"{DOMAIN}.value_updated"].async_validate_trigger_config(
|
||||
hass,
|
||||
{
|
||||
"platform": f"{DOMAIN}.value_updated",
|
||||
@ -1041,7 +1026,7 @@ async def test_zwave_js_trigger_config_entry_unloaded(
|
||||
await hass.config_entries.async_unload(integration.entry_id)
|
||||
|
||||
# Test full validation for both events
|
||||
assert await async_validate_trigger_config(
|
||||
assert await TRIGGERS[f"{DOMAIN}.value_updated"].async_validate_trigger_config(
|
||||
hass,
|
||||
{
|
||||
"platform": f"{DOMAIN}.value_updated",
|
||||
@ -1051,7 +1036,7 @@ async def test_zwave_js_trigger_config_entry_unloaded(
|
||||
},
|
||||
)
|
||||
|
||||
assert await async_validate_trigger_config(
|
||||
assert await TRIGGERS[f"{DOMAIN}.event"].async_validate_trigger_config(
|
||||
hass,
|
||||
{
|
||||
"platform": f"{DOMAIN}.event",
|
||||
@ -1115,12 +1100,6 @@ async def test_zwave_js_trigger_config_entry_unloaded(
|
||||
)
|
||||
|
||||
|
||||
def test_get_trigger_platform_failure() -> None:
|
||||
"""Test _get_trigger_platform."""
|
||||
with pytest.raises(ValueError):
|
||||
_get_trigger_platform({CONF_PLATFORM: "zwave_js.invalid"})
|
||||
|
||||
|
||||
async def test_server_reconnect_event(
|
||||
hass: HomeAssistant,
|
||||
client,
|
||||
|
Loading…
x
Reference in New Issue
Block a user