mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 14:17:45 +00:00
Fix repeat action when variables present (#38237)
This commit is contained in:
parent
bea1570354
commit
1158925b53
@ -4,7 +4,19 @@ from datetime import datetime
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
from types import MappingProxyType
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from async_timeout import timeout
|
from async_timeout import timeout
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@ -49,7 +61,7 @@ from homeassistant.helpers.service import (
|
|||||||
CONF_SERVICE_DATA,
|
CONF_SERVICE_DATA,
|
||||||
async_prepare_call_from_config,
|
async_prepare_call_from_config,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.util import slugify
|
from homeassistant.util import slugify
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
@ -134,13 +146,13 @@ class _ScriptRun:
|
|||||||
self,
|
self,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
script: "Script",
|
script: "Script",
|
||||||
variables: TemplateVarsType,
|
variables: Dict[str, Any],
|
||||||
context: Optional[Context],
|
context: Optional[Context],
|
||||||
log_exceptions: bool,
|
log_exceptions: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
self._script = script
|
self._script = script
|
||||||
self._variables = variables or {}
|
self._variables = variables
|
||||||
self._context = context
|
self._context = context
|
||||||
self._log_exceptions = log_exceptions
|
self._log_exceptions = log_exceptions
|
||||||
self._step = -1
|
self._step = -1
|
||||||
@ -595,6 +607,9 @@ async def _async_stop_scripts_at_shutdown(hass, event):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_VarsType = Union[Dict[str, Any], MappingProxyType]
|
||||||
|
|
||||||
|
|
||||||
class Script:
|
class Script:
|
||||||
"""Representation of a script."""
|
"""Representation of a script."""
|
||||||
|
|
||||||
@ -617,6 +632,7 @@ class Script:
|
|||||||
hass.bus.async_listen_once(
|
hass.bus.async_listen_once(
|
||||||
EVENT_HOMEASSISTANT_STOP, partial(_async_stop_scripts_at_shutdown, hass)
|
EVENT_HOMEASSISTANT_STOP, partial(_async_stop_scripts_at_shutdown, hass)
|
||||||
)
|
)
|
||||||
|
self._top_level = top_level
|
||||||
if top_level:
|
if top_level:
|
||||||
all_scripts.append(
|
all_scripts.append(
|
||||||
{"instance": self, "started_before_shutdown": not hass.is_stopping}
|
{"instance": self, "started_before_shutdown": not hass.is_stopping}
|
||||||
@ -745,7 +761,7 @@ class Script:
|
|||||||
return referenced
|
return referenced
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self, variables: TemplateVarsType = None, context: Optional[Context] = None
|
self, variables: Optional[_VarsType] = None, context: Optional[Context] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run script."""
|
"""Run script."""
|
||||||
asyncio.run_coroutine_threadsafe(
|
asyncio.run_coroutine_threadsafe(
|
||||||
@ -753,7 +769,7 @@ class Script:
|
|||||||
).result()
|
).result()
|
||||||
|
|
||||||
async def async_run(
|
async def async_run(
|
||||||
self, variables: TemplateVarsType = None, context: Optional[Context] = None
|
self, variables: Optional[_VarsType] = None, context: Optional[Context] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run script."""
|
"""Run script."""
|
||||||
if self.is_running:
|
if self.is_running:
|
||||||
@ -767,11 +783,19 @@ class Script:
|
|||||||
self._log("Maximum number of runs exceeded", level=logging.WARNING)
|
self._log("Maximum number of runs exceeded", level=logging.WARNING)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If this is a top level Script then make a copy of the variables in case they
|
||||||
|
# are read-only, but more importantly, so as not to leak any variables created
|
||||||
|
# during the run back to the caller.
|
||||||
|
if self._top_level:
|
||||||
|
variables = dict(variables) if variables is not None else {}
|
||||||
|
|
||||||
if self.script_mode != SCRIPT_MODE_QUEUED:
|
if self.script_mode != SCRIPT_MODE_QUEUED:
|
||||||
cls = _ScriptRun
|
cls = _ScriptRun
|
||||||
else:
|
else:
|
||||||
cls = _QueuedScriptRun
|
cls = _QueuedScriptRun
|
||||||
run = cls(self._hass, self, variables, context, self._log_exceptions)
|
run = cls(
|
||||||
|
self._hass, self, cast(dict, variables), context, self._log_exceptions
|
||||||
|
)
|
||||||
self._runs.append(run)
|
self._runs.append(run)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -4,6 +4,7 @@ import asyncio
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import logging
|
import logging
|
||||||
|
from types import MappingProxyType
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -122,7 +123,7 @@ async def test_firing_event_template(hass):
|
|||||||
)
|
)
|
||||||
script_obj = script.Script(hass, sequence)
|
script_obj = script.Script(hass, sequence)
|
||||||
|
|
||||||
await script_obj.async_run({"is_world": "yes"}, context=context)
|
await script_obj.async_run(MappingProxyType({"is_world": "yes"}), context=context)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
@ -175,7 +176,7 @@ async def test_calling_service_template(hass):
|
|||||||
)
|
)
|
||||||
script_obj = script.Script(hass, sequence)
|
script_obj = script.Script(hass, sequence)
|
||||||
|
|
||||||
await script_obj.async_run({"is_world": "yes"}, context=context)
|
await script_obj.async_run(MappingProxyType({"is_world": "yes"}), context=context)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
@ -235,15 +236,17 @@ async def test_multiple_runs_no_wait(hass):
|
|||||||
logger.debug("starting 1st script")
|
logger.debug("starting 1st script")
|
||||||
hass.async_create_task(
|
hass.async_create_task(
|
||||||
script_obj.async_run(
|
script_obj.async_run(
|
||||||
|
MappingProxyType(
|
||||||
{"fire1": "1", "listen1": "2", "fire2": "3", "listen2": "4"}
|
{"fire1": "1", "listen1": "2", "fire2": "3", "listen2": "4"}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
)
|
||||||
await asyncio.wait_for(heard_event.wait(), 1)
|
await asyncio.wait_for(heard_event.wait(), 1)
|
||||||
unsub()
|
unsub()
|
||||||
|
|
||||||
logger.debug("starting 2nd script")
|
logger.debug("starting 2nd script")
|
||||||
await script_obj.async_run(
|
await script_obj.async_run(
|
||||||
{"fire1": "2", "listen1": "3", "fire2": "4", "listen2": "4"}
|
MappingProxyType({"fire1": "2", "listen1": "3", "fire2": "4", "listen2": "4"})
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
@ -670,7 +673,9 @@ async def test_wait_template_variables(hass):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
hass.states.async_set("switch.test", "on")
|
hass.states.async_set("switch.test", "on")
|
||||||
hass.async_create_task(script_obj.async_run({"data": "switch.test"}))
|
hass.async_create_task(
|
||||||
|
script_obj.async_run(MappingProxyType({"data": "switch.test"}))
|
||||||
|
)
|
||||||
await asyncio.wait_for(wait_started_flag.wait(), 1)
|
await asyncio.wait_for(wait_started_flag.wait(), 1)
|
||||||
|
|
||||||
assert script_obj.is_running
|
assert script_obj.is_running
|
||||||
@ -882,7 +887,14 @@ async def test_repeat_var_in_condition(hass, condition):
|
|||||||
assert len(events) == 2
|
assert len(events) == 2
|
||||||
|
|
||||||
|
|
||||||
async def test_repeat_nested(hass):
|
@pytest.mark.parametrize(
|
||||||
|
"variables,first_last,inside_x",
|
||||||
|
[
|
||||||
|
(None, {"repeat": "None", "x": "None"}, "None"),
|
||||||
|
(MappingProxyType({"x": 1}), {"repeat": "None", "x": "1"}, "1"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_repeat_nested(hass, variables, first_last, inside_x):
|
||||||
"""Test nested repeats."""
|
"""Test nested repeats."""
|
||||||
event = "test_event"
|
event = "test_event"
|
||||||
events = async_capture_events(hass, event)
|
events = async_capture_events(hass, event)
|
||||||
@ -892,7 +904,8 @@ async def test_repeat_nested(hass):
|
|||||||
{
|
{
|
||||||
"event": event,
|
"event": event,
|
||||||
"event_data_template": {
|
"event_data_template": {
|
||||||
"repeat": "{{ None if repeat is not defined else repeat }}"
|
"repeat": "{{ None if repeat is not defined else repeat }}",
|
||||||
|
"x": "{{ None if x is not defined else x }}",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -905,6 +918,7 @@ async def test_repeat_nested(hass):
|
|||||||
"first": "{{ repeat.first }}",
|
"first": "{{ repeat.first }}",
|
||||||
"index": "{{ repeat.index }}",
|
"index": "{{ repeat.index }}",
|
||||||
"last": "{{ repeat.last }}",
|
"last": "{{ repeat.last }}",
|
||||||
|
"x": "{{ None if x is not defined else x }}",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -916,6 +930,7 @@ async def test_repeat_nested(hass):
|
|||||||
"first": "{{ repeat.first }}",
|
"first": "{{ repeat.first }}",
|
||||||
"index": "{{ repeat.index }}",
|
"index": "{{ repeat.index }}",
|
||||||
"last": "{{ repeat.last }}",
|
"last": "{{ repeat.last }}",
|
||||||
|
"x": "{{ None if x is not defined else x }}",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -926,6 +941,7 @@ async def test_repeat_nested(hass):
|
|||||||
"first": "{{ repeat.first }}",
|
"first": "{{ repeat.first }}",
|
||||||
"index": "{{ repeat.index }}",
|
"index": "{{ repeat.index }}",
|
||||||
"last": "{{ repeat.last }}",
|
"last": "{{ repeat.last }}",
|
||||||
|
"x": "{{ None if x is not defined else x }}",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@ -934,7 +950,8 @@ async def test_repeat_nested(hass):
|
|||||||
{
|
{
|
||||||
"event": event,
|
"event": event,
|
||||||
"event_data_template": {
|
"event_data_template": {
|
||||||
"repeat": "{{ None if repeat is not defined else repeat }}"
|
"repeat": "{{ None if repeat is not defined else repeat }}",
|
||||||
|
"x": "{{ None if x is not defined else x }}",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
@ -945,21 +962,21 @@ async def test_repeat_nested(hass):
|
|||||||
"homeassistant.helpers.condition._LOGGER.error",
|
"homeassistant.helpers.condition._LOGGER.error",
|
||||||
side_effect=AssertionError("Template Error"),
|
side_effect=AssertionError("Template Error"),
|
||||||
):
|
):
|
||||||
await script_obj.async_run()
|
await script_obj.async_run(variables)
|
||||||
|
|
||||||
assert len(events) == 10
|
assert len(events) == 10
|
||||||
assert events[0].data == {"repeat": "None"}
|
assert events[0].data == first_last
|
||||||
assert events[-1].data == {"repeat": "None"}
|
assert events[-1].data == first_last
|
||||||
for index, result in enumerate(
|
for index, result in enumerate(
|
||||||
(
|
(
|
||||||
("True", "1", "False"),
|
("True", "1", "False", inside_x),
|
||||||
("True", "1", "False"),
|
("True", "1", "False", inside_x),
|
||||||
("False", "2", "True"),
|
("False", "2", "True", inside_x),
|
||||||
("True", "1", "False"),
|
("True", "1", "False", inside_x),
|
||||||
("False", "2", "True"),
|
("False", "2", "True", inside_x),
|
||||||
("True", "1", "False"),
|
("True", "1", "False", inside_x),
|
||||||
("False", "2", "True"),
|
("False", "2", "True", inside_x),
|
||||||
("False", "2", "True"),
|
("False", "2", "True", inside_x),
|
||||||
),
|
),
|
||||||
1,
|
1,
|
||||||
):
|
):
|
||||||
@ -967,6 +984,7 @@ async def test_repeat_nested(hass):
|
|||||||
"first": result[0],
|
"first": result[0],
|
||||||
"index": result[1],
|
"index": result[1],
|
||||||
"last": result[2],
|
"last": result[2],
|
||||||
|
"x": result[3],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -998,7 +1016,7 @@ async def test_choose(hass, var, result):
|
|||||||
)
|
)
|
||||||
script_obj = script.Script(hass, sequence)
|
script_obj = script.Script(hass, sequence)
|
||||||
|
|
||||||
await script_obj.async_run({"var": var})
|
await script_obj.async_run(MappingProxyType({"var": var}))
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user