diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 1f9963e184b..6fed54227a3 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -176,8 +176,6 @@ class _ScriptRun: try: if self._stop.is_set(): return - self._script.last_triggered = utcnow() - self._changed() self._log("Running script") for self._step, self._action in enumerate(self._script.sequence): if self._stop.is_set(): @@ -797,6 +795,8 @@ class Script: self._hass, self, cast(dict, variables), context, self._log_exceptions ) self._runs.append(run) + self.last_triggered = utcnow() + self._changed() try: await asyncio.shield(run.async_run()) diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 28761c0ba17..fbfd06aa930 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -1290,23 +1290,46 @@ async def test_script_mode_queued(hass): sequence = cv.SCRIPT_SCHEMA( [ {"event": event, "event_data": {"value": 1}}, - {"wait_template": "{{ states.switch.test.state == 'off' }}"}, + { + "wait_template": "{{ states.switch.test.state == 'off' }}", + "alias": "wait_1", + }, {"event": event, "event_data": {"value": 2}}, - {"wait_template": "{{ states.switch.test.state == 'on' }}"}, + { + "wait_template": "{{ states.switch.test.state == 'on' }}", + "alias": "wait_2", + }, ] ) logger = logging.getLogger("TEST") script_obj = script.Script( hass, sequence, script_mode="queued", max_runs=2, logger=logger ) - wait_started_flag = async_watch_for_action(script_obj, "wait") + + watch_messages = [] + + @callback + def check_action(): + for message, flag in watch_messages: + if script_obj.last_action and message in script_obj.last_action: + flag.set() + + script_obj.change_listener = check_action + wait_started_flag_1 = asyncio.Event() + watch_messages.append(("wait_1", wait_started_flag_1)) + wait_started_flag_2 = asyncio.Event() + watch_messages.append(("wait_2", wait_started_flag_2)) try: + assert not script_obj.is_running + assert script_obj.runs == 0 + hass.states.async_set("switch.test", "on") hass.async_create_task(script_obj.async_run()) - await asyncio.wait_for(wait_started_flag.wait(), 1) + await asyncio.wait_for(wait_started_flag_1.wait(), 1) assert script_obj.is_running + assert script_obj.runs == 1 assert len(events) == 1 assert events[0].data["value"] == 1 @@ -1314,25 +1337,26 @@ async def test_script_mode_queued(hass): # This second run should not start until the first run has finished. hass.async_create_task(script_obj.async_run()) - await asyncio.sleep(0) + assert script_obj.is_running + assert script_obj.runs == 2 assert len(events) == 1 - wait_started_flag.clear() hass.states.async_set("switch.test", "off") - await asyncio.wait_for(wait_started_flag.wait(), 1) + await asyncio.wait_for(wait_started_flag_2.wait(), 1) assert script_obj.is_running + assert script_obj.runs == 2 assert len(events) == 2 assert events[1].data["value"] == 2 - wait_started_flag.clear() + wait_started_flag_1.clear() hass.states.async_set("switch.test", "on") - await asyncio.wait_for(wait_started_flag.wait(), 1) + await asyncio.wait_for(wait_started_flag_1.wait(), 1) - await asyncio.sleep(0) assert script_obj.is_running + assert script_obj.runs == 1 assert len(events) == 3 assert events[2].data["value"] == 1 except (AssertionError, asyncio.TimeoutError): @@ -1345,10 +1369,52 @@ async def test_script_mode_queued(hass): await hass.async_block_till_done() assert not script_obj.is_running + assert script_obj.runs == 0 assert len(events) == 4 assert events[3].data["value"] == 2 +async def test_script_mode_queued_cancel(hass): + """Test canceling with a queued run.""" + script_obj = script.Script( + hass, + cv.SCRIPT_SCHEMA({"wait_template": "{{ false }}"}), + "test", + script_mode="queued", + max_runs=2, + ) + wait_started_flag = async_watch_for_action(script_obj, "wait") + + try: + assert not script_obj.is_running + assert script_obj.runs == 0 + + task1 = hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(wait_started_flag.wait(), 1) + task2 = hass.async_create_task(script_obj.async_run()) + await asyncio.sleep(0) + + assert script_obj.is_running + assert script_obj.runs == 2 + + with pytest.raises(asyncio.CancelledError): + task2.cancel() + await task2 + + assert script_obj.is_running + assert script_obj.runs == 1 + + with pytest.raises(asyncio.CancelledError): + task1.cancel() + await task1 + + assert not script_obj.is_running + assert script_obj.runs == 0 + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + + async def test_script_logging(hass, caplog): """Test script logging.""" script_obj = script.Script(hass, [], "Script with % Name")