diff --git a/tests/conftest.py b/tests/conftest.py index 9b861d5bde5..2c23270daee 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,11 +42,14 @@ import respx from syrupy.assertion import SnapshotAssertion from syrupy.session import SnapshotSession +# Setup patching of JSON functions before any other Home Assistant imports +from . import patch_json # isort:skip + from homeassistant import block_async_io from homeassistant.exceptions import ServiceNotFound # Setup patching of recorder functions before any other Home Assistant imports -from . import patch_recorder +from . import patch_recorder # isort:skip # Setup patching of dt_util time functions before any other Home Assistant imports from . import patch_time # noqa: F401, isort:skip @@ -449,6 +452,12 @@ def reset_globals() -> Generator[None]: frame.async_setup(None) frame._REPORTED_INTEGRATIONS.clear() + # Reset patch_json + if patch_json.mock_objects: + obj = patch_json.mock_objects.pop() + patch_json.mock_objects.clear() + pytest.fail(f"Test attempted to serialize mock object {obj}") + @pytest.fixture(autouse=True, scope="session") def bcrypt_cost() -> Generator[None]: diff --git a/tests/patch_json.py b/tests/patch_json.py new file mode 100644 index 00000000000..e741ba1a816 --- /dev/null +++ b/tests/patch_json.py @@ -0,0 +1,37 @@ +"""Patch JSON related functions.""" + +from __future__ import annotations + +import functools +from typing import Any +from unittest import mock + +import orjson + +from homeassistant.helpers import json as json_helper + +real_json_encoder_default = json_helper.json_encoder_default + +mock_objects = [] + + +def json_encoder_default(obj: Any) -> Any: + """Convert Home Assistant objects. + + Hand other objects to the original method. + """ + if isinstance(obj, mock.Base): + mock_objects.append(obj) + raise TypeError(f"Attempting to serialize mock object {obj}") + return real_json_encoder_default(obj) + + +json_helper.json_encoder_default = json_encoder_default +json_helper.json_bytes = functools.partial( + orjson.dumps, option=orjson.OPT_NON_STR_KEYS, default=json_encoder_default +) +json_helper.json_bytes_sorted = functools.partial( + orjson.dumps, + option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SORT_KEYS, + default=json_encoder_default, +)