From ad1e43e08357d6e8461ecf3e33a5ce2e3277126b Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 18 Mar 2022 01:25:22 -0700 Subject: [PATCH] Add support for variables on trigger (#68275) --- homeassistant/helpers/config_validation.py | 21 ++++++++--- homeassistant/helpers/trigger.py | 31 ++++++++++++++-- tests/helpers/test_trigger.py | 42 ++++++++++++++++++++++ 3 files changed, 87 insertions(+), 7 deletions(-) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 51cd38569fb..fb920c4ef1b 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -1316,12 +1316,25 @@ CONDITION_ACTION_SCHEMA: vol.Schema = vol.Schema( ) TRIGGER_BASE_SCHEMA = vol.Schema( - {vol.Required(CONF_PLATFORM): str, vol.Optional(CONF_ID): str} + { + vol.Required(CONF_PLATFORM): str, + vol.Optional(CONF_ID): str, + vol.Optional(CONF_VARIABLES): SCRIPT_VARIABLES_SCHEMA, + } ) -TRIGGER_SCHEMA = vol.All( - ensure_list, [TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)] -) + +_base_trigger_validator_schema = TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA) + + +# This is first round of validation, we don't want to process the config here already, +# just ensure basics as platform and ID are there. +def _base_trigger_validator(value: Any) -> Any: + _base_trigger_validator_schema(value) + return value + + +TRIGGER_SCHEMA = vol.All(ensure_list, [_base_trigger_validator]) _SCRIPT_DELAY_SCHEMA = vol.Schema( { diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 0b18ad9aa42..79ac9f33f24 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -3,13 +3,14 @@ from __future__ import annotations import asyncio from collections.abc import Callable +import functools import logging from typing import Any import voluptuous as vol -from homeassistant.const import CONF_ID, CONF_PLATFORM -from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback +from homeassistant.const import CONF_ID, CONF_PLATFORM, CONF_VARIABLES +from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.loader import IntegrationNotFound, async_get_integration @@ -55,6 +56,25 @@ async def async_validate_trigger_config( return config +def _trigger_action_wrapper( + hass: HomeAssistant, action: Callable, conf: ConfigType +) -> Callable: + """Wrap trigger action with extra vars if configured.""" + if CONF_VARIABLES not in conf: + return action + + @functools.wraps(action) + async def with_vars( + run_variables: dict[str, Any], context: Context | None = None + ) -> None: + """Wrap action with extra vars.""" + trigger_variables = conf[CONF_VARIABLES] + run_variables.update(trigger_variables.async_render(hass, run_variables)) + await action(run_variables, context) + + return with_vars + + async def async_initialize_triggers( hass: HomeAssistant, trigger_config: list[ConfigType], @@ -80,7 +100,12 @@ async def async_initialize_triggers( "variables": variables, "trigger_data": trigger_data, } - triggers.append(platform.async_attach_trigger(hass, conf, action, info)) + + triggers.append( + platform.async_attach_trigger( + hass, conf, _trigger_action_wrapper(hass, action, conf), info + ) + ) attach_results = await asyncio.gather(*triggers, return_exceptions=True) removes: list[Callable[[], None]] = [] diff --git a/tests/helpers/test_trigger.py b/tests/helpers/test_trigger.py index 7afdb629792..598906b48c3 100644 --- a/tests/helpers/test_trigger.py +++ b/tests/helpers/test_trigger.py @@ -8,6 +8,15 @@ from homeassistant.helpers.trigger import ( _async_get_trigger_platform, async_validate_trigger_config, ) +from homeassistant.setup import async_setup_component + +from tests.common import async_mock_service + + +@pytest.fixture +def calls(hass): + """Track calls to a mock service.""" + return async_mock_service(hass, "test", "automation") async def test_bad_trigger_platform(hass): @@ -24,3 +33,36 @@ async def test_trigger_subtype(hass): ) as integration_mock: await _async_get_trigger_platform(hass, {"platform": "test.subtype"}) assert integration_mock.call_args == call(hass, "test") + + +async def test_trigger_variables(hass): + """Test trigger variables.""" + + +async def test_if_fires_on_event(hass, calls): + """Test the firing of events.""" + assert await async_setup_component( + hass, + "automation", + { + "automation": { + "trigger": { + "platform": "event", + "event_type": "test_event", + "variables": { + "name": "Paulus", + "via_event": "{{ trigger.event.event_type }}", + }, + }, + "action": { + "service": "test.automation", + "data_template": {"hello": "{{ name }} + {{ via_event }}"}, + }, + } + }, + ) + + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0].data["hello"] == "Paulus + test_event"