mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Support decrypting backups when downloading (#135728)
* Support decrypting backups when downloading * Close stream * Use test helper * Wait for worker to finish * Simplify * Update backup.json * Simplify * Revert change from the future
This commit is contained in:
parent
6fdccda225
commit
9db6be11f7
@ -4,18 +4,23 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import cast
|
import threading
|
||||||
|
from typing import IO, cast
|
||||||
|
|
||||||
from aiohttp import BodyPartReader
|
from aiohttp import BodyPartReader
|
||||||
from aiohttp.hdrs import CONTENT_DISPOSITION
|
from aiohttp.hdrs import CONTENT_DISPOSITION
|
||||||
from aiohttp.web import FileResponse, Request, Response, StreamResponse
|
from aiohttp.web import FileResponse, Request, Response, StreamResponse
|
||||||
|
from multidict import istr
|
||||||
|
|
||||||
from homeassistant.components.http import KEY_HASS, HomeAssistantView, require_admin
|
from homeassistant.components.http import KEY_HASS, HomeAssistantView, require_admin
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.util import slugify
|
from homeassistant.util import slugify
|
||||||
|
|
||||||
|
from . import util
|
||||||
|
from .agent import BackupAgent
|
||||||
from .const import DATA_MANAGER
|
from .const import DATA_MANAGER
|
||||||
|
from .manager import BackupManager
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -43,8 +48,13 @@ class DownloadBackupView(HomeAssistantView):
|
|||||||
agent_id = request.query.getone("agent_id")
|
agent_id = request.query.getone("agent_id")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return Response(status=HTTPStatus.BAD_REQUEST)
|
return Response(status=HTTPStatus.BAD_REQUEST)
|
||||||
|
try:
|
||||||
|
password = request.query.getone("password")
|
||||||
|
except KeyError:
|
||||||
|
password = None
|
||||||
|
|
||||||
manager = request.app[KEY_HASS].data[DATA_MANAGER]
|
hass = request.app[KEY_HASS]
|
||||||
|
manager = hass.data[DATA_MANAGER]
|
||||||
if agent_id not in manager.backup_agents:
|
if agent_id not in manager.backup_agents:
|
||||||
return Response(status=HTTPStatus.BAD_REQUEST)
|
return Response(status=HTTPStatus.BAD_REQUEST)
|
||||||
agent = manager.backup_agents[agent_id]
|
agent = manager.backup_agents[agent_id]
|
||||||
@ -58,6 +68,24 @@ class DownloadBackupView(HomeAssistantView):
|
|||||||
headers = {
|
headers = {
|
||||||
CONTENT_DISPOSITION: f"attachment; filename={slugify(backup.name)}.tar"
|
CONTENT_DISPOSITION: f"attachment; filename={slugify(backup.name)}.tar"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if not password:
|
||||||
|
return await self._send_backup_no_password(
|
||||||
|
request, headers, backup_id, agent_id, agent, manager
|
||||||
|
)
|
||||||
|
return await self._send_backup_with_password(
|
||||||
|
hass, request, headers, backup_id, agent_id, password, agent, manager
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _send_backup_no_password(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
headers: dict[istr, str],
|
||||||
|
backup_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
agent: BackupAgent,
|
||||||
|
manager: BackupManager,
|
||||||
|
) -> StreamResponse | FileResponse | Response:
|
||||||
if agent_id in manager.local_backup_agents:
|
if agent_id in manager.local_backup_agents:
|
||||||
local_agent = manager.local_backup_agents[agent_id]
|
local_agent = manager.local_backup_agents[agent_id]
|
||||||
path = local_agent.get_backup_path(backup_id)
|
path = local_agent.get_backup_path(backup_id)
|
||||||
@ -70,6 +98,50 @@ class DownloadBackupView(HomeAssistantView):
|
|||||||
await response.write(chunk)
|
await response.write(chunk)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def _send_backup_with_password(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
request: Request,
|
||||||
|
headers: dict[istr, str],
|
||||||
|
backup_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
password: str,
|
||||||
|
agent: BackupAgent,
|
||||||
|
manager: BackupManager,
|
||||||
|
) -> StreamResponse | FileResponse | Response:
|
||||||
|
reader: IO[bytes]
|
||||||
|
if agent_id in manager.local_backup_agents:
|
||||||
|
local_agent = manager.local_backup_agents[agent_id]
|
||||||
|
path = local_agent.get_backup_path(backup_id)
|
||||||
|
try:
|
||||||
|
reader = await hass.async_add_executor_job(open, path.as_posix(), "rb")
|
||||||
|
except FileNotFoundError:
|
||||||
|
return Response(status=HTTPStatus.NOT_FOUND)
|
||||||
|
else:
|
||||||
|
stream = await agent.async_download_backup(backup_id)
|
||||||
|
reader = cast(IO[bytes], util.AsyncIteratorReader(hass, stream))
|
||||||
|
|
||||||
|
worker_done_event = asyncio.Event()
|
||||||
|
|
||||||
|
def on_done() -> None:
|
||||||
|
"""Call by the worker thread when it's done."""
|
||||||
|
hass.loop.call_soon_threadsafe(worker_done_event.set)
|
||||||
|
|
||||||
|
stream = util.AsyncIteratorWriter(hass)
|
||||||
|
worker = threading.Thread(
|
||||||
|
target=util.decrypt_backup, args=[reader, stream, password, on_done]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
worker.start()
|
||||||
|
response = StreamResponse(status=HTTPStatus.OK, headers=headers)
|
||||||
|
await response.prepare(request)
|
||||||
|
async for chunk in stream:
|
||||||
|
await response.write(chunk)
|
||||||
|
return response
|
||||||
|
finally:
|
||||||
|
reader.close()
|
||||||
|
await worker_done_event.wait()
|
||||||
|
|
||||||
|
|
||||||
class UploadBackupView(HomeAssistantView):
|
class UploadBackupView(HomeAssistantView):
|
||||||
"""Generate backup view."""
|
"""Generate backup view."""
|
||||||
|
@ -1033,10 +1033,12 @@ class BackupManager:
|
|||||||
validate_password_stream(reader, password)
|
validate_password_stream(reader, password)
|
||||||
except backup_util.IncorrectPassword as err:
|
except backup_util.IncorrectPassword as err:
|
||||||
raise IncorrectPasswordError from err
|
raise IncorrectPasswordError from err
|
||||||
except backup_util.UnsuppertedSecureTarVersion as err:
|
except backup_util.UnsupportedSecureTarVersion as err:
|
||||||
raise DecryptOnDowloadNotSupported from err
|
raise DecryptOnDowloadNotSupported from err
|
||||||
except backup_util.DecryptError as err:
|
except backup_util.DecryptError as err:
|
||||||
raise BackupManagerError(str(err)) from err
|
raise BackupManagerError(str(err)) from err
|
||||||
|
finally:
|
||||||
|
reader.close()
|
||||||
|
|
||||||
|
|
||||||
class KnownBackups:
|
class KnownBackups:
|
||||||
|
@ -3,14 +3,23 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator, Callable
|
||||||
from pathlib import Path
|
import copy
|
||||||
|
from io import BytesIO
|
||||||
|
import json
|
||||||
|
from pathlib import Path, PurePath
|
||||||
from queue import SimpleQueue
|
from queue import SimpleQueue
|
||||||
import tarfile
|
import tarfile
|
||||||
from typing import IO, cast
|
from typing import IO, Self, cast
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from securetar import VERSION_HEADER, SecureTarFile, SecureTarReadError
|
from securetar import (
|
||||||
|
PLAINTEXT_SIZE_HEADER,
|
||||||
|
VERSION_HEADER,
|
||||||
|
SecureTarError,
|
||||||
|
SecureTarFile,
|
||||||
|
SecureTarReadError,
|
||||||
|
)
|
||||||
|
|
||||||
from homeassistant.backup_restore import password_to_key
|
from homeassistant.backup_restore import password_to_key
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -24,7 +33,7 @@ class DecryptError(Exception):
|
|||||||
"""Error during decryption."""
|
"""Error during decryption."""
|
||||||
|
|
||||||
|
|
||||||
class UnsuppertedSecureTarVersion(DecryptError):
|
class UnsupportedSecureTarVersion(DecryptError):
|
||||||
"""Unsupported securetar version."""
|
"""Unsupported securetar version."""
|
||||||
|
|
||||||
|
|
||||||
@ -157,6 +166,33 @@ class AsyncIteratorReader:
|
|||||||
self._buffer = None
|
self._buffer = None
|
||||||
return bytes(result)
|
return bytes(result)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the iterator."""
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncIteratorWriter:
|
||||||
|
"""Wrap an AsyncIterator."""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
|
"""Initialize the wrapper."""
|
||||||
|
self._hass = hass
|
||||||
|
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
|
||||||
|
|
||||||
|
def __aiter__(self) -> Self:
|
||||||
|
"""Return the iterator."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self) -> bytes:
|
||||||
|
"""Get the next chunk from the iterator."""
|
||||||
|
if data := await self._queue.get():
|
||||||
|
return data
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
def write(self, s: bytes, /) -> int:
|
||||||
|
"""Write data to the iterator."""
|
||||||
|
asyncio.run_coroutine_threadsafe(self._queue.put(s), self._hass.loop).result()
|
||||||
|
return len(s)
|
||||||
|
|
||||||
|
|
||||||
def validate_password_stream(
|
def validate_password_stream(
|
||||||
input_stream: IO[bytes],
|
input_stream: IO[bytes],
|
||||||
@ -170,7 +206,7 @@ def validate_password_stream(
|
|||||||
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
|
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
|
||||||
continue
|
continue
|
||||||
if obj.pax_headers.get(VERSION_HEADER) != "2.0":
|
if obj.pax_headers.get(VERSION_HEADER) != "2.0":
|
||||||
raise UnsuppertedSecureTarVersion
|
raise UnsupportedSecureTarVersion
|
||||||
istf = SecureTarFile(
|
istf = SecureTarFile(
|
||||||
None, # Not used
|
None, # Not used
|
||||||
gzip=False,
|
gzip=False,
|
||||||
@ -187,6 +223,68 @@ def validate_password_stream(
|
|||||||
raise BackupEmpty
|
raise BackupEmpty
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_backup(
|
||||||
|
input_stream: IO[bytes],
|
||||||
|
output_stream: IO[bytes],
|
||||||
|
password: str | None,
|
||||||
|
on_done: Callable[[], None],
|
||||||
|
) -> None:
|
||||||
|
"""Decrypt a backup."""
|
||||||
|
try:
|
||||||
|
with (
|
||||||
|
tarfile.open(
|
||||||
|
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
|
||||||
|
) as input_tar,
|
||||||
|
tarfile.open(
|
||||||
|
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
|
||||||
|
) as output_tar,
|
||||||
|
):
|
||||||
|
_decrypt_backup(input_tar, output_tar, password)
|
||||||
|
except (DecryptError, SecureTarError, tarfile.TarError) as err:
|
||||||
|
LOGGER.warning("Error decrypting backup: %s", err)
|
||||||
|
finally:
|
||||||
|
output_stream.write(b"") # Write an empty chunk to signal the end of the stream
|
||||||
|
on_done()
|
||||||
|
|
||||||
|
|
||||||
|
def _decrypt_backup(
|
||||||
|
input_tar: tarfile.TarFile,
|
||||||
|
output_tar: tarfile.TarFile,
|
||||||
|
password: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Decrypt a backup."""
|
||||||
|
for obj in input_tar:
|
||||||
|
# We compare with PurePath to avoid issues with different path separators,
|
||||||
|
# for example when backup.json is added as "./backup.json"
|
||||||
|
if PurePath(obj.name) == PurePath("backup.json"):
|
||||||
|
# Rewrite the backup.json file to indicate that the backup is decrypted
|
||||||
|
if not (reader := input_tar.extractfile(obj)):
|
||||||
|
raise DecryptError
|
||||||
|
metadata = json_loads_object(reader.read())
|
||||||
|
metadata["protected"] = False
|
||||||
|
updated_metadata_b = json.dumps(metadata).encode()
|
||||||
|
metadata_obj = copy.deepcopy(obj)
|
||||||
|
metadata_obj.size = len(updated_metadata_b)
|
||||||
|
output_tar.addfile(metadata_obj, BytesIO(updated_metadata_b))
|
||||||
|
continue
|
||||||
|
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
|
||||||
|
output_tar.addfile(obj, input_tar.extractfile(obj))
|
||||||
|
continue
|
||||||
|
if obj.pax_headers.get(VERSION_HEADER) != "2.0":
|
||||||
|
raise UnsupportedSecureTarVersion
|
||||||
|
decrypted_obj = copy.deepcopy(obj)
|
||||||
|
decrypted_obj.size = int(obj.pax_headers[PLAINTEXT_SIZE_HEADER])
|
||||||
|
istf = SecureTarFile(
|
||||||
|
None, # Not used
|
||||||
|
gzip=False,
|
||||||
|
key=password_to_key(password) if password is not None else None,
|
||||||
|
mode="r",
|
||||||
|
fileobj=input_tar.extractfile(obj),
|
||||||
|
)
|
||||||
|
with istf.decrypt(obj) as decrypted:
|
||||||
|
output_tar.addfile(decrypted_obj, decrypted)
|
||||||
|
|
||||||
|
|
||||||
async def receive_file(
|
async def receive_file(
|
||||||
hass: HomeAssistant, contents: aiohttp.BodyPartReader, path: Path
|
hass: HomeAssistant, contents: aiohttp.BodyPartReader, path: Path
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -9,11 +9,14 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.backup import DOMAIN
|
||||||
from homeassistant.components.backup.manager import NewBackup, WrittenBackup
|
from homeassistant.components.backup.manager import NewBackup, WrittenBackup
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
from .common import TEST_BACKUP_PATH_ABC123
|
from .common import TEST_BACKUP_PATH_ABC123
|
||||||
|
|
||||||
|
from tests.common import get_fixture_path
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="mocked_json_bytes")
|
@pytest.fixture(name="mocked_json_bytes")
|
||||||
def mocked_json_bytes_fixture() -> Generator[Mock]:
|
def mocked_json_bytes_fixture() -> Generator[Mock]:
|
||||||
@ -113,3 +116,18 @@ def mock_backup_generation_fixture(
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_backups() -> Generator[None]:
|
||||||
|
"""Fixture to setup test backups."""
|
||||||
|
# pylint: disable-next=import-outside-toplevel
|
||||||
|
from homeassistant.components.backup import backup as core_backup
|
||||||
|
|
||||||
|
class CoreLocalBackupAgent(core_backup.CoreLocalBackupAgent):
|
||||||
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
|
super().__init__(hass)
|
||||||
|
self._backup_dir = get_fixture_path("test_backups", DOMAIN)
|
||||||
|
|
||||||
|
with patch.object(core_backup, "CoreLocalBackupAgent", CoreLocalBackupAgent):
|
||||||
|
yield
|
||||||
|
@ -1,18 +1,23 @@
|
|||||||
"""Tests for the Backup integration."""
|
"""Tests for the Backup integration."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from io import StringIO
|
from collections.abc import AsyncIterator, Iterable
|
||||||
|
from io import BytesIO, StringIO
|
||||||
|
import json
|
||||||
|
import tarfile
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.backup.const import DATA_MANAGER
|
from homeassistant.components.backup import AddonInfo, AgentBackup, Folder
|
||||||
|
from homeassistant.components.backup.const import DATA_MANAGER, DOMAIN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
from .common import TEST_BACKUP_ABC123, BackupAgentTest, setup_backup_integration
|
from .common import TEST_BACKUP_ABC123, BackupAgentTest, setup_backup_integration
|
||||||
|
|
||||||
from tests.common import MockUser
|
from tests.common import MockUser, get_fixture_path
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
|
|
||||||
@ -45,8 +50,9 @@ async def test_downloading_remote_backup(
|
|||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test downloading a remote backup."""
|
"""Test downloading a remote backup."""
|
||||||
await setup_backup_integration(hass)
|
await setup_backup_integration(
|
||||||
hass.data[DATA_MANAGER].backup_agents["domain.test"] = BackupAgentTest("test")
|
hass, backups={"test.test": [TEST_BACKUP_ABC123]}, remote_agents=["test"]
|
||||||
|
)
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
|
|
||||||
@ -54,11 +60,140 @@ async def test_downloading_remote_backup(
|
|||||||
patch.object(BackupAgentTest, "async_download_backup") as download_mock,
|
patch.object(BackupAgentTest, "async_download_backup") as download_mock,
|
||||||
):
|
):
|
||||||
download_mock.return_value.__aiter__.return_value = iter((b"backup data",))
|
download_mock.return_value.__aiter__.return_value = iter((b"backup data",))
|
||||||
resp = await client.get("/api/backup/download/abc123?agent_id=domain.test")
|
resp = await client.get("/api/backup/download/abc123?agent_id=test.test")
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
assert await resp.content.read() == b"backup data"
|
assert await resp.content.read() == b"backup data"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_downloading_local_encrypted_backup_file_not_found(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test downloading a local backup file."""
|
||||||
|
await setup_backup_integration(hass)
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.backup.backup.CoreLocalBackupAgent.async_get_backup",
|
||||||
|
return_value=TEST_BACKUP_ABC123,
|
||||||
|
):
|
||||||
|
resp = await client.get(
|
||||||
|
"/api/backup/download/abc123?agent_id=backup.local&password=blah"
|
||||||
|
)
|
||||||
|
assert resp.status == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_backups")
|
||||||
|
async def test_downloading_local_encrypted_backup(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test downloading a local backup file."""
|
||||||
|
await setup_backup_integration(hass)
|
||||||
|
await _test_downloading_encrypted_backup(hass_client, "backup.local")
|
||||||
|
|
||||||
|
|
||||||
|
async def aiter_from_iter(iterable: Iterable) -> AsyncIterator:
|
||||||
|
"""Convert an iterable to an async iterator."""
|
||||||
|
for i in iterable:
|
||||||
|
yield i
|
||||||
|
|
||||||
|
|
||||||
|
@patch.object(BackupAgentTest, "async_download_backup")
|
||||||
|
async def test_downloading_remote_encrypted_backup(
|
||||||
|
download_mock,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test downloading a local backup file."""
|
||||||
|
backup_path = get_fixture_path("test_backups/ed1608a9.tar", DOMAIN)
|
||||||
|
await setup_backup_integration(hass)
|
||||||
|
hass.data[DATA_MANAGER].backup_agents["domain.test"] = BackupAgentTest(
|
||||||
|
"test",
|
||||||
|
[
|
||||||
|
AgentBackup(
|
||||||
|
addons=[AddonInfo(name="Test", slug="test", version="1.0.0")],
|
||||||
|
backup_id="ed1608a9",
|
||||||
|
database_included=True,
|
||||||
|
date="1970-01-01T00:00:00Z",
|
||||||
|
extra_metadata={},
|
||||||
|
folders=[Folder.MEDIA, Folder.SHARE],
|
||||||
|
homeassistant_included=True,
|
||||||
|
homeassistant_version="2024.12.0",
|
||||||
|
name="Test",
|
||||||
|
protected=True,
|
||||||
|
size=13,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def download_backup(backup_id: str, **kwargs: Any) -> AsyncIterator[bytes]:
|
||||||
|
return aiter_from_iter((backup_path.read_bytes(),))
|
||||||
|
|
||||||
|
download_mock.side_effect = download_backup
|
||||||
|
await _test_downloading_encrypted_backup(hass_client, "domain.test")
|
||||||
|
|
||||||
|
|
||||||
|
async def _test_downloading_encrypted_backup(
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
agent_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test downloading an encrypted backup file."""
|
||||||
|
# Try downloading without supplying a password
|
||||||
|
client = await hass_client()
|
||||||
|
resp = await client.get(f"/api/backup/download/ed1608a9?agent_id={agent_id}")
|
||||||
|
assert resp.status == 200
|
||||||
|
backup = await resp.read()
|
||||||
|
# We expect a valid outer tar file, but the inner tar file is encrypted and
|
||||||
|
# can't be read
|
||||||
|
with tarfile.open(fileobj=BytesIO(backup), mode="r") as outer_tar:
|
||||||
|
enc_metadata = json.loads(outer_tar.extractfile("./backup.json").read())
|
||||||
|
assert enc_metadata["protected"] is True
|
||||||
|
with (
|
||||||
|
outer_tar.extractfile("core.tar.gz") as inner_tar_file,
|
||||||
|
pytest.raises(tarfile.ReadError, match="file could not be opened"),
|
||||||
|
):
|
||||||
|
# pylint: disable-next=consider-using-with
|
||||||
|
tarfile.open(fileobj=inner_tar_file, mode="r")
|
||||||
|
|
||||||
|
# Download with the wrong password
|
||||||
|
resp = await client.get(
|
||||||
|
f"/api/backup/download/ed1608a9?agent_id={agent_id}&password=wrong"
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
backup = await resp.read()
|
||||||
|
# We expect a truncated outer tar file
|
||||||
|
with (
|
||||||
|
tarfile.open(fileobj=BytesIO(backup), mode="r") as outer_tar,
|
||||||
|
pytest.raises(tarfile.ReadError, match="unexpected end of data"),
|
||||||
|
):
|
||||||
|
outer_tar.getnames()
|
||||||
|
|
||||||
|
# Finally download with the correct password
|
||||||
|
resp = await client.get(
|
||||||
|
f"/api/backup/download/ed1608a9?agent_id={agent_id}&password=hunter2"
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
backup = await resp.read()
|
||||||
|
# We expect a valid outer tar file, the inner tar file is decrypted and can be read
|
||||||
|
with (
|
||||||
|
tarfile.open(fileobj=BytesIO(backup), mode="r") as outer_tar,
|
||||||
|
):
|
||||||
|
dec_metadata = json.loads(outer_tar.extractfile("./backup.json").read())
|
||||||
|
assert dec_metadata == enc_metadata | {"protected": False}
|
||||||
|
with (
|
||||||
|
outer_tar.extractfile("core.tar.gz") as inner_tar_file,
|
||||||
|
tarfile.open(fileobj=inner_tar_file, mode="r") as inner_tar,
|
||||||
|
):
|
||||||
|
assert inner_tar.getnames() == [
|
||||||
|
".",
|
||||||
|
"README.md",
|
||||||
|
"test_symlink",
|
||||||
|
"test1",
|
||||||
|
"test1/script.sh",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def test_downloading_backup_not_found(
|
async def test_downloading_backup_not_found(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
|
@ -36,7 +36,7 @@ from .common import (
|
|||||||
setup_backup_platform,
|
setup_backup_platform,
|
||||||
)
|
)
|
||||||
|
|
||||||
from tests.common import async_fire_time_changed, async_mock_service, get_fixture_path
|
from tests.common import async_fire_time_changed, async_mock_service
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
BACKUP_CALL = call(
|
BACKUP_CALL = call(
|
||||||
@ -2556,21 +2556,6 @@ async def test_subscribe_event(
|
|||||||
assert await client.receive_json() == snapshot
|
assert await client.receive_json() == snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_backups() -> Generator[None]:
|
|
||||||
"""Fixture to setup test backups."""
|
|
||||||
# pylint: disable-next=import-outside-toplevel
|
|
||||||
from homeassistant.components.backup import backup as core_backup
|
|
||||||
|
|
||||||
class CoreLocalBackupAgent(core_backup.CoreLocalBackupAgent):
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
|
||||||
super().__init__(hass)
|
|
||||||
self._backup_dir = get_fixture_path("test_backups", DOMAIN)
|
|
||||||
|
|
||||||
with patch.object(core_backup, "CoreLocalBackupAgent", CoreLocalBackupAgent):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("agent_id", "backup_id", "password"),
|
("agent_id", "backup_id", "password"),
|
||||||
[
|
[
|
||||||
|
Loading…
x
Reference in New Issue
Block a user