diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index ff95d3513dc..34872d8cec1 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -310,8 +310,10 @@ class ConfigEntry: # Function to cancel a scheduled retry self._async_cancel_retry_setup: Callable[[], Any] | None = None - # Hold list for functions to call on unload. - self._on_unload: list[CALLBACK_TYPE] | None = None + # Hold list for actions to call on unload. + self._on_unload: list[ + Callable[[], Coroutine[Any, Any, None] | None] + ] | None = None # Reload lock to prevent conflicting reloads self.reload_lock = asyncio.Lock() @@ -395,7 +397,7 @@ class ConfigEntry: self.domain, error_reason, ) - await self._async_process_on_unload() + await self._async_process_on_unload(hass) result = False except ConfigEntryAuthFailed as ex: message = str(ex) @@ -410,7 +412,7 @@ class ConfigEntry: self.domain, auth_message, ) - await self._async_process_on_unload() + await self._async_process_on_unload(hass) self.async_start_reauth(hass) result = False except ConfigEntryNotReady as ex: @@ -461,7 +463,7 @@ class ConfigEntry: EVENT_HOMEASSISTANT_STARTED, setup_again ) - await self._async_process_on_unload() + await self._async_process_on_unload(hass) return # pylint: disable-next=broad-except except (asyncio.CancelledError, SystemExit, Exception): @@ -544,7 +546,7 @@ class ConfigEntry: if result and integration.domain == self.domain: self.async_set_state(hass, ConfigEntryState.NOT_LOADED, None) - await self._async_process_on_unload() + await self._async_process_on_unload(hass) return result except Exception as ex: # pylint: disable=broad-except @@ -674,17 +676,20 @@ class ConfigEntry: } @callback - def async_on_unload(self, func: CALLBACK_TYPE) -> None: + def async_on_unload( + self, func: Callable[[], Coroutine[Any, Any, None] | None] + ) -> None: """Add a function to call when config entry is unloaded.""" if self._on_unload is None: self._on_unload = [] self._on_unload.append(func) - async def _async_process_on_unload(self) -> None: + async def _async_process_on_unload(self, hass: HomeAssistant) -> None: """Process the on_unload callbacks and wait for pending tasks.""" if self._on_unload is not None: while self._on_unload: - self._on_unload.pop()() + if job := self._on_unload.pop()(): + self._tasks.add(hass.async_create_task(job)) if not self._tasks and not self._background_tasks: return diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 60b9a250c17..6a9258d4ede 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -3867,7 +3867,7 @@ async def test_task_tracking(hass: HomeAssistant) -> None: event = asyncio.Event() results = [] - async def test_task(): + async def test_task() -> None: try: await event.wait() results.append("normal") @@ -3875,9 +3875,14 @@ async def test_task_tracking(hass: HomeAssistant) -> None: results.append("background") raise + async def test_unload() -> None: + await event.wait() + results.append("on_unload") + + entry.async_on_unload(test_unload) entry.async_create_task(hass, test_task()) entry.async_create_background_task(hass, test_task(), "background-task-name") await asyncio.sleep(0) hass.loop.call_soon(event.set) - await entry._async_process_on_unload() - assert results == ["background", "normal"] + await entry._async_process_on_unload(hass) + assert results == ["on_unload", "background", "normal"]