Allow overriding trigger runner helper

This commit is contained in:
abmantis
2025-10-01 16:24:50 +01:00
parent c3f45d594b
commit bdd448fbe0
4 changed files with 71 additions and 46 deletions

View File

@@ -23,8 +23,8 @@ from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.trigger import ( from homeassistant.helpers.trigger import (
RunTriggerActionCallback,
Trigger, Trigger,
TriggerActionRunnerType,
TriggerConfig, TriggerConfig,
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@@ -131,7 +131,7 @@ class EventTrigger(Trigger):
_event_name: str _event_name: str
_event_data_filter: dict _event_data_filter: dict
_unsubs: list[Callable] _unsubs: list[Callable]
_action_runner: TriggerActionRunnerType _action_runner: RunTriggerActionCallback
@classmethod @classmethod
async def async_validate_complete_config( async def async_validate_complete_config(
@@ -174,7 +174,9 @@ class EventTrigger(Trigger):
assert config.options is not None assert config.options is not None
self._options = config.options self._options = config.options
async def async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE: async def async_attach_runner(
self, run_action: RunTriggerActionCallback
) -> CALLBACK_TYPE:
"""Attach a trigger.""" """Attach a trigger."""
dev_reg = dr.async_get(self._hass) dev_reg = dr.async_get(self._hass)
options = self._options options = self._options

View File

@@ -17,8 +17,8 @@ from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.trigger import ( from homeassistant.helpers.trigger import (
RunTriggerActionCallback,
Trigger, Trigger,
TriggerActionRunnerType,
TriggerConfig, TriggerConfig,
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@@ -93,7 +93,7 @@ async def async_validate_trigger_config(
async def async_attach_trigger( async def async_attach_trigger(
hass: HomeAssistant, options: ConfigType, run_action: TriggerActionRunnerType hass: HomeAssistant, options: ConfigType, run_action: RunTriggerActionCallback
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
dev_reg = dr.async_get(hass) dev_reg = dr.async_get(hass)
@@ -231,6 +231,8 @@ class ValueUpdatedTrigger(Trigger):
assert config.options is not None assert config.options is not None
self._options = config.options self._options = config.options
async def async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE: async def async_attach_runner(
self, run_action: RunTriggerActionCallback
) -> CALLBACK_TYPE:
"""Attach a trigger.""" """Attach a trigger."""
return await async_attach_trigger(self._hass, self._options, run_action) return await async_attach_trigger(self._hass, self._options, run_action)

View File

@@ -213,9 +213,36 @@ class Trigger(abc.ABC):
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None: def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
"""Initialize trigger.""" """Initialize trigger."""
async def async_attach_action(
self, hass: HomeAssistant, action: TriggerActionType
) -> CALLBACK_TYPE:
"""Attach the trigger to an action."""
job = HassJob(action)
@callback
def run_action(
description: str,
extra_trigger_payload: dict[str, Any],
context: Context | None = None,
) -> asyncio.Future[Any] | None:
"""Run action with trigger variables."""
payload = {
"trigger": {
"description": description,
**extra_trigger_payload,
}
}
return hass.async_run_hass_job(job, payload, context)
return await self.async_attach_runner(run_action)
@abc.abstractmethod @abc.abstractmethod
async def async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE: async def async_attach_runner(
"""Attach the trigger.""" self, run_action: RunTriggerActionCallback
) -> CALLBACK_TYPE:
"""Attach the trigger to an action runner."""
class TriggerProtocol(Protocol): class TriggerProtocol(Protocol):
@@ -253,7 +280,7 @@ class TriggerConfig:
options: dict[str, Any] | None = None options: dict[str, Any] | None = None
class TriggerActionRunnerType(Protocol): class RunTriggerActionCallback(Protocol):
"""Protocol type for the trigger action runner helper callback.""" """Protocol type for the trigger action runner helper callback."""
@callback @callback
@@ -453,15 +480,22 @@ async def async_validate_trigger_config(
def _trigger_action_wrapper( def _trigger_action_wrapper(
hass: HomeAssistant, action: Callable, conf: ConfigType hass: HomeAssistant,
action: Callable,
conf: ConfigType,
extra_trigger_payload: dict[str, Any],
) -> Callable: ) -> Callable:
"""Wrap trigger action with extra vars if configured. """Wrap trigger action with extra vars.
If action is a coroutine function, a coroutine function will be returned. If action is a coroutine function, a coroutine function will be returned.
If action is a callback, a callback will be returned. If action is a callback, a callback will be returned.
""" """
if CONF_VARIABLES not in conf:
return action def update_run_variables(run_variables: dict[str, Any]) -> None:
run_variables.get("trigger", {}).update(extra_trigger_payload)
if CONF_VARIABLES in conf:
trigger_variables = conf[CONF_VARIABLES]
run_variables.update(trigger_variables.async_render(hass, run_variables))
# Check for partials to properly determine if coroutine function # Check for partials to properly determine if coroutine function
check_func = action check_func = action
@@ -477,8 +511,7 @@ def _trigger_action_wrapper(
run_variables: dict[str, Any], context: Context | None = None run_variables: dict[str, Any], context: Context | None = None
) -> Any: ) -> Any:
"""Wrap action with extra vars.""" """Wrap action with extra vars."""
trigger_variables = conf[CONF_VARIABLES] update_run_variables(run_variables)
run_variables.update(trigger_variables.async_render(hass, run_variables))
return await action(run_variables, context) return await action(run_variables, context)
wrapper_func = async_with_vars wrapper_func = async_with_vars
@@ -490,8 +523,7 @@ def _trigger_action_wrapper(
run_variables: dict[str, Any], context: Context | None = None run_variables: dict[str, Any], context: Context | None = None
) -> Any: ) -> Any:
"""Wrap action with extra vars.""" """Wrap action with extra vars."""
trigger_variables = conf[CONF_VARIABLES] update_run_variables(run_variables)
run_variables.update(trigger_variables.async_render(hass, run_variables))
return action(run_variables, context) return action(run_variables, context)
if is_callback(check_func): if is_callback(check_func):
@@ -507,30 +539,15 @@ async def _async_attach_trigger_cls(
trigger_cls: type[Trigger], trigger_cls: type[Trigger],
trigger_key: str, trigger_key: str,
conf: ConfigType, conf: ConfigType,
action: TriggerActionType, action: Callable,
trigger_info: TriggerInfo, trigger_info: TriggerInfo,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Initialize a new Trigger class and attach it.""" """Initialize a new Trigger class and attach it."""
job = HassJob(action) extra_trigger_payload = {
**trigger_info["trigger_data"],
@callback CONF_PLATFORM: trigger_key,
def run_action( }
description: str, action_wrapper = _trigger_action_wrapper(hass, action, conf, extra_trigger_payload)
extra_trigger_payload: dict[str, Any],
context: Context | None = None,
) -> asyncio.Future[Any] | None:
"""Run action with trigger variables."""
payload = {
"trigger": {
**trigger_info["trigger_data"],
CONF_PLATFORM: trigger_key,
"description": description,
**extra_trigger_payload,
}
}
return hass.async_run_hass_job(job, payload, context)
trigger = trigger_cls( trigger = trigger_cls(
hass, hass,
@@ -540,7 +557,7 @@ async def _async_attach_trigger_cls(
options=conf.get(CONF_OPTIONS), options=conf.get(CONF_OPTIONS),
), ),
) )
return await trigger.async_attach(run_action) return await trigger.async_attach_action(hass, action_wrapper)
async def async_initialize_triggers( async def async_initialize_triggers(
@@ -582,7 +599,6 @@ async def async_initialize_triggers(
trigger_data=trigger_data, trigger_data=trigger_data,
) )
action_wrapper = _trigger_action_wrapper(hass, action, conf)
if hasattr(platform, "async_get_triggers"): if hasattr(platform, "async_get_triggers"):
trigger_descriptors = await platform.async_get_triggers(hass) trigger_descriptors = await platform.async_get_triggers(hass)
relative_trigger_key = get_relative_description_key( relative_trigger_key = get_relative_description_key(
@@ -590,9 +606,14 @@ async def async_initialize_triggers(
) )
trigger_cls = trigger_descriptors[relative_trigger_key] trigger_cls = trigger_descriptors[relative_trigger_key]
coro = _async_attach_trigger_cls( coro = _async_attach_trigger_cls(
hass, trigger_cls, trigger_key, conf, action_wrapper, info hass, trigger_cls, trigger_key, conf, action, info
) )
else: else:
action_wrapper = (
_trigger_action_wrapper(hass, action, conf, {})
if CONF_VARIABLES in conf
else action
)
coro = platform.async_attach_trigger(hass, conf, action_wrapper, info) coro = platform.async_attach_trigger(hass, conf, action_wrapper, info)
triggers.append(create_eager_task(coro)) triggers.append(create_eager_task(coro))

View File

@@ -23,8 +23,8 @@ from homeassistant.helpers.automation import move_top_level_schema_fields_to_opt
from homeassistant.helpers.trigger import ( from homeassistant.helpers.trigger import (
DATA_PLUGGABLE_ACTIONS, DATA_PLUGGABLE_ACTIONS,
PluggableAction, PluggableAction,
RunTriggerActionCallback,
Trigger, Trigger,
TriggerActionRunnerType,
_async_get_trigger_platform, _async_get_trigger_platform,
async_initialize_triggers, async_initialize_triggers,
async_validate_trigger_config, async_validate_trigger_config,
@@ -463,8 +463,8 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
class MockTrigger1(MockTrigger): class MockTrigger1(MockTrigger):
"""Mock trigger 1.""" """Mock trigger 1."""
async def async_attach( async def async_attach_runner(
self, run_action: TriggerActionRunnerType self, run_action: RunTriggerActionCallback
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Attach a trigger.""" """Attach a trigger."""
run_action("trigger 1 desc", {"extra": "test_trigger_1"}) run_action("trigger 1 desc", {"extra": "test_trigger_1"})
@@ -472,8 +472,8 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
class MockTrigger2(MockTrigger): class MockTrigger2(MockTrigger):
"""Mock trigger 2.""" """Mock trigger 2."""
async def async_attach( async def async_attach_runner(
self, run_action: TriggerActionRunnerType self, run_action: RunTriggerActionCallback
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Attach a trigger.""" """Attach a trigger."""
run_action("trigger 2 desc", {"extra": "test_trigger_2"}) run_action("trigger 2 desc", {"extra": "test_trigger_2"})