diff --git a/homeassistant/components/backup/util.py b/homeassistant/components/backup/util.py index bea3fe1f4ef..2416aa5f28e 100644 --- a/homeassistant/components/backup/util.py +++ b/homeassistant/components/backup/util.py @@ -12,7 +12,6 @@ import os from pathlib import Path, PurePath from queue import SimpleQueue import tarfile -import threading from typing import IO, Any, Self, cast import aiohttp @@ -22,6 +21,7 @@ from homeassistant.backup_restore import password_to_key from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.util.json import JsonObjectType, json_loads_object +from homeassistant.util.thread import ThreadWithException from .const import BUF_SIZE, LOGGER from .models import AddonInfo, AgentBackup, Folder @@ -57,6 +57,12 @@ class BackupEmpty(DecryptError): _message = "No tar files found in the backup." +class AbortCipher(HomeAssistantError): + """Abort the cipher operation.""" + + _message = "Abort cipher operation." + + def make_backup_dir(path: Path) -> None: """Create a backup directory if it does not exist.""" path.mkdir(exist_ok=True) @@ -252,24 +258,29 @@ def decrypt_backup( """Decrypt a backup.""" error: Exception | None = None try: - with ( - tarfile.open( - fileobj=input_stream, mode="r|", bufsize=BUF_SIZE - ) as input_tar, - tarfile.open( - fileobj=output_stream, mode="w|", bufsize=BUF_SIZE - ) as output_tar, - ): - _decrypt_backup(input_tar, output_tar, password) - except (DecryptError, SecureTarError, tarfile.TarError) as err: - LOGGER.warning("Error decrypting backup: %s", err) - error = err - else: - # Pad the output stream to the requested minimum size - padding = max(minimum_size - output_stream.tell(), 0) - output_stream.write(b"\0" * padding) + try: + with ( + tarfile.open( + fileobj=input_stream, mode="r|", bufsize=BUF_SIZE + ) as input_tar, + tarfile.open( + fileobj=output_stream, mode="w|", bufsize=BUF_SIZE + ) as output_tar, + ): + _decrypt_backup(input_tar, output_tar, password) + except (DecryptError, SecureTarError, tarfile.TarError) as err: + LOGGER.warning("Error decrypting backup: %s", err) + error = err + else: + # Pad the output stream to the requested minimum size + padding = max(minimum_size - output_stream.tell(), 0) + output_stream.write(b"\0" * padding) + finally: + # Write an empty chunk to signal the end of the stream + output_stream.write(b"") + except AbortCipher: + LOGGER.debug("Cipher operation aborted") finally: - output_stream.write(b"") # Write an empty chunk to signal the end of the stream on_done(error) @@ -322,24 +333,29 @@ def encrypt_backup( """Encrypt a backup.""" error: Exception | None = None try: - with ( - tarfile.open( - fileobj=input_stream, mode="r|", bufsize=BUF_SIZE - ) as input_tar, - tarfile.open( - fileobj=output_stream, mode="w|", bufsize=BUF_SIZE - ) as output_tar, - ): - _encrypt_backup(input_tar, output_tar, password, nonces) - except (EncryptError, SecureTarError, tarfile.TarError) as err: - LOGGER.warning("Error encrypting backup: %s", err) - error = err - else: - # Pad the output stream to the requested minimum size - padding = max(minimum_size - output_stream.tell(), 0) - output_stream.write(b"\0" * padding) + try: + with ( + tarfile.open( + fileobj=input_stream, mode="r|", bufsize=BUF_SIZE + ) as input_tar, + tarfile.open( + fileobj=output_stream, mode="w|", bufsize=BUF_SIZE + ) as output_tar, + ): + _encrypt_backup(input_tar, output_tar, password, nonces) + except (EncryptError, SecureTarError, tarfile.TarError) as err: + LOGGER.warning("Error encrypting backup: %s", err) + error = err + else: + # Pad the output stream to the requested minimum size + padding = max(minimum_size - output_stream.tell(), 0) + output_stream.write(b"\0" * padding) + finally: + # Write an empty chunk to signal the end of the stream + output_stream.write(b"") + except AbortCipher: + LOGGER.debug("Cipher operation aborted") finally: - output_stream.write(b"") # Write an empty chunk to signal the end of the stream on_done(error) @@ -387,7 +403,7 @@ def _encrypt_backup( class _CipherWorkerStatus: done: asyncio.Event error: Exception | None = None - thread: threading.Thread + thread: ThreadWithException class _CipherBackupStreamer: @@ -440,7 +456,7 @@ class _CipherBackupStreamer: stream = await self._open_stream() reader = AsyncIteratorReader(self._hass, stream) writer = AsyncIteratorWriter(self._hass) - worker = threading.Thread( + worker = ThreadWithException( target=self._cipher_func, args=[reader, writer, self._password, on_done, self.size(), self._nonces], ) @@ -451,6 +467,10 @@ class _CipherBackupStreamer: async def wait(self) -> None: """Wait for the worker threads to finish.""" + for worker in self._workers: + if not worker.thread.is_alive(): + continue + worker.thread.raise_exc(AbortCipher) await asyncio.gather(*(worker.done.wait() for worker in self._workers))