diff --git a/homeassistant/util/json.py b/homeassistant/util/json.py index d69a4106728..68273c89743 100644 --- a/homeassistant/util/json.py +++ b/homeassistant/util/json.py @@ -11,6 +11,10 @@ import orjson from homeassistant.core import Event, State from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.json import ( + JSONEncoder as DefaultHASSJSONEncoder, + json_encoder_default as default_hass_orjson_encoder, +) from .file import write_utf8_file, write_utf8_file_atomic @@ -52,6 +56,15 @@ def _orjson_encoder(data: Any) -> str: ).decode("utf-8") +def _orjson_default_encoder(data: Any) -> str: + """JSON encoder that uses orjson with hass defaults.""" + return orjson.dumps( + data, + option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS, + default=default_hass_orjson_encoder, + ).decode("utf-8") + + def save_json( filename: str, data: list | dict, @@ -64,10 +77,20 @@ def save_json( Returns True on success. """ - dump: Callable[[Any], Any] = json.dumps + dump: Callable[[Any], Any] try: if encoder: - json_data = json.dumps(data, indent=2, cls=encoder) + # For backwards compatibility, if they pass in the + # default json encoder we use _orjson_default_encoder + # which is the orjson equivalent to the default encoder. + if encoder is DefaultHASSJSONEncoder: + dump = _orjson_default_encoder + json_data = _orjson_default_encoder(data) + # If they pass a custom encoder that is not the + # DefaultHASSJSONEncoder, we use the slow path of json.dumps + else: + dump = json.dumps + json_data = json.dumps(data, indent=2, cls=encoder) else: dump = _orjson_encoder json_data = _orjson_encoder(data) diff --git a/tests/util/test_json.py b/tests/util/test_json.py index 9974cbb9628..28d321036c5 100644 --- a/tests/util/test_json.py +++ b/tests/util/test_json.py @@ -5,12 +5,13 @@ from json import JSONEncoder, dumps import math import os from tempfile import mkdtemp -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest from homeassistant.core import Event, State from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.json import JSONEncoder as DefaultHASSJSONEncoder from homeassistant.helpers.template import TupleWrapper from homeassistant.util.json import ( SerializationError, @@ -127,6 +128,21 @@ def test_custom_encoder(): assert data == "9" +def test_default_encoder_is_passed(): + """Test we use orjson if they pass in the default encoder.""" + fname = _path_for("test6") + with patch( + "homeassistant.util.json.orjson.dumps", return_value=b"{}" + ) as mock_orjson_dumps: + save_json(fname, {"any": 1}, encoder=DefaultHASSJSONEncoder) + assert len(mock_orjson_dumps.mock_calls) == 1 + # Patch json.dumps to make sure we are using the orjson path + with patch("homeassistant.util.json.json.dumps", side_effect=Exception): + save_json(fname, {"any": {1}}, encoder=DefaultHASSJSONEncoder) + data = load_json(fname) + assert data == {"any": [1]} + + def test_find_unserializable_data(): """Find unserializeable data.""" assert find_paths_unserializable_data(1) == {}