mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Interrupt _CipherBackupStreamer workers (#136845)
* Interrupt _CipherBackupStreamer workers * Fix cleanup * Only abort live threads
This commit is contained in:
parent
3118831557
commit
660653e226
@ -12,7 +12,6 @@ import os
|
||||
from pathlib import Path, PurePath
|
||||
from queue import SimpleQueue
|
||||
import tarfile
|
||||
import threading
|
||||
from typing import IO, Any, Self, cast
|
||||
|
||||
import aiohttp
|
||||
@ -22,6 +21,7 @@ from homeassistant.backup_restore import password_to_key
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util.json import JsonObjectType, json_loads_object
|
||||
from homeassistant.util.thread import ThreadWithException
|
||||
|
||||
from .const import BUF_SIZE, LOGGER
|
||||
from .models import AddonInfo, AgentBackup, Folder
|
||||
@ -57,6 +57,12 @@ class BackupEmpty(DecryptError):
|
||||
_message = "No tar files found in the backup."
|
||||
|
||||
|
||||
class AbortCipher(HomeAssistantError):
|
||||
"""Abort the cipher operation."""
|
||||
|
||||
_message = "Abort cipher operation."
|
||||
|
||||
|
||||
def make_backup_dir(path: Path) -> None:
|
||||
"""Create a backup directory if it does not exist."""
|
||||
path.mkdir(exist_ok=True)
|
||||
@ -252,24 +258,29 @@ def decrypt_backup(
|
||||
"""Decrypt a backup."""
|
||||
error: Exception | None = None
|
||||
try:
|
||||
with (
|
||||
tarfile.open(
|
||||
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
|
||||
) as input_tar,
|
||||
tarfile.open(
|
||||
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
|
||||
) as output_tar,
|
||||
):
|
||||
_decrypt_backup(input_tar, output_tar, password)
|
||||
except (DecryptError, SecureTarError, tarfile.TarError) as err:
|
||||
LOGGER.warning("Error decrypting backup: %s", err)
|
||||
error = err
|
||||
else:
|
||||
# Pad the output stream to the requested minimum size
|
||||
padding = max(minimum_size - output_stream.tell(), 0)
|
||||
output_stream.write(b"\0" * padding)
|
||||
try:
|
||||
with (
|
||||
tarfile.open(
|
||||
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
|
||||
) as input_tar,
|
||||
tarfile.open(
|
||||
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
|
||||
) as output_tar,
|
||||
):
|
||||
_decrypt_backup(input_tar, output_tar, password)
|
||||
except (DecryptError, SecureTarError, tarfile.TarError) as err:
|
||||
LOGGER.warning("Error decrypting backup: %s", err)
|
||||
error = err
|
||||
else:
|
||||
# Pad the output stream to the requested minimum size
|
||||
padding = max(minimum_size - output_stream.tell(), 0)
|
||||
output_stream.write(b"\0" * padding)
|
||||
finally:
|
||||
# Write an empty chunk to signal the end of the stream
|
||||
output_stream.write(b"")
|
||||
except AbortCipher:
|
||||
LOGGER.debug("Cipher operation aborted")
|
||||
finally:
|
||||
output_stream.write(b"") # Write an empty chunk to signal the end of the stream
|
||||
on_done(error)
|
||||
|
||||
|
||||
@ -322,24 +333,29 @@ def encrypt_backup(
|
||||
"""Encrypt a backup."""
|
||||
error: Exception | None = None
|
||||
try:
|
||||
with (
|
||||
tarfile.open(
|
||||
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
|
||||
) as input_tar,
|
||||
tarfile.open(
|
||||
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
|
||||
) as output_tar,
|
||||
):
|
||||
_encrypt_backup(input_tar, output_tar, password, nonces)
|
||||
except (EncryptError, SecureTarError, tarfile.TarError) as err:
|
||||
LOGGER.warning("Error encrypting backup: %s", err)
|
||||
error = err
|
||||
else:
|
||||
# Pad the output stream to the requested minimum size
|
||||
padding = max(minimum_size - output_stream.tell(), 0)
|
||||
output_stream.write(b"\0" * padding)
|
||||
try:
|
||||
with (
|
||||
tarfile.open(
|
||||
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
|
||||
) as input_tar,
|
||||
tarfile.open(
|
||||
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
|
||||
) as output_tar,
|
||||
):
|
||||
_encrypt_backup(input_tar, output_tar, password, nonces)
|
||||
except (EncryptError, SecureTarError, tarfile.TarError) as err:
|
||||
LOGGER.warning("Error encrypting backup: %s", err)
|
||||
error = err
|
||||
else:
|
||||
# Pad the output stream to the requested minimum size
|
||||
padding = max(minimum_size - output_stream.tell(), 0)
|
||||
output_stream.write(b"\0" * padding)
|
||||
finally:
|
||||
# Write an empty chunk to signal the end of the stream
|
||||
output_stream.write(b"")
|
||||
except AbortCipher:
|
||||
LOGGER.debug("Cipher operation aborted")
|
||||
finally:
|
||||
output_stream.write(b"") # Write an empty chunk to signal the end of the stream
|
||||
on_done(error)
|
||||
|
||||
|
||||
@ -387,7 +403,7 @@ def _encrypt_backup(
|
||||
class _CipherWorkerStatus:
|
||||
done: asyncio.Event
|
||||
error: Exception | None = None
|
||||
thread: threading.Thread
|
||||
thread: ThreadWithException
|
||||
|
||||
|
||||
class _CipherBackupStreamer:
|
||||
@ -440,7 +456,7 @@ class _CipherBackupStreamer:
|
||||
stream = await self._open_stream()
|
||||
reader = AsyncIteratorReader(self._hass, stream)
|
||||
writer = AsyncIteratorWriter(self._hass)
|
||||
worker = threading.Thread(
|
||||
worker = ThreadWithException(
|
||||
target=self._cipher_func,
|
||||
args=[reader, writer, self._password, on_done, self.size(), self._nonces],
|
||||
)
|
||||
@ -451,6 +467,10 @@ class _CipherBackupStreamer:
|
||||
|
||||
async def wait(self) -> None:
|
||||
"""Wait for the worker threads to finish."""
|
||||
for worker in self._workers:
|
||||
if not worker.thread.is_alive():
|
||||
continue
|
||||
worker.thread.raise_exc(AbortCipher)
|
||||
await asyncio.gather(*(worker.done.wait() for worker in self._workers))
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user