diff --git a/homeassistant/helpers/start.py b/homeassistant/helpers/start.py index 805ac193834..4560119a685 100644 --- a/homeassistant/helpers/start.py +++ b/homeassistant/helpers/start.py @@ -4,23 +4,24 @@ from __future__ import annotations from collections.abc import Awaitable, Callable from homeassistant.const import EVENT_HOMEASSISTANT_START -from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.core import Event, HassJob, HomeAssistant, callback @callback def async_at_start( - hass: HomeAssistant, at_start_cb: Callable[[HomeAssistant], Awaitable] + hass: HomeAssistant, at_start_cb: Callable[[HomeAssistant], Awaitable[None] | None] ) -> None: """Execute something when Home Assistant is started. Will execute it now if Home Assistant is already started. """ + at_start_job = HassJob(at_start_cb) if hass.is_running: - hass.async_create_task(at_start_cb(hass)) + hass.async_run_hass_job(at_start_job, hass) return async def _matched_event(event: Event) -> None: """Call the callback when Home Assistant started.""" - await at_start_cb(hass) + hass.async_run_hass_job(at_start_job, hass) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, _matched_event) diff --git a/tests/helpers/test_start.py b/tests/helpers/test_start.py index 35838f1ceaa..55f98cf60eb 100644 --- a/tests/helpers/test_start.py +++ b/tests/helpers/test_start.py @@ -4,8 +4,9 @@ from homeassistant.const import EVENT_HOMEASSISTANT_START from homeassistant.helpers import start -async def test_at_start_when_running(hass): +async def test_at_start_when_running_awaitable(hass): """Test at start when already running.""" + assert hass.state == core.CoreState.running assert hass.is_running calls = [] @@ -18,8 +19,37 @@ async def test_at_start_when_running(hass): await hass.async_block_till_done() assert len(calls) == 1 + hass.state = core.CoreState.starting + assert hass.is_running -async def test_at_start_when_starting(hass): + start.async_at_start(hass, cb_at_start) + await hass.async_block_till_done() + assert len(calls) == 2 + + +async def test_at_start_when_running_callback(hass): + """Test at start when already running.""" + assert hass.state == core.CoreState.running + assert hass.is_running + + calls = [] + + @core.callback + def cb_at_start(hass): + """Home Assistant is started.""" + calls.append(1) + + start.async_at_start(hass, cb_at_start) + assert len(calls) == 1 + + hass.state = core.CoreState.starting + assert hass.is_running + + start.async_at_start(hass, cb_at_start) + assert len(calls) == 2 + + +async def test_at_start_when_starting_awaitable(hass): """Test at start when yet to start.""" hass.state = core.CoreState.not_running assert not hass.is_running @@ -37,3 +67,24 @@ async def test_at_start_when_starting(hass): hass.bus.async_fire(EVENT_HOMEASSISTANT_START) await hass.async_block_till_done() assert len(calls) == 1 + + +async def test_at_start_when_starting_callback(hass): + """Test at start when yet to start.""" + hass.state = core.CoreState.not_running + assert not hass.is_running + + calls = [] + + @core.callback + def cb_at_start(hass): + """Home Assistant is started.""" + calls.append(1) + + start.async_at_start(hass, cb_at_start) + await hass.async_block_till_done() + assert len(calls) == 0 + + hass.bus.async_fire(EVENT_HOMEASSISTANT_START) + await hass.async_block_till_done() + assert len(calls) == 1