diff --git a/homeassistant/components/config/device_registry.py b/homeassistant/components/config/device_registry.py index 42d2386977f..74c15da2f00 100644 --- a/homeassistant/components/config/device_registry.py +++ b/homeassistant/components/config/device_registry.py @@ -1,22 +1,17 @@ """HTTP views to interact with the device registry.""" from __future__ import annotations -from typing import Any +from typing import Any, cast import voluptuous as vol from homeassistant import loader from homeassistant.components import websocket_api from homeassistant.components.websocket_api.decorators import require_admin -from homeassistant.components.websocket_api.messages import ( - IDEN_JSON_TEMPLATE, - IDEN_TEMPLATE, - message_to_json, -) -from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.device_registry import ( - EVENT_DEVICE_REGISTRY_UPDATED, + DeviceEntry, DeviceEntryDisabler, async_get, ) @@ -25,44 +20,6 @@ from homeassistant.helpers.device_registry import ( async def async_setup(hass): """Enable the Device Registry views.""" - cached_list_devices: str | None = None - - @callback - def _async_clear_list_device_cache(event: Event) -> None: - nonlocal cached_list_devices - cached_list_devices = None - - @callback - @websocket_api.websocket_command( - { - vol.Required("type"): "config/device_registry/list", - } - ) - def websocket_list_devices( - hass: HomeAssistant, - connection: websocket_api.ActiveConnection, - msg: dict[str, Any], - ) -> None: - """Handle list devices command.""" - nonlocal cached_list_devices - if not cached_list_devices: - registry = async_get(hass) - cached_list_devices = message_to_json( - websocket_api.result_message( - IDEN_TEMPLATE, # type: ignore[arg-type] - [_entry_dict(entry) for entry in registry.devices.values()], - ) - ) - connection.send_message( - cached_list_devices.replace(IDEN_JSON_TEMPLATE, str(msg["id"]), 1) - ) - - hass.bus.async_listen( - EVENT_DEVICE_REGISTRY_UPDATED, - _async_clear_list_device_cache, - run_immediately=True, - ) - websocket_api.async_register_command(hass, websocket_list_devices) websocket_api.async_register_command(hass, websocket_update_device) websocket_api.async_register_command( @@ -71,6 +28,37 @@ async def async_setup(hass): return True +@callback +@websocket_api.websocket_command( + { + vol.Required("type"): "config/device_registry/list", + } +) +def websocket_list_devices( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Handle list devices command.""" + registry = async_get(hass) + # Build start of response message + msg_json_prefix = ( + f'{{"id":{msg["id"]},"type": "{websocket_api.const.TYPE_RESULT}",' + f'"success":true,"result": [' + ) + # Concatenate cached entity registry item JSON serializations + msg_json = ( + msg_json_prefix + + ",".join( + entry.json_repr + for entry in registry.devices.values() + if entry.json_repr is not None + ) + + "]}" + ) + connection.send_message(msg_json) + + @require_admin @websocket_api.websocket_command( { @@ -98,9 +86,9 @@ def websocket_update_device( if msg.get("disabled_by") is not None: msg["disabled_by"] = DeviceEntryDisabler(msg["disabled_by"]) - entry = registry.async_update_device(**msg) + entry = cast(DeviceEntry, registry.async_update_device(**msg)) - connection.send_message(websocket_api.result_message(msg_id, _entry_dict(entry))) + connection.send_message(websocket_api.result_message(msg_id, entry.dict_repr)) @websocket_api.require_admin @@ -151,28 +139,6 @@ async def websocket_remove_config_entry_from_device( device_id, remove_config_entry_id=config_entry_id ) - entry_as_dict = _entry_dict(entry) if entry else None + entry_as_dict = entry.dict_repr if entry else None connection.send_message(websocket_api.result_message(msg["id"], entry_as_dict)) - - -@callback -def _entry_dict(entry): - """Convert entry to API format.""" - return { - "area_id": entry.area_id, - "configuration_url": entry.configuration_url, - "config_entries": list(entry.config_entries), - "connections": list(entry.connections), - "disabled_by": entry.disabled_by, - "entry_type": entry.entry_type, - "hw_version": entry.hw_version, - "id": entry.id, - "identifiers": list(entry.identifiers), - "manufacturer": entry.manufacturer, - "model": entry.model, - "name_by_user": entry.name_by_user, - "name": entry.name, - "sw_version": entry.sw_version, - "via_device_id": entry.via_device_id, - } diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 6f2dc22f1dd..b63814b8960 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections import UserDict -from collections.abc import Coroutine +from collections.abc import Coroutine, ValuesView import logging import time from typing import TYPE_CHECKING, Any, TypeVar, cast @@ -14,11 +14,16 @@ from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.core import Event, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError, RequiredParameterMissing from homeassistant.loader import bind_hass +from homeassistant.util.json import ( + find_paths_unserializable_data, + format_unserializable_data, +) import homeassistant.util.uuid as uuid_util from . import storage from .debounce import Debouncer from .frame import report +from .json import JSON_DUMP from .typing import UNDEFINED, UndefinedType if TYPE_CHECKING: @@ -89,11 +94,53 @@ class DeviceEntry: # This value is not stored, just used to keep track of events to fire. is_new: bool = attr.ib(default=False) + _json_repr: str | None = attr.ib(cmp=False, default=None, init=False, repr=False) + @property def disabled(self) -> bool: """Return if entry is disabled.""" return self.disabled_by is not None + @property + def dict_repr(self) -> dict[str, Any]: + """Return a dict representation of the entry.""" + return { + "area_id": self.area_id, + "configuration_url": self.configuration_url, + "config_entries": list(self.config_entries), + "connections": list(self.connections), + "disabled_by": self.disabled_by, + "entry_type": self.entry_type, + "hw_version": self.hw_version, + "id": self.id, + "identifiers": list(self.identifiers), + "manufacturer": self.manufacturer, + "model": self.model, + "name_by_user": self.name_by_user, + "name": self.name, + "sw_version": self.sw_version, + "via_device_id": self.via_device_id, + } + + @property + def json_repr(self) -> str | None: + """Return a cached JSON representation of the entry.""" + if self._json_repr is not None: + return self._json_repr + + try: + dict_repr = self.dict_repr + object.__setattr__(self, "_json_repr", JSON_DUMP(dict_repr)) + except (ValueError, TypeError): + _LOGGER.error( + "Unable to serialize entry %s to JSON. Bad data found at %s", + self.id, + format_unserializable_data( + find_paths_unserializable_data(dict_repr, dump=JSON_DUMP) + ), + ) + return self._json_repr + @attr.s(slots=True, frozen=True) class DeletedDeviceEntry: @@ -199,6 +246,10 @@ class DeviceRegistryItems(UserDict[str, _EntryTypeT]): self._connections: dict[tuple[str, str], _EntryTypeT] = {} self._identifiers: dict[tuple[str, str], _EntryTypeT] = {} + def values(self) -> ValuesView[_EntryTypeT]: + """Return the underlying values to avoid __iter__ overhead.""" + return self.data.values() + def __setitem__(self, key: str, entry: _EntryTypeT) -> None: """Add an item.""" if key in self: diff --git a/tests/components/config/test_device_registry.py b/tests/components/config/test_device_registry.py index 4f47e463751..487658ddb27 100644 --- a/tests/components/config/test_device_registry.py +++ b/tests/components/config/test_device_registry.py @@ -85,7 +85,10 @@ async def test_list_devices(hass, client, registry): }, ] - registry.async_remove_device(device2.id) + class Unserializable: + """Good luck serializing me.""" + + registry.async_update_device(device2.id, name=Unserializable()) await hass.async_block_till_done() await client.send_json({"id": 6, "type": "config/device_registry/list"}) @@ -111,6 +114,9 @@ async def test_list_devices(hass, client, registry): } ] + # Remove the bad device to avoid errors when test is being torn down + registry.async_remove_device(device2.id) + @pytest.mark.parametrize( "payload_key,payload_value",