Allow overriding trigger runner helper

This commit is contained in:
abmantis
2025-10-01 16:24:50 +01:00
parent c3f45d594b
commit bdd448fbe0
4 changed files with 71 additions and 46 deletions

View File

@@ -213,9 +213,36 @@ class Trigger(abc.ABC):
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
"""Initialize trigger."""
async def async_attach_action(
self, hass: HomeAssistant, action: TriggerActionType
) -> CALLBACK_TYPE:
"""Attach the trigger to an action."""
job = HassJob(action)
@callback
def run_action(
description: str,
extra_trigger_payload: dict[str, Any],
context: Context | None = None,
) -> asyncio.Future[Any] | None:
"""Run action with trigger variables."""
payload = {
"trigger": {
"description": description,
**extra_trigger_payload,
}
}
return hass.async_run_hass_job(job, payload, context)
return await self.async_attach_runner(run_action)
@abc.abstractmethod
async def async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE:
"""Attach the trigger."""
async def async_attach_runner(
self, run_action: RunTriggerActionCallback
) -> CALLBACK_TYPE:
"""Attach the trigger to an action runner."""
class TriggerProtocol(Protocol):
@@ -253,7 +280,7 @@ class TriggerConfig:
options: dict[str, Any] | None = None
class TriggerActionRunnerType(Protocol):
class RunTriggerActionCallback(Protocol):
"""Protocol type for the trigger action runner helper callback."""
@callback
@@ -453,15 +480,22 @@ async def async_validate_trigger_config(
def _trigger_action_wrapper(
hass: HomeAssistant, action: Callable, conf: ConfigType
hass: HomeAssistant,
action: Callable,
conf: ConfigType,
extra_trigger_payload: dict[str, Any],
) -> Callable:
"""Wrap trigger action with extra vars if configured.
"""Wrap trigger action with extra vars.
If action is a coroutine function, a coroutine function will be returned.
If action is a callback, a callback will be returned.
"""
if CONF_VARIABLES not in conf:
return action
def update_run_variables(run_variables: dict[str, Any]) -> None:
run_variables.get("trigger", {}).update(extra_trigger_payload)
if CONF_VARIABLES in conf:
trigger_variables = conf[CONF_VARIABLES]
run_variables.update(trigger_variables.async_render(hass, run_variables))
# Check for partials to properly determine if coroutine function
check_func = action
@@ -477,8 +511,7 @@ def _trigger_action_wrapper(
run_variables: dict[str, Any], context: Context | None = None
) -> Any:
"""Wrap action with extra vars."""
trigger_variables = conf[CONF_VARIABLES]
run_variables.update(trigger_variables.async_render(hass, run_variables))
update_run_variables(run_variables)
return await action(run_variables, context)
wrapper_func = async_with_vars
@@ -490,8 +523,7 @@ def _trigger_action_wrapper(
run_variables: dict[str, Any], context: Context | None = None
) -> Any:
"""Wrap action with extra vars."""
trigger_variables = conf[CONF_VARIABLES]
run_variables.update(trigger_variables.async_render(hass, run_variables))
update_run_variables(run_variables)
return action(run_variables, context)
if is_callback(check_func):
@@ -507,30 +539,15 @@ async def _async_attach_trigger_cls(
trigger_cls: type[Trigger],
trigger_key: str,
conf: ConfigType,
action: TriggerActionType,
action: Callable,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Initialize a new Trigger class and attach it."""
job = HassJob(action)
@callback
def run_action(
description: str,
extra_trigger_payload: dict[str, Any],
context: Context | None = None,
) -> asyncio.Future[Any] | None:
"""Run action with trigger variables."""
payload = {
"trigger": {
**trigger_info["trigger_data"],
CONF_PLATFORM: trigger_key,
"description": description,
**extra_trigger_payload,
}
}
return hass.async_run_hass_job(job, payload, context)
extra_trigger_payload = {
**trigger_info["trigger_data"],
CONF_PLATFORM: trigger_key,
}
action_wrapper = _trigger_action_wrapper(hass, action, conf, extra_trigger_payload)
trigger = trigger_cls(
hass,
@@ -540,7 +557,7 @@ async def _async_attach_trigger_cls(
options=conf.get(CONF_OPTIONS),
),
)
return await trigger.async_attach(run_action)
return await trigger.async_attach_action(hass, action_wrapper)
async def async_initialize_triggers(
@@ -582,7 +599,6 @@ async def async_initialize_triggers(
trigger_data=trigger_data,
)
action_wrapper = _trigger_action_wrapper(hass, action, conf)
if hasattr(platform, "async_get_triggers"):
trigger_descriptors = await platform.async_get_triggers(hass)
relative_trigger_key = get_relative_description_key(
@@ -590,9 +606,14 @@ async def async_initialize_triggers(
)
trigger_cls = trigger_descriptors[relative_trigger_key]
coro = _async_attach_trigger_cls(
hass, trigger_cls, trigger_key, conf, action_wrapper, info
hass, trigger_cls, trigger_key, conf, action, info
)
else:
action_wrapper = (
_trigger_action_wrapper(hass, action, conf, {})
if CONF_VARIABLES in conf
else action
)
coro = platform.async_attach_trigger(hass, conf, action_wrapper, info)
triggers.append(create_eager_task(coro))