mirror of
https://github.com/home-assistant/core.git
synced 2025-11-08 18:39:30 +00:00
Allow overriding trigger runner helper
This commit is contained in:
@@ -213,9 +213,36 @@ class Trigger(abc.ABC):
|
||||
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
|
||||
"""Initialize trigger."""
|
||||
|
||||
async def async_attach_action(
|
||||
self, hass: HomeAssistant, action: TriggerActionType
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Attach the trigger to an action."""
|
||||
job = HassJob(action)
|
||||
|
||||
@callback
|
||||
def run_action(
|
||||
description: str,
|
||||
extra_trigger_payload: dict[str, Any],
|
||||
context: Context | None = None,
|
||||
) -> asyncio.Future[Any] | None:
|
||||
"""Run action with trigger variables."""
|
||||
|
||||
payload = {
|
||||
"trigger": {
|
||||
"description": description,
|
||||
**extra_trigger_payload,
|
||||
}
|
||||
}
|
||||
|
||||
return hass.async_run_hass_job(job, payload, context)
|
||||
|
||||
return await self.async_attach_runner(run_action)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE:
|
||||
"""Attach the trigger."""
|
||||
async def async_attach_runner(
|
||||
self, run_action: RunTriggerActionCallback
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Attach the trigger to an action runner."""
|
||||
|
||||
|
||||
class TriggerProtocol(Protocol):
|
||||
@@ -253,7 +280,7 @@ class TriggerConfig:
|
||||
options: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TriggerActionRunnerType(Protocol):
|
||||
class RunTriggerActionCallback(Protocol):
|
||||
"""Protocol type for the trigger action runner helper callback."""
|
||||
|
||||
@callback
|
||||
@@ -453,15 +480,22 @@ async def async_validate_trigger_config(
|
||||
|
||||
|
||||
def _trigger_action_wrapper(
|
||||
hass: HomeAssistant, action: Callable, conf: ConfigType
|
||||
hass: HomeAssistant,
|
||||
action: Callable,
|
||||
conf: ConfigType,
|
||||
extra_trigger_payload: dict[str, Any],
|
||||
) -> Callable:
|
||||
"""Wrap trigger action with extra vars if configured.
|
||||
"""Wrap trigger action with extra vars.
|
||||
|
||||
If action is a coroutine function, a coroutine function will be returned.
|
||||
If action is a callback, a callback will be returned.
|
||||
"""
|
||||
if CONF_VARIABLES not in conf:
|
||||
return action
|
||||
|
||||
def update_run_variables(run_variables: dict[str, Any]) -> None:
|
||||
run_variables.get("trigger", {}).update(extra_trigger_payload)
|
||||
if CONF_VARIABLES in conf:
|
||||
trigger_variables = conf[CONF_VARIABLES]
|
||||
run_variables.update(trigger_variables.async_render(hass, run_variables))
|
||||
|
||||
# Check for partials to properly determine if coroutine function
|
||||
check_func = action
|
||||
@@ -477,8 +511,7 @@ def _trigger_action_wrapper(
|
||||
run_variables: dict[str, Any], context: Context | None = None
|
||||
) -> Any:
|
||||
"""Wrap action with extra vars."""
|
||||
trigger_variables = conf[CONF_VARIABLES]
|
||||
run_variables.update(trigger_variables.async_render(hass, run_variables))
|
||||
update_run_variables(run_variables)
|
||||
return await action(run_variables, context)
|
||||
|
||||
wrapper_func = async_with_vars
|
||||
@@ -490,8 +523,7 @@ def _trigger_action_wrapper(
|
||||
run_variables: dict[str, Any], context: Context | None = None
|
||||
) -> Any:
|
||||
"""Wrap action with extra vars."""
|
||||
trigger_variables = conf[CONF_VARIABLES]
|
||||
run_variables.update(trigger_variables.async_render(hass, run_variables))
|
||||
update_run_variables(run_variables)
|
||||
return action(run_variables, context)
|
||||
|
||||
if is_callback(check_func):
|
||||
@@ -507,30 +539,15 @@ async def _async_attach_trigger_cls(
|
||||
trigger_cls: type[Trigger],
|
||||
trigger_key: str,
|
||||
conf: ConfigType,
|
||||
action: TriggerActionType,
|
||||
action: Callable,
|
||||
trigger_info: TriggerInfo,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Initialize a new Trigger class and attach it."""
|
||||
job = HassJob(action)
|
||||
|
||||
@callback
|
||||
def run_action(
|
||||
description: str,
|
||||
extra_trigger_payload: dict[str, Any],
|
||||
context: Context | None = None,
|
||||
) -> asyncio.Future[Any] | None:
|
||||
"""Run action with trigger variables."""
|
||||
|
||||
payload = {
|
||||
"trigger": {
|
||||
**trigger_info["trigger_data"],
|
||||
CONF_PLATFORM: trigger_key,
|
||||
"description": description,
|
||||
**extra_trigger_payload,
|
||||
}
|
||||
}
|
||||
|
||||
return hass.async_run_hass_job(job, payload, context)
|
||||
extra_trigger_payload = {
|
||||
**trigger_info["trigger_data"],
|
||||
CONF_PLATFORM: trigger_key,
|
||||
}
|
||||
action_wrapper = _trigger_action_wrapper(hass, action, conf, extra_trigger_payload)
|
||||
|
||||
trigger = trigger_cls(
|
||||
hass,
|
||||
@@ -540,7 +557,7 @@ async def _async_attach_trigger_cls(
|
||||
options=conf.get(CONF_OPTIONS),
|
||||
),
|
||||
)
|
||||
return await trigger.async_attach(run_action)
|
||||
return await trigger.async_attach_action(hass, action_wrapper)
|
||||
|
||||
|
||||
async def async_initialize_triggers(
|
||||
@@ -582,7 +599,6 @@ async def async_initialize_triggers(
|
||||
trigger_data=trigger_data,
|
||||
)
|
||||
|
||||
action_wrapper = _trigger_action_wrapper(hass, action, conf)
|
||||
if hasattr(platform, "async_get_triggers"):
|
||||
trigger_descriptors = await platform.async_get_triggers(hass)
|
||||
relative_trigger_key = get_relative_description_key(
|
||||
@@ -590,9 +606,14 @@ async def async_initialize_triggers(
|
||||
)
|
||||
trigger_cls = trigger_descriptors[relative_trigger_key]
|
||||
coro = _async_attach_trigger_cls(
|
||||
hass, trigger_cls, trigger_key, conf, action_wrapper, info
|
||||
hass, trigger_cls, trigger_key, conf, action, info
|
||||
)
|
||||
else:
|
||||
action_wrapper = (
|
||||
_trigger_action_wrapper(hass, action, conf, {})
|
||||
if CONF_VARIABLES in conf
|
||||
else action
|
||||
)
|
||||
coro = platform.async_attach_trigger(hass, conf, action_wrapper, info)
|
||||
|
||||
triggers.append(create_eager_task(coro))
|
||||
|
||||
Reference in New Issue
Block a user