Refactor entity registry JSON cache (#85085)

* Refactor entity registry JSON cache

* Fix generator

* Tweak

* Improve string building

* Improve test coverage

* Override EntityRegistryItems.values to avoid __iter__ overhead
This commit is contained in:
Erik Montnemery 2023-01-09 16:52:52 +01:00 committed by GitHub
parent 174cc23309
commit b933a53aa3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 47 deletions

View File

@ -9,12 +9,7 @@ from homeassistant import config_entries
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.websocket_api import ERR_NOT_FOUND from homeassistant.components.websocket_api import ERR_NOT_FOUND
from homeassistant.components.websocket_api.decorators import require_admin from homeassistant.components.websocket_api.decorators import require_admin
from homeassistant.components.websocket_api.messages import ( from homeassistant.core import HomeAssistant, callback
IDEN_JSON_TEMPLATE,
IDEN_TEMPLATE,
message_to_json,
)
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers import ( from homeassistant.helpers import (
config_validation as cv, config_validation as cv,
device_registry as dr, device_registry as dr,
@ -25,41 +20,6 @@ from homeassistant.helpers import (
async def async_setup(hass: HomeAssistant) -> bool: async def async_setup(hass: HomeAssistant) -> bool:
"""Enable the Entity Registry views.""" """Enable the Entity Registry views."""
cached_list_entities: str | None = None
@callback
def _async_clear_list_entities_cache(event: Event) -> None:
nonlocal cached_list_entities
cached_list_entities = None
@websocket_api.websocket_command(
{vol.Required("type"): "config/entity_registry/list"}
)
@callback
def websocket_list_entities(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Handle list registry entries command."""
nonlocal cached_list_entities
if not cached_list_entities:
registry = er.async_get(hass)
cached_list_entities = message_to_json(
websocket_api.result_message(
IDEN_TEMPLATE, # type: ignore[arg-type]
[_entry_dict(entry) for entry in registry.entities.values()],
)
)
connection.send_message(
cached_list_entities.replace(IDEN_JSON_TEMPLATE, str(msg["id"]), 1)
)
hass.bus.async_listen(
er.EVENT_ENTITY_REGISTRY_UPDATED,
_async_clear_list_entities_cache,
run_immediately=True,
)
websocket_api.async_register_command(hass, websocket_list_entities) websocket_api.async_register_command(hass, websocket_list_entities)
websocket_api.async_register_command(hass, websocket_get_entity) websocket_api.async_register_command(hass, websocket_get_entity)
websocket_api.async_register_command(hass, websocket_get_entities) websocket_api.async_register_command(hass, websocket_get_entities)
@ -68,6 +28,33 @@ async def async_setup(hass: HomeAssistant) -> bool:
return True return True
@websocket_api.websocket_command({vol.Required("type"): "config/entity_registry/list"})
@callback
def websocket_list_entities(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Handle list registry entries command."""
registry = er.async_get(hass)
# Build start of response message
msg_json_prefix = (
f'{{"id":{msg["id"]},"type": "{websocket_api.const.TYPE_RESULT}",'
f'"success":true,"result": ['
)
# Concatenate cached entity registry item JSON serializations
msg_json = (
msg_json_prefix
+ ",".join(
entry.json_repr
for entry in registry.entities.values()
if entry.json_repr is not None
)
+ "]}"
)
connection.send_message(msg_json)
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "config/entity_registry/get", vol.Required("type"): "config/entity_registry/get",

View File

@ -10,7 +10,7 @@ timer.
from __future__ import annotations from __future__ import annotations
from collections import UserDict from collections import UserDict
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping, ValuesView
import logging import logging
from typing import TYPE_CHECKING, Any, TypeVar, cast from typing import TYPE_CHECKING, Any, TypeVar, cast
@ -42,10 +42,15 @@ from homeassistant.core import (
from homeassistant.exceptions import MaxLengthExceeded from homeassistant.exceptions import MaxLengthExceeded
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import slugify, uuid as uuid_util from homeassistant.util import slugify, uuid as uuid_util
from homeassistant.util.json import (
find_paths_unserializable_data,
format_unserializable_data,
)
from . import device_registry as dr, storage from . import device_registry as dr, storage
from .device_registry import EVENT_DEVICE_REGISTRY_UPDATED from .device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from .frame import report from .frame import report
from .json import JSON_DUMP
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
if TYPE_CHECKING: if TYPE_CHECKING:
@ -130,6 +135,8 @@ class RegistryEntry:
translation_key: str | None = attr.ib(default=None) translation_key: str | None = attr.ib(default=None)
unit_of_measurement: str | None = attr.ib(default=None) unit_of_measurement: str | None = attr.ib(default=None)
_json_repr: str | None = attr.ib(cmp=False, default=None, init=False, repr=False)
@domain.default @domain.default
def _domain_default(self) -> str: def _domain_default(self) -> str:
"""Compute domain value.""" """Compute domain value."""
@ -145,6 +152,41 @@ class RegistryEntry:
"""Return if entry is hidden.""" """Return if entry is hidden."""
return self.hidden_by is not None return self.hidden_by is not None
@property
def json_repr(self) -> str | None:
"""Return a cached JSON representation of the entry."""
if self._json_repr is not None:
return self._json_repr
try:
dict_repr = {
"area_id": self.area_id,
"config_entry_id": self.config_entry_id,
"device_id": self.device_id,
"disabled_by": self.disabled_by,
"entity_category": self.entity_category,
"entity_id": self.entity_id,
"has_entity_name": self.has_entity_name,
"hidden_by": self.hidden_by,
"icon": self.icon,
"id": self.id,
"name": self.name,
"original_name": self.original_name,
"platform": self.platform,
"translation_key": self.translation_key,
"unique_id": self.unique_id,
}
object.__setattr__(self, "_json_repr", JSON_DUMP(dict_repr))
except (ValueError, TypeError):
_LOGGER.error(
"Unable to serialize entry %s to JSON. Bad data found at %s",
self.entity_id,
format_unserializable_data(
find_paths_unserializable_data(dict_repr, dump=JSON_DUMP)
),
)
return self._json_repr
@callback @callback
def write_unavailable_state(self, hass: HomeAssistant) -> None: def write_unavailable_state(self, hass: HomeAssistant) -> None:
"""Write the unavailable state to the state machine.""" """Write the unavailable state to the state machine."""
@ -268,6 +310,10 @@ class EntityRegistryItems(UserDict[str, "RegistryEntry"]):
self._entry_ids: dict[str, RegistryEntry] = {} self._entry_ids: dict[str, RegistryEntry] = {}
self._index: dict[tuple[str, str, str], str] = {} self._index: dict[tuple[str, str, str], str] = {}
def values(self) -> ValuesView[RegistryEntry]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def __setitem__(self, key: str, entry: RegistryEntry) -> None: def __setitem__(self, key: str, entry: RegistryEntry) -> None:
"""Add an item.""" """Add an item."""
if key in self: if key in self:

View File

@ -6,7 +6,6 @@ from homeassistant.components.config import entity_registry
from homeassistant.const import ATTR_ICON from homeassistant.const import ATTR_ICON
from homeassistant.helpers.device_registry import DeviceEntryDisabler from homeassistant.helpers.device_registry import DeviceEntryDisabler
from homeassistant.helpers.entity_registry import ( from homeassistant.helpers.entity_registry import (
EVENT_ENTITY_REGISTRY_UPDATED,
RegistryEntry, RegistryEntry,
RegistryEntryDisabler, RegistryEntryDisabler,
RegistryEntryHider, RegistryEntryHider,
@ -95,6 +94,9 @@ async def test_list_entities(hass, client):
}, },
] ]
class Unserializable:
"""Good luck serializing me."""
mock_registry( mock_registry(
hass, hass,
{ {
@ -104,13 +106,15 @@ async def test_list_entities(hass, client):
platform="test_platform", platform="test_platform",
name="Hello World", name="Hello World",
), ),
"test_domain.name_2": RegistryEntry(
entity_id="test_domain.name_2",
unique_id="6789",
platform="test_platform",
name=Unserializable(),
),
}, },
) )
hass.bus.async_fire(
EVENT_ENTITY_REGISTRY_UPDATED,
{"action": "create", "entity_id": "test_domain.no_name"},
)
await client.send_json({"id": 6, "type": "config/entity_registry/list"}) await client.send_json({"id": 6, "type": "config/entity_registry/list"})
msg = await client.receive_json() msg = await client.receive_json()