Simplify firing of trigger actions

This commit is contained in:
abmantis
2025-09-22 18:44:22 +01:00
parent 86dc453c55
commit 7d96a814f9
4 changed files with 98 additions and 86 deletions

View File

@@ -17,17 +17,14 @@ from homeassistant.const import (
ATTR_DEVICE_ID, ATTR_DEVICE_ID,
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
CONF_OPTIONS, CONF_OPTIONS,
CONF_PLATFORM,
) )
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, device_registry as dr from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.trigger import ( from homeassistant.helpers.trigger import (
Trigger, Trigger,
TriggerActionType, TriggerActionRunnerType,
TriggerConfig, TriggerConfig,
TriggerData,
TriggerInfo,
move_top_level_schema_fields_to_options, move_top_level_schema_fields_to_options,
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@@ -127,17 +124,13 @@ _CONFIG_SCHEMA = vol.Schema(
class EventTrigger(Trigger): class EventTrigger(Trigger):
"""Z-Wave JS event trigger.""" """Z-Wave JS event trigger."""
_hass: HomeAssistant
_options: dict[str, Any] _options: dict[str, Any]
_event_source: str _event_source: str
_event_name: str _event_name: str
_event_data_filter: dict _event_data_filter: dict
_job: HassJob
_trigger_data: TriggerData
_unsubs: list[Callable] _unsubs: list[Callable]
_action_runner: TriggerActionRunnerType
_platform_type = PLATFORM_TYPE
@classmethod @classmethod
async def async_validate_complete_config( async def async_validate_complete_config(
@@ -176,15 +169,11 @@ class EventTrigger(Trigger):
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None: def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
"""Initialize trigger.""" """Initialize trigger."""
self._hass = hass super().__init__(hass, config)
assert config.options is not None assert config.options is not None
self._options = config.options self._options = config.options
async def async_attach( async def _async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE:
self,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger.""" """Attach a trigger."""
dev_reg = dr.async_get(self._hass) dev_reg = dr.async_get(self._hass)
options = self._options options = self._options
@@ -198,8 +187,7 @@ class EventTrigger(Trigger):
self._event_source = options[ATTR_EVENT_SOURCE] self._event_source = options[ATTR_EVENT_SOURCE]
self._event_name = options[ATTR_EVENT] self._event_name = options[ATTR_EVENT]
self._event_data_filter = options.get(ATTR_EVENT_DATA, {}) self._event_data_filter = options.get(ATTR_EVENT_DATA, {})
self._job = HassJob(action) self._action_runner = run_action
self._trigger_data = trigger_info["trigger_data"]
self._unsubs: list[Callable] = [] self._unsubs: list[Callable] = []
self._create_zwave_listeners() self._create_zwave_listeners()
@@ -226,8 +214,6 @@ class EventTrigger(Trigger):
return return
payload = { payload = {
**self._trigger_data,
CONF_PLATFORM: self._platform_type,
ATTR_EVENT_SOURCE: self._event_source, ATTR_EVENT_SOURCE: self._event_source,
ATTR_EVENT: self._event_name, ATTR_EVENT: self._event_name,
ATTR_EVENT_DATA: event_data, ATTR_EVENT_DATA: event_data,
@@ -237,21 +223,17 @@ class EventTrigger(Trigger):
f"Z-Wave JS '{self._event_source}' event '{self._event_name}' was emitted" f"Z-Wave JS '{self._event_source}' event '{self._event_name}' was emitted"
) )
description = primary_desc
if device: if device:
device_name = device.name_by_user or device.name device_name = device.name_by_user or device.name
payload[ATTR_DEVICE_ID] = device.id payload[ATTR_DEVICE_ID] = device.id
home_and_node_id = get_home_and_node_id_from_device_entry(device) home_and_node_id = get_home_and_node_id_from_device_entry(device)
assert home_and_node_id assert home_and_node_id
payload[ATTR_NODE_ID] = home_and_node_id[1] payload[ATTR_NODE_ID] = home_and_node_id[1] # type: ignore[assignment]
payload["description"] = f"{primary_desc} on {device_name}" description = f"{primary_desc} on {device_name}"
else:
payload["description"] = primary_desc
payload["description"] = ( description = f"{description} with event data: {event_data}"
f"{payload['description']} with event data: {event_data}" self._action_runner(description, payload)
)
self._hass.async_run_hass_job(self._job, {"trigger": payload})
@callback @callback
def _async_remove(self) -> None: def _async_remove(self) -> None:

View File

@@ -11,21 +11,14 @@ from zwave_js_server.const import CommandClass
from zwave_js_server.model.driver import Driver from zwave_js_server.model.driver import Driver
from zwave_js_server.model.value import Value, get_value_id_str from zwave_js_server.model.value import Value, get_value_id_str
from homeassistant.const import ( from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_OPTIONS, MATCH_ALL
ATTR_DEVICE_ID, from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
ATTR_ENTITY_ID,
CONF_OPTIONS,
CONF_PLATFORM,
MATCH_ALL,
)
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, device_registry as dr from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.trigger import ( from homeassistant.helpers.trigger import (
Trigger, Trigger,
TriggerActionType, TriggerActionRunnerType,
TriggerConfig, TriggerConfig,
TriggerInfo,
move_top_level_schema_fields_to_options, move_top_level_schema_fields_to_options,
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@@ -100,12 +93,7 @@ async def async_validate_trigger_config(
async def async_attach_trigger( async def async_attach_trigger(
hass: HomeAssistant, hass: HomeAssistant, options: ConfigType, run_action: TriggerActionRunnerType
options: ConfigType,
action: TriggerActionType,
trigger_info: TriggerInfo,
*,
platform_type: str = PLATFORM_TYPE,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
dev_reg = dr.async_get(hass) dev_reg = dr.async_get(hass)
@@ -121,9 +109,6 @@ async def async_attach_trigger(
endpoint = options.get(ATTR_ENDPOINT) endpoint = options.get(ATTR_ENDPOINT)
property_key = options.get(ATTR_PROPERTY_KEY) property_key = options.get(ATTR_PROPERTY_KEY)
unsubs: list[Callable] = [] unsubs: list[Callable] = []
job = HassJob(action)
trigger_data = trigger_info["trigger_data"]
@callback @callback
def async_on_value_updated( def async_on_value_updated(
@@ -152,10 +137,8 @@ async def async_attach_trigger(
return return
device_name = device.name_by_user or device.name device_name = device.name_by_user or device.name
description = f"Z-Wave value {value.value_id} updated on {device_name}"
payload = { payload = {
**trigger_data,
CONF_PLATFORM: platform_type,
ATTR_DEVICE_ID: device.id, ATTR_DEVICE_ID: device.id,
ATTR_NODE_ID: value.node.node_id, ATTR_NODE_ID: value.node.node_id,
ATTR_COMMAND_CLASS: value.command_class, ATTR_COMMAND_CLASS: value.command_class,
@@ -169,10 +152,9 @@ async def async_attach_trigger(
ATTR_PREVIOUS_VALUE_RAW: prev_value_raw, ATTR_PREVIOUS_VALUE_RAW: prev_value_raw,
ATTR_CURRENT_VALUE: curr_value, ATTR_CURRENT_VALUE: curr_value,
ATTR_CURRENT_VALUE_RAW: curr_value_raw, ATTR_CURRENT_VALUE_RAW: curr_value_raw,
"description": f"Z-Wave value {value.value_id} updated on {device_name}",
} }
hass.async_run_hass_job(job, {"trigger": payload}) run_action(description, payload)
@callback @callback
def async_remove() -> None: def async_remove() -> None:
@@ -223,7 +205,6 @@ async def async_attach_trigger(
class ValueUpdatedTrigger(Trigger): class ValueUpdatedTrigger(Trigger):
"""Z-Wave JS value updated trigger.""" """Z-Wave JS value updated trigger."""
_hass: HomeAssistant
_options: dict[str, Any] _options: dict[str, Any]
@classmethod @classmethod
@@ -245,16 +226,10 @@ class ValueUpdatedTrigger(Trigger):
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None: def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
"""Initialize trigger.""" """Initialize trigger."""
self._hass = hass super().__init__(hass, config)
assert config.options is not None assert config.options is not None
self._options = config.options self._options = config.options
async def async_attach( async def _async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE:
self,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger.""" """Attach a trigger."""
return await async_attach_trigger( return await async_attach_trigger(self._hass, self._options, run_action)
self._hass, self._options, action, trigger_info
)

View File

@@ -178,6 +178,9 @@ _TRIGGER_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend(
class Trigger(abc.ABC): class Trigger(abc.ABC):
"""Trigger class.""" """Trigger class."""
_job: HassJob | None = None
_trigger_info: TriggerInfo | None = None
@classmethod @classmethod
async def async_validate_complete_config( async def async_validate_complete_config(
cls, hass: HomeAssistant, complete_config: ConfigType cls, hass: HomeAssistant, complete_config: ConfigType
@@ -212,13 +215,41 @@ class Trigger(abc.ABC):
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None: def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
"""Initialize trigger.""" """Initialize trigger."""
self._hass = hass
self._config = config
async def async_attach(
self, action: TriggerActionType, trigger_info: TriggerInfo
) -> CALLBACK_TYPE:
"""Attach the trigger."""
self._job = HassJob(action)
self._trigger_info = trigger_info
@callback
def run_action(
description: str,
extra_trigger_payload: dict[str, Any],
context: Context | None = None,
) -> None:
"""Run action with trigger variables."""
assert self._job
assert self._trigger_info
payload = {
"trigger": {
**self._trigger_info["trigger_data"],
CONF_PLATFORM: self._config.key,
"description": description,
**extra_trigger_payload,
}
}
self._hass.async_run_hass_job(self._job, payload, context)
return await self._async_attach(run_action)
@abc.abstractmethod @abc.abstractmethod
async def async_attach( async def _async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE:
self,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach the trigger.""" """Attach the trigger."""
@@ -257,6 +288,19 @@ class TriggerConfig:
options: dict[str, Any] | None = None options: dict[str, Any] | None = None
class TriggerActionRunnerType(Protocol):
"""Protocol type for the trigger action runner helper callback."""
@callback
def __call__(
self,
description: str,
extra_trigger_payload: dict[str, Any],
context: Context | None = None,
) -> None:
"""Define trigger action runner type."""
class TriggerActionType(Protocol): class TriggerActionType(Protocol):
"""Protocol type for trigger action callback.""" """Protocol type for trigger action callback."""

View File

@@ -23,9 +23,7 @@ from homeassistant.helpers.trigger import (
DATA_PLUGGABLE_ACTIONS, DATA_PLUGGABLE_ACTIONS,
PluggableAction, PluggableAction,
Trigger, Trigger,
TriggerActionType, TriggerActionRunnerType,
TriggerConfig,
TriggerInfo,
_async_get_trigger_platform, _async_get_trigger_platform,
async_initialize_triggers, async_initialize_triggers,
async_validate_trigger_config, async_validate_trigger_config,
@@ -536,30 +534,23 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
"""Validate config.""" """Validate config."""
return config return config
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
"""Initialize trigger."""
class MockTrigger1(MockTrigger): class MockTrigger1(MockTrigger):
"""Mock trigger 1.""" """Mock trigger 1."""
async def async_attach( async def _async_attach(
self, self, run_action: TriggerActionRunnerType
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Attach a trigger.""" """Attach a trigger."""
action({"trigger": "test_trigger_1"}) run_action("trigger 1 desc", {"extra": "test_trigger_1"})
class MockTrigger2(MockTrigger): class MockTrigger2(MockTrigger):
"""Mock trigger 2.""" """Mock trigger 2."""
async def async_attach( async def _async_attach(
self, self, run_action: TriggerActionRunnerType
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Attach a trigger.""" """Attach a trigger."""
action({"trigger": "test_trigger_2"}) run_action("trigger 2 desc", {"extra": "test_trigger_2"})
async def async_get_triggers(hass: HomeAssistant) -> dict[str, type[Trigger]]: async def async_get_triggers(hass: HomeAssistant) -> dict[str, type[Trigger]]:
return { return {
@@ -589,11 +580,31 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
action_calls.append([*args]) action_calls.append([*args])
await async_initialize_triggers(hass, config_1, cb_action, "test", "", log_cb) await async_initialize_triggers(hass, config_1, cb_action, "test", "", log_cb)
assert action_calls == [[{"trigger": "test_trigger_1"}]] assert len(action_calls) == 1
assert action_calls[0][0] == {
"trigger": {
"alias": None,
"description": "trigger 1 desc",
"extra": "test_trigger_1",
"id": "0",
"idx": "0",
"platform": "test",
}
}
action_calls.clear() action_calls.clear()
await async_initialize_triggers(hass, config_2, cb_action, "test", "", log_cb) await async_initialize_triggers(hass, config_2, cb_action, "test", "", log_cb)
assert action_calls == [[{"trigger": "test_trigger_2"}]] assert len(action_calls) == 1
assert action_calls[0][0] == {
"trigger": {
"alias": None,
"description": "trigger 2 desc",
"extra": "test_trigger_2",
"id": "0",
"idx": "0",
"platform": "test.trig_2",
}
}
action_calls.clear() action_calls.clear()
with pytest.raises(KeyError): with pytest.raises(KeyError):