From 9db6be11f7d97141fcb18174e2e082e453584098 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Thu, 16 Jan 2025 12:36:12 +0100 Subject: [PATCH] Support decrypting backups when downloading (#135728) * Support decrypting backups when downloading * Close stream * Use test helper * Wait for worker to finish * Simplify * Update backup.json * Simplify * Revert change from the future --- homeassistant/components/backup/http.py | 76 ++++++++++- homeassistant/components/backup/manager.py | 4 +- homeassistant/components/backup/util.py | 110 ++++++++++++++- tests/components/backup/conftest.py | 18 +++ tests/components/backup/test_http.py | 147 ++++++++++++++++++++- tests/components/backup/test_websocket.py | 17 +-- 6 files changed, 341 insertions(+), 31 deletions(-) diff --git a/homeassistant/components/backup/http.py b/homeassistant/components/backup/http.py index 73a8c8eb602..b909b2728a7 100644 --- a/homeassistant/components/backup/http.py +++ b/homeassistant/components/backup/http.py @@ -4,18 +4,23 @@ from __future__ import annotations import asyncio from http import HTTPStatus -from typing import cast +import threading +from typing import IO, cast from aiohttp import BodyPartReader from aiohttp.hdrs import CONTENT_DISPOSITION from aiohttp.web import FileResponse, Request, Response, StreamResponse +from multidict import istr from homeassistant.components.http import KEY_HASS, HomeAssistantView, require_admin from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.util import slugify +from . import util +from .agent import BackupAgent from .const import DATA_MANAGER +from .manager import BackupManager @callback @@ -43,8 +48,13 @@ class DownloadBackupView(HomeAssistantView): agent_id = request.query.getone("agent_id") except KeyError: return Response(status=HTTPStatus.BAD_REQUEST) + try: + password = request.query.getone("password") + except KeyError: + password = None - manager = request.app[KEY_HASS].data[DATA_MANAGER] + hass = request.app[KEY_HASS] + manager = hass.data[DATA_MANAGER] if agent_id not in manager.backup_agents: return Response(status=HTTPStatus.BAD_REQUEST) agent = manager.backup_agents[agent_id] @@ -58,6 +68,24 @@ class DownloadBackupView(HomeAssistantView): headers = { CONTENT_DISPOSITION: f"attachment; filename={slugify(backup.name)}.tar" } + + if not password: + return await self._send_backup_no_password( + request, headers, backup_id, agent_id, agent, manager + ) + return await self._send_backup_with_password( + hass, request, headers, backup_id, agent_id, password, agent, manager + ) + + async def _send_backup_no_password( + self, + request: Request, + headers: dict[istr, str], + backup_id: str, + agent_id: str, + agent: BackupAgent, + manager: BackupManager, + ) -> StreamResponse | FileResponse | Response: if agent_id in manager.local_backup_agents: local_agent = manager.local_backup_agents[agent_id] path = local_agent.get_backup_path(backup_id) @@ -70,6 +98,50 @@ class DownloadBackupView(HomeAssistantView): await response.write(chunk) return response + async def _send_backup_with_password( + self, + hass: HomeAssistant, + request: Request, + headers: dict[istr, str], + backup_id: str, + agent_id: str, + password: str, + agent: BackupAgent, + manager: BackupManager, + ) -> StreamResponse | FileResponse | Response: + reader: IO[bytes] + if agent_id in manager.local_backup_agents: + local_agent = manager.local_backup_agents[agent_id] + path = local_agent.get_backup_path(backup_id) + try: + reader = await hass.async_add_executor_job(open, path.as_posix(), "rb") + except FileNotFoundError: + return Response(status=HTTPStatus.NOT_FOUND) + else: + stream = await agent.async_download_backup(backup_id) + reader = cast(IO[bytes], util.AsyncIteratorReader(hass, stream)) + + worker_done_event = asyncio.Event() + + def on_done() -> None: + """Call by the worker thread when it's done.""" + hass.loop.call_soon_threadsafe(worker_done_event.set) + + stream = util.AsyncIteratorWriter(hass) + worker = threading.Thread( + target=util.decrypt_backup, args=[reader, stream, password, on_done] + ) + try: + worker.start() + response = StreamResponse(status=HTTPStatus.OK, headers=headers) + await response.prepare(request) + async for chunk in stream: + await response.write(chunk) + return response + finally: + reader.close() + await worker_done_event.wait() + class UploadBackupView(HomeAssistantView): """Generate backup view.""" diff --git a/homeassistant/components/backup/manager.py b/homeassistant/components/backup/manager.py index 73bbfafdcf8..58600d0a4c0 100644 --- a/homeassistant/components/backup/manager.py +++ b/homeassistant/components/backup/manager.py @@ -1033,10 +1033,12 @@ class BackupManager: validate_password_stream(reader, password) except backup_util.IncorrectPassword as err: raise IncorrectPasswordError from err - except backup_util.UnsuppertedSecureTarVersion as err: + except backup_util.UnsupportedSecureTarVersion as err: raise DecryptOnDowloadNotSupported from err except backup_util.DecryptError as err: raise BackupManagerError(str(err)) from err + finally: + reader.close() class KnownBackups: diff --git a/homeassistant/components/backup/util.py b/homeassistant/components/backup/util.py index ae0244591d8..55f3c3c05c7 100644 --- a/homeassistant/components/backup/util.py +++ b/homeassistant/components/backup/util.py @@ -3,14 +3,23 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncIterator -from pathlib import Path +from collections.abc import AsyncIterator, Callable +import copy +from io import BytesIO +import json +from pathlib import Path, PurePath from queue import SimpleQueue import tarfile -from typing import IO, cast +from typing import IO, Self, cast import aiohttp -from securetar import VERSION_HEADER, SecureTarFile, SecureTarReadError +from securetar import ( + PLAINTEXT_SIZE_HEADER, + VERSION_HEADER, + SecureTarError, + SecureTarFile, + SecureTarReadError, +) from homeassistant.backup_restore import password_to_key from homeassistant.core import HomeAssistant @@ -24,7 +33,7 @@ class DecryptError(Exception): """Error during decryption.""" -class UnsuppertedSecureTarVersion(DecryptError): +class UnsupportedSecureTarVersion(DecryptError): """Unsupported securetar version.""" @@ -157,6 +166,33 @@ class AsyncIteratorReader: self._buffer = None return bytes(result) + def close(self) -> None: + """Close the iterator.""" + + +class AsyncIteratorWriter: + """Wrap an AsyncIterator.""" + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the wrapper.""" + self._hass = hass + self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1) + + def __aiter__(self) -> Self: + """Return the iterator.""" + return self + + async def __anext__(self) -> bytes: + """Get the next chunk from the iterator.""" + if data := await self._queue.get(): + return data + raise StopAsyncIteration + + def write(self, s: bytes, /) -> int: + """Write data to the iterator.""" + asyncio.run_coroutine_threadsafe(self._queue.put(s), self._hass.loop).result() + return len(s) + def validate_password_stream( input_stream: IO[bytes], @@ -170,7 +206,7 @@ def validate_password_stream( if not obj.name.endswith((".tar", ".tgz", ".tar.gz")): continue if obj.pax_headers.get(VERSION_HEADER) != "2.0": - raise UnsuppertedSecureTarVersion + raise UnsupportedSecureTarVersion istf = SecureTarFile( None, # Not used gzip=False, @@ -187,6 +223,68 @@ def validate_password_stream( raise BackupEmpty +def decrypt_backup( + input_stream: IO[bytes], + output_stream: IO[bytes], + password: str | None, + on_done: Callable[[], None], +) -> None: + """Decrypt a backup.""" + try: + with ( + tarfile.open( + fileobj=input_stream, mode="r|", bufsize=BUF_SIZE + ) as input_tar, + tarfile.open( + fileobj=output_stream, mode="w|", bufsize=BUF_SIZE + ) as output_tar, + ): + _decrypt_backup(input_tar, output_tar, password) + except (DecryptError, SecureTarError, tarfile.TarError) as err: + LOGGER.warning("Error decrypting backup: %s", err) + finally: + output_stream.write(b"") # Write an empty chunk to signal the end of the stream + on_done() + + +def _decrypt_backup( + input_tar: tarfile.TarFile, + output_tar: tarfile.TarFile, + password: str | None, +) -> None: + """Decrypt a backup.""" + for obj in input_tar: + # We compare with PurePath to avoid issues with different path separators, + # for example when backup.json is added as "./backup.json" + if PurePath(obj.name) == PurePath("backup.json"): + # Rewrite the backup.json file to indicate that the backup is decrypted + if not (reader := input_tar.extractfile(obj)): + raise DecryptError + metadata = json_loads_object(reader.read()) + metadata["protected"] = False + updated_metadata_b = json.dumps(metadata).encode() + metadata_obj = copy.deepcopy(obj) + metadata_obj.size = len(updated_metadata_b) + output_tar.addfile(metadata_obj, BytesIO(updated_metadata_b)) + continue + if not obj.name.endswith((".tar", ".tgz", ".tar.gz")): + output_tar.addfile(obj, input_tar.extractfile(obj)) + continue + if obj.pax_headers.get(VERSION_HEADER) != "2.0": + raise UnsupportedSecureTarVersion + decrypted_obj = copy.deepcopy(obj) + decrypted_obj.size = int(obj.pax_headers[PLAINTEXT_SIZE_HEADER]) + istf = SecureTarFile( + None, # Not used + gzip=False, + key=password_to_key(password) if password is not None else None, + mode="r", + fileobj=input_tar.extractfile(obj), + ) + with istf.decrypt(obj) as decrypted: + output_tar.addfile(decrypted_obj, decrypted) + + async def receive_file( hass: HomeAssistant, contents: aiohttp.BodyPartReader, path: Path ) -> None: diff --git a/tests/components/backup/conftest.py b/tests/components/backup/conftest.py index ee855fb70f2..29a6b27db56 100644 --- a/tests/components/backup/conftest.py +++ b/tests/components/backup/conftest.py @@ -9,11 +9,14 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest +from homeassistant.components.backup import DOMAIN from homeassistant.components.backup.manager import NewBackup, WrittenBackup from homeassistant.core import HomeAssistant from .common import TEST_BACKUP_PATH_ABC123 +from tests.common import get_fixture_path + @pytest.fixture(name="mocked_json_bytes") def mocked_json_bytes_fixture() -> Generator[Mock]: @@ -113,3 +116,18 @@ def mock_backup_generation_fixture( ), ): yield + + +@pytest.fixture +def mock_backups() -> Generator[None]: + """Fixture to setup test backups.""" + # pylint: disable-next=import-outside-toplevel + from homeassistant.components.backup import backup as core_backup + + class CoreLocalBackupAgent(core_backup.CoreLocalBackupAgent): + def __init__(self, hass: HomeAssistant) -> None: + super().__init__(hass) + self._backup_dir = get_fixture_path("test_backups", DOMAIN) + + with patch.object(core_backup, "CoreLocalBackupAgent", CoreLocalBackupAgent): + yield diff --git a/tests/components/backup/test_http.py b/tests/components/backup/test_http.py index c071a0d8386..693434631b9 100644 --- a/tests/components/backup/test_http.py +++ b/tests/components/backup/test_http.py @@ -1,18 +1,23 @@ """Tests for the Backup integration.""" import asyncio -from io import StringIO +from collections.abc import AsyncIterator, Iterable +from io import BytesIO, StringIO +import json +import tarfile +from typing import Any from unittest.mock import patch from aiohttp import web import pytest -from homeassistant.components.backup.const import DATA_MANAGER +from homeassistant.components.backup import AddonInfo, AgentBackup, Folder +from homeassistant.components.backup.const import DATA_MANAGER, DOMAIN from homeassistant.core import HomeAssistant from .common import TEST_BACKUP_ABC123, BackupAgentTest, setup_backup_integration -from tests.common import MockUser +from tests.common import MockUser, get_fixture_path from tests.typing import ClientSessionGenerator @@ -45,8 +50,9 @@ async def test_downloading_remote_backup( hass_client: ClientSessionGenerator, ) -> None: """Test downloading a remote backup.""" - await setup_backup_integration(hass) - hass.data[DATA_MANAGER].backup_agents["domain.test"] = BackupAgentTest("test") + await setup_backup_integration( + hass, backups={"test.test": [TEST_BACKUP_ABC123]}, remote_agents=["test"] + ) client = await hass_client() @@ -54,11 +60,140 @@ async def test_downloading_remote_backup( patch.object(BackupAgentTest, "async_download_backup") as download_mock, ): download_mock.return_value.__aiter__.return_value = iter((b"backup data",)) - resp = await client.get("/api/backup/download/abc123?agent_id=domain.test") + resp = await client.get("/api/backup/download/abc123?agent_id=test.test") assert resp.status == 200 assert await resp.content.read() == b"backup data" +async def test_downloading_local_encrypted_backup_file_not_found( + hass: HomeAssistant, + hass_client: ClientSessionGenerator, +) -> None: + """Test downloading a local backup file.""" + await setup_backup_integration(hass) + client = await hass_client() + + with patch( + "homeassistant.components.backup.backup.CoreLocalBackupAgent.async_get_backup", + return_value=TEST_BACKUP_ABC123, + ): + resp = await client.get( + "/api/backup/download/abc123?agent_id=backup.local&password=blah" + ) + assert resp.status == 404 + + +@pytest.mark.usefixtures("mock_backups") +async def test_downloading_local_encrypted_backup( + hass: HomeAssistant, + hass_client: ClientSessionGenerator, +) -> None: + """Test downloading a local backup file.""" + await setup_backup_integration(hass) + await _test_downloading_encrypted_backup(hass_client, "backup.local") + + +async def aiter_from_iter(iterable: Iterable) -> AsyncIterator: + """Convert an iterable to an async iterator.""" + for i in iterable: + yield i + + +@patch.object(BackupAgentTest, "async_download_backup") +async def test_downloading_remote_encrypted_backup( + download_mock, + hass: HomeAssistant, + hass_client: ClientSessionGenerator, +) -> None: + """Test downloading a local backup file.""" + backup_path = get_fixture_path("test_backups/ed1608a9.tar", DOMAIN) + await setup_backup_integration(hass) + hass.data[DATA_MANAGER].backup_agents["domain.test"] = BackupAgentTest( + "test", + [ + AgentBackup( + addons=[AddonInfo(name="Test", slug="test", version="1.0.0")], + backup_id="ed1608a9", + database_included=True, + date="1970-01-01T00:00:00Z", + extra_metadata={}, + folders=[Folder.MEDIA, Folder.SHARE], + homeassistant_included=True, + homeassistant_version="2024.12.0", + name="Test", + protected=True, + size=13, + ) + ], + ) + + async def download_backup(backup_id: str, **kwargs: Any) -> AsyncIterator[bytes]: + return aiter_from_iter((backup_path.read_bytes(),)) + + download_mock.side_effect = download_backup + await _test_downloading_encrypted_backup(hass_client, "domain.test") + + +async def _test_downloading_encrypted_backup( + hass_client: ClientSessionGenerator, + agent_id: str, +) -> None: + """Test downloading an encrypted backup file.""" + # Try downloading without supplying a password + client = await hass_client() + resp = await client.get(f"/api/backup/download/ed1608a9?agent_id={agent_id}") + assert resp.status == 200 + backup = await resp.read() + # We expect a valid outer tar file, but the inner tar file is encrypted and + # can't be read + with tarfile.open(fileobj=BytesIO(backup), mode="r") as outer_tar: + enc_metadata = json.loads(outer_tar.extractfile("./backup.json").read()) + assert enc_metadata["protected"] is True + with ( + outer_tar.extractfile("core.tar.gz") as inner_tar_file, + pytest.raises(tarfile.ReadError, match="file could not be opened"), + ): + # pylint: disable-next=consider-using-with + tarfile.open(fileobj=inner_tar_file, mode="r") + + # Download with the wrong password + resp = await client.get( + f"/api/backup/download/ed1608a9?agent_id={agent_id}&password=wrong" + ) + assert resp.status == 200 + backup = await resp.read() + # We expect a truncated outer tar file + with ( + tarfile.open(fileobj=BytesIO(backup), mode="r") as outer_tar, + pytest.raises(tarfile.ReadError, match="unexpected end of data"), + ): + outer_tar.getnames() + + # Finally download with the correct password + resp = await client.get( + f"/api/backup/download/ed1608a9?agent_id={agent_id}&password=hunter2" + ) + assert resp.status == 200 + backup = await resp.read() + # We expect a valid outer tar file, the inner tar file is decrypted and can be read + with ( + tarfile.open(fileobj=BytesIO(backup), mode="r") as outer_tar, + ): + dec_metadata = json.loads(outer_tar.extractfile("./backup.json").read()) + assert dec_metadata == enc_metadata | {"protected": False} + with ( + outer_tar.extractfile("core.tar.gz") as inner_tar_file, + tarfile.open(fileobj=inner_tar_file, mode="r") as inner_tar, + ): + assert inner_tar.getnames() == [ + ".", + "README.md", + "test_symlink", + "test1", + "test1/script.sh", + ] + + async def test_downloading_backup_not_found( hass: HomeAssistant, hass_client: ClientSessionGenerator, diff --git a/tests/components/backup/test_websocket.py b/tests/components/backup/test_websocket.py index 7820408f265..2aa6eca3b95 100644 --- a/tests/components/backup/test_websocket.py +++ b/tests/components/backup/test_websocket.py @@ -36,7 +36,7 @@ from .common import ( setup_backup_platform, ) -from tests.common import async_fire_time_changed, async_mock_service, get_fixture_path +from tests.common import async_fire_time_changed, async_mock_service from tests.typing import WebSocketGenerator BACKUP_CALL = call( @@ -2556,21 +2556,6 @@ async def test_subscribe_event( assert await client.receive_json() == snapshot -@pytest.fixture -def mock_backups() -> Generator[None]: - """Fixture to setup test backups.""" - # pylint: disable-next=import-outside-toplevel - from homeassistant.components.backup import backup as core_backup - - class CoreLocalBackupAgent(core_backup.CoreLocalBackupAgent): - def __init__(self, hass: HomeAssistant) -> None: - super().__init__(hass) - self._backup_dir = get_fixture_path("test_backups", DOMAIN) - - with patch.object(core_backup, "CoreLocalBackupAgent", CoreLocalBackupAgent): - yield - - @pytest.mark.parametrize( ("agent_id", "backup_id", "password"), [