mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Support decrypting backups when downloading (#135728)
* Support decrypting backups when downloading * Close stream * Use test helper * Wait for worker to finish * Simplify * Update backup.json * Simplify * Revert change from the future
This commit is contained in:
parent
6fdccda225
commit
9db6be11f7
@ -4,18 +4,23 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
import threading
|
||||
from typing import IO, cast
|
||||
|
||||
from aiohttp import BodyPartReader
|
||||
from aiohttp.hdrs import CONTENT_DISPOSITION
|
||||
from aiohttp.web import FileResponse, Request, Response, StreamResponse
|
||||
from multidict import istr
|
||||
|
||||
from homeassistant.components.http import KEY_HASS, HomeAssistantView, require_admin
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util import slugify
|
||||
|
||||
from . import util
|
||||
from .agent import BackupAgent
|
||||
from .const import DATA_MANAGER
|
||||
from .manager import BackupManager
|
||||
|
||||
|
||||
@callback
|
||||
@ -43,8 +48,13 @@ class DownloadBackupView(HomeAssistantView):
|
||||
agent_id = request.query.getone("agent_id")
|
||||
except KeyError:
|
||||
return Response(status=HTTPStatus.BAD_REQUEST)
|
||||
try:
|
||||
password = request.query.getone("password")
|
||||
except KeyError:
|
||||
password = None
|
||||
|
||||
manager = request.app[KEY_HASS].data[DATA_MANAGER]
|
||||
hass = request.app[KEY_HASS]
|
||||
manager = hass.data[DATA_MANAGER]
|
||||
if agent_id not in manager.backup_agents:
|
||||
return Response(status=HTTPStatus.BAD_REQUEST)
|
||||
agent = manager.backup_agents[agent_id]
|
||||
@ -58,6 +68,24 @@ class DownloadBackupView(HomeAssistantView):
|
||||
headers = {
|
||||
CONTENT_DISPOSITION: f"attachment; filename={slugify(backup.name)}.tar"
|
||||
}
|
||||
|
||||
if not password:
|
||||
return await self._send_backup_no_password(
|
||||
request, headers, backup_id, agent_id, agent, manager
|
||||
)
|
||||
return await self._send_backup_with_password(
|
||||
hass, request, headers, backup_id, agent_id, password, agent, manager
|
||||
)
|
||||
|
||||
async def _send_backup_no_password(
|
||||
self,
|
||||
request: Request,
|
||||
headers: dict[istr, str],
|
||||
backup_id: str,
|
||||
agent_id: str,
|
||||
agent: BackupAgent,
|
||||
manager: BackupManager,
|
||||
) -> StreamResponse | FileResponse | Response:
|
||||
if agent_id in manager.local_backup_agents:
|
||||
local_agent = manager.local_backup_agents[agent_id]
|
||||
path = local_agent.get_backup_path(backup_id)
|
||||
@ -70,6 +98,50 @@ class DownloadBackupView(HomeAssistantView):
|
||||
await response.write(chunk)
|
||||
return response
|
||||
|
||||
async def _send_backup_with_password(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
request: Request,
|
||||
headers: dict[istr, str],
|
||||
backup_id: str,
|
||||
agent_id: str,
|
||||
password: str,
|
||||
agent: BackupAgent,
|
||||
manager: BackupManager,
|
||||
) -> StreamResponse | FileResponse | Response:
|
||||
reader: IO[bytes]
|
||||
if agent_id in manager.local_backup_agents:
|
||||
local_agent = manager.local_backup_agents[agent_id]
|
||||
path = local_agent.get_backup_path(backup_id)
|
||||
try:
|
||||
reader = await hass.async_add_executor_job(open, path.as_posix(), "rb")
|
||||
except FileNotFoundError:
|
||||
return Response(status=HTTPStatus.NOT_FOUND)
|
||||
else:
|
||||
stream = await agent.async_download_backup(backup_id)
|
||||
reader = cast(IO[bytes], util.AsyncIteratorReader(hass, stream))
|
||||
|
||||
worker_done_event = asyncio.Event()
|
||||
|
||||
def on_done() -> None:
|
||||
"""Call by the worker thread when it's done."""
|
||||
hass.loop.call_soon_threadsafe(worker_done_event.set)
|
||||
|
||||
stream = util.AsyncIteratorWriter(hass)
|
||||
worker = threading.Thread(
|
||||
target=util.decrypt_backup, args=[reader, stream, password, on_done]
|
||||
)
|
||||
try:
|
||||
worker.start()
|
||||
response = StreamResponse(status=HTTPStatus.OK, headers=headers)
|
||||
await response.prepare(request)
|
||||
async for chunk in stream:
|
||||
await response.write(chunk)
|
||||
return response
|
||||
finally:
|
||||
reader.close()
|
||||
await worker_done_event.wait()
|
||||
|
||||
|
||||
class UploadBackupView(HomeAssistantView):
|
||||
"""Generate backup view."""
|
||||
|
@ -1033,10 +1033,12 @@ class BackupManager:
|
||||
validate_password_stream(reader, password)
|
||||
except backup_util.IncorrectPassword as err:
|
||||
raise IncorrectPasswordError from err
|
||||
except backup_util.UnsuppertedSecureTarVersion as err:
|
||||
except backup_util.UnsupportedSecureTarVersion as err:
|
||||
raise DecryptOnDowloadNotSupported from err
|
||||
except backup_util.DecryptError as err:
|
||||
raise BackupManagerError(str(err)) from err
|
||||
finally:
|
||||
reader.close()
|
||||
|
||||
|
||||
class KnownBackups:
|
||||
|
@ -3,14 +3,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
import copy
|
||||
from io import BytesIO
|
||||
import json
|
||||
from pathlib import Path, PurePath
|
||||
from queue import SimpleQueue
|
||||
import tarfile
|
||||
from typing import IO, cast
|
||||
from typing import IO, Self, cast
|
||||
|
||||
import aiohttp
|
||||
from securetar import VERSION_HEADER, SecureTarFile, SecureTarReadError
|
||||
from securetar import (
|
||||
PLAINTEXT_SIZE_HEADER,
|
||||
VERSION_HEADER,
|
||||
SecureTarError,
|
||||
SecureTarFile,
|
||||
SecureTarReadError,
|
||||
)
|
||||
|
||||
from homeassistant.backup_restore import password_to_key
|
||||
from homeassistant.core import HomeAssistant
|
||||
@ -24,7 +33,7 @@ class DecryptError(Exception):
|
||||
"""Error during decryption."""
|
||||
|
||||
|
||||
class UnsuppertedSecureTarVersion(DecryptError):
|
||||
class UnsupportedSecureTarVersion(DecryptError):
|
||||
"""Unsupported securetar version."""
|
||||
|
||||
|
||||
@ -157,6 +166,33 @@ class AsyncIteratorReader:
|
||||
self._buffer = None
|
||||
return bytes(result)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the iterator."""
|
||||
|
||||
|
||||
class AsyncIteratorWriter:
|
||||
"""Wrap an AsyncIterator."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the wrapper."""
|
||||
self._hass = hass
|
||||
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
|
||||
|
||||
def __aiter__(self) -> Self:
|
||||
"""Return the iterator."""
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> bytes:
|
||||
"""Get the next chunk from the iterator."""
|
||||
if data := await self._queue.get():
|
||||
return data
|
||||
raise StopAsyncIteration
|
||||
|
||||
def write(self, s: bytes, /) -> int:
|
||||
"""Write data to the iterator."""
|
||||
asyncio.run_coroutine_threadsafe(self._queue.put(s), self._hass.loop).result()
|
||||
return len(s)
|
||||
|
||||
|
||||
def validate_password_stream(
|
||||
input_stream: IO[bytes],
|
||||
@ -170,7 +206,7 @@ def validate_password_stream(
|
||||
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
|
||||
continue
|
||||
if obj.pax_headers.get(VERSION_HEADER) != "2.0":
|
||||
raise UnsuppertedSecureTarVersion
|
||||
raise UnsupportedSecureTarVersion
|
||||
istf = SecureTarFile(
|
||||
None, # Not used
|
||||
gzip=False,
|
||||
@ -187,6 +223,68 @@ def validate_password_stream(
|
||||
raise BackupEmpty
|
||||
|
||||
|
||||
def decrypt_backup(
|
||||
input_stream: IO[bytes],
|
||||
output_stream: IO[bytes],
|
||||
password: str | None,
|
||||
on_done: Callable[[], None],
|
||||
) -> None:
|
||||
"""Decrypt a backup."""
|
||||
try:
|
||||
with (
|
||||
tarfile.open(
|
||||
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
|
||||
) as input_tar,
|
||||
tarfile.open(
|
||||
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
|
||||
) as output_tar,
|
||||
):
|
||||
_decrypt_backup(input_tar, output_tar, password)
|
||||
except (DecryptError, SecureTarError, tarfile.TarError) as err:
|
||||
LOGGER.warning("Error decrypting backup: %s", err)
|
||||
finally:
|
||||
output_stream.write(b"") # Write an empty chunk to signal the end of the stream
|
||||
on_done()
|
||||
|
||||
|
||||
def _decrypt_backup(
|
||||
input_tar: tarfile.TarFile,
|
||||
output_tar: tarfile.TarFile,
|
||||
password: str | None,
|
||||
) -> None:
|
||||
"""Decrypt a backup."""
|
||||
for obj in input_tar:
|
||||
# We compare with PurePath to avoid issues with different path separators,
|
||||
# for example when backup.json is added as "./backup.json"
|
||||
if PurePath(obj.name) == PurePath("backup.json"):
|
||||
# Rewrite the backup.json file to indicate that the backup is decrypted
|
||||
if not (reader := input_tar.extractfile(obj)):
|
||||
raise DecryptError
|
||||
metadata = json_loads_object(reader.read())
|
||||
metadata["protected"] = False
|
||||
updated_metadata_b = json.dumps(metadata).encode()
|
||||
metadata_obj = copy.deepcopy(obj)
|
||||
metadata_obj.size = len(updated_metadata_b)
|
||||
output_tar.addfile(metadata_obj, BytesIO(updated_metadata_b))
|
||||
continue
|
||||
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
|
||||
output_tar.addfile(obj, input_tar.extractfile(obj))
|
||||
continue
|
||||
if obj.pax_headers.get(VERSION_HEADER) != "2.0":
|
||||
raise UnsupportedSecureTarVersion
|
||||
decrypted_obj = copy.deepcopy(obj)
|
||||
decrypted_obj.size = int(obj.pax_headers[PLAINTEXT_SIZE_HEADER])
|
||||
istf = SecureTarFile(
|
||||
None, # Not used
|
||||
gzip=False,
|
||||
key=password_to_key(password) if password is not None else None,
|
||||
mode="r",
|
||||
fileobj=input_tar.extractfile(obj),
|
||||
)
|
||||
with istf.decrypt(obj) as decrypted:
|
||||
output_tar.addfile(decrypted_obj, decrypted)
|
||||
|
||||
|
||||
async def receive_file(
|
||||
hass: HomeAssistant, contents: aiohttp.BodyPartReader, path: Path
|
||||
) -> None:
|
||||
|
@ -9,11 +9,14 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.backup import DOMAIN
|
||||
from homeassistant.components.backup.manager import NewBackup, WrittenBackup
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .common import TEST_BACKUP_PATH_ABC123
|
||||
|
||||
from tests.common import get_fixture_path
|
||||
|
||||
|
||||
@pytest.fixture(name="mocked_json_bytes")
|
||||
def mocked_json_bytes_fixture() -> Generator[Mock]:
|
||||
@ -113,3 +116,18 @@ def mock_backup_generation_fixture(
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backups() -> Generator[None]:
|
||||
"""Fixture to setup test backups."""
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
from homeassistant.components.backup import backup as core_backup
|
||||
|
||||
class CoreLocalBackupAgent(core_backup.CoreLocalBackupAgent):
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
super().__init__(hass)
|
||||
self._backup_dir = get_fixture_path("test_backups", DOMAIN)
|
||||
|
||||
with patch.object(core_backup, "CoreLocalBackupAgent", CoreLocalBackupAgent):
|
||||
yield
|
||||
|
@ -1,18 +1,23 @@
|
||||
"""Tests for the Backup integration."""
|
||||
|
||||
import asyncio
|
||||
from io import StringIO
|
||||
from collections.abc import AsyncIterator, Iterable
|
||||
from io import BytesIO, StringIO
|
||||
import json
|
||||
import tarfile
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from aiohttp import web
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.backup.const import DATA_MANAGER
|
||||
from homeassistant.components.backup import AddonInfo, AgentBackup, Folder
|
||||
from homeassistant.components.backup.const import DATA_MANAGER, DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .common import TEST_BACKUP_ABC123, BackupAgentTest, setup_backup_integration
|
||||
|
||||
from tests.common import MockUser
|
||||
from tests.common import MockUser, get_fixture_path
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
@ -45,8 +50,9 @@ async def test_downloading_remote_backup(
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test downloading a remote backup."""
|
||||
await setup_backup_integration(hass)
|
||||
hass.data[DATA_MANAGER].backup_agents["domain.test"] = BackupAgentTest("test")
|
||||
await setup_backup_integration(
|
||||
hass, backups={"test.test": [TEST_BACKUP_ABC123]}, remote_agents=["test"]
|
||||
)
|
||||
|
||||
client = await hass_client()
|
||||
|
||||
@ -54,11 +60,140 @@ async def test_downloading_remote_backup(
|
||||
patch.object(BackupAgentTest, "async_download_backup") as download_mock,
|
||||
):
|
||||
download_mock.return_value.__aiter__.return_value = iter((b"backup data",))
|
||||
resp = await client.get("/api/backup/download/abc123?agent_id=domain.test")
|
||||
resp = await client.get("/api/backup/download/abc123?agent_id=test.test")
|
||||
assert resp.status == 200
|
||||
assert await resp.content.read() == b"backup data"
|
||||
|
||||
|
||||
async def test_downloading_local_encrypted_backup_file_not_found(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test downloading a local backup file."""
|
||||
await setup_backup_integration(hass)
|
||||
client = await hass_client()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.backup.backup.CoreLocalBackupAgent.async_get_backup",
|
||||
return_value=TEST_BACKUP_ABC123,
|
||||
):
|
||||
resp = await client.get(
|
||||
"/api/backup/download/abc123?agent_id=backup.local&password=blah"
|
||||
)
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_backups")
|
||||
async def test_downloading_local_encrypted_backup(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test downloading a local backup file."""
|
||||
await setup_backup_integration(hass)
|
||||
await _test_downloading_encrypted_backup(hass_client, "backup.local")
|
||||
|
||||
|
||||
async def aiter_from_iter(iterable: Iterable) -> AsyncIterator:
|
||||
"""Convert an iterable to an async iterator."""
|
||||
for i in iterable:
|
||||
yield i
|
||||
|
||||
|
||||
@patch.object(BackupAgentTest, "async_download_backup")
|
||||
async def test_downloading_remote_encrypted_backup(
|
||||
download_mock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test downloading a local backup file."""
|
||||
backup_path = get_fixture_path("test_backups/ed1608a9.tar", DOMAIN)
|
||||
await setup_backup_integration(hass)
|
||||
hass.data[DATA_MANAGER].backup_agents["domain.test"] = BackupAgentTest(
|
||||
"test",
|
||||
[
|
||||
AgentBackup(
|
||||
addons=[AddonInfo(name="Test", slug="test", version="1.0.0")],
|
||||
backup_id="ed1608a9",
|
||||
database_included=True,
|
||||
date="1970-01-01T00:00:00Z",
|
||||
extra_metadata={},
|
||||
folders=[Folder.MEDIA, Folder.SHARE],
|
||||
homeassistant_included=True,
|
||||
homeassistant_version="2024.12.0",
|
||||
name="Test",
|
||||
protected=True,
|
||||
size=13,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def download_backup(backup_id: str, **kwargs: Any) -> AsyncIterator[bytes]:
|
||||
return aiter_from_iter((backup_path.read_bytes(),))
|
||||
|
||||
download_mock.side_effect = download_backup
|
||||
await _test_downloading_encrypted_backup(hass_client, "domain.test")
|
||||
|
||||
|
||||
async def _test_downloading_encrypted_backup(
|
||||
hass_client: ClientSessionGenerator,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Test downloading an encrypted backup file."""
|
||||
# Try downloading without supplying a password
|
||||
client = await hass_client()
|
||||
resp = await client.get(f"/api/backup/download/ed1608a9?agent_id={agent_id}")
|
||||
assert resp.status == 200
|
||||
backup = await resp.read()
|
||||
# We expect a valid outer tar file, but the inner tar file is encrypted and
|
||||
# can't be read
|
||||
with tarfile.open(fileobj=BytesIO(backup), mode="r") as outer_tar:
|
||||
enc_metadata = json.loads(outer_tar.extractfile("./backup.json").read())
|
||||
assert enc_metadata["protected"] is True
|
||||
with (
|
||||
outer_tar.extractfile("core.tar.gz") as inner_tar_file,
|
||||
pytest.raises(tarfile.ReadError, match="file could not be opened"),
|
||||
):
|
||||
# pylint: disable-next=consider-using-with
|
||||
tarfile.open(fileobj=inner_tar_file, mode="r")
|
||||
|
||||
# Download with the wrong password
|
||||
resp = await client.get(
|
||||
f"/api/backup/download/ed1608a9?agent_id={agent_id}&password=wrong"
|
||||
)
|
||||
assert resp.status == 200
|
||||
backup = await resp.read()
|
||||
# We expect a truncated outer tar file
|
||||
with (
|
||||
tarfile.open(fileobj=BytesIO(backup), mode="r") as outer_tar,
|
||||
pytest.raises(tarfile.ReadError, match="unexpected end of data"),
|
||||
):
|
||||
outer_tar.getnames()
|
||||
|
||||
# Finally download with the correct password
|
||||
resp = await client.get(
|
||||
f"/api/backup/download/ed1608a9?agent_id={agent_id}&password=hunter2"
|
||||
)
|
||||
assert resp.status == 200
|
||||
backup = await resp.read()
|
||||
# We expect a valid outer tar file, the inner tar file is decrypted and can be read
|
||||
with (
|
||||
tarfile.open(fileobj=BytesIO(backup), mode="r") as outer_tar,
|
||||
):
|
||||
dec_metadata = json.loads(outer_tar.extractfile("./backup.json").read())
|
||||
assert dec_metadata == enc_metadata | {"protected": False}
|
||||
with (
|
||||
outer_tar.extractfile("core.tar.gz") as inner_tar_file,
|
||||
tarfile.open(fileobj=inner_tar_file, mode="r") as inner_tar,
|
||||
):
|
||||
assert inner_tar.getnames() == [
|
||||
".",
|
||||
"README.md",
|
||||
"test_symlink",
|
||||
"test1",
|
||||
"test1/script.sh",
|
||||
]
|
||||
|
||||
|
||||
async def test_downloading_backup_not_found(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
|
@ -36,7 +36,7 @@ from .common import (
|
||||
setup_backup_platform,
|
||||
)
|
||||
|
||||
from tests.common import async_fire_time_changed, async_mock_service, get_fixture_path
|
||||
from tests.common import async_fire_time_changed, async_mock_service
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
BACKUP_CALL = call(
|
||||
@ -2556,21 +2556,6 @@ async def test_subscribe_event(
|
||||
assert await client.receive_json() == snapshot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backups() -> Generator[None]:
|
||||
"""Fixture to setup test backups."""
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
from homeassistant.components.backup import backup as core_backup
|
||||
|
||||
class CoreLocalBackupAgent(core_backup.CoreLocalBackupAgent):
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
super().__init__(hass)
|
||||
self._backup_dir = get_fixture_path("test_backups", DOMAIN)
|
||||
|
||||
with patch.object(core_backup, "CoreLocalBackupAgent", CoreLocalBackupAgent):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("agent_id", "backup_id", "password"),
|
||||
[
|
||||
|
Loading…
x
Reference in New Issue
Block a user