Keep capabilities up to date in the entity registry (#101748)

* Keep capabilities up to date in the entity registry

* Warn if entities update their capabilities very often

* Fix updating of device class

* Stop tracking capability updates once flooding is logged

* Only sync registry if state changed

* Improve test

* Revert "Only sync registry if state changed"

This reverts commit 1c52571596c06444df234d4b088242b494b630f2.

* Avoid calculating device class twice

* Address review comments

* Revert using dataclass

* Fix unintended revert

* Add helper method
This commit is contained in:
Erik Montnemery 2023-12-13 17:27:26 +01:00 committed by GitHub
parent 4f9f548929
commit dd5a48996a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 257 additions and 15 deletions

View File

@ -509,7 +509,8 @@ class GroupEntity(Entity):
self.async_update_supported_features( self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"] event.data["entity_id"], event.data["new_state"]
) )
preview_callback(*self._async_generate_attributes()) calculated_state = self._async_calculate_state()
preview_callback(calculated_state.state, calculated_state.attributes)
async_state_changed_listener(None) async_state_changed_listener(None)
return async_track_state_change_event( return async_track_state_change_event(

View File

@ -236,7 +236,8 @@ class MediaPlayerGroup(MediaPlayerEntity):
) -> None: ) -> None:
"""Handle child updates.""" """Handle child updates."""
self.async_update_group_state() self.async_update_group_state()
preview_callback(*self._async_generate_attributes()) calculated_state = self._async_calculate_state()
preview_callback(calculated_state.state, calculated_state.attributes)
async_state_changed_listener(None) async_state_changed_listener(None)
return async_track_state_change_event( return async_track_state_change_event(

View File

@ -430,14 +430,17 @@ class TemplateEntity(Entity):
return return
try: try:
state, attrs = self._async_generate_attributes() calculated_state = self._async_calculate_state()
validate_state(state) validate_state(calculated_state.state)
except Exception as err: # pylint: disable=broad-exception-caught except Exception as err: # pylint: disable=broad-exception-caught
self._preview_callback(None, None, None, str(err)) self._preview_callback(None, None, None, str(err))
else: else:
assert self._template_result_info assert self._template_result_info
self._preview_callback( self._preview_callback(
state, attrs, self._template_result_info.listeners, None calculated_state.state,
calculated_state.attributes,
self._template_result_info.listeners,
None,
) )
@callback @callback

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from abc import ABC from abc import ABC
import asyncio import asyncio
from collections import deque
from collections.abc import Coroutine, Iterable, Mapping, MutableMapping from collections.abc import Coroutine, Iterable, Mapping, MutableMapping
import dataclasses import dataclasses
from datetime import timedelta from datetime import timedelta
@ -75,6 +76,9 @@ DATA_ENTITY_SOURCE = "entity_info"
# epsilon to make the string representation readable # epsilon to make the string representation readable
FLOAT_PRECISION = abs(int(math.floor(math.log10(abs(sys.float_info.epsilon))))) - 1 FLOAT_PRECISION = abs(int(math.floor(math.log10(abs(sys.float_info.epsilon))))) - 1
# How many times per hour we allow capabilities to be updated before logging a warning
CAPABILITIES_UPDATE_LIMIT = 100
@callback @callback
def async_setup(hass: HomeAssistant) -> None: def async_setup(hass: HomeAssistant) -> None:
@ -237,6 +241,22 @@ class EntityDescription(metaclass=FrozenOrThawed, frozen_or_thawed=True):
unit_of_measurement: str | None = None unit_of_measurement: str | None = None
@dataclasses.dataclass(frozen=True, slots=True)
class CalculatedState:
"""Container with state and attributes.
Returned by Entity._async_calculate_state.
"""
state: str
# The union of all attributes, after overriding with entity registry settings
attributes: dict[str, Any]
# Capability attributes returned by the capability_attributes property
capability_attributes: Mapping[str, Any] | None
# Attributes which may be overridden by the entity registry
shadowed_attributes: Mapping[str, Any]
class Entity(ABC): class Entity(ABC):
"""An abstract class for Home Assistant entities.""" """An abstract class for Home Assistant entities."""
@ -311,6 +331,8 @@ class Entity(ABC):
# and removes the need for constant None checks or asserts. # and removes the need for constant None checks or asserts.
_state_info: StateInfo = None # type: ignore[assignment] _state_info: StateInfo = None # type: ignore[assignment]
__capabilities_updated_at: deque[float]
__capabilities_updated_at_reported: bool = False
__remove_event: asyncio.Event | None = None __remove_event: asyncio.Event | None = None
# Entity Properties # Entity Properties
@ -775,12 +797,29 @@ class Entity(ABC):
return f"{device_name} {name}" if device_name else name return f"{device_name} {name}" if device_name else name
@callback @callback
def _async_generate_attributes(self) -> tuple[str, dict[str, Any]]: def _async_calculate_state(self) -> CalculatedState:
"""Calculate state string and attribute mapping.""" """Calculate state string and attribute mapping."""
return CalculatedState(*self.__async_calculate_state())
def __async_calculate_state(
self,
) -> tuple[str, dict[str, Any], Mapping[str, Any] | None, Mapping[str, Any]]:
"""Calculate state string and attribute mapping.
Returns a tuple (state, attr, capability_attr, shadowed_attr).
state - the stringified state
attr - the attribute dictionary
capability_attr - a mapping with capability attributes
shadowed_attr - a mapping with attributes which may be overridden
This method is called when writing the state to avoid the overhead of creating
a dataclass object.
"""
entry = self.registry_entry entry = self.registry_entry
attr = self.capability_attributes capability_attr = self.capability_attributes
attr = dict(attr) if attr else {} attr = dict(capability_attr) if capability_attr else {}
shadowed_attr = {}
available = self.available # only call self.available once per update cycle available = self.available # only call self.available once per update cycle
state = self._stringify_state(available) state = self._stringify_state(available)
@ -797,26 +836,30 @@ class Entity(ABC):
if (attribution := self.attribution) is not None: if (attribution := self.attribution) is not None:
attr[ATTR_ATTRIBUTION] = attribution attr[ATTR_ATTRIBUTION] = attribution
shadowed_attr[ATTR_DEVICE_CLASS] = self.device_class
if ( if (
device_class := (entry and entry.device_class) or self.device_class device_class := (entry and entry.device_class)
or shadowed_attr[ATTR_DEVICE_CLASS]
) is not None: ) is not None:
attr[ATTR_DEVICE_CLASS] = str(device_class) attr[ATTR_DEVICE_CLASS] = str(device_class)
if (entity_picture := self.entity_picture) is not None: if (entity_picture := self.entity_picture) is not None:
attr[ATTR_ENTITY_PICTURE] = entity_picture attr[ATTR_ENTITY_PICTURE] = entity_picture
if (icon := (entry and entry.icon) or self.icon) is not None: shadowed_attr[ATTR_ICON] = self.icon
if (icon := (entry and entry.icon) or shadowed_attr[ATTR_ICON]) is not None:
attr[ATTR_ICON] = icon attr[ATTR_ICON] = icon
shadowed_attr[ATTR_FRIENDLY_NAME] = self._friendly_name_internal()
if ( if (
name := (entry and entry.name) or self._friendly_name_internal() name := (entry and entry.name) or shadowed_attr[ATTR_FRIENDLY_NAME]
) is not None: ) is not None:
attr[ATTR_FRIENDLY_NAME] = name attr[ATTR_FRIENDLY_NAME] = name
if (supported_features := self.supported_features) is not None: if (supported_features := self.supported_features) is not None:
attr[ATTR_SUPPORTED_FEATURES] = supported_features attr[ATTR_SUPPORTED_FEATURES] = supported_features
return (state, attr) return (state, attr, capability_attr, shadowed_attr)
@callback @callback
def _async_write_ha_state(self) -> None: def _async_write_ha_state(self) -> None:
@ -842,9 +885,45 @@ class Entity(ABC):
return return
start = timer() start = timer()
state, attr = self._async_generate_attributes() state, attr, capabilities, shadowed_attr = self.__async_calculate_state()
end = timer() end = timer()
if entry:
# Make sure capabilities in the entity registry are up to date. Capabilities
# include capability attributes, device class and supported features
original_device_class: str | None = shadowed_attr[ATTR_DEVICE_CLASS]
supported_features: int = attr.get(ATTR_SUPPORTED_FEATURES) or 0
if (
capabilities != entry.capabilities
or original_device_class != entry.original_device_class
or supported_features != entry.supported_features
):
if not self.__capabilities_updated_at_reported:
time_now = hass.loop.time()
capabilities_updated_at = self.__capabilities_updated_at
capabilities_updated_at.append(time_now)
while time_now - capabilities_updated_at[0] > 3600:
capabilities_updated_at.popleft()
if len(capabilities_updated_at) > CAPABILITIES_UPDATE_LIMIT:
self.__capabilities_updated_at_reported = True
report_issue = self._suggest_report_issue()
_LOGGER.warning(
(
"Entity %s (%s) is updating its capabilities too often,"
" please %s"
),
entity_id,
type(self),
report_issue,
)
entity_registry = er.async_get(self.hass)
self.registry_entry = entity_registry.async_update_entity(
self.entity_id,
capabilities=capabilities,
original_device_class=original_device_class,
supported_features=supported_features,
)
if end - start > 0.4 and not self._slow_reported: if end - start > 0.4 and not self._slow_reported:
self._slow_reported = True self._slow_reported = True
report_issue = self._suggest_report_issue() report_issue = self._suggest_report_issue()
@ -1118,6 +1197,8 @@ class Entity(ABC):
) )
self._async_subscribe_device_updates() self._async_subscribe_device_updates()
self.__capabilities_updated_at = deque(maxlen=CAPABILITIES_UPDATE_LIMIT + 1)
async def async_internal_will_remove_from_hass(self) -> None: async def async_internal_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass. """Run when entity will be removed from hass.

View File

@ -8,6 +8,7 @@ import threading
from typing import Any from typing import Any
from unittest.mock import MagicMock, PropertyMock, patch from unittest.mock import MagicMock, PropertyMock, patch
from freezegun.api import FrozenDateTimeFactory
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
import voluptuous as vol import voluptuous as vol
@ -1412,8 +1413,8 @@ async def test_repr_using_stringify_state() -> None:
"""Return the state.""" """Return the state."""
raise ValueError("Boom") raise ValueError("Boom")
entity = MyEntity(entity_id="test.test", available=False) my_entity = MyEntity(entity_id="test.test", available=False)
assert str(entity) == "<entity test.test=unavailable>" assert str(my_entity) == "<entity test.test=unavailable>"
async def test_warn_using_async_update_ha_state( async def test_warn_using_async_update_ha_state(
@ -1761,3 +1762,158 @@ def test_extending_entity_description(snapshot: SnapshotAssertion):
assert obj == snapshot assert obj == snapshot
assert obj == CustomInitEntityDescription(key="blah", extra="foo", name="name") assert obj == CustomInitEntityDescription(key="blah", extra="foo", name="name")
assert repr(obj) == snapshot assert repr(obj) == snapshot
async def test_update_capabilities(
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
) -> None:
"""Test entity capabilities are updated automatically."""
platform = MockEntityPlatform(hass)
ent = MockEntity(unique_id="qwer")
await platform.async_add_entities([ent])
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities is None
assert entry.device_class is None
assert entry.supported_features == 0
ent._values["capability_attributes"] = {"bla": "blu"}
ent._values["device_class"] = "some_class"
ent._values["supported_features"] = 127
ent.async_write_ha_state()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities == {"bla": "blu"}
assert entry.original_device_class == "some_class"
assert entry.supported_features == 127
ent._values["capability_attributes"] = None
ent._values["device_class"] = None
ent._values["supported_features"] = None
ent.async_write_ha_state()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities is None
assert entry.original_device_class is None
assert entry.supported_features == 0
# Device class can be overridden by user, make sure that does not break the
# automatic updating.
entity_registry.async_update_entity(ent.entity_id, device_class="set_by_user")
await hass.async_block_till_done()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities is None
assert entry.original_device_class is None
assert entry.supported_features == 0
# This will not trigger a state change because the device class is shadowed
# by the entity registry
ent._values["device_class"] = "some_class"
ent.async_write_ha_state()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities is None
assert entry.original_device_class == "some_class"
assert entry.supported_features == 0
async def test_update_capabilities_no_unique_id(
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
) -> None:
"""Test entity capabilities are updated automatically."""
platform = MockEntityPlatform(hass)
ent = MockEntity()
await platform.async_add_entities([ent])
assert entity_registry.async_get(ent.entity_id) is None
ent._values["capability_attributes"] = {"bla": "blu"}
ent._values["supported_features"] = 127
ent.async_write_ha_state()
assert entity_registry.async_get(ent.entity_id) is None
async def test_update_capabilities_too_often(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
entity_registry: er.EntityRegistry,
) -> None:
"""Test entity capabilities are updated automatically."""
capabilities_too_often_warning = "is updating its capabilities too often"
platform = MockEntityPlatform(hass)
ent = MockEntity(unique_id="qwer")
await platform.async_add_entities([ent])
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities is None
assert entry.device_class is None
assert entry.supported_features == 0
for supported_features in range(1, entity.CAPABILITIES_UPDATE_LIMIT + 1):
ent._values["capability_attributes"] = {"bla": "blu"}
ent._values["device_class"] = "some_class"
ent._values["supported_features"] = supported_features
ent.async_write_ha_state()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities == {"bla": "blu"}
assert entry.original_device_class == "some_class"
assert entry.supported_features == supported_features
assert capabilities_too_often_warning not in caplog.text
ent._values["capability_attributes"] = {"bla": "blu"}
ent._values["device_class"] = "some_class"
ent._values["supported_features"] = supported_features + 1
ent.async_write_ha_state()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities == {"bla": "blu"}
assert entry.original_device_class == "some_class"
assert entry.supported_features == supported_features + 1
assert capabilities_too_often_warning in caplog.text
async def test_update_capabilities_too_often_cooldown(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
entity_registry: er.EntityRegistry,
freezer: FrozenDateTimeFactory,
) -> None:
"""Test entity capabilities are updated automatically."""
capabilities_too_often_warning = "is updating its capabilities too often"
platform = MockEntityPlatform(hass)
ent = MockEntity(unique_id="qwer")
await platform.async_add_entities([ent])
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities is None
assert entry.device_class is None
assert entry.supported_features == 0
for supported_features in range(1, entity.CAPABILITIES_UPDATE_LIMIT + 1):
ent._values["capability_attributes"] = {"bla": "blu"}
ent._values["device_class"] = "some_class"
ent._values["supported_features"] = supported_features
ent.async_write_ha_state()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities == {"bla": "blu"}
assert entry.original_device_class == "some_class"
assert entry.supported_features == supported_features
assert capabilities_too_often_warning not in caplog.text
freezer.tick(timedelta(minutes=60) + timedelta(seconds=1))
ent._values["capability_attributes"] = {"bla": "blu"}
ent._values["device_class"] = "some_class"
ent._values["supported_features"] = supported_features + 1
ent.async_write_ha_state()
entry = entity_registry.async_get(ent.entity_id)
assert entry.capabilities == {"bla": "blu"}
assert entry.original_device_class == "some_class"
assert entry.supported_features == supported_features + 1
assert capabilities_too_often_warning not in caplog.text