mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Interrupt _CipherBackupStreamer workers (#136845)
* Interrupt _CipherBackupStreamer workers * Fix cleanup * Only abort live threads
This commit is contained in:
parent
3118831557
commit
660653e226
@ -12,7 +12,6 @@ import os
|
|||||||
from pathlib import Path, PurePath
|
from pathlib import Path, PurePath
|
||||||
from queue import SimpleQueue
|
from queue import SimpleQueue
|
||||||
import tarfile
|
import tarfile
|
||||||
import threading
|
|
||||||
from typing import IO, Any, Self, cast
|
from typing import IO, Any, Self, cast
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -22,6 +21,7 @@ from homeassistant.backup_restore import password_to_key
|
|||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.util.json import JsonObjectType, json_loads_object
|
from homeassistant.util.json import JsonObjectType, json_loads_object
|
||||||
|
from homeassistant.util.thread import ThreadWithException
|
||||||
|
|
||||||
from .const import BUF_SIZE, LOGGER
|
from .const import BUF_SIZE, LOGGER
|
||||||
from .models import AddonInfo, AgentBackup, Folder
|
from .models import AddonInfo, AgentBackup, Folder
|
||||||
@ -57,6 +57,12 @@ class BackupEmpty(DecryptError):
|
|||||||
_message = "No tar files found in the backup."
|
_message = "No tar files found in the backup."
|
||||||
|
|
||||||
|
|
||||||
|
class AbortCipher(HomeAssistantError):
|
||||||
|
"""Abort the cipher operation."""
|
||||||
|
|
||||||
|
_message = "Abort cipher operation."
|
||||||
|
|
||||||
|
|
||||||
def make_backup_dir(path: Path) -> None:
|
def make_backup_dir(path: Path) -> None:
|
||||||
"""Create a backup directory if it does not exist."""
|
"""Create a backup directory if it does not exist."""
|
||||||
path.mkdir(exist_ok=True)
|
path.mkdir(exist_ok=True)
|
||||||
@ -252,24 +258,29 @@ def decrypt_backup(
|
|||||||
"""Decrypt a backup."""
|
"""Decrypt a backup."""
|
||||||
error: Exception | None = None
|
error: Exception | None = None
|
||||||
try:
|
try:
|
||||||
with (
|
try:
|
||||||
tarfile.open(
|
with (
|
||||||
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
|
tarfile.open(
|
||||||
) as input_tar,
|
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
|
||||||
tarfile.open(
|
) as input_tar,
|
||||||
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
|
tarfile.open(
|
||||||
) as output_tar,
|
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
|
||||||
):
|
) as output_tar,
|
||||||
_decrypt_backup(input_tar, output_tar, password)
|
):
|
||||||
except (DecryptError, SecureTarError, tarfile.TarError) as err:
|
_decrypt_backup(input_tar, output_tar, password)
|
||||||
LOGGER.warning("Error decrypting backup: %s", err)
|
except (DecryptError, SecureTarError, tarfile.TarError) as err:
|
||||||
error = err
|
LOGGER.warning("Error decrypting backup: %s", err)
|
||||||
else:
|
error = err
|
||||||
# Pad the output stream to the requested minimum size
|
else:
|
||||||
padding = max(minimum_size - output_stream.tell(), 0)
|
# Pad the output stream to the requested minimum size
|
||||||
output_stream.write(b"\0" * padding)
|
padding = max(minimum_size - output_stream.tell(), 0)
|
||||||
|
output_stream.write(b"\0" * padding)
|
||||||
|
finally:
|
||||||
|
# Write an empty chunk to signal the end of the stream
|
||||||
|
output_stream.write(b"")
|
||||||
|
except AbortCipher:
|
||||||
|
LOGGER.debug("Cipher operation aborted")
|
||||||
finally:
|
finally:
|
||||||
output_stream.write(b"") # Write an empty chunk to signal the end of the stream
|
|
||||||
on_done(error)
|
on_done(error)
|
||||||
|
|
||||||
|
|
||||||
@ -322,24 +333,29 @@ def encrypt_backup(
|
|||||||
"""Encrypt a backup."""
|
"""Encrypt a backup."""
|
||||||
error: Exception | None = None
|
error: Exception | None = None
|
||||||
try:
|
try:
|
||||||
with (
|
try:
|
||||||
tarfile.open(
|
with (
|
||||||
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
|
tarfile.open(
|
||||||
) as input_tar,
|
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
|
||||||
tarfile.open(
|
) as input_tar,
|
||||||
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
|
tarfile.open(
|
||||||
) as output_tar,
|
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
|
||||||
):
|
) as output_tar,
|
||||||
_encrypt_backup(input_tar, output_tar, password, nonces)
|
):
|
||||||
except (EncryptError, SecureTarError, tarfile.TarError) as err:
|
_encrypt_backup(input_tar, output_tar, password, nonces)
|
||||||
LOGGER.warning("Error encrypting backup: %s", err)
|
except (EncryptError, SecureTarError, tarfile.TarError) as err:
|
||||||
error = err
|
LOGGER.warning("Error encrypting backup: %s", err)
|
||||||
else:
|
error = err
|
||||||
# Pad the output stream to the requested minimum size
|
else:
|
||||||
padding = max(minimum_size - output_stream.tell(), 0)
|
# Pad the output stream to the requested minimum size
|
||||||
output_stream.write(b"\0" * padding)
|
padding = max(minimum_size - output_stream.tell(), 0)
|
||||||
|
output_stream.write(b"\0" * padding)
|
||||||
|
finally:
|
||||||
|
# Write an empty chunk to signal the end of the stream
|
||||||
|
output_stream.write(b"")
|
||||||
|
except AbortCipher:
|
||||||
|
LOGGER.debug("Cipher operation aborted")
|
||||||
finally:
|
finally:
|
||||||
output_stream.write(b"") # Write an empty chunk to signal the end of the stream
|
|
||||||
on_done(error)
|
on_done(error)
|
||||||
|
|
||||||
|
|
||||||
@ -387,7 +403,7 @@ def _encrypt_backup(
|
|||||||
class _CipherWorkerStatus:
|
class _CipherWorkerStatus:
|
||||||
done: asyncio.Event
|
done: asyncio.Event
|
||||||
error: Exception | None = None
|
error: Exception | None = None
|
||||||
thread: threading.Thread
|
thread: ThreadWithException
|
||||||
|
|
||||||
|
|
||||||
class _CipherBackupStreamer:
|
class _CipherBackupStreamer:
|
||||||
@ -440,7 +456,7 @@ class _CipherBackupStreamer:
|
|||||||
stream = await self._open_stream()
|
stream = await self._open_stream()
|
||||||
reader = AsyncIteratorReader(self._hass, stream)
|
reader = AsyncIteratorReader(self._hass, stream)
|
||||||
writer = AsyncIteratorWriter(self._hass)
|
writer = AsyncIteratorWriter(self._hass)
|
||||||
worker = threading.Thread(
|
worker = ThreadWithException(
|
||||||
target=self._cipher_func,
|
target=self._cipher_func,
|
||||||
args=[reader, writer, self._password, on_done, self.size(), self._nonces],
|
args=[reader, writer, self._password, on_done, self.size(), self._nonces],
|
||||||
)
|
)
|
||||||
@ -451,6 +467,10 @@ class _CipherBackupStreamer:
|
|||||||
|
|
||||||
async def wait(self) -> None:
|
async def wait(self) -> None:
|
||||||
"""Wait for the worker threads to finish."""
|
"""Wait for the worker threads to finish."""
|
||||||
|
for worker in self._workers:
|
||||||
|
if not worker.thread.is_alive():
|
||||||
|
continue
|
||||||
|
worker.thread.raise_exc(AbortCipher)
|
||||||
await asyncio.gather(*(worker.done.wait() for worker in self._workers))
|
await asyncio.gather(*(worker.done.wait() for worker in self._workers))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user