Start script runs eagerly (#113190)

This commit is contained in:
J. Nick Koston 2024-03-14 16:53:26 -10:00 committed by GitHub
parent 92e73312ea
commit bdede0e0da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 33 additions and 8 deletions

View File

@ -77,6 +77,7 @@ from homeassistant.core import (
callback, callback,
) )
from homeassistant.util import slugify from homeassistant.util import slugify
from homeassistant.util.async_ import create_eager_task
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from . import condition, config_validation as cv, service, template from . import condition, config_validation as cv, service, template
@ -1611,7 +1612,7 @@ class Script:
self._changed() self._changed()
try: try:
return await asyncio.shield(run.async_run()) return await asyncio.shield(create_eager_task(run.async_run()))
except asyncio.CancelledError: except asyncio.CancelledError:
await run.async_stop() await run.async_stop()
self._changed() self._changed()

View File

@ -44,6 +44,7 @@ async def async_turn_on(
} }
await hass.services.async_call(DOMAIN, SERVICE_TURN_ON, data, blocking=True) await hass.services.async_call(DOMAIN, SERVICE_TURN_ON, data, blocking=True)
await hass.async_block_till_done()
async def async_turn_off(hass, entity_id=ENTITY_MATCH_ALL) -> None: async def async_turn_off(hass, entity_id=ENTITY_MATCH_ALL) -> None:
@ -51,6 +52,7 @@ async def async_turn_off(hass, entity_id=ENTITY_MATCH_ALL) -> None:
data = {ATTR_ENTITY_ID: entity_id} if entity_id else {} data = {ATTR_ENTITY_ID: entity_id} if entity_id else {}
await hass.services.async_call(DOMAIN, SERVICE_TURN_OFF, data, blocking=True) await hass.services.async_call(DOMAIN, SERVICE_TURN_OFF, data, blocking=True)
await hass.async_block_till_done()
async def async_oscillate( async def async_oscillate(
@ -67,6 +69,7 @@ async def async_oscillate(
} }
await hass.services.async_call(DOMAIN, SERVICE_OSCILLATE, data, blocking=True) await hass.services.async_call(DOMAIN, SERVICE_OSCILLATE, data, blocking=True)
await hass.async_block_till_done()
async def async_set_preset_mode( async def async_set_preset_mode(
@ -80,6 +83,7 @@ async def async_set_preset_mode(
} }
await hass.services.async_call(DOMAIN, SERVICE_SET_PRESET_MODE, data, blocking=True) await hass.services.async_call(DOMAIN, SERVICE_SET_PRESET_MODE, data, blocking=True)
await hass.async_block_till_done()
async def async_set_percentage( async def async_set_percentage(
@ -93,6 +97,7 @@ async def async_set_percentage(
} }
await hass.services.async_call(DOMAIN, SERVICE_SET_PERCENTAGE, data, blocking=True) await hass.services.async_call(DOMAIN, SERVICE_SET_PERCENTAGE, data, blocking=True)
await hass.async_block_till_done()
async def async_increase_speed( async def async_increase_speed(
@ -109,6 +114,7 @@ async def async_increase_speed(
} }
await hass.services.async_call(DOMAIN, SERVICE_INCREASE_SPEED, data, blocking=True) await hass.services.async_call(DOMAIN, SERVICE_INCREASE_SPEED, data, blocking=True)
await hass.async_block_till_done()
async def async_decrease_speed( async def async_decrease_speed(
@ -125,6 +131,7 @@ async def async_decrease_speed(
} }
await hass.services.async_call(DOMAIN, SERVICE_DECREASE_SPEED, data, blocking=True) await hass.services.async_call(DOMAIN, SERVICE_DECREASE_SPEED, data, blocking=True)
await hass.async_block_till_done()
async def async_set_direction( async def async_set_direction(
@ -138,3 +145,4 @@ async def async_set_direction(
} }
await hass.services.async_call(DOMAIN, SERVICE_SET_DIRECTION, data, blocking=True) await hass.services.async_call(DOMAIN, SERVICE_SET_DIRECTION, data, blocking=True)
await hass.async_block_till_done()

View File

@ -1207,7 +1207,10 @@ async def test_if_not_fires_on_entities_change_with_for_after_stop(
"below": below, "below": below,
"for": {"seconds": 5}, "for": {"seconds": 5},
}, },
"action": {"service": "test.automation"}, "action": [
{"delay": "0.0001"},
{"service": "test.automation"},
],
} }
}, },
) )
@ -1833,7 +1836,10 @@ async def test_attribute_if_not_fires_on_entities_change_with_for_after_stop(
"attribute": "test-measurement", "attribute": "test-measurement",
"for": 5, "for": 5,
}, },
"action": {"service": "test.automation"}, "action": [
{"delay": "0.0001"},
{"service": "test.automation"},
],
} }
}, },
) )

