From ebb25ab0e604cf75ab5a038dc9d363d3894dd888 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 11 Nov 2021 00:19:56 -0600 Subject: [PATCH] Make yaml file writes safer (#59384) --- homeassistant/components/config/__init__.py | 6 +- homeassistant/util/file.py | 54 +++++++++++++++++ homeassistant/util/json.py | 28 +-------- tests/components/config/test_group.py | 50 ++++++++++++++++ tests/util/test_file.py | 65 +++++++++++++++++++++ 5 files changed, 175 insertions(+), 28 deletions(-) create mode 100644 homeassistant/util/file.py create mode 100644 tests/util/test_file.py diff --git a/homeassistant/components/config/__init__.py b/homeassistant/components/config/__init__.py index c39a79f3e4a..43c9cfabd08 100644 --- a/homeassistant/components/config/__init__.py +++ b/homeassistant/components/config/__init__.py @@ -11,6 +11,7 @@ from homeassistant.const import CONF_ID, EVENT_COMPONENT_LOADED from homeassistant.core import callback from homeassistant.exceptions import HomeAssistantError from homeassistant.setup import ATTR_COMPONENT +from homeassistant.util.file import write_utf8_file from homeassistant.util.yaml import dump, load_yaml DOMAIN = "config" @@ -252,6 +253,5 @@ def _write(path, data): """Write YAML helper.""" # Do it before opening file. If dump causes error it will now not # truncate the file. - data = dump(data) - with open(path, "w", encoding="utf-8") as outfile: - outfile.write(data) + contents = dump(data) + write_utf8_file(path, contents) diff --git a/homeassistant/util/file.py b/homeassistant/util/file.py new file mode 100644 index 00000000000..9c5b11e4807 --- /dev/null +++ b/homeassistant/util/file.py @@ -0,0 +1,54 @@ +"""File utility functions.""" +from __future__ import annotations + +import logging +import os +import tempfile + +from homeassistant.exceptions import HomeAssistantError + +_LOGGER = logging.getLogger(__name__) + + +class WriteError(HomeAssistantError): + """Error writing the data.""" + + +def write_utf8_file( + filename: str, + utf8_data: str, + private: bool = False, +) -> None: + """Write a file and rename it into place. + + Writes all or nothing. + """ + + tmp_filename = "" + tmp_path = os.path.split(filename)[0] + try: + # Modern versions of Python tempfile create this file with mode 0o600 + with tempfile.NamedTemporaryFile( + mode="w", encoding="utf-8", dir=tmp_path, delete=False + ) as fdesc: + fdesc.write(utf8_data) + tmp_filename = fdesc.name + if not private: + os.chmod(tmp_filename, 0o644) + os.replace(tmp_filename, filename) + except OSError as error: + _LOGGER.exception("Saving file failed: %s", filename) + raise WriteError(error) from error + finally: + if os.path.exists(tmp_filename): + try: + os.remove(tmp_filename) + except OSError as err: + # If we are cleaning up then something else went wrong, so + # we should suppress likely follow-on errors in the cleanup + _LOGGER.error( + "File replacement cleanup failed for %s while saving %s: %s", + tmp_filename, + filename, + err, + ) diff --git a/homeassistant/util/json.py b/homeassistant/util/json.py index e82bd968754..e3bde277837 100644 --- a/homeassistant/util/json.py +++ b/homeassistant/util/json.py @@ -5,13 +5,13 @@ from collections import deque from collections.abc import Callable import json import logging -import os -import tempfile from typing import Any from homeassistant.core import Event, State from homeassistant.exceptions import HomeAssistantError +from .file import write_utf8_file + _LOGGER = logging.getLogger(__name__) @@ -61,29 +61,7 @@ def save_json( _LOGGER.error(msg) raise SerializationError(msg) from error - tmp_filename = "" - tmp_path = os.path.split(filename)[0] - try: - # Modern versions of Python tempfile create this file with mode 0o600 - with tempfile.NamedTemporaryFile( - mode="w", encoding="utf-8", dir=tmp_path, delete=False - ) as fdesc: - fdesc.write(json_data) - tmp_filename = fdesc.name - if not private: - os.chmod(tmp_filename, 0o644) - os.replace(tmp_filename, filename) - except OSError as error: - _LOGGER.exception("Saving JSON file failed: %s", filename) - raise WriteError(error) from error - finally: - if os.path.exists(tmp_filename): - try: - os.remove(tmp_filename) - except OSError as err: - # If we are cleaning up then something else went wrong, so - # we should suppress likely follow-on errors in the cleanup - _LOGGER.error("JSON replacement cleanup failed: %s", err) + write_utf8_file(filename, json_data, private) def format_unserializable_data(data: dict[str, Any]) -> str: diff --git a/tests/components/config/test_group.py b/tests/components/config/test_group.py index 72a9a00cbea..4d1d28020bb 100644 --- a/tests/components/config/test_group.py +++ b/tests/components/config/test_group.py @@ -1,10 +1,14 @@ """Test Group config panel.""" from http import HTTPStatus import json +from pathlib import Path from unittest.mock import AsyncMock, patch from homeassistant.bootstrap import async_setup_component from homeassistant.components import config +from homeassistant.components.config import group +from homeassistant.util.file import write_utf8_file +from homeassistant.util.yaml import dump, load_yaml VIEW_NAME = "api:config:group:config" @@ -113,3 +117,49 @@ async def test_update_device_config_invalid_json(hass, hass_client): resp = await client.post("/api/config/group/config/hello_beer", data="not json") assert resp.status == HTTPStatus.BAD_REQUEST + + +async def test_update_config_write_to_temp_file(hass, hass_client, tmpdir): + """Test config with a temp file.""" + test_dir = await hass.async_add_executor_job(tmpdir.mkdir, "files") + group_yaml = Path(test_dir / "group.yaml") + + with patch.object(group, "GROUP_CONFIG_PATH", group_yaml), patch.object( + config, "SECTIONS", ["group"] + ): + await async_setup_component(hass, "config", {}) + + client = await hass_client() + + orig_data = { + "hello.beer": {"ignored": True}, + "other.entity": {"polling_intensity": 2}, + } + contents = dump(orig_data) + await hass.async_add_executor_job(write_utf8_file, group_yaml, contents) + + mock_call = AsyncMock() + + with patch.object(hass.services, "async_call", mock_call): + resp = await client.post( + "/api/config/group/config/hello_beer", + data=json.dumps( + {"name": "Beer", "entities": ["light.top", "light.bottom"]} + ), + ) + await hass.async_block_till_done() + + assert resp.status == HTTPStatus.OK + result = await resp.json() + assert result == {"result": "ok"} + + new_data = await hass.async_add_executor_job(load_yaml, group_yaml) + + assert new_data == { + **orig_data, + "hello_beer": { + "name": "Beer", + "entities": ["light.top", "light.bottom"], + }, + } + mock_call.assert_called_once_with("group", "reload") diff --git a/tests/util/test_file.py b/tests/util/test_file.py new file mode 100644 index 00000000000..109645a839a --- /dev/null +++ b/tests/util/test_file.py @@ -0,0 +1,65 @@ +"""Test Home Assistant file utility functions.""" +import os +from pathlib import Path +from unittest.mock import patch + +import pytest + +from homeassistant.util.file import WriteError, write_utf8_file + + +def test_write_utf8_file_private(tmpdir): + """Test files can be written as 0o600 or 0o644.""" + test_dir = tmpdir.mkdir("files") + test_file = Path(test_dir / "test.json") + + write_utf8_file(test_file, '{"some":"data"}', False) + with open(test_file) as fh: + assert fh.read() == '{"some":"data"}' + assert os.stat(test_file).st_mode & 0o777 == 0o644 + + write_utf8_file(test_file, '{"some":"data"}', True) + with open(test_file) as fh: + assert fh.read() == '{"some":"data"}' + assert os.stat(test_file).st_mode & 0o777 == 0o600 + + +def test_write_utf8_file_fails_at_creation(tmpdir): + """Test that failed creation of the temp file does not create an empty file.""" + test_dir = tmpdir.mkdir("files") + test_file = Path(test_dir / "test.json") + + with pytest.raises(WriteError), patch( + "homeassistant.util.file.tempfile.NamedTemporaryFile", side_effect=OSError + ): + write_utf8_file(test_file, '{"some":"data"}', False) + + assert not os.path.exists(test_file) + + +def test_write_utf8_file_fails_at_rename(tmpdir, caplog): + """Test that if rename fails not not remove, we do not log the failed cleanup.""" + test_dir = tmpdir.mkdir("files") + test_file = Path(test_dir / "test.json") + + with pytest.raises(WriteError), patch( + "homeassistant.util.file.os.replace", side_effect=OSError + ): + write_utf8_file(test_file, '{"some":"data"}', False) + + assert not os.path.exists(test_file) + + assert "File replacement cleanup failed" not in caplog.text + + +def test_write_utf8_file_fails_at_rename_and_remove(tmpdir, caplog): + """Test that if rename and remove both fail, we log the failed cleanup.""" + test_dir = tmpdir.mkdir("files") + test_file = Path(test_dir / "test.json") + + with pytest.raises(WriteError), patch( + "homeassistant.util.file.os.remove", side_effect=OSError + ), patch("homeassistant.util.file.os.replace", side_effect=OSError): + write_utf8_file(test_file, '{"some":"data"}', False) + + assert "File replacement cleanup failed" in caplog.text