Schedule tasks eagerly when called from hass.add_job (#113014)

This commit is contained in:
J. Nick Koston 2024-03-10 21:19:49 -10:00 committed by GitHub
parent cede16fc40
commit 3387892f59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 25 deletions

View File

@ -524,30 +524,43 @@ class HomeAssistant:
if target is None: if target is None:
raise ValueError("Don't call add_job with None") raise ValueError("Don't call add_job with None")
if asyncio.iscoroutine(target): if asyncio.iscoroutine(target):
self.loop.call_soon_threadsafe(self.async_add_job, target) self.loop.call_soon_threadsafe(
functools.partial(self.async_add_job, target, eager_start=True)
)
return return
if TYPE_CHECKING: if TYPE_CHECKING:
target = cast(Callable[..., Any], target) target = cast(Callable[..., Any], target)
self.loop.call_soon_threadsafe(self.async_add_job, target, *args) self.loop.call_soon_threadsafe(
functools.partial(self.async_add_job, target, *args, eager_start=True)
)
@overload @overload
@callback @callback
def async_add_job( def async_add_job(
self, target: Callable[..., Coroutine[Any, Any, _R]], *args: Any self,
target: Callable[..., Coroutine[Any, Any, _R]],
*args: Any,
eager_start: bool = False,
) -> asyncio.Future[_R] | None: ) -> asyncio.Future[_R] | None:
... ...
@overload @overload
@callback @callback
def async_add_job( def async_add_job(
self, target: Callable[..., Coroutine[Any, Any, _R] | _R], *args: Any self,
target: Callable[..., Coroutine[Any, Any, _R] | _R],
*args: Any,
eager_start: bool = False,
) -> asyncio.Future[_R] | None: ) -> asyncio.Future[_R] | None:
... ...
@overload @overload
@callback @callback
def async_add_job( def async_add_job(
self, target: Coroutine[Any, Any, _R], *args: Any self,
target: Coroutine[Any, Any, _R],
*args: Any,
eager_start: bool = False,
) -> asyncio.Future[_R] | None: ) -> asyncio.Future[_R] | None:
... ...
@ -556,6 +569,7 @@ class HomeAssistant:
self, self,
target: Callable[..., Coroutine[Any, Any, _R] | _R] | Coroutine[Any, Any, _R], target: Callable[..., Coroutine[Any, Any, _R] | _R] | Coroutine[Any, Any, _R],
*args: Any, *args: Any,
eager_start: bool = False,
) -> asyncio.Future[_R] | None: ) -> asyncio.Future[_R] | None:
"""Add a job to be executed by the event loop or by an executor. """Add a job to be executed by the event loop or by an executor.
@ -571,7 +585,7 @@ class HomeAssistant:
raise ValueError("Don't call async_add_job with None") raise ValueError("Don't call async_add_job with None")
if asyncio.iscoroutine(target): if asyncio.iscoroutine(target):
return self.async_create_task(target) return self.async_create_task(target, eager_start=eager_start)
# This code path is performance sensitive and uses # This code path is performance sensitive and uses
# if TYPE_CHECKING to avoid the overhead of constructing # if TYPE_CHECKING to avoid the overhead of constructing
@ -579,7 +593,7 @@ class HomeAssistant:
# https://github.com/home-assistant/core/pull/71960 # https://github.com/home-assistant/core/pull/71960
if TYPE_CHECKING: if TYPE_CHECKING:
target = cast(Callable[..., Coroutine[Any, Any, _R] | _R], target) target = cast(Callable[..., Coroutine[Any, Any, _R] | _R], target)
return self.async_add_hass_job(HassJob(target), *args) return self.async_add_hass_job(HassJob(target), *args, eager_start=eager_start)
@overload @overload
@callback @callback

View File

@ -234,7 +234,7 @@ async def async_test_home_assistant(
orig_async_create_task = hass.async_create_task orig_async_create_task = hass.async_create_task
orig_tz = dt_util.DEFAULT_TIME_ZONE orig_tz = dt_util.DEFAULT_TIME_ZONE
def async_add_job(target, *args): def async_add_job(target, *args, eager_start: bool = False):
"""Add job.""" """Add job."""
check_target = target check_target = target
while isinstance(check_target, ft.partial): while isinstance(check_target, ft.partial):
@ -245,7 +245,7 @@ async def async_test_home_assistant(
fut.set_result(target(*args)) fut.set_result(target(*args))
return fut return fut
return orig_async_add_job(target, *args) return orig_async_add_job(target, *args, eager_start=eager_start)
def async_add_executor_job(target, *args): def async_add_executor_job(target, *args):
"""Add executor job.""" """Add executor job."""

View File

@ -736,6 +736,20 @@ async def test_pending_scheduler(hass: HomeAssistant) -> None:
assert len(call_count) == 3 assert len(call_count) == 3
def test_add_job_pending_tasks_coro(hass: HomeAssistant) -> None:
"""Add a coro to pending tasks."""
async def test_coro():
"""Test Coro."""
pass
for _ in range(2):
hass.add_job(test_coro())
# Ensure add_job does not run immediately
assert len(hass._tasks) == 0
async def test_async_add_job_pending_tasks_coro(hass: HomeAssistant) -> None: async def test_async_add_job_pending_tasks_coro(hass: HomeAssistant) -> None:
"""Add a coro to pending tasks.""" """Add a coro to pending tasks."""
call_count = [] call_count = []
@ -745,18 +759,12 @@ async def test_async_add_job_pending_tasks_coro(hass: HomeAssistant) -> None:
call_count.append("call") call_count.append("call")
for _ in range(2): for _ in range(2):
hass.add_job(test_coro()) hass.async_add_job(test_coro())
async def wait_finish_callback():
"""Wait until all stuff is scheduled."""
await asyncio.sleep(0)
await asyncio.sleep(0)
await wait_finish_callback()
assert len(hass._tasks) == 2 assert len(hass._tasks) == 2
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(call_count) == 2 assert len(call_count) == 2
assert len(hass._tasks) == 0
async def test_async_create_task_pending_tasks_coro(hass: HomeAssistant) -> None: async def test_async_create_task_pending_tasks_coro(hass: HomeAssistant) -> None:
@ -768,18 +776,12 @@ async def test_async_create_task_pending_tasks_coro(hass: HomeAssistant) -> None
call_count.append("call") call_count.append("call")
for _ in range(2): for _ in range(2):
hass.create_task(test_coro()) hass.async_create_task(test_coro())
async def wait_finish_callback():
"""Wait until all stuff is scheduled."""
await asyncio.sleep(0)
await asyncio.sleep(0)
await wait_finish_callback()
assert len(hass._tasks) == 2 assert len(hass._tasks) == 2
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(call_count) == 2 assert len(call_count) == 2
assert len(hass._tasks) == 0
async def test_async_add_job_pending_tasks_executor(hass: HomeAssistant) -> None: async def test_async_add_job_pending_tasks_executor(hass: HomeAssistant) -> None: