mirror of
https://github.com/home-assistant/core.git
synced 2025-11-09 02:49:40 +00:00
Simplify firing of trigger actions
This commit is contained in:
@@ -17,17 +17,14 @@ from homeassistant.const import (
|
||||
ATTR_DEVICE_ID,
|
||||
ATTR_ENTITY_ID,
|
||||
CONF_OPTIONS,
|
||||
CONF_PLATFORM,
|
||||
)
|
||||
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv, device_registry as dr
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.trigger import (
|
||||
Trigger,
|
||||
TriggerActionType,
|
||||
TriggerActionRunnerType,
|
||||
TriggerConfig,
|
||||
TriggerData,
|
||||
TriggerInfo,
|
||||
move_top_level_schema_fields_to_options,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
@@ -127,17 +124,13 @@ _CONFIG_SCHEMA = vol.Schema(
|
||||
class EventTrigger(Trigger):
|
||||
"""Z-Wave JS event trigger."""
|
||||
|
||||
_hass: HomeAssistant
|
||||
_options: dict[str, Any]
|
||||
|
||||
_event_source: str
|
||||
_event_name: str
|
||||
_event_data_filter: dict
|
||||
_job: HassJob
|
||||
_trigger_data: TriggerData
|
||||
_unsubs: list[Callable]
|
||||
|
||||
_platform_type = PLATFORM_TYPE
|
||||
_action_runner: TriggerActionRunnerType
|
||||
|
||||
@classmethod
|
||||
async def async_validate_complete_config(
|
||||
@@ -176,15 +169,11 @@ class EventTrigger(Trigger):
|
||||
|
||||
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
|
||||
"""Initialize trigger."""
|
||||
self._hass = hass
|
||||
super().__init__(hass, config)
|
||||
assert config.options is not None
|
||||
self._options = config.options
|
||||
|
||||
async def async_attach(
|
||||
self,
|
||||
action: TriggerActionType,
|
||||
trigger_info: TriggerInfo,
|
||||
) -> CALLBACK_TYPE:
|
||||
async def _async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE:
|
||||
"""Attach a trigger."""
|
||||
dev_reg = dr.async_get(self._hass)
|
||||
options = self._options
|
||||
@@ -198,8 +187,7 @@ class EventTrigger(Trigger):
|
||||
self._event_source = options[ATTR_EVENT_SOURCE]
|
||||
self._event_name = options[ATTR_EVENT]
|
||||
self._event_data_filter = options.get(ATTR_EVENT_DATA, {})
|
||||
self._job = HassJob(action)
|
||||
self._trigger_data = trigger_info["trigger_data"]
|
||||
self._action_runner = run_action
|
||||
self._unsubs: list[Callable] = []
|
||||
|
||||
self._create_zwave_listeners()
|
||||
@@ -226,8 +214,6 @@ class EventTrigger(Trigger):
|
||||
return
|
||||
|
||||
payload = {
|
||||
**self._trigger_data,
|
||||
CONF_PLATFORM: self._platform_type,
|
||||
ATTR_EVENT_SOURCE: self._event_source,
|
||||
ATTR_EVENT: self._event_name,
|
||||
ATTR_EVENT_DATA: event_data,
|
||||
@@ -237,21 +223,17 @@ class EventTrigger(Trigger):
|
||||
f"Z-Wave JS '{self._event_source}' event '{self._event_name}' was emitted"
|
||||
)
|
||||
|
||||
description = primary_desc
|
||||
if device:
|
||||
device_name = device.name_by_user or device.name
|
||||
payload[ATTR_DEVICE_ID] = device.id
|
||||
home_and_node_id = get_home_and_node_id_from_device_entry(device)
|
||||
assert home_and_node_id
|
||||
payload[ATTR_NODE_ID] = home_and_node_id[1]
|
||||
payload["description"] = f"{primary_desc} on {device_name}"
|
||||
else:
|
||||
payload["description"] = primary_desc
|
||||
payload[ATTR_NODE_ID] = home_and_node_id[1] # type: ignore[assignment]
|
||||
description = f"{primary_desc} on {device_name}"
|
||||
|
||||
payload["description"] = (
|
||||
f"{payload['description']} with event data: {event_data}"
|
||||
)
|
||||
|
||||
self._hass.async_run_hass_job(self._job, {"trigger": payload})
|
||||
description = f"{description} with event data: {event_data}"
|
||||
self._action_runner(description, payload)
|
||||
|
||||
@callback
|
||||
def _async_remove(self) -> None:
|
||||
|
||||
@@ -11,21 +11,14 @@ from zwave_js_server.const import CommandClass
|
||||
from zwave_js_server.model.driver import Driver
|
||||
from zwave_js_server.model.value import Value, get_value_id_str
|
||||
|
||||
from homeassistant.const import (
|
||||
ATTR_DEVICE_ID,
|
||||
ATTR_ENTITY_ID,
|
||||
CONF_OPTIONS,
|
||||
CONF_PLATFORM,
|
||||
MATCH_ALL,
|
||||
)
|
||||
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
||||
from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_OPTIONS, MATCH_ALL
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv, device_registry as dr
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.trigger import (
|
||||
Trigger,
|
||||
TriggerActionType,
|
||||
TriggerActionRunnerType,
|
||||
TriggerConfig,
|
||||
TriggerInfo,
|
||||
move_top_level_schema_fields_to_options,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
@@ -100,12 +93,7 @@ async def async_validate_trigger_config(
|
||||
|
||||
|
||||
async def async_attach_trigger(
|
||||
hass: HomeAssistant,
|
||||
options: ConfigType,
|
||||
action: TriggerActionType,
|
||||
trigger_info: TriggerInfo,
|
||||
*,
|
||||
platform_type: str = PLATFORM_TYPE,
|
||||
hass: HomeAssistant, options: ConfigType, run_action: TriggerActionRunnerType
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Listen for state changes based on configuration."""
|
||||
dev_reg = dr.async_get(hass)
|
||||
@@ -121,9 +109,6 @@ async def async_attach_trigger(
|
||||
endpoint = options.get(ATTR_ENDPOINT)
|
||||
property_key = options.get(ATTR_PROPERTY_KEY)
|
||||
unsubs: list[Callable] = []
|
||||
job = HassJob(action)
|
||||
|
||||
trigger_data = trigger_info["trigger_data"]
|
||||
|
||||
@callback
|
||||
def async_on_value_updated(
|
||||
@@ -152,10 +137,8 @@ async def async_attach_trigger(
|
||||
return
|
||||
|
||||
device_name = device.name_by_user or device.name
|
||||
|
||||
description = f"Z-Wave value {value.value_id} updated on {device_name}"
|
||||
payload = {
|
||||
**trigger_data,
|
||||
CONF_PLATFORM: platform_type,
|
||||
ATTR_DEVICE_ID: device.id,
|
||||
ATTR_NODE_ID: value.node.node_id,
|
||||
ATTR_COMMAND_CLASS: value.command_class,
|
||||
@@ -169,10 +152,9 @@ async def async_attach_trigger(
|
||||
ATTR_PREVIOUS_VALUE_RAW: prev_value_raw,
|
||||
ATTR_CURRENT_VALUE: curr_value,
|
||||
ATTR_CURRENT_VALUE_RAW: curr_value_raw,
|
||||
"description": f"Z-Wave value {value.value_id} updated on {device_name}",
|
||||
}
|
||||
|
||||
hass.async_run_hass_job(job, {"trigger": payload})
|
||||
run_action(description, payload)
|
||||
|
||||
@callback
|
||||
def async_remove() -> None:
|
||||
@@ -223,7 +205,6 @@ async def async_attach_trigger(
|
||||
class ValueUpdatedTrigger(Trigger):
|
||||
"""Z-Wave JS value updated trigger."""
|
||||
|
||||
_hass: HomeAssistant
|
||||
_options: dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
@@ -245,16 +226,10 @@ class ValueUpdatedTrigger(Trigger):
|
||||
|
||||
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
|
||||
"""Initialize trigger."""
|
||||
self._hass = hass
|
||||
super().__init__(hass, config)
|
||||
assert config.options is not None
|
||||
self._options = config.options
|
||||
|
||||
async def async_attach(
|
||||
self,
|
||||
action: TriggerActionType,
|
||||
trigger_info: TriggerInfo,
|
||||
) -> CALLBACK_TYPE:
|
||||
async def _async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE:
|
||||
"""Attach a trigger."""
|
||||
return await async_attach_trigger(
|
||||
self._hass, self._options, action, trigger_info
|
||||
)
|
||||
return await async_attach_trigger(self._hass, self._options, run_action)
|
||||
|
||||
@@ -178,6 +178,9 @@ _TRIGGER_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend(
|
||||
class Trigger(abc.ABC):
|
||||
"""Trigger class."""
|
||||
|
||||
_job: HassJob | None = None
|
||||
_trigger_info: TriggerInfo | None = None
|
||||
|
||||
@classmethod
|
||||
async def async_validate_complete_config(
|
||||
cls, hass: HomeAssistant, complete_config: ConfigType
|
||||
@@ -212,13 +215,41 @@ class Trigger(abc.ABC):
|
||||
|
||||
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
|
||||
"""Initialize trigger."""
|
||||
self._hass = hass
|
||||
self._config = config
|
||||
|
||||
async def async_attach(
|
||||
self, action: TriggerActionType, trigger_info: TriggerInfo
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Attach the trigger."""
|
||||
self._job = HassJob(action)
|
||||
self._trigger_info = trigger_info
|
||||
|
||||
@callback
|
||||
def run_action(
|
||||
description: str,
|
||||
extra_trigger_payload: dict[str, Any],
|
||||
context: Context | None = None,
|
||||
) -> None:
|
||||
"""Run action with trigger variables."""
|
||||
assert self._job
|
||||
assert self._trigger_info
|
||||
|
||||
payload = {
|
||||
"trigger": {
|
||||
**self._trigger_info["trigger_data"],
|
||||
CONF_PLATFORM: self._config.key,
|
||||
"description": description,
|
||||
**extra_trigger_payload,
|
||||
}
|
||||
}
|
||||
|
||||
self._hass.async_run_hass_job(self._job, payload, context)
|
||||
|
||||
return await self._async_attach(run_action)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_attach(
|
||||
self,
|
||||
action: TriggerActionType,
|
||||
trigger_info: TriggerInfo,
|
||||
) -> CALLBACK_TYPE:
|
||||
async def _async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE:
|
||||
"""Attach the trigger."""
|
||||
|
||||
|
||||
@@ -257,6 +288,19 @@ class TriggerConfig:
|
||||
options: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TriggerActionRunnerType(Protocol):
|
||||
"""Protocol type for the trigger action runner helper callback."""
|
||||
|
||||
@callback
|
||||
def __call__(
|
||||
self,
|
||||
description: str,
|
||||
extra_trigger_payload: dict[str, Any],
|
||||
context: Context | None = None,
|
||||
) -> None:
|
||||
"""Define trigger action runner type."""
|
||||
|
||||
|
||||
class TriggerActionType(Protocol):
|
||||
"""Protocol type for trigger action callback."""
|
||||
|
||||
|
||||
@@ -23,9 +23,7 @@ from homeassistant.helpers.trigger import (
|
||||
DATA_PLUGGABLE_ACTIONS,
|
||||
PluggableAction,
|
||||
Trigger,
|
||||
TriggerActionType,
|
||||
TriggerConfig,
|
||||
TriggerInfo,
|
||||
TriggerActionRunnerType,
|
||||
_async_get_trigger_platform,
|
||||
async_initialize_triggers,
|
||||
async_validate_trigger_config,
|
||||
@@ -536,30 +534,23 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
|
||||
"""Validate config."""
|
||||
return config
|
||||
|
||||
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
|
||||
"""Initialize trigger."""
|
||||
|
||||
class MockTrigger1(MockTrigger):
|
||||
"""Mock trigger 1."""
|
||||
|
||||
async def async_attach(
|
||||
self,
|
||||
action: TriggerActionType,
|
||||
trigger_info: TriggerInfo,
|
||||
async def _async_attach(
|
||||
self, run_action: TriggerActionRunnerType
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Attach a trigger."""
|
||||
action({"trigger": "test_trigger_1"})
|
||||
run_action("trigger 1 desc", {"extra": "test_trigger_1"})
|
||||
|
||||
class MockTrigger2(MockTrigger):
|
||||
"""Mock trigger 2."""
|
||||
|
||||
async def async_attach(
|
||||
self,
|
||||
action: TriggerActionType,
|
||||
trigger_info: TriggerInfo,
|
||||
async def _async_attach(
|
||||
self, run_action: TriggerActionRunnerType
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Attach a trigger."""
|
||||
action({"trigger": "test_trigger_2"})
|
||||
run_action("trigger 2 desc", {"extra": "test_trigger_2"})
|
||||
|
||||
async def async_get_triggers(hass: HomeAssistant) -> dict[str, type[Trigger]]:
|
||||
return {
|
||||
@@ -589,11 +580,31 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
|
||||
action_calls.append([*args])
|
||||
|
||||
await async_initialize_triggers(hass, config_1, cb_action, "test", "", log_cb)
|
||||
assert action_calls == [[{"trigger": "test_trigger_1"}]]
|
||||
assert len(action_calls) == 1
|
||||
assert action_calls[0][0] == {
|
||||
"trigger": {
|
||||
"alias": None,
|
||||
"description": "trigger 1 desc",
|
||||
"extra": "test_trigger_1",
|
||||
"id": "0",
|
||||
"idx": "0",
|
||||
"platform": "test",
|
||||
}
|
||||
}
|
||||
action_calls.clear()
|
||||
|
||||
await async_initialize_triggers(hass, config_2, cb_action, "test", "", log_cb)
|
||||
assert action_calls == [[{"trigger": "test_trigger_2"}]]
|
||||
assert len(action_calls) == 1
|
||||
assert action_calls[0][0] == {
|
||||
"trigger": {
|
||||
"alias": None,
|
||||
"description": "trigger 2 desc",
|
||||
"extra": "test_trigger_2",
|
||||
"id": "0",
|
||||
"idx": "0",
|
||||
"platform": "test.trig_2",
|
||||
}
|
||||
}
|
||||
action_calls.clear()
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
|
||||
Reference in New Issue
Block a user