mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 12:47:08 +00:00
Make yaml file writes safer (#59384)
This commit is contained in:
parent
751098c220
commit
ebb25ab0e6
@ -11,6 +11,7 @@ from homeassistant.const import CONF_ID, EVENT_COMPONENT_LOADED
|
|||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.setup import ATTR_COMPONENT
|
from homeassistant.setup import ATTR_COMPONENT
|
||||||
|
from homeassistant.util.file import write_utf8_file
|
||||||
from homeassistant.util.yaml import dump, load_yaml
|
from homeassistant.util.yaml import dump, load_yaml
|
||||||
|
|
||||||
DOMAIN = "config"
|
DOMAIN = "config"
|
||||||
@ -252,6 +253,5 @@ def _write(path, data):
|
|||||||
"""Write YAML helper."""
|
"""Write YAML helper."""
|
||||||
# Do it before opening file. If dump causes error it will now not
|
# Do it before opening file. If dump causes error it will now not
|
||||||
# truncate the file.
|
# truncate the file.
|
||||||
data = dump(data)
|
contents = dump(data)
|
||||||
with open(path, "w", encoding="utf-8") as outfile:
|
write_utf8_file(path, contents)
|
||||||
outfile.write(data)
|
|
||||||
|
54
homeassistant/util/file.py
Normal file
54
homeassistant/util/file.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
"""File utility functions."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteError(HomeAssistantError):
|
||||||
|
"""Error writing the data."""
|
||||||
|
|
||||||
|
|
||||||
|
def write_utf8_file(
|
||||||
|
filename: str,
|
||||||
|
utf8_data: str,
|
||||||
|
private: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Write a file and rename it into place.
|
||||||
|
|
||||||
|
Writes all or nothing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tmp_filename = ""
|
||||||
|
tmp_path = os.path.split(filename)[0]
|
||||||
|
try:
|
||||||
|
# Modern versions of Python tempfile create this file with mode 0o600
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="w", encoding="utf-8", dir=tmp_path, delete=False
|
||||||
|
) as fdesc:
|
||||||
|
fdesc.write(utf8_data)
|
||||||
|
tmp_filename = fdesc.name
|
||||||
|
if not private:
|
||||||
|
os.chmod(tmp_filename, 0o644)
|
||||||
|
os.replace(tmp_filename, filename)
|
||||||
|
except OSError as error:
|
||||||
|
_LOGGER.exception("Saving file failed: %s", filename)
|
||||||
|
raise WriteError(error) from error
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp_filename):
|
||||||
|
try:
|
||||||
|
os.remove(tmp_filename)
|
||||||
|
except OSError as err:
|
||||||
|
# If we are cleaning up then something else went wrong, so
|
||||||
|
# we should suppress likely follow-on errors in the cleanup
|
||||||
|
_LOGGER.error(
|
||||||
|
"File replacement cleanup failed for %s while saving %s: %s",
|
||||||
|
tmp_filename,
|
||||||
|
filename,
|
||||||
|
err,
|
||||||
|
)
|
@ -5,13 +5,13 @@ from collections import deque
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.core import Event, State
|
from homeassistant.core import Event, State
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
|
from .file import write_utf8_file
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -61,29 +61,7 @@ def save_json(
|
|||||||
_LOGGER.error(msg)
|
_LOGGER.error(msg)
|
||||||
raise SerializationError(msg) from error
|
raise SerializationError(msg) from error
|
||||||
|
|
||||||
tmp_filename = ""
|
write_utf8_file(filename, json_data, private)
|
||||||
tmp_path = os.path.split(filename)[0]
|
|
||||||
try:
|
|
||||||
# Modern versions of Python tempfile create this file with mode 0o600
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
mode="w", encoding="utf-8", dir=tmp_path, delete=False
|
|
||||||
) as fdesc:
|
|
||||||
fdesc.write(json_data)
|
|
||||||
tmp_filename = fdesc.name
|
|
||||||
if not private:
|
|
||||||
os.chmod(tmp_filename, 0o644)
|
|
||||||
os.replace(tmp_filename, filename)
|
|
||||||
except OSError as error:
|
|
||||||
_LOGGER.exception("Saving JSON file failed: %s", filename)
|
|
||||||
raise WriteError(error) from error
|
|
||||||
finally:
|
|
||||||
if os.path.exists(tmp_filename):
|
|
||||||
try:
|
|
||||||
os.remove(tmp_filename)
|
|
||||||
except OSError as err:
|
|
||||||
# If we are cleaning up then something else went wrong, so
|
|
||||||
# we should suppress likely follow-on errors in the cleanup
|
|
||||||
_LOGGER.error("JSON replacement cleanup failed: %s", err)
|
|
||||||
|
|
||||||
|
|
||||||
def format_unserializable_data(data: dict[str, Any]) -> str:
|
def format_unserializable_data(data: dict[str, Any]) -> str:
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
"""Test Group config panel."""
|
"""Test Group config panel."""
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
import json
|
import json
|
||||||
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
from homeassistant.bootstrap import async_setup_component
|
from homeassistant.bootstrap import async_setup_component
|
||||||
from homeassistant.components import config
|
from homeassistant.components import config
|
||||||
|
from homeassistant.components.config import group
|
||||||
|
from homeassistant.util.file import write_utf8_file
|
||||||
|
from homeassistant.util.yaml import dump, load_yaml
|
||||||
|
|
||||||
VIEW_NAME = "api:config:group:config"
|
VIEW_NAME = "api:config:group:config"
|
||||||
|
|
||||||
@ -113,3 +117,49 @@ async def test_update_device_config_invalid_json(hass, hass_client):
|
|||||||
resp = await client.post("/api/config/group/config/hello_beer", data="not json")
|
resp = await client.post("/api/config/group/config/hello_beer", data="not json")
|
||||||
|
|
||||||
assert resp.status == HTTPStatus.BAD_REQUEST
|
assert resp.status == HTTPStatus.BAD_REQUEST
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update_config_write_to_temp_file(hass, hass_client, tmpdir):
|
||||||
|
"""Test config with a temp file."""
|
||||||
|
test_dir = await hass.async_add_executor_job(tmpdir.mkdir, "files")
|
||||||
|
group_yaml = Path(test_dir / "group.yaml")
|
||||||
|
|
||||||
|
with patch.object(group, "GROUP_CONFIG_PATH", group_yaml), patch.object(
|
||||||
|
config, "SECTIONS", ["group"]
|
||||||
|
):
|
||||||
|
await async_setup_component(hass, "config", {})
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
orig_data = {
|
||||||
|
"hello.beer": {"ignored": True},
|
||||||
|
"other.entity": {"polling_intensity": 2},
|
||||||
|
}
|
||||||
|
contents = dump(orig_data)
|
||||||
|
await hass.async_add_executor_job(write_utf8_file, group_yaml, contents)
|
||||||
|
|
||||||
|
mock_call = AsyncMock()
|
||||||
|
|
||||||
|
with patch.object(hass.services, "async_call", mock_call):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/config/group/config/hello_beer",
|
||||||
|
data=json.dumps(
|
||||||
|
{"name": "Beer", "entities": ["light.top", "light.bottom"]}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert resp.status == HTTPStatus.OK
|
||||||
|
result = await resp.json()
|
||||||
|
assert result == {"result": "ok"}
|
||||||
|
|
||||||
|
new_data = await hass.async_add_executor_job(load_yaml, group_yaml)
|
||||||
|
|
||||||
|
assert new_data == {
|
||||||
|
**orig_data,
|
||||||
|
"hello_beer": {
|
||||||
|
"name": "Beer",
|
||||||
|
"entities": ["light.top", "light.bottom"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mock_call.assert_called_once_with("group", "reload")
|
||||||
|
65
tests/util/test_file.py
Normal file
65
tests/util/test_file.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
"""Test Home Assistant file utility functions."""
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.util.file import WriteError, write_utf8_file
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_utf8_file_private(tmpdir):
|
||||||
|
"""Test files can be written as 0o600 or 0o644."""
|
||||||
|
test_dir = tmpdir.mkdir("files")
|
||||||
|
test_file = Path(test_dir / "test.json")
|
||||||
|
|
||||||
|
write_utf8_file(test_file, '{"some":"data"}', False)
|
||||||
|
with open(test_file) as fh:
|
||||||
|
assert fh.read() == '{"some":"data"}'
|
||||||
|
assert os.stat(test_file).st_mode & 0o777 == 0o644
|
||||||
|
|
||||||
|
write_utf8_file(test_file, '{"some":"data"}', True)
|
||||||
|
with open(test_file) as fh:
|
||||||
|
assert fh.read() == '{"some":"data"}'
|
||||||
|
assert os.stat(test_file).st_mode & 0o777 == 0o600
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_utf8_file_fails_at_creation(tmpdir):
|
||||||
|
"""Test that failed creation of the temp file does not create an empty file."""
|
||||||
|
test_dir = tmpdir.mkdir("files")
|
||||||
|
test_file = Path(test_dir / "test.json")
|
||||||
|
|
||||||
|
with pytest.raises(WriteError), patch(
|
||||||
|
"homeassistant.util.file.tempfile.NamedTemporaryFile", side_effect=OSError
|
||||||
|
):
|
||||||
|
write_utf8_file(test_file, '{"some":"data"}', False)
|
||||||
|
|
||||||
|
assert not os.path.exists(test_file)
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_utf8_file_fails_at_rename(tmpdir, caplog):
|
||||||
|
"""Test that if rename fails not not remove, we do not log the failed cleanup."""
|
||||||
|
test_dir = tmpdir.mkdir("files")
|
||||||
|
test_file = Path(test_dir / "test.json")
|
||||||
|
|
||||||
|
with pytest.raises(WriteError), patch(
|
||||||
|
"homeassistant.util.file.os.replace", side_effect=OSError
|
||||||
|
):
|
||||||
|
write_utf8_file(test_file, '{"some":"data"}', False)
|
||||||
|
|
||||||
|
assert not os.path.exists(test_file)
|
||||||
|
|
||||||
|
assert "File replacement cleanup failed" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_utf8_file_fails_at_rename_and_remove(tmpdir, caplog):
|
||||||
|
"""Test that if rename and remove both fail, we log the failed cleanup."""
|
||||||
|
test_dir = tmpdir.mkdir("files")
|
||||||
|
test_file = Path(test_dir / "test.json")
|
||||||
|
|
||||||
|
with pytest.raises(WriteError), patch(
|
||||||
|
"homeassistant.util.file.os.remove", side_effect=OSError
|
||||||
|
), patch("homeassistant.util.file.os.replace", side_effect=OSError):
|
||||||
|
write_utf8_file(test_file, '{"some":"data"}', False)
|
||||||
|
|
||||||
|
assert "File replacement cleanup failed" in caplog.text
|
Loading…
x
Reference in New Issue
Block a user