mirror of
https://github.com/home-assistant/core.git
synced 2025-07-18 18:57:06 +00:00
Add support for variables on trigger (#68275)
This commit is contained in:
parent
ad84a02b8e
commit
ad1e43e083
@ -1316,12 +1316,25 @@ CONDITION_ACTION_SCHEMA: vol.Schema = vol.Schema(
|
||||
)
|
||||
|
||||
TRIGGER_BASE_SCHEMA = vol.Schema(
|
||||
{vol.Required(CONF_PLATFORM): str, vol.Optional(CONF_ID): str}
|
||||
{
|
||||
vol.Required(CONF_PLATFORM): str,
|
||||
vol.Optional(CONF_ID): str,
|
||||
vol.Optional(CONF_VARIABLES): SCRIPT_VARIABLES_SCHEMA,
|
||||
}
|
||||
)
|
||||
|
||||
TRIGGER_SCHEMA = vol.All(
|
||||
ensure_list, [TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)]
|
||||
)
|
||||
|
||||
_base_trigger_validator_schema = TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
||||
# This is first round of validation, we don't want to process the config here already,
|
||||
# just ensure basics as platform and ID are there.
|
||||
def _base_trigger_validator(value: Any) -> Any:
|
||||
_base_trigger_validator_schema(value)
|
||||
return value
|
||||
|
||||
|
||||
TRIGGER_SCHEMA = vol.All(ensure_list, [_base_trigger_validator])
|
||||
|
||||
_SCRIPT_DELAY_SCHEMA = vol.Schema(
|
||||
{
|
||||
|
@ -3,13 +3,14 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import CONF_ID, CONF_PLATFORM
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
||||
from homeassistant.const import CONF_ID, CONF_PLATFORM, CONF_VARIABLES
|
||||
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
||||
|
||||
@ -55,6 +56,25 @@ async def async_validate_trigger_config(
|
||||
return config
|
||||
|
||||
|
||||
def _trigger_action_wrapper(
|
||||
hass: HomeAssistant, action: Callable, conf: ConfigType
|
||||
) -> Callable:
|
||||
"""Wrap trigger action with extra vars if configured."""
|
||||
if CONF_VARIABLES not in conf:
|
||||
return action
|
||||
|
||||
@functools.wraps(action)
|
||||
async def with_vars(
|
||||
run_variables: dict[str, Any], context: Context | None = None
|
||||
) -> None:
|
||||
"""Wrap action with extra vars."""
|
||||
trigger_variables = conf[CONF_VARIABLES]
|
||||
run_variables.update(trigger_variables.async_render(hass, run_variables))
|
||||
await action(run_variables, context)
|
||||
|
||||
return with_vars
|
||||
|
||||
|
||||
async def async_initialize_triggers(
|
||||
hass: HomeAssistant,
|
||||
trigger_config: list[ConfigType],
|
||||
@ -80,7 +100,12 @@ async def async_initialize_triggers(
|
||||
"variables": variables,
|
||||
"trigger_data": trigger_data,
|
||||
}
|
||||
triggers.append(platform.async_attach_trigger(hass, conf, action, info))
|
||||
|
||||
triggers.append(
|
||||
platform.async_attach_trigger(
|
||||
hass, conf, _trigger_action_wrapper(hass, action, conf), info
|
||||
)
|
||||
)
|
||||
|
||||
attach_results = await asyncio.gather(*triggers, return_exceptions=True)
|
||||
removes: list[Callable[[], None]] = []
|
||||
|
@ -8,6 +8,15 @@ from homeassistant.helpers.trigger import (
|
||||
_async_get_trigger_platform,
|
||||
async_validate_trigger_config,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import async_mock_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def calls(hass):
|
||||
"""Track calls to a mock service."""
|
||||
return async_mock_service(hass, "test", "automation")
|
||||
|
||||
|
||||
async def test_bad_trigger_platform(hass):
|
||||
@ -24,3 +33,36 @@ async def test_trigger_subtype(hass):
|
||||
) as integration_mock:
|
||||
await _async_get_trigger_platform(hass, {"platform": "test.subtype"})
|
||||
assert integration_mock.call_args == call(hass, "test")
|
||||
|
||||
|
||||
async def test_trigger_variables(hass):
|
||||
"""Test trigger variables."""
|
||||
|
||||
|
||||
async def test_if_fires_on_event(hass, calls):
|
||||
"""Test the firing of events."""
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"automation",
|
||||
{
|
||||
"automation": {
|
||||
"trigger": {
|
||||
"platform": "event",
|
||||
"event_type": "test_event",
|
||||
"variables": {
|
||||
"name": "Paulus",
|
||||
"via_event": "{{ trigger.event.event_type }}",
|
||||
},
|
||||
},
|
||||
"action": {
|
||||
"service": "test.automation",
|
||||
"data_template": {"hello": "{{ name }} + {{ via_event }}"},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
hass.bus.async_fire("test_event")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data["hello"] == "Paulus + test_event"
|
||||
|
Loading…
x
Reference in New Issue
Block a user