Fix handling of renamed backup files in the core writer (#136898)

* Fix handling of renamed backup files in the core writer

* Adjust mocking

* Raise BackupAgentError instead of KeyError in get_backup_path

* Add specific error indicating backup not found

* Fix tests

* Ensure backups are loaded

* Fix tests
This commit is contained in:
Erik Montnemery 2025-01-30 15:25:16 +01:00 committed by GitHub
parent 1c4ddb36d5
commit bab616fa61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 234 additions and 81 deletions

View File

@ -27,6 +27,12 @@ class BackupAgentUnreachableError(BackupAgentError):
_message = "The backup agent is unreachable." _message = "The backup agent is unreachable."
class BackupNotFound(BackupAgentError):
"""Raised when a backup is not found."""
error_code = "backup_not_found"
class BackupAgent(abc.ABC): class BackupAgent(abc.ABC):
"""Backup agent interface.""" """Backup agent interface."""
@ -94,11 +100,16 @@ class LocalBackupAgent(BackupAgent):
@abc.abstractmethod @abc.abstractmethod
def get_backup_path(self, backup_id: str) -> Path: def get_backup_path(self, backup_id: str) -> Path:
"""Return the local path to a backup. """Return the local path to an existing backup.
The method should return the path to the backup file with the specified id. The method should return the path to the backup file with the specified id.
Raises BackupAgentError if the backup does not exist.
""" """
@abc.abstractmethod
def get_new_backup_path(self, backup: AgentBackup) -> Path:
"""Return the local path to a new backup."""
class BackupAgentPlatformProtocol(Protocol): class BackupAgentPlatformProtocol(Protocol):
"""Define the format of backup platforms which implement backup agents.""" """Define the format of backup platforms which implement backup agents."""

View File

@ -11,7 +11,7 @@ from typing import Any
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.hassio import is_hassio from homeassistant.helpers.hassio import is_hassio
from .agent import BackupAgent, LocalBackupAgent from .agent import BackupAgent, BackupNotFound, LocalBackupAgent
from .const import DOMAIN, LOGGER from .const import DOMAIN, LOGGER
from .models import AgentBackup from .models import AgentBackup
from .util import read_backup from .util import read_backup
@ -39,7 +39,7 @@ class CoreLocalBackupAgent(LocalBackupAgent):
super().__init__() super().__init__()
self._hass = hass self._hass = hass
self._backup_dir = Path(hass.config.path("backups")) self._backup_dir = Path(hass.config.path("backups"))
self._backups: dict[str, AgentBackup] = {} self._backups: dict[str, tuple[AgentBackup, Path]] = {}
self._loaded_backups = False self._loaded_backups = False
async def _load_backups(self) -> None: async def _load_backups(self) -> None:
@ -49,13 +49,13 @@ class CoreLocalBackupAgent(LocalBackupAgent):
self._backups = backups self._backups = backups
self._loaded_backups = True self._loaded_backups = True
def _read_backups(self) -> dict[str, AgentBackup]: def _read_backups(self) -> dict[str, tuple[AgentBackup, Path]]:
"""Read backups from disk.""" """Read backups from disk."""
backups: dict[str, AgentBackup] = {} backups: dict[str, tuple[AgentBackup, Path]] = {}
for backup_path in self._backup_dir.glob("*.tar"): for backup_path in self._backup_dir.glob("*.tar"):
try: try:
backup = read_backup(backup_path) backup = read_backup(backup_path)
backups[backup.backup_id] = backup backups[backup.backup_id] = (backup, backup_path)
except (OSError, TarError, json.JSONDecodeError, KeyError) as err: except (OSError, TarError, json.JSONDecodeError, KeyError) as err:
LOGGER.warning("Unable to read backup %s: %s", backup_path, err) LOGGER.warning("Unable to read backup %s: %s", backup_path, err)
return backups return backups
@ -76,13 +76,13 @@ class CoreLocalBackupAgent(LocalBackupAgent):
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Upload a backup.""" """Upload a backup."""
self._backups[backup.backup_id] = backup self._backups[backup.backup_id] = (backup, self.get_new_backup_path(backup))
async def async_list_backups(self, **kwargs: Any) -> list[AgentBackup]: async def async_list_backups(self, **kwargs: Any) -> list[AgentBackup]:
"""List backups.""" """List backups."""
if not self._loaded_backups: if not self._loaded_backups:
await self._load_backups() await self._load_backups()
return list(self._backups.values()) return [backup for backup, _ in self._backups.values()]
async def async_get_backup( async def async_get_backup(
self, self,
@ -93,10 +93,10 @@ class CoreLocalBackupAgent(LocalBackupAgent):
if not self._loaded_backups: if not self._loaded_backups:
await self._load_backups() await self._load_backups()
if not (backup := self._backups.get(backup_id)): if backup_id not in self._backups:
return None return None
backup_path = self.get_backup_path(backup_id) backup, backup_path = self._backups[backup_id]
if not await self._hass.async_add_executor_job(backup_path.exists): if not await self._hass.async_add_executor_job(backup_path.exists):
LOGGER.debug( LOGGER.debug(
( (
@ -112,15 +112,28 @@ class CoreLocalBackupAgent(LocalBackupAgent):
return backup return backup
def get_backup_path(self, backup_id: str) -> Path: def get_backup_path(self, backup_id: str) -> Path:
"""Return the local path to a backup.""" """Return the local path to an existing backup.
return self._backup_dir / f"{backup_id}.tar"
Raises BackupAgentError if the backup does not exist.
"""
try:
return self._backups[backup_id][1]
except KeyError as err:
raise BackupNotFound(f"Backup {backup_id} does not exist") from err
def get_new_backup_path(self, backup: AgentBackup) -> Path:
"""Return the local path to a new backup."""
return self._backup_dir / f"{backup.backup_id}.tar"
async def async_delete_backup(self, backup_id: str, **kwargs: Any) -> None: async def async_delete_backup(self, backup_id: str, **kwargs: Any) -> None:
"""Delete a backup file.""" """Delete a backup file."""
if await self.async_get_backup(backup_id) is None: if not self._loaded_backups:
return await self._load_backups()
try:
backup_path = self.get_backup_path(backup_id) backup_path = self.get_backup_path(backup_id)
except BackupNotFound:
return
await self._hass.async_add_executor_job(backup_path.unlink, True) await self._hass.async_add_executor_job(backup_path.unlink, True)
LOGGER.debug("Deleted backup located at %s", backup_path) LOGGER.debug("Deleted backup located at %s", backup_path)
self._backups.pop(backup_id) self._backups.pop(backup_id)

View File

@ -1346,10 +1346,24 @@ class CoreBackupReaderWriter(BackupReaderWriter):
if agent_config and not agent_config.protected: if agent_config and not agent_config.protected:
password = None password = None
backup = AgentBackup(
addons=[],
backup_id=backup_id,
database_included=include_database,
date=date_str,
extra_metadata=extra_metadata,
folders=[],
homeassistant_included=True,
homeassistant_version=HAVERSION,
name=backup_name,
protected=password is not None,
size=0,
)
local_agent_tar_file_path = None local_agent_tar_file_path = None
if self._local_agent_id in agent_ids: if self._local_agent_id in agent_ids:
local_agent = manager.local_backup_agents[self._local_agent_id] local_agent = manager.local_backup_agents[self._local_agent_id]
local_agent_tar_file_path = local_agent.get_backup_path(backup_id) local_agent_tar_file_path = local_agent.get_new_backup_path(backup)
on_progress( on_progress(
CreateBackupEvent( CreateBackupEvent(
@ -1391,19 +1405,7 @@ class CoreBackupReaderWriter(BackupReaderWriter):
# ValueError from json_bytes # ValueError from json_bytes
raise BackupReaderWriterError(str(err)) from err raise BackupReaderWriterError(str(err)) from err
else: else:
backup = AgentBackup( backup = replace(backup, size=size_in_bytes)
addons=[],
backup_id=backup_id,
database_included=include_database,
date=date_str,
extra_metadata=extra_metadata,
folders=[],
homeassistant_included=True,
homeassistant_version=HAVERSION,
name=backup_name,
protected=password is not None,
size=size_in_bytes,
)
async_add_executor_job = self._hass.async_add_executor_job async_add_executor_job = self._hass.async_add_executor_job
@ -1517,7 +1519,7 @@ class CoreBackupReaderWriter(BackupReaderWriter):
manager = self._hass.data[DATA_MANAGER] manager = self._hass.data[DATA_MANAGER]
if self._local_agent_id in agent_ids: if self._local_agent_id in agent_ids:
local_agent = manager.local_backup_agents[self._local_agent_id] local_agent = manager.local_backup_agents[self._local_agent_id]
tar_file_path = local_agent.get_backup_path(backup.backup_id) tar_file_path = local_agent.get_new_backup_path(backup)
await async_add_executor_job(make_backup_dir, tar_file_path.parent) await async_add_executor_job(make_backup_dir, tar_file_path.parent)
await async_add_executor_job(shutil.move, temp_file, tar_file_path) await async_add_executor_job(shutil.move, temp_file, tar_file_path)
else: else:

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncIterator, Callable, Coroutine from collections.abc import AsyncIterator, Callable, Coroutine, Iterable
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from unittest.mock import ANY, AsyncMock, Mock, patch from unittest.mock import ANY, AsyncMock, Mock, patch
@ -52,10 +52,17 @@ TEST_BACKUP_DEF456 = AgentBackup(
protected=False, protected=False,
size=1, size=1,
) )
TEST_BACKUP_PATH_DEF456 = Path("custom_def456.tar")
TEST_DOMAIN = "test" TEST_DOMAIN = "test"
async def aiter_from_iter(iterable: Iterable) -> AsyncIterator:
"""Convert an iterable to an async iterator."""
for i in iterable:
yield i
class BackupAgentTest(BackupAgent): class BackupAgentTest(BackupAgent):
"""Test backup agent.""" """Test backup agent."""
@ -162,7 +169,13 @@ async def setup_backup_integration(
if with_hassio and agent_id == LOCAL_AGENT_ID: if with_hassio and agent_id == LOCAL_AGENT_ID:
continue continue
agent = hass.data[DATA_MANAGER].backup_agents[agent_id] agent = hass.data[DATA_MANAGER].backup_agents[agent_id]
agent._backups = {backups.backup_id: backups for backups in agent_backups}
async def open_stream() -> AsyncIterator[bytes]:
"""Open a stream."""
return aiter_from_iter((b"backup data",))
for backup in agent_backups:
await agent.async_upload_backup(open_stream=open_stream, backup=backup)
if agent_id == LOCAL_AGENT_ID: if agent_id == LOCAL_AGENT_ID:
agent._loaded_backups = True agent._loaded_backups = True

View File

@ -13,7 +13,7 @@ from homeassistant.components.backup import DOMAIN
from homeassistant.components.backup.manager import NewBackup, WrittenBackup from homeassistant.components.backup.manager import NewBackup, WrittenBackup
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .common import TEST_BACKUP_PATH_ABC123 from .common import TEST_BACKUP_PATH_ABC123, TEST_BACKUP_PATH_DEF456
from tests.common import get_fixture_path from tests.common import get_fixture_path
@ -38,10 +38,14 @@ def mocked_tarfile_fixture() -> Generator[Mock]:
@pytest.fixture(name="path_glob") @pytest.fixture(name="path_glob")
def path_glob_fixture() -> Generator[MagicMock]: def path_glob_fixture(hass: HomeAssistant) -> Generator[MagicMock]:
"""Mock path glob.""" """Mock path glob."""
with patch( with patch(
"pathlib.Path.glob", return_value=[TEST_BACKUP_PATH_ABC123] "pathlib.Path.glob",
return_value=[
Path(hass.config.path()) / "backups" / TEST_BACKUP_PATH_ABC123,
Path(hass.config.path()) / "backups" / TEST_BACKUP_PATH_DEF456,
],
) as path_glob: ) as path_glob:
yield path_glob yield path_glob

View File

@ -1,5 +1,5 @@
# serializer version: 1 # serializer version: 1
# name: test_delete_backup[found_backups0-True-1] # name: test_delete_backup[found_backups0-abc123-1-unlink_path0]
dict({ dict({
'id': 1, 'id': 1,
'result': dict({ 'result': dict({
@ -10,7 +10,7 @@
'type': 'result', 'type': 'result',
}) })
# --- # ---
# name: test_delete_backup[found_backups1-False-0] # name: test_delete_backup[found_backups1-def456-1-unlink_path1]
dict({ dict({
'id': 1, 'id': 1,
'result': dict({ 'result': dict({
@ -21,7 +21,7 @@
'type': 'result', 'type': 'result',
}) })
# --- # ---
# name: test_delete_backup[found_backups2-True-0] # name: test_delete_backup[found_backups2-abc123-0-None]
dict({ dict({
'id': 1, 'id': 1,
'result': dict({ 'result': dict({
@ -32,7 +32,7 @@
'type': 'result', 'type': 'result',
}) })
# --- # ---
# name: test_load_backups[None] # name: test_load_backups[mock_read_backup]
dict({ dict({
'id': 1, 'id': 1,
'result': dict({ 'result': dict({
@ -47,7 +47,7 @@
'type': 'result', 'type': 'result',
}) })
# --- # ---
# name: test_load_backups[None].1 # name: test_load_backups[mock_read_backup].1
dict({ dict({
'id': 2, 'id': 2,
'result': dict({ 'result': dict({
@ -82,6 +82,29 @@
'name': 'Test', 'name': 'Test',
'with_automatic_settings': True, 'with_automatic_settings': True,
}), }),
dict({
'addons': list([
]),
'agents': dict({
'backup.local': dict({
'protected': False,
'size': 1,
}),
}),
'backup_id': 'def456',
'database_included': False,
'date': '1980-01-01T00:00:00.000Z',
'failed_agent_ids': list([
]),
'folders': list([
'media',
'share',
]),
'homeassistant_included': True,
'homeassistant_version': '2024.12.0',
'name': 'Test 2',
'with_automatic_settings': None,
}),
]), ]),
'last_attempted_automatic_backup': None, 'last_attempted_automatic_backup': None,
'last_completed_automatic_backup': None, 'last_completed_automatic_backup': None,

View File

@ -12,21 +12,35 @@ from unittest.mock import MagicMock, mock_open, patch
import pytest import pytest
from syrupy import SnapshotAssertion from syrupy import SnapshotAssertion
from homeassistant.components.backup import DOMAIN from homeassistant.components.backup import DOMAIN, AgentBackup
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .common import TEST_BACKUP_ABC123, TEST_BACKUP_PATH_ABC123 from .common import (
TEST_BACKUP_ABC123,
TEST_BACKUP_DEF456,
TEST_BACKUP_PATH_ABC123,
TEST_BACKUP_PATH_DEF456,
)
from tests.typing import ClientSessionGenerator, WebSocketGenerator from tests.typing import ClientSessionGenerator, WebSocketGenerator
def mock_read_backup(backup_path: Path) -> AgentBackup:
"""Mock read backup."""
mock_backups = {
"abc123": TEST_BACKUP_ABC123,
"custom_def456": TEST_BACKUP_DEF456,
}
return mock_backups[backup_path.stem]
@pytest.fixture(name="read_backup") @pytest.fixture(name="read_backup")
def read_backup_fixture(path_glob: MagicMock) -> Generator[MagicMock]: def read_backup_fixture(path_glob: MagicMock) -> Generator[MagicMock]:
"""Mock read backup.""" """Mock read backup."""
with patch( with patch(
"homeassistant.components.backup.backup.read_backup", "homeassistant.components.backup.backup.read_backup",
return_value=TEST_BACKUP_ABC123, side_effect=mock_read_backup,
) as read_backup: ) as read_backup:
yield read_backup yield read_backup
@ -34,7 +48,7 @@ def read_backup_fixture(path_glob: MagicMock) -> Generator[MagicMock]:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"side_effect", "side_effect",
[ [
None, mock_read_backup,
OSError("Boom"), OSError("Boom"),
TarError("Boom"), TarError("Boom"),
json.JSONDecodeError("Boom", "test", 1), json.JSONDecodeError("Boom", "test", 1),
@ -94,11 +108,21 @@ async def test_upload(
@pytest.mark.usefixtures("read_backup") @pytest.mark.usefixtures("read_backup")
@pytest.mark.parametrize( @pytest.mark.parametrize(
("found_backups", "backup_exists", "unlink_calls"), ("found_backups", "backup_id", "unlink_calls", "unlink_path"),
[ [
([TEST_BACKUP_PATH_ABC123], True, 1), (
([TEST_BACKUP_PATH_ABC123], False, 0), [TEST_BACKUP_PATH_ABC123, TEST_BACKUP_PATH_DEF456],
(([], True, 0)), TEST_BACKUP_ABC123.backup_id,
1,
TEST_BACKUP_PATH_ABC123,
),
(
[TEST_BACKUP_PATH_ABC123, TEST_BACKUP_PATH_DEF456],
TEST_BACKUP_DEF456.backup_id,
1,
TEST_BACKUP_PATH_DEF456,
),
(([], TEST_BACKUP_ABC123.backup_id, 0, None)),
], ],
) )
async def test_delete_backup( async def test_delete_backup(
@ -108,8 +132,9 @@ async def test_delete_backup(
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
path_glob: MagicMock, path_glob: MagicMock,
found_backups: list[Path], found_backups: list[Path],
backup_exists: bool, backup_id: str,
unlink_calls: int, unlink_calls: int,
unlink_path: Path | None,
) -> None: ) -> None:
"""Test delete backup.""" """Test delete backup."""
assert await async_setup_component(hass, DOMAIN, {}) assert await async_setup_component(hass, DOMAIN, {})
@ -118,12 +143,13 @@ async def test_delete_backup(
path_glob.return_value = found_backups path_glob.return_value = found_backups
with ( with (
patch("pathlib.Path.exists", return_value=backup_exists), patch("pathlib.Path.unlink", autospec=True) as unlink,
patch("pathlib.Path.unlink") as unlink,
): ):
await client.send_json_auto_id( await client.send_json_auto_id(
{"type": "backup/delete", "backup_id": TEST_BACKUP_ABC123.backup_id} {"type": "backup/delete", "backup_id": backup_id}
) )
assert await client.receive_json() == snapshot assert await client.receive_json() == snapshot
assert unlink.call_count == unlink_calls assert unlink.call_count == unlink_calls
for call in unlink.mock_calls:
assert call.args[0] == unlink_path

View File

@ -1,7 +1,7 @@
"""Tests for the Backup integration.""" """Tests for the Backup integration."""
import asyncio import asyncio
from collections.abc import AsyncIterator, Iterable from collections.abc import AsyncIterator
from io import BytesIO, StringIO from io import BytesIO, StringIO
import json import json
import tarfile import tarfile
@ -15,7 +15,12 @@ from homeassistant.components.backup import AddonInfo, AgentBackup, Folder
from homeassistant.components.backup.const import DATA_MANAGER, DOMAIN from homeassistant.components.backup.const import DATA_MANAGER, DOMAIN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .common import TEST_BACKUP_ABC123, BackupAgentTest, setup_backup_integration from .common import (
TEST_BACKUP_ABC123,
BackupAgentTest,
aiter_from_iter,
setup_backup_integration,
)
from tests.common import MockUser, get_fixture_path from tests.common import MockUser, get_fixture_path
from tests.typing import ClientSessionGenerator from tests.typing import ClientSessionGenerator
@ -35,6 +40,9 @@ async def test_downloading_local_backup(
"homeassistant.components.backup.backup.CoreLocalBackupAgent.async_get_backup", "homeassistant.components.backup.backup.CoreLocalBackupAgent.async_get_backup",
return_value=TEST_BACKUP_ABC123, return_value=TEST_BACKUP_ABC123,
), ),
patch(
"homeassistant.components.backup.backup.CoreLocalBackupAgent.get_backup_path",
),
patch("pathlib.Path.exists", return_value=True), patch("pathlib.Path.exists", return_value=True),
patch( patch(
"homeassistant.components.backup.http.FileResponse", "homeassistant.components.backup.http.FileResponse",
@ -73,9 +81,14 @@ async def test_downloading_local_encrypted_backup_file_not_found(
await setup_backup_integration(hass) await setup_backup_integration(hass)
client = await hass_client() client = await hass_client()
with patch( with (
patch(
"homeassistant.components.backup.backup.CoreLocalBackupAgent.async_get_backup", "homeassistant.components.backup.backup.CoreLocalBackupAgent.async_get_backup",
return_value=TEST_BACKUP_ABC123, return_value=TEST_BACKUP_ABC123,
),
patch(
"homeassistant.components.backup.backup.CoreLocalBackupAgent.get_backup_path",
),
): ):
resp = await client.get( resp = await client.get(
"/api/backup/download/abc123?agent_id=backup.local&password=blah" "/api/backup/download/abc123?agent_id=backup.local&password=blah"
@ -93,12 +106,6 @@ async def test_downloading_local_encrypted_backup(
await _test_downloading_encrypted_backup(hass_client, "backup.local") await _test_downloading_encrypted_backup(hass_client, "backup.local")
async def aiter_from_iter(iterable: Iterable) -> AsyncIterator:
"""Convert an iterable to an async iterator."""
for i in iterable:
yield i
@patch.object(BackupAgentTest, "async_download_backup") @patch.object(BackupAgentTest, "async_download_backup")
async def test_downloading_remote_encrypted_backup( async def test_downloading_remote_encrypted_backup(
download_mock, download_mock,

View File

@ -54,6 +54,8 @@ from .common import (
LOCAL_AGENT_ID, LOCAL_AGENT_ID,
TEST_BACKUP_ABC123, TEST_BACKUP_ABC123,
TEST_BACKUP_DEF456, TEST_BACKUP_DEF456,
TEST_BACKUP_PATH_ABC123,
TEST_BACKUP_PATH_DEF456,
BackupAgentTest, BackupAgentTest,
setup_backup_platform, setup_backup_platform,
) )
@ -89,6 +91,15 @@ def generate_backup_id_fixture() -> Generator[MagicMock]:
yield mock yield mock
def mock_read_backup(backup_path: Path) -> AgentBackup:
"""Mock read backup."""
mock_backups = {
"abc123": TEST_BACKUP_ABC123,
"custom_def456": TEST_BACKUP_DEF456,
}
return mock_backups[backup_path.stem]
@pytest.mark.usefixtures("mock_backup_generation") @pytest.mark.usefixtures("mock_backup_generation")
async def test_create_backup_service( async def test_create_backup_service(
hass: HomeAssistant, hass: HomeAssistant,
@ -1311,7 +1322,11 @@ class LocalBackupAgentTest(BackupAgentTest, LocalBackupAgent):
"""Local backup agent.""" """Local backup agent."""
def get_backup_path(self, backup_id: str) -> Path: def get_backup_path(self, backup_id: str) -> Path:
"""Return the local path to a backup.""" """Return the local path to an existing backup."""
return Path("test.tar")
def get_new_backup_path(self, backup: AgentBackup) -> Path:
"""Return the local path to a new backup."""
return Path("test.tar") return Path("test.tar")
@ -2023,10 +2038,6 @@ async def test_receive_backup_file_write_error(
with ( with (
patch("pathlib.Path.open", open_mock), patch("pathlib.Path.open", open_mock),
patch(
"homeassistant.components.backup.manager.read_backup",
return_value=TEST_BACKUP_ABC123,
),
): ):
resp = await client.post( resp = await client.post(
"/api/backup/upload?agent_id=test.remote", "/api/backup/upload?agent_id=test.remote",
@ -2375,18 +2386,61 @@ async def test_receive_backup_file_read_error(
@pytest.mark.usefixtures("path_glob") @pytest.mark.usefixtures("path_glob")
@pytest.mark.parametrize( @pytest.mark.parametrize(
("agent_id", "password_param", "restore_database", "restore_homeassistant", "dir"), (
"agent_id",
"backup_id",
"password_param",
"backup_path",
"restore_database",
"restore_homeassistant",
"dir",
),
[ [
(LOCAL_AGENT_ID, {}, True, False, "backups"), (
(LOCAL_AGENT_ID, {"password": "abc123"}, False, True, "backups"), LOCAL_AGENT_ID,
("test.remote", {}, True, True, "tmp_backups"), TEST_BACKUP_ABC123.backup_id,
{},
TEST_BACKUP_PATH_ABC123,
True,
False,
"backups",
),
(
LOCAL_AGENT_ID,
TEST_BACKUP_DEF456.backup_id,
{},
TEST_BACKUP_PATH_DEF456,
True,
False,
"backups",
),
(
LOCAL_AGENT_ID,
TEST_BACKUP_ABC123.backup_id,
{"password": "abc123"},
TEST_BACKUP_PATH_ABC123,
False,
True,
"backups",
),
(
"test.remote",
TEST_BACKUP_ABC123.backup_id,
{},
TEST_BACKUP_PATH_ABC123,
True,
True,
"tmp_backups",
),
], ],
) )
async def test_restore_backup( async def test_restore_backup(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
agent_id: str, agent_id: str,
backup_id: str,
password_param: dict[str, str], password_param: dict[str, str],
backup_path: Path,
restore_database: bool, restore_database: bool,
restore_homeassistant: bool, restore_homeassistant: bool,
dir: str, dir: str,
@ -2426,14 +2480,14 @@ async def test_restore_backup(
patch.object(remote_agent, "async_download_backup") as download_mock, patch.object(remote_agent, "async_download_backup") as download_mock,
patch( patch(
"homeassistant.components.backup.backup.read_backup", "homeassistant.components.backup.backup.read_backup",
return_value=TEST_BACKUP_ABC123, side_effect=mock_read_backup,
), ),
): ):
download_mock.return_value.__aiter__.return_value = iter((b"backup data",)) download_mock.return_value.__aiter__.return_value = iter((b"backup data",))
await ws_client.send_json_auto_id( await ws_client.send_json_auto_id(
{ {
"type": "backup/restore", "type": "backup/restore",
"backup_id": TEST_BACKUP_ABC123.backup_id, "backup_id": backup_id,
"agent_id": agent_id, "agent_id": agent_id,
"restore_database": restore_database, "restore_database": restore_database,
"restore_homeassistant": restore_homeassistant, "restore_homeassistant": restore_homeassistant,
@ -2473,17 +2527,17 @@ async def test_restore_backup(
result = await ws_client.receive_json() result = await ws_client.receive_json()
assert result["success"] is True assert result["success"] is True
backup_path = f"{hass.config.path()}/{dir}/abc123.tar" full_backup_path = f"{hass.config.path()}/{dir}/{backup_path.name}"
expected_restore_file = json.dumps( expected_restore_file = json.dumps(
{ {
"path": backup_path, "path": full_backup_path,
"password": password, "password": password,
"remove_after_restore": agent_id != LOCAL_AGENT_ID, "remove_after_restore": agent_id != LOCAL_AGENT_ID,
"restore_database": restore_database, "restore_database": restore_database,
"restore_homeassistant": restore_homeassistant, "restore_homeassistant": restore_homeassistant,
} }
) )
validate_password_mock.assert_called_once_with(Path(backup_path), password) validate_password_mock.assert_called_once_with(Path(full_backup_path), password)
assert mocked_write_text.call_args[0][0] == expected_restore_file assert mocked_write_text.call_args[0][0] == expected_restore_file
assert mocked_service_call.called assert mocked_service_call.called
@ -2533,7 +2587,7 @@ async def test_restore_backup_wrong_password(
patch.object(remote_agent, "async_download_backup") as download_mock, patch.object(remote_agent, "async_download_backup") as download_mock,
patch( patch(
"homeassistant.components.backup.backup.read_backup", "homeassistant.components.backup.backup.read_backup",
return_value=TEST_BACKUP_ABC123, side_effect=mock_read_backup,
), ),
): ):
download_mock.return_value.__aiter__.return_value = iter((b"backup data",)) download_mock.return_value.__aiter__.return_value = iter((b"backup data",))
@ -2581,8 +2635,8 @@ async def test_restore_backup_wrong_password(
("parameters", "expected_error", "expected_reason"), ("parameters", "expected_error", "expected_reason"),
[ [
( (
{"backup_id": TEST_BACKUP_DEF456.backup_id}, {"backup_id": "no_such_backup"},
f"Backup def456 not found in agent {LOCAL_AGENT_ID}", f"Backup no_such_backup not found in agent {LOCAL_AGENT_ID}",
"backup_manager_error", "backup_manager_error",
), ),
( (
@ -2629,7 +2683,7 @@ async def test_restore_backup_wrong_parameters(
patch("homeassistant.core.ServiceRegistry.async_call") as mocked_service_call, patch("homeassistant.core.ServiceRegistry.async_call") as mocked_service_call,
patch( patch(
"homeassistant.components.backup.backup.read_backup", "homeassistant.components.backup.backup.read_backup",
return_value=TEST_BACKUP_ABC123, side_effect=mock_read_backup,
), ),
): ):
await ws_client.send_json_auto_id( await ws_client.send_json_auto_id(