mirror of
https://github.com/home-assistant/core.git
synced 2025-11-09 19:09:32 +00:00
Move logic into Entity class
This commit is contained in:
@@ -526,6 +526,9 @@ class Entity(
|
||||
__capabilities_updated_at_reported: bool = False
|
||||
__remove_future: asyncio.Future[None] | None = None
|
||||
|
||||
# Remember we keep track of included entities
|
||||
__init_track_included_entities: bool = False
|
||||
|
||||
# Entity Properties
|
||||
_attr_assumed_state: bool = False
|
||||
_attr_attribution: str | None = None
|
||||
@@ -541,6 +544,8 @@ class Entity(
|
||||
_attr_extra_state_attributes: dict[str, Any]
|
||||
_attr_force_update: bool
|
||||
_attr_icon: str | None
|
||||
_attr_included_entities: list[str]
|
||||
_attr_included_unique_ids: list[str]
|
||||
_attr_name: str | None
|
||||
_attr_should_poll: bool = True
|
||||
_attr_state: StateType = STATE_UNKNOWN
|
||||
@@ -1635,6 +1640,74 @@ class Entity(
|
||||
self.hass, integration_domain=platform_name, module=type(self).__module__
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_set_included_entities(
|
||||
self, integration_domain: str, unique_ids: list[str]
|
||||
) -> None:
|
||||
"""Set the list of included entities identified by their unique IDs.
|
||||
|
||||
Integrations need to initialize this in entity.async_async_added_to_hass,
|
||||
and when the list of included entities changes.
|
||||
The entity ids of included entities will will be looked up and they will be
|
||||
tracked for changes.
|
||||
None existing entities for the supplied unique IDs will be ignored.
|
||||
"""
|
||||
entity_registry = er.async_get(self.hass)
|
||||
self._attr_included_unique_ids = unique_ids
|
||||
assert self.entity_id is not None
|
||||
|
||||
def _update_group_entity_ids() -> None:
|
||||
self._attr_included_entities = []
|
||||
for included_id in self.included_unique_ids:
|
||||
if entity_id := entity_registry.async_get_entity_id(
|
||||
self.platform.domain, integration_domain, included_id
|
||||
):
|
||||
self._attr_included_entities.append(entity_id)
|
||||
|
||||
async def _handle_entity_registry_updated(event: Event[Any]) -> None:
|
||||
"""Handle registry create or update event."""
|
||||
if (
|
||||
event.data["action"] in {"create", "update"}
|
||||
and (entry := entity_registry.async_get(event.data["entity_id"]))
|
||||
and entry.unique_id in self.included_unique_ids
|
||||
) or (
|
||||
event.data["action"] == "remove"
|
||||
and self.included_entities is not None
|
||||
and event.data["entity_id"] in self.included_entities
|
||||
):
|
||||
_update_group_entity_ids()
|
||||
self.async_write_ha_state()
|
||||
|
||||
if not self.__init_track_included_entities:
|
||||
self.async_on_remove(
|
||||
self.hass.bus.async_listen(
|
||||
er.EVENT_ENTITY_REGISTRY_UPDATED,
|
||||
_handle_entity_registry_updated,
|
||||
)
|
||||
)
|
||||
self.__init_track_included_entities = True
|
||||
_update_group_entity_ids()
|
||||
|
||||
@property
|
||||
def included_unique_ids(self) -> list[str]:
|
||||
"""Return the list of unique IDs if the entity represents a group.
|
||||
|
||||
The corresponding entities will be shown as members in the UI.
|
||||
"""
|
||||
if hasattr(self, "_attr_included_unique_ids"):
|
||||
return self._attr_included_unique_ids
|
||||
return []
|
||||
|
||||
@property
|
||||
def included_entities(self) -> list[str] | None:
|
||||
"""Return a list of entity IDs if the entity represents a group.
|
||||
|
||||
Included entities will be shown as members in the UI.
|
||||
"""
|
||||
if hasattr(self, "_attr_included_entities"):
|
||||
return self._attr_included_entities
|
||||
return None
|
||||
|
||||
|
||||
class ToggleEntityDescription(EntityDescription, frozen_or_thawed=True):
|
||||
"""A class that describes toggle entities."""
|
||||
@@ -1699,89 +1772,3 @@ class ToggleEntity(
|
||||
await self.async_turn_off(**kwargs)
|
||||
else:
|
||||
await self.async_turn_on(**kwargs)
|
||||
|
||||
|
||||
class IncludedEntitiesMixin(Entity):
|
||||
"""Mixin class to include entities that are contained.
|
||||
|
||||
Integrations can include the this Mixin class to
|
||||
include the `entity_id` state attribute.
|
||||
"""
|
||||
|
||||
_attr_included_entities: list[str]
|
||||
_attr_included_unique_ids: list[str]
|
||||
__initialized: bool = False
|
||||
|
||||
@callback
|
||||
def async_set_included_entities(
|
||||
self, integration_domain: str, unique_ids: list[str]
|
||||
) -> None:
|
||||
"""Set the list of included entities identified by their unique IDs.
|
||||
|
||||
Integrations need to initialize this in entity.async_async_added_to_hass,
|
||||
and when the list of included entities changes.
|
||||
The entity ids of included entities will will be looked up and they will be
|
||||
tracked for changes.
|
||||
None existing entities for the supplied unique IDs will be ignored.
|
||||
"""
|
||||
self._integration_domain = integration_domain
|
||||
self._attr_included_unique_ids = unique_ids
|
||||
self._monitor_member_updates()
|
||||
|
||||
@property
|
||||
def included_unique_ids(self) -> list[str]:
|
||||
"""Return the list of unique IDs if the entity represents a group.
|
||||
|
||||
The corresponding entities will be shown as members in the UI.
|
||||
"""
|
||||
if hasattr(self, "_attr_included_unique_ids"):
|
||||
return self._attr_included_unique_ids
|
||||
return []
|
||||
|
||||
@property
|
||||
def included_entities(self) -> list[str] | None:
|
||||
"""Return a list of entity IDs if the entity represents a group.
|
||||
|
||||
Included entities will be shown as members in the UI.
|
||||
"""
|
||||
if hasattr(self, "_attr_included_entities"):
|
||||
return self._attr_included_entities
|
||||
return None
|
||||
|
||||
@callback
|
||||
def _monitor_member_updates(self) -> None:
|
||||
"""Update the group members if the entity registry is updated."""
|
||||
entity_registry = er.async_get(self.hass)
|
||||
assert self.entity_id is not None
|
||||
|
||||
def _update_group_entity_ids() -> None:
|
||||
self._attr_included_entities = []
|
||||
for included_id in self.included_unique_ids:
|
||||
if entity_id := entity_registry.async_get_entity_id(
|
||||
self.platform.domain, self._integration_domain, included_id
|
||||
):
|
||||
self._attr_included_entities.append(entity_id)
|
||||
|
||||
async def _handle_entity_registry_updated(event: Event[Any]) -> None:
|
||||
"""Handle registry create or update event."""
|
||||
if (
|
||||
event.data["action"] in {"create", "update"}
|
||||
and (entry := entity_registry.async_get(event.data["entity_id"]))
|
||||
and entry.unique_id in self.included_unique_ids
|
||||
) or (
|
||||
event.data["action"] == "remove"
|
||||
and self.included_entities is not None
|
||||
and event.data["entity_id"] in self.included_entities
|
||||
):
|
||||
_update_group_entity_ids()
|
||||
self.async_write_ha_state()
|
||||
|
||||
if not self.__initialized:
|
||||
self.async_on_remove(
|
||||
self.hass.bus.async_listen(
|
||||
er.EVENT_ENTITY_REGISTRY_UPDATED,
|
||||
_handle_entity_registry_updated,
|
||||
)
|
||||
)
|
||||
self.__initialized = True
|
||||
_update_group_entity_ids()
|
||||
|
||||
@@ -2936,9 +2936,7 @@ async def test_included_entities_mixin(
|
||||
|
||||
return None
|
||||
|
||||
class MockHelloIncludedEntitiesClass(
|
||||
MockHelloBaseClass, entity.IncludedEntitiesMixin
|
||||
):
|
||||
class MockHelloIncludedEntitiesClass(MockHelloBaseClass, entity.Entity):
|
||||
""".Mock hello grouped entity class for a test integration."""
|
||||
|
||||
platform = MockEntityPlatform(hass, domain="hello")
|
||||
|
||||
Reference in New Issue
Block a user