Fix repeat action when variables present (#38237)

This commit is contained in:
Phil Bruckner 2020-07-27 16:51:34 -05:00 committed by GitHub
parent bea1570354
commit 1158925b53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 27 deletions

View File

@ -4,7 +4,19 @@ from datetime import datetime
from functools import partial
import itertools
import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from types import MappingProxyType
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
from async_timeout import timeout
import voluptuous as vol
@ -49,7 +61,7 @@ from homeassistant.helpers.service import (
CONF_SERVICE_DATA,
async_prepare_call_from_config,
)
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import slugify
from homeassistant.util.dt import utcnow
@ -134,13 +146,13 @@ class _ScriptRun:
self,
hass: HomeAssistant,
script: "Script",
variables: TemplateVarsType,
variables: Dict[str, Any],
context: Optional[Context],
log_exceptions: bool,
) -> None:
self._hass = hass
self._script = script
self._variables = variables or {}
self._variables = variables
self._context = context
self._log_exceptions = log_exceptions
self._step = -1
@ -595,6 +607,9 @@ async def _async_stop_scripts_at_shutdown(hass, event):
)
_VarsType = Union[Dict[str, Any], MappingProxyType]
class Script:
"""Representation of a script."""
@ -617,6 +632,7 @@ class Script:
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, partial(_async_stop_scripts_at_shutdown, hass)
)
self._top_level = top_level
if top_level:
all_scripts.append(
{"instance": self, "started_before_shutdown": not hass.is_stopping}
@ -745,7 +761,7 @@ class Script:
return referenced
def run(
self, variables: TemplateVarsType = None, context: Optional[Context] = None
self, variables: Optional[_VarsType] = None, context: Optional[Context] = None
) -> None:
"""Run script."""
asyncio.run_coroutine_threadsafe(
@ -753,7 +769,7 @@ class Script:
).result()
async def async_run(
self, variables: TemplateVarsType = None, context: Optional[Context] = None
self, variables: Optional[_VarsType] = None, context: Optional[Context] = None
) -> None:
"""Run script."""
if self.is_running:
@ -767,11 +783,19 @@ class Script:
self._log("Maximum number of runs exceeded", level=logging.WARNING)
return
# If this is a top level Script then make a copy of the variables in case they
# are read-only, but more importantly, so as not to leak any variables created
# during the run back to the caller.
if self._top_level:
variables = dict(variables) if variables is not None else {}
if self.script_mode != SCRIPT_MODE_QUEUED:
cls = _ScriptRun
else:
cls = _QueuedScriptRun
run = cls(self._hass, self, variables, context, self._log_exceptions)
run = cls(
self._hass, self, cast(dict, variables), context, self._log_exceptions
)
self._runs.append(run)
try:

View File

@ -4,6 +4,7 @@ import asyncio
from contextlib import contextmanager
from datetime import timedelta
import logging
from types import MappingProxyType
from unittest import mock
import pytest
@ -122,7 +123,7 @@ async def test_firing_event_template(hass):
)
script_obj = script.Script(hass, sequence)
await script_obj.async_run({"is_world": "yes"}, context=context)
await script_obj.async_run(MappingProxyType({"is_world": "yes"}), context=context)
await hass.async_block_till_done()
assert len(events) == 1
@ -175,7 +176,7 @@ async def test_calling_service_template(hass):
)
script_obj = script.Script(hass, sequence)
await script_obj.async_run({"is_world": "yes"}, context=context)
await script_obj.async_run(MappingProxyType({"is_world": "yes"}), context=context)
await hass.async_block_till_done()
assert len(calls) == 1
@ -235,7 +236,9 @@ async def test_multiple_runs_no_wait(hass):
logger.debug("starting 1st script")
hass.async_create_task(
script_obj.async_run(
{"fire1": "1", "listen1": "2", "fire2": "3", "listen2": "4"}
MappingProxyType(
{"fire1": "1", "listen1": "2", "fire2": "3", "listen2": "4"}
)
)
)
await asyncio.wait_for(heard_event.wait(), 1)
@ -243,7 +246,7 @@ async def test_multiple_runs_no_wait(hass):
logger.debug("starting 2nd script")
await script_obj.async_run(
{"fire1": "2", "listen1": "3", "fire2": "4", "listen2": "4"}
MappingProxyType({"fire1": "2", "listen1": "3", "fire2": "4", "listen2": "4"})
)
await hass.async_block_till_done()
@ -670,7 +673,9 @@ async def test_wait_template_variables(hass):
try:
hass.states.async_set("switch.test", "on")
hass.async_create_task(script_obj.async_run({"data": "switch.test"}))
hass.async_create_task(
script_obj.async_run(MappingProxyType({"data": "switch.test"}))
)
await asyncio.wait_for(wait_started_flag.wait(), 1)
assert script_obj.is_running
@ -882,7 +887,14 @@ async def test_repeat_var_in_condition(hass, condition):
assert len(events) == 2
async def test_repeat_nested(hass):
@pytest.mark.parametrize(
"variables,first_last,inside_x",
[
(None, {"repeat": "None", "x": "None"}, "None"),
(MappingProxyType({"x": 1}), {"repeat": "None", "x": "1"}, "1"),
],
)
async def test_repeat_nested(hass, variables, first_last, inside_x):
"""Test nested repeats."""
event = "test_event"
events = async_capture_events(hass, event)
@ -892,7 +904,8 @@ async def test_repeat_nested(hass):
{
"event": event,
"event_data_template": {
"repeat": "{{ None if repeat is not defined else repeat }}"
"repeat": "{{ None if repeat is not defined else repeat }}",
"x": "{{ None if x is not defined else x }}",
},
},
{
@ -905,6 +918,7 @@ async def test_repeat_nested(hass):
"first": "{{ repeat.first }}",
"index": "{{ repeat.index }}",
"last": "{{ repeat.last }}",
"x": "{{ None if x is not defined else x }}",
},
},
{
@ -916,6 +930,7 @@ async def test_repeat_nested(hass):
"first": "{{ repeat.first }}",
"index": "{{ repeat.index }}",
"last": "{{ repeat.last }}",
"x": "{{ None if x is not defined else x }}",
},
},
}
@ -926,6 +941,7 @@ async def test_repeat_nested(hass):
"first": "{{ repeat.first }}",
"index": "{{ repeat.index }}",
"last": "{{ repeat.last }}",
"x": "{{ None if x is not defined else x }}",
},
},
],
@ -934,7 +950,8 @@ async def test_repeat_nested(hass):
{
"event": event,
"event_data_template": {
"repeat": "{{ None if repeat is not defined else repeat }}"
"repeat": "{{ None if repeat is not defined else repeat }}",
"x": "{{ None if x is not defined else x }}",
},
},
]
@ -945,21 +962,21 @@ async def test_repeat_nested(hass):
"homeassistant.helpers.condition._LOGGER.error",
side_effect=AssertionError("Template Error"),
):
await script_obj.async_run()
await script_obj.async_run(variables)
assert len(events) == 10
assert events[0].data == {"repeat": "None"}
assert events[-1].data == {"repeat": "None"}
assert events[0].data == first_last
assert events[-1].data == first_last
for index, result in enumerate(
(
("True", "1", "False"),
("True", "1", "False"),
("False", "2", "True"),
("True", "1", "False"),
("False", "2", "True"),
("True", "1", "False"),
("False", "2", "True"),
("False", "2", "True"),
("True", "1", "False", inside_x),
("True", "1", "False", inside_x),
("False", "2", "True", inside_x),
("True", "1", "False", inside_x),
("False", "2", "True", inside_x),
("True", "1", "False", inside_x),
("False", "2", "True", inside_x),
("False", "2", "True", inside_x),
),
1,
):
@ -967,6 +984,7 @@ async def test_repeat_nested(hass):
"first": result[0],
"index": result[1],
"last": result[2],
"x": result[3],
}
@ -998,7 +1016,7 @@ async def test_choose(hass, var, result):
)
script_obj = script.Script(hass, sequence)
await script_obj.async_run({"var": var})
await script_obj.async_run(MappingProxyType({"var": var}))
await hass.async_block_till_done()
assert len(events) == 1