mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Fix repeat action when variables present (#38237)
This commit is contained in:
parent
bea1570354
commit
1158925b53
@ -4,7 +4,19 @@ from datetime import datetime
|
||||
from functools import partial
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
||||
from types import MappingProxyType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from async_timeout import timeout
|
||||
import voluptuous as vol
|
||||
@ -49,7 +61,7 @@ from homeassistant.helpers.service import (
|
||||
CONF_SERVICE_DATA,
|
||||
async_prepare_call_from_config,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.util import slugify
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
@ -134,13 +146,13 @@ class _ScriptRun:
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
script: "Script",
|
||||
variables: TemplateVarsType,
|
||||
variables: Dict[str, Any],
|
||||
context: Optional[Context],
|
||||
log_exceptions: bool,
|
||||
) -> None:
|
||||
self._hass = hass
|
||||
self._script = script
|
||||
self._variables = variables or {}
|
||||
self._variables = variables
|
||||
self._context = context
|
||||
self._log_exceptions = log_exceptions
|
||||
self._step = -1
|
||||
@ -595,6 +607,9 @@ async def _async_stop_scripts_at_shutdown(hass, event):
|
||||
)
|
||||
|
||||
|
||||
_VarsType = Union[Dict[str, Any], MappingProxyType]
|
||||
|
||||
|
||||
class Script:
|
||||
"""Representation of a script."""
|
||||
|
||||
@ -617,6 +632,7 @@ class Script:
|
||||
hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_STOP, partial(_async_stop_scripts_at_shutdown, hass)
|
||||
)
|
||||
self._top_level = top_level
|
||||
if top_level:
|
||||
all_scripts.append(
|
||||
{"instance": self, "started_before_shutdown": not hass.is_stopping}
|
||||
@ -745,7 +761,7 @@ class Script:
|
||||
return referenced
|
||||
|
||||
def run(
|
||||
self, variables: TemplateVarsType = None, context: Optional[Context] = None
|
||||
self, variables: Optional[_VarsType] = None, context: Optional[Context] = None
|
||||
) -> None:
|
||||
"""Run script."""
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
@ -753,7 +769,7 @@ class Script:
|
||||
).result()
|
||||
|
||||
async def async_run(
|
||||
self, variables: TemplateVarsType = None, context: Optional[Context] = None
|
||||
self, variables: Optional[_VarsType] = None, context: Optional[Context] = None
|
||||
) -> None:
|
||||
"""Run script."""
|
||||
if self.is_running:
|
||||
@ -767,11 +783,19 @@ class Script:
|
||||
self._log("Maximum number of runs exceeded", level=logging.WARNING)
|
||||
return
|
||||
|
||||
# If this is a top level Script then make a copy of the variables in case they
|
||||
# are read-only, but more importantly, so as not to leak any variables created
|
||||
# during the run back to the caller.
|
||||
if self._top_level:
|
||||
variables = dict(variables) if variables is not None else {}
|
||||
|
||||
if self.script_mode != SCRIPT_MODE_QUEUED:
|
||||
cls = _ScriptRun
|
||||
else:
|
||||
cls = _QueuedScriptRun
|
||||
run = cls(self._hass, self, variables, context, self._log_exceptions)
|
||||
run = cls(
|
||||
self._hass, self, cast(dict, variables), context, self._log_exceptions
|
||||
)
|
||||
self._runs.append(run)
|
||||
|
||||
try:
|
||||
|
@ -4,6 +4,7 @@ import asyncio
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@ -122,7 +123,7 @@ async def test_firing_event_template(hass):
|
||||
)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
await script_obj.async_run({"is_world": "yes"}, context=context)
|
||||
await script_obj.async_run(MappingProxyType({"is_world": "yes"}), context=context)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(events) == 1
|
||||
@ -175,7 +176,7 @@ async def test_calling_service_template(hass):
|
||||
)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
await script_obj.async_run({"is_world": "yes"}, context=context)
|
||||
await script_obj.async_run(MappingProxyType({"is_world": "yes"}), context=context)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
@ -235,7 +236,9 @@ async def test_multiple_runs_no_wait(hass):
|
||||
logger.debug("starting 1st script")
|
||||
hass.async_create_task(
|
||||
script_obj.async_run(
|
||||
{"fire1": "1", "listen1": "2", "fire2": "3", "listen2": "4"}
|
||||
MappingProxyType(
|
||||
{"fire1": "1", "listen1": "2", "fire2": "3", "listen2": "4"}
|
||||
)
|
||||
)
|
||||
)
|
||||
await asyncio.wait_for(heard_event.wait(), 1)
|
||||
@ -243,7 +246,7 @@ async def test_multiple_runs_no_wait(hass):
|
||||
|
||||
logger.debug("starting 2nd script")
|
||||
await script_obj.async_run(
|
||||
{"fire1": "2", "listen1": "3", "fire2": "4", "listen2": "4"}
|
||||
MappingProxyType({"fire1": "2", "listen1": "3", "fire2": "4", "listen2": "4"})
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
@ -670,7 +673,9 @@ async def test_wait_template_variables(hass):
|
||||
|
||||
try:
|
||||
hass.states.async_set("switch.test", "on")
|
||||
hass.async_create_task(script_obj.async_run({"data": "switch.test"}))
|
||||
hass.async_create_task(
|
||||
script_obj.async_run(MappingProxyType({"data": "switch.test"}))
|
||||
)
|
||||
await asyncio.wait_for(wait_started_flag.wait(), 1)
|
||||
|
||||
assert script_obj.is_running
|
||||
@ -882,7 +887,14 @@ async def test_repeat_var_in_condition(hass, condition):
|
||||
assert len(events) == 2
|
||||
|
||||
|
||||
async def test_repeat_nested(hass):
|
||||
@pytest.mark.parametrize(
|
||||
"variables,first_last,inside_x",
|
||||
[
|
||||
(None, {"repeat": "None", "x": "None"}, "None"),
|
||||
(MappingProxyType({"x": 1}), {"repeat": "None", "x": "1"}, "1"),
|
||||
],
|
||||
)
|
||||
async def test_repeat_nested(hass, variables, first_last, inside_x):
|
||||
"""Test nested repeats."""
|
||||
event = "test_event"
|
||||
events = async_capture_events(hass, event)
|
||||
@ -892,7 +904,8 @@ async def test_repeat_nested(hass):
|
||||
{
|
||||
"event": event,
|
||||
"event_data_template": {
|
||||
"repeat": "{{ None if repeat is not defined else repeat }}"
|
||||
"repeat": "{{ None if repeat is not defined else repeat }}",
|
||||
"x": "{{ None if x is not defined else x }}",
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -905,6 +918,7 @@ async def test_repeat_nested(hass):
|
||||
"first": "{{ repeat.first }}",
|
||||
"index": "{{ repeat.index }}",
|
||||
"last": "{{ repeat.last }}",
|
||||
"x": "{{ None if x is not defined else x }}",
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -916,6 +930,7 @@ async def test_repeat_nested(hass):
|
||||
"first": "{{ repeat.first }}",
|
||||
"index": "{{ repeat.index }}",
|
||||
"last": "{{ repeat.last }}",
|
||||
"x": "{{ None if x is not defined else x }}",
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -926,6 +941,7 @@ async def test_repeat_nested(hass):
|
||||
"first": "{{ repeat.first }}",
|
||||
"index": "{{ repeat.index }}",
|
||||
"last": "{{ repeat.last }}",
|
||||
"x": "{{ None if x is not defined else x }}",
|
||||
},
|
||||
},
|
||||
],
|
||||
@ -934,7 +950,8 @@ async def test_repeat_nested(hass):
|
||||
{
|
||||
"event": event,
|
||||
"event_data_template": {
|
||||
"repeat": "{{ None if repeat is not defined else repeat }}"
|
||||
"repeat": "{{ None if repeat is not defined else repeat }}",
|
||||
"x": "{{ None if x is not defined else x }}",
|
||||
},
|
||||
},
|
||||
]
|
||||
@ -945,21 +962,21 @@ async def test_repeat_nested(hass):
|
||||
"homeassistant.helpers.condition._LOGGER.error",
|
||||
side_effect=AssertionError("Template Error"),
|
||||
):
|
||||
await script_obj.async_run()
|
||||
await script_obj.async_run(variables)
|
||||
|
||||
assert len(events) == 10
|
||||
assert events[0].data == {"repeat": "None"}
|
||||
assert events[-1].data == {"repeat": "None"}
|
||||
assert events[0].data == first_last
|
||||
assert events[-1].data == first_last
|
||||
for index, result in enumerate(
|
||||
(
|
||||
("True", "1", "False"),
|
||||
("True", "1", "False"),
|
||||
("False", "2", "True"),
|
||||
("True", "1", "False"),
|
||||
("False", "2", "True"),
|
||||
("True", "1", "False"),
|
||||
("False", "2", "True"),
|
||||
("False", "2", "True"),
|
||||
("True", "1", "False", inside_x),
|
||||
("True", "1", "False", inside_x),
|
||||
("False", "2", "True", inside_x),
|
||||
("True", "1", "False", inside_x),
|
||||
("False", "2", "True", inside_x),
|
||||
("True", "1", "False", inside_x),
|
||||
("False", "2", "True", inside_x),
|
||||
("False", "2", "True", inside_x),
|
||||
),
|
||||
1,
|
||||
):
|
||||
@ -967,6 +984,7 @@ async def test_repeat_nested(hass):
|
||||
"first": result[0],
|
||||
"index": result[1],
|
||||
"last": result[2],
|
||||
"x": result[3],
|
||||
}
|
||||
|
||||
|
||||
@ -998,7 +1016,7 @@ async def test_choose(hass, var, result):
|
||||
)
|
||||
script_obj = script.Script(hass, sequence)
|
||||
|
||||
await script_obj.async_run({"var": var})
|
||||
await script_obj.async_run(MappingProxyType({"var": var}))
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(events) == 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user