Support decrypting backups when downloading (#135728)

* Support decrypting backups when downloading

* Close stream

* Use test helper

* Wait for worker to finish

* Simplify

* Update backup.json

* Simplify

* Revert change from the future
This commit is contained in:
Erik Montnemery 2025-01-16 12:36:12 +01:00 committed by GitHub
parent 6fdccda225
commit 9db6be11f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 341 additions and 31 deletions

View File

@ -4,18 +4,23 @@ from __future__ import annotations
import asyncio
from http import HTTPStatus
from typing import cast
import threading
from typing import IO, cast
from aiohttp import BodyPartReader
from aiohttp.hdrs import CONTENT_DISPOSITION
from aiohttp.web import FileResponse, Request, Response, StreamResponse
from multidict import istr
from homeassistant.components.http import KEY_HASS, HomeAssistantView, require_admin
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import slugify
from . import util
from .agent import BackupAgent
from .const import DATA_MANAGER
from .manager import BackupManager
@callback
@ -43,8 +48,13 @@ class DownloadBackupView(HomeAssistantView):
agent_id = request.query.getone("agent_id")
except KeyError:
return Response(status=HTTPStatus.BAD_REQUEST)
try:
password = request.query.getone("password")
except KeyError:
password = None
manager = request.app[KEY_HASS].data[DATA_MANAGER]
hass = request.app[KEY_HASS]
manager = hass.data[DATA_MANAGER]
if agent_id not in manager.backup_agents:
return Response(status=HTTPStatus.BAD_REQUEST)
agent = manager.backup_agents[agent_id]
@ -58,6 +68,24 @@ class DownloadBackupView(HomeAssistantView):
headers = {
CONTENT_DISPOSITION: f"attachment; filename={slugify(backup.name)}.tar"
}
if not password:
return await self._send_backup_no_password(
request, headers, backup_id, agent_id, agent, manager
)
return await self._send_backup_with_password(
hass, request, headers, backup_id, agent_id, password, agent, manager
)
async def _send_backup_no_password(
self,
request: Request,
headers: dict[istr, str],
backup_id: str,
agent_id: str,
agent: BackupAgent,
manager: BackupManager,
) -> StreamResponse | FileResponse | Response:
if agent_id in manager.local_backup_agents:
local_agent = manager.local_backup_agents[agent_id]
path = local_agent.get_backup_path(backup_id)
@ -70,6 +98,50 @@ class DownloadBackupView(HomeAssistantView):
await response.write(chunk)
return response
async def _send_backup_with_password(
self,
hass: HomeAssistant,
request: Request,
headers: dict[istr, str],
backup_id: str,
agent_id: str,
password: str,
agent: BackupAgent,
manager: BackupManager,
) -> StreamResponse | FileResponse | Response:
reader: IO[bytes]
if agent_id in manager.local_backup_agents:
local_agent = manager.local_backup_agents[agent_id]
path = local_agent.get_backup_path(backup_id)
try:
reader = await hass.async_add_executor_job(open, path.as_posix(), "rb")
except FileNotFoundError:
return Response(status=HTTPStatus.NOT_FOUND)
else:
stream = await agent.async_download_backup(backup_id)
reader = cast(IO[bytes], util.AsyncIteratorReader(hass, stream))
worker_done_event = asyncio.Event()
def on_done() -> None:
"""Call by the worker thread when it's done."""
hass.loop.call_soon_threadsafe(worker_done_event.set)
stream = util.AsyncIteratorWriter(hass)
worker = threading.Thread(
target=util.decrypt_backup, args=[reader, stream, password, on_done]
)
try:
worker.start()
response = StreamResponse(status=HTTPStatus.OK, headers=headers)
await response.prepare(request)
async for chunk in stream:
await response.write(chunk)
return response
finally:
reader.close()
await worker_done_event.wait()
class UploadBackupView(HomeAssistantView):
"""Generate backup view."""

View File

@ -1033,10 +1033,12 @@ class BackupManager:
validate_password_stream(reader, password)
except backup_util.IncorrectPassword as err:
raise IncorrectPasswordError from err
except backup_util.UnsuppertedSecureTarVersion as err:
except backup_util.UnsupportedSecureTarVersion as err:
raise DecryptOnDowloadNotSupported from err
except backup_util.DecryptError as err:
raise BackupManagerError(str(err)) from err
finally:
reader.close()
class KnownBackups:

View File

@ -3,14 +3,23 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from pathlib import Path
from collections.abc import AsyncIterator, Callable
import copy
from io import BytesIO
import json
from pathlib import Path, PurePath
from queue import SimpleQueue
import tarfile
from typing import IO, cast
from typing import IO, Self, cast
import aiohttp
from securetar import VERSION_HEADER, SecureTarFile, SecureTarReadError
from securetar import (
PLAINTEXT_SIZE_HEADER,
VERSION_HEADER,
SecureTarError,
SecureTarFile,
SecureTarReadError,
)
from homeassistant.backup_restore import password_to_key
from homeassistant.core import HomeAssistant
@ -24,7 +33,7 @@ class DecryptError(Exception):
"""Error during decryption."""
class UnsuppertedSecureTarVersion(DecryptError):
class UnsupportedSecureTarVersion(DecryptError):
"""Unsupported securetar version."""
@ -157,6 +166,33 @@ class AsyncIteratorReader:
self._buffer = None
return bytes(result)
def close(self) -> None:
"""Close the iterator."""
class AsyncIteratorWriter:
"""Wrap an AsyncIterator."""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the wrapper."""
self._hass = hass
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
def __aiter__(self) -> Self:
"""Return the iterator."""
return self
async def __anext__(self) -> bytes:
"""Get the next chunk from the iterator."""
if data := await self._queue.get():
return data
raise StopAsyncIteration
def write(self, s: bytes, /) -> int:
"""Write data to the iterator."""
asyncio.run_coroutine_threadsafe(self._queue.put(s), self._hass.loop).result()
return len(s)
def validate_password_stream(
input_stream: IO[bytes],
@ -170,7 +206,7 @@ def validate_password_stream(
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
continue
if obj.pax_headers.get(VERSION_HEADER) != "2.0":
raise UnsuppertedSecureTarVersion
raise UnsupportedSecureTarVersion
istf = SecureTarFile(
None, # Not used
gzip=False,
@ -187,6 +223,68 @@ def validate_password_stream(
raise BackupEmpty
def decrypt_backup(
input_stream: IO[bytes],
output_stream: IO[bytes],
password: str | None,
on_done: Callable[[], None],
) -> None:
"""Decrypt a backup."""
try:
with (
tarfile.open(
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
) as input_tar,
tarfile.open(
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
) as output_tar,
):
_decrypt_backup(input_tar, output_tar, password)
except (DecryptError, SecureTarError, tarfile.TarError) as err:
LOGGER.warning("Error decrypting backup: %s", err)
finally:
output_stream.write(b"") # Write an empty chunk to signal the end of the stream
on_done()
def _decrypt_backup(
input_tar: tarfile.TarFile,
output_tar: tarfile.TarFile,
password: str | None,
) -> None:
"""Decrypt a backup."""
for obj in input_tar:
# We compare with PurePath to avoid issues with different path separators,
# for example when backup.json is added as "./backup.json"
if PurePath(obj.name) == PurePath("backup.json"):
# Rewrite the backup.json file to indicate that the backup is decrypted
if not (reader := input_tar.extractfile(obj)):
raise DecryptError
metadata = json_loads_object(reader.read())
metadata["protected"] = False
updated_metadata_b = json.dumps(metadata).encode()
metadata_obj = copy.deepcopy(obj)
metadata_obj.size = len(updated_metadata_b)
output_tar.addfile(metadata_obj, BytesIO(updated_metadata_b))
continue
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
output_tar.addfile(obj, input_tar.extractfile(obj))
continue
if obj.pax_headers.get(VERSION_HEADER) != "2.0":
raise UnsupportedSecureTarVersion
decrypted_obj = copy.deepcopy(obj)
decrypted_obj.size = int(obj.pax_headers[PLAINTEXT_SIZE_HEADER])
istf = SecureTarFile(
None, # Not used
gzip=False,
key=password_to_key(password) if password is not None else None,
mode="r",
fileobj=input_tar.extractfile(obj),
)
with istf.decrypt(obj) as decrypted:
output_tar.addfile(decrypted_obj, decrypted)
async def receive_file(
hass: HomeAssistant, contents: aiohttp.BodyPartReader, path: Path
) -> None:

View File

@ -9,11 +9,14 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from homeassistant.components.backup import DOMAIN
from homeassistant.components.backup.manager import NewBackup, WrittenBackup
from homeassistant.core import HomeAssistant
from .common import TEST_BACKUP_PATH_ABC123
from tests.common import get_fixture_path
@pytest.fixture(name="mocked_json_bytes")
def mocked_json_bytes_fixture() -> Generator[Mock]:
@ -113,3 +116,18 @@ def mock_backup_generation_fixture(
),
):
yield
@pytest.fixture
def mock_backups() -> Generator[None]:
"""Fixture to setup test backups."""
# pylint: disable-next=import-outside-toplevel
from homeassistant.components.backup import backup as core_backup
class CoreLocalBackupAgent(core_backup.CoreLocalBackupAgent):
def __init__(self, hass: HomeAssistant) -> None:
super().__init__(hass)
self._backup_dir = get_fixture_path("test_backups", DOMAIN)
with patch.object(core_backup, "CoreLocalBackupAgent", CoreLocalBackupAgent):
yield

View File

@ -1,18 +1,23 @@
"""Tests for the Backup integration."""
import asyncio
from io import StringIO
from collections.abc import AsyncIterator, Iterable
from io import BytesIO, StringIO
import json
import tarfile
from typing import Any
from unittest.mock import patch
from aiohttp import web
import pytest
from homeassistant.components.backup.const import DATA_MANAGER
from homeassistant.components.backup import AddonInfo, AgentBackup, Folder
from homeassistant.components.backup.const import DATA_MANAGER, DOMAIN
from homeassistant.core import HomeAssistant
from .common import TEST_BACKUP_ABC123, BackupAgentTest, setup_backup_integration
from tests.common import MockUser
from tests.common import MockUser, get_fixture_path
from tests.typing import ClientSessionGenerator
@ -45,8 +50,9 @@ async def test_downloading_remote_backup(
hass_client: ClientSessionGenerator,
) -> None:
"""Test downloading a remote backup."""
await setup_backup_integration(hass)
hass.data[DATA_MANAGER].backup_agents["domain.test"] = BackupAgentTest("test")
await setup_backup_integration(
hass, backups={"test.test": [TEST_BACKUP_ABC123]}, remote_agents=["test"]
)
client = await hass_client()
@ -54,11 +60,140 @@ async def test_downloading_remote_backup(
patch.object(BackupAgentTest, "async_download_backup") as download_mock,
):
download_mock.return_value.__aiter__.return_value = iter((b"backup data",))
resp = await client.get("/api/backup/download/abc123?agent_id=domain.test")
resp = await client.get("/api/backup/download/abc123?agent_id=test.test")
assert resp.status == 200
assert await resp.content.read() == b"backup data"
async def test_downloading_local_encrypted_backup_file_not_found(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
) -> None:
"""Test downloading a local backup file."""
await setup_backup_integration(hass)
client = await hass_client()
with patch(
"homeassistant.components.backup.backup.CoreLocalBackupAgent.async_get_backup",
return_value=TEST_BACKUP_ABC123,
):
resp = await client.get(
"/api/backup/download/abc123?agent_id=backup.local&password=blah"
)
assert resp.status == 404
@pytest.mark.usefixtures("mock_backups")
async def test_downloading_local_encrypted_backup(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
) -> None:
"""Test downloading a local backup file."""
await setup_backup_integration(hass)
await _test_downloading_encrypted_backup(hass_client, "backup.local")
async def aiter_from_iter(iterable: Iterable) -> AsyncIterator:
"""Convert an iterable to an async iterator."""
for i in iterable:
yield i
@patch.object(BackupAgentTest, "async_download_backup")
async def test_downloading_remote_encrypted_backup(
download_mock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
) -> None:
"""Test downloading a local backup file."""
backup_path = get_fixture_path("test_backups/ed1608a9.tar", DOMAIN)
await setup_backup_integration(hass)
hass.data[DATA_MANAGER].backup_agents["domain.test"] = BackupAgentTest(
"test",
[
AgentBackup(
addons=[AddonInfo(name="Test", slug="test", version="1.0.0")],
backup_id="ed1608a9",
database_included=True,
date="1970-01-01T00:00:00Z",
extra_metadata={},
folders=[Folder.MEDIA, Folder.SHARE],
homeassistant_included=True,
homeassistant_version="2024.12.0",
name="Test",
protected=True,
size=13,
)
],
)
async def download_backup(backup_id: str, **kwargs: Any) -> AsyncIterator[bytes]:
return aiter_from_iter((backup_path.read_bytes(),))
download_mock.side_effect = download_backup
await _test_downloading_encrypted_backup(hass_client, "domain.test")
async def _test_downloading_encrypted_backup(
hass_client: ClientSessionGenerator,
agent_id: str,
) -> None:
"""Test downloading an encrypted backup file."""
# Try downloading without supplying a password
client = await hass_client()
resp = await client.get(f"/api/backup/download/ed1608a9?agent_id={agent_id}")
assert resp.status == 200
backup = await resp.read()
# We expect a valid outer tar file, but the inner tar file is encrypted and
# can't be read
with tarfile.open(fileobj=BytesIO(backup), mode="r") as outer_tar:
enc_metadata = json.loads(outer_tar.extractfile("./backup.json").read())
assert enc_metadata["protected"] is True
with (
outer_tar.extractfile("core.tar.gz") as inner_tar_file,
pytest.raises(tarfile.ReadError, match="file could not be opened"),
):
# pylint: disable-next=consider-using-with
tarfile.open(fileobj=inner_tar_file, mode="r")
# Download with the wrong password
resp = await client.get(
f"/api/backup/download/ed1608a9?agent_id={agent_id}&password=wrong"
)
assert resp.status == 200
backup = await resp.read()
# We expect a truncated outer tar file
with (
tarfile.open(fileobj=BytesIO(backup), mode="r") as outer_tar,
pytest.raises(tarfile.ReadError, match="unexpected end of data"),
):
outer_tar.getnames()
# Finally download with the correct password
resp = await client.get(
f"/api/backup/download/ed1608a9?agent_id={agent_id}&password=hunter2"
)
assert resp.status == 200
backup = await resp.read()
# We expect a valid outer tar file, the inner tar file is decrypted and can be read
with (
tarfile.open(fileobj=BytesIO(backup), mode="r") as outer_tar,
):
dec_metadata = json.loads(outer_tar.extractfile("./backup.json").read())
assert dec_metadata == enc_metadata | {"protected": False}
with (
outer_tar.extractfile("core.tar.gz") as inner_tar_file,
tarfile.open(fileobj=inner_tar_file, mode="r") as inner_tar,
):
assert inner_tar.getnames() == [
".",
"README.md",
"test_symlink",
"test1",
"test1/script.sh",
]
async def test_downloading_backup_not_found(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,

View File

@ -36,7 +36,7 @@ from .common import (
setup_backup_platform,
)
from tests.common import async_fire_time_changed, async_mock_service, get_fixture_path
from tests.common import async_fire_time_changed, async_mock_service
from tests.typing import WebSocketGenerator
BACKUP_CALL = call(
@ -2556,21 +2556,6 @@ async def test_subscribe_event(
assert await client.receive_json() == snapshot
@pytest.fixture
def mock_backups() -> Generator[None]:
"""Fixture to setup test backups."""
# pylint: disable-next=import-outside-toplevel
from homeassistant.components.backup import backup as core_backup
class CoreLocalBackupAgent(core_backup.CoreLocalBackupAgent):
def __init__(self, hass: HomeAssistant) -> None:
super().__init__(hass)
self._backup_dir = get_fixture_path("test_backups", DOMAIN)
with patch.object(core_backup, "CoreLocalBackupAgent", CoreLocalBackupAgent):
yield
@pytest.mark.parametrize(
("agent_id", "backup_id", "password"),
[