Fix repeat action when variables present (#38237)

This commit is contained in:
Phil Bruckner 2020-07-27 16:51:34 -05:00 committed by GitHub
parent bea1570354
commit 1158925b53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 27 deletions

View File

@ -4,7 +4,19 @@ from datetime import datetime
from functools import partial from functools import partial
import itertools import itertools
import logging import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from types import MappingProxyType
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
from async_timeout import timeout from async_timeout import timeout
import voluptuous as vol import voluptuous as vol
@ -49,7 +61,7 @@ from homeassistant.helpers.service import (
CONF_SERVICE_DATA, CONF_SERVICE_DATA,
async_prepare_call_from_config, async_prepare_call_from_config,
) )
from homeassistant.helpers.typing import ConfigType, TemplateVarsType from homeassistant.helpers.typing import ConfigType
from homeassistant.util import slugify from homeassistant.util import slugify
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
@ -134,13 +146,13 @@ class _ScriptRun:
self, self,
hass: HomeAssistant, hass: HomeAssistant,
script: "Script", script: "Script",
variables: TemplateVarsType, variables: Dict[str, Any],
context: Optional[Context], context: Optional[Context],
log_exceptions: bool, log_exceptions: bool,
) -> None: ) -> None:
self._hass = hass self._hass = hass
self._script = script self._script = script
self._variables = variables or {} self._variables = variables
self._context = context self._context = context
self._log_exceptions = log_exceptions self._log_exceptions = log_exceptions
self._step = -1 self._step = -1
@ -595,6 +607,9 @@ async def _async_stop_scripts_at_shutdown(hass, event):
) )
_VarsType = Union[Dict[str, Any], MappingProxyType]
class Script: class Script:
"""Representation of a script.""" """Representation of a script."""
@ -617,6 +632,7 @@ class Script:
hass.bus.async_listen_once( hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, partial(_async_stop_scripts_at_shutdown, hass) EVENT_HOMEASSISTANT_STOP, partial(_async_stop_scripts_at_shutdown, hass)
) )
self._top_level = top_level
if top_level: if top_level:
all_scripts.append( all_scripts.append(
{"instance": self, "started_before_shutdown": not hass.is_stopping} {"instance": self, "started_before_shutdown": not hass.is_stopping}
@ -745,7 +761,7 @@ class Script:
return referenced return referenced
def run( def run(
self, variables: TemplateVarsType = None, context: Optional[Context] = None self, variables: Optional[_VarsType] = None, context: Optional[Context] = None
) -> None: ) -> None:
"""Run script.""" """Run script."""
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
@ -753,7 +769,7 @@ class Script:
).result() ).result()
async def async_run( async def async_run(
self, variables: TemplateVarsType = None, context: Optional[Context] = None self, variables: Optional[_VarsType] = None, context: Optional[Context] = None
) -> None: ) -> None:
"""Run script.""" """Run script."""
if self.is_running: if self.is_running:
@ -767,11 +783,19 @@ class Script:
self._log("Maximum number of runs exceeded", level=logging.WARNING) self._log("Maximum number of runs exceeded", level=logging.WARNING)
return return
# If this is a top level Script then make a copy of the variables in case they
# are read-only, but more importantly, so as not to leak any variables created
# during the run back to the caller.
if self._top_level:
variables = dict(variables) if variables is not None else {}
if self.script_mode != SCRIPT_MODE_QUEUED: if self.script_mode != SCRIPT_MODE_QUEUED:
cls = _ScriptRun cls = _ScriptRun
else: else:
cls = _QueuedScriptRun cls = _QueuedScriptRun
run = cls(self._hass, self, variables, context, self._log_exceptions) run = cls(
self._hass, self, cast(dict, variables), context, self._log_exceptions
)
self._runs.append(run) self._runs.append(run)
try: try:

View File

