diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 2aa5b1b8c62..c832bab7eb4 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -73,6 +73,7 @@ PATH_CONFIG = ".config_entries.json" SAVE_DELAY = 1 _T = TypeVar("_T", bound="ConfigEntryState") +_R = TypeVar("_R") class ConfigEntryState(Enum): @@ -193,6 +194,7 @@ class ConfigEntry: "_async_cancel_retry_setup", "_on_unload", "reload_lock", + "_pending_tasks", ) def __init__( @@ -285,6 +287,8 @@ class ConfigEntry: # Reload lock to prevent conflicting reloads self.reload_lock = asyncio.Lock() + self._pending_tasks: list[asyncio.Future[Any]] = [] + async def async_setup( self, hass: HomeAssistant, @@ -366,7 +370,7 @@ class ConfigEntry: self.domain, auth_message, ) - self._async_process_on_unload() + await self._async_process_on_unload() self.async_start_reauth(hass) result = False except ConfigEntryNotReady as ex: @@ -406,7 +410,7 @@ class ConfigEntry: EVENT_HOMEASSISTANT_STARTED, setup_again ) - self._async_process_on_unload() + await self._async_process_on_unload() return except Exception: # pylint: disable=broad-except _LOGGER.exception( @@ -494,7 +498,7 @@ class ConfigEntry: self.state = ConfigEntryState.NOT_LOADED self.reason = None - self._async_process_on_unload() + await self._async_process_on_unload() # https://github.com/python/mypy/issues/11839 return result # type: ignore[no-any-return] @@ -619,13 +623,18 @@ class ConfigEntry: self._on_unload = [] self._on_unload.append(func) - @callback - def _async_process_on_unload(self) -> None: - """Process the on_unload callbacks.""" + async def _async_process_on_unload(self) -> None: + """Process the on_unload callbacks and wait for pending tasks.""" if self._on_unload is not None: while self._on_unload: self._on_unload.pop()() + while self._pending_tasks: + pending = [task for task in self._pending_tasks if not task.done()] + self._pending_tasks.clear() + if pending: + await asyncio.gather(*pending) + @callback def async_start_reauth(self, hass: HomeAssistant) -> None: """Start a reauth flow.""" @@ -648,6 +657,22 @@ class ConfigEntry: ) ) + @callback + def async_create_task( + self, hass: HomeAssistant, target: Coroutine[Any, Any, _R] + ) -> asyncio.Task[_R]: + """Create a task from within the eventloop. + + This method must be run in the event loop. + + target: target to call. + """ + task = hass.async_create_task(target) + + self._pending_tasks.append(task) + + return task + current_entry: ContextVar[ConfigEntry | None] = ContextVar( "current_entry", default=None diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index ec71778af12..6253b939bed 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -214,8 +214,9 @@ class EntityPlatform: def async_create_setup_task() -> Coroutine: """Get task to set up platform.""" config_entries.current_entry.set(config_entry) + return platform.async_setup_entry( # type: ignore[no-any-return,union-attr] - self.hass, config_entry, self._async_schedule_add_entities + self.hass, config_entry, self._async_schedule_add_entities_for_entry ) return await self._async_setup_platform(async_create_setup_task) @@ -334,6 +335,20 @@ class EntityPlatform: if not self._setup_complete: self._tasks.append(task) + @callback + def _async_schedule_add_entities_for_entry( + self, new_entities: Iterable[Entity], update_before_add: bool = False + ) -> None: + """Schedule adding entities for a single platform async and track the task.""" + assert self.config_entry + task = self.config_entry.async_create_task( + self.hass, + self.async_add_entities(new_entities, update_before_add=update_before_add), + ) + + if not self._setup_complete: + self._tasks.append(task) + def add_entities( self, new_entities: Iterable[Entity], update_before_add: bool = False ) -> None: diff --git a/tests/components/cast/test_media_player.py b/tests/components/cast/test_media_player.py index e4df84f6443..00626cc8c16 100644 --- a/tests/components/cast/test_media_player.py +++ b/tests/components/cast/test_media_player.py @@ -127,7 +127,7 @@ async def async_setup_cast(hass, config=None): config = {} data = {**{"ignore_cec": [], "known_hosts": [], "uuid": []}, **config} with patch( - "homeassistant.helpers.entity_platform.EntityPlatform._async_schedule_add_entities" + "homeassistant.helpers.entity_platform.EntityPlatform._async_schedule_add_entities_for_entry" ) as add_entities: entry = MockConfigEntry(data=data, domain="cast") entry.add_to_hass(hass)