diff --git a/supervisor/backups/backup.py b/supervisor/backups/backup.py index 3f92c6b54..b43f01672 100644 --- a/supervisor/backups/backup.py +++ b/supervisor/backups/backup.py @@ -6,11 +6,13 @@ from collections.abc import Awaitable from copy import deepcopy from datetime import timedelta from functools import cached_property +import io import json import logging from pathlib import Path import tarfile from tempfile import TemporaryDirectory +import time from typing import Any from awesomeversion import AwesomeVersion, AwesomeVersionCompareException @@ -51,7 +53,7 @@ from ..jobs.decorator import Job from ..jobs.job_group import JobGroup from ..utils import remove_folder from ..utils.dt import parse_datetime, utcnow -from ..utils.json import write_json_file +from ..utils.json import json_bytes from .const import BUF_SIZE, BackupType from .utils import key_to_iv, password_to_key from .validate import SCHEMA_BACKUP @@ -76,6 +78,8 @@ class Backup(JobGroup): self._tarfile: Path = tar_file self._data: dict[str, Any] = data or {ATTR_SLUG: slug} self._tmp = None + self._outer_secure_tarfile: SecureTarFile | None = None + self._outer_secure_tarfile_tarfile: tarfile.TarFile | None = None self._key: bytes | None = None self._aes: Cipher | None = None @@ -321,13 +325,21 @@ class Backup(JobGroup): async def __aenter__(self): """Async context to open a backup.""" - self._tmp = TemporaryDirectory(dir=str(self.tarfile.parent)) # create a backup if not self.tarfile.is_file(): - return self + self._outer_secure_tarfile = SecureTarFile( + self.tarfile, + "w", + gzip=False, + bufsize=BUF_SIZE, + ) + self._outer_secure_tarfile_tarfile = self._outer_secure_tarfile.__enter__() + return # extract an existing backup + self._tmp = TemporaryDirectory(dir=str(self.tarfile.parent)) + def _extract_backup(): """Extract a backup.""" with tarfile.open(self.tarfile, "r:") as tar: @@ -342,8 +354,26 @@ class Backup(JobGroup): async def __aexit__(self, exception_type, exception_value, traceback): """Async context to close a backup.""" # exists backup or exception on build - if self.tarfile.is_file() or exception_type is not None: - self._tmp.cleanup() + try: + await self._aexit(exception_type, exception_value, traceback) + finally: + if self._tmp: + self._tmp.cleanup() + if self._outer_secure_tarfile: + self._outer_secure_tarfile.__exit__( + exception_type, exception_value, traceback + ) + self._outer_secure_tarfile = None + self._outer_secure_tarfile_tarfile = None + + async def _aexit(self, exception_type, exception_value, traceback): + """Cleanup after backup creation. + + This is a separate method to allow it to be called from __aexit__ to ensure + that cleanup is always performed, even if an exception is raised. + """ + # If we're not creating a new backup, or if an exception was raised, we're done + if not self._outer_secure_tarfile or exception_type is not None: return # validate data @@ -356,19 +386,20 @@ class Backup(JobGroup): raise ValueError("Invalid config") from None # new backup, build it - def _create_backup(): + def _add_backup_json(): """Create a new backup.""" - with tarfile.open(self.tarfile, "w:") as tar: - tar.add(self._tmp.name, arcname=".") + raw_bytes = json_bytes(self._data) + fileobj = io.BytesIO(raw_bytes) + tar_info = tarfile.TarInfo(name="./backup.json") + tar_info.size = len(raw_bytes) + tar_info.mtime = int(time.time()) + self._outer_secure_tarfile_tarfile.addfile(tar_info, fileobj=fileobj) try: - write_json_file(Path(self._tmp.name, "backup.json"), self._data) - await self.sys_run_in_executor(_create_backup) + await self.sys_run_in_executor(_add_backup_json) except (OSError, json.JSONDecodeError) as err: self.sys_jobs.current.capture_error(BackupError("Can't write backup")) _LOGGER.error("Can't write backup: %s", err) - finally: - self._tmp.cleanup() @Job(name="backup_addon_save", cleanup=False) async def _addon_save(self, addon: Addon) -> asyncio.Task | None: @@ -376,14 +407,12 @@ class Backup(JobGroup): self.sys_jobs.current.reference = addon.slug tar_name = f"{addon.slug}.tar{'.gz' if self.compressed else ''}" - addon_file = SecureTarFile( - Path(self._tmp.name, tar_name), - "w", - key=self._key, - gzip=self.compressed, - bufsize=BUF_SIZE, - ) + addon_file = self._outer_secure_tarfile.create_inner_tar( + f"./{tar_name}", + gzip=self.compressed, + key=self._key, + ) # Take backup try: start_task = await addon.backup(addon_file) @@ -493,9 +522,7 @@ class Backup(JobGroup): self.sys_jobs.current.reference = name slug_name = name.replace("/", "_") - tar_name = Path( - self._tmp.name, f"{slug_name}.tar{'.gz' if self.compressed else ''}" - ) + tar_name = f"{slug_name}.tar{'.gz' if self.compressed else ''}" origin_dir = Path(self.sys_config.path_supervisor, name) # Check if exists @@ -506,8 +533,11 @@ class Backup(JobGroup): def _save() -> None: # Take backup _LOGGER.info("Backing up folder %s", name) - with SecureTarFile( - tar_name, "w", key=self._key, gzip=self.compressed, bufsize=BUF_SIZE + + with self._outer_secure_tarfile.create_inner_tar( + f"./{tar_name}", + gzip=self.compressed, + key=self._key, ) as tar_file: atomic_contents_add( tar_file, @@ -677,12 +707,12 @@ class Backup(JobGroup): ATTR_EXCLUDE_DATABASE: exclude_database, } + tar_name = f"homeassistant.tar{'.gz' if self.compressed else ''}" # Backup Home Assistant Core config directory - tar_name = Path( - self._tmp.name, f"homeassistant.tar{'.gz' if self.compressed else ''}" - ) - homeassistant_file = SecureTarFile( - tar_name, "w", key=self._key, gzip=self.compressed, bufsize=BUF_SIZE + homeassistant_file = self._outer_secure_tarfile.create_inner_tar( + f"./{tar_name}", + gzip=self.compressed, + key=self._key, ) await self.sys_homeassistant.backup(homeassistant_file, exclude_database) diff --git a/tests/backups/test_backup.py b/tests/backups/test_backup.py index ab6379333..c80b04177 100644 --- a/tests/backups/test_backup.py +++ b/tests/backups/test_backup.py @@ -16,7 +16,7 @@ async def test_new_backup_stays_in_folder(coresys: CoreSys, tmp_path: Path): async with backup: assert len(listdir(tmp_path)) == 1 - assert not backup.tarfile.exists() + assert backup.tarfile.exists() assert len(listdir(tmp_path)) == 1 assert backup.tarfile.exists()