From 094f2cbad733e63bff20b70caf398c9b2ade1cbb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 31 Jul 2023 09:49:02 -0700 Subject: [PATCH] Fix saving subclassed datetime objects in storage (#97502) --- homeassistant/helpers/json.py | 2 ++ tests/common.py | 11 +++++++++-- tests/helpers/test_json.py | 14 ++++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/homeassistant/helpers/json.py b/homeassistant/helpers/json.py index 38c23050885..33054bcb1b0 100644 --- a/homeassistant/helpers/json.py +++ b/homeassistant/helpers/json.py @@ -53,6 +53,8 @@ def json_encoder_default(obj: Any) -> Any: return obj.as_dict() if isinstance(obj, Path): return obj.as_posix() + if isinstance(obj, datetime.datetime): + return obj.isoformat() raise TypeError diff --git a/tests/common.py b/tests/common.py index 4fdccced370..542aa0afcee 100644 --- a/tests/common.py +++ b/tests/common.py @@ -67,7 +67,7 @@ from homeassistant.helpers import ( storage, ) from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.json import JSONEncoder +from homeassistant.helpers.json import JSONEncoder, _orjson_default_encoder from homeassistant.helpers.typing import ConfigType, StateType from homeassistant.setup import setup_component from homeassistant.util.async_ import run_callback_threadsafe @@ -1260,7 +1260,14 @@ def mock_storage( # To ensure that the data can be serialized _LOGGER.debug("Writing data to %s: %s", store.key, data_to_write) raise_contains_mocks(data_to_write) - data[store.key] = json.loads(json.dumps(data_to_write, cls=store._encoder)) + encoder = store._encoder + if encoder and encoder is not JSONEncoder: + # If they pass a custom encoder that is not the + # default JSONEncoder, we use the slow path of json.dumps + dump = ft.partial(json.dumps, cls=store._encoder) + else: + dump = _orjson_default_encoder + data[store.key] = json.loads(dump(data_to_write)) async def mock_remove(store: storage.Store) -> None: """Remove data.""" diff --git a/tests/helpers/test_json.py b/tests/helpers/test_json.py index 419122b018b..7e248c8c381 100644 --- a/tests/helpers/test_json.py +++ b/tests/helpers/test_json.py @@ -215,6 +215,20 @@ def test_custom_encoder(tmp_path: Path) -> None: assert data == "9" +def test_saving_subclassed_datetime(tmp_path: Path) -> None: + """Test saving subclassed datetime objects.""" + + class SubClassDateTime(datetime.datetime): + """Subclass datetime.""" + + time = SubClassDateTime.fromtimestamp(0) + + fname = tmp_path / "test6.json" + save_json(fname, {"time": time}) + data = load_json(fname) + assert data == {"time": time.isoformat()} + + def test_default_encoder_is_passed(tmp_path: Path) -> None: """Test we use orjson if they pass in the default encoder.""" fname = tmp_path / "test6.json"