Improve shutdown of _CipherBackupStreamer (#137257)

* Improve shutdown of _CipherBackupStreamer

* Catch the right exception
This commit is contained in:
Erik Montnemery 2025-02-04 12:24:30 +01:00 committed by GitHub
parent e18062bce4
commit ca53d97a6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 185 additions and 10 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator, Callable, Coroutine
from concurrent.futures import CancelledError, Future
import copy
from dataclasses import dataclass, replace
from io import BytesIO
@ -12,6 +13,7 @@ import os
from pathlib import Path, PurePath
from queue import SimpleQueue
import tarfile
import threading
from typing import IO, Any, Self, cast
import aiohttp
@ -22,7 +24,6 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import dt as dt_util
from homeassistant.util.json import JsonObjectType, json_loads_object
from homeassistant.util.thread import ThreadWithException
from .const import BUF_SIZE, LOGGER
from .models import AddonInfo, AgentBackup, Folder
@ -167,23 +168,38 @@ class AsyncIteratorReader:
def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None:
"""Initialize the wrapper."""
self._aborted = False
self._hass = hass
self._stream = stream
self._buffer: bytes | None = None
self._next_future: Future[bytes | None] | None = None
self._pos: int = 0
async def _next(self) -> bytes | None:
"""Get the next chunk from the iterator."""
return await anext(self._stream, None)
def abort(self) -> None:
"""Abort the reader."""
self._aborted = True
if self._next_future is not None:
self._next_future.cancel()
def read(self, n: int = -1, /) -> bytes:
"""Read data from the iterator."""
result = bytearray()
while n < 0 or len(result) < n:
if not self._buffer:
self._buffer = asyncio.run_coroutine_threadsafe(
self._next_future = asyncio.run_coroutine_threadsafe(
self._next(), self._hass.loop
).result()
)
if self._aborted:
self._next_future.cancel()
raise AbortCipher
try:
self._buffer = self._next_future.result()
except CancelledError as err:
raise AbortCipher from err
self._pos = 0
if not self._buffer:
# The stream is exhausted
@ -205,9 +221,11 @@ class AsyncIteratorWriter:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the wrapper."""
self._aborted = False
self._hass = hass
self._pos: int = 0
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
self._write_future: Future[bytes | None] | None = None
def __aiter__(self) -> Self:
"""Return the iterator."""
@ -219,13 +237,28 @@ class AsyncIteratorWriter:
return data
raise StopAsyncIteration
def abort(self) -> None:
"""Abort the writer."""
self._aborted = True
if self._write_future is not None:
self._write_future.cancel()
def tell(self) -> int:
"""Return the current position in the iterator."""
return self._pos
def write(self, s: bytes, /) -> int:
"""Write data to the iterator."""
asyncio.run_coroutine_threadsafe(self._queue.put(s), self._hass.loop).result()
self._write_future = asyncio.run_coroutine_threadsafe(
self._queue.put(s), self._hass.loop
)
if self._aborted:
self._write_future.cancel()
raise AbortCipher
try:
self._write_future.result()
except CancelledError as err:
raise AbortCipher from err
self._pos += len(s)
return len(s)
@ -415,7 +448,9 @@ def _encrypt_backup(
class _CipherWorkerStatus:
done: asyncio.Event
error: Exception | None = None
thread: ThreadWithException
reader: AsyncIteratorReader
thread: threading.Thread
writer: AsyncIteratorWriter
class _CipherBackupStreamer:
@ -468,11 +503,13 @@ class _CipherBackupStreamer:
stream = await self._open_stream()
reader = AsyncIteratorReader(self._hass, stream)
writer = AsyncIteratorWriter(self._hass)
worker = ThreadWithException(
worker = threading.Thread(
target=self._cipher_func,
args=[reader, writer, self._password, on_done, self.size(), self._nonces],
)
worker_status = _CipherWorkerStatus(done=asyncio.Event(), thread=worker)
worker_status = _CipherWorkerStatus(
done=asyncio.Event(), reader=reader, thread=worker, writer=writer
)
self._workers.append(worker_status)
worker.start()
return writer
@ -480,9 +517,8 @@ class _CipherBackupStreamer:
async def wait(self) -> None:
"""Wait for the worker threads to finish."""
for worker in self._workers:
if not worker.thread.is_alive():
continue
worker.thread.raise_exc(AbortCipher)
worker.reader.abort()
worker.writer.abort()
await asyncio.gather(*(worker.done.wait() for worker in self._workers))

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
import dataclasses
import tarfile
@ -189,6 +190,73 @@ async def test_decrypted_backup_streamer(hass: HomeAssistant) -> None:
assert decrypted_output == decrypted_backup_data + expected_padding
async def test_decrypted_backup_streamer_interrupt_stuck_reader(
hass: HomeAssistant,
) -> None:
"""Test the decrypted backup streamer."""
encrypted_backup_path = get_fixture_path("test_backups/c0cb53bd.tar", DOMAIN)
backup = AgentBackup(
addons=["addon_1", "addon_2"],
backup_id="1234",
date="2024-12-02T07:23:58.261875-05:00",
database_included=False,
extra_metadata={},
folders=[],
homeassistant_included=True,
homeassistant_version="2024.12.0.dev0",
name="test",
protected=True,
size=encrypted_backup_path.stat().st_size,
)
stuck = asyncio.Event()
async def send_backup() -> AsyncIterator[bytes]:
f = encrypted_backup_path.open("rb")
while chunk := f.read(1024):
await stuck.wait()
yield chunk
async def open_backup() -> AsyncIterator[bytes]:
return send_backup()
decryptor = DecryptedBackupStreamer(hass, backup, open_backup, "hunter2")
await decryptor.open_stream()
await decryptor.wait()
async def test_decrypted_backup_streamer_interrupt_stuck_writer(
hass: HomeAssistant,
) -> None:
"""Test the decrypted backup streamer."""
encrypted_backup_path = get_fixture_path("test_backups/c0cb53bd.tar", DOMAIN)
backup = AgentBackup(
addons=["addon_1", "addon_2"],
backup_id="1234",
date="2024-12-02T07:23:58.261875-05:00",
database_included=False,
extra_metadata={},
folders=[],
homeassistant_included=True,
homeassistant_version="2024.12.0.dev0",
name="test",
protected=True,
size=encrypted_backup_path.stat().st_size,
)
async def send_backup() -> AsyncIterator[bytes]:
f = encrypted_backup_path.open("rb")
while chunk := f.read(1024):
yield chunk
async def open_backup() -> AsyncIterator[bytes]:
return send_backup()
decryptor = DecryptedBackupStreamer(hass, backup, open_backup, "hunter2")
await decryptor.open_stream()
await decryptor.wait()
async def test_decrypted_backup_streamer_wrong_password(hass: HomeAssistant) -> None:
"""Test the decrypted backup streamer with wrong password."""
encrypted_backup_path = get_fixture_path("test_backups/c0cb53bd.tar", DOMAIN)
@ -279,6 +347,77 @@ async def test_encrypted_backup_streamer(hass: HomeAssistant) -> None:
assert encrypted_output == encrypted_backup_data + expected_padding
async def test_encrypted_backup_streamer_interrupt_stuck_reader(
hass: HomeAssistant,
) -> None:
"""Test the encrypted backup streamer."""
decrypted_backup_path = get_fixture_path(
"test_backups/c0cb53bd.tar.decrypted", DOMAIN
)
backup = AgentBackup(
addons=["addon_1", "addon_2"],
backup_id="1234",
date="2024-12-02T07:23:58.261875-05:00",
database_included=False,
extra_metadata={},
folders=[],
homeassistant_included=True,
homeassistant_version="2024.12.0.dev0",
name="test",
protected=False,
size=decrypted_backup_path.stat().st_size,
)
stuck = asyncio.Event()
async def send_backup() -> AsyncIterator[bytes]:
f = decrypted_backup_path.open("rb")
while chunk := f.read(1024):
await stuck.wait()
yield chunk
async def open_backup() -> AsyncIterator[bytes]:
return send_backup()
decryptor = EncryptedBackupStreamer(hass, backup, open_backup, "hunter2")
await decryptor.open_stream()
await decryptor.wait()
async def test_encrypted_backup_streamer_interrupt_stuck_writer(
hass: HomeAssistant,
) -> None:
"""Test the encrypted backup streamer."""
decrypted_backup_path = get_fixture_path(
"test_backups/c0cb53bd.tar.decrypted", DOMAIN
)
backup = AgentBackup(
addons=["addon_1", "addon_2"],
backup_id="1234",
date="2024-12-02T07:23:58.261875-05:00",
database_included=False,
extra_metadata={},
folders=[],
homeassistant_included=True,
homeassistant_version="2024.12.0.dev0",
name="test",
protected=True,
size=decrypted_backup_path.stat().st_size,
)
async def send_backup() -> AsyncIterator[bytes]:
f = decrypted_backup_path.open("rb")
while chunk := f.read(1024):
yield chunk
async def open_backup() -> AsyncIterator[bytes]:
return send_backup()
decryptor = EncryptedBackupStreamer(hass, backup, open_backup, "hunter2")
await decryptor.open_stream()
await decryptor.wait()
async def test_encrypted_backup_streamer_random_nonce(hass: HomeAssistant) -> None:
"""Test the encrypted backup streamer."""
decrypted_backup_path = get_fixture_path(