mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Add support for variables on trigger (#68275)
This commit is contained in:
parent
ad84a02b8e
commit
ad1e43e083
@ -1316,12 +1316,25 @@ CONDITION_ACTION_SCHEMA: vol.Schema = vol.Schema(
|
|||||||
)
|
)
|
||||||
|
|
||||||
TRIGGER_BASE_SCHEMA = vol.Schema(
|
TRIGGER_BASE_SCHEMA = vol.Schema(
|
||||||
{vol.Required(CONF_PLATFORM): str, vol.Optional(CONF_ID): str}
|
{
|
||||||
|
vol.Required(CONF_PLATFORM): str,
|
||||||
|
vol.Optional(CONF_ID): str,
|
||||||
|
vol.Optional(CONF_VARIABLES): SCRIPT_VARIABLES_SCHEMA,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
TRIGGER_SCHEMA = vol.All(
|
|
||||||
ensure_list, [TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)]
|
_base_trigger_validator_schema = TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
|
||||||
)
|
|
||||||
|
|
||||||
|
# This is first round of validation, we don't want to process the config here already,
|
||||||
|
# just ensure basics as platform and ID are there.
|
||||||
|
def _base_trigger_validator(value: Any) -> Any:
|
||||||
|
_base_trigger_validator_schema(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
TRIGGER_SCHEMA = vol.All(ensure_list, [_base_trigger_validator])
|
||||||
|
|
||||||
_SCRIPT_DELAY_SCHEMA = vol.Schema(
|
_SCRIPT_DELAY_SCHEMA = vol.Schema(
|
||||||
{
|
{
|
||||||
|
@ -3,13 +3,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.const import CONF_ID, CONF_PLATFORM
|
from homeassistant.const import CONF_ID, CONF_PLATFORM, CONF_VARIABLES
|
||||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
||||||
|
|
||||||
@ -55,6 +56,25 @@ async def async_validate_trigger_config(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _trigger_action_wrapper(
|
||||||
|
hass: HomeAssistant, action: Callable, conf: ConfigType
|
||||||
|
) -> Callable:
|
||||||
|
"""Wrap trigger action with extra vars if configured."""
|
||||||
|
if CONF_VARIABLES not in conf:
|
||||||
|
return action
|
||||||
|
|
||||||
|
@functools.wraps(action)
|
||||||
|
async def with_vars(
|
||||||
|
run_variables: dict[str, Any], context: Context | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Wrap action with extra vars."""
|
||||||
|
trigger_variables = conf[CONF_VARIABLES]
|
||||||
|
run_variables.update(trigger_variables.async_render(hass, run_variables))
|
||||||
|
await action(run_variables, context)
|
||||||
|
|
||||||
|
return with_vars
|
||||||
|
|
||||||
|
|
||||||
async def async_initialize_triggers(
|
async def async_initialize_triggers(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
trigger_config: list[ConfigType],
|
trigger_config: list[ConfigType],
|
||||||
@ -80,7 +100,12 @@ async def async_initialize_triggers(
|
|||||||
"variables": variables,
|
"variables": variables,
|
||||||
"trigger_data": trigger_data,
|
"trigger_data": trigger_data,
|
||||||
}
|
}
|
||||||
triggers.append(platform.async_attach_trigger(hass, conf, action, info))
|
|
||||||
|
triggers.append(
|
||||||
|
platform.async_attach_trigger(
|
||||||
|
hass, conf, _trigger_action_wrapper(hass, action, conf), info
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
attach_results = await asyncio.gather(*triggers, return_exceptions=True)
|
attach_results = await asyncio.gather(*triggers, return_exceptions=True)
|
||||||
removes: list[Callable[[], None]] = []
|
removes: list[Callable[[], None]] = []
|
||||||
|
@ -8,6 +8,15 @@ from homeassistant.helpers.trigger import (
|
|||||||
_async_get_trigger_platform,
|
_async_get_trigger_platform,
|
||||||
async_validate_trigger_config,
|
async_validate_trigger_config,
|
||||||
)
|
)
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.common import async_mock_service
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def calls(hass):
|
||||||
|
"""Track calls to a mock service."""
|
||||||
|
return async_mock_service(hass, "test", "automation")
|
||||||
|
|
||||||
|
|
||||||
async def test_bad_trigger_platform(hass):
|
async def test_bad_trigger_platform(hass):
|
||||||
@ -24,3 +33,36 @@ async def test_trigger_subtype(hass):
|
|||||||
) as integration_mock:
|
) as integration_mock:
|
||||||
await _async_get_trigger_platform(hass, {"platform": "test.subtype"})
|
await _async_get_trigger_platform(hass, {"platform": "test.subtype"})
|
||||||
assert integration_mock.call_args == call(hass, "test")
|
assert integration_mock.call_args == call(hass, "test")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_trigger_variables(hass):
|
||||||
|
"""Test trigger variables."""
|
||||||
|
|
||||||
|
|
||||||
|
async def test_if_fires_on_event(hass, calls):
|
||||||
|
"""Test the firing of events."""
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"automation",
|
||||||
|
{
|
||||||
|
"automation": {
|
||||||
|
"trigger": {
|
||||||
|
"platform": "event",
|
||||||
|
"event_type": "test_event",
|
||||||
|
"variables": {
|
||||||
|
"name": "Paulus",
|
||||||
|
"via_event": "{{ trigger.event.event_type }}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"service": "test.automation",
|
||||||
|
"data_template": {"hello": "{{ name }} + {{ via_event }}"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
hass.bus.async_fire("test_event")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 1
|
||||||
|
assert calls[0].data["hello"] == "Paulus + test_event"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user