mirror of
https://github.com/home-assistant/core.git
synced 2025-07-09 06:17:07 +00:00
Extract variable rendering (#39934)
This commit is contained in:
parent
b107e87d38
commit
8ef04268be
@ -45,6 +45,7 @@ from homeassistant.helpers.script import (
|
|||||||
Script,
|
Script,
|
||||||
make_script_schema,
|
make_script_schema,
|
||||||
)
|
)
|
||||||
|
from homeassistant.helpers.script_variables import ScriptVariables
|
||||||
from homeassistant.helpers.service import async_register_admin_service
|
from homeassistant.helpers.service import async_register_admin_service
|
||||||
from homeassistant.helpers.trigger import async_initialize_triggers
|
from homeassistant.helpers.trigger import async_initialize_triggers
|
||||||
from homeassistant.helpers.typing import TemplateVarsType
|
from homeassistant.helpers.typing import TemplateVarsType
|
||||||
@ -256,8 +257,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||||||
self._referenced_entities: Optional[Set[str]] = None
|
self._referenced_entities: Optional[Set[str]] = None
|
||||||
self._referenced_devices: Optional[Set[str]] = None
|
self._referenced_devices: Optional[Set[str]] = None
|
||||||
self._logger = _LOGGER
|
self._logger = _LOGGER
|
||||||
self._variables = variables
|
self._variables: ScriptVariables = variables
|
||||||
self._variables_dynamic = template.is_complex(variables)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@ -334,9 +334,6 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||||||
"""Startup with initial state or previous state."""
|
"""Startup with initial state or previous state."""
|
||||||
await super().async_added_to_hass()
|
await super().async_added_to_hass()
|
||||||
|
|
||||||
if self._variables_dynamic:
|
|
||||||
template.attach(cast(HomeAssistant, self.hass), self._variables)
|
|
||||||
|
|
||||||
self._logger = logging.getLogger(
|
self._logger = logging.getLogger(
|
||||||
f"{__name__}.{split_entity_id(self.entity_id)[1]}"
|
f"{__name__}.{split_entity_id(self.entity_id)[1]}"
|
||||||
)
|
)
|
||||||
@ -392,15 +389,13 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
if self._variables:
|
if self._variables:
|
||||||
if self._variables_dynamic:
|
try:
|
||||||
variables = template.render_complex(self._variables, run_variables)
|
variables = self._variables.async_render(self.hass, run_variables)
|
||||||
|
except template.TemplateError as err:
|
||||||
|
self._logger.error("Error rendering variables: %s", err)
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
variables = dict(self._variables)
|
variables = run_variables
|
||||||
else:
|
|
||||||
variables = {}
|
|
||||||
|
|
||||||
if run_variables:
|
|
||||||
variables.update(run_variables)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not skip_condition
|
not skip_condition
|
||||||
|
@ -77,7 +77,7 @@ CONFIG_SCHEMA = vol.Schema(
|
|||||||
|
|
||||||
SCRIPT_SERVICE_SCHEMA = vol.Schema(dict)
|
SCRIPT_SERVICE_SCHEMA = vol.Schema(dict)
|
||||||
SCRIPT_TURN_ONOFF_SCHEMA = make_entity_service_schema(
|
SCRIPT_TURN_ONOFF_SCHEMA = make_entity_service_schema(
|
||||||
{vol.Optional(ATTR_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA}
|
{vol.Optional(ATTR_VARIABLES): {str: cv.match_all}}
|
||||||
)
|
)
|
||||||
RELOAD_SERVICE_SCHEMA = vol.Schema({})
|
RELOAD_SERVICE_SCHEMA = vol.Schema({})
|
||||||
|
|
||||||
|
@ -81,7 +81,10 @@ from homeassistant.const import (
|
|||||||
)
|
)
|
||||||
from homeassistant.core import split_entity_id, valid_entity_id
|
from homeassistant.core import split_entity_id, valid_entity_id
|
||||||
from homeassistant.exceptions import TemplateError
|
from homeassistant.exceptions import TemplateError
|
||||||
from homeassistant.helpers import template as template_helper
|
from homeassistant.helpers import (
|
||||||
|
script_variables as script_variables_helper,
|
||||||
|
template as template_helper,
|
||||||
|
)
|
||||||
from homeassistant.helpers.logging import KeywordStyleAdapter
|
from homeassistant.helpers.logging import KeywordStyleAdapter
|
||||||
from homeassistant.util import slugify as util_slugify
|
from homeassistant.util import slugify as util_slugify
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
@ -863,7 +866,11 @@ def make_entity_service_schema(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SCRIPT_VARIABLES_SCHEMA = vol.Schema({str: template_complex})
|
SCRIPT_VARIABLES_SCHEMA = vol.All(
|
||||||
|
vol.Schema({str: template_complex}),
|
||||||
|
# pylint: disable=unnecessary-lambda
|
||||||
|
lambda val: script_variables_helper.ScriptVariables(val),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def script_action(value: Any) -> dict:
|
def script_action(value: Any) -> dict:
|
||||||
|
@ -55,6 +55,7 @@ from homeassistant.const import (
|
|||||||
from homeassistant.core import SERVICE_CALL_LIMIT, Context, HomeAssistant, callback
|
from homeassistant.core import SERVICE_CALL_LIMIT, Context, HomeAssistant, callback
|
||||||
from homeassistant.helpers import condition, config_validation as cv, template
|
from homeassistant.helpers import condition, config_validation as cv, template
|
||||||
from homeassistant.helpers.event import async_call_later, async_track_template
|
from homeassistant.helpers.event import async_call_later, async_track_template
|
||||||
|
from homeassistant.helpers.script_variables import ScriptVariables
|
||||||
from homeassistant.helpers.service import (
|
from homeassistant.helpers.service import (
|
||||||
CONF_SERVICE_DATA,
|
CONF_SERVICE_DATA,
|
||||||
async_prepare_call_from_config,
|
async_prepare_call_from_config,
|
||||||
@ -717,7 +718,7 @@ class Script:
|
|||||||
logger: Optional[logging.Logger] = None,
|
logger: Optional[logging.Logger] = None,
|
||||||
log_exceptions: bool = True,
|
log_exceptions: bool = True,
|
||||||
top_level: bool = True,
|
top_level: bool = True,
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
variables: Optional[ScriptVariables] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the script."""
|
"""Initialize the script."""
|
||||||
all_scripts = hass.data.get(DATA_SCRIPTS)
|
all_scripts = hass.data.get(DATA_SCRIPTS)
|
||||||
@ -900,15 +901,19 @@ class Script:
|
|||||||
# during the run back to the caller.
|
# during the run back to the caller.
|
||||||
if self._top_level:
|
if self._top_level:
|
||||||
if self.variables:
|
if self.variables:
|
||||||
if self._variables_dynamic:
|
try:
|
||||||
variables = template.render_complex(self.variables, run_variables)
|
variables = self.variables.async_render(
|
||||||
else:
|
self._hass,
|
||||||
variables = dict(self.variables)
|
run_variables,
|
||||||
|
)
|
||||||
|
except template.TemplateError as err:
|
||||||
|
self._log("Error rendering variables: %s", err, level=logging.ERROR)
|
||||||
|
raise
|
||||||
|
elif run_variables:
|
||||||
|
variables = dict(run_variables)
|
||||||
else:
|
else:
|
||||||
variables = {}
|
variables = {}
|
||||||
|
|
||||||
if run_variables:
|
|
||||||
variables.update(run_variables)
|
|
||||||
variables["context"] = context
|
variables["context"] = context
|
||||||
else:
|
else:
|
||||||
variables = cast(dict, run_variables)
|
variables = cast(dict, run_variables)
|
||||||
|
57
homeassistant/helpers/script_variables.py
Normal file
57
homeassistant/helpers/script_variables.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
"""Script variables."""
|
||||||
|
from typing import Any, Dict, Mapping, Optional
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
|
||||||
|
from . import template
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptVariables:
|
||||||
|
"""Class to hold and render script variables."""
|
||||||
|
|
||||||
|
def __init__(self, variables: Dict[str, Any]):
|
||||||
|
"""Initialize script variables."""
|
||||||
|
self.variables = variables
|
||||||
|
self._has_template: Optional[bool] = None
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_render(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
run_variables: Optional[Mapping[str, Any]],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Render script variables.
|
||||||
|
|
||||||
|
The run variables are used to compute the static variables, but afterwards will also
|
||||||
|
be merged on top of the static variables.
|
||||||
|
"""
|
||||||
|
if self._has_template is None:
|
||||||
|
self._has_template = template.is_complex(self.variables)
|
||||||
|
template.attach(hass, self.variables)
|
||||||
|
|
||||||
|
if not self._has_template:
|
||||||
|
rendered_variables = dict(self.variables)
|
||||||
|
|
||||||
|
if run_variables is not None:
|
||||||
|
rendered_variables.update(run_variables)
|
||||||
|
|
||||||
|
return rendered_variables
|
||||||
|
|
||||||
|
rendered_variables = {} if run_variables is None else dict(run_variables)
|
||||||
|
|
||||||
|
for key, value in self.variables.items():
|
||||||
|
# We can skip if we're going to override this key with
|
||||||
|
# run variables anyway
|
||||||
|
if key in rendered_variables:
|
||||||
|
continue
|
||||||
|
|
||||||
|
rendered_variables[key] = template.render_complex(value, rendered_variables)
|
||||||
|
|
||||||
|
if run_variables:
|
||||||
|
rendered_variables.update(run_variables)
|
||||||
|
|
||||||
|
return rendered_variables
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
"""Return dict version of this class."""
|
||||||
|
return self.variables
|
@ -1136,7 +1136,7 @@ async def test_logbook_humanify_automation_triggered_event(hass):
|
|||||||
assert event2["entity_id"] == "automation.bye"
|
assert event2["entity_id"] == "automation.bye"
|
||||||
|
|
||||||
|
|
||||||
async def test_automation_variables(hass):
|
async def test_automation_variables(hass, caplog):
|
||||||
"""Test automation variables."""
|
"""Test automation variables."""
|
||||||
calls = async_mock_service(hass, "test", "automation")
|
calls = async_mock_service(hass, "test", "automation")
|
||||||
|
|
||||||
@ -1172,6 +1172,15 @@ async def test_automation_variables(hass):
|
|||||||
"service": "test.automation",
|
"service": "test.automation",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"variables": {
|
||||||
|
"test_var": "{{ trigger.event.data.break + 1 }}",
|
||||||
|
},
|
||||||
|
"trigger": {"platform": "event", "event_type": "test_event_3"},
|
||||||
|
"action": {
|
||||||
|
"service": "test.automation",
|
||||||
|
},
|
||||||
|
},
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -1188,3 +1197,13 @@ async def test_automation_variables(hass):
|
|||||||
hass.bus.async_fire("test_event_2", {"pass_condition": True})
|
hass.bus.async_fire("test_event_2", {"pass_condition": True})
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert len(calls) == 2
|
assert len(calls) == 2
|
||||||
|
|
||||||
|
assert "Error rendering variables" not in caplog.text
|
||||||
|
hass.bus.async_fire("test_event_3")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert "Error rendering variables" in caplog.text
|
||||||
|
|
||||||
|
hass.bus.async_fire("test_event_3", {"break": 0})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(calls) == 3
|
||||||
|
@ -17,6 +17,7 @@ from homeassistant.const import (
|
|||||||
)
|
)
|
||||||
from homeassistant.core import Context, callback, split_entity_id
|
from homeassistant.core import Context, callback, split_entity_id
|
||||||
from homeassistant.exceptions import ServiceNotFound
|
from homeassistant.exceptions import ServiceNotFound
|
||||||
|
from homeassistant.helpers import template
|
||||||
from homeassistant.helpers.event import async_track_state_change
|
from homeassistant.helpers.event import async_track_state_change
|
||||||
from homeassistant.helpers.service import async_get_all_descriptions
|
from homeassistant.helpers.service import async_get_all_descriptions
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
@ -617,7 +618,7 @@ async def test_concurrent_script(hass, concurrently):
|
|||||||
assert not script.is_on(hass, "script.script2")
|
assert not script.is_on(hass, "script.script2")
|
||||||
|
|
||||||
|
|
||||||
async def test_script_variables(hass):
|
async def test_script_variables(hass, caplog):
|
||||||
"""Test defining scripts."""
|
"""Test defining scripts."""
|
||||||
assert await async_setup_component(
|
assert await async_setup_component(
|
||||||
hass,
|
hass,
|
||||||
@ -652,6 +653,19 @@ async def test_script_variables(hass):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
"script3": {
|
||||||
|
"variables": {
|
||||||
|
"test_var": "{{ break + 1 }}",
|
||||||
|
},
|
||||||
|
"sequence": [
|
||||||
|
{
|
||||||
|
"service": "test.script",
|
||||||
|
"data": {
|
||||||
|
"value": "{{ test_var }}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -681,3 +695,14 @@ async def test_script_variables(hass):
|
|||||||
|
|
||||||
assert len(mock_calls) == 3
|
assert len(mock_calls) == 3
|
||||||
assert mock_calls[2].data["value"] == "from_service"
|
assert mock_calls[2].data["value"] == "from_service"
|
||||||
|
|
||||||
|
assert "Error rendering variables" not in caplog.text
|
||||||
|
with pytest.raises(template.TemplateError):
|
||||||
|
await hass.services.async_call("script", "script3", blocking=True)
|
||||||
|
assert "Error rendering variables" in caplog.text
|
||||||
|
assert len(mock_calls) == 3
|
||||||
|
|
||||||
|
await hass.services.async_call("script", "script3", {"break": 0}, blocking=True)
|
||||||
|
|
||||||
|
assert len(mock_calls) == 4
|
||||||
|
assert mock_calls[3].data["value"] == "1"
|
||||||
|
60
tests/helpers/test_script_variables.py
Normal file
60
tests/helpers/test_script_variables.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
"""Test script variables."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.helpers import config_validation as cv, template
|
||||||
|
|
||||||
|
|
||||||
|
async def test_static_vars():
|
||||||
|
"""Test static vars."""
|
||||||
|
orig = {"hello": "world"}
|
||||||
|
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
|
||||||
|
rendered = var.async_render(None, None)
|
||||||
|
assert rendered is not orig
|
||||||
|
assert rendered == orig
|
||||||
|
|
||||||
|
|
||||||
|
async def test_static_vars_run_args():
|
||||||
|
"""Test static vars."""
|
||||||
|
orig = {"hello": "world"}
|
||||||
|
orig_copy = dict(orig)
|
||||||
|
var = cv.SCRIPT_VARIABLES_SCHEMA(orig)
|
||||||
|
rendered = var.async_render(None, {"hello": "override", "run": "var"})
|
||||||
|
assert rendered == {"hello": "override", "run": "var"}
|
||||||
|
# Make sure we don't change original vars
|
||||||
|
assert orig == orig_copy
|
||||||
|
|
||||||
|
|
||||||
|
async def test_template_vars(hass):
|
||||||
|
"""Test template vars."""
|
||||||
|
var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ 1 + 1 }}"})
|
||||||
|
rendered = var.async_render(hass, None)
|
||||||
|
assert rendered == {"hello": "2"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_template_vars_run_args(hass):
|
||||||
|
"""Test template vars."""
|
||||||
|
var = cv.SCRIPT_VARIABLES_SCHEMA(
|
||||||
|
{
|
||||||
|
"something": "{{ run_var_ex + 1 }}",
|
||||||
|
"something_2": "{{ run_var_ex + 1 }}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
rendered = var.async_render(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
"run_var_ex": 5,
|
||||||
|
"something_2": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert rendered == {
|
||||||
|
"run_var_ex": 5,
|
||||||
|
"something": "6",
|
||||||
|
"something_2": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_template_vars_error(hass):
|
||||||
|
"""Test template vars."""
|
||||||
|
var = cv.SCRIPT_VARIABLES_SCHEMA({"hello": "{{ canont.work }}"})
|
||||||
|
with pytest.raises(template.TemplateError):
|
||||||
|
var.async_render(hass, None)
|
Loading…
x
Reference in New Issue
Block a user