Fix handling of renamed backup files in the core writer (#136898)

* Fix handling of renamed backup files in the core writer

* Adjust mocking

* Raise BackupAgentError instead of KeyError in get_backup_path

* Add specific error indicating backup not found

* Fix tests

* Ensure backups are loaded

* Fix tests
This commit is contained in:
Erik Montnemery 2025-01-30 15:25:16 +01:00 committed by Bram Kragten
parent aed779172d
commit b300fb1fab
9 changed files with 234 additions and 81 deletions

View File

@ -27,6 +27,12 @@ class BackupAgentUnreachableError(BackupAgentError):
_message = "The backup agent is unreachable."
class BackupNotFound(BackupAgentError):
"""Raised when a backup is not found."""
error_code = "backup_not_found"
class BackupAgent(abc.ABC):
"""Backup agent interface."""
@ -94,11 +100,16 @@ class LocalBackupAgent(BackupAgent):
@abc.abstractmethod
def get_backup_path(self, backup_id: str) -> Path:
"""Return the local path to a backup.
"""Return the local path to an existing backup.
The method should return the path to the backup file with the specified id.
Raises BackupAgentError if the backup does not exist.
"""
@abc.abstractmethod
def get_new_backup_path(self, backup: AgentBackup) -> Path:
"""Return the local path to a new backup."""
class BackupAgentPlatformProtocol(Protocol):
"""Define the format of backup platforms which implement backup agents."""

View File

@ -11,7 +11,7 @@ from typing import Any
from homeassistant.core import HomeAssistant
from homeassistant.helpers.hassio import is_hassio
from .agent import BackupAgent, LocalBackupAgent
from .agent import BackupAgent, BackupNotFound, LocalBackupAgent
from .const import DOMAIN, LOGGER
from .models import AgentBackup
from .util import read_backup
@ -39,7 +39,7 @@ class CoreLocalBackupAgent(LocalBackupAgent):
super().__init__()
self._hass = hass
self._backup_dir = Path(hass.config.path("backups"))
self._backups: dict[str, AgentBackup] = {}
self._backups: dict[str, tuple[AgentBackup, Path]] = {}
self._loaded_backups = False
async def _load_backups(self) -> None:
@ -49,13 +49,13 @@ class CoreLocalBackupAgent(LocalBackupAgent):
self._backups = backups
self._loaded_backups = True
def _read_backups(self) -> dict[str, AgentBackup]:
def _read_backups(self) -> dict[str, tuple[AgentBackup, Path]]:
"""Read backups from disk."""
backups: dict[str, AgentBackup] = {}
backups: dict[str, tuple[AgentBackup, Path]] = {}
for backup_path in self._backup_dir.glob("*.tar"):
try:
backup = read_backup(backup_path)
backups[backup.backup_id] = backup
backups[backup.backup_id] = (backup, backup_path)
except (OSError, TarError, json.JSONDecodeError, KeyError) as err:
LOGGER.warning("Unable to read backup %s: %s", backup_path, err)
return backups
@ -76,13 +76,13 @@ class CoreLocalBackupAgent(LocalBackupAgent):
**kwargs: Any,
) -> None:
"""Upload a backup."""
self._backups[backup.backup_id] = backup
self._backups[backup.backup_id] = (backup, self.get_new_backup_path(backup))
async def async_list_backups(self, **kwargs: Any) -> list[AgentBackup]:
"""List backups."""
if not self._loaded_backups:
await self._load_backups()
return list(self._backups.values())
return [backup for backup, _ in self._backups.values()]
async def async_get_backup(
self,
@ -93,10 +93,10 @@ class CoreLocalBackupAgent(LocalBackupAgent):
if not self._loaded_backups:
await self._load_backups()
if not (backup := self._backups.get(backup_id)):
if backup_id not in self._backups:
return None
backup_path = self.get_backup_path(backup_id)
backup, backup_path = self._backups[backup_id]
if not await self._hass.async_add_executor_job(backup_path.exists):
LOGGER.debug(
(
@ -112,15 +112,28 @@ class CoreLocalBackupAgent(LocalBackupAgent):
return backup
def get_backup_path(self, backup_id: str) -> Path:
"""Return the local path to a backup."""
return self._backup_dir / f"{backup_id}.tar"
"""Return the local path to an existing backup.
Raises BackupAgentError if the backup does not exist.
"""
try:
return self._backups[backup_id][1]
except KeyError as err:
raise BackupNotFound(f"Backup {backup_id} does not exist") from err
def get_new_backup_path(self, backup: AgentBackup) -> Path:
"""Return the local path to a new backup."""
return self._backup_dir / f"{backup.backup_id}.tar"
async def async_delete_backup(self, backup_id: str, **kwargs: Any) -> None:
"""Delete a backup file."""
if await self.async_get_backup(backup_id) is None:
return
if not self._loaded_backups:
await self._load_backups()
backup_path = self.get_backup_path(backup_id)
try:
backup_path = self.get_backup_path(backup_id)
except BackupNotFound:
return
await self._hass.async_add_executor_job(backup_path.unlink, True)
LOGGER.debug("Deleted backup located at %s", backup_path)
self._backups.pop(backup_id)

View File

@ -1346,10 +1346,24 @@ class CoreBackupReaderWriter(BackupReaderWriter):
if agent_config and not agent_config.protected:
password = None
backup = AgentBackup(
addons=[],
backup_id=backup_id,
database_included=include_database,
date=date_str,
extra_metadata=extra_metadata,
folders=[],
homeassistant_included=True,
homeassistant_version=HAVERSION,
name=backup_name,
protected=password is not None,
size=0,
)
local_agent_tar_file_path = None
if self._local_agent_id in agent_ids:
local_agent = manager.local_backup_agents[self._local_agent_id]
local_agent_tar_file_path = local_agent.get_backup_path(backup_id)
local_agent_tar_file_path = local_agent.get_new_backup_path(backup)
on_progress(
CreateBackupEvent(
@ -1391,19 +1405,7 @@ class CoreBackupReaderWriter(BackupReaderWriter):
# ValueError from json_bytes
raise BackupReaderWriterError(str(err)) from err
else:
backup = AgentBackup(
addons=[],
backup_id=backup_id,
database_included=include_database,
date=date_str,
extra_metadata=extra_metadata,
folders=[],
homeassistant_included=True,
homeassistant_version=HAVERSION,
name=backup_name,
protected=password is not None,
size=size_in_bytes,
)
backup = replace(backup, size=size_in_bytes)
async_add_executor_job = self._hass.async_add_executor_job
@ -1517,7 +1519,7 @@ class CoreBackupReaderWriter(BackupReaderWriter):
manager = self._hass.data[DATA_MANAGER]
if self._local_agent_id in agent_ids:
local_agent = manager.local_backup_agents[self._local_agent_id]
tar_file_path = local_agent.get_backup_path(backup.backup_id)
tar_file_path = local_agent.get_new_backup_path(backup)
await async_add_executor_job(make_backup_dir, tar_file_path.parent)
await async_add_executor_job(shutil.move, temp_file, tar_file_path)
else:

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Callable, Coroutine
from collections.abc import AsyncIterator, Callable, Coroutine, Iterable
from pathlib import Path
from typing import Any
from unittest.mock import ANY, AsyncMock, Mock, patch
@ -52,10 +52,17 @@ TEST_BACKUP_DEF456 = AgentBackup(
protected=False,
size=1,
)
TEST_BACKUP_PATH_DEF456 = Path("custom_def456.tar")
TEST_DOMAIN = "test"
async def aiter_from_iter(iterable: Iterable) -> AsyncIterator:
"""Convert an iterable to an async iterator."""
for i in iterable:
yield i
class BackupAgentTest(BackupAgent):
"""Test backup agent."""
@ -162,7 +169,13 @@ async def setup_backup_integration(
if with_hassio and agent_id == LOCAL_AGENT_ID:
continue
agent = hass.data[DATA_MANAGER].backup_agents[agent_id]
agent._backups = {backups.backup_id: backups for backups in agent_backups}
async def open_stream() -> AsyncIterator[bytes]:
"""Open a stream."""
return aiter_from_iter((b"backup data",))
for backup in agent_backups:
await agent.async_upload_backup(open_stream=open_stream, backup=backup)
if agent_id == LOCAL_AGENT_ID:
agent._loaded_backups = True

View File

@ -13,7 +13,7 @@ from homeassistant.components.backup import DOMAIN
from homeassistant.components.backup.manager import NewBackup, WrittenBackup
from homeassistant.core import HomeAssistant
from .common import TEST_BACKUP_PATH_ABC123
from .common import TEST_BACKUP_PATH_ABC123, TEST_BACKUP_PATH_DEF456
from tests.common import get_fixture_path
@ -38,10 +38,14 @@ def mocked_tarfile_fixture() -> Generator[Mock]:
@pytest.fixture(name="path_glob")
def path_glob_fixture() -> Generator[MagicMock]:
def path_glob_fixture(hass: HomeAssistant) -> Generator[MagicMock]:
"""Mock path glob."""
with patch(
"pathlib.Path.glob", return_value=[TEST_BACKUP_PATH_ABC123]
"pathlib.Path.glob",
return_value=[
Path(hass.config.path()) / "backups" / TEST_BACKUP_PATH_ABC123,
Path(hass.config.path()) / "backups" / TEST_BACKUP_PATH_DEF456,
],
) as path_glob:
yield path_glob

View File

@ -1,5 +1,5 @@
# serializer version: 1
# name: test_delete_backup[found_backups0-True-1]
# name: test_delete_backup[found_backups0-abc123-1-unlink_path0]
dict({
'id': 1,
'result': dict({
@ -10,7 +10,7 @@
'type': 'result',
})
# ---
# name: test_delete_backup[found_backups1-False-0]
# name: test_delete_backup[found_backups1-def456-1-unlink_path1]
dict({
'id': 1,
'result': dict({
@ -21,7 +21,7 @@
'type': 'result',
})
# ---
# name: test_delete_backup[found_backups2-True-0]
# name: test_delete_backup[found_backups2-abc123-0-None]
dict({
'id': 1,
'result': dict({
@ -32,7 +32,7 @@
'type': 'result',
})
# ---
# name: test_load_backups[None]
# name: test_load_backups[mock_read_backup]
dict({
'id': 1,
'result': dict({
@ -47,7 +47,7 @@
'type': 'result',
})
# ---
# name: test_load_backups[None].1
# name: test_load_backups[mock_read_backup].1
dict({
'id': 2,
'result': dict({
@ -82,6 +82,29 @@
'name': 'Test',
'with_automatic_settings': True,
}),
dict({
'addons': list([
]),
'agents': dict({
'backup.local': dict({
'protected': False,
'size': 1,
}),
}),
'backup_id': 'def456',
'database_included': False,
'date': '1980-01-01T00:00:00.000Z',
'failed_agent_ids': list([
]),
'folders': list([
'media',
'share',
]),
'homeassistant_included': True,
'homeassistant_version': '2024.12.0',
'name': 'Test 2',
'with_automatic_settings': None,
}),
]),
'last_attempted_automatic_backup': None,
'last_completed_automatic_backup': None,

View File

@ -12,21 +12,35 @@ from unittest.mock import MagicMock, mock_open, patch
import pytest
from syrupy import SnapshotAssertion
from homeassistant.components.backup import DOMAIN
from homeassistant.components.backup import DOMAIN, AgentBackup
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from .common import TEST_BACKUP_ABC123, TEST_BACKUP_PATH_ABC123
from .common import (
TEST_BACKUP_ABC123,
TEST_BACKUP_DEF456,
TEST_BACKUP_PATH_ABC123,
TEST_BACKUP_PATH_DEF456,
)
from tests.typing import ClientSessionGenerator, WebSocketGenerator
def mock_read_backup(backup_path: Path) -> AgentBackup:
"""Mock read backup."""
mock_backups = {
"abc123": TEST_BACKUP_ABC123,
"custom_def456": TEST_BACKUP_DEF456,
}
return mock_backups[backup_path.stem]
@pytest.fixture(name="read_backup")
def read_backup_fixture(path_glob: MagicMock) -> Generator[MagicMock]:
"""Mock read backup."""
with patch(
"homeassistant.components.backup.backup.read_backup",
return_value=TEST_BACKUP_ABC123,
side_effect=mock_read_backup,
) as read_backup:
yield read_backup
@ -34,7 +48,7 @@ def read_backup_fixture(path_glob: MagicMock) -> Generator[MagicMock]:
@pytest.mark.parametrize(
"side_effect",
[
None,
mock_read_backup,
OSError("Boom"),
TarError("Boom"),
json.JSONDecodeError("Boom", "test", 1),
@ -94,11 +108,21 @@ async def test_upload(
@pytest.mark.usefixtures("read_backup")
@pytest.mark.parametrize(
("found_backups", "backup_exists", "unlink_calls"),
("found_backups", "backup_id", "unlink_calls", "unlink_path"),
[
([TEST_BACKUP_PATH_ABC123], True, 1),
([TEST_BACKUP_PATH_ABC123], False, 0),
(([], True, 0)),
(
[TEST_BACKUP_PATH_ABC123, TEST_BACKUP_PATH_DEF456],
TEST_BACKUP_ABC123.backup_id,
1,
TEST_BACKUP_PATH_ABC123,
),
(
[TEST_BACKUP_PATH_ABC123, TEST_BACKUP_PATH_DEF456],
TEST_BACKUP_DEF456.backup_id,
1,
TEST_BACKUP_PATH_DEF456,
),
(([], TEST_BACKUP_ABC123.backup_id, 0, None)),
],
)
async def test_delete_backup(
@ -108,8 +132,9 @@ async def test_delete_backup(
snapshot: SnapshotAssertion,
path_glob: MagicMock,
found_backups: list[Path],
backup_exists: bool,
backup_id: str,
unlink_calls: int,
unlink_path: Path | None,
) -> None:
"""Test delete backup."""
assert await async_setup_component(hass, DOMAIN, {})
@ -118,12 +143,13 @@ async def test_delete_backup(
path_glob.return_value = found_backups
with (
patch("pathlib.Path.exists", return_value=backup_exists),
patch("pathlib.Path.unlink") as unlink,
patch("pathlib.Path.unlink", autospec=True) as unlink,
):
await client.send_json_auto_id(
{"type": "backup/delete", "backup_id": TEST_BACKUP_ABC123.backup_id}
{"type": "backup/delete", "backup_id": backup_id}
)
assert await client.receive_json() == snapshot
assert unlink.call_count == unlink_calls
for call in unlink.mock_calls:
assert call.args[0] == unlink_path

View File

@ -1,7 +1,7 @@
"""Tests for the Backup integration."""
import asyncio
from collections.abc import AsyncIterator, Iterable
from collections.abc import AsyncIterator
from io import BytesIO, StringIO
import json
import tarfile
@ -15,7 +15,12 @@ from homeassistant.components.backup import AddonInfo, AgentBackup, Folder
from homeassistant.components.backup.const import DATA_MANAGER, DOMAIN
from homeassistant.core import HomeAssistant
from .common import TEST_BACKUP_ABC123, BackupAgentTest, setup_backup_integration
from .common import (
TEST_BACKUP_ABC123,
BackupAgentTest,
aiter_from_iter,
setup_backup_integration,
)
from tests.common import MockUser, get_fixture_path
from tests.typing import ClientSessionGenerator
@ -35,6 +40,9 @@ async def test_downloading_local_backup(
"homeassistant.components.backup.backup.CoreLocalBackupAgent.async_get_backup",
return_value=TEST_BACKUP_ABC123,
),
patch(
"homeassistant.components.backup.backup.CoreLocalBackupAgent.get_backup_path",
),
patch("pathlib.Path.exists", return_value=True),
patch(
"homeassistant.components.backup.http.FileResponse",
@ -73,9 +81,14 @@ async def test_downloading_local_encrypted_backup_file_not_found(
await setup_backup_integration(hass)
client = await hass_client()
with patch(
"homeassistant.components.backup.backup.CoreLocalBackupAgent.async_get_backup",
return_value=TEST_BACKUP_ABC123,
with (
patch(
"homeassistant.components.backup.backup.CoreLocalBackupAgent.async_get_backup",
return_value=TEST_BACKUP_ABC123,
),
patch(
"homeassistant.components.backup.backup.CoreLocalBackupAgent.get_backup_path",
),
):
resp = await client.get(
"/api/backup/download/abc123?agent_id=backup.local&password=blah"
@ -93,12 +106,6 @@ async def test_downloading_local_encrypted_backup(
await _test_downloading_encrypted_backup(hass_client, "backup.local")
async def aiter_from_iter(iterable: Iterable) -> AsyncIterator:
"""Convert an iterable to an async iterator."""
for i in iterable:
yield i
@patch.object(BackupAgentTest, "async_download_backup")
async def test_downloading_remote_encrypted_backup(
download_mock,

View File

@ -54,6 +54,8 @@ from .common import (
LOCAL_AGENT_ID,
TEST_BACKUP_ABC123,
TEST_BACKUP_DEF456,
TEST_BACKUP_PATH_ABC123,
TEST_BACKUP_PATH_DEF456,
BackupAgentTest,
setup_backup_platform,
)
@ -89,6 +91,15 @@ def generate_backup_id_fixture() -> Generator[MagicMock]:
yield mock
def mock_read_backup(backup_path: Path) -> AgentBackup:
"""Mock read backup."""
mock_backups = {
"abc123": TEST_BACKUP_ABC123,
"custom_def456": TEST_BACKUP_DEF456,
}
return mock_backups[backup_path.stem]
@pytest.mark.usefixtures("mock_backup_generation")
async def test_create_backup_service(
hass: HomeAssistant,
@ -1311,7 +1322,11 @@ class LocalBackupAgentTest(BackupAgentTest, LocalBackupAgent):
"""Local backup agent."""
def get_backup_path(self, backup_id: str) -> Path:
"""Return the local path to a backup."""
"""Return the local path to an existing backup."""
return Path("test.tar")
def get_new_backup_path(self, backup: AgentBackup) -> Path:
"""Return the local path to a new backup."""
return Path("test.tar")
@ -2023,10 +2038,6 @@ async def test_receive_backup_file_write_error(
with (
patch("pathlib.Path.open", open_mock),
patch(
"homeassistant.components.backup.manager.read_backup",
return_value=TEST_BACKUP_ABC123,
),
):
resp = await client.post(
"/api/backup/upload?agent_id=test.remote",
@ -2375,18 +2386,61 @@ async def test_receive_backup_file_read_error(
@pytest.mark.usefixtures("path_glob")
@pytest.mark.parametrize(
("agent_id", "password_param", "restore_database", "restore_homeassistant", "dir"),
(
"agent_id",
"backup_id",
"password_param",
"backup_path",
"restore_database",
"restore_homeassistant",
"dir",
),
[
(LOCAL_AGENT_ID, {}, True, False, "backups"),
(LOCAL_AGENT_ID, {"password": "abc123"}, False, True, "backups"),
("test.remote", {}, True, True, "tmp_backups"),
(
LOCAL_AGENT_ID,
TEST_BACKUP_ABC123.backup_id,
{},
TEST_BACKUP_PATH_ABC123,
True,
False,
"backups",
),
(
LOCAL_AGENT_ID,
TEST_BACKUP_DEF456.backup_id,
{},
TEST_BACKUP_PATH_DEF456,
True,
False,
"backups",
),
(
LOCAL_AGENT_ID,
TEST_BACKUP_ABC123.backup_id,
{"password": "abc123"},
TEST_BACKUP_PATH_ABC123,
False,
True,
"backups",
),
(
"test.remote",
TEST_BACKUP_ABC123.backup_id,
{},
TEST_BACKUP_PATH_ABC123,
True,
True,
"tmp_backups",
),
],
)
async def test_restore_backup(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
agent_id: str,
backup_id: str,
password_param: dict[str, str],
backup_path: Path,
restore_database: bool,
restore_homeassistant: bool,
dir: str,
@ -2426,14 +2480,14 @@ async def test_restore_backup(
patch.object(remote_agent, "async_download_backup") as download_mock,
patch(
"homeassistant.components.backup.backup.read_backup",
return_value=TEST_BACKUP_ABC123,
side_effect=mock_read_backup,
),
):
download_mock.return_value.__aiter__.return_value = iter((b"backup data",))
await ws_client.send_json_auto_id(
{
"type": "backup/restore",
"backup_id": TEST_BACKUP_ABC123.backup_id,
"backup_id": backup_id,
"agent_id": agent_id,
"restore_database": restore_database,
"restore_homeassistant": restore_homeassistant,
@ -2473,17 +2527,17 @@ async def test_restore_backup(
result = await ws_client.receive_json()
assert result["success"] is True
backup_path = f"{hass.config.path()}/{dir}/abc123.tar"
full_backup_path = f"{hass.config.path()}/{dir}/{backup_path.name}"
expected_restore_file = json.dumps(
{
"path": backup_path,
"path": full_backup_path,
"password": password,
"remove_after_restore": agent_id != LOCAL_AGENT_ID,
"restore_database": restore_database,
"restore_homeassistant": restore_homeassistant,
}
)
validate_password_mock.assert_called_once_with(Path(backup_path), password)
validate_password_mock.assert_called_once_with(Path(full_backup_path), password)
assert mocked_write_text.call_args[0][0] == expected_restore_file
assert mocked_service_call.called
@ -2533,7 +2587,7 @@ async def test_restore_backup_wrong_password(
patch.object(remote_agent, "async_download_backup") as download_mock,
patch(
"homeassistant.components.backup.backup.read_backup",
return_value=TEST_BACKUP_ABC123,
side_effect=mock_read_backup,
),
):
download_mock.return_value.__aiter__.return_value = iter((b"backup data",))
@ -2581,8 +2635,8 @@ async def test_restore_backup_wrong_password(
("parameters", "expected_error", "expected_reason"),
[
(
{"backup_id": TEST_BACKUP_DEF456.backup_id},
f"Backup def456 not found in agent {LOCAL_AGENT_ID}",
{"backup_id": "no_such_backup"},
f"Backup no_such_backup not found in agent {LOCAL_AGENT_ID}",
"backup_manager_error",
),
(
@ -2629,7 +2683,7 @@ async def test_restore_backup_wrong_parameters(
patch("homeassistant.core.ServiceRegistry.async_call") as mocked_service_call,
patch(
"homeassistant.components.backup.backup.read_backup",
return_value=TEST_BACKUP_ABC123,
side_effect=mock_read_backup,
),
):
await ws_client.send_json_auto_id(