mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 08:47:57 +00:00
Improve shutdown of _CipherBackupStreamer (#137257)
* Improve shutdown of _CipherBackupStreamer * Catch the right exception
This commit is contained in:
parent
e18062bce4
commit
ca53d97a6d
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, Callable, Coroutine
|
||||
from concurrent.futures import CancelledError, Future
|
||||
import copy
|
||||
from dataclasses import dataclass, replace
|
||||
from io import BytesIO
|
||||
@ -12,6 +13,7 @@ import os
|
||||
from pathlib import Path, PurePath
|
||||
from queue import SimpleQueue
|
||||
import tarfile
|
||||
import threading
|
||||
from typing import IO, Any, Self, cast
|
||||
|
||||
import aiohttp
|
||||
@ -22,7 +24,6 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.json import JsonObjectType, json_loads_object
|
||||
from homeassistant.util.thread import ThreadWithException
|
||||
|
||||
from .const import BUF_SIZE, LOGGER
|
||||
from .models import AddonInfo, AgentBackup, Folder
|
||||
@ -167,23 +168,38 @@ class AsyncIteratorReader:
|
||||
|
||||
def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None:
|
||||
"""Initialize the wrapper."""
|
||||
self._aborted = False
|
||||
self._hass = hass
|
||||
self._stream = stream
|
||||
self._buffer: bytes | None = None
|
||||
self._next_future: Future[bytes | None] | None = None
|
||||
self._pos: int = 0
|
||||
|
||||
async def _next(self) -> bytes | None:
|
||||
"""Get the next chunk from the iterator."""
|
||||
return await anext(self._stream, None)
|
||||
|
||||
def abort(self) -> None:
|
||||
"""Abort the reader."""
|
||||
self._aborted = True
|
||||
if self._next_future is not None:
|
||||
self._next_future.cancel()
|
||||
|
||||
def read(self, n: int = -1, /) -> bytes:
|
||||
"""Read data from the iterator."""
|
||||
result = bytearray()
|
||||
while n < 0 or len(result) < n:
|
||||
if not self._buffer:
|
||||
self._buffer = asyncio.run_coroutine_threadsafe(
|
||||
self._next_future = asyncio.run_coroutine_threadsafe(
|
||||
self._next(), self._hass.loop
|
||||
).result()
|
||||
)
|
||||
if self._aborted:
|
||||
self._next_future.cancel()
|
||||
raise AbortCipher
|
||||
try:
|
||||
self._buffer = self._next_future.result()
|
||||
except CancelledError as err:
|
||||
raise AbortCipher from err
|
||||
self._pos = 0
|
||||
if not self._buffer:
|
||||
# The stream is exhausted
|
||||
@ -205,9 +221,11 @@ class AsyncIteratorWriter:
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the wrapper."""
|
||||
self._aborted = False
|
||||
self._hass = hass
|
||||
self._pos: int = 0
|
||||
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
|
||||
self._write_future: Future[bytes | None] | None = None
|
||||
|
||||
def __aiter__(self) -> Self:
|
||||
"""Return the iterator."""
|
||||
@ -219,13 +237,28 @@ class AsyncIteratorWriter:
|
||||
return data
|
||||
raise StopAsyncIteration
|
||||
|
||||
def abort(self) -> None:
|
||||
"""Abort the writer."""
|
||||
self._aborted = True
|
||||
if self._write_future is not None:
|
||||
self._write_future.cancel()
|
||||
|
||||
def tell(self) -> int:
|
||||
"""Return the current position in the iterator."""
|
||||
return self._pos
|
||||
|
||||
def write(self, s: bytes, /) -> int:
|
||||
"""Write data to the iterator."""
|
||||
asyncio.run_coroutine_threadsafe(self._queue.put(s), self._hass.loop).result()
|
||||
self._write_future = asyncio.run_coroutine_threadsafe(
|
||||
self._queue.put(s), self._hass.loop
|
||||
)
|
||||
if self._aborted:
|
||||
self._write_future.cancel()
|
||||
raise AbortCipher
|
||||
try:
|
||||
self._write_future.result()
|
||||
except CancelledError as err:
|
||||
raise AbortCipher from err
|
||||
self._pos += len(s)
|
||||
return len(s)
|
||||
|
||||
@ -415,7 +448,9 @@ def _encrypt_backup(
|
||||
class _CipherWorkerStatus:
|
||||
done: asyncio.Event
|
||||
error: Exception | None = None
|
||||
thread: ThreadWithException
|
||||
reader: AsyncIteratorReader
|
||||
thread: threading.Thread
|
||||
writer: AsyncIteratorWriter
|
||||
|
||||
|
||||
class _CipherBackupStreamer:
|
||||
@ -468,11 +503,13 @@ class _CipherBackupStreamer:
|
||||
stream = await self._open_stream()
|
||||
reader = AsyncIteratorReader(self._hass, stream)
|
||||
writer = AsyncIteratorWriter(self._hass)
|
||||
worker = ThreadWithException(
|
||||
worker = threading.Thread(
|
||||
target=self._cipher_func,
|
||||
args=[reader, writer, self._password, on_done, self.size(), self._nonces],
|
||||
)
|
||||
worker_status = _CipherWorkerStatus(done=asyncio.Event(), thread=worker)
|
||||
worker_status = _CipherWorkerStatus(
|
||||
done=asyncio.Event(), reader=reader, thread=worker, writer=writer
|
||||
)
|
||||
self._workers.append(worker_status)
|
||||
worker.start()
|
||||
return writer
|
||||
@ -480,9 +517,8 @@ class _CipherBackupStreamer:
|
||||
async def wait(self) -> None:
|
||||
"""Wait for the worker threads to finish."""
|
||||
for worker in self._workers:
|
||||
if not worker.thread.is_alive():
|
||||
continue
|
||||
worker.thread.raise_exc(AbortCipher)
|
||||
worker.reader.abort()
|
||||
worker.writer.abort()
|
||||
await asyncio.gather(*(worker.done.wait() for worker in self._workers))
|
||||
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
import dataclasses
|
||||
import tarfile
|
||||
@ -189,6 +190,73 @@ async def test_decrypted_backup_streamer(hass: HomeAssistant) -> None:
|
||||
assert decrypted_output == decrypted_backup_data + expected_padding
|
||||
|
||||
|
||||
async def test_decrypted_backup_streamer_interrupt_stuck_reader(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test the decrypted backup streamer."""
|
||||
encrypted_backup_path = get_fixture_path("test_backups/c0cb53bd.tar", DOMAIN)
|
||||
backup = AgentBackup(
|
||||
addons=["addon_1", "addon_2"],
|
||||
backup_id="1234",
|
||||
date="2024-12-02T07:23:58.261875-05:00",
|
||||
database_included=False,
|
||||
extra_metadata={},
|
||||
folders=[],
|
||||
homeassistant_included=True,
|
||||
homeassistant_version="2024.12.0.dev0",
|
||||
name="test",
|
||||
protected=True,
|
||||
size=encrypted_backup_path.stat().st_size,
|
||||
)
|
||||
|
||||
stuck = asyncio.Event()
|
||||
|
||||
async def send_backup() -> AsyncIterator[bytes]:
|
||||
f = encrypted_backup_path.open("rb")
|
||||
while chunk := f.read(1024):
|
||||
await stuck.wait()
|
||||
yield chunk
|
||||
|
||||
async def open_backup() -> AsyncIterator[bytes]:
|
||||
return send_backup()
|
||||
|
||||
decryptor = DecryptedBackupStreamer(hass, backup, open_backup, "hunter2")
|
||||
await decryptor.open_stream()
|
||||
await decryptor.wait()
|
||||
|
||||
|
||||
async def test_decrypted_backup_streamer_interrupt_stuck_writer(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test the decrypted backup streamer."""
|
||||
encrypted_backup_path = get_fixture_path("test_backups/c0cb53bd.tar", DOMAIN)
|
||||
backup = AgentBackup(
|
||||
addons=["addon_1", "addon_2"],
|
||||
backup_id="1234",
|
||||
date="2024-12-02T07:23:58.261875-05:00",
|
||||
database_included=False,
|
||||
extra_metadata={},
|
||||
folders=[],
|
||||
homeassistant_included=True,
|
||||
homeassistant_version="2024.12.0.dev0",
|
||||
name="test",
|
||||
protected=True,
|
||||
size=encrypted_backup_path.stat().st_size,
|
||||
)
|
||||
|
||||
async def send_backup() -> AsyncIterator[bytes]:
|
||||
f = encrypted_backup_path.open("rb")
|
||||
while chunk := f.read(1024):
|
||||
yield chunk
|
||||
|
||||
async def open_backup() -> AsyncIterator[bytes]:
|
||||
return send_backup()
|
||||
|
||||
decryptor = DecryptedBackupStreamer(hass, backup, open_backup, "hunter2")
|
||||
await decryptor.open_stream()
|
||||
await decryptor.wait()
|
||||
|
||||
|
||||
async def test_decrypted_backup_streamer_wrong_password(hass: HomeAssistant) -> None:
|
||||
"""Test the decrypted backup streamer with wrong password."""
|
||||
encrypted_backup_path = get_fixture_path("test_backups/c0cb53bd.tar", DOMAIN)
|
||||
@ -279,6 +347,77 @@ async def test_encrypted_backup_streamer(hass: HomeAssistant) -> None:
|
||||
assert encrypted_output == encrypted_backup_data + expected_padding
|
||||
|
||||
|
||||
async def test_encrypted_backup_streamer_interrupt_stuck_reader(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test the encrypted backup streamer."""
|
||||
decrypted_backup_path = get_fixture_path(
|
||||
"test_backups/c0cb53bd.tar.decrypted", DOMAIN
|
||||
)
|
||||
backup = AgentBackup(
|
||||
addons=["addon_1", "addon_2"],
|
||||
backup_id="1234",
|
||||
date="2024-12-02T07:23:58.261875-05:00",
|
||||
database_included=False,
|
||||
extra_metadata={},
|
||||
folders=[],
|
||||
homeassistant_included=True,
|
||||
homeassistant_version="2024.12.0.dev0",
|
||||
name="test",
|
||||
protected=False,
|
||||
size=decrypted_backup_path.stat().st_size,
|
||||
)
|
||||
|
||||
stuck = asyncio.Event()
|
||||
|
||||
async def send_backup() -> AsyncIterator[bytes]:
|
||||
f = decrypted_backup_path.open("rb")
|
||||
while chunk := f.read(1024):
|
||||
await stuck.wait()
|
||||
yield chunk
|
||||
|
||||
async def open_backup() -> AsyncIterator[bytes]:
|
||||
return send_backup()
|
||||
|
||||
decryptor = EncryptedBackupStreamer(hass, backup, open_backup, "hunter2")
|
||||
await decryptor.open_stream()
|
||||
await decryptor.wait()
|
||||
|
||||
|
||||
async def test_encrypted_backup_streamer_interrupt_stuck_writer(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test the encrypted backup streamer."""
|
||||
decrypted_backup_path = get_fixture_path(
|
||||
"test_backups/c0cb53bd.tar.decrypted", DOMAIN
|
||||
)
|
||||
backup = AgentBackup(
|
||||
addons=["addon_1", "addon_2"],
|
||||
backup_id="1234",
|
||||
date="2024-12-02T07:23:58.261875-05:00",
|
||||
database_included=False,
|
||||
extra_metadata={},
|
||||
folders=[],
|
||||
homeassistant_included=True,
|
||||
homeassistant_version="2024.12.0.dev0",
|
||||
name="test",
|
||||
protected=True,
|
||||
size=decrypted_backup_path.stat().st_size,
|
||||
)
|
||||
|
||||
async def send_backup() -> AsyncIterator[bytes]:
|
||||
f = decrypted_backup_path.open("rb")
|
||||
while chunk := f.read(1024):
|
||||
yield chunk
|
||||
|
||||
async def open_backup() -> AsyncIterator[bytes]:
|
||||
return send_backup()
|
||||
|
||||
decryptor = EncryptedBackupStreamer(hass, backup, open_backup, "hunter2")
|
||||
await decryptor.open_stream()
|
||||
await decryptor.wait()
|
||||
|
||||
|
||||
async def test_encrypted_backup_streamer_random_nonce(hass: HomeAssistant) -> None:
|
||||
"""Test the encrypted backup streamer."""
|
||||
decrypted_backup_path = get_fixture_path(
|
||||
|
Loading…
x
Reference in New Issue
Block a user