@ -4,6 +4,7 @@ import asyncio
from contextlib import contextmanager from contextlib import contextmanager
from datetime import timedelta from datetime import timedelta
import logging import logging
from types import MappingProxyType
from unittest import mock from unittest import mock
import pytest import pytest
@ -122,7 +123,7 @@ async def test_firing_event_template(hass):
) )
script_obj = script.Script(hass, sequence) script_obj = script.Script(hass, sequence)
await script_obj.async_run({"is_world": "yes"}, context=context) await script_obj.async_run(MappingProxyType({"is_world": "yes"}), context=context)
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(events) == 1 assert len(events) == 1
@ -175,7 +176,7 @@ async def test_calling_service_template(hass):
) )
script_obj = script.Script(hass, sequence) script_obj = script.Script(hass, sequence)
await script_obj.async_run({"is_world": "yes"}, context=context) await script_obj.async_run(MappingProxyType({"is_world": "yes"}), context=context)
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
@ -235,15 +236,17 @@ async def test_multiple_runs_no_wait(hass):
logger.debug("starting 1st script") logger.debug("starting 1st script")
hass.async_create_task( hass.async_create_task(
script_obj.async_run( script_obj.async_run(
MappingProxyType(
{"fire1": "1", "listen1": "2", "fire2": "3", "listen2": "4"} {"fire1": "1", "listen1": "2", "fire2": "3", "listen2": "4"}
) )
) )
)
await asyncio.wait_for(heard_event.wait(), 1) await asyncio.wait_for(heard_event.wait(), 1)
unsub() unsub()
logger.debug("starting 2nd script") logger.debug("starting 2nd script")
await script_obj.async_run( await script_obj.async_run(
{"fire1": "2", "listen1": "3", "fire2": "4", "listen2": "4"} MappingProxyType({"fire1": "2", "listen1": "3", "fire2": "4", "listen2": "4"})
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -670,7 +673,9 @@ async def test_wait_template_variables(hass):
try: try:
hass.states.async_set("switch.test", "on") hass.states.async_set("switch.test", "on")
hass.async_create_task(script_obj.async_run({"data": "switch.test"})) hass.async_create_task(
script_obj.async_run(MappingProxyType({"data": "switch.test"}))
)
await asyncio.wait_for(wait_started_flag.wait(), 1) await asyncio.wait_for(wait_started_flag.wait(), 1)
assert script_obj.is_running assert script_obj.is_running
@ -882,7 +887,14 @@ async def test_repeat_var_in_condition(hass, condition):
assert len(events) == 2 assert len(events) == 2
async def test_repeat_nested(hass): @pytest.mark.parametrize(
"variables,first_last,inside_x",
[
(None, {"repeat": "None", "x": "None"}, "None"),
(MappingProxyType({"x": 1}), {"repeat": "None", "x": "1"}, "1"),
],
)
async def test_repeat_nested(hass, variables, first_last, inside_x):
"""Test nested repeats.""" """Test nested repeats."""
event = "test_event" event = "test_event"
events = async_capture_events(hass, event) events = async_capture_events(hass, event)
@ -892,7 +904,8 @@ async def test_repeat_nested(hass):
{ {
"event": event, "event": event,
"event_data_template": { "event_data_template": {
"repeat": "{{ None if repeat is not defined else repeat }}" "repeat": "{{ None if repeat is not defined else repeat }}",
"x": "{{ None if x is not defined else x }}",
}, },
}, },
{ {
@ -905,6 +918,7 @@ async def test_repeat_nested(hass):
"first": "{{ repeat.first }}", "first": "{{ repeat.first }}",
"index": "{{ repeat.index }}", "index": "{{ repeat.index }}",
"last": "{{ repeat.last }}", "last": "{{ repeat.last }}",
"x": "{{ None if x is not defined else x }}",
}, },
}, },
{ {
@ -916,6 +930,7 @@ async def test_repeat_nested(hass):
"first": "{{ repeat.first }}", "first": "{{ repeat.first }}",
"index": "{{ repeat.index }}", "index": "{{ repeat.index }}",
"last": "{{ repeat.last }}", "last": "{{ repeat.last }}",
"x": "{{ None if x is not defined else x }}",
}, },
}, },
} }
@ -926,6 +941,7 @@ async def test_repeat_nested(hass):
"first": "{{ repeat.first }}", "first": "{{ repeat.first }}",
"index": "{{ repeat.index }}", "index": "{{ repeat.index }}",
"last": "{{ repeat.last }}", "last": "{{ repeat.last }}",
"x": "{{ None if x is not defined else x }}",
}, },
}, },
], ],
@ -934,7 +950,8 @@ async def test_repeat_nested(hass):
{ {
"event": event, "event": event,
"event_data_template": { "event_data_template": {
"repeat": "{{ None if repeat is not defined else repeat }}" "repeat": "{{ None if repeat is not defined else repeat }}",
"x": "{{ None if x is not defined else x }}",
}, },
}, },
] ]
@ -945,21 +962,21 @@ async def test_repeat_nested(hass):
"homeassistant.helpers.condition._LOGGER.error", "homeassistant.helpers.condition._LOGGER.error",
side_effect=AssertionError("Template Error"), side_effect=AssertionError("Template Error"),
): ):
await script_obj.async_run() await script_obj.async_run(variables)
assert len(events) == 10 assert len(events) == 10
assert events[0].data == {"repeat": "None"} assert events[0].data == first_last
assert events[-1].data == {"repeat": "None"} assert events[-1].data == first_last
for index, result in enumerate( for index, result in enumerate(
( (
("True", "1", "False"), ("True", "1", "False", inside_x),
("True", "1", "False"), ("True", "1", "False", inside_x),
("False", "2", "True"), ("False", "2", "True", inside_x),
("True", "1", "False"), ("True", "1", "False", inside_x),
("False", "2", "True"), ("False", "2", "True", inside_x),
("True", "1", "False"), ("True", "1", "False", inside_x),
("False", "2", "True"), ("False", "2", "True", inside_x),
("False", "2", "True"), ("False", "2", "True", inside_x),
), ),
1, 1,
): ):
@ -967,6 +984,7 @@ async def test_repeat_nested(hass):
"first": result[0], "first": result[0],
"index": result[1], "index": result[1],
"last": result[2], "last": result[2],
"x": result[3],
} }
@ -998,7 +1016,7 @@ async def test_choose(hass, var, result):
) )
script_obj = script.Script(hass, sequence) script_obj = script.Script(hass, sequence)
await script_obj.async_run({"var": var}) await script_obj.async_run(MappingProxyType({"var": var}))
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(events) == 1 assert len(events) == 1