mirror of
https://github.com/home-assistant/core.git
synced 2025-11-09 19:09:32 +00:00
Allow overriding trigger runner helper
This commit is contained in:
@@ -23,8 +23,8 @@ from homeassistant.helpers import config_validation as cv, device_registry as dr
|
|||||||
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||||
from homeassistant.helpers.trigger import (
|
from homeassistant.helpers.trigger import (
|
||||||
|
RunTriggerActionCallback,
|
||||||
Trigger,
|
Trigger,
|
||||||
TriggerActionRunnerType,
|
|
||||||
TriggerConfig,
|
TriggerConfig,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
@@ -131,7 +131,7 @@ class EventTrigger(Trigger):
|
|||||||
_event_name: str
|
_event_name: str
|
||||||
_event_data_filter: dict
|
_event_data_filter: dict
|
||||||
_unsubs: list[Callable]
|
_unsubs: list[Callable]
|
||||||
_action_runner: TriggerActionRunnerType
|
_action_runner: RunTriggerActionCallback
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def async_validate_complete_config(
|
async def async_validate_complete_config(
|
||||||
@@ -174,7 +174,9 @@ class EventTrigger(Trigger):
|
|||||||
assert config.options is not None
|
assert config.options is not None
|
||||||
self._options = config.options
|
self._options = config.options
|
||||||
|
|
||||||
async def async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE:
|
async def async_attach_runner(
|
||||||
|
self, run_action: RunTriggerActionCallback
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
"""Attach a trigger."""
|
"""Attach a trigger."""
|
||||||
dev_reg = dr.async_get(self._hass)
|
dev_reg = dr.async_get(self._hass)
|
||||||
options = self._options
|
options = self._options
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ from homeassistant.helpers import config_validation as cv, device_registry as dr
|
|||||||
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||||
from homeassistant.helpers.trigger import (
|
from homeassistant.helpers.trigger import (
|
||||||
|
RunTriggerActionCallback,
|
||||||
Trigger,
|
Trigger,
|
||||||
TriggerActionRunnerType,
|
|
||||||
TriggerConfig,
|
TriggerConfig,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
@@ -93,7 +93,7 @@ async def async_validate_trigger_config(
|
|||||||
|
|
||||||
|
|
||||||
async def async_attach_trigger(
|
async def async_attach_trigger(
|
||||||
hass: HomeAssistant, options: ConfigType, run_action: TriggerActionRunnerType
|
hass: HomeAssistant, options: ConfigType, run_action: RunTriggerActionCallback
|
||||||
) -> CALLBACK_TYPE:
|
) -> CALLBACK_TYPE:
|
||||||
"""Listen for state changes based on configuration."""
|
"""Listen for state changes based on configuration."""
|
||||||
dev_reg = dr.async_get(hass)
|
dev_reg = dr.async_get(hass)
|
||||||
@@ -231,6 +231,8 @@ class ValueUpdatedTrigger(Trigger):
|
|||||||
assert config.options is not None
|
assert config.options is not None
|
||||||
self._options = config.options
|
self._options = config.options
|
||||||
|
|
||||||
async def async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE:
|
async def async_attach_runner(
|
||||||
|
self, run_action: RunTriggerActionCallback
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
"""Attach a trigger."""
|
"""Attach a trigger."""
|
||||||
return await async_attach_trigger(self._hass, self._options, run_action)
|
return await async_attach_trigger(self._hass, self._options, run_action)
|
||||||
|
|||||||
@@ -213,9 +213,36 @@ class Trigger(abc.ABC):
|
|||||||
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
|
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
|
||||||
"""Initialize trigger."""
|
"""Initialize trigger."""
|
||||||
|
|
||||||
|
async def async_attach_action(
|
||||||
|
self, hass: HomeAssistant, action: TriggerActionType
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
|
"""Attach the trigger to an action."""
|
||||||
|
job = HassJob(action)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def run_action(
|
||||||
|
description: str,
|
||||||
|
extra_trigger_payload: dict[str, Any],
|
||||||
|
context: Context | None = None,
|
||||||
|
) -> asyncio.Future[Any] | None:
|
||||||
|
"""Run action with trigger variables."""
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"trigger": {
|
||||||
|
"description": description,
|
||||||
|
**extra_trigger_payload,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return hass.async_run_hass_job(job, payload, context)
|
||||||
|
|
||||||
|
return await self.async_attach_runner(run_action)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def async_attach(self, run_action: TriggerActionRunnerType) -> CALLBACK_TYPE:
|
async def async_attach_runner(
|
||||||
"""Attach the trigger."""
|
self, run_action: RunTriggerActionCallback
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
|
"""Attach the trigger to an action runner."""
|
||||||
|
|
||||||
|
|
||||||
class TriggerProtocol(Protocol):
|
class TriggerProtocol(Protocol):
|
||||||
@@ -253,7 +280,7 @@ class TriggerConfig:
|
|||||||
options: dict[str, Any] | None = None
|
options: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class TriggerActionRunnerType(Protocol):
|
class RunTriggerActionCallback(Protocol):
|
||||||
"""Protocol type for the trigger action runner helper callback."""
|
"""Protocol type for the trigger action runner helper callback."""
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@@ -453,15 +480,22 @@ async def async_validate_trigger_config(
|
|||||||
|
|
||||||
|
|
||||||
def _trigger_action_wrapper(
|
def _trigger_action_wrapper(
|
||||||
hass: HomeAssistant, action: Callable, conf: ConfigType
|
hass: HomeAssistant,
|
||||||
|
action: Callable,
|
||||||
|
conf: ConfigType,
|
||||||
|
extra_trigger_payload: dict[str, Any],
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
"""Wrap trigger action with extra vars if configured.
|
"""Wrap trigger action with extra vars.
|
||||||
|
|
||||||
If action is a coroutine function, a coroutine function will be returned.
|
If action is a coroutine function, a coroutine function will be returned.
|
||||||
If action is a callback, a callback will be returned.
|
If action is a callback, a callback will be returned.
|
||||||
"""
|
"""
|
||||||
if CONF_VARIABLES not in conf:
|
|
||||||
return action
|
def update_run_variables(run_variables: dict[str, Any]) -> None:
|
||||||
|
run_variables.get("trigger", {}).update(extra_trigger_payload)
|
||||||
|
if CONF_VARIABLES in conf:
|
||||||
|
trigger_variables = conf[CONF_VARIABLES]
|
||||||
|
run_variables.update(trigger_variables.async_render(hass, run_variables))
|
||||||
|
|
||||||
# Check for partials to properly determine if coroutine function
|
# Check for partials to properly determine if coroutine function
|
||||||
check_func = action
|
check_func = action
|
||||||
@@ -477,8 +511,7 @@ def _trigger_action_wrapper(
|
|||||||
run_variables: dict[str, Any], context: Context | None = None
|
run_variables: dict[str, Any], context: Context | None = None
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Wrap action with extra vars."""
|
"""Wrap action with extra vars."""
|
||||||
trigger_variables = conf[CONF_VARIABLES]
|
update_run_variables(run_variables)
|
||||||
run_variables.update(trigger_variables.async_render(hass, run_variables))
|
|
||||||
return await action(run_variables, context)
|
return await action(run_variables, context)
|
||||||
|
|
||||||
wrapper_func = async_with_vars
|
wrapper_func = async_with_vars
|
||||||
@@ -490,8 +523,7 @@ def _trigger_action_wrapper(
|
|||||||
run_variables: dict[str, Any], context: Context | None = None
|
run_variables: dict[str, Any], context: Context | None = None
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Wrap action with extra vars."""
|
"""Wrap action with extra vars."""
|
||||||
trigger_variables = conf[CONF_VARIABLES]
|
update_run_variables(run_variables)
|
||||||
run_variables.update(trigger_variables.async_render(hass, run_variables))
|
|
||||||
return action(run_variables, context)
|
return action(run_variables, context)
|
||||||
|
|
||||||
if is_callback(check_func):
|
if is_callback(check_func):
|
||||||
@@ -507,30 +539,15 @@ async def _async_attach_trigger_cls(
|
|||||||
trigger_cls: type[Trigger],
|
trigger_cls: type[Trigger],
|
||||||
trigger_key: str,
|
trigger_key: str,
|
||||||
conf: ConfigType,
|
conf: ConfigType,
|
||||||
action: TriggerActionType,
|
action: Callable,
|
||||||
trigger_info: TriggerInfo,
|
trigger_info: TriggerInfo,
|
||||||
) -> CALLBACK_TYPE:
|
) -> CALLBACK_TYPE:
|
||||||
"""Initialize a new Trigger class and attach it."""
|
"""Initialize a new Trigger class and attach it."""
|
||||||
job = HassJob(action)
|
extra_trigger_payload = {
|
||||||
|
**trigger_info["trigger_data"],
|
||||||
@callback
|
CONF_PLATFORM: trigger_key,
|
||||||
def run_action(
|
}
|
||||||
description: str,
|
action_wrapper = _trigger_action_wrapper(hass, action, conf, extra_trigger_payload)
|
||||||
extra_trigger_payload: dict[str, Any],
|
|
||||||
context: Context | None = None,
|
|
||||||
) -> asyncio.Future[Any] | None:
|
|
||||||
"""Run action with trigger variables."""
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"trigger": {
|
|
||||||
**trigger_info["trigger_data"],
|
|
||||||
CONF_PLATFORM: trigger_key,
|
|
||||||
"description": description,
|
|
||||||
**extra_trigger_payload,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return hass.async_run_hass_job(job, payload, context)
|
|
||||||
|
|
||||||
trigger = trigger_cls(
|
trigger = trigger_cls(
|
||||||
hass,
|
hass,
|
||||||
@@ -540,7 +557,7 @@ async def _async_attach_trigger_cls(
|
|||||||
options=conf.get(CONF_OPTIONS),
|
options=conf.get(CONF_OPTIONS),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await trigger.async_attach(run_action)
|
return await trigger.async_attach_action(hass, action_wrapper)
|
||||||
|
|
||||||
|
|
||||||
async def async_initialize_triggers(
|
async def async_initialize_triggers(
|
||||||
@@ -582,7 +599,6 @@ async def async_initialize_triggers(
|
|||||||
trigger_data=trigger_data,
|
trigger_data=trigger_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
action_wrapper = _trigger_action_wrapper(hass, action, conf)
|
|
||||||
if hasattr(platform, "async_get_triggers"):
|
if hasattr(platform, "async_get_triggers"):
|
||||||
trigger_descriptors = await platform.async_get_triggers(hass)
|
trigger_descriptors = await platform.async_get_triggers(hass)
|
||||||
relative_trigger_key = get_relative_description_key(
|
relative_trigger_key = get_relative_description_key(
|
||||||
@@ -590,9 +606,14 @@ async def async_initialize_triggers(
|
|||||||
)
|
)
|
||||||
trigger_cls = trigger_descriptors[relative_trigger_key]
|
trigger_cls = trigger_descriptors[relative_trigger_key]
|
||||||
coro = _async_attach_trigger_cls(
|
coro = _async_attach_trigger_cls(
|
||||||
hass, trigger_cls, trigger_key, conf, action_wrapper, info
|
hass, trigger_cls, trigger_key, conf, action, info
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
action_wrapper = (
|
||||||
|
_trigger_action_wrapper(hass, action, conf, {})
|
||||||
|
if CONF_VARIABLES in conf
|
||||||
|
else action
|
||||||
|
)
|
||||||
coro = platform.async_attach_trigger(hass, conf, action_wrapper, info)
|
coro = platform.async_attach_trigger(hass, conf, action_wrapper, info)
|
||||||
|
|
||||||
triggers.append(create_eager_task(coro))
|
triggers.append(create_eager_task(coro))
|
||||||
|
|||||||
@@ -23,8 +23,8 @@ from homeassistant.helpers.automation import move_top_level_schema_fields_to_opt
|
|||||||
from homeassistant.helpers.trigger import (
|
from homeassistant.helpers.trigger import (
|
||||||
DATA_PLUGGABLE_ACTIONS,
|
DATA_PLUGGABLE_ACTIONS,
|
||||||
PluggableAction,
|
PluggableAction,
|
||||||
|
RunTriggerActionCallback,
|
||||||
Trigger,
|
Trigger,
|
||||||
TriggerActionRunnerType,
|
|
||||||
_async_get_trigger_platform,
|
_async_get_trigger_platform,
|
||||||
async_initialize_triggers,
|
async_initialize_triggers,
|
||||||
async_validate_trigger_config,
|
async_validate_trigger_config,
|
||||||
@@ -463,8 +463,8 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
|
|||||||
class MockTrigger1(MockTrigger):
|
class MockTrigger1(MockTrigger):
|
||||||
"""Mock trigger 1."""
|
"""Mock trigger 1."""
|
||||||
|
|
||||||
async def async_attach(
|
async def async_attach_runner(
|
||||||
self, run_action: TriggerActionRunnerType
|
self, run_action: RunTriggerActionCallback
|
||||||
) -> CALLBACK_TYPE:
|
) -> CALLBACK_TYPE:
|
||||||
"""Attach a trigger."""
|
"""Attach a trigger."""
|
||||||
run_action("trigger 1 desc", {"extra": "test_trigger_1"})
|
run_action("trigger 1 desc", {"extra": "test_trigger_1"})
|
||||||
@@ -472,8 +472,8 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
|
|||||||
class MockTrigger2(MockTrigger):
|
class MockTrigger2(MockTrigger):
|
||||||
"""Mock trigger 2."""
|
"""Mock trigger 2."""
|
||||||
|
|
||||||
async def async_attach(
|
async def async_attach_runner(
|
||||||
self, run_action: TriggerActionRunnerType
|
self, run_action: RunTriggerActionCallback
|
||||||
) -> CALLBACK_TYPE:
|
) -> CALLBACK_TYPE:
|
||||||
"""Attach a trigger."""
|
"""Attach a trigger."""
|
||||||
run_action("trigger 2 desc", {"extra": "test_trigger_2"})
|
run_action("trigger 2 desc", {"extra": "test_trigger_2"})
|
||||||
|
|||||||
Reference in New Issue
Block a user