Move setup code to add_to_platform_finish

This commit is contained in:
jbouwh
2025-11-11 11:14:17 +00:00
parent 253189805e
commit cffceffe04

View File

@@ -526,9 +526,6 @@ class Entity(
__capabilities_updated_at_reported: bool = False
__remove_future: asyncio.Future[None] | None = None
# Remember we keep track of included entities
__init_track_included_entities: bool = False
# Entity Properties
_attr_assumed_state: bool = False
_attr_attribution: str | None = None
@@ -1028,7 +1025,8 @@ class Entity(
self._async_verify_state_writable()
if self.hass.loop_thread_id != threading.get_ident():
report_non_thread_safe_operation("async_write_ha_state")
self._async_set_included_entities()
if self.included_unique_ids:
self._update_group_entity_ids()
self._async_write_ha_state()
def _stringify_state(self, available: bool) -> str:
@@ -1382,8 +1380,41 @@ class Entity(
self.platform = None # type: ignore[assignment]
self.parallel_updates = None
def _update_group_entity_ids(self) -> None:
entity_registry = er.async_get(self.hass)
self._attr_included_entities = []
for included_id in self.included_unique_ids:
if entity_id := entity_registry.async_get_entity_id(
self.platform.domain, self.platform.platform_name, included_id
):
self._attr_included_entities.append(entity_id)
async def add_to_platform_finish(self) -> None:
"""Finish adding an entity to a platform."""
entity_registry = er.async_get(self.hass)
async def _handle_entity_registry_updated(event: Event[Any]) -> None:
"""Handle registry create or update event."""
if (
event.data["action"] in {"create", "update"}
and (entry := entity_registry.async_get(event.data["entity_id"]))
and entry.unique_id in self.included_unique_ids
) or (
event.data["action"] == "remove"
and hasattr(self, "_attr_included_entities")
and event.data["entity_id"] in self._attr_included_entities
):
self._update_group_entity_ids()
self.async_write_ha_state()
if self.included_unique_ids:
self.async_on_remove(
self.hass.bus.async_listen(
er.EVENT_ENTITY_REGISTRY_UPDATED,
_handle_entity_registry_updated,
)
)
await self.async_internal_added_to_hass()
await self.async_added_to_hass()
self._platform_state = EntityPlatformState.ADDED
@@ -1643,50 +1674,6 @@ class Entity(
self.hass, integration_domain=platform_name, module=type(self).__module__
)
@callback
def _async_set_included_entities(self) -> None:
"""Set the list of included entities identified by their unique IDs.
This is called just before the entity state is written.
"""
if not self.included_unique_ids:
return
entity_registry = er.async_get(self.hass)
assert self.entity_id is not None
def _update_group_entity_ids() -> None:
self._attr_included_entities = []
for included_id in self.included_unique_ids:
if entity_id := entity_registry.async_get_entity_id(
self.platform.domain, self.platform.platform_name, included_id
):
self._attr_included_entities.append(entity_id)
async def _handle_entity_registry_updated(event: Event[Any]) -> None:
"""Handle registry create or update event."""
if (
event.data["action"] in {"create", "update"}
and (entry := entity_registry.async_get(event.data["entity_id"]))
and entry.unique_id in self.included_unique_ids
) or (
event.data["action"] == "remove"
and hasattr(self, "_attr_included_entities")
and event.data["entity_id"] in self._attr_included_entities
):
_update_group_entity_ids()
self.async_write_ha_state()
if not self.__init_track_included_entities:
self.async_on_remove(
self.hass.bus.async_listen(
er.EVENT_ENTITY_REGISTRY_UPDATED,
_handle_entity_registry_updated,
)
)
self.__init_track_included_entities = True
_update_group_entity_ids()
@cached_property
def included_unique_ids(self) -> list[str]:
"""Return the list of unique IDs if the entity represents a group.