mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Interrupt _CipherBackupStreamer workers (#136845)
* Interrupt _CipherBackupStreamer workers * Fix cleanup * Only abort live threads
This commit is contained in:
parent
3118831557
commit
660653e226
@ -12,7 +12,6 @@ import os
|
||||
from pathlib import Path, PurePath
|
||||
from queue import SimpleQueue
|
||||
import tarfile
|
||||
import threading
|
||||
from typing import IO, Any, Self, cast
|
||||
|
||||
import aiohttp
|
||||
@ -22,6 +21,7 @@ from homeassistant.backup_restore import password_to_key
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util.json import JsonObjectType, json_loads_object
|
||||
from homeassistant.util.thread import ThreadWithException
|
||||
|
||||
from .const import BUF_SIZE, LOGGER
|
||||
from .models import AddonInfo, AgentBackup, Folder
|
||||
@ -57,6 +57,12 @@ class BackupEmpty(DecryptError):
|
||||
_message = "No tar files found in the backup."
|
||||
|
||||
|
||||
class AbortCipher(HomeAssistantError):
|
||||
"""Abort the cipher operation."""
|
||||
|
||||
_message = "Abort cipher operation."
|
||||
|
||||
|
||||
def make_backup_dir(path: Path) -> None:
|
||||
"""Create a backup directory if it does not exist."""
|
||||
path.mkdir(exist_ok=True)
|
||||
@ -251,6 +257,7 @@ def decrypt_backup(
|
||||
) -> None:
|
||||
"""Decrypt a backup."""
|
||||
error: Exception | None = None
|
||||
try:
|
||||
try:
|
||||
with (
|
||||
tarfile.open(
|
||||
@ -269,7 +276,11 @@ def decrypt_backup(
|
||||
padding = max(minimum_size - output_stream.tell(), 0)
|
||||
output_stream.write(b"\0" * padding)
|
||||
finally:
|
||||
output_stream.write(b"") # Write an empty chunk to signal the end of the stream
|
||||
# Write an empty chunk to signal the end of the stream
|
||||
output_stream.write(b"")
|
||||
except AbortCipher:
|
||||
LOGGER.debug("Cipher operation aborted")
|
||||
finally:
|
||||
on_done(error)
|
||||
|
||||
|
||||
@ -321,6 +332,7 @@ def encrypt_backup(
|
||||
) -> None:
|
||||
"""Encrypt a backup."""
|
||||
error: Exception | None = None
|
||||
try:
|
||||
try:
|
||||
with (
|
||||
tarfile.open(
|
||||
@ -339,7 +351,11 @@ def encrypt_backup(
|
||||
padding = max(minimum_size - output_stream.tell(), 0)
|
||||
output_stream.write(b"\0" * padding)
|
||||
finally:
|
||||
output_stream.write(b"") # Write an empty chunk to signal the end of the stream
|
||||
# Write an empty chunk to signal the end of the stream
|
||||
output_stream.write(b"")
|
||||
except AbortCipher:
|
||||
LOGGER.debug("Cipher operation aborted")
|
||||
finally:
|
||||
on_done(error)
|
||||
|
||||
|
||||
@ -387,7 +403,7 @@ def _encrypt_backup(
|
||||
class _CipherWorkerStatus:
|
||||
done: asyncio.Event
|
||||
error: Exception | None = None
|
||||
thread: threading.Thread
|
||||
thread: ThreadWithException
|
||||
|
||||
|
||||
class _CipherBackupStreamer:
|
||||
@ -440,7 +456,7 @@ class _CipherBackupStreamer:
|
||||
stream = await self._open_stream()
|
||||
reader = AsyncIteratorReader(self._hass, stream)
|
||||
writer = AsyncIteratorWriter(self._hass)
|
||||
worker = threading.Thread(
|
||||
worker = ThreadWithException(
|
||||
target=self._cipher_func,
|
||||
args=[reader, writer, self._password, on_done, self.size(), self._nonces],
|
||||
)
|
||||
@ -451,6 +467,10 @@ class _CipherBackupStreamer:
|
||||
|
||||
async def wait(self) -> None:
|
||||
"""Wait for the worker threads to finish."""
|
||||
for worker in self._workers:
|
||||
if not worker.thread.is_alive():
|
||||
continue
|
||||
worker.thread.raise_exc(AbortCipher)
|
||||
await asyncio.gather(*(worker.done.wait() for worker in self._workers))
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user