"""Backup platform for the cloud integration.""" from __future__ import annotations import base64 from collections.abc import AsyncIterator, Callable, Coroutine, Mapping import hashlib import logging from typing import Any, Self from aiohttp import ClientError, ClientTimeout, StreamReader from hass_nabucasa import Cloud, CloudError from hass_nabucasa.cloud_api import ( async_files_delete_file, async_files_download_details, async_files_list, async_files_upload_details, ) from homeassistant.components.backup import AgentBackup, BackupAgent, BackupAgentError from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.dispatcher import async_dispatcher_connect from .client import CloudClient from .const import DATA_CLOUD, DOMAIN, EVENT_CLOUD_EVENT _LOGGER = logging.getLogger(__name__) _STORAGE_BACKUP = "backup" async def _b64md5(stream: AsyncIterator[bytes]) -> str: """Calculate the MD5 hash of a file.""" file_hash = hashlib.md5() async for chunk in stream: file_hash.update(chunk) return base64.b64encode(file_hash.digest()).decode() async def async_get_backup_agents( hass: HomeAssistant, **kwargs: Any, ) -> list[BackupAgent]: """Return the cloud backup agent.""" cloud = hass.data[DATA_CLOUD] if not cloud.is_logged_in: return [] return [CloudBackupAgent(hass=hass, cloud=cloud)] @callback def async_register_backup_agents_listener( hass: HomeAssistant, *, listener: Callable[[], None], **kwargs: Any, ) -> Callable[[], None]: """Register a listener to be called when agents are added or removed.""" @callback def unsub() -> None: """Unsubscribe from events.""" unsub_signal() @callback def handle_event(data: Mapping[str, Any]) -> None: """Handle event.""" if data["type"] not in ("login", "logout"): return listener() unsub_signal = async_dispatcher_connect(hass, EVENT_CLOUD_EVENT, handle_event) return unsub class ChunkAsyncStreamIterator: """Async iterator for chunked streams. Based on aiohttp.streams.ChunkTupleAsyncStreamIterator, but yields bytes instead of tuple[bytes, bool]. """ __slots__ = ("_stream",) def __init__(self, stream: StreamReader) -> None: """Initialize.""" self._stream = stream def __aiter__(self) -> Self: """Iterate.""" return self async def __anext__(self) -> bytes: """Yield next chunk.""" rv = await self._stream.readchunk() if rv == (b"", False): raise StopAsyncIteration return rv[0] class CloudBackupAgent(BackupAgent): """Cloud backup agent.""" domain = DOMAIN name = DOMAIN def __init__(self, hass: HomeAssistant, cloud: Cloud[CloudClient]) -> None: """Initialize the cloud backup sync agent.""" super().__init__() self._cloud = cloud self._hass = hass @callback def _get_backup_filename(self) -> str: """Return the backup filename.""" return f"{self._cloud.client.prefs.instance_id}.tar" async def async_download_backup( self, backup_id: str, **kwargs: Any, ) -> AsyncIterator[bytes]: """Download a backup file. :param backup_id: The ID of the backup that was returned in async_list_backups. :return: An async iterator that yields bytes. """ if not await self.async_get_backup(backup_id): raise BackupAgentError("Backup not found") try: details = await async_files_download_details( self._cloud, storage_type=_STORAGE_BACKUP, filename=self._get_backup_filename(), ) except (ClientError, CloudError) as err: raise BackupAgentError("Failed to get download details") from err try: resp = await self._cloud.websession.get(details["url"]) resp.raise_for_status() except ClientError as err: raise BackupAgentError("Failed to download backup") from err return ChunkAsyncStreamIterator(resp.content) async def async_upload_backup( self, *, open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]], backup: AgentBackup, **kwargs: Any, ) -> None: """Upload a backup. :param open_stream: A function returning an async iterator that yields bytes. :param backup: Metadata about the backup that should be uploaded. """ if not backup.protected: raise BackupAgentError("Cloud backups must be protected") base64md5hash = await _b64md5(await open_stream()) try: details = await async_files_upload_details( self._cloud, storage_type=_STORAGE_BACKUP, filename=self._get_backup_filename(), metadata=backup.as_dict(), size=backup.size, base64md5hash=base64md5hash, ) except (ClientError, CloudError) as err: raise BackupAgentError("Failed to get upload details") from err try: upload_status = await self._cloud.websession.put( details["url"], data=await open_stream(), headers=details["headers"] | {"content-length": str(backup.size)}, timeout=ClientTimeout(connect=10.0, total=43200.0), # 43200s == 12h ) _LOGGER.log( logging.DEBUG if upload_status.status < 400 else logging.WARNING, "Backup upload status: %s", upload_status.status, ) upload_status.raise_for_status() except (TimeoutError, ClientError) as err: raise BackupAgentError("Failed to upload backup") from err async def async_delete_backup( self, backup_id: str, **kwargs: Any, ) -> None: """Delete a backup file. :param backup_id: The ID of the backup that was returned in async_list_backups. """ if not await self.async_get_backup(backup_id): return try: await async_files_delete_file( self._cloud, storage_type=_STORAGE_BACKUP, filename=self._get_backup_filename(), ) except (ClientError, CloudError) as err: raise BackupAgentError("Failed to delete backup") from err async def async_list_backups(self, **kwargs: Any) -> list[AgentBackup]: """List backups.""" try: backups = await async_files_list(self._cloud, storage_type=_STORAGE_BACKUP) _LOGGER.debug("Cloud backups: %s", backups) except (ClientError, CloudError) as err: raise BackupAgentError("Failed to list backups") from err return [AgentBackup.from_dict(backup["Metadata"]) for backup in backups] async def async_get_backup( self, backup_id: str, **kwargs: Any, ) -> AgentBackup | None: """Return a backup.""" backups = await self.async_list_backups() for backup in backups: if backup.backup_id == backup_id: return backup return None