Move logic into Entity class

This commit is contained in:
jbouwh
2025-10-16 19:26:34 +00:00
parent 20a494e4f8
commit bc8d7fc02e
2 changed files with 74 additions and 89 deletions

View File

@@ -526,6 +526,9 @@ class Entity(
__capabilities_updated_at_reported: bool = False
__remove_future: asyncio.Future[None] | None = None
# Remember we keep track of included entities
__init_track_included_entities: bool = False
# Entity Properties
_attr_assumed_state: bool = False
_attr_attribution: str | None = None
@@ -541,6 +544,8 @@ class Entity(
_attr_extra_state_attributes: dict[str, Any]
_attr_force_update: bool
_attr_icon: str | None
_attr_included_entities: list[str]
_attr_included_unique_ids: list[str]
_attr_name: str | None
_attr_should_poll: bool = True
_attr_state: StateType = STATE_UNKNOWN
@@ -1635,6 +1640,74 @@ class Entity(
self.hass, integration_domain=platform_name, module=type(self).__module__
)
@callback
def async_set_included_entities(
self, integration_domain: str, unique_ids: list[str]
) -> None:
"""Set the list of included entities identified by their unique IDs.
Integrations need to initialize this in entity.async_async_added_to_hass,
and when the list of included entities changes.
The entity ids of included entities will will be looked up and they will be
tracked for changes.
None existing entities for the supplied unique IDs will be ignored.
"""
entity_registry = er.async_get(self.hass)
self._attr_included_unique_ids = unique_ids
assert self.entity_id is not None
def _update_group_entity_ids() -> None:
self._attr_included_entities = []
for included_id in self.included_unique_ids:
if entity_id := entity_registry.async_get_entity_id(
self.platform.domain, integration_domain, included_id
):
self._attr_included_entities.append(entity_id)
async def _handle_entity_registry_updated(event: Event[Any]) -> None:
"""Handle registry create or update event."""
if (
event.data["action"] in {"create", "update"}
and (entry := entity_registry.async_get(event.data["entity_id"]))
and entry.unique_id in self.included_unique_ids
) or (
event.data["action"] == "remove"
and self.included_entities is not None
and event.data["entity_id"] in self.included_entities
):
_update_group_entity_ids()
self.async_write_ha_state()
if not self.__init_track_included_entities:
self.async_on_remove(
self.hass.bus.async_listen(
er.EVENT_ENTITY_REGISTRY_UPDATED,
_handle_entity_registry_updated,
)
)
self.__init_track_included_entities = True
_update_group_entity_ids()
@property
def included_unique_ids(self) -> list[str]:
"""Return the list of unique IDs if the entity represents a group.
The corresponding entities will be shown as members in the UI.
"""
if hasattr(self, "_attr_included_unique_ids"):
return self._attr_included_unique_ids
return []
@property
def included_entities(self) -> list[str] | None:
"""Return a list of entity IDs if the entity represents a group.
Included entities will be shown as members in the UI.
"""
if hasattr(self, "_attr_included_entities"):
return self._attr_included_entities
return None
class ToggleEntityDescription(EntityDescription, frozen_or_thawed=True):
"""A class that describes toggle entities."""
@@ -1699,89 +1772,3 @@ class ToggleEntity(
await self.async_turn_off(**kwargs)
else:
await self.async_turn_on(**kwargs)
class IncludedEntitiesMixin(Entity):
"""Mixin class to include entities that are contained.
Integrations can include the this Mixin class to
include the `entity_id` state attribute.
"""
_attr_included_entities: list[str]
_attr_included_unique_ids: list[str]
__initialized: bool = False
@callback
def async_set_included_entities(
self, integration_domain: str, unique_ids: list[str]
) -> None:
"""Set the list of included entities identified by their unique IDs.
Integrations need to initialize this in entity.async_async_added_to_hass,
and when the list of included entities changes.
The entity ids of included entities will will be looked up and they will be
tracked for changes.
None existing entities for the supplied unique IDs will be ignored.
"""
self._integration_domain = integration_domain
self._attr_included_unique_ids = unique_ids
self._monitor_member_updates()
@property
def included_unique_ids(self) -> list[str]:
"""Return the list of unique IDs if the entity represents a group.
The corresponding entities will be shown as members in the UI.
"""
if hasattr(self, "_attr_included_unique_ids"):
return self._attr_included_unique_ids
return []
@property
def included_entities(self) -> list[str] | None:
"""Return a list of entity IDs if the entity represents a group.
Included entities will be shown as members in the UI.
"""
if hasattr(self, "_attr_included_entities"):
return self._attr_included_entities
return None
@callback
def _monitor_member_updates(self) -> None:
"""Update the group members if the entity registry is updated."""
entity_registry = er.async_get(self.hass)
assert self.entity_id is not None
def _update_group_entity_ids() -> None:
self._attr_included_entities = []
for included_id in self.included_unique_ids:
if entity_id := entity_registry.async_get_entity_id(
self.platform.domain, self._integration_domain, included_id
):
self._attr_included_entities.append(entity_id)
async def _handle_entity_registry_updated(event: Event[Any]) -> None:
"""Handle registry create or update event."""
if (
event.data["action"] in {"create", "update"}
and (entry := entity_registry.async_get(event.data["entity_id"]))
and entry.unique_id in self.included_unique_ids
) or (
event.data["action"] == "remove"
and self.included_entities is not None
and event.data["entity_id"] in self.included_entities
):
_update_group_entity_ids()
self.async_write_ha_state()
if not self.__initialized:
self.async_on_remove(
self.hass.bus.async_listen(
er.EVENT_ENTITY_REGISTRY_UPDATED,
_handle_entity_registry_updated,
)
)
self.__initialized = True
_update_group_entity_ids()

View File

@@ -2936,9 +2936,7 @@ async def test_included_entities_mixin(
return None
class MockHelloIncludedEntitiesClass(
MockHelloBaseClass, entity.IncludedEntitiesMixin
):
class MockHelloIncludedEntitiesClass(MockHelloBaseClass, entity.Entity):
""".Mock hello grouped entity class for a test integration."""
platform = MockEntityPlatform(hass, domain="hello")