diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index dc6cc84c09c..dd3f6cdbef0 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -2,7 +2,6 @@ from __future__ import annotations from collections.abc import Callable -from contextlib import suppress import datetime as dt from functools import lru_cache import json @@ -50,6 +49,17 @@ from . import const, decorators, messages from .connection import ActiveConnection from .const import ERR_NOT_FOUND +_STATES_TEMPLATE = "__STATES__" +_STATES_JSON_TEMPLATE = '"__STATES__"' +_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE = JSON_DUMP( + messages.event_message( + messages.IDEN_TEMPLATE, {messages.ENTITY_EVENT_ADD: _STATES_TEMPLATE} + ) +) +_HANDLE_GET_STATES_TEMPLATE = JSON_DUMP( + messages.result_message(messages.IDEN_TEMPLATE, _STATES_TEMPLATE) +) + @callback def async_register_commands( @@ -242,33 +252,43 @@ def handle_get_states( """Handle get states command.""" states = _async_get_allowed_states(hass, connection) - # JSON serialize here so we can recover if it blows up due to the - # state machine containing unserializable data. This command is required - # to succeed for the UI to show. - response = messages.result_message(msg["id"], states) try: - connection.send_message(JSON_DUMP(response)) - return + serialized_states = [state.as_dict_json() for state in states] except (ValueError, TypeError): - connection.logger.error( - "Unable to serialize to JSON. Bad data found at %s", - format_unserializable_data( - find_paths_unserializable_data(response, dump=JSON_DUMP) - ), - ) - del response + pass + else: + _send_handle_get_states_response(connection, msg["id"], serialized_states) + return # If we can't serialize, we'll filter out unserializable states - serialized = [] + serialized_states = [] for state in states: - # Error is already logged above - with suppress(ValueError, TypeError): - serialized.append(JSON_DUMP(state)) + try: + serialized_states.append(state.as_dict_json()) + except (ValueError, TypeError): + connection.logger.error( + "Unable to serialize to JSON. Bad data found at %s", + format_unserializable_data( + find_paths_unserializable_data(state, dump=JSON_DUMP) + ), + ) - # We now have partially serialized states. Craft some JSON. - response2 = JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"])) - response2 = response2.replace('"TO_REPLACE"', ", ".join(serialized)) - connection.send_message(response2) + _send_handle_get_states_response(connection, msg["id"], serialized_states) + + +def _send_handle_get_states_response( + connection: ActiveConnection, msg_id: int, serialized_states: list[str] +) -> None: + """Send handle get states response.""" + connection.send_message( + _HANDLE_GET_STATES_TEMPLATE.replace( + messages.IDEN_JSON_TEMPLATE, str(msg_id), 1 + ).replace( + _STATES_JSON_TEMPLATE, + "[" + ",".join(serialized_states) + "]", + 1, + ) + ) @callback @@ -304,42 +324,50 @@ def handle_subscribe_entities( EVENT_STATE_CHANGED, forward_entity_changes, run_immediately=True ) connection.send_result(msg["id"]) - data: dict[str, dict[str, dict]] = { - messages.ENTITY_EVENT_ADD: { - state.entity_id: state.as_compressed_state() - for state in states - if not entity_ids or state.entity_id in entity_ids - } - } # JSON serialize here so we can recover if it blows up due to the # state machine containing unserializable data. This command is required # to succeed for the UI to show. - response = messages.event_message(msg["id"], data) try: - connection.send_message(JSON_DUMP(response)) - return + serialized_states = [ + state.as_compressed_state_json() + for state in states + if not entity_ids or state.entity_id in entity_ids + ] except (ValueError, TypeError): - connection.logger.error( - "Unable to serialize to JSON. Bad data found at %s", - format_unserializable_data( - find_paths_unserializable_data(response, dump=JSON_DUMP) - ), - ) - del response + pass + else: + _send_handle_entities_init_response(connection, msg["id"], serialized_states) + return - add_entities = data[messages.ENTITY_EVENT_ADD] - cannot_serialize: list[str] = [] - for entity_id, state_dict in add_entities.items(): + serialized_states = [] + for state in states: try: - JSON_DUMP(state_dict) + serialized_states.append(state.as_compressed_state_json()) except (ValueError, TypeError): - cannot_serialize.append(entity_id) + connection.logger.error( + "Unable to serialize to JSON. Bad data found at %s", + format_unserializable_data( + find_paths_unserializable_data(state, dump=JSON_DUMP) + ), + ) - for entity_id in cannot_serialize: - del add_entities[entity_id] + _send_handle_entities_init_response(connection, msg["id"], serialized_states) - connection.send_message(JSON_DUMP(messages.event_message(msg["id"], data))) + +def _send_handle_entities_init_response( + connection: ActiveConnection, msg_id: int, serialized_states: list[str] +) -> None: + """Send handle entities init response.""" + connection.send_message( + _HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE.replace( + messages.IDEN_JSON_TEMPLATE, str(msg_id), 1 + ).replace( + _STATES_JSON_TEMPLATE, + "{" + ",".join(serialized_states) + "}", + 1, + ) + ) @decorators.websocket_command({vol.Required("type"): "get_services"}) diff --git a/homeassistant/components/websocket_api/messages.py b/homeassistant/components/websocket_api/messages.py index 341aebdf9e4..4fc604df2dc 100644 --- a/homeassistant/components/websocket_api/messages.py +++ b/homeassistant/components/websocket_api/messages.py @@ -44,7 +44,7 @@ ENTITY_EVENT_REMOVE = "r" ENTITY_EVENT_CHANGE = "c" -def result_message(iden: int, result: Any = None) -> dict[str, Any]: +def result_message(iden: JSON_TYPE | int, result: Any = None) -> dict[str, Any]: """Return a success result message.""" return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result} diff --git a/homeassistant/core.py b/homeassistant/core.py index 2fed56644e4..9560d0f1031 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -80,6 +80,7 @@ from .exceptions import ( Unauthorized, ) from .helpers.aiohttp_compat import restore_original_aiohttp_cancel_behavior +from .helpers.json import json_dumps from .util import dt as dt_util, location, ulid as ulid_util from .util.async_ import run_callback_threadsafe, shutdown_run_callback_threadsafe from .util.read_only_dict import ReadOnlyDict @@ -1224,6 +1225,8 @@ class State: "object_id", "_as_dict", "_as_compressed_state", + "_as_dict_json", + "_as_compressed_state_json", ) def __init__( @@ -1260,6 +1263,8 @@ class State: self.domain, self.object_id = split_entity_id(self.entity_id) self._as_dict: ReadOnlyDict[str, Collection[Any]] | None = None self._as_compressed_state: dict[str, Any] | None = None + self._as_dict_json: str | None = None + self._as_compressed_state_json: str | None = None @property def name(self) -> str: @@ -1294,6 +1299,12 @@ class State: ) return self._as_dict + def as_dict_json(self) -> str: + """Return a JSON string of the State.""" + if not self._as_dict_json: + self._as_dict_json = json_dumps(self.as_dict()) + return self._as_dict_json + def as_compressed_state(self) -> dict[str, Any]: """Build a compressed dict of a state for adds. @@ -1321,6 +1332,19 @@ class State: self._as_compressed_state = compressed_state return compressed_state + def as_compressed_state_json(self) -> str: + """Build a compressed JSON key value pair of a state for adds. + + The JSON string is a key value pair of the entity_id and the compressed state. + + It is used for sending multiple states in a single message. + """ + if not self._as_compressed_state_json: + self._as_compressed_state_json = json_dumps( + {self.entity_id: self.as_compressed_state()} + )[1:-1] + return self._as_compressed_state_json + @classmethod def from_dict(cls, json_dict: dict[str, Any]) -> Self | None: """Initialize a state from a dict. diff --git a/homeassistant/helpers/json.py b/homeassistant/helpers/json.py index c15436ed2c1..5f8a6894121 100644 --- a/homeassistant/helpers/json.py +++ b/homeassistant/helpers/json.py @@ -9,7 +9,6 @@ from typing import Any, Final import orjson -from homeassistant.core import Event, State from homeassistant.util.file import write_utf8_file, write_utf8_file_atomic from homeassistant.util.json import ( # pylint: disable=unused-import # noqa: F401 JSON_DECODE_EXCEPTIONS, @@ -189,6 +188,11 @@ def find_paths_unserializable_data( This method is slow! Only use for error handling. """ + from homeassistant.core import ( # pylint: disable=import-outside-toplevel + Event, + State, + ) + to_process = deque([(bad_data, "$")]) invalid = {} diff --git a/tests/components/websocket_api/test_http.py b/tests/components/websocket_api/test_http.py index 7c008f7515b..02384aace89 100644 --- a/tests/components/websocket_api/test_http.py +++ b/tests/components/websocket_api/test_http.py @@ -188,10 +188,9 @@ async def test_non_json_message( assert msg["type"] == const.TYPE_RESULT assert msg["success"] assert msg["result"] == [] - assert ( - f"Unable to serialize to JSON. Bad data found at $.result[0](State: test_domain.entity).attributes.bad={bad_data}(" - in caplog.text - ) + assert "Unable to serialize to JSON. Bad data found" in caplog.text + assert "State: test_domain.entity" in caplog.text + assert "bad= None: assert state.as_dict() is as_dict_1 +def test_state_as_dict_json() -> None: + """Test a State as JSON.""" + last_time = datetime(1984, 12, 8, 12, 0, 0) + state = ha.State( + "happy.happy", + "on", + {"pig": "dog"}, + last_updated=last_time, + last_changed=last_time, + context=ha.Context(id="01H0D6K3RFJAYAV2093ZW30PCW"), + ) + expected = ( + '{"entity_id":"happy.happy","state":"on","attributes":{"pig":"dog"},' + '"last_changed":"1984-12-08T12:00:00","last_updated":"1984-12-08T12:00:00",' + '"context":{"id":"01H0D6K3RFJAYAV2093ZW30PCW","parent_id":null,"user_id":null}}' + ) + as_dict_json_1 = state.as_dict_json() + assert as_dict_json_1 == expected + # 2nd time to verify cache + assert state.as_dict_json() == expected + assert state.as_dict_json() is as_dict_json_1 + + def test_state_as_compressed_state() -> None: """Test a State as compressed state.""" last_time = datetime(1984, 12, 8, 12, 0, 0, tzinfo=dt_util.UTC) @@ -518,6 +541,27 @@ def test_state_as_compressed_state_unique_last_updated() -> None: assert state.as_compressed_state() is as_compressed_state +def test_state_as_compressed_state_json() -> None: + """Test a State as a JSON compressed state.""" + last_time = datetime(1984, 12, 8, 12, 0, 0, tzinfo=dt_util.UTC) + state = ha.State( + "happy.happy", + "on", + {"pig": "dog"}, + last_updated=last_time, + last_changed=last_time, + context=ha.Context(id="01H0D6H5K3SZJ3XGDHED1TJ79N"), + ) + expected = '"happy.happy":{"s":"on","a":{"pig":"dog"},"c":"01H0D6H5K3SZJ3XGDHED1TJ79N","lc":471355200.0}' + as_compressed_state = state.as_compressed_state_json() + # We are not too concerned about these being ReadOnlyDict + # since we don't expect them to be called by external callers + assert as_compressed_state == expected + # 2nd time to verify cache + assert state.as_compressed_state_json() == expected + assert state.as_compressed_state_json() is as_compressed_state + + async def test_eventbus_add_remove_listener(hass: HomeAssistant) -> None: """Test remove_listener method.""" old_count = len(hass.bus.async_listeners())