Improve shutdown of _CipherBackupStreamer (#137257)

* Improve shutdown of _CipherBackupStreamer

* Catch the right exception
This commit is contained in:
Erik Montnemery 2025-02-04 12:24:30 +01:00 committed by GitHub
parent e18062bce4
commit ca53d97a6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 185 additions and 10 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import AsyncIterator, Callable, Coroutine from collections.abc import AsyncIterator, Callable, Coroutine
from concurrent.futures import CancelledError, Future
import copy import copy
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from io import BytesIO from io import BytesIO
@ -12,6 +13,7 @@ import os
from pathlib import Path, PurePath from pathlib import Path, PurePath
from queue import SimpleQueue from queue import SimpleQueue
import tarfile import tarfile
import threading
from typing import IO, Any, Self, cast from typing import IO, Any, Self, cast
import aiohttp import aiohttp
@ -22,7 +24,6 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.json import JsonObjectType, json_loads_object from homeassistant.util.json import JsonObjectType, json_loads_object
from homeassistant.util.thread import ThreadWithException
from .const import BUF_SIZE, LOGGER from .const import BUF_SIZE, LOGGER
from .models import AddonInfo, AgentBackup, Folder from .models import AddonInfo, AgentBackup, Folder
@ -167,23 +168,38 @@ class AsyncIteratorReader:
def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None: def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None:
"""Initialize the wrapper.""" """Initialize the wrapper."""
self._aborted = False
self._hass = hass self._hass = hass
self._stream = stream self._stream = stream
self._buffer: bytes | None = None self._buffer: bytes | None = None
self._next_future: Future[bytes | None] | None = None
self._pos: int = 0 self._pos: int = 0
async def _next(self) -> bytes | None: async def _next(self) -> bytes | None:
"""Get the next chunk from the iterator.""" """Get the next chunk from the iterator."""
return await anext(self._stream, None) return await anext(self._stream, None)
def abort(self) -> None:
"""Abort the reader."""
self._aborted = True
if self._next_future is not None:
self._next_future.cancel()
def read(self, n: int = -1, /) -> bytes: def read(self, n: int = -1, /) -> bytes:
"""Read data from the iterator.""" """Read data from the iterator."""
result = bytearray() result = bytearray()
while n < 0 or len(result) < n: while n < 0 or len(result) < n:
if not self._buffer: if not self._buffer:
self._buffer = asyncio.run_coroutine_threadsafe( self._next_future = asyncio.run_coroutine_threadsafe(
self._next(), self._hass.loop self._next(), self._hass.loop
).result() )
if self._aborted:
self._next_future.cancel()
raise AbortCipher
try:
self._buffer = self._next_future.result()
except CancelledError as err:
raise AbortCipher from err
self._pos = 0 self._pos = 0
if not self._buffer: if not self._buffer:
# The stream is exhausted # The stream is exhausted
@ -205,9 +221,11 @@ class AsyncIteratorWriter:
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the wrapper.""" """Initialize the wrapper."""
self._aborted = False
self._hass = hass self._hass = hass
self._pos: int = 0 self._pos: int = 0
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1) self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
self._write_future: Future[bytes | None] | None = None
def __aiter__(self) -> Self: def __aiter__(self) -> Self:
"""Return the iterator.""" """Return the iterator."""
@ -219,13 +237,28 @@ class AsyncIteratorWriter:
return data return data
raise StopAsyncIteration raise StopAsyncIteration
def abort(self) -> None:
"""Abort the writer."""
self._aborted = True
if self._write_future is not None:
self._write_future.cancel()
def tell(self) -> int: def tell(self) -> int:
"""Return the current position in the iterator.""" """Return the current position in the iterator."""
return self._pos return self._pos
def write(self, s: bytes, /) -> int: def write(self, s: bytes, /) -> int:
"""Write data to the iterator.""" """Write data to the iterator."""
asyncio.run_coroutine_threadsafe(self._queue.put(s), self._hass.loop).result() self._write_future = asyncio.run_coroutine_threadsafe(
self._queue.put(s), self._hass.loop
)
if self._aborted:
self._write_future.cancel()
raise AbortCipher
try:
self._write_future.result()
except CancelledError as err:
raise AbortCipher from err
self._pos += len(s) self._pos += len(s)
return len(s) return len(s)
@ -415,7 +448,9 @@ def _encrypt_backup(
class _CipherWorkerStatus: class _CipherWorkerStatus:
done: asyncio.Event done: asyncio.Event
error: Exception | None = None error: Exception | None = None
thread: ThreadWithException reader: AsyncIteratorReader
thread: threading.Thread
writer: AsyncIteratorWriter
class _CipherBackupStreamer: class _CipherBackupStreamer:
@ -468,11 +503,13 @@ class _CipherBackupStreamer:
stream = await self._open_stream() stream = await self._open_stream()
reader = AsyncIteratorReader(self._hass, stream) reader = AsyncIteratorReader(self._hass, stream)
writer = AsyncIteratorWriter(self._hass) writer = AsyncIteratorWriter(self._hass)
worker = ThreadWithException( worker = threading.Thread(
target=self._cipher_func, target=self._cipher_func,
args=[reader, writer, self._password, on_done, self.size(), self._nonces], args=[reader, writer, self._password, on_done, self.size(), self._nonces],
) )
worker_status = _CipherWorkerStatus(done=asyncio.Event(), thread=worker) worker_status = _CipherWorkerStatus(
done=asyncio.Event(), reader=reader, thread=worker, writer=writer
)
self._workers.append(worker_status) self._workers.append(worker_status)
worker.start() worker.start()
return writer return writer
@ -480,9 +517,8 @@ class _CipherBackupStreamer:
async def wait(self) -> None: async def wait(self) -> None:
"""Wait for the worker threads to finish.""" """Wait for the worker threads to finish."""
for worker in self._workers: for worker in self._workers:
if not worker.thread.is_alive(): worker.reader.abort()
continue worker.writer.abort()
worker.thread.raise_exc(AbortCipher)
await asyncio.gather(*(worker.done.wait() for worker in self._workers)) await asyncio.gather(*(worker.done.wait() for worker in self._workers))

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
import dataclasses import dataclasses
import tarfile import tarfile
@ -189,6 +190,73 @@ async def test_decrypted_backup_streamer(hass: HomeAssistant) -> None:
assert decrypted_output == decrypted_backup_data + expected_padding assert decrypted_output == decrypted_backup_data + expected_padding
async def test_decrypted_backup_streamer_interrupt_stuck_reader(
hass: HomeAssistant,
) -> None:
"""Test the decrypted backup streamer."""
encrypted_backup_path = get_fixture_path("test_backups/c0cb53bd.tar", DOMAIN)
backup = AgentBackup(
addons=["addon_1", "addon_2"],
backup_id="1234",
date="2024-12-02T07:23:58.261875-05:00",
database_included=False,
extra_metadata={},
folders=[],
homeassistant_included=True,
homeassistant_version="2024.12.0.dev0",
name="test",
protected=True,
size=encrypted_backup_path.stat().st_size,
)
stuck = asyncio.Event()
async def send_backup() -> AsyncIterator[bytes]:
f = encrypted_backup_path.open("rb")
while chunk := f.read(1024):
await stuck.wait()
yield chunk
async def open_backup() -> AsyncIterator[bytes]:
return send_backup()
decryptor = DecryptedBackupStreamer(hass, backup, open_backup, "hunter2")
await decryptor.open_stream()
await decryptor.wait()
async def test_decrypted_backup_streamer_interrupt_stuck_writer(
hass: HomeAssistant,
) -> None:
"""Test the decrypted backup streamer."""
encrypted_backup_path = get_fixture_path("test_backups/c0cb53bd.tar", DOMAIN)
backup = AgentBackup(
addons=["addon_1", "addon_2"],
backup_id="1234",
date="2024-12-02T07:23:58.261875-05:00",
database_included=False,
extra_metadata={},
folders=[],
homeassistant_included=True,
homeassistant_version="2024.12.0.dev0",
name="test",
protected=True,
size=encrypted_backup_path.stat().st_size,
)
async def send_backup() -> AsyncIterator[bytes]:
f = encrypted_backup_path.open("rb")
while chunk := f.read(1024):
yield chunk
async def open_backup() -> AsyncIterator[bytes]:
return send_backup()
decryptor = DecryptedBackupStreamer(hass, backup, open_backup, "hunter2")
await decryptor.open_stream()
await decryptor.wait()
async def test_decrypted_backup_streamer_wrong_password(hass: HomeAssistant) -> None: async def test_decrypted_backup_streamer_wrong_password(hass: HomeAssistant) -> None:
"""Test the decrypted backup streamer with wrong password.""" """Test the decrypted backup streamer with wrong password."""
encrypted_backup_path = get_fixture_path("test_backups/c0cb53bd.tar", DOMAIN) encrypted_backup_path = get_fixture_path("test_backups/c0cb53bd.tar", DOMAIN)
@ -279,6 +347,77 @@ async def test_encrypted_backup_streamer(hass: HomeAssistant) -> None:
assert encrypted_output == encrypted_backup_data + expected_padding assert encrypted_output == encrypted_backup_data + expected_padding
async def test_encrypted_backup_streamer_interrupt_stuck_reader(
hass: HomeAssistant,
) -> None:
"""Test the encrypted backup streamer."""
decrypted_backup_path = get_fixture_path(
"test_backups/c0cb53bd.tar.decrypted", DOMAIN
)
backup = AgentBackup(
addons=["addon_1", "addon_2"],
backup_id="1234",
date="2024-12-02T07:23:58.261875-05:00",
database_included=False,
extra_metadata={},
folders=[],
homeassistant_included=True,
homeassistant_version="2024.12.0.dev0",
name="test",
protected=False,
size=decrypted_backup_path.stat().st_size,
)
stuck = asyncio.Event()
async def send_backup() -> AsyncIterator[bytes]:
f = decrypted_backup_path.open("rb")
while chunk := f.read(1024):
await stuck.wait()
yield chunk
async def open_backup() -> AsyncIterator[bytes]:
return send_backup()
decryptor = EncryptedBackupStreamer(hass, backup, open_backup, "hunter2")
await decryptor.open_stream()
await decryptor.wait()
async def test_encrypted_backup_streamer_interrupt_stuck_writer(
hass: HomeAssistant,
) -> None:
"""Test the encrypted backup streamer."""
decrypted_backup_path = get_fixture_path(
"test_backups/c0cb53bd.tar.decrypted", DOMAIN
)
backup = AgentBackup(
addons=["addon_1", "addon_2"],
backup_id="1234",
date="2024-12-02T07:23:58.261875-05:00",
database_included=False,
extra_metadata={},
folders=[],
homeassistant_included=True,
homeassistant_version="2024.12.0.dev0",
name="test",
protected=True,
size=decrypted_backup_path.stat().st_size,
)
async def send_backup() -> AsyncIterator[bytes]:
f = decrypted_backup_path.open("rb")
while chunk := f.read(1024):
yield chunk
async def open_backup() -> AsyncIterator[bytes]:
return send_backup()
decryptor = EncryptedBackupStreamer(hass, backup, open_backup, "hunter2")
await decryptor.open_stream()
await decryptor.wait()
async def test_encrypted_backup_streamer_random_nonce(hass: HomeAssistant) -> None: async def test_encrypted_backup_streamer_random_nonce(hass: HomeAssistant) -> None:
"""Test the encrypted backup streamer.""" """Test the encrypted backup streamer."""
decrypted_backup_path = get_fixture_path( decrypted_backup_path = get_fixture_path(