Extend template entities with a script section (#96175)

* Extend template entities with a script section

This allows making a trigger entity that triggers a few times a day,
and allows collecting data from a service resopnse which can be
fed into a template entity.

The current alternatives are to publish and subscribe to events or to
store data in input entities.

* Make variables set in actions accessible to templates

* Format code

---------

Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
Allen Porter 2023-09-02 16:19:45 -07:00 committed by GitHub
parent 6312f34538
commit 7b1c0c2df2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 79 additions and 10 deletions

View File

@ -563,7 +563,8 @@ class ScriptEntity(BaseScriptEntity, RestoreEntity):
) )
coro = self._async_run(variables, context) coro = self._async_run(variables, context)
if wait: if wait:
return await coro script_result = await coro
return script_result.service_response if script_result else None
# Caller does not want to wait for called script to finish so let script run in # Caller does not want to wait for called script to finish so let script run in
# separate Task. Make a new empty script stack; scripts are allowed to # separate Task. Make a new empty script stack; scripts are allowed to

View File

@ -20,11 +20,12 @@ from homeassistant.helpers import (
update_coordinator, update_coordinator,
) )
from homeassistant.helpers.reload import async_reload_integration_platforms from homeassistant.helpers.reload import async_reload_integration_platforms
from homeassistant.helpers.script import Script
from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import async_get_integration from homeassistant.loader import async_get_integration
from .const import CONF_TRIGGER, DOMAIN, PLATFORMS from .const import CONF_ACTION, CONF_TRIGGER, DOMAIN, PLATFORMS
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -133,6 +134,7 @@ class TriggerUpdateCoordinator(update_coordinator.DataUpdateCoordinator):
self.config = config self.config = config
self._unsub_start: Callable[[], None] | None = None self._unsub_start: Callable[[], None] | None = None
self._unsub_trigger: Callable[[], None] | None = None self._unsub_trigger: Callable[[], None] | None = None
self._script: Script | None = None
@property @property
def unique_id(self) -> str | None: def unique_id(self) -> str | None:
@ -170,6 +172,14 @@ class TriggerUpdateCoordinator(update_coordinator.DataUpdateCoordinator):
async def _attach_triggers(self, start_event=None) -> None: async def _attach_triggers(self, start_event=None) -> None:
"""Attach the triggers.""" """Attach the triggers."""
if CONF_ACTION in self.config:
self._script = Script(
self.hass,
self.config[CONF_ACTION],
self.name,
DOMAIN,
)
if start_event is not None: if start_event is not None:
self._unsub_start = None self._unsub_start = None
@ -183,8 +193,11 @@ class TriggerUpdateCoordinator(update_coordinator.DataUpdateCoordinator):
start_event is not None, start_event is not None,
) )
@callback async def _handle_triggered(self, run_variables, context=None):
def _handle_triggered(self, run_variables, context=None): if self._script:
script_result = await self._script.async_run(run_variables, context)
if script_result:
run_variables = script_result.variables
self.async_set_updated_data( self.async_set_updated_data(
{"run_variables": run_variables, "context": context} {"run_variables": run_variables, "context": context}
) )

View File

@ -22,7 +22,7 @@ from . import (
select as select_platform, select as select_platform,
sensor as sensor_platform, sensor as sensor_platform,
) )
from .const import CONF_TRIGGER, DOMAIN from .const import CONF_ACTION, CONF_TRIGGER, DOMAIN
PACKAGE_MERGE_HINT = "list" PACKAGE_MERGE_HINT = "list"
@ -30,6 +30,7 @@ CONFIG_SECTION_SCHEMA = vol.Schema(
{ {
vol.Optional(CONF_UNIQUE_ID): cv.string, vol.Optional(CONF_UNIQUE_ID): cv.string,
vol.Optional(CONF_TRIGGER): cv.TRIGGER_SCHEMA, vol.Optional(CONF_TRIGGER): cv.TRIGGER_SCHEMA,
vol.Optional(CONF_ACTION): cv.SCRIPT_SCHEMA,
vol.Optional(NUMBER_DOMAIN): vol.All( vol.Optional(NUMBER_DOMAIN): vol.All(
cv.ensure_list, [number_platform.NUMBER_SCHEMA] cv.ensure_list, [number_platform.NUMBER_SCHEMA]
), ),

View File

@ -2,6 +2,7 @@
from homeassistant.const import Platform from homeassistant.const import Platform
CONF_ACTION = "action"
CONF_AVAILABILITY_TEMPLATE = "availability_template" CONF_AVAILABILITY_TEMPLATE = "availability_template"
CONF_ATTRIBUTE_TEMPLATES = "attribute_templates" CONF_ATTRIBUTE_TEMPLATES = "attribute_templates"
CONF_TRIGGER = "trigger" CONF_TRIGGER = "trigger"

View File

@ -713,12 +713,12 @@ async def handle_execute_script(
context = connection.context(msg) context = connection.context(msg)
script_obj = Script(hass, script_config, f"{const.DOMAIN} script", const.DOMAIN) script_obj = Script(hass, script_config, f"{const.DOMAIN} script", const.DOMAIN)
response = await script_obj.async_run(msg.get("variables"), context=context) script_result = await script_obj.async_run(msg.get("variables"), context=context)
connection.send_result( connection.send_result(
msg["id"], msg["id"],
{ {
"context": context, "context": context,
"response": response, "response": script_result.service_response if script_result else None,
}, },
) )

View File

@ -6,6 +6,7 @@ from collections.abc import Callable, Mapping, Sequence
from contextlib import asynccontextmanager, suppress from contextlib import asynccontextmanager, suppress
from contextvars import ContextVar from contextvars import ContextVar
from copy import copy from copy import copy
from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial from functools import partial
import itertools import itertools
@ -401,7 +402,7 @@ class _ScriptRun:
) )
self._log("Executing step %s%s", self._script.last_action, _timeout) self._log("Executing step %s%s", self._script.last_action, _timeout)
async def async_run(self) -> ServiceResponse: async def async_run(self) -> ScriptRunResult | None:
"""Run script.""" """Run script."""
# Push the script to the script execution stack # Push the script to the script execution stack
if (script_stack := script_stack_cv.get()) is None: if (script_stack := script_stack_cv.get()) is None:
@ -443,7 +444,7 @@ class _ScriptRun:
script_stack.pop() script_stack.pop()
self._finish() self._finish()
return response return ScriptRunResult(response, self._variables)
async def _async_step(self, log_exceptions): async def _async_step(self, log_exceptions):
continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False) continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False)
@ -1189,6 +1190,14 @@ class _IfData(TypedDict):
if_else: Script | None if_else: Script | None
@dataclass
class ScriptRunResult:
"""Container with the result of a script run."""
service_response: ServiceResponse
variables: dict
class Script: class Script:
"""Representation of a script.""" """Representation of a script."""
@ -1480,7 +1489,7 @@ class Script:
run_variables: _VarsType | None = None, run_variables: _VarsType | None = None,
context: Context | None = None, context: Context | None = None,
started_action: Callable[..., Any] | None = None, started_action: Callable[..., Any] | None = None,
) -> ServiceResponse: ) -> ScriptRunResult | None:
"""Run script.""" """Run script."""
if context is None: if context is None:
self._log( self._log(

View File

@ -1582,3 +1582,47 @@ async def test_trigger_entity_restore_state(
assert state.attributes["entity_picture"] == "/local/dogs.png" assert state.attributes["entity_picture"] == "/local/dogs.png"
assert state.attributes["plus_one"] == 3 assert state.attributes["plus_one"] == 3
assert state.attributes["another"] == 1 assert state.attributes["another"] == 1
@pytest.mark.parametrize(("count", "domain"), [(1, "template")])
@pytest.mark.parametrize(
"config",
[
{
"template": [
{
"unique_id": "listening-test-event",
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [
{
"variables": {
"my_variable": "{{ trigger.event.data.beer + 1 }}"
},
},
],
"sensor": [
{
"name": "Hello Name",
"state": "{{ my_variable + 1 }}",
}
],
},
],
},
],
)
async def test_trigger_action(
hass: HomeAssistant, start_ha, entity_registry: er.EntityRegistry
) -> None:
"""Test trigger entity with an action works."""
state = hass.states.get("sensor.hello_name")
assert state is not None
assert state.state == STATE_UNKNOWN
context = Context()
hass.bus.async_fire("test_event", {"beer": 1}, context=context)
await hass.async_block_till_done()
state = hass.states.get("sensor.hello_name")
assert state.state == "3"
assert state.context is context