From cffceffe0400d82292794da99886c70220bb7a53 Mon Sep 17 00:00:00 2001 From: jbouwh Date: Tue, 11 Nov 2025 11:14:17 +0000 Subject: [PATCH] Move setup code to add_to_platform_finish --- homeassistant/helpers/entity.py | 83 ++++++++++++++------------------- 1 file changed, 35 insertions(+), 48 deletions(-) diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 79457a212db..4fbc8744572 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -526,9 +526,6 @@ class Entity( __capabilities_updated_at_reported: bool = False __remove_future: asyncio.Future[None] | None = None - # Remember we keep track of included entities - __init_track_included_entities: bool = False - # Entity Properties _attr_assumed_state: bool = False _attr_attribution: str | None = None @@ -1028,7 +1025,8 @@ class Entity( self._async_verify_state_writable() if self.hass.loop_thread_id != threading.get_ident(): report_non_thread_safe_operation("async_write_ha_state") - self._async_set_included_entities() + if self.included_unique_ids: + self._update_group_entity_ids() self._async_write_ha_state() def _stringify_state(self, available: bool) -> str: @@ -1382,8 +1380,41 @@ class Entity( self.platform = None # type: ignore[assignment] self.parallel_updates = None + def _update_group_entity_ids(self) -> None: + entity_registry = er.async_get(self.hass) + self._attr_included_entities = [] + for included_id in self.included_unique_ids: + if entity_id := entity_registry.async_get_entity_id( + self.platform.domain, self.platform.platform_name, included_id + ): + self._attr_included_entities.append(entity_id) + async def add_to_platform_finish(self) -> None: """Finish adding an entity to a platform.""" + entity_registry = er.async_get(self.hass) + + async def _handle_entity_registry_updated(event: Event[Any]) -> None: + """Handle registry create or update event.""" + if ( + event.data["action"] in {"create", "update"} + and (entry := entity_registry.async_get(event.data["entity_id"])) + and entry.unique_id in self.included_unique_ids + ) or ( + event.data["action"] == "remove" + and hasattr(self, "_attr_included_entities") + and event.data["entity_id"] in self._attr_included_entities + ): + self._update_group_entity_ids() + self.async_write_ha_state() + + if self.included_unique_ids: + self.async_on_remove( + self.hass.bus.async_listen( + er.EVENT_ENTITY_REGISTRY_UPDATED, + _handle_entity_registry_updated, + ) + ) + await self.async_internal_added_to_hass() await self.async_added_to_hass() self._platform_state = EntityPlatformState.ADDED @@ -1643,50 +1674,6 @@ class Entity( self.hass, integration_domain=platform_name, module=type(self).__module__ ) - @callback - def _async_set_included_entities(self) -> None: - """Set the list of included entities identified by their unique IDs. - - This is called just before the entity state is written. - """ - if not self.included_unique_ids: - return - - entity_registry = er.async_get(self.hass) - assert self.entity_id is not None - - def _update_group_entity_ids() -> None: - self._attr_included_entities = [] - for included_id in self.included_unique_ids: - if entity_id := entity_registry.async_get_entity_id( - self.platform.domain, self.platform.platform_name, included_id - ): - self._attr_included_entities.append(entity_id) - - async def _handle_entity_registry_updated(event: Event[Any]) -> None: - """Handle registry create or update event.""" - if ( - event.data["action"] in {"create", "update"} - and (entry := entity_registry.async_get(event.data["entity_id"])) - and entry.unique_id in self.included_unique_ids - ) or ( - event.data["action"] == "remove" - and hasattr(self, "_attr_included_entities") - and event.data["entity_id"] in self._attr_included_entities - ): - _update_group_entity_ids() - self.async_write_ha_state() - - if not self.__init_track_included_entities: - self.async_on_remove( - self.hass.bus.async_listen( - er.EVENT_ENTITY_REGISTRY_UPDATED, - _handle_entity_registry_updated, - ) - ) - self.__init_track_included_entities = True - _update_group_entity_ids() - @cached_property def included_unique_ids(self) -> list[str]: """Return the list of unique IDs if the entity represents a group.