Create variable with result of wait_template and accept template for timeout option (#38634)

This commit is contained in:
Phil Bruckner 2020-08-12 13:42:06 -05:00 committed by GitHub
parent 45526f4e8a
commit 580e229cf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 164 additions and 128 deletions

View File

@ -47,11 +47,7 @@ TRIGGER_SCHEMA = vol.All(
vol.Optional(CONF_BELOW): vol.Coerce(float), vol.Optional(CONF_BELOW): vol.Coerce(float),
vol.Optional(CONF_ABOVE): vol.Coerce(float), vol.Optional(CONF_ABOVE): vol.Coerce(float),
vol.Optional(CONF_VALUE_TEMPLATE): cv.template, vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
vol.Optional(CONF_FOR): vol.Any( vol.Optional(CONF_FOR): cv.positive_time_period_template,
vol.All(cv.time_period, cv.positive_timedelta),
cv.template,
cv.template_complex,
),
} }
), ),
cv.has_at_least_one_key(CONF_BELOW, CONF_ABOVE), cv.has_at_least_one_key(CONF_BELOW, CONF_ABOVE),
@ -141,20 +137,9 @@ async def async_attach_trigger(
} }
try: try:
if isinstance(time_delta, template.Template): period[entity] = cv.positive_time_period(
period[entity] = vol.All(cv.time_period, cv.positive_timedelta)( template.render_complex(time_delta, variables)
time_delta.async_render(variables) )
)
elif isinstance(time_delta, dict):
time_delta_data = {}
time_delta_data.update(
template.render_complex(time_delta, variables)
)
period[entity] = vol.All(cv.time_period, cv.positive_timedelta)(
time_delta_data
)
else:
period[entity] = time_delta
except (exceptions.TemplateError, vol.Invalid) as ex: except (exceptions.TemplateError, vol.Invalid) as ex:
_LOGGER.error( _LOGGER.error(
"Error rendering '%s' for template: %s", "Error rendering '%s' for template: %s",

View File

@ -33,11 +33,7 @@ TRIGGER_SCHEMA = vol.All(
# These are str on purpose. Want to catch YAML conversions # These are str on purpose. Want to catch YAML conversions
vol.Optional(CONF_FROM): vol.Any(str, [str]), vol.Optional(CONF_FROM): vol.Any(str, [str]),
vol.Optional(CONF_TO): vol.Any(str, [str]), vol.Optional(CONF_TO): vol.Any(str, [str]),
vol.Optional(CONF_FOR): vol.Any( vol.Optional(CONF_FOR): cv.positive_time_period_template,
vol.All(cv.time_period, cv.positive_timedelta),
cv.template,
cv.template_complex,
),
} }
), ),
cv.key_dependency(CONF_FOR, CONF_TO), cv.key_dependency(CONF_FOR, CONF_TO),
@ -115,18 +111,9 @@ async def async_attach_trigger(
} }
try: try:
if isinstance(time_delta, template.Template): period[entity] = cv.positive_time_period(
period[entity] = vol.All(cv.time_period, cv.positive_timedelta)( template.render_complex(time_delta, variables)
time_delta.async_render(variables) )
)
elif isinstance(time_delta, dict):
time_delta_data = {}
time_delta_data.update(template.render_complex(time_delta, variables))
period[entity] = vol.All(cv.time_period, cv.positive_timedelta)(
time_delta_data
)
else:
period[entity] = time_delta
except (exceptions.TemplateError, vol.Invalid) as ex: except (exceptions.TemplateError, vol.Invalid) as ex:
_LOGGER.error( _LOGGER.error(
"Error rendering '%s' for template: %s", automation_info["name"], ex "Error rendering '%s' for template: %s", automation_info["name"], ex

View File

@ -17,11 +17,7 @@ TRIGGER_SCHEMA = IF_ACTION_SCHEMA = vol.Schema(
{ {
vol.Required(CONF_PLATFORM): "template", vol.Required(CONF_PLATFORM): "template",
vol.Required(CONF_VALUE_TEMPLATE): cv.template, vol.Required(CONF_VALUE_TEMPLATE): cv.template,
vol.Optional(CONF_FOR): vol.Any( vol.Optional(CONF_FOR): cv.positive_time_period_template,
vol.All(cv.time_period, cv.positive_timedelta),
cv.template,
cv.template_complex,
),
} }
) )
@ -73,16 +69,9 @@ async def async_attach_trigger(
} }
try: try:
if isinstance(time_delta, template.Template): period = cv.positive_time_period(
period = vol.All(cv.time_period, cv.positive_timedelta)( template.render_complex(time_delta, variables)
time_delta.async_render(variables) )
)
elif isinstance(time_delta, dict):
time_delta_data = {}
time_delta_data.update(template.render_complex(time_delta, variables))
period = vol.All(cv.time_period, cv.positive_timedelta)(time_delta_data)
else:
period = time_delta
except (exceptions.TemplateError, vol.Invalid) as ex: except (exceptions.TemplateError, vol.Invalid) as ex:
_LOGGER.error( _LOGGER.error(
"Error rendering '%s' for template: %s", automation_info["name"], ex "Error rendering '%s' for template: %s", automation_info["name"], ex

View File

@ -68,13 +68,13 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_SENSOR): cv.entity_id, vol.Required(CONF_SENSOR): cv.entity_id,
vol.Optional(CONF_AC_MODE): cv.boolean, vol.Optional(CONF_AC_MODE): cv.boolean,
vol.Optional(CONF_MAX_TEMP): vol.Coerce(float), vol.Optional(CONF_MAX_TEMP): vol.Coerce(float),
vol.Optional(CONF_MIN_DUR): vol.All(cv.time_period, cv.positive_timedelta), vol.Optional(CONF_MIN_DUR): cv.positive_time_period,
vol.Optional(CONF_MIN_TEMP): vol.Coerce(float), vol.Optional(CONF_MIN_TEMP): vol.Coerce(float),
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Optional(CONF_COLD_TOLERANCE, default=DEFAULT_TOLERANCE): vol.Coerce(float), vol.Optional(CONF_COLD_TOLERANCE, default=DEFAULT_TOLERANCE): vol.Coerce(float),
vol.Optional(CONF_HOT_TOLERANCE, default=DEFAULT_TOLERANCE): vol.Coerce(float), vol.Optional(CONF_HOT_TOLERANCE, default=DEFAULT_TOLERANCE): vol.Coerce(float),
vol.Optional(CONF_TARGET_TEMP): vol.Coerce(float), vol.Optional(CONF_TARGET_TEMP): vol.Coerce(float),
vol.Optional(CONF_KEEP_ALIVE): vol.All(cv.time_period, cv.positive_timedelta), vol.Optional(CONF_KEEP_ALIVE): cv.positive_time_period,
vol.Optional(CONF_INITIAL_HVAC_MODE): vol.In( vol.Optional(CONF_INITIAL_HVAC_MODE): vol.In(
[HVAC_MODE_COOL, HVAC_MODE_HEAT, HVAC_MODE_OFF] [HVAC_MODE_COOL, HVAC_MODE_HEAT, HVAC_MODE_OFF]
), ),

View File

@ -56,7 +56,7 @@ CONFIG_SCHEMA = vol.Schema(
vol.Required(CONF_DEVICE_PORT): cv.port, vol.Required(CONF_DEVICE_PORT): cv.port,
vol.Optional( vol.Optional(
CONF_SCAN_INTERVAL, default=DEFAULT_SCAN_INTERVAL CONF_SCAN_INTERVAL, default=DEFAULT_SCAN_INTERVAL
): vol.All(cv.time_period, cv.positive_timedelta), ): cv.positive_time_period,
vol.Optional(CONF_ZONES, default=DEFAULT_ZONES): vol.All( vol.Optional(CONF_ZONES, default=DEFAULT_ZONES): vol.All(
cv.ensure_list, [ZONE_SCHEMA] cv.ensure_list, [ZONE_SCHEMA]
), ),

View File

@ -35,7 +35,7 @@ CONFIG_SCHEMA = vol.Schema(
vol.Optional(CONF_SERVER_ID): cv.positive_int, vol.Optional(CONF_SERVER_ID): cv.positive_int,
vol.Optional( vol.Optional(
CONF_SCAN_INTERVAL, default=timedelta(minutes=DEFAULT_SCAN_INTERVAL) CONF_SCAN_INTERVAL, default=timedelta(minutes=DEFAULT_SCAN_INTERVAL)
): vol.All(cv.time_period, cv.positive_timedelta), ): cv.positive_time_period,
vol.Optional(CONF_MANUAL, default=False): cv.boolean, vol.Optional(CONF_MANUAL, default=False): cv.boolean,
vol.Optional( vol.Optional(
CONF_MONITORED_CONDITIONS, default=list(SENSOR_TYPES) CONF_MONITORED_CONDITIONS, default=list(SENSOR_TYPES)

View File

@ -49,8 +49,8 @@ SENSOR_SCHEMA = vol.Schema(
vol.Optional(ATTR_FRIENDLY_NAME): cv.string, vol.Optional(ATTR_FRIENDLY_NAME): cv.string,
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids, vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA, vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA,
vol.Optional(CONF_DELAY_ON): vol.All(cv.time_period, cv.positive_timedelta), vol.Optional(CONF_DELAY_ON): cv.positive_time_period,
vol.Optional(CONF_DELAY_OFF): vol.All(cv.time_period, cv.positive_timedelta), vol.Optional(CONF_DELAY_OFF): cv.positive_time_period,
vol.Optional(CONF_UNIQUE_ID): cv.string, vol.Optional(CONF_UNIQUE_ID): cv.string,
} }
) )

View File

@ -48,7 +48,7 @@ CONFIG_SCHEMA = vol.Schema(
vol.Required(CONF_CLIENT_SECRET): cv.string, vol.Required(CONF_CLIENT_SECRET): cv.string,
vol.Optional( vol.Optional(
CONF_SCAN_INTERVAL, default=DEFAULT_SCAN_INTERVAL CONF_SCAN_INTERVAL, default=DEFAULT_SCAN_INTERVAL
): vol.All(cv.time_period, cv.positive_timedelta), ): cv.positive_time_period,
} }
), ),
) )

View File

@ -102,7 +102,7 @@ SERVICE_SCHEMA_SET_SCENE = XIAOMI_MIIO_SERVICE_SCHEMA.extend(
) )
SERVICE_SCHEMA_SET_DELAYED_TURN_OFF = XIAOMI_MIIO_SERVICE_SCHEMA.extend( SERVICE_SCHEMA_SET_DELAYED_TURN_OFF = XIAOMI_MIIO_SERVICE_SCHEMA.extend(
{vol.Required(ATTR_TIME_PERIOD): vol.All(cv.time_period, cv.positive_timedelta)} {vol.Required(ATTR_TIME_PERIOD): cv.positive_time_period}
) )
SERVICE_TO_METHOD = { SERVICE_TO_METHOD = {

View File

@ -402,6 +402,7 @@ def positive_timedelta(value: timedelta) -> timedelta:
positive_time_period_dict = vol.All(time_period_dict, positive_timedelta) positive_time_period_dict = vol.All(time_period_dict, positive_timedelta)
positive_time_period = vol.All(time_period, positive_timedelta)
def remove_falsy(value: List[T]) -> List[T]: def remove_falsy(value: List[T]) -> List[T]:
@ -530,6 +531,11 @@ def template_complex(value: Any) -> Any:
return value return value
positive_time_period_template = vol.Any(
positive_time_period, template, template_complex
)
def datetime(value: Any) -> datetime_sys: def datetime(value: Any) -> datetime_sys:
"""Validate datetime.""" """Validate datetime."""
if isinstance(value, datetime_sys): if isinstance(value, datetime_sys):
@ -876,7 +882,7 @@ STATE_CONDITION_SCHEMA = vol.All(
vol.Required(CONF_CONDITION): "state", vol.Required(CONF_CONDITION): "state",
vol.Required(CONF_ENTITY_ID): entity_ids, vol.Required(CONF_ENTITY_ID): entity_ids,
vol.Required(CONF_STATE): vol.Any(str, [str]), vol.Required(CONF_STATE): vol.Any(str, [str]),
vol.Optional(CONF_FOR): vol.All(time_period, positive_timedelta), vol.Optional(CONF_FOR): positive_time_period,
# To support use_trigger_value in automation # To support use_trigger_value in automation
# Deprecated 2016/04/25 # Deprecated 2016/04/25
vol.Optional("from"): str, vol.Optional("from"): str,
@ -992,9 +998,7 @@ CONDITION_SCHEMA: vol.Schema = key_value_schemas(
_SCRIPT_DELAY_SCHEMA = vol.Schema( _SCRIPT_DELAY_SCHEMA = vol.Schema(
{ {
vol.Optional(CONF_ALIAS): string, vol.Optional(CONF_ALIAS): string,
vol.Required(CONF_DELAY): vol.Any( vol.Required(CONF_DELAY): positive_time_period_template,
vol.All(time_period, positive_timedelta), template, template_complex
),
} }
) )
@ -1002,7 +1006,7 @@ _SCRIPT_WAIT_TEMPLATE_SCHEMA = vol.Schema(
{ {
vol.Optional(CONF_ALIAS): string, vol.Optional(CONF_ALIAS): string,
vol.Required(CONF_WAIT_TEMPLATE): template, vol.Required(CONF_WAIT_TEMPLATE): template,
vol.Optional(CONF_TIMEOUT): vol.All(time_period, positive_timedelta), vol.Optional(CONF_TIMEOUT): positive_time_period_template,
vol.Optional(CONF_CONTINUE_ON_TIMEOUT): boolean, vol.Optional(CONF_CONTINUE_ON_TIMEOUT): boolean,
} }
) )

View File

@ -1,6 +1,6 @@
"""Helpers to execute scripts.""" """Helpers to execute scripts."""
import asyncio import asyncio
from datetime import datetime from datetime import datetime, timedelta
from functools import partial from functools import partial
import itertools import itertools
import logging import logging
@ -241,21 +241,25 @@ class _ScriptRun:
level=level, level=level,
) )
async def _async_delay_step(self): def _get_pos_time_period_template(self, key):
"""Handle delay."""
try: try:
delay = vol.All(cv.time_period, cv.positive_timedelta)( return cv.positive_time_period(
template.render_complex(self._action[CONF_DELAY], self._variables) template.render_complex(self._action[key], self._variables)
) )
except (exceptions.TemplateError, vol.Invalid) as ex: except (exceptions.TemplateError, vol.Invalid) as ex:
self._log( self._log(
"Error rendering %s delay template: %s", "Error rendering %s %s template: %s",
self._script.name, self._script.name,
key,
ex, ex,
level=logging.ERROR, level=logging.ERROR,
) )
raise _StopScript raise _StopScript
async def _async_delay_step(self):
"""Handle delay."""
delay = self._get_pos_time_period_template(CONF_DELAY)
self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}") self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}")
self._log("Executing step %s", self._script.last_action) self._log("Executing step %s", self._script.last_action)
@ -269,41 +273,55 @@ class _ScriptRun:
async def _async_wait_template_step(self): async def _async_wait_template_step(self):
"""Handle a wait template.""" """Handle a wait template."""
if CONF_TIMEOUT in self._action:
delay = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
else:
delay = None
self._script.last_action = self._action.get(CONF_ALIAS, "wait template") self._script.last_action = self._action.get(CONF_ALIAS, "wait template")
self._log("Executing step %s", self._script.last_action) self._log(
"Executing step %s%s",
self._script.last_action,
"" if delay is None else f" (timeout: {timedelta(seconds=delay)})",
)
self._variables["wait"] = {"remaining": delay, "completed": False}
wait_template = self._action[CONF_WAIT_TEMPLATE] wait_template = self._action[CONF_WAIT_TEMPLATE]
wait_template.hass = self._hass wait_template.hass = self._hass
# check if condition already okay # check if condition already okay
if condition.async_template(self._hass, wait_template, self._variables): if condition.async_template(self._hass, wait_template, self._variables):
self._variables["wait"]["completed"] = True
return return
@callback @callback
def async_script_wait(entity_id, from_s, to_s): def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true.""" """Handle script after template condition is true."""
self._variables["wait"] = {
"remaining": to_context.remaining if to_context else delay,
"completed": True,
}
done.set() done.set()
to_context = None
unsub = async_track_template( unsub = async_track_template(
self._hass, wait_template, async_script_wait, self._variables self._hass, wait_template, async_script_wait, self._variables
) )
self._changed() self._changed()
try:
delay = self._action[CONF_TIMEOUT].total_seconds()
except KeyError:
delay = None
done = asyncio.Event() done = asyncio.Event()
tasks = [ tasks = [
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done) self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
] ]
try: try:
async with timeout(delay): async with timeout(delay) as to_context:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
except asyncio.TimeoutError: except asyncio.TimeoutError:
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._log(_TIMEOUT_MSG) self._log(_TIMEOUT_MSG)
raise _StopScript raise _StopScript
self._variables["wait"]["remaining"] = 0.0
finally: finally:
for task in tasks: for task in tasks:
task.cancel() task.cancel()

View File

@ -16,7 +16,6 @@ import homeassistant.components.scene as scene
from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON
from homeassistant.core import Context, CoreState, callback from homeassistant.core import Context, CoreState, callback
from homeassistant.helpers import config_validation as cv, script from homeassistant.helpers import config_validation as cv, script
from homeassistant.helpers.event import async_call_later
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from tests.async_mock import patch from tests.async_mock import patch
@ -29,49 +28,6 @@ from tests.common import (
ENTITY_ID = "script.test" ENTITY_ID = "script.test"
@pytest.fixture
def mock_timeout(hass, monkeypatch):
"""Mock async_timeout.timeout."""
class MockTimeout:
def __init__(self, timeout):
self._timeout = timeout
self._loop = asyncio.get_event_loop()
self._task = None
self._cancelled = False
self._unsub = None
async def __aenter__(self):
if self._timeout is None:
return self
self._task = asyncio.Task.current_task()
if self._timeout <= 0:
self._loop.call_soon(self._cancel_task)
return self
# Wait for a time_changed event instead of real time passing.
self._unsub = async_call_later(hass, self._timeout, self._cancel_task)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is asyncio.CancelledError and self._cancelled:
self._unsub = None
self._task = None
raise asyncio.TimeoutError
if self._timeout is not None and self._unsub:
self._unsub()
self._unsub = None
self._task = None
return None
@callback
def _cancel_task(self, now=None):
if self._task is not None:
self._task.cancel()
self._cancelled = True
monkeypatch.setattr(script, "timeout", MockTimeout)
def async_watch_for_action(script_obj, message): def async_watch_for_action(script_obj, message):
"""Watch for message in last_action.""" """Watch for message in last_action."""
flag = asyncio.Event() flag = asyncio.Event()
@ -326,7 +282,7 @@ async def test_stop_no_wait(hass, count):
assert len(events) == 0 assert len(events) == 0
async def test_delay_basic(hass, mock_timeout): async def test_delay_basic(hass):
"""Test the delay.""" """Test the delay."""
delay_alias = "delay step" delay_alias = "delay step"
sequence = cv.SCRIPT_SCHEMA({"delay": {"seconds": 5}, "alias": delay_alias}) sequence = cv.SCRIPT_SCHEMA({"delay": {"seconds": 5}, "alias": delay_alias})
@ -350,7 +306,7 @@ async def test_delay_basic(hass, mock_timeout):
assert script_obj.last_action is None assert script_obj.last_action is None
async def test_multiple_runs_delay(hass, mock_timeout): async def test_multiple_runs_delay(hass):
"""Test multiple runs with delay in script.""" """Test multiple runs with delay in script."""
event = "test_event" event = "test_event"
events = async_capture_events(hass, event) events = async_capture_events(hass, event)
@ -393,7 +349,7 @@ async def test_multiple_runs_delay(hass, mock_timeout):
assert events[-1].data["value"] == 2 assert events[-1].data["value"] == 2
async def test_delay_template_ok(hass, mock_timeout): async def test_delay_template_ok(hass):
"""Test the delay as a template.""" """Test the delay as a template."""
sequence = cv.SCRIPT_SCHEMA({"delay": "00:00:{{ 5 }}"}) sequence = cv.SCRIPT_SCHEMA({"delay": "00:00:{{ 5 }}"})
script_obj = script.Script(hass, sequence, "Test Name", "test_domain") script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
@ -441,7 +397,7 @@ async def test_delay_template_invalid(hass, caplog):
assert len(events) == 1 assert len(events) == 1
async def test_delay_template_complex_ok(hass, mock_timeout): async def test_delay_template_complex_ok(hass):
"""Test the delay with a working complex template.""" """Test the delay with a working complex template."""
sequence = cv.SCRIPT_SCHEMA({"delay": {"seconds": "{{ 5 }}"}}) sequence = cv.SCRIPT_SCHEMA({"delay": {"seconds": "{{ 5 }}"}})
script_obj = script.Script(hass, sequence, "Test Name", "test_domain") script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
@ -647,11 +603,56 @@ async def test_wait_template_not_schedule(hass):
assert len(events) == 2 assert len(events) == 2
@pytest.mark.parametrize(
"timeout_param", [5, "{{ 5 }}", {"seconds": 5}, {"seconds": "{{ 5 }}"}]
)
async def test_wait_template_timeout(hass, caplog, timeout_param):
"""Test the wait timeout option."""
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA(
[
{
"wait_template": "{{ states.switch.test.state == 'off' }}",
"timeout": timeout_param,
"continue_on_timeout": True,
},
{"event": event},
]
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait")
try:
hass.states.async_set("switch.test", "on")
hass.async_create_task(script_obj.async_run())
await asyncio.wait_for(wait_started_flag.wait(), 1)
assert script_obj.is_running
assert len(events) == 0
except (AssertionError, asyncio.TimeoutError):
await script_obj.async_stop()
raise
else:
cur_time = dt_util.utcnow()
async_fire_time_changed(hass, cur_time + timedelta(seconds=4))
await asyncio.sleep(0)
assert len(events) == 0
async_fire_time_changed(hass, cur_time + timedelta(seconds=5))
await hass.async_block_till_done()
assert not script_obj.is_running
assert len(events) == 1
assert "(timeout: 0:00:05)" in caplog.text
@pytest.mark.parametrize( @pytest.mark.parametrize(
"continue_on_timeout,n_events", [(False, 0), (True, 1), (None, 1)] "continue_on_timeout,n_events", [(False, 0), (True, 1), (None, 1)]
) )
async def test_wait_template_timeout(hass, mock_timeout, continue_on_timeout, n_events): async def test_wait_template_continue_on_timeout(hass, continue_on_timeout, n_events):
"""Test the wait template, halt on timeout.""" """Test the wait template continue_on_timeout option."""
event = "test_event" event = "test_event"
events = async_capture_events(hass, event) events = async_capture_events(hass, event)
sequence = [ sequence = [
@ -682,8 +683,8 @@ async def test_wait_template_timeout(hass, mock_timeout, continue_on_timeout, n_
assert len(events) == n_events assert len(events) == n_events
async def test_wait_template_variables(hass): async def test_wait_template_variables_in(hass):
"""Test the wait template with variables.""" """Test the wait template with input variables."""
sequence = cv.SCRIPT_SCHEMA({"wait_template": "{{ is_state(data, 'off') }}"}) sequence = cv.SCRIPT_SCHEMA({"wait_template": "{{ is_state(data, 'off') }}"})
script_obj = script.Script(hass, sequence, "Test Name", "test_domain") script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait") wait_started_flag = async_watch_for_action(script_obj, "wait")
@ -706,6 +707,58 @@ async def test_wait_template_variables(hass):
assert not script_obj.is_running assert not script_obj.is_running
@pytest.mark.parametrize("mode", ["no_timeout", "timeout_finish", "timeout_not_finish"])
async def test_wait_template_variables_out(hass, mode):
"""Test the wait template output variable."""
event = "test_event"
events = async_capture_events(hass, event)
action = {"wait_template": "{{ states.switch.test.state == 'off' }}"}
if mode != "no_timeout":
action["timeout"] = 5
action["continue_on_timeout"] = True
sequence = [
action,
{
"event": event,
"event_data_template": {
"completed": "{{ wait.completed }}",
"remaining": "{{ wait.remaining }}",
},
},
]
sequence = cv.SCRIPT_SCHEMA(sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait")
try:
hass.states.async_set("switch.test", "on")
hass.async_create_task(script_obj.async_run())
await asyncio.wait_for(wait_started_flag.wait(), 1)
assert script_obj.is_running
assert len(events) == 0
except (AssertionError, asyncio.TimeoutError):
await script_obj.async_stop()
raise
else:
if mode == "timeout_not_finish":
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=5))
else:
hass.states.async_set("switch.test", "off")
await hass.async_block_till_done()
assert not script_obj.is_running
assert len(events) == 1
assert events[0].data["completed"] == str(mode != "timeout_not_finish")
remaining = events[0].data["remaining"]
if mode == "no_timeout":
assert remaining == "None"
elif mode == "timeout_finish":
assert 0.0 < float(remaining) < 5
else:
assert float(remaining) == 0.0
async def test_condition_basic(hass): async def test_condition_basic(hass):
"""Test if we can use conditions in a script.""" """Test if we can use conditions in a script."""
event = "test_event" event = "test_event"