From af715a4b9a21a7c855afc67bec005d0e977c8987 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Fri, 8 Dec 2023 18:13:34 +0100 Subject: [PATCH] Add workaround for orjson not handling subclasses of str (#105314) Co-authored-by: Franck Nijhof --- homeassistant/util/json.py | 14 +++++++++++--- tests/util/test_json.py | 19 +++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/homeassistant/util/json.py b/homeassistant/util/json.py index ac18d43727c..1af35c604eb 100644 --- a/homeassistant/util/json.py +++ b/homeassistant/util/json.py @@ -33,9 +33,17 @@ class SerializationError(HomeAssistantError): """Error serializing the data to JSON.""" -json_loads: Callable[[bytes | bytearray | memoryview | str], JsonValueType] -json_loads = orjson.loads -"""Parse JSON data.""" +def json_loads(__obj: bytes | bytearray | memoryview | str) -> JsonValueType: + """Parse JSON data. + + This adds a workaround for orjson not handling subclasses of str, + https://github.com/ijl/orjson/issues/445. + """ + if type(__obj) in (bytes, bytearray, memoryview, str): + return orjson.loads(__obj) # type:ignore[no-any-return] + if isinstance(__obj, str): + return orjson.loads(str(__obj)) # type:ignore[no-any-return] + return orjson.loads(__obj) # type:ignore[no-any-return] def json_loads_array(__obj: bytes | bytearray | memoryview | str) -> JsonArrayType: diff --git a/tests/util/test_json.py b/tests/util/test_json.py index b3bccf71b58..ff0f1ed8392 100644 --- a/tests/util/test_json.py +++ b/tests/util/test_json.py @@ -1,10 +1,12 @@ """Test Home Assistant json utility functions.""" from pathlib import Path +import orjson import pytest from homeassistant.exceptions import HomeAssistantError from homeassistant.util.json import ( + json_loads, json_loads_array, json_loads_object, load_json, @@ -153,3 +155,20 @@ async def test_deprecated_save_json( save_json(fname, TEST_JSON_A) assert "uses save_json from homeassistant.util.json" in caplog.text assert "should be updated to use homeassistant.helpers.json module" in caplog.text + + +async def test_loading_derived_class(): + """Test loading data from classes derived from str.""" + + class MyStr(str): + pass + + class MyBytes(bytes): + pass + + assert json_loads('"abc"') == "abc" + assert json_loads(MyStr('"abc"')) == "abc" + + assert json_loads(b'"abc"') == "abc" + with pytest.raises(orjson.JSONDecodeError): + assert json_loads(MyBytes(b'"abc"')) == "abc"