From bc8d7fc02e143b678b36007078e655989b65cea8 Mon Sep 17 00:00:00 2001 From: jbouwh Date: Thu, 16 Oct 2025 19:26:34 +0000 Subject: [PATCH] Move logic into Entity class --- homeassistant/helpers/entity.py | 159 +++++++++++++++----------------- tests/helpers/test_entity.py | 4 +- 2 files changed, 74 insertions(+), 89 deletions(-) diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index ecb0c74d0e7..8b0a0ba078f 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -526,6 +526,9 @@ class Entity( __capabilities_updated_at_reported: bool = False __remove_future: asyncio.Future[None] | None = None + # Remember we keep track of included entities + __init_track_included_entities: bool = False + # Entity Properties _attr_assumed_state: bool = False _attr_attribution: str | None = None @@ -541,6 +544,8 @@ class Entity( _attr_extra_state_attributes: dict[str, Any] _attr_force_update: bool _attr_icon: str | None + _attr_included_entities: list[str] + _attr_included_unique_ids: list[str] _attr_name: str | None _attr_should_poll: bool = True _attr_state: StateType = STATE_UNKNOWN @@ -1635,6 +1640,74 @@ class Entity( self.hass, integration_domain=platform_name, module=type(self).__module__ ) + @callback + def async_set_included_entities( + self, integration_domain: str, unique_ids: list[str] + ) -> None: + """Set the list of included entities identified by their unique IDs. + + Integrations need to initialize this in entity.async_async_added_to_hass, + and when the list of included entities changes. + The entity ids of included entities will will be looked up and they will be + tracked for changes. + None existing entities for the supplied unique IDs will be ignored. + """ + entity_registry = er.async_get(self.hass) + self._attr_included_unique_ids = unique_ids + assert self.entity_id is not None + + def _update_group_entity_ids() -> None: + self._attr_included_entities = [] + for included_id in self.included_unique_ids: + if entity_id := entity_registry.async_get_entity_id( + self.platform.domain, integration_domain, included_id + ): + self._attr_included_entities.append(entity_id) + + async def _handle_entity_registry_updated(event: Event[Any]) -> None: + """Handle registry create or update event.""" + if ( + event.data["action"] in {"create", "update"} + and (entry := entity_registry.async_get(event.data["entity_id"])) + and entry.unique_id in self.included_unique_ids + ) or ( + event.data["action"] == "remove" + and self.included_entities is not None + and event.data["entity_id"] in self.included_entities + ): + _update_group_entity_ids() + self.async_write_ha_state() + + if not self.__init_track_included_entities: + self.async_on_remove( + self.hass.bus.async_listen( + er.EVENT_ENTITY_REGISTRY_UPDATED, + _handle_entity_registry_updated, + ) + ) + self.__init_track_included_entities = True + _update_group_entity_ids() + + @property + def included_unique_ids(self) -> list[str]: + """Return the list of unique IDs if the entity represents a group. + + The corresponding entities will be shown as members in the UI. + """ + if hasattr(self, "_attr_included_unique_ids"): + return self._attr_included_unique_ids + return [] + + @property + def included_entities(self) -> list[str] | None: + """Return a list of entity IDs if the entity represents a group. + + Included entities will be shown as members in the UI. + """ + if hasattr(self, "_attr_included_entities"): + return self._attr_included_entities + return None + class ToggleEntityDescription(EntityDescription, frozen_or_thawed=True): """A class that describes toggle entities.""" @@ -1699,89 +1772,3 @@ class ToggleEntity( await self.async_turn_off(**kwargs) else: await self.async_turn_on(**kwargs) - - -class IncludedEntitiesMixin(Entity): - """Mixin class to include entities that are contained. - - Integrations can include the this Mixin class to - include the `entity_id` state attribute. - """ - - _attr_included_entities: list[str] - _attr_included_unique_ids: list[str] - __initialized: bool = False - - @callback - def async_set_included_entities( - self, integration_domain: str, unique_ids: list[str] - ) -> None: - """Set the list of included entities identified by their unique IDs. - - Integrations need to initialize this in entity.async_async_added_to_hass, - and when the list of included entities changes. - The entity ids of included entities will will be looked up and they will be - tracked for changes. - None existing entities for the supplied unique IDs will be ignored. - """ - self._integration_domain = integration_domain - self._attr_included_unique_ids = unique_ids - self._monitor_member_updates() - - @property - def included_unique_ids(self) -> list[str]: - """Return the list of unique IDs if the entity represents a group. - - The corresponding entities will be shown as members in the UI. - """ - if hasattr(self, "_attr_included_unique_ids"): - return self._attr_included_unique_ids - return [] - - @property - def included_entities(self) -> list[str] | None: - """Return a list of entity IDs if the entity represents a group. - - Included entities will be shown as members in the UI. - """ - if hasattr(self, "_attr_included_entities"): - return self._attr_included_entities - return None - - @callback - def _monitor_member_updates(self) -> None: - """Update the group members if the entity registry is updated.""" - entity_registry = er.async_get(self.hass) - assert self.entity_id is not None - - def _update_group_entity_ids() -> None: - self._attr_included_entities = [] - for included_id in self.included_unique_ids: - if entity_id := entity_registry.async_get_entity_id( - self.platform.domain, self._integration_domain, included_id - ): - self._attr_included_entities.append(entity_id) - - async def _handle_entity_registry_updated(event: Event[Any]) -> None: - """Handle registry create or update event.""" - if ( - event.data["action"] in {"create", "update"} - and (entry := entity_registry.async_get(event.data["entity_id"])) - and entry.unique_id in self.included_unique_ids - ) or ( - event.data["action"] == "remove" - and self.included_entities is not None - and event.data["entity_id"] in self.included_entities - ): - _update_group_entity_ids() - self.async_write_ha_state() - - if not self.__initialized: - self.async_on_remove( - self.hass.bus.async_listen( - er.EVENT_ENTITY_REGISTRY_UPDATED, - _handle_entity_registry_updated, - ) - ) - self.__initialized = True - _update_group_entity_ids() diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index e081f9bc843..4349328074b 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -2936,9 +2936,7 @@ async def test_included_entities_mixin( return None - class MockHelloIncludedEntitiesClass( - MockHelloBaseClass, entity.IncludedEntitiesMixin - ): + class MockHelloIncludedEntitiesClass(MockHelloBaseClass, entity.Entity): """.Mock hello grouped entity class for a test integration.""" platform = MockEntityPlatform(hass, domain="hello")