"""Local backup support for Core and Container installations.""" from __future__ import annotations import asyncio from collections.abc import AsyncIterator, Callable, Coroutine from concurrent.futures import CancelledError, Future import copy from dataclasses import dataclass, replace from io import BytesIO import json import os from pathlib import Path, PurePath from queue import SimpleQueue import tarfile import threading from typing import IO, Any, Self, cast import aiohttp from securetar import SecureTarError, SecureTarFile, SecureTarReadError from homeassistant.backup_restore import password_to_key from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.util import dt as dt_util from homeassistant.util.json import JsonObjectType, json_loads_object from .const import BUF_SIZE, LOGGER from .models import AddonInfo, AgentBackup, Folder class DecryptError(HomeAssistantError): """Error during decryption.""" _message = "Unexpected error during decryption." class EncryptError(HomeAssistantError): """Error during encryption.""" _message = "Unexpected error during encryption." class UnsupportedSecureTarVersion(DecryptError): """Unsupported securetar version.""" _message = "Unsupported securetar version." class IncorrectPassword(DecryptError): """Invalid password or corrupted backup.""" _message = "Invalid password or corrupted backup." class BackupEmpty(DecryptError): """No tar files found in the backup.""" _message = "No tar files found in the backup." class AbortCipher(HomeAssistantError): """Abort the cipher operation.""" _message = "Abort cipher operation." def make_backup_dir(path: Path) -> None: """Create a backup directory if it does not exist.""" path.mkdir(exist_ok=True) def read_backup(backup_path: Path) -> AgentBackup: """Read a backup from disk.""" with tarfile.open(backup_path, "r:", bufsize=BUF_SIZE) as backup_file: if not (data_file := backup_file.extractfile("./backup.json")): raise KeyError("backup.json not found in tar file") data = json_loads_object(data_file.read()) addons = [ AddonInfo( name=cast(str, addon["name"]), slug=cast(str, addon["slug"]), version=cast(str, addon["version"]), ) for addon in cast(list[JsonObjectType], data.get("addons", [])) ] folders = [ Folder(folder) for folder in cast(list[str], data.get("folders", [])) if folder != "homeassistant" ] homeassistant_included = False homeassistant_version: str | None = None database_included = False if ( homeassistant := cast(JsonObjectType, data.get("homeassistant")) ) and "version" in homeassistant: homeassistant_included = True homeassistant_version = cast(str, homeassistant["version"]) database_included = not cast( bool, homeassistant.get("exclude_database", False) ) extra_metadata = cast(dict[str, bool | str], data.get("extra", {})) date = extra_metadata.get("supervisor.backup_request_date", data["date"]) return AgentBackup( addons=addons, backup_id=cast(str, data["slug"]), database_included=database_included, date=cast(str, date), extra_metadata=extra_metadata, folders=folders, homeassistant_included=homeassistant_included, homeassistant_version=homeassistant_version, name=cast(str, data["name"]), protected=cast(bool, data.get("protected", False)), size=backup_path.stat().st_size, ) def suggested_filename_from_name_date(name: str, date_str: str) -> str: """Suggest a filename for the backup.""" date = dt_util.parse_datetime(date_str, raise_on_error=True) return "_".join(f"{name} {date.strftime('%Y-%m-%d %H.%M %S%f')}.tar".split()) def suggested_filename(backup: AgentBackup) -> str: """Suggest a filename for the backup.""" return suggested_filename_from_name_date(backup.name, backup.date) def validate_password(path: Path, password: str | None) -> bool: """Validate the password.""" with tarfile.open(path, "r:", bufsize=BUF_SIZE) as backup_file: compressed = False ha_tar_name = "homeassistant.tar" try: ha_tar = backup_file.extractfile(ha_tar_name) except KeyError: compressed = True ha_tar_name = "homeassistant.tar.gz" try: ha_tar = backup_file.extractfile(ha_tar_name) except KeyError: LOGGER.error("No homeassistant.tar or homeassistant.tar.gz found") return False try: with SecureTarFile( path, # Not used gzip=compressed, key=password_to_key(password) if password is not None else None, mode="r", fileobj=ha_tar, ): # If we can read the tar file, the password is correct return True except tarfile.ReadError: LOGGER.debug("Invalid password") return False except Exception: # noqa: BLE001 LOGGER.exception("Unexpected error validating password") return False class AsyncIteratorReader: """Wrap an AsyncIterator.""" def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None: """Initialize the wrapper.""" self._aborted = False self._hass = hass self._stream = stream self._buffer: bytes | None = None self._next_future: Future[bytes | None] | None = None self._pos: int = 0 async def _next(self) -> bytes | None: """Get the next chunk from the iterator.""" return await anext(self._stream, None) def abort(self) -> None: """Abort the reader.""" self._aborted = True if self._next_future is not None: self._next_future.cancel() def read(self, n: int = -1, /) -> bytes: """Read data from the iterator.""" result = bytearray() while n < 0 or len(result) < n: if not self._buffer: self._next_future = asyncio.run_coroutine_threadsafe( self._next(), self._hass.loop ) if self._aborted: self._next_future.cancel() raise AbortCipher try: self._buffer = self._next_future.result() except CancelledError as err: raise AbortCipher from err self._pos = 0 if not self._buffer: # The stream is exhausted break chunk = self._buffer[self._pos : self._pos + n] result.extend(chunk) n -= len(chunk) self._pos += len(chunk) if self._pos == len(self._buffer): self._buffer = None return bytes(result) def close(self) -> None: """Close the iterator.""" class AsyncIteratorWriter: """Wrap an AsyncIterator.""" def __init__(self, hass: HomeAssistant) -> None: """Initialize the wrapper.""" self._aborted = False self._hass = hass self._pos: int = 0 self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1) self._write_future: Future[bytes | None] | None = None def __aiter__(self) -> Self: """Return the iterator.""" return self async def __anext__(self) -> bytes: """Get the next chunk from the iterator.""" if data := await self._queue.get(): return data raise StopAsyncIteration def abort(self) -> None: """Abort the writer.""" self._aborted = True if self._write_future is not None: self._write_future.cancel() def tell(self) -> int: """Return the current position in the iterator.""" return self._pos def write(self, s: bytes, /) -> int: """Write data to the iterator.""" self._write_future = asyncio.run_coroutine_threadsafe( self._queue.put(s), self._hass.loop ) if self._aborted: self._write_future.cancel() raise AbortCipher try: self._write_future.result() except CancelledError as err: raise AbortCipher from err self._pos += len(s) return len(s) def validate_password_stream( input_stream: IO[bytes], password: str | None, ) -> None: """Decrypt a backup.""" with ( tarfile.open(fileobj=input_stream, mode="r|", bufsize=BUF_SIZE) as input_tar, ): for obj in input_tar: if not obj.name.endswith((".tar", ".tgz", ".tar.gz")): continue istf = SecureTarFile( None, # Not used gzip=False, key=password_to_key(password) if password is not None else None, mode="r", fileobj=input_tar.extractfile(obj), ) with istf.decrypt(obj) as decrypted: if istf.securetar_header.plaintext_size is None: raise UnsupportedSecureTarVersion try: decrypted.read(1) # Read a single byte to trigger the decryption except SecureTarReadError as err: raise IncorrectPassword from err return raise BackupEmpty def decrypt_backup( input_stream: IO[bytes], output_stream: IO[bytes], password: str | None, on_done: Callable[[Exception | None], None], minimum_size: int, nonces: list[bytes], ) -> None: """Decrypt a backup.""" error: Exception | None = None try: try: with ( tarfile.open( fileobj=input_stream, mode="r|", bufsize=BUF_SIZE ) as input_tar, tarfile.open( fileobj=output_stream, mode="w|", bufsize=BUF_SIZE ) as output_tar, ): _decrypt_backup(input_tar, output_tar, password) except (DecryptError, SecureTarError, tarfile.TarError) as err: LOGGER.warning("Error decrypting backup: %s", err) error = err else: # Pad the output stream to the requested minimum size padding = max(minimum_size - output_stream.tell(), 0) output_stream.write(b"\0" * padding) finally: # Write an empty chunk to signal the end of the stream output_stream.write(b"") except AbortCipher: LOGGER.debug("Cipher operation aborted") finally: on_done(error) def _decrypt_backup( input_tar: tarfile.TarFile, output_tar: tarfile.TarFile, password: str | None, ) -> None: """Decrypt a backup.""" for obj in input_tar: # We compare with PurePath to avoid issues with different path separators, # for example when backup.json is added as "./backup.json" if PurePath(obj.name) == PurePath("backup.json"): # Rewrite the backup.json file to indicate that the backup is decrypted if not (reader := input_tar.extractfile(obj)): raise DecryptError metadata = json_loads_object(reader.read()) metadata["protected"] = False updated_metadata_b = json.dumps(metadata).encode() metadata_obj = copy.deepcopy(obj) metadata_obj.size = len(updated_metadata_b) output_tar.addfile(metadata_obj, BytesIO(updated_metadata_b)) continue if not obj.name.endswith((".tar", ".tgz", ".tar.gz")): output_tar.addfile(obj, input_tar.extractfile(obj)) continue istf = SecureTarFile( None, # Not used gzip=False, key=password_to_key(password) if password is not None else None, mode="r", fileobj=input_tar.extractfile(obj), ) with istf.decrypt(obj) as decrypted: if (plaintext_size := istf.securetar_header.plaintext_size) is None: raise UnsupportedSecureTarVersion decrypted_obj = copy.deepcopy(obj) decrypted_obj.size = plaintext_size output_tar.addfile(decrypted_obj, decrypted) def encrypt_backup( input_stream: IO[bytes], output_stream: IO[bytes], password: str | None, on_done: Callable[[Exception | None], None], minimum_size: int, nonces: list[bytes], ) -> None: """Encrypt a backup.""" error: Exception | None = None try: try: with ( tarfile.open( fileobj=input_stream, mode="r|", bufsize=BUF_SIZE ) as input_tar, tarfile.open( fileobj=output_stream, mode="w|", bufsize=BUF_SIZE ) as output_tar, ): _encrypt_backup(input_tar, output_tar, password, nonces) except (EncryptError, SecureTarError, tarfile.TarError) as err: LOGGER.warning("Error encrypting backup: %s", err) error = err else: # Pad the output stream to the requested minimum size padding = max(minimum_size - output_stream.tell(), 0) output_stream.write(b"\0" * padding) finally: # Write an empty chunk to signal the end of the stream output_stream.write(b"") except AbortCipher: LOGGER.debug("Cipher operation aborted") finally: on_done(error) def _encrypt_backup( input_tar: tarfile.TarFile, output_tar: tarfile.TarFile, password: str | None, nonces: list[bytes], ) -> None: """Encrypt a backup.""" inner_tar_idx = 0 for obj in input_tar: # We compare with PurePath to avoid issues with different path separators, # for example when backup.json is added as "./backup.json" if PurePath(obj.name) == PurePath("backup.json"): # Rewrite the backup.json file to indicate that the backup is encrypted if not (reader := input_tar.extractfile(obj)): raise EncryptError metadata = json_loads_object(reader.read()) metadata["protected"] = True updated_metadata_b = json.dumps(metadata).encode() metadata_obj = copy.deepcopy(obj) metadata_obj.size = len(updated_metadata_b) output_tar.addfile(metadata_obj, BytesIO(updated_metadata_b)) continue if not obj.name.endswith((".tar", ".tgz", ".tar.gz")): output_tar.addfile(obj, input_tar.extractfile(obj)) continue istf = SecureTarFile( None, # Not used gzip=False, key=password_to_key(password) if password is not None else None, mode="r", fileobj=input_tar.extractfile(obj), nonce=nonces[inner_tar_idx], ) inner_tar_idx += 1 with istf.encrypt(obj) as encrypted: encrypted_obj = copy.deepcopy(obj) encrypted_obj.size = encrypted.encrypted_size output_tar.addfile(encrypted_obj, encrypted) @dataclass(kw_only=True) class _CipherWorkerStatus: done: asyncio.Event error: Exception | None = None reader: AsyncIteratorReader thread: threading.Thread writer: AsyncIteratorWriter class _CipherBackupStreamer: """Encrypt or decrypt a backup.""" _cipher_func: Callable[ [ IO[bytes], IO[bytes], str | None, Callable[[Exception | None], None], int, list[bytes], ], None, ] def __init__( self, hass: HomeAssistant, backup: AgentBackup, open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]], password: str | None, ) -> None: """Initialize.""" self._workers: list[_CipherWorkerStatus] = [] self._backup = backup self._hass = hass self._open_stream = open_stream self._password = password self._nonces: list[bytes] = [] def size(self) -> int: """Return the maximum size of the decrypted or encrypted backup.""" return self._backup.size + self._num_tar_files() * tarfile.RECORDSIZE def _num_tar_files(self) -> int: """Return the number of inner tar files.""" b = self._backup return len(b.addons) + len(b.folders) + b.homeassistant_included + 1 async def open_stream(self) -> AsyncIterator[bytes]: """Open a stream.""" def on_done(error: Exception | None) -> None: """Call by the worker thread when it's done.""" worker_status.error = error self._hass.loop.call_soon_threadsafe(worker_status.done.set) stream = await self._open_stream() reader = AsyncIteratorReader(self._hass, stream) writer = AsyncIteratorWriter(self._hass) worker = threading.Thread( target=self._cipher_func, args=[reader, writer, self._password, on_done, self.size(), self._nonces], ) worker_status = _CipherWorkerStatus( done=asyncio.Event(), reader=reader, thread=worker, writer=writer ) self._workers.append(worker_status) worker.start() return writer async def wait(self) -> None: """Wait for the worker threads to finish.""" for worker in self._workers: worker.reader.abort() worker.writer.abort() await asyncio.gather(*(worker.done.wait() for worker in self._workers)) class DecryptedBackupStreamer(_CipherBackupStreamer): """Decrypt a backup.""" _cipher_func = staticmethod(decrypt_backup) def backup(self) -> AgentBackup: """Return the decrypted backup.""" return replace(self._backup, protected=False, size=self.size()) class EncryptedBackupStreamer(_CipherBackupStreamer): """Encrypt a backup.""" def __init__( self, hass: HomeAssistant, backup: AgentBackup, open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]], password: str | None, ) -> None: """Initialize.""" super().__init__(hass, backup, open_stream, password) self._nonces = [os.urandom(16) for _ in range(self._num_tar_files())] _cipher_func = staticmethod(encrypt_backup) def backup(self) -> AgentBackup: """Return the encrypted backup.""" return replace(self._backup, protected=True, size=self.size()) async def receive_file( hass: HomeAssistant, contents: aiohttp.BodyPartReader, path: Path ) -> None: """Receive a file from a stream and write it to a file.""" queue: SimpleQueue[tuple[bytes, asyncio.Future[None] | None] | None] = SimpleQueue() def _sync_queue_consumer() -> None: with path.open("wb") as file_handle: while True: if (_chunk_future := queue.get()) is None: break _chunk, _future = _chunk_future if _future is not None: hass.loop.call_soon_threadsafe(_future.set_result, None) file_handle.write(_chunk) fut: asyncio.Future[None] | None = None try: fut = hass.async_add_executor_job(_sync_queue_consumer) megabytes_sending = 0 while chunk := await contents.read_chunk(BUF_SIZE): megabytes_sending += 1 if megabytes_sending % 5 != 0: queue.put_nowait((chunk, None)) continue chunk_future = hass.loop.create_future() queue.put_nowait((chunk, chunk_future)) await asyncio.wait( (fut, chunk_future), return_when=asyncio.FIRST_COMPLETED, ) if fut.done(): # The executor job failed break queue.put_nowait(None) # terminate queue consumer finally: if fut is not None: await fut