diff --git a/homeassistant/components/zwave_js/triggers/event.py b/homeassistant/components/zwave_js/triggers/event.py index 6565e698373..0eeaa01c452 100644 --- a/homeassistant/components/zwave_js/triggers/event.py +++ b/homeassistant/components/zwave_js/triggers/event.py @@ -17,17 +17,14 @@ from homeassistant.const import ( ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_OPTIONS, - CONF_PLATFORM, ) -from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback +from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.helpers import config_validation as cv, device_registry as dr from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.trigger import ( Trigger, - TriggerActionType, + TriggerActionRunnerType, TriggerConfig, - TriggerData, - TriggerInfo, move_top_level_schema_fields_to_options, ) from homeassistant.helpers.typing import ConfigType @@ -127,17 +124,13 @@ _CONFIG_SCHEMA = vol.Schema( class EventTrigger(Trigger): """Z-Wave JS event trigger.""" - _hass: HomeAssistant _options: dict[str, Any] _event_source: str _event_name: str _event_data_filter: dict - _job: HassJob - _trigger_data: TriggerData _unsubs: list[Callable] - - _platform_type = PLATFORM_TYPE + _action_runner: TriggerActionRunnerType @classmethod async def async_validate_complete_config( @@ -176,15 +169,11 @@ class EventTrigger(Trigger): def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None: """Initialize trigger.""" - self._hass = hass + super().__init__(hass, config) assert config.options is not None self._options = config.options - async def async_attach( - self, - action: TriggerActionType, - trigger_info: TriggerInfo, - ) -> CALLBACK_TYPE: + async def _async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE: """Attach a trigger.""" dev_reg = dr.async_get(self._hass) options = self._options @@ -198,8 +187,7 @@ class EventTrigger(Trigger): self._event_source = options[ATTR_EVENT_SOURCE] self._event_name = options[ATTR_EVENT] self._event_data_filter = options.get(ATTR_EVENT_DATA, {}) - self._job = HassJob(action) - self._trigger_data = trigger_info["trigger_data"] + self._action_runner = run_action self._unsubs: list[Callable] = [] self._create_zwave_listeners() @@ -226,8 +214,6 @@ class EventTrigger(Trigger): return payload = { - **self._trigger_data, - CONF_PLATFORM: self._platform_type, ATTR_EVENT_SOURCE: self._event_source, ATTR_EVENT: self._event_name, ATTR_EVENT_DATA: event_data, @@ -237,21 +223,17 @@ class EventTrigger(Trigger): f"Z-Wave JS '{self._event_source}' event '{self._event_name}' was emitted" ) + description = primary_desc if device: device_name = device.name_by_user or device.name payload[ATTR_DEVICE_ID] = device.id home_and_node_id = get_home_and_node_id_from_device_entry(device) assert home_and_node_id - payload[ATTR_NODE_ID] = home_and_node_id[1] - payload["description"] = f"{primary_desc} on {device_name}" - else: - payload["description"] = primary_desc + payload[ATTR_NODE_ID] = home_and_node_id[1] # type: ignore[assignment] + description = f"{primary_desc} on {device_name}" - payload["description"] = ( - f"{payload['description']} with event data: {event_data}" - ) - - self._hass.async_run_hass_job(self._job, {"trigger": payload}) + description = f"{description} with event data: {event_data}" + self._action_runner(description, payload) @callback def _async_remove(self) -> None: diff --git a/homeassistant/components/zwave_js/triggers/value_updated.py b/homeassistant/components/zwave_js/triggers/value_updated.py index 14ab0996189..3dab649ab67 100644 --- a/homeassistant/components/zwave_js/triggers/value_updated.py +++ b/homeassistant/components/zwave_js/triggers/value_updated.py @@ -11,21 +11,14 @@ from zwave_js_server.const import CommandClass from zwave_js_server.model.driver import Driver from zwave_js_server.model.value import Value, get_value_id_str -from homeassistant.const import ( - ATTR_DEVICE_ID, - ATTR_ENTITY_ID, - CONF_OPTIONS, - CONF_PLATFORM, - MATCH_ALL, -) -from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback +from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_OPTIONS, MATCH_ALL +from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.helpers import config_validation as cv, device_registry as dr from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.trigger import ( Trigger, - TriggerActionType, + TriggerActionRunnerType, TriggerConfig, - TriggerInfo, move_top_level_schema_fields_to_options, ) from homeassistant.helpers.typing import ConfigType @@ -100,12 +93,7 @@ async def async_validate_trigger_config( async def async_attach_trigger( - hass: HomeAssistant, - options: ConfigType, - action: TriggerActionType, - trigger_info: TriggerInfo, - *, - platform_type: str = PLATFORM_TYPE, + hass: HomeAssistant, options: ConfigType, run_action: TriggerActionRunnerType ) -> CALLBACK_TYPE: """Listen for state changes based on configuration.""" dev_reg = dr.async_get(hass) @@ -121,9 +109,6 @@ async def async_attach_trigger( endpoint = options.get(ATTR_ENDPOINT) property_key = options.get(ATTR_PROPERTY_KEY) unsubs: list[Callable] = [] - job = HassJob(action) - - trigger_data = trigger_info["trigger_data"] @callback def async_on_value_updated( @@ -152,10 +137,8 @@ async def async_attach_trigger( return device_name = device.name_by_user or device.name - + description = f"Z-Wave value {value.value_id} updated on {device_name}" payload = { - **trigger_data, - CONF_PLATFORM: platform_type, ATTR_DEVICE_ID: device.id, ATTR_NODE_ID: value.node.node_id, ATTR_COMMAND_CLASS: value.command_class, @@ -169,10 +152,9 @@ async def async_attach_trigger( ATTR_PREVIOUS_VALUE_RAW: prev_value_raw, ATTR_CURRENT_VALUE: curr_value, ATTR_CURRENT_VALUE_RAW: curr_value_raw, - "description": f"Z-Wave value {value.value_id} updated on {device_name}", } - hass.async_run_hass_job(job, {"trigger": payload}) + run_action(description, payload) @callback def async_remove() -> None: @@ -223,7 +205,6 @@ async def async_attach_trigger( class ValueUpdatedTrigger(Trigger): """Z-Wave JS value updated trigger.""" - _hass: HomeAssistant _options: dict[str, Any] @classmethod @@ -245,16 +226,10 @@ class ValueUpdatedTrigger(Trigger): def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None: """Initialize trigger.""" - self._hass = hass + super().__init__(hass, config) assert config.options is not None self._options = config.options - async def async_attach( - self, - action: TriggerActionType, - trigger_info: TriggerInfo, - ) -> CALLBACK_TYPE: + async def _async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE: """Attach a trigger.""" - return await async_attach_trigger( - self._hass, self._options, action, trigger_info - ) + return await async_attach_trigger(self._hass, self._options, run_action) diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 9ebd3367846..e7c9024b853 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -178,6 +178,9 @@ _TRIGGER_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend( class Trigger(abc.ABC): """Trigger class.""" + _job: HassJob | None = None + _trigger_info: TriggerInfo | None = None + @classmethod async def async_validate_complete_config( cls, hass: HomeAssistant, complete_config: ConfigType @@ -212,13 +215,41 @@ class Trigger(abc.ABC): def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None: """Initialize trigger.""" + self._hass = hass + self._config = config + + async def async_attach( + self, action: TriggerActionType, trigger_info: TriggerInfo + ) -> CALLBACK_TYPE: + """Attach the trigger.""" + self._job = HassJob(action) + self._trigger_info = trigger_info + + @callback + def run_action( + description: str, + extra_trigger_payload: dict[str, Any], + context: Context | None = None, + ) -> None: + """Run action with trigger variables.""" + assert self._job + assert self._trigger_info + + payload = { + "trigger": { + **self._trigger_info["trigger_data"], + CONF_PLATFORM: self._config.key, + "description": description, + **extra_trigger_payload, + } + } + + self._hass.async_run_hass_job(self._job, payload, context) + + return await self._async_attach(run_action) @abc.abstractmethod - async def async_attach( - self, - action: TriggerActionType, - trigger_info: TriggerInfo, - ) -> CALLBACK_TYPE: + async def _async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE: """Attach the trigger.""" @@ -257,6 +288,19 @@ class TriggerConfig: options: dict[str, Any] | None = None +class TriggerActionRunnerType(Protocol): + """Protocol type for the trigger action runner helper callback.""" + + @callback + def __call__( + self, + description: str, + extra_trigger_payload: dict[str, Any], + context: Context | None = None, + ) -> None: + """Define trigger action runner type.""" + + class TriggerActionType(Protocol): """Protocol type for trigger action callback.""" diff --git a/tests/helpers/test_trigger.py b/tests/helpers/test_trigger.py index 7402cf2899f..d248b57fc74 100644 --- a/tests/helpers/test_trigger.py +++ b/tests/helpers/test_trigger.py @@ -23,9 +23,7 @@ from homeassistant.helpers.trigger import ( DATA_PLUGGABLE_ACTIONS, PluggableAction, Trigger, - TriggerActionType, - TriggerConfig, - TriggerInfo, + TriggerActionRunnerType, _async_get_trigger_platform, async_initialize_triggers, async_validate_trigger_config, @@ -536,30 +534,23 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None: """Validate config.""" return config - def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None: - """Initialize trigger.""" - class MockTrigger1(MockTrigger): """Mock trigger 1.""" - async def async_attach( - self, - action: TriggerActionType, - trigger_info: TriggerInfo, + async def _async_attach( + self, run_action: TriggerActionRunnerType ) -> CALLBACK_TYPE: """Attach a trigger.""" - action({"trigger": "test_trigger_1"}) + run_action("trigger 1 desc", {"extra": "test_trigger_1"}) class MockTrigger2(MockTrigger): """Mock trigger 2.""" - async def async_attach( - self, - action: TriggerActionType, - trigger_info: TriggerInfo, + async def _async_attach( + self, run_action: TriggerActionRunnerType ) -> CALLBACK_TYPE: """Attach a trigger.""" - action({"trigger": "test_trigger_2"}) + run_action("trigger 2 desc", {"extra": "test_trigger_2"}) async def async_get_triggers(hass: HomeAssistant) -> dict[str, type[Trigger]]: return { @@ -589,11 +580,31 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None: action_calls.append([*args]) await async_initialize_triggers(hass, config_1, cb_action, "test", "", log_cb) - assert action_calls == [[{"trigger": "test_trigger_1"}]] + assert len(action_calls) == 1 + assert action_calls[0][0] == { + "trigger": { + "alias": None, + "description": "trigger 1 desc", + "extra": "test_trigger_1", + "id": "0", + "idx": "0", + "platform": "test", + } + } action_calls.clear() await async_initialize_triggers(hass, config_2, cb_action, "test", "", log_cb) - assert action_calls == [[{"trigger": "test_trigger_2"}]] + assert len(action_calls) == 1 + assert action_calls[0][0] == { + "trigger": { + "alias": None, + "description": "trigger 2 desc", + "extra": "test_trigger_2", + "id": "0", + "idx": "0", + "platform": "test.trig_2", + } + } action_calls.clear() with pytest.raises(KeyError):