diff --git a/homeassistant/core.py b/homeassistant/core.py index ef9436dd5bc..097e1ed7165 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -713,6 +713,8 @@ class HomeAssistant: task.add_done_callback(self._tasks.remove) task.cancel() + self.exit_code = exit_code + # stage 1 self.state = CoreState.stopping self.bus.async_fire(EVENT_HOMEASSISTANT_STOP) @@ -756,8 +758,6 @@ class HomeAssistant: "Timed out waiting for shutdown stage 3 to complete, the shutdown will" " continue" ) - - self.exit_code = exit_code self.state = CoreState.stopped if self._stopped is not None: diff --git a/tests/test_core.py b/tests/test_core.py index 6268f7f4ac1..6c040621745 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -181,6 +181,46 @@ async def test_stage_shutdown(hass: HomeAssistant) -> None: assert len(test_all) == 2 +async def test_stage_shutdown_with_exit_code(hass): + """Simulate a shutdown, test calling stuff with exit code checks.""" + test_stop = async_capture_events(hass, EVENT_HOMEASSISTANT_STOP) + test_final_write = async_capture_events(hass, EVENT_HOMEASSISTANT_FINAL_WRITE) + test_close = async_capture_events(hass, EVENT_HOMEASSISTANT_CLOSE) + test_all = async_capture_events(hass, MATCH_ALL) + + event_call_counters = [0, 0, 0] + expected_exit_code = 101 + + async def async_on_stop(event) -> None: + if hass.exit_code == expected_exit_code: + event_call_counters[0] += 1 + + async def async_on_final_write(event) -> None: + if hass.exit_code == expected_exit_code: + event_call_counters[1] += 1 + + async def async_on_close(event) -> None: + if hass.exit_code == expected_exit_code: + event_call_counters[2] += 1 + + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_on_stop) + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_FINAL_WRITE, async_on_final_write) + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, async_on_close) + + await hass.async_stop(expected_exit_code) + + assert len(test_stop) == 1 + assert len(test_close) == 1 + assert len(test_final_write) == 1 + assert len(test_all) == 2 + + assert ( + event_call_counters[0] == 1 + and event_call_counters[1] == 1 + and event_call_counters[2] == 1 + ) + + async def test_shutdown_calls_block_till_done_after_shutdown_run_callback_threadsafe( hass, ):