diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index cad0a272e47..1f9963e184b 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -4,7 +4,19 @@ from datetime import datetime from functools import partial import itertools import logging -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple +from types import MappingProxyType +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Tuple, + Union, + cast, +) from async_timeout import timeout import voluptuous as vol @@ -49,7 +61,7 @@ from homeassistant.helpers.service import ( CONF_SERVICE_DATA, async_prepare_call_from_config, ) -from homeassistant.helpers.typing import ConfigType, TemplateVarsType +from homeassistant.helpers.typing import ConfigType from homeassistant.util import slugify from homeassistant.util.dt import utcnow @@ -134,13 +146,13 @@ class _ScriptRun: self, hass: HomeAssistant, script: "Script", - variables: TemplateVarsType, + variables: Dict[str, Any], context: Optional[Context], log_exceptions: bool, ) -> None: self._hass = hass self._script = script - self._variables = variables or {} + self._variables = variables self._context = context self._log_exceptions = log_exceptions self._step = -1 @@ -595,6 +607,9 @@ async def _async_stop_scripts_at_shutdown(hass, event): ) +_VarsType = Union[Dict[str, Any], MappingProxyType] + + class Script: """Representation of a script.""" @@ -617,6 +632,7 @@ class Script: hass.bus.async_listen_once( EVENT_HOMEASSISTANT_STOP, partial(_async_stop_scripts_at_shutdown, hass) ) + self._top_level = top_level if top_level: all_scripts.append( {"instance": self, "started_before_shutdown": not hass.is_stopping} @@ -745,7 +761,7 @@ class Script: return referenced def run( - self, variables: TemplateVarsType = None, context: Optional[Context] = None + self, variables: Optional[_VarsType] = None, context: Optional[Context] = None ) -> None: """Run script.""" asyncio.run_coroutine_threadsafe( @@ -753,7 +769,7 @@ class Script: ).result() async def async_run( - self, variables: TemplateVarsType = None, context: Optional[Context] = None + self, variables: Optional[_VarsType] = None, context: Optional[Context] = None ) -> None: """Run script.""" if self.is_running: @@ -767,11 +783,19 @@ class Script: self._log("Maximum number of runs exceeded", level=logging.WARNING) return + # If this is a top level Script then make a copy of the variables in case they + # are read-only, but more importantly, so as not to leak any variables created + # during the run back to the caller. + if self._top_level: + variables = dict(variables) if variables is not None else {} + if self.script_mode != SCRIPT_MODE_QUEUED: cls = _ScriptRun else: cls = _QueuedScriptRun - run = cls(self._hass, self, variables, context, self._log_exceptions) + run = cls( + self._hass, self, cast(dict, variables), context, self._log_exceptions + ) self._runs.append(run) try: diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 8b61a5db64b..28761c0ba17 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -4,6 +4,7 @@ import asyncio from contextlib import contextmanager from datetime import timedelta import logging +from types import MappingProxyType from unittest import mock import pytest @@ -122,7 +123,7 @@ async def test_firing_event_template(hass): ) script_obj = script.Script(hass, sequence) - await script_obj.async_run({"is_world": "yes"}, context=context) + await script_obj.async_run(MappingProxyType({"is_world": "yes"}), context=context) await hass.async_block_till_done() assert len(events) == 1 @@ -175,7 +176,7 @@ async def test_calling_service_template(hass): ) script_obj = script.Script(hass, sequence) - await script_obj.async_run({"is_world": "yes"}, context=context) + await script_obj.async_run(MappingProxyType({"is_world": "yes"}), context=context) await hass.async_block_till_done() assert len(calls) == 1 @@ -235,7 +236,9 @@ async def test_multiple_runs_no_wait(hass): logger.debug("starting 1st script") hass.async_create_task( script_obj.async_run( - {"fire1": "1", "listen1": "2", "fire2": "3", "listen2": "4"} + MappingProxyType( + {"fire1": "1", "listen1": "2", "fire2": "3", "listen2": "4"} + ) ) ) await asyncio.wait_for(heard_event.wait(), 1) @@ -243,7 +246,7 @@ async def test_multiple_runs_no_wait(hass): logger.debug("starting 2nd script") await script_obj.async_run( - {"fire1": "2", "listen1": "3", "fire2": "4", "listen2": "4"} + MappingProxyType({"fire1": "2", "listen1": "3", "fire2": "4", "listen2": "4"}) ) await hass.async_block_till_done() @@ -670,7 +673,9 @@ async def test_wait_template_variables(hass): try: hass.states.async_set("switch.test", "on") - hass.async_create_task(script_obj.async_run({"data": "switch.test"})) + hass.async_create_task( + script_obj.async_run(MappingProxyType({"data": "switch.test"})) + ) await asyncio.wait_for(wait_started_flag.wait(), 1) assert script_obj.is_running @@ -882,7 +887,14 @@ async def test_repeat_var_in_condition(hass, condition): assert len(events) == 2 -async def test_repeat_nested(hass): +@pytest.mark.parametrize( + "variables,first_last,inside_x", + [ + (None, {"repeat": "None", "x": "None"}, "None"), + (MappingProxyType({"x": 1}), {"repeat": "None", "x": "1"}, "1"), + ], +) +async def test_repeat_nested(hass, variables, first_last, inside_x): """Test nested repeats.""" event = "test_event" events = async_capture_events(hass, event) @@ -892,7 +904,8 @@ async def test_repeat_nested(hass): { "event": event, "event_data_template": { - "repeat": "{{ None if repeat is not defined else repeat }}" + "repeat": "{{ None if repeat is not defined else repeat }}", + "x": "{{ None if x is not defined else x }}", }, }, { @@ -905,6 +918,7 @@ async def test_repeat_nested(hass): "first": "{{ repeat.first }}", "index": "{{ repeat.index }}", "last": "{{ repeat.last }}", + "x": "{{ None if x is not defined else x }}", }, }, { @@ -916,6 +930,7 @@ async def test_repeat_nested(hass): "first": "{{ repeat.first }}", "index": "{{ repeat.index }}", "last": "{{ repeat.last }}", + "x": "{{ None if x is not defined else x }}", }, }, } @@ -926,6 +941,7 @@ async def test_repeat_nested(hass): "first": "{{ repeat.first }}", "index": "{{ repeat.index }}", "last": "{{ repeat.last }}", + "x": "{{ None if x is not defined else x }}", }, }, ], @@ -934,7 +950,8 @@ async def test_repeat_nested(hass): { "event": event, "event_data_template": { - "repeat": "{{ None if repeat is not defined else repeat }}" + "repeat": "{{ None if repeat is not defined else repeat }}", + "x": "{{ None if x is not defined else x }}", }, }, ] @@ -945,21 +962,21 @@ async def test_repeat_nested(hass): "homeassistant.helpers.condition._LOGGER.error", side_effect=AssertionError("Template Error"), ): - await script_obj.async_run() + await script_obj.async_run(variables) assert len(events) == 10 - assert events[0].data == {"repeat": "None"} - assert events[-1].data == {"repeat": "None"} + assert events[0].data == first_last + assert events[-1].data == first_last for index, result in enumerate( ( - ("True", "1", "False"), - ("True", "1", "False"), - ("False", "2", "True"), - ("True", "1", "False"), - ("False", "2", "True"), - ("True", "1", "False"), - ("False", "2", "True"), - ("False", "2", "True"), + ("True", "1", "False", inside_x), + ("True", "1", "False", inside_x), + ("False", "2", "True", inside_x), + ("True", "1", "False", inside_x), + ("False", "2", "True", inside_x), + ("True", "1", "False", inside_x), + ("False", "2", "True", inside_x), + ("False", "2", "True", inside_x), ), 1, ): @@ -967,6 +984,7 @@ async def test_repeat_nested(hass): "first": result[0], "index": result[1], "last": result[2], + "x": result[3], } @@ -998,7 +1016,7 @@ async def test_choose(hass, var, result): ) script_obj = script.Script(hass, sequence) - await script_obj.async_run({"var": var}) + await script_obj.async_run(MappingProxyType({"var": var})) await hass.async_block_till_done() assert len(events) == 1