Add support for variables on trigger (#68275)

This commit is contained in:
Paulus Schoutsen 2022-03-18 01:25:22 -07:00 committed by GitHub
parent ad84a02b8e
commit ad1e43e083
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 7 deletions

View File

@ -1316,12 +1316,25 @@ CONDITION_ACTION_SCHEMA: vol.Schema = vol.Schema(
) )
TRIGGER_BASE_SCHEMA = vol.Schema( TRIGGER_BASE_SCHEMA = vol.Schema(
{vol.Required(CONF_PLATFORM): str, vol.Optional(CONF_ID): str} {
vol.Required(CONF_PLATFORM): str,
vol.Optional(CONF_ID): str,
vol.Optional(CONF_VARIABLES): SCRIPT_VARIABLES_SCHEMA,
}
) )
TRIGGER_SCHEMA = vol.All(
ensure_list, [TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)] _base_trigger_validator_schema = TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
)
# This is first round of validation, we don't want to process the config here already,
# just ensure basics as platform and ID are there.
def _base_trigger_validator(value: Any) -> Any:
_base_trigger_validator_schema(value)
return value
TRIGGER_SCHEMA = vol.All(ensure_list, [_base_trigger_validator])
_SCRIPT_DELAY_SCHEMA = vol.Schema( _SCRIPT_DELAY_SCHEMA = vol.Schema(
{ {

View File

@ -3,13 +3,14 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable from collections.abc import Callable
import functools
import logging import logging
from typing import Any from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.const import CONF_ID, CONF_PLATFORM from homeassistant.const import CONF_ID, CONF_PLATFORM, CONF_VARIABLES
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.loader import IntegrationNotFound, async_get_integration
@ -55,6 +56,25 @@ async def async_validate_trigger_config(
return config return config
def _trigger_action_wrapper(
hass: HomeAssistant, action: Callable, conf: ConfigType
) -> Callable:
"""Wrap trigger action with extra vars if configured."""
if CONF_VARIABLES not in conf:
return action
@functools.wraps(action)
async def with_vars(
run_variables: dict[str, Any], context: Context | None = None
) -> None:
"""Wrap action with extra vars."""
trigger_variables = conf[CONF_VARIABLES]
run_variables.update(trigger_variables.async_render(hass, run_variables))
await action(run_variables, context)
return with_vars
async def async_initialize_triggers( async def async_initialize_triggers(
hass: HomeAssistant, hass: HomeAssistant,
trigger_config: list[ConfigType], trigger_config: list[ConfigType],
@ -80,7 +100,12 @@ async def async_initialize_triggers(
"variables": variables, "variables": variables,
"trigger_data": trigger_data, "trigger_data": trigger_data,
} }
triggers.append(platform.async_attach_trigger(hass, conf, action, info))
triggers.append(
platform.async_attach_trigger(
hass, conf, _trigger_action_wrapper(hass, action, conf), info
)
)
attach_results = await asyncio.gather(*triggers, return_exceptions=True) attach_results = await asyncio.gather(*triggers, return_exceptions=True)
removes: list[Callable[[], None]] = [] removes: list[Callable[[], None]] = []

View File

@ -8,6 +8,15 @@ from homeassistant.helpers.trigger import (
_async_get_trigger_platform, _async_get_trigger_platform,
async_validate_trigger_config, async_validate_trigger_config,
) )
from homeassistant.setup import async_setup_component
from tests.common import async_mock_service
@pytest.fixture
def calls(hass):
"""Track calls to a mock service."""
return async_mock_service(hass, "test", "automation")
async def test_bad_trigger_platform(hass): async def test_bad_trigger_platform(hass):
@ -24,3 +33,36 @@ async def test_trigger_subtype(hass):
) as integration_mock: ) as integration_mock:
await _async_get_trigger_platform(hass, {"platform": "test.subtype"}) await _async_get_trigger_platform(hass, {"platform": "test.subtype"})
assert integration_mock.call_args == call(hass, "test") assert integration_mock.call_args == call(hass, "test")
async def test_trigger_variables(hass):
"""Test trigger variables."""
async def test_if_fires_on_event(hass, calls):
"""Test the firing of events."""
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "event",
"event_type": "test_event",
"variables": {
"name": "Paulus",
"via_event": "{{ trigger.event.event_type }}",
},
},
"action": {
"service": "test.automation",
"data_template": {"hello": "{{ name }} + {{ via_event }}"},
},
}
},
)
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].data["hello"] == "Paulus + test_event"