Interrupt _CipherBackupStreamer workers (#136845)

* Interrupt _CipherBackupStreamer workers

* Fix cleanup

* Only abort live threads
This commit is contained in:
Erik Montnemery 2025-01-29 17:44:29 +01:00 committed by GitHub
parent 3118831557
commit 660653e226
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,7 +12,6 @@ import os
from pathlib import Path, PurePath from pathlib import Path, PurePath
from queue import SimpleQueue from queue import SimpleQueue
import tarfile import tarfile
import threading
from typing import IO, Any, Self, cast from typing import IO, Any, Self, cast
import aiohttp import aiohttp
@ -22,6 +21,7 @@ from homeassistant.backup_restore import password_to_key
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.json import JsonObjectType, json_loads_object from homeassistant.util.json import JsonObjectType, json_loads_object
from homeassistant.util.thread import ThreadWithException
from .const import BUF_SIZE, LOGGER from .const import BUF_SIZE, LOGGER
from .models import AddonInfo, AgentBackup, Folder from .models import AddonInfo, AgentBackup, Folder
@ -57,6 +57,12 @@ class BackupEmpty(DecryptError):
_message = "No tar files found in the backup." _message = "No tar files found in the backup."
class AbortCipher(HomeAssistantError):
"""Abort the cipher operation."""
_message = "Abort cipher operation."
def make_backup_dir(path: Path) -> None: def make_backup_dir(path: Path) -> None:
"""Create a backup directory if it does not exist.""" """Create a backup directory if it does not exist."""
path.mkdir(exist_ok=True) path.mkdir(exist_ok=True)
@ -251,6 +257,7 @@ def decrypt_backup(
) -> None: ) -> None:
"""Decrypt a backup.""" """Decrypt a backup."""
error: Exception | None = None error: Exception | None = None
try:
try: try:
with ( with (
tarfile.open( tarfile.open(
@ -269,7 +276,11 @@ def decrypt_backup(
padding = max(minimum_size - output_stream.tell(), 0) padding = max(minimum_size - output_stream.tell(), 0)
output_stream.write(b"\0" * padding) output_stream.write(b"\0" * padding)
finally: finally:
output_stream.write(b"") # Write an empty chunk to signal the end of the stream # Write an empty chunk to signal the end of the stream
output_stream.write(b"")
except AbortCipher:
LOGGER.debug("Cipher operation aborted")
finally:
on_done(error) on_done(error)
@ -321,6 +332,7 @@ def encrypt_backup(
) -> None: ) -> None:
"""Encrypt a backup.""" """Encrypt a backup."""
error: Exception | None = None error: Exception | None = None
try:
try: try:
with ( with (
tarfile.open( tarfile.open(
@ -339,7 +351,11 @@ def encrypt_backup(
padding = max(minimum_size - output_stream.tell(), 0) padding = max(minimum_size - output_stream.tell(), 0)
output_stream.write(b"\0" * padding) output_stream.write(b"\0" * padding)
finally: finally:
output_stream.write(b"") # Write an empty chunk to signal the end of the stream # Write an empty chunk to signal the end of the stream
output_stream.write(b"")
except AbortCipher:
LOGGER.debug("Cipher operation aborted")
finally:
on_done(error) on_done(error)
@ -387,7 +403,7 @@ def _encrypt_backup(
class _CipherWorkerStatus: class _CipherWorkerStatus:
done: asyncio.Event done: asyncio.Event
error: Exception | None = None error: Exception | None = None
thread: threading.Thread thread: ThreadWithException
class _CipherBackupStreamer: class _CipherBackupStreamer:
@ -440,7 +456,7 @@ class _CipherBackupStreamer:
stream = await self._open_stream() stream = await self._open_stream()
reader = AsyncIteratorReader(self._hass, stream) reader = AsyncIteratorReader(self._hass, stream)
writer = AsyncIteratorWriter(self._hass) writer = AsyncIteratorWriter(self._hass)
worker = threading.Thread( worker = ThreadWithException(
target=self._cipher_func, target=self._cipher_func,
args=[reader, writer, self._password, on_done, self.size(), self._nonces], args=[reader, writer, self._password, on_done, self.size(), self._nonces],
) )
@ -451,6 +467,10 @@ class _CipherBackupStreamer:
async def wait(self) -> None: async def wait(self) -> None:
"""Wait for the worker threads to finish.""" """Wait for the worker threads to finish."""
for worker in self._workers:
if not worker.thread.is_alive():
continue
worker.thread.raise_exc(AbortCipher)
await asyncio.gather(*(worker.done.wait() for worker in self._workers)) await asyncio.gather(*(worker.done.wait() for worker in self._workers))