mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 15:47:12 +00:00
Use the orjson equivalent default encoder when save_json is passed the default encoder (#74377)
This commit is contained in:
parent
4e1359e2cc
commit
8d0e54d776
@ -11,6 +11,10 @@ import orjson
|
|||||||
|
|
||||||
from homeassistant.core import Event, State
|
from homeassistant.core import Event, State
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers.json import (
|
||||||
|
JSONEncoder as DefaultHASSJSONEncoder,
|
||||||
|
json_encoder_default as default_hass_orjson_encoder,
|
||||||
|
)
|
||||||
|
|
||||||
from .file import write_utf8_file, write_utf8_file_atomic
|
from .file import write_utf8_file, write_utf8_file_atomic
|
||||||
|
|
||||||
@ -52,6 +56,15 @@ def _orjson_encoder(data: Any) -> str:
|
|||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _orjson_default_encoder(data: Any) -> str:
|
||||||
|
"""JSON encoder that uses orjson with hass defaults."""
|
||||||
|
return orjson.dumps(
|
||||||
|
data,
|
||||||
|
option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS,
|
||||||
|
default=default_hass_orjson_encoder,
|
||||||
|
).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def save_json(
|
def save_json(
|
||||||
filename: str,
|
filename: str,
|
||||||
data: list | dict,
|
data: list | dict,
|
||||||
@ -64,10 +77,20 @@ def save_json(
|
|||||||
|
|
||||||
Returns True on success.
|
Returns True on success.
|
||||||
"""
|
"""
|
||||||
dump: Callable[[Any], Any] = json.dumps
|
dump: Callable[[Any], Any]
|
||||||
try:
|
try:
|
||||||
if encoder:
|
if encoder:
|
||||||
json_data = json.dumps(data, indent=2, cls=encoder)
|
# For backwards compatibility, if they pass in the
|
||||||
|
# default json encoder we use _orjson_default_encoder
|
||||||
|
# which is the orjson equivalent to the default encoder.
|
||||||
|
if encoder is DefaultHASSJSONEncoder:
|
||||||
|
dump = _orjson_default_encoder
|
||||||
|
json_data = _orjson_default_encoder(data)
|
||||||
|
# If they pass a custom encoder that is not the
|
||||||
|
# DefaultHASSJSONEncoder, we use the slow path of json.dumps
|
||||||
|
else:
|
||||||
|
dump = json.dumps
|
||||||
|
json_data = json.dumps(data, indent=2, cls=encoder)
|
||||||
else:
|
else:
|
||||||
dump = _orjson_encoder
|
dump = _orjson_encoder
|
||||||
json_data = _orjson_encoder(data)
|
json_data = _orjson_encoder(data)
|
||||||
|
@ -5,12 +5,13 @@ from json import JSONEncoder, dumps
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from tempfile import mkdtemp
|
from tempfile import mkdtemp
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.core import Event, State
|
from homeassistant.core import Event, State
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers.json import JSONEncoder as DefaultHASSJSONEncoder
|
||||||
from homeassistant.helpers.template import TupleWrapper
|
from homeassistant.helpers.template import TupleWrapper
|
||||||
from homeassistant.util.json import (
|
from homeassistant.util.json import (
|
||||||
SerializationError,
|
SerializationError,
|
||||||
@ -127,6 +128,21 @@ def test_custom_encoder():
|
|||||||
assert data == "9"
|
assert data == "9"
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_encoder_is_passed():
|
||||||
|
"""Test we use orjson if they pass in the default encoder."""
|
||||||
|
fname = _path_for("test6")
|
||||||
|
with patch(
|
||||||
|
"homeassistant.util.json.orjson.dumps", return_value=b"{}"
|
||||||
|
) as mock_orjson_dumps:
|
||||||
|
save_json(fname, {"any": 1}, encoder=DefaultHASSJSONEncoder)
|
||||||
|
assert len(mock_orjson_dumps.mock_calls) == 1
|
||||||
|
# Patch json.dumps to make sure we are using the orjson path
|
||||||
|
with patch("homeassistant.util.json.json.dumps", side_effect=Exception):
|
||||||
|
save_json(fname, {"any": {1}}, encoder=DefaultHASSJSONEncoder)
|
||||||
|
data = load_json(fname)
|
||||||
|
assert data == {"any": [1]}
|
||||||
|
|
||||||
|
|
||||||
def test_find_unserializable_data():
|
def test_find_unserializable_data():
|
||||||
"""Find unserializeable data."""
|
"""Find unserializeable data."""
|
||||||
assert find_paths_unserializable_data(1) == {}
|
assert find_paths_unserializable_data(1) == {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user