diff --git a/homeassistant/helpers/json.py b/homeassistant/helpers/json.py index b9862907960..ba2486a196e 100644 --- a/homeassistant/helpers/json.py +++ b/homeassistant/helpers/json.py @@ -148,12 +148,17 @@ JSON_DUMP: Final = json_dumps def _orjson_default_encoder(data: Any) -> str: - """JSON encoder that uses orjson with hass defaults.""" + """JSON encoder that uses orjson with hass defaults and returns a str.""" + return _orjson_bytes_default_encoder(data).decode("utf-8") + + +def _orjson_bytes_default_encoder(data: Any) -> bytes: + """JSON encoder that uses orjson with hass defaults and returns bytes.""" return orjson.dumps( data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS, default=json_encoder_default, - ).decode("utf-8") + ) def save_json( @@ -173,11 +178,13 @@ def save_json( if encoder and encoder is not JSONEncoder: # If they pass a custom encoder that is not the # default JSONEncoder, we use the slow path of json.dumps + mode = "w" dump = json.dumps - json_data = json.dumps(data, indent=2, cls=encoder) + json_data: str | bytes = json.dumps(data, indent=2, cls=encoder) else: + mode = "wb" dump = _orjson_default_encoder - json_data = _orjson_default_encoder(data) + json_data = _orjson_bytes_default_encoder(data) except TypeError as error: formatted_data = format_unserializable_data( find_paths_unserializable_data(data, dump=dump) @@ -186,10 +193,8 @@ def save_json( _LOGGER.error(msg) raise SerializationError(msg) from error - if atomic_writes: - write_utf8_file_atomic(filename, json_data, private) - else: - write_utf8_file(filename, json_data, private) + method = write_utf8_file_atomic if atomic_writes else write_utf8_file + method(filename, json_data, private, mode=mode) def find_paths_unserializable_data( diff --git a/homeassistant/util/file.py b/homeassistant/util/file.py index 06471eaca6a..1af65fa51d7 100644 --- a/homeassistant/util/file.py +++ b/homeassistant/util/file.py @@ -17,9 +17,7 @@ class WriteError(HomeAssistantError): def write_utf8_file_atomic( - filename: str, - utf8_data: str, - private: bool = False, + filename: str, utf8_data: bytes | str, private: bool = False, mode: str = "w" ) -> None: """Write a file and rename it into place using atomicwrites. @@ -34,7 +32,7 @@ def write_utf8_file_atomic( negatively impact performance. """ try: - with AtomicWriter(filename, overwrite=True).open() as fdesc: + with AtomicWriter(filename, mode=mode, overwrite=True).open() as fdesc: if not private: os.fchmod(fdesc.fileno(), 0o644) fdesc.write(utf8_data) @@ -44,20 +42,18 @@ def write_utf8_file_atomic( def write_utf8_file( - filename: str, - utf8_data: str, - private: bool = False, + filename: str, utf8_data: bytes | str, private: bool = False, mode: str = "w" ) -> None: """Write a file and rename it into place. Writes all or nothing. """ - tmp_filename = "" + encoding = "utf-8" if "b" not in mode else None try: # Modern versions of Python tempfile create this file with mode 0o600 with tempfile.NamedTemporaryFile( - mode="w", encoding="utf-8", dir=os.path.dirname(filename), delete=False + mode=mode, encoding=encoding, dir=os.path.dirname(filename), delete=False ) as fdesc: fdesc.write(utf8_data) tmp_filename = fdesc.name diff --git a/tests/util/test_file.py b/tests/util/test_file.py index 0b87985fe13..dc09ff83e9e 100644 --- a/tests/util/test_file.py +++ b/tests/util/test_file.py @@ -25,6 +25,11 @@ def test_write_utf8_file_atomic_private(tmpdir: py.path.local, func) -> None: assert fh.read() == '{"some":"data"}' assert os.stat(test_file).st_mode & 0o777 == 0o600 + func(test_file, b'{"some":"data"}', True, mode="wb") + with open(test_file) as fh: + assert fh.read() == '{"some":"data"}' + assert os.stat(test_file).st_mode & 0o777 == 0o600 + def test_write_utf8_file_fails_at_creation(tmpdir: py.path.local) -> None: """Test that failed creation of the temp file does not create an empty file."""