View File

@ -666,7 +666,10 @@ async def test_if_not_fires_on_entities_change_with_for_after_stop(
"to": "world", "to": "world",
"for": {"seconds": 5}, "for": {"seconds": 5},
}, },
"action": {"service": "test.automation"}, "action": [
{"delay": "0.0001"},
{"service": "test.automation"},
],
} }
}, },
) )
@ -1624,7 +1627,10 @@ async def test_attribute_if_not_fires_on_entities_change_with_for_after_stop(
"attribute": "name", "attribute": "name",
"for": 5, "for": 5,
}, },
"action": {"service": "test.automation"}, "action": [
{"delay": "0.0001"},
{"service": "test.automation"},
],
} }
}, },
) )

View File

@ -428,6 +428,7 @@ async def test_set_invalid_direction_from_initial_stage(
await common.async_turn_on(hass, _TEST_FAN) await common.async_turn_on(hass, _TEST_FAN)
await common.async_set_direction(hass, _TEST_FAN, "invalid") await common.async_set_direction(hass, _TEST_FAN, "invalid")
assert hass.states.get(_DIRECTION_INPUT_SELECT).state == "" assert hass.states.get(_DIRECTION_INPUT_SELECT).state == ""
_verify(hass, STATE_ON, 0, None, None, None) _verify(hass, STATE_ON, 0, None, None, None)

View File

@ -3441,7 +3441,8 @@ async def test_parallel_loop(
script_obj = script.Script(hass, sequence, "Test Name", "test_domain") script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
hass.async_create_task( hass.async_create_task(
script_obj.async_run(MappingProxyType({"what": "world"}), Context()) script_obj.async_run(MappingProxyType({"what": "world"}), Context()),
eager_start=True,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -3456,7 +3457,6 @@ async def test_parallel_loop(
expected_trace = { expected_trace = {
"0": [{"variables": {"what": "world"}}], "0": [{"variables": {"what": "world"}}],
"0/parallel/0/sequence/0": [{}], "0/parallel/0/sequence/0": [{}],
"0/parallel/1/sequence/0": [{}],
"0/parallel/0/sequence/0/repeat/sequence/0": [ "0/parallel/0/sequence/0/repeat/sequence/0": [
{ {
"variables": { "variables": {
@ -3492,6 +3492,7 @@ async def test_parallel_loop(
"result": {"event": "loop1", "event_data": {"hello1": "loop1_c"}}, "result": {"event": "loop1", "event_data": {"hello1": "loop1_c"}},
}, },
], ],
"0/parallel/1/sequence/0": [{}],
"0/parallel/1/sequence/0/repeat/sequence/0": [ "0/parallel/1/sequence/0/repeat/sequence/0": [
{ {
"variables": { "variables": {
@ -4118,7 +4119,9 @@ async def test_max_exceeded(
) )
hass.states.async_set("switch.test", "on") hass.states.async_set("switch.test", "on")
for _ in range(max_runs + 1): for _ in range(max_runs + 1):
hass.async_create_task(script_obj.async_run(context=Context())) hass.async_create_task(
script_obj.async_run(context=Context()), eager_start=True
)
hass.states.async_set("switch.test", "off") hass.states.async_set("switch.test", "off")
await hass.async_block_till_done() await hass.async_block_till_done()
if max_exceeded is None: if max_exceeded is None: