mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 03:07:37 +00:00
Add WS command backup/can_decrypt_on_download (#135662)
* Add WS command backup/can_decrypt_on_download * Wrap errors * Add default messages to exceptions * Improve test coverage
This commit is contained in:
parent
3622e8331b
commit
f36a10126c
@ -14,7 +14,7 @@ from pathlib import Path, PurePath
|
|||||||
import shutil
|
import shutil
|
||||||
import tarfile
|
import tarfile
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, Protocol, TypedDict
|
from typing import IO, TYPE_CHECKING, Any, Protocol, TypedDict, cast
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from securetar import SecureTarFile, atomic_contents_add
|
from securetar import SecureTarFile, atomic_contents_add
|
||||||
@ -31,6 +31,7 @@ from homeassistant.helpers import (
|
|||||||
from homeassistant.helpers.json import json_bytes
|
from homeassistant.helpers.json import json_bytes
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
|
from . import util as backup_util
|
||||||
from .agent import (
|
from .agent import (
|
||||||
BackupAgent,
|
BackupAgent,
|
||||||
BackupAgentError,
|
BackupAgentError,
|
||||||
@ -48,7 +49,13 @@ from .const import (
|
|||||||
)
|
)
|
||||||
from .models import AgentBackup, BackupManagerError, Folder
|
from .models import AgentBackup, BackupManagerError, Folder
|
||||||
from .store import BackupStore
|
from .store import BackupStore
|
||||||
from .util import make_backup_dir, read_backup, validate_password
|
from .util import (
|
||||||
|
AsyncIteratorReader,
|
||||||
|
make_backup_dir,
|
||||||
|
read_backup,
|
||||||
|
validate_password,
|
||||||
|
validate_password_stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True, slots=True)
|
@dataclass(frozen=True, kw_only=True, slots=True)
|
||||||
@ -248,6 +255,14 @@ class BackupReaderWriterError(HomeAssistantError):
|
|||||||
class IncorrectPasswordError(BackupReaderWriterError):
|
class IncorrectPasswordError(BackupReaderWriterError):
|
||||||
"""Raised when the password is incorrect."""
|
"""Raised when the password is incorrect."""
|
||||||
|
|
||||||
|
_message = "The password provided is incorrect."
|
||||||
|
|
||||||
|
|
||||||
|
class DecryptOnDowloadNotSupported(BackupManagerError):
|
||||||
|
"""Raised when on-the-fly decryption is not supported."""
|
||||||
|
|
||||||
|
_message = "On-the-fly decryption is not supported for this backup."
|
||||||
|
|
||||||
|
|
||||||
class BackupManager:
|
class BackupManager:
|
||||||
"""Define the format that backup managers can have."""
|
"""Define the format that backup managers can have."""
|
||||||
@ -990,6 +1005,39 @@ class BackupManager:
|
|||||||
translation_placeholders={"failed_agents": ", ".join(agent_errors)},
|
translation_placeholders={"failed_agents": ", ".join(agent_errors)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def async_can_decrypt_on_download(
|
||||||
|
self,
|
||||||
|
backup_id: str,
|
||||||
|
*,
|
||||||
|
agent_id: str,
|
||||||
|
password: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Check if we are able to decrypt the backup on download."""
|
||||||
|
try:
|
||||||
|
agent = self.backup_agents[agent_id]
|
||||||
|
except KeyError as err:
|
||||||
|
raise BackupManagerError(f"Invalid agent selected: {agent_id}") from err
|
||||||
|
if not await agent.async_get_backup(backup_id):
|
||||||
|
raise BackupManagerError(
|
||||||
|
f"Backup {backup_id} not found in agent {agent_id}"
|
||||||
|
)
|
||||||
|
reader: IO[bytes]
|
||||||
|
if agent_id in self.local_backup_agents:
|
||||||
|
local_agent = self.local_backup_agents[agent_id]
|
||||||
|
path = local_agent.get_backup_path(backup_id)
|
||||||
|
reader = await self.hass.async_add_executor_job(open, path.as_posix(), "rb")
|
||||||
|
else:
|
||||||
|
backup_stream = await agent.async_download_backup(backup_id)
|
||||||
|
reader = cast(IO[bytes], AsyncIteratorReader(self.hass, backup_stream))
|
||||||
|
try:
|
||||||
|
validate_password_stream(reader, password)
|
||||||
|
except backup_util.IncorrectPassword as err:
|
||||||
|
raise IncorrectPasswordError from err
|
||||||
|
except backup_util.UnsuppertedSecureTarVersion as err:
|
||||||
|
raise DecryptOnDowloadNotSupported from err
|
||||||
|
except backup_util.DecryptError as err:
|
||||||
|
raise BackupManagerError(str(err)) from err
|
||||||
|
|
||||||
|
|
||||||
class KnownBackups:
|
class KnownBackups:
|
||||||
"""Track known backups."""
|
"""Track known backups."""
|
||||||
@ -1372,7 +1420,7 @@ class CoreBackupReaderWriter(BackupReaderWriter):
|
|||||||
validate_password, path, password
|
validate_password, path, password
|
||||||
)
|
)
|
||||||
if not password_valid:
|
if not password_valid:
|
||||||
raise IncorrectPasswordError("The password provided is incorrect.")
|
raise IncorrectPasswordError
|
||||||
|
|
||||||
def _write_restore_file() -> None:
|
def _write_restore_file() -> None:
|
||||||
"""Write the restore file."""
|
"""Write the restore file."""
|
||||||
|
@ -3,13 +3,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import SimpleQueue
|
from queue import SimpleQueue
|
||||||
import tarfile
|
import tarfile
|
||||||
from typing import cast
|
from typing import IO, cast
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from securetar import SecureTarFile
|
from securetar import VERSION_HEADER, SecureTarFile, SecureTarReadError
|
||||||
|
|
||||||
from homeassistant.backup_restore import password_to_key
|
from homeassistant.backup_restore import password_to_key
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -19,6 +20,22 @@ from .const import BUF_SIZE, LOGGER
|
|||||||
from .models import AddonInfo, AgentBackup, Folder
|
from .models import AddonInfo, AgentBackup, Folder
|
||||||
|
|
||||||
|
|
||||||
|
class DecryptError(Exception):
|
||||||
|
"""Error during decryption."""
|
||||||
|
|
||||||
|
|
||||||
|
class UnsuppertedSecureTarVersion(DecryptError):
|
||||||
|
"""Unsupported securetar version."""
|
||||||
|
|
||||||
|
|
||||||
|
class IncorrectPassword(DecryptError):
|
||||||
|
"""Invalid password or corrupted backup."""
|
||||||
|
|
||||||
|
|
||||||
|
class BackupEmpty(DecryptError):
|
||||||
|
"""No tar files found in the backup."""
|
||||||
|
|
||||||
|
|
||||||
def make_backup_dir(path: Path) -> None:
|
def make_backup_dir(path: Path) -> None:
|
||||||
"""Create a backup directory if it does not exist."""
|
"""Create a backup directory if it does not exist."""
|
||||||
path.mkdir(exist_ok=True)
|
path.mkdir(exist_ok=True)
|
||||||
@ -106,6 +123,70 @@ def validate_password(path: Path, password: str | None) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncIteratorReader:
|
||||||
|
"""Wrap an AsyncIterator."""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None:
|
||||||
|
"""Initialize the wrapper."""
|
||||||
|
self._hass = hass
|
||||||
|
self._stream = stream
|
||||||
|
self._buffer: bytes | None = None
|
||||||
|
self._pos: int = 0
|
||||||
|
|
||||||
|
async def _next(self) -> bytes | None:
|
||||||
|
"""Get the next chunk from the iterator."""
|
||||||
|
return await anext(self._stream, None)
|
||||||
|
|
||||||
|
def read(self, n: int = -1, /) -> bytes:
|
||||||
|
"""Read data from the iterator."""
|
||||||
|
result = bytearray()
|
||||||
|
while n < 0 or len(result) < n:
|
||||||
|
if not self._buffer:
|
||||||
|
self._buffer = asyncio.run_coroutine_threadsafe(
|
||||||
|
self._next(), self._hass.loop
|
||||||
|
).result()
|
||||||
|
self._pos = 0
|
||||||
|
if not self._buffer:
|
||||||
|
# The stream is exhausted
|
||||||
|
break
|
||||||
|
chunk = self._buffer[self._pos : self._pos + n]
|
||||||
|
result.extend(chunk)
|
||||||
|
n -= len(chunk)
|
||||||
|
self._pos += len(chunk)
|
||||||
|
if self._pos == len(self._buffer):
|
||||||
|
self._buffer = None
|
||||||
|
return bytes(result)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_password_stream(
|
||||||
|
input_stream: IO[bytes],
|
||||||
|
password: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Decrypt a backup."""
|
||||||
|
with (
|
||||||
|
tarfile.open(fileobj=input_stream, mode="r|", bufsize=BUF_SIZE) as input_tar,
|
||||||
|
):
|
||||||
|
for obj in input_tar:
|
||||||
|
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
|
||||||
|
continue
|
||||||
|
if obj.pax_headers.get(VERSION_HEADER) != "2.0":
|
||||||
|
raise UnsuppertedSecureTarVersion
|
||||||
|
istf = SecureTarFile(
|
||||||
|
None, # Not used
|
||||||
|
gzip=False,
|
||||||
|
key=password_to_key(password) if password is not None else None,
|
||||||
|
mode="r",
|
||||||
|
fileobj=input_tar.extractfile(obj),
|
||||||
|
)
|
||||||
|
with istf.decrypt(obj) as decrypted:
|
||||||
|
try:
|
||||||
|
decrypted.read(1) # Read a single byte to trigger the decryption
|
||||||
|
except SecureTarReadError as err:
|
||||||
|
raise IncorrectPassword from err
|
||||||
|
return
|
||||||
|
raise BackupEmpty
|
||||||
|
|
||||||
|
|
||||||
async def receive_file(
|
async def receive_file(
|
||||||
hass: HomeAssistant, contents: aiohttp.BodyPartReader, path: Path
|
hass: HomeAssistant, contents: aiohttp.BodyPartReader, path: Path
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -9,7 +9,11 @@ from homeassistant.core import HomeAssistant, callback
|
|||||||
|
|
||||||
from .config import ScheduleState
|
from .config import ScheduleState
|
||||||
from .const import DATA_MANAGER, LOGGER
|
from .const import DATA_MANAGER, LOGGER
|
||||||
from .manager import IncorrectPasswordError, ManagerStateEvent
|
from .manager import (
|
||||||
|
DecryptOnDowloadNotSupported,
|
||||||
|
IncorrectPasswordError,
|
||||||
|
ManagerStateEvent,
|
||||||
|
)
|
||||||
from .models import Folder
|
from .models import Folder
|
||||||
|
|
||||||
|
|
||||||
@ -24,6 +28,7 @@ def async_register_websocket_handlers(hass: HomeAssistant, with_hassio: bool) ->
|
|||||||
|
|
||||||
websocket_api.async_register_command(hass, handle_details)
|
websocket_api.async_register_command(hass, handle_details)
|
||||||
websocket_api.async_register_command(hass, handle_info)
|
websocket_api.async_register_command(hass, handle_info)
|
||||||
|
websocket_api.async_register_command(hass, handle_can_decrypt_on_download)
|
||||||
websocket_api.async_register_command(hass, handle_create)
|
websocket_api.async_register_command(hass, handle_create)
|
||||||
websocket_api.async_register_command(hass, handle_create_with_automatic_settings)
|
websocket_api.async_register_command(hass, handle_create_with_automatic_settings)
|
||||||
websocket_api.async_register_command(hass, handle_delete)
|
websocket_api.async_register_command(hass, handle_delete)
|
||||||
@ -147,6 +152,38 @@ async def handle_restore(
|
|||||||
connection.send_result(msg["id"])
|
connection.send_result(msg["id"])
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.require_admin
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "backup/can_decrypt_on_download",
|
||||||
|
vol.Required("backup_id"): str,
|
||||||
|
vol.Required("agent_id"): str,
|
||||||
|
vol.Required("password"): str,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@websocket_api.async_response
|
||||||
|
async def handle_can_decrypt_on_download(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
connection: websocket_api.ActiveConnection,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Check if the supplied password is correct."""
|
||||||
|
try:
|
||||||
|
await hass.data[DATA_MANAGER].async_can_decrypt_on_download(
|
||||||
|
msg["backup_id"],
|
||||||
|
agent_id=msg["agent_id"],
|
||||||
|
password=msg.get("password"),
|
||||||
|
)
|
||||||
|
except IncorrectPasswordError:
|
||||||
|
connection.send_error(msg["id"], "password_incorrect", "Incorrect password")
|
||||||
|
except DecryptOnDowloadNotSupported:
|
||||||
|
connection.send_error(
|
||||||
|
msg["id"], "decrypt_not_supported", "Decrypt on download not supported"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
connection.send_result(msg["id"])
|
||||||
|
|
||||||
|
|
||||||
@websocket_api.require_admin
|
@websocket_api.require_admin
|
||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
{
|
{
|
||||||
|
BIN
tests/components/backup/fixtures/test_backups/2bcb3113.tar
Normal file
BIN
tests/components/backup/fixtures/test_backups/2bcb3113.tar
Normal file
Binary file not shown.
BIN
tests/components/backup/fixtures/test_backups/ed1608a9.tar
Normal file
BIN
tests/components/backup/fixtures/test_backups/ed1608a9.tar
Normal file
Binary file not shown.
@ -175,6 +175,58 @@
|
|||||||
'type': 'result',
|
'type': 'result',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_can_decrypt_on_download[backup.local-2bcb3113-hunter2]
|
||||||
|
dict({
|
||||||
|
'error': dict({
|
||||||
|
'code': 'decrypt_not_supported',
|
||||||
|
'message': 'Decrypt on download not supported',
|
||||||
|
}),
|
||||||
|
'id': 1,
|
||||||
|
'success': False,
|
||||||
|
'type': 'result',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_can_decrypt_on_download[backup.local-ed1608a9-hunter2]
|
||||||
|
dict({
|
||||||
|
'id': 1,
|
||||||
|
'result': None,
|
||||||
|
'success': True,
|
||||||
|
'type': 'result',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_can_decrypt_on_download[backup.local-ed1608a9-wrong_password]
|
||||||
|
dict({
|
||||||
|
'error': dict({
|
||||||
|
'code': 'password_incorrect',
|
||||||
|
'message': 'Incorrect password',
|
||||||
|
}),
|
||||||
|
'id': 1,
|
||||||
|
'success': False,
|
||||||
|
'type': 'result',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_can_decrypt_on_download[backup.local-no_such_backup-hunter2]
|
||||||
|
dict({
|
||||||
|
'error': dict({
|
||||||
|
'code': 'home_assistant_error',
|
||||||
|
'message': 'Backup no_such_backup not found in agent backup.local',
|
||||||
|
}),
|
||||||
|
'id': 1,
|
||||||
|
'success': False,
|
||||||
|
'type': 'result',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_can_decrypt_on_download[no_such_agent-ed1608a9-hunter2]
|
||||||
|
dict({
|
||||||
|
'error': dict({
|
||||||
|
'code': 'home_assistant_error',
|
||||||
|
'message': 'Invalid agent selected: no_such_agent',
|
||||||
|
}),
|
||||||
|
'id': 1,
|
||||||
|
'success': False,
|
||||||
|
'type': 'result',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
# name: test_config_info[None]
|
# name: test_config_info[None]
|
||||||
dict({
|
dict({
|
||||||
'id': 1,
|
'id': 1,
|
||||||
|
@ -36,7 +36,7 @@ from .common import (
|
|||||||
setup_backup_platform,
|
setup_backup_platform,
|
||||||
)
|
)
|
||||||
|
|
||||||
from tests.common import async_fire_time_changed, async_mock_service
|
from tests.common import async_fire_time_changed, async_mock_service, get_fixture_path
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
BACKUP_CALL = call(
|
BACKUP_CALL = call(
|
||||||
@ -2554,3 +2554,56 @@ async def test_subscribe_event(
|
|||||||
CreateBackupEvent(stage=None, state=CreateBackupState.IN_PROGRESS)
|
CreateBackupEvent(stage=None, state=CreateBackupState.IN_PROGRESS)
|
||||||
)
|
)
|
||||||
assert await client.receive_json() == snapshot
|
assert await client.receive_json() == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_backups() -> Generator[None]:
|
||||||
|
"""Fixture to setup test backups."""
|
||||||
|
# pylint: disable-next=import-outside-toplevel
|
||||||
|
from homeassistant.components.backup import backup as core_backup
|
||||||
|
|
||||||
|
class CoreLocalBackupAgent(core_backup.CoreLocalBackupAgent):
|
||||||
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
|
super().__init__(hass)
|
||||||
|
self._backup_dir = get_fixture_path("test_backups", DOMAIN)
|
||||||
|
|
||||||
|
with patch.object(core_backup, "CoreLocalBackupAgent", CoreLocalBackupAgent):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("agent_id", "backup_id", "password"),
|
||||||
|
[
|
||||||
|
# Invalid agent or backup
|
||||||
|
("no_such_agent", "ed1608a9", "hunter2"),
|
||||||
|
("backup.local", "no_such_backup", "hunter2"),
|
||||||
|
# Legacy backup, which can't be streamed
|
||||||
|
("backup.local", "2bcb3113", "hunter2"),
|
||||||
|
# New backup, which can be streamed, try with correct and wrong password
|
||||||
|
("backup.local", "ed1608a9", "hunter2"),
|
||||||
|
("backup.local", "ed1608a9", "wrong_password"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.usefixtures("mock_backups")
|
||||||
|
async def test_can_decrypt_on_download(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
agent_id: str,
|
||||||
|
backup_id: str,
|
||||||
|
password: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test can decrypt on download."""
|
||||||
|
await setup_backup_integration(hass, with_hassio=False)
|
||||||
|
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "backup/can_decrypt_on_download",
|
||||||
|
"backup_id": backup_id,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"password": password,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert await client.receive_json() == snapshot
|
||||||
|
Loading…
x
Reference in New Issue
Block a user