diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index c724b9e890d..107cd9e2106 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -393,58 +393,85 @@ class _ScriptRun(_ScriptRunBase): except KeyError: delay = None done = asyncio.Event() + tasks = [ + self._hass.async_create_task(flag.wait()) for flag in (self._stop, done) + ] try: async with timeout(delay): - _, pending = await asyncio.wait( - {self._stop.wait(), done.wait()}, - return_when=asyncio.FIRST_COMPLETED, - ) - for pending_task in pending: - pending_task.cancel() + await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) except asyncio.TimeoutError: if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): self._log(_TIMEOUT_MSG) raise _StopScript finally: + for task in tasks: + task.cancel() unsub() async def _async_call_service_step(self): """Call the service specified in the action.""" domain, service, service_data = self._prep_call_service_step() + running_script = ( + domain == "automation" + and service == "trigger" + or domain == "python_script" + or domain == "script" + and service != SERVICE_TURN_OFF + ) # If this might start a script then disable the call timeout. # Otherwise use the normal service call limit. - if domain == "script" and service != SERVICE_TURN_OFF: + if running_script: limit = None else: limit = SERVICE_CALL_LIMIT - coro = self._hass.services.async_call( - domain, - service, - service_data, - blocking=True, - context=self._context, - limit=limit, + service_task = self._hass.async_create_task( + self._hass.services.async_call( + domain, + service, + service_data, + blocking=True, + context=self._context, + limit=limit, + ) ) - if limit is not None: # There is a call limit, so just wait for it to finish. - await coro + await service_task return - # No call limit (i.e., potentially starting one or more fully blocking scripts) - # so watch for a stop request. - done, pending = await asyncio.wait( - {self._stop.wait(), coro}, return_when=asyncio.FIRST_COMPLETED, - ) - # Note that cancelling the service call, if it has not yet returned, will also - # stop any non-background script runs that it may have started. - for pending_task in pending: - pending_task.cancel() - # Propagate any exceptions that might have happened. - for done_task in done: - done_task.result() + async def async_cancel_service_task(): + # Stop service task and wait for it to finish. + service_task.cancel() + try: + await service_task + except Exception: # pylint: disable=broad-except + pass + + # No call limit so watch for a stop request. + stop_task = self._hass.async_create_task(self._stop.wait()) + try: + await asyncio.wait( + {service_task, stop_task}, return_when=asyncio.FIRST_COMPLETED + ) + # If our task is cancelled, then cancel service task, too. Note that if service + # task is cancelled otherwise the CancelledError exception will not be raised to + # here due to the call to asyncio.wait(). Rather we'll check for that below. + except asyncio.CancelledError: + await async_cancel_service_task() + raise + finally: + stop_task.cancel() + + if service_task.cancelled(): + raise asyncio.CancelledError + if service_task.done(): + # Propagate any exceptions that occurred. + service_task.result() + elif running_script: + # Stopped before service completed, so cancel service. + await async_cancel_service_task() class _QueuedScriptRun(_ScriptRun): @@ -459,12 +486,18 @@ class _QueuedScriptRun(_ScriptRun): lock_task = self._hass.async_create_task( self._script._queue_lck.acquire() # pylint: disable=protected-access ) - done, pending = await asyncio.wait( - {self._stop.wait(), lock_task}, return_when=asyncio.FIRST_COMPLETED - ) - for pending_task in pending: - pending_task.cancel() - self.lock_acquired = lock_task in done + stop_task = self._hass.async_create_task(self._stop.wait()) + try: + await asyncio.wait( + {lock_task, stop_task}, return_when=asyncio.FIRST_COMPLETED + ) + except asyncio.CancelledError: + lock_task.cancel() + self._finish() + raise + finally: + stop_task.cancel() + self.lock_acquired = lock_task.done() and not lock_task.cancelled() # If we've been told to stop, then just finish up. Otherwise, we've acquired the # lock so we can go ahead and start the run.