Interrupt _CipherBackupStreamer workers (#136845)

* Interrupt _CipherBackupStreamer workers

* Fix cleanup

* Only abort live threads
This commit is contained in:
Erik Montnemery 2025-01-29 17:44:29 +01:00 committed by GitHub
parent 3118831557
commit 660653e226
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,7 +12,6 @@ import os
from pathlib import Path, PurePath
from queue import SimpleQueue
import tarfile
import threading
from typing import IO, Any, Self, cast
import aiohttp
@ -22,6 +21,7 @@ from homeassistant.backup_restore import password_to_key
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.json import JsonObjectType, json_loads_object
from homeassistant.util.thread import ThreadWithException
from .const import BUF_SIZE, LOGGER
from .models import AddonInfo, AgentBackup, Folder
@ -57,6 +57,12 @@ class BackupEmpty(DecryptError):
_message = "No tar files found in the backup."
class AbortCipher(HomeAssistantError):
"""Abort the cipher operation."""
_message = "Abort cipher operation."
def make_backup_dir(path: Path) -> None:
"""Create a backup directory if it does not exist."""
path.mkdir(exist_ok=True)
@ -252,24 +258,29 @@ def decrypt_backup(
"""Decrypt a backup."""
error: Exception | None = None
try:
with (
tarfile.open(
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
) as input_tar,
tarfile.open(
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
) as output_tar,
):
_decrypt_backup(input_tar, output_tar, password)
except (DecryptError, SecureTarError, tarfile.TarError) as err:
LOGGER.warning("Error decrypting backup: %s", err)
error = err
else:
# Pad the output stream to the requested minimum size
padding = max(minimum_size - output_stream.tell(), 0)
output_stream.write(b"\0" * padding)
try:
with (
tarfile.open(
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
) as input_tar,
tarfile.open(
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
) as output_tar,
):
_decrypt_backup(input_tar, output_tar, password)
except (DecryptError, SecureTarError, tarfile.TarError) as err:
LOGGER.warning("Error decrypting backup: %s", err)
error = err
else:
# Pad the output stream to the requested minimum size
padding = max(minimum_size - output_stream.tell(), 0)
output_stream.write(b"\0" * padding)
finally:
# Write an empty chunk to signal the end of the stream
output_stream.write(b"")
except AbortCipher:
LOGGER.debug("Cipher operation aborted")
finally:
output_stream.write(b"") # Write an empty chunk to signal the end of the stream
on_done(error)
@ -322,24 +333,29 @@ def encrypt_backup(
"""Encrypt a backup."""
error: Exception | None = None
try:
with (
tarfile.open(
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
) as input_tar,
tarfile.open(
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
) as output_tar,
):
_encrypt_backup(input_tar, output_tar, password, nonces)
except (EncryptError, SecureTarError, tarfile.TarError) as err:
LOGGER.warning("Error encrypting backup: %s", err)
error = err
else:
# Pad the output stream to the requested minimum size
padding = max(minimum_size - output_stream.tell(), 0)
output_stream.write(b"\0" * padding)
try:
with (
tarfile.open(
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
) as input_tar,
tarfile.open(
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
) as output_tar,
):
_encrypt_backup(input_tar, output_tar, password, nonces)
except (EncryptError, SecureTarError, tarfile.TarError) as err:
LOGGER.warning("Error encrypting backup: %s", err)
error = err
else:
# Pad the output stream to the requested minimum size
padding = max(minimum_size - output_stream.tell(), 0)
output_stream.write(b"\0" * padding)
finally:
# Write an empty chunk to signal the end of the stream
output_stream.write(b"")
except AbortCipher:
LOGGER.debug("Cipher operation aborted")
finally:
output_stream.write(b"") # Write an empty chunk to signal the end of the stream
on_done(error)
@ -387,7 +403,7 @@ def _encrypt_backup(
class _CipherWorkerStatus:
done: asyncio.Event
error: Exception | None = None
thread: threading.Thread
thread: ThreadWithException
class _CipherBackupStreamer:
@ -440,7 +456,7 @@ class _CipherBackupStreamer:
stream = await self._open_stream()
reader = AsyncIteratorReader(self._hass, stream)
writer = AsyncIteratorWriter(self._hass)
worker = threading.Thread(
worker = ThreadWithException(
target=self._cipher_func,
args=[reader, writer, self._password, on_done, self.size(), self._nonces],
)
@ -451,6 +467,10 @@ class _CipherBackupStreamer:
async def wait(self) -> None:
"""Wait for the worker threads to finish."""
for worker in self._workers:
if not worker.thread.is_alive():
continue
worker.thread.raise_exc(AbortCipher)
await asyncio.gather(*(worker.done.wait() for worker in self._workers))