Refactor device registry JSON cache (#85539)

This commit is contained in:
Erik Montnemery 2023-01-09 20:50:27 +01:00 committed by GitHub
parent a8f95c36a6
commit 8983f665cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 96 additions and 73 deletions

View File

@ -1,22 +1,17 @@
"""HTTP views to interact with the device registry.""" """HTTP views to interact with the device registry."""
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any, cast
import voluptuous as vol import voluptuous as vol
from homeassistant import loader from homeassistant import loader
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.decorators import require_admin from homeassistant.components.websocket_api.decorators import require_admin
from homeassistant.components.websocket_api.messages import ( from homeassistant.core import HomeAssistant, callback
IDEN_JSON_TEMPLATE,
IDEN_TEMPLATE,
message_to_json,
)
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.device_registry import ( from homeassistant.helpers.device_registry import (
EVENT_DEVICE_REGISTRY_UPDATED, DeviceEntry,
DeviceEntryDisabler, DeviceEntryDisabler,
async_get, async_get,
) )
@ -25,44 +20,6 @@ from homeassistant.helpers.device_registry import (
async def async_setup(hass): async def async_setup(hass):
"""Enable the Device Registry views.""" """Enable the Device Registry views."""
cached_list_devices: str | None = None
@callback
def _async_clear_list_device_cache(event: Event) -> None:
nonlocal cached_list_devices
cached_list_devices = None
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "config/device_registry/list",
}
)
def websocket_list_devices(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Handle list devices command."""
nonlocal cached_list_devices
if not cached_list_devices:
registry = async_get(hass)
cached_list_devices = message_to_json(
websocket_api.result_message(
IDEN_TEMPLATE, # type: ignore[arg-type]
[_entry_dict(entry) for entry in registry.devices.values()],
)
)
connection.send_message(
cached_list_devices.replace(IDEN_JSON_TEMPLATE, str(msg["id"]), 1)
)
hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED,
_async_clear_list_device_cache,
run_immediately=True,
)
websocket_api.async_register_command(hass, websocket_list_devices) websocket_api.async_register_command(hass, websocket_list_devices)
websocket_api.async_register_command(hass, websocket_update_device) websocket_api.async_register_command(hass, websocket_update_device)
websocket_api.async_register_command( websocket_api.async_register_command(
@ -71,6 +28,37 @@ async def async_setup(hass):
return True return True
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "config/device_registry/list",
}
)
def websocket_list_devices(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Handle list devices command."""
registry = async_get(hass)
# Build start of response message
msg_json_prefix = (
f'{{"id":{msg["id"]},"type": "{websocket_api.const.TYPE_RESULT}",'
f'"success":true,"result": ['
)
# Concatenate cached entity registry item JSON serializations
msg_json = (
msg_json_prefix
+ ",".join(
entry.json_repr
for entry in registry.devices.values()
if entry.json_repr is not None
)
+ "]}"
)
connection.send_message(msg_json)
@require_admin @require_admin
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
@ -98,9 +86,9 @@ def websocket_update_device(
if msg.get("disabled_by") is not None: if msg.get("disabled_by") is not None:
msg["disabled_by"] = DeviceEntryDisabler(msg["disabled_by"]) msg["disabled_by"] = DeviceEntryDisabler(msg["disabled_by"])
entry = registry.async_update_device(**msg) entry = cast(DeviceEntry, registry.async_update_device(**msg))
connection.send_message(websocket_api.result_message(msg_id, _entry_dict(entry))) connection.send_message(websocket_api.result_message(msg_id, entry.dict_repr))
@websocket_api.require_admin @websocket_api.require_admin
@ -151,28 +139,6 @@ async def websocket_remove_config_entry_from_device(
device_id, remove_config_entry_id=config_entry_id device_id, remove_config_entry_id=config_entry_id
) )
entry_as_dict = _entry_dict(entry) if entry else None entry_as_dict = entry.dict_repr if entry else None
connection.send_message(websocket_api.result_message(msg["id"], entry_as_dict)) connection.send_message(websocket_api.result_message(msg["id"], entry_as_dict))
@callback
def _entry_dict(entry):
"""Convert entry to API format."""
return {
"area_id": entry.area_id,
"configuration_url": entry.configuration_url,
"config_entries": list(entry.config_entries),
"connections": list(entry.connections),
"disabled_by": entry.disabled_by,
"entry_type": entry.entry_type,
"hw_version": entry.hw_version,
"id": entry.id,
"identifiers": list(entry.identifiers),
"manufacturer": entry.manufacturer,
"model": entry.model,
"name_by_user": entry.name_by_user,
"name": entry.name,
"sw_version": entry.sw_version,
"via_device_id": entry.via_device_id,
}

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections import UserDict from collections import UserDict
from collections.abc import Coroutine from collections.abc import Coroutine, ValuesView
import logging import logging
import time import time
from typing import TYPE_CHECKING, Any, TypeVar, cast from typing import TYPE_CHECKING, Any, TypeVar, cast
@ -14,11 +14,16 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, RequiredParameterMissing from homeassistant.exceptions import HomeAssistantError, RequiredParameterMissing
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util.json import (
find_paths_unserializable_data,
format_unserializable_data,
)
import homeassistant.util.uuid as uuid_util import homeassistant.util.uuid as uuid_util
from . import storage from . import storage
from .debounce import Debouncer from .debounce import Debouncer
from .frame import report from .frame import report
from .json import JSON_DUMP
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
if TYPE_CHECKING: if TYPE_CHECKING:
@ -89,11 +94,53 @@ class DeviceEntry:
# This value is not stored, just used to keep track of events to fire. # This value is not stored, just used to keep track of events to fire.
is_new: bool = attr.ib(default=False) is_new: bool = attr.ib(default=False)
_json_repr: str | None = attr.ib(cmp=False, default=None, init=False, repr=False)
@property @property
def disabled(self) -> bool: def disabled(self) -> bool:
"""Return if entry is disabled.""" """Return if entry is disabled."""
return self.disabled_by is not None return self.disabled_by is not None
@property
def dict_repr(self) -> dict[str, Any]:
"""Return a dict representation of the entry."""
return {
"area_id": self.area_id,
"configuration_url": self.configuration_url,
"config_entries": list(self.config_entries),
"connections": list(self.connections),
"disabled_by": self.disabled_by,
"entry_type": self.entry_type,
"hw_version": self.hw_version,
"id": self.id,
"identifiers": list(self.identifiers),
"manufacturer": self.manufacturer,
"model": self.model,
"name_by_user": self.name_by_user,
"name": self.name,
"sw_version": self.sw_version,
"via_device_id": self.via_device_id,
}
@property
def json_repr(self) -> str | None:
"""Return a cached JSON representation of the entry."""
if self._json_repr is not None:
return self._json_repr
try:
dict_repr = self.dict_repr
object.__setattr__(self, "_json_repr", JSON_DUMP(dict_repr))
except (ValueError, TypeError):
_LOGGER.error(
"Unable to serialize entry %s to JSON. Bad data found at %s",
self.id,
format_unserializable_data(
find_paths_unserializable_data(dict_repr, dump=JSON_DUMP)
),
)
return self._json_repr
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class DeletedDeviceEntry: class DeletedDeviceEntry:
@ -199,6 +246,10 @@ class DeviceRegistryItems(UserDict[str, _EntryTypeT]):
self._connections: dict[tuple[str, str], _EntryTypeT] = {} self._connections: dict[tuple[str, str], _EntryTypeT] = {}
self._identifiers: dict[tuple[str, str], _EntryTypeT] = {} self._identifiers: dict[tuple[str, str], _EntryTypeT] = {}
def values(self) -> ValuesView[_EntryTypeT]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def __setitem__(self, key: str, entry: _EntryTypeT) -> None: def __setitem__(self, key: str, entry: _EntryTypeT) -> None:
"""Add an item.""" """Add an item."""
if key in self: if key in self:

View File

@ -85,7 +85,10 @@ async def test_list_devices(hass, client, registry):
}, },
] ]
registry.async_remove_device(device2.id) class Unserializable:
"""Good luck serializing me."""
registry.async_update_device(device2.id, name=Unserializable())
await hass.async_block_till_done() await hass.async_block_till_done()
await client.send_json({"id": 6, "type": "config/device_registry/list"}) await client.send_json({"id": 6, "type": "config/device_registry/list"})
@ -111,6 +114,9 @@ async def test_list_devices(hass, client, registry):
} }
] ]
# Remove the bad device to avoid errors when test is being torn down
registry.async_remove_device(device2.id)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"payload_key,payload_value", "payload_key,payload_value",