diff --git a/homeassistant/components/history/__init__.py b/homeassistant/components/history/__init__.py index 77301532d3d..27acff54f99 100644 --- a/homeassistant/components/history/__init__.py +++ b/homeassistant/components/history/__init__.py @@ -24,10 +24,10 @@ from homeassistant.components.recorder.statistics import ( ) from homeassistant.components.recorder.util import session_scope from homeassistant.components.websocket_api import messages +from homeassistant.components.websocket_api.const import JSON_DUMP from homeassistant.core import HomeAssistant import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entityfilter import INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA -from homeassistant.helpers.json import JSON_DUMP from homeassistant.helpers.typing import ConfigType import homeassistant.util.dt as dt_util diff --git a/homeassistant/components/logbook/websocket_api.py b/homeassistant/components/logbook/websocket_api.py index 461ed018090..82b1db1081c 100644 --- a/homeassistant/components/logbook/websocket_api.py +++ b/homeassistant/components/logbook/websocket_api.py @@ -14,9 +14,9 @@ from homeassistant.components import websocket_api from homeassistant.components.recorder import get_instance from homeassistant.components.websocket_api import messages from homeassistant.components.websocket_api.connection import ActiveConnection +from homeassistant.components.websocket_api.const import JSON_DUMP from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback from homeassistant.helpers.event import async_track_point_in_utc_time -from homeassistant.helpers.json import JSON_DUMP import homeassistant.util.dt as dt_util from .helpers import ( diff --git a/homeassistant/components/recorder/const.py b/homeassistant/components/recorder/const.py index e94092d2154..e558d19b530 100644 --- a/homeassistant/components/recorder/const.py +++ b/homeassistant/components/recorder/const.py @@ -1,11 +1,12 @@ """Recorder constants.""" +from functools import partial +import json +from typing import Final from homeassistant.backports.enum import StrEnum from homeassistant.const import ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES -from homeassistant.helpers.json import ( # noqa: F401 pylint: disable=unused-import - JSON_DUMP, -) +from homeassistant.helpers.json import JSONEncoder DATA_INSTANCE = "recorder_instance" SQLITE_URL_PREFIX = "sqlite://" @@ -26,6 +27,7 @@ MAX_ROWS_TO_PURGE = 998 DB_WORKER_PREFIX = "DbWorker" +JSON_DUMP: Final = partial(json.dumps, cls=JSONEncoder, separators=(",", ":")) ALL_DOMAIN_EXCLUDE_ATTRS = {ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES} diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 8b15e15042f..7df4cf57e56 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -744,12 +744,11 @@ class Recorder(threading.Thread): return try: - shared_data_bytes = EventData.shared_data_bytes_from_event(event) + shared_data = EventData.shared_data_from_event(event) except (TypeError, ValueError) as ex: _LOGGER.warning("Event is not JSON serializable: %s: %s", event, ex) return - shared_data = shared_data_bytes.decode("utf-8") # Matching attributes found in the pending commit if pending_event_data := self._pending_event_data.get(shared_data): dbevent.event_data_rel = pending_event_data @@ -757,7 +756,7 @@ class Recorder(threading.Thread): elif data_id := self._event_data_ids.get(shared_data): dbevent.data_id = data_id else: - data_hash = EventData.hash_shared_data_bytes(shared_data_bytes) + data_hash = EventData.hash_shared_data(shared_data) # Matching attributes found in the database if data_id := self._find_shared_data_in_db(data_hash, shared_data): self._event_data_ids[shared_data] = dbevent.data_id = data_id @@ -776,7 +775,7 @@ class Recorder(threading.Thread): assert self.event_session is not None try: dbstate = States.from_event(event) - shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event( + shared_attrs = StateAttributes.shared_attrs_from_event( event, self._exclude_attributes_by_domain ) except (TypeError, ValueError) as ex: @@ -787,7 +786,6 @@ class Recorder(threading.Thread): ) return - shared_attrs = shared_attrs_bytes.decode("utf-8") dbstate.attributes = None # Matching attributes found in the pending commit if pending_attributes := self._pending_state_attributes.get(shared_attrs): @@ -796,7 +794,7 @@ class Recorder(threading.Thread): elif attributes_id := self._state_attributes_ids.get(shared_attrs): dbstate.attributes_id = attributes_id else: - attr_hash = StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes) + attr_hash = StateAttributes.hash_shared_attrs(shared_attrs) # Matching attributes found in the database if attributes_id := self._find_shared_attr_in_db(attr_hash, shared_attrs): dbstate.attributes_id = attributes_id diff --git a/homeassistant/components/recorder/models.py b/homeassistant/components/recorder/models.py index e0a22184cc8..70c816c2af5 100644 --- a/homeassistant/components/recorder/models.py +++ b/homeassistant/components/recorder/models.py @@ -3,12 +3,12 @@ from __future__ import annotations from collections.abc import Callable from datetime import datetime, timedelta +import json import logging from typing import Any, TypedDict, cast, overload import ciso8601 from fnvhash import fnv1a_32 -import orjson from sqlalchemy import ( JSON, BigInteger, @@ -46,10 +46,9 @@ from homeassistant.const import ( MAX_LENGTH_STATE_STATE, ) from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id -from homeassistant.helpers.json import JSON_DUMP, json_bytes import homeassistant.util.dt as dt_util -from .const import ALL_DOMAIN_EXCLUDE_ATTRS +from .const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP # SQLAlchemy Schema # pylint: disable=invalid-name @@ -133,7 +132,7 @@ class JSONLiteral(JSON): # type: ignore[misc] def process(value: Any) -> str: """Dump json.""" - return JSON_DUMP(value) + return json.dumps(value) return process @@ -200,7 +199,7 @@ class Events(Base): # type: ignore[misc,valid-type] try: return Event( self.event_type, - orjson.loads(self.event_data) if self.event_data else {}, + json.loads(self.event_data) if self.event_data else {}, EventOrigin(self.origin) if self.origin else EVENT_ORIGIN_ORDER[self.origin_idx], @@ -208,7 +207,7 @@ class Events(Base): # type: ignore[misc,valid-type] context=context, ) except ValueError: - # When orjson.loads fails + # When json.loads fails _LOGGER.exception("Error converting to event: %s", self) return None @@ -236,26 +235,25 @@ class EventData(Base): # type: ignore[misc,valid-type] @staticmethod def from_event(event: Event) -> EventData: """Create object from an event.""" - shared_data = json_bytes(event.data) + shared_data = JSON_DUMP(event.data) return EventData( - shared_data=shared_data.decode("utf-8"), - hash=EventData.hash_shared_data_bytes(shared_data), + shared_data=shared_data, hash=EventData.hash_shared_data(shared_data) ) @staticmethod - def shared_data_bytes_from_event(event: Event) -> bytes: - """Create shared_data from an event.""" - return json_bytes(event.data) + def shared_data_from_event(event: Event) -> str: + """Create shared_attrs from an event.""" + return JSON_DUMP(event.data) @staticmethod - def hash_shared_data_bytes(shared_data_bytes: bytes) -> int: + def hash_shared_data(shared_data: str) -> int: """Return the hash of json encoded shared data.""" - return cast(int, fnv1a_32(shared_data_bytes)) + return cast(int, fnv1a_32(shared_data.encode("utf-8"))) def to_native(self) -> dict[str, Any]: """Convert to an HA state object.""" try: - return cast(dict[str, Any], orjson.loads(self.shared_data)) + return cast(dict[str, Any], json.loads(self.shared_data)) except ValueError: _LOGGER.exception("Error converting row to event data: %s", self) return {} @@ -342,9 +340,9 @@ class States(Base): # type: ignore[misc,valid-type] parent_id=self.context_parent_id, ) try: - attrs = orjson.loads(self.attributes) if self.attributes else {} + attrs = json.loads(self.attributes) if self.attributes else {} except ValueError: - # When orjson.loads fails + # When json.loads fails _LOGGER.exception("Error converting row to state: %s", self) return None if self.last_changed is None or self.last_changed == self.last_updated: @@ -390,39 +388,40 @@ class StateAttributes(Base): # type: ignore[misc,valid-type] """Create object from a state_changed event.""" state: State | None = event.data.get("new_state") # None state means the state was removed from the state machine - attr_bytes = b"{}" if state is None else json_bytes(state.attributes) - dbstate = StateAttributes(shared_attrs=attr_bytes.decode("utf-8")) - dbstate.hash = StateAttributes.hash_shared_attrs_bytes(attr_bytes) + dbstate = StateAttributes( + shared_attrs="{}" if state is None else JSON_DUMP(state.attributes) + ) + dbstate.hash = StateAttributes.hash_shared_attrs(dbstate.shared_attrs) return dbstate @staticmethod - def shared_attrs_bytes_from_event( + def shared_attrs_from_event( event: Event, exclude_attrs_by_domain: dict[str, set[str]] - ) -> bytes: + ) -> str: """Create shared_attrs from a state_changed event.""" state: State | None = event.data.get("new_state") # None state means the state was removed from the state machine if state is None: - return b"{}" + return "{}" domain = split_entity_id(state.entity_id)[0] exclude_attrs = ( exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS ) - return json_bytes( + return JSON_DUMP( {k: v for k, v in state.attributes.items() if k not in exclude_attrs} ) @staticmethod - def hash_shared_attrs_bytes(shared_attrs_bytes: bytes) -> int: - """Return the hash of orjson encoded shared attributes.""" - return cast(int, fnv1a_32(shared_attrs_bytes)) + def hash_shared_attrs(shared_attrs: str) -> int: + """Return the hash of json encoded shared attributes.""" + return cast(int, fnv1a_32(shared_attrs.encode("utf-8"))) def to_native(self) -> dict[str, Any]: """Convert to an HA state object.""" try: - return cast(dict[str, Any], orjson.loads(self.shared_attrs)) + return cast(dict[str, Any], json.loads(self.shared_attrs)) except ValueError: - # When orjson.loads fails + # When json.loads fails _LOGGER.exception("Error converting row to state attributes: %s", self) return {} @@ -836,7 +835,7 @@ def decode_attributes_from_row( if not source or source == EMPTY_JSON_OBJECT: return {} try: - attr_cache[source] = attributes = orjson.loads(source) + attr_cache[source] = attributes = json.loads(source) except ValueError: _LOGGER.exception("Error converting row to state attributes: %s", source) attr_cache[source] = attributes = {} diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index bea08722eb0..61bcb8badf0 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -29,7 +29,7 @@ from homeassistant.helpers.event import ( TrackTemplateResult, async_track_template_result, ) -from homeassistant.helpers.json import JSON_DUMP, ExtendedJSONEncoder +from homeassistant.helpers.json import ExtendedJSONEncoder from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations @@ -241,13 +241,13 @@ def handle_get_states( # to succeed for the UI to show. response = messages.result_message(msg["id"], states) try: - connection.send_message(JSON_DUMP(response)) + connection.send_message(const.JSON_DUMP(response)) return except (ValueError, TypeError): connection.logger.error( "Unable to serialize to JSON. Bad data found at %s", format_unserializable_data( - find_paths_unserializable_data(response, dump=JSON_DUMP) + find_paths_unserializable_data(response, dump=const.JSON_DUMP) ), ) del response @@ -256,13 +256,13 @@ def handle_get_states( serialized = [] for state in states: try: - serialized.append(JSON_DUMP(state)) + serialized.append(const.JSON_DUMP(state)) except (ValueError, TypeError): # Error is already logged above pass # We now have partially serialized states. Craft some JSON. - response2 = JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"])) + response2 = const.JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"])) response2 = response2.replace('"TO_REPLACE"', ", ".join(serialized)) connection.send_message(response2) @@ -315,13 +315,13 @@ def handle_subscribe_entities( # to succeed for the UI to show. response = messages.event_message(msg["id"], data) try: - connection.send_message(JSON_DUMP(response)) + connection.send_message(const.JSON_DUMP(response)) return except (ValueError, TypeError): connection.logger.error( "Unable to serialize to JSON. Bad data found at %s", format_unserializable_data( - find_paths_unserializable_data(response, dump=JSON_DUMP) + find_paths_unserializable_data(response, dump=const.JSON_DUMP) ), ) del response @@ -330,14 +330,14 @@ def handle_subscribe_entities( cannot_serialize: list[str] = [] for entity_id, state_dict in add_entities.items(): try: - JSON_DUMP(state_dict) + const.JSON_DUMP(state_dict) except (ValueError, TypeError): cannot_serialize.append(entity_id) for entity_id in cannot_serialize: del add_entities[entity_id] - connection.send_message(JSON_DUMP(messages.event_message(msg["id"], data))) + connection.send_message(const.JSON_DUMP(messages.event_message(msg["id"], data))) @decorators.websocket_command({vol.Required("type"): "get_services"}) diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 26c4c6f8321..0280863f83e 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -11,7 +11,6 @@ import voluptuous as vol from homeassistant.auth.models import RefreshToken, User from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError, Unauthorized -from homeassistant.helpers.json import JSON_DUMP from . import const, messages @@ -57,7 +56,7 @@ class ActiveConnection: async def send_big_result(self, msg_id: int, result: Any) -> None: """Send a result message that would be expensive to JSON serialize.""" content = await self.hass.async_add_executor_job( - JSON_DUMP, messages.result_message(msg_id, result) + const.JSON_DUMP, messages.result_message(msg_id, result) ) self.send_message(content) diff --git a/homeassistant/components/websocket_api/const.py b/homeassistant/components/websocket_api/const.py index 60a00126092..107cf6d0270 100644 --- a/homeassistant/components/websocket_api/const.py +++ b/homeassistant/components/websocket_api/const.py @@ -4,9 +4,12 @@ from __future__ import annotations import asyncio from collections.abc import Awaitable, Callable from concurrent import futures +from functools import partial +import json from typing import TYPE_CHECKING, Any, Final from homeassistant.core import HomeAssistant +from homeassistant.helpers.json import JSONEncoder if TYPE_CHECKING: from .connection import ActiveConnection # noqa: F401 @@ -50,6 +53,10 @@ SIGNAL_WEBSOCKET_DISCONNECTED: Final = "websocket_disconnected" # Data used to store the current connection list DATA_CONNECTIONS: Final = f"{DOMAIN}.connections" +JSON_DUMP: Final = partial( + json.dumps, cls=JSONEncoder, allow_nan=False, separators=(",", ":") +) + COMPRESSED_STATE_STATE = "s" COMPRESSED_STATE_ATTRIBUTES = "a" COMPRESSED_STATE_CONTEXT = "c" diff --git a/homeassistant/components/websocket_api/messages.py b/homeassistant/components/websocket_api/messages.py index c3e5f6bb5f5..f546ba5eec6 100644 --- a/homeassistant/components/websocket_api/messages.py +++ b/homeassistant/components/websocket_api/messages.py @@ -9,7 +9,6 @@ import voluptuous as vol from homeassistant.core import Event, State from homeassistant.helpers import config_validation as cv -from homeassistant.helpers.json import JSON_DUMP from homeassistant.util.json import ( find_paths_unserializable_data, format_unserializable_data, @@ -194,15 +193,15 @@ def compressed_state_dict_add(state: State) -> dict[str, Any]: def message_to_json(message: dict[str, Any]) -> str: """Serialize a websocket message to json.""" try: - return JSON_DUMP(message) + return const.JSON_DUMP(message) except (ValueError, TypeError): _LOGGER.error( "Unable to serialize to JSON. Bad data found at %s", format_unserializable_data( - find_paths_unserializable_data(message, dump=JSON_DUMP) + find_paths_unserializable_data(message, dump=const.JSON_DUMP) ), ) - return JSON_DUMP( + return const.JSON_DUMP( error_message( message["id"], const.ERR_UNKNOWN_ERROR, "Invalid JSON in response" ) diff --git a/homeassistant/helpers/aiohttp_client.py b/homeassistant/helpers/aiohttp_client.py index 2e56698db41..eaabb002b0a 100644 --- a/homeassistant/helpers/aiohttp_client.py +++ b/homeassistant/helpers/aiohttp_client.py @@ -14,7 +14,6 @@ from aiohttp import web from aiohttp.hdrs import CONTENT_TYPE, USER_AGENT from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout import async_timeout -import orjson from homeassistant import config_entries from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__ @@ -98,7 +97,6 @@ def _async_create_clientsession( """Create a new ClientSession with kwargs, i.e. for cookies.""" clientsession = aiohttp.ClientSession( connector=_async_get_connector(hass, verify_ssl), - json_serialize=lambda x: orjson.dumps(x).decode("utf-8"), **kwargs, ) # Prevent packages accidentally overriding our default headers diff --git a/homeassistant/helpers/json.py b/homeassistant/helpers/json.py index 912667a13b5..c581e5a9361 100644 --- a/homeassistant/helpers/json.py +++ b/homeassistant/helpers/json.py @@ -1,10 +1,7 @@ """Helpers to help with encoding Home Assistant objects in JSON.""" import datetime import json -from pathlib import Path -from typing import Any, Final - -import orjson +from typing import Any class JSONEncoder(json.JSONEncoder): @@ -25,20 +22,6 @@ class JSONEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, o) -def json_encoder_default(obj: Any) -> Any: - """Convert Home Assistant objects. - - Hand other objects to the original method. - """ - if isinstance(obj, set): - return list(obj) - if hasattr(obj, "as_dict"): - return obj.as_dict() - if isinstance(obj, Path): - return obj.as_posix() - raise TypeError - - class ExtendedJSONEncoder(JSONEncoder): """JSONEncoder that supports Home Assistant objects and falls back to repr(o).""" @@ -57,31 +40,3 @@ class ExtendedJSONEncoder(JSONEncoder): return super().default(o) except TypeError: return {"__type": str(type(o)), "repr": repr(o)} - - -def json_bytes(data: Any) -> bytes: - """Dump json bytes.""" - return orjson.dumps( - data, option=orjson.OPT_NON_STR_KEYS, default=json_encoder_default - ) - - -def json_dumps(data: Any) -> str: - """Dump json string. - - orjson supports serializing dataclasses natively which - eliminates the need to implement as_dict in many places - when the data is already in a dataclass. This works - well as long as all the data in the dataclass can also - be serialized. - - If it turns out to be a problem we can disable this - with option |= orjson.OPT_PASSTHROUGH_DATACLASS and it - will fallback to as_dict - """ - return orjson.dumps( - data, option=orjson.OPT_NON_STR_KEYS, default=json_encoder_default - ).decode("utf-8") - - -JSON_DUMP: Final = json_dumps diff --git a/homeassistant/package_constraints.txt b/homeassistant/package_constraints.txt index c158d26a9aa..a3d8a00bcfb 100644 --- a/homeassistant/package_constraints.txt +++ b/homeassistant/package_constraints.txt @@ -20,7 +20,6 @@ httpx==0.23.0 ifaddr==0.1.7 jinja2==3.1.2 lru-dict==1.1.7 -orjson==3.6.8 paho-mqtt==1.6.1 pillow==9.1.1 pip>=21.0,<22.2 diff --git a/homeassistant/scripts/benchmark/__init__.py b/homeassistant/scripts/benchmark/__init__.py index efbfec5e961..a681b3e210d 100644 --- a/homeassistant/scripts/benchmark/__init__.py +++ b/homeassistant/scripts/benchmark/__init__.py @@ -12,13 +12,14 @@ from timeit import default_timer as timer from typing import TypeVar from homeassistant import core +from homeassistant.components.websocket_api.const import JSON_DUMP from homeassistant.const import EVENT_STATE_CHANGED from homeassistant.helpers.entityfilter import convert_include_exclude_filter from homeassistant.helpers.event import ( async_track_state_change, async_track_state_change_event, ) -from homeassistant.helpers.json import JSON_DUMP, JSONEncoder +from homeassistant.helpers.json import JSONEncoder # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs # mypy: no-warn-return-any diff --git a/homeassistant/util/json.py b/homeassistant/util/json.py index 82ecfd34d6d..fdee7a7a90f 100644 --- a/homeassistant/util/json.py +++ b/homeassistant/util/json.py @@ -7,8 +7,6 @@ import json import logging from typing import Any -import orjson - from homeassistant.core import Event, State from homeassistant.exceptions import HomeAssistantError @@ -32,7 +30,7 @@ def load_json(filename: str, default: list | dict | None = None) -> list | dict: """ try: with open(filename, encoding="utf-8") as fdesc: - return orjson.loads(fdesc.read()) # type: ignore[no-any-return] + return json.loads(fdesc.read()) # type: ignore[no-any-return] except FileNotFoundError: # This is not a fatal error _LOGGER.debug("JSON file not found: %s", filename) @@ -58,10 +56,7 @@ def save_json( Returns True on success. """ try: - if encoder: - json_data = json.dumps(data, indent=2, cls=encoder) - else: - json_data = orjson.dumps(data, option=orjson.OPT_INDENT_2).decode("utf-8") + json_data = json.dumps(data, indent=4, cls=encoder) except TypeError as error: msg = f"Failed to serialize to JSON: {filename}. Bad data at {format_unserializable_data(find_paths_unserializable_data(data))}" _LOGGER.error(msg) diff --git a/pyproject.toml b/pyproject.toml index 7e62bafd6af..cc745f58ad6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,6 @@ dependencies = [ "PyJWT==2.4.0", # PyJWT has loose dependency. We want the latest one. "cryptography==36.0.2", - "orjson==3.6.8", "pip>=21.0,<22.2", "python-slugify==4.0.1", "pyyaml==6.0", @@ -120,7 +119,6 @@ extension-pkg-allow-list = [ "av.audio.stream", "av.stream", "ciso8601", - "orjson", "cv2", ] diff --git a/requirements.txt b/requirements.txt index 9805ae7cd47..fe2bf87ad25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,6 @@ ifaddr==0.1.7 jinja2==3.1.2 PyJWT==2.4.0 cryptography==36.0.2 -orjson==3.6.8 pip>=21.0,<22.2 python-slugify==4.0.1 pyyaml==6.0 diff --git a/tests/components/energy/test_validate.py b/tests/components/energy/test_validate.py index e802688daaf..37ebe4147c5 100644 --- a/tests/components/energy/test_validate.py +++ b/tests/components/energy/test_validate.py @@ -4,7 +4,6 @@ from unittest.mock import patch import pytest from homeassistant.components.energy import async_get_manager, validate -from homeassistant.helpers.json import JSON_DUMP from homeassistant.setup import async_setup_component @@ -409,11 +408,7 @@ async def test_validation_grid( }, ) - result = await validate.async_validate(hass) - # verify its also json serializable - JSON_DUMP(result) - - assert result.as_dict() == { + assert (await validate.async_validate(hass)).as_dict() == { "energy_sources": [ [ { diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index 0f4695596fc..4d3302f7c13 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -619,15 +619,12 @@ async def test_states_filters_visible(hass, hass_admin_user, websocket_client): async def test_get_states_not_allows_nan(hass, websocket_client): - """Test get_states command converts NaN to None.""" + """Test get_states command not allows NaN floats.""" hass.states.async_set("greeting.hello", "world") hass.states.async_set("greeting.bad", "data", {"hello": float("NaN")}) hass.states.async_set("greeting.bye", "universe") await websocket_client.send_json({"id": 5, "type": "get_states"}) - bad = dict(hass.states.get("greeting.bad").as_dict()) - bad["attributes"] = dict(bad["attributes"]) - bad["attributes"]["hello"] = None msg = await websocket_client.receive_json() assert msg["id"] == 5 @@ -635,7 +632,6 @@ async def test_get_states_not_allows_nan(hass, websocket_client): assert msg["success"] assert msg["result"] == [ hass.states.get("greeting.hello").as_dict(), - bad, hass.states.get("greeting.bye").as_dict(), ]