From 7b1c0c2df20fa23281a05f224a1cd6c0b029d6ca Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 2 Sep 2023 16:19:45 -0700 Subject: [PATCH] Extend template entities with a script section (#96175) * Extend template entities with a script section This allows making a trigger entity that triggers a few times a day, and allows collecting data from a service resopnse which can be fed into a template entity. The current alternatives are to publish and subscribe to events or to store data in input entities. * Make variables set in actions accessible to templates * Format code --------- Co-authored-by: Erik --- homeassistant/components/script/__init__.py | 3 +- homeassistant/components/template/__init__.py | 19 ++++++-- homeassistant/components/template/config.py | 3 +- homeassistant/components/template/const.py | 1 + .../components/websocket_api/commands.py | 4 +- homeassistant/helpers/script.py | 15 +++++-- tests/components/template/test_sensor.py | 44 +++++++++++++++++++ 7 files changed, 79 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index 8530aa3b04c..13b25a00053 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -563,7 +563,8 @@ class ScriptEntity(BaseScriptEntity, RestoreEntity): ) coro = self._async_run(variables, context) if wait: - return await coro + script_result = await coro + return script_result.service_response if script_result else None # Caller does not want to wait for called script to finish so let script run in # separate Task. Make a new empty script stack; scripts are allowed to diff --git a/homeassistant/components/template/__init__.py b/homeassistant/components/template/__init__.py index e9ced060491..c4ba7081f5a 100644 --- a/homeassistant/components/template/__init__.py +++ b/homeassistant/components/template/__init__.py @@ -20,11 +20,12 @@ from homeassistant.helpers import ( update_coordinator, ) from homeassistant.helpers.reload import async_reload_integration_platforms +from homeassistant.helpers.script import Script from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.typing import ConfigType from homeassistant.loader import async_get_integration -from .const import CONF_TRIGGER, DOMAIN, PLATFORMS +from .const import CONF_ACTION, CONF_TRIGGER, DOMAIN, PLATFORMS _LOGGER = logging.getLogger(__name__) @@ -133,6 +134,7 @@ class TriggerUpdateCoordinator(update_coordinator.DataUpdateCoordinator): self.config = config self._unsub_start: Callable[[], None] | None = None self._unsub_trigger: Callable[[], None] | None = None + self._script: Script | None = None @property def unique_id(self) -> str | None: @@ -170,6 +172,14 @@ class TriggerUpdateCoordinator(update_coordinator.DataUpdateCoordinator): async def _attach_triggers(self, start_event=None) -> None: """Attach the triggers.""" + if CONF_ACTION in self.config: + self._script = Script( + self.hass, + self.config[CONF_ACTION], + self.name, + DOMAIN, + ) + if start_event is not None: self._unsub_start = None @@ -183,8 +193,11 @@ class TriggerUpdateCoordinator(update_coordinator.DataUpdateCoordinator): start_event is not None, ) - @callback - def _handle_triggered(self, run_variables, context=None): + async def _handle_triggered(self, run_variables, context=None): + if self._script: + script_result = await self._script.async_run(run_variables, context) + if script_result: + run_variables = script_result.variables self.async_set_updated_data( {"run_variables": run_variables, "context": context} ) diff --git a/homeassistant/components/template/config.py b/homeassistant/components/template/config.py index 2261bde2659..54c82d88c74 100644 --- a/homeassistant/components/template/config.py +++ b/homeassistant/components/template/config.py @@ -22,7 +22,7 @@ from . import ( select as select_platform, sensor as sensor_platform, ) -from .const import CONF_TRIGGER, DOMAIN +from .const import CONF_ACTION, CONF_TRIGGER, DOMAIN PACKAGE_MERGE_HINT = "list" @@ -30,6 +30,7 @@ CONFIG_SECTION_SCHEMA = vol.Schema( { vol.Optional(CONF_UNIQUE_ID): cv.string, vol.Optional(CONF_TRIGGER): cv.TRIGGER_SCHEMA, + vol.Optional(CONF_ACTION): cv.SCRIPT_SCHEMA, vol.Optional(NUMBER_DOMAIN): vol.All( cv.ensure_list, [number_platform.NUMBER_SCHEMA] ), diff --git a/homeassistant/components/template/const.py b/homeassistant/components/template/const.py index 9b371125750..6805c0ad812 100644 --- a/homeassistant/components/template/const.py +++ b/homeassistant/components/template/const.py @@ -2,6 +2,7 @@ from homeassistant.const import Platform +CONF_ACTION = "action" CONF_AVAILABILITY_TEMPLATE = "availability_template" CONF_ATTRIBUTE_TEMPLATES = "attribute_templates" CONF_TRIGGER = "trigger" diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index bbcbfa6ecb8..c6564967a39 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -713,12 +713,12 @@ async def handle_execute_script( context = connection.context(msg) script_obj = Script(hass, script_config, f"{const.DOMAIN} script", const.DOMAIN) - response = await script_obj.async_run(msg.get("variables"), context=context) + script_result = await script_obj.async_run(msg.get("variables"), context=context) connection.send_result( msg["id"], { "context": context, - "response": response, + "response": script_result.service_response if script_result else None, }, ) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 4035d55b325..c9d8de23b96 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -6,6 +6,7 @@ from collections.abc import Callable, Mapping, Sequence from contextlib import asynccontextmanager, suppress from contextvars import ContextVar from copy import copy +from dataclasses import dataclass from datetime import datetime, timedelta from functools import partial import itertools @@ -401,7 +402,7 @@ class _ScriptRun: ) self._log("Executing step %s%s", self._script.last_action, _timeout) - async def async_run(self) -> ServiceResponse: + async def async_run(self) -> ScriptRunResult | None: """Run script.""" # Push the script to the script execution stack if (script_stack := script_stack_cv.get()) is None: @@ -443,7 +444,7 @@ class _ScriptRun: script_stack.pop() self._finish() - return response + return ScriptRunResult(response, self._variables) async def _async_step(self, log_exceptions): continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False) @@ -1189,6 +1190,14 @@ class _IfData(TypedDict): if_else: Script | None +@dataclass +class ScriptRunResult: + """Container with the result of a script run.""" + + service_response: ServiceResponse + variables: dict + + class Script: """Representation of a script.""" @@ -1480,7 +1489,7 @@ class Script: run_variables: _VarsType | None = None, context: Context | None = None, started_action: Callable[..., Any] | None = None, - ) -> ServiceResponse: + ) -> ScriptRunResult | None: """Run script.""" if context is None: self._log( diff --git a/tests/components/template/test_sensor.py b/tests/components/template/test_sensor.py index 47e307bc6aa..cf9f3724020 100644 --- a/tests/components/template/test_sensor.py +++ b/tests/components/template/test_sensor.py @@ -1582,3 +1582,47 @@ async def test_trigger_entity_restore_state( assert state.attributes["entity_picture"] == "/local/dogs.png" assert state.attributes["plus_one"] == 3 assert state.attributes["another"] == 1 + + +@pytest.mark.parametrize(("count", "domain"), [(1, "template")]) +@pytest.mark.parametrize( + "config", + [ + { + "template": [ + { + "unique_id": "listening-test-event", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [ + { + "variables": { + "my_variable": "{{ trigger.event.data.beer + 1 }}" + }, + }, + ], + "sensor": [ + { + "name": "Hello Name", + "state": "{{ my_variable + 1 }}", + } + ], + }, + ], + }, + ], +) +async def test_trigger_action( + hass: HomeAssistant, start_ha, entity_registry: er.EntityRegistry +) -> None: + """Test trigger entity with an action works.""" + state = hass.states.get("sensor.hello_name") + assert state is not None + assert state.state == STATE_UNKNOWN + + context = Context() + hass.bus.async_fire("test_event", {"beer": 1}, context=context) + await hass.async_block_till_done() + + state = hass.states.get("sensor.hello_name") + assert state.state == "3" + assert state.context is context