Add support for simultaneous runs of Script helper - Part 3 (#36202)

This commit is contained in:
Phil Bruckner 2020-05-27 17:10:28 -05:00 committed by GitHub
parent f14a4935df
commit 1e9ec917f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -393,58 +393,85 @@ class _ScriptRun(_ScriptRunBase):
except KeyError: except KeyError:
delay = None delay = None
done = asyncio.Event() done = asyncio.Event()
tasks = [
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
]
try: try:
async with timeout(delay): async with timeout(delay):
_, pending = await asyncio.wait( await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
{self._stop.wait(), done.wait()},
return_when=asyncio.FIRST_COMPLETED,
)
for pending_task in pending:
pending_task.cancel()
except asyncio.TimeoutError: except asyncio.TimeoutError:
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._log(_TIMEOUT_MSG) self._log(_TIMEOUT_MSG)
raise _StopScript raise _StopScript
finally: finally:
for task in tasks:
task.cancel()
unsub() unsub()
async def _async_call_service_step(self): async def _async_call_service_step(self):
"""Call the service specified in the action.""" """Call the service specified in the action."""
domain, service, service_data = self._prep_call_service_step() domain, service, service_data = self._prep_call_service_step()
running_script = (
domain == "automation"
and service == "trigger"
or domain == "python_script"
or domain == "script"
and service != SERVICE_TURN_OFF
)
# If this might start a script then disable the call timeout. # If this might start a script then disable the call timeout.
# Otherwise use the normal service call limit. # Otherwise use the normal service call limit.
if domain == "script" and service != SERVICE_TURN_OFF: if running_script:
limit = None limit = None
else: else:
limit = SERVICE_CALL_LIMIT limit = SERVICE_CALL_LIMIT
coro = self._hass.services.async_call( service_task = self._hass.async_create_task(
domain, self._hass.services.async_call(
service, domain,
service_data, service,
blocking=True, service_data,
context=self._context, blocking=True,
limit=limit, context=self._context,
limit=limit,
)
) )
if limit is not None: if limit is not None:
# There is a call limit, so just wait for it to finish. # There is a call limit, so just wait for it to finish.
await coro await service_task
return return
# No call limit (i.e., potentially starting one or more fully blocking scripts) async def async_cancel_service_task():
# so watch for a stop request. # Stop service task and wait for it to finish.
done, pending = await asyncio.wait( service_task.cancel()
{self._stop.wait(), coro}, return_when=asyncio.FIRST_COMPLETED, try:
) await service_task
# Note that cancelling the service call, if it has not yet returned, will also except Exception: # pylint: disable=broad-except
# stop any non-background script runs that it may have started. pass
for pending_task in pending:
pending_task.cancel() # No call limit so watch for a stop request.
# Propagate any exceptions that might have happened. stop_task = self._hass.async_create_task(self._stop.wait())
for done_task in done: try:
done_task.result() await asyncio.wait(
{service_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
)
# If our task is cancelled, then cancel service task, too. Note that if service
# task is cancelled otherwise the CancelledError exception will not be raised to
# here due to the call to asyncio.wait(). Rather we'll check for that below.
except asyncio.CancelledError:
await async_cancel_service_task()
raise
finally:
stop_task.cancel()
if service_task.cancelled():
raise asyncio.CancelledError
if service_task.done():
# Propagate any exceptions that occurred.
service_task.result()
elif running_script:
# Stopped before service completed, so cancel service.
await async_cancel_service_task()
class _QueuedScriptRun(_ScriptRun): class _QueuedScriptRun(_ScriptRun):
@ -459,12 +486,18 @@ class _QueuedScriptRun(_ScriptRun):
lock_task = self._hass.async_create_task( lock_task = self._hass.async_create_task(
self._script._queue_lck.acquire() # pylint: disable=protected-access self._script._queue_lck.acquire() # pylint: disable=protected-access
) )
done, pending = await asyncio.wait( stop_task = self._hass.async_create_task(self._stop.wait())
{self._stop.wait(), lock_task}, return_when=asyncio.FIRST_COMPLETED try:
) await asyncio.wait(
for pending_task in pending: {lock_task, stop_task}, return_when=asyncio.FIRST_COMPLETED
pending_task.cancel() )
self.lock_acquired = lock_task in done except asyncio.CancelledError:
lock_task.cancel()
self._finish()
raise
finally:
stop_task.cancel()
self.lock_acquired = lock_task.done() and not lock_task.cancelled()
# If we've been told to stop, then just finish up. Otherwise, we've acquired the # If we've been told to stop, then just finish up. Otherwise, we've acquired the
# lock so we can go ahead and start the run. # lock so we can go ahead and start the run.