From 2b4b46eaf8dba83cc80b0caf3f9e30723dffadc1 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Mon, 29 Sep 2025 18:54:23 +0200 Subject: [PATCH] Add async_iterator util (#153194) --- homeassistant/components/backup/http.py | 5 +- homeassistant/components/backup/manager.py | 4 +- homeassistant/components/backup/util.py | 122 ++----------------- homeassistant/util/async_iterator.py | 134 +++++++++++++++++++++ tests/util/test_async_iterator.py | 116 ++++++++++++++++++ 5 files changed, 265 insertions(+), 116 deletions(-) create mode 100644 homeassistant/util/async_iterator.py create mode 100644 tests/util/test_async_iterator.py diff --git a/homeassistant/components/backup/http.py b/homeassistant/components/backup/http.py index b71859611b4..b40ea76cd59 100644 --- a/homeassistant/components/backup/http.py +++ b/homeassistant/components/backup/http.py @@ -17,6 +17,7 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import frame from homeassistant.util import slugify +from homeassistant.util.async_iterator import AsyncIteratorReader, AsyncIteratorWriter from . import util from .agent import BackupAgent @@ -144,7 +145,7 @@ class DownloadBackupView(HomeAssistantView): return Response(status=HTTPStatus.NOT_FOUND) else: stream = await agent.async_download_backup(backup_id) - reader = cast(IO[bytes], util.AsyncIteratorReader(hass, stream)) + reader = cast(IO[bytes], AsyncIteratorReader(hass.loop, stream)) worker_done_event = asyncio.Event() @@ -152,7 +153,7 @@ class DownloadBackupView(HomeAssistantView): """Call by the worker thread when it's done.""" hass.loop.call_soon_threadsafe(worker_done_event.set) - stream = util.AsyncIteratorWriter(hass) + stream = AsyncIteratorWriter(hass.loop) worker = threading.Thread( target=util.decrypt_backup, args=[backup, reader, stream, password, on_done, 0, []], diff --git a/homeassistant/components/backup/manager.py b/homeassistant/components/backup/manager.py index 863775a32ed..cba09a078c1 100644 --- a/homeassistant/components/backup/manager.py +++ b/homeassistant/components/backup/manager.py @@ -38,6 +38,7 @@ from homeassistant.helpers import ( ) from homeassistant.helpers.json import json_bytes from homeassistant.util import dt as dt_util, json as json_util +from homeassistant.util.async_iterator import AsyncIteratorReader from . import util as backup_util from .agent import ( @@ -72,7 +73,6 @@ from .models import ( ) from .store import BackupStore from .util import ( - AsyncIteratorReader, DecryptedBackupStreamer, EncryptedBackupStreamer, make_backup_dir, @@ -1525,7 +1525,7 @@ class BackupManager: reader = await self.hass.async_add_executor_job(open, path.as_posix(), "rb") else: backup_stream = await agent.async_download_backup(backup_id) - reader = cast(IO[bytes], AsyncIteratorReader(self.hass, backup_stream)) + reader = cast(IO[bytes], AsyncIteratorReader(self.hass.loop, backup_stream)) try: await self.hass.async_add_executor_job( validate_password_stream, reader, password diff --git a/homeassistant/components/backup/util.py b/homeassistant/components/backup/util.py index 1a32c938a54..9dfcb36783d 100644 --- a/homeassistant/components/backup/util.py +++ b/homeassistant/components/backup/util.py @@ -4,7 +4,6 @@ from __future__ import annotations import asyncio from collections.abc import AsyncIterator, Callable, Coroutine -from concurrent.futures import CancelledError, Future import copy from dataclasses import dataclass, replace from io import BytesIO @@ -14,7 +13,7 @@ from pathlib import Path, PurePath from queue import SimpleQueue import tarfile import threading -from typing import IO, Any, Self, cast +from typing import IO, Any, cast import aiohttp from securetar import SecureTarError, SecureTarFile, SecureTarReadError @@ -23,6 +22,11 @@ from homeassistant.backup_restore import password_to_key from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.util import dt as dt_util +from homeassistant.util.async_iterator import ( + Abort, + AsyncIteratorReader, + AsyncIteratorWriter, +) from homeassistant.util.json import JsonObjectType, json_loads_object from .const import BUF_SIZE, LOGGER @@ -59,12 +63,6 @@ class BackupEmpty(DecryptError): _message = "No tar files found in the backup." -class AbortCipher(HomeAssistantError): - """Abort the cipher operation.""" - - _message = "Abort cipher operation." - - def make_backup_dir(path: Path) -> None: """Create a backup directory if it does not exist.""" path.mkdir(exist_ok=True) @@ -166,106 +164,6 @@ def validate_password(path: Path, password: str | None) -> bool: return False -class AsyncIteratorReader: - """Wrap an AsyncIterator.""" - - def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None: - """Initialize the wrapper.""" - self._aborted = False - self._hass = hass - self._stream = stream - self._buffer: bytes | None = None - self._next_future: Future[bytes | None] | None = None - self._pos: int = 0 - - async def _next(self) -> bytes | None: - """Get the next chunk from the iterator.""" - return await anext(self._stream, None) - - def abort(self) -> None: - """Abort the reader.""" - self._aborted = True - if self._next_future is not None: - self._next_future.cancel() - - def read(self, n: int = -1, /) -> bytes: - """Read data from the iterator.""" - result = bytearray() - while n < 0 or len(result) < n: - if not self._buffer: - self._next_future = asyncio.run_coroutine_threadsafe( - self._next(), self._hass.loop - ) - if self._aborted: - self._next_future.cancel() - raise AbortCipher - try: - self._buffer = self._next_future.result() - except CancelledError as err: - raise AbortCipher from err - self._pos = 0 - if not self._buffer: - # The stream is exhausted - break - chunk = self._buffer[self._pos : self._pos + n] - result.extend(chunk) - n -= len(chunk) - self._pos += len(chunk) - if self._pos == len(self._buffer): - self._buffer = None - return bytes(result) - - def close(self) -> None: - """Close the iterator.""" - - -class AsyncIteratorWriter: - """Wrap an AsyncIterator.""" - - def __init__(self, hass: HomeAssistant) -> None: - """Initialize the wrapper.""" - self._aborted = False - self._hass = hass - self._pos: int = 0 - self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1) - self._write_future: Future[bytes | None] | None = None - - def __aiter__(self) -> Self: - """Return the iterator.""" - return self - - async def __anext__(self) -> bytes: - """Get the next chunk from the iterator.""" - if data := await self._queue.get(): - return data - raise StopAsyncIteration - - def abort(self) -> None: - """Abort the writer.""" - self._aborted = True - if self._write_future is not None: - self._write_future.cancel() - - def tell(self) -> int: - """Return the current position in the iterator.""" - return self._pos - - def write(self, s: bytes, /) -> int: - """Write data to the iterator.""" - self._write_future = asyncio.run_coroutine_threadsafe( - self._queue.put(s), self._hass.loop - ) - if self._aborted: - self._write_future.cancel() - raise AbortCipher - try: - self._write_future.result() - except CancelledError as err: - raise AbortCipher from err - self._pos += len(s) - return len(s) - - def validate_password_stream( input_stream: IO[bytes], password: str | None, @@ -342,7 +240,7 @@ def decrypt_backup( finally: # Write an empty chunk to signal the end of the stream output_stream.write(b"") - except AbortCipher: + except Abort: LOGGER.debug("Cipher operation aborted") finally: on_done(error) @@ -430,7 +328,7 @@ def encrypt_backup( finally: # Write an empty chunk to signal the end of the stream output_stream.write(b"") - except AbortCipher: + except Abort: LOGGER.debug("Cipher operation aborted") finally: on_done(error) @@ -557,8 +455,8 @@ class _CipherBackupStreamer: self._hass.loop.call_soon_threadsafe(worker_status.done.set) stream = await self._open_stream() - reader = AsyncIteratorReader(self._hass, stream) - writer = AsyncIteratorWriter(self._hass) + reader = AsyncIteratorReader(self._hass.loop, stream) + writer = AsyncIteratorWriter(self._hass.loop) worker = threading.Thread( target=self._cipher_func, args=[ diff --git a/homeassistant/util/async_iterator.py b/homeassistant/util/async_iterator.py new file mode 100644 index 00000000000..b59d8b47416 --- /dev/null +++ b/homeassistant/util/async_iterator.py @@ -0,0 +1,134 @@ +"""Async iterator utilities.""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from concurrent.futures import CancelledError, Future +from typing import Self + + +class Abort(Exception): + """Raised when abort is requested.""" + + +class AsyncIteratorReader: + """Allow reading from an AsyncIterator using blocking I/O. + + The class implements a blocking read method reading from the async iterator, + and a close method. + + In addition, the abort method can be used to abort any ongoing read operation. + """ + + def __init__( + self, + loop: asyncio.AbstractEventLoop, + stream: AsyncIterator[bytes], + ) -> None: + """Initialize the wrapper.""" + self._aborted = False + self._loop = loop + self._stream = stream + self._buffer: bytes | None = None + self._next_future: Future[bytes | None] | None = None + self._pos: int = 0 + + async def _next(self) -> bytes | None: + """Get the next chunk from the iterator.""" + return await anext(self._stream, None) + + def abort(self) -> None: + """Abort the reader.""" + self._aborted = True + if self._next_future is not None: + self._next_future.cancel() + + def read(self, n: int = -1, /) -> bytes: + """Read up to n bytes of data from the iterator. + + The read method returns 0 bytes when the iterator is exhausted. + """ + result = bytearray() + while n < 0 or len(result) < n: + if not self._buffer: + self._next_future = asyncio.run_coroutine_threadsafe( + self._next(), self._loop + ) + if self._aborted: + self._next_future.cancel() + raise Abort + try: + self._buffer = self._next_future.result() + except CancelledError as err: + raise Abort from err + self._pos = 0 + if not self._buffer: + # The stream is exhausted + break + chunk = self._buffer[self._pos : self._pos + n] + result.extend(chunk) + n -= len(chunk) + self._pos += len(chunk) + if self._pos == len(self._buffer): + self._buffer = None + return bytes(result) + + def close(self) -> None: + """Close the iterator.""" + + +class AsyncIteratorWriter: + """Allow writing to an AsyncIterator using blocking I/O. + + The class implements a blocking write method writing to the async iterator, + as well as a close and tell methods. + + In addition, the abort method can be used to abort any ongoing write operation. + """ + + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + """Initialize the wrapper.""" + self._aborted = False + self._loop = loop + self._pos: int = 0 + self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1) + self._write_future: Future[bytes | None] | None = None + + def __aiter__(self) -> Self: + """Return the iterator.""" + return self + + async def __anext__(self) -> bytes: + """Get the next chunk from the iterator.""" + if data := await self._queue.get(): + return data + raise StopAsyncIteration + + def abort(self) -> None: + """Abort the writer.""" + self._aborted = True + if self._write_future is not None: + self._write_future.cancel() + + def tell(self) -> int: + """Return the current position in the iterator.""" + return self._pos + + def write(self, s: bytes, /) -> int: + """Write data to the iterator. + + To signal the end of the stream, write a zero-length bytes object. + """ + self._write_future = asyncio.run_coroutine_threadsafe( + self._queue.put(s), self._loop + ) + if self._aborted: + self._write_future.cancel() + raise Abort + try: + self._write_future.result() + except CancelledError as err: + raise Abort from err + self._pos += len(s) + return len(s) diff --git a/tests/util/test_async_iterator.py b/tests/util/test_async_iterator.py new file mode 100644 index 00000000000..866b0c8c51c --- /dev/null +++ b/tests/util/test_async_iterator.py @@ -0,0 +1,116 @@ +"""Tests for async iterator utility functions.""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator + +import pytest + +from homeassistant.core import HomeAssistant +from homeassistant.util.async_iterator import ( + Abort, + AsyncIteratorReader, + AsyncIteratorWriter, +) + + +def _read_all(reader: AsyncIteratorReader) -> bytes: + output = b"" + while chunk := reader.read(500): + output += chunk + return output + + +async def test_async_iterator_reader(hass: HomeAssistant) -> None: + """Test the async iterator reader.""" + data = b"hello world" * 1000 + + async def async_gen() -> AsyncIterator[bytes]: + for _ in range(10): + yield data + + reader = AsyncIteratorReader(hass.loop, async_gen()) + assert await hass.async_add_executor_job(_read_all, reader) == data * 10 + + +async def test_async_iterator_reader_abort_early(hass: HomeAssistant) -> None: + """Test abort the async iterator reader.""" + evt = asyncio.Event() + + async def async_gen() -> AsyncIterator[bytes]: + await evt.wait() + yield b"" + + reader = AsyncIteratorReader(hass.loop, async_gen()) + reader.abort() + fut = hass.async_add_executor_job(_read_all, reader) + with pytest.raises(Abort): + await fut + + +async def test_async_iterator_reader_abort_late(hass: HomeAssistant) -> None: + """Test abort the async iterator reader.""" + evt = asyncio.Event() + + async def async_gen() -> AsyncIterator[bytes]: + await evt.wait() + yield b"" + + reader = AsyncIteratorReader(hass.loop, async_gen()) + fut = hass.async_add_executor_job(_read_all, reader) + await asyncio.sleep(0.1) + reader.abort() + with pytest.raises(Abort): + await fut + + +def _write_all(writer: AsyncIteratorWriter, data: list[bytes]) -> bytes: + for chunk in data: + assert writer.write(chunk) == len(chunk) + assert writer.write(b"") == 0 + + +async def test_async_iterator_writer(hass: HomeAssistant) -> None: + """Test the async iterator writer.""" + chunk = b"hello world" * 1000 + chunks = [chunk] * 10 + writer = AsyncIteratorWriter(hass.loop) + + fut = hass.async_add_executor_job(_write_all, writer, chunks) + + read = b"" + async for data in writer: + read += data + + await fut + + assert read == chunk * 10 + assert writer.tell() == len(read) + + +async def test_async_iterator_writer_abort_early(hass: HomeAssistant) -> None: + """Test the async iterator writer.""" + chunk = b"hello world" * 1000 + chunks = [chunk] * 10 + writer = AsyncIteratorWriter(hass.loop) + writer.abort() + + fut = hass.async_add_executor_job(_write_all, writer, chunks) + + with pytest.raises(Abort): + await fut + + +async def test_async_iterator_writer_abort_late(hass: HomeAssistant) -> None: + """Test the async iterator writer.""" + chunk = b"hello world" * 1000 + chunks = [chunk] * 10 + writer = AsyncIteratorWriter(hass.loop) + + fut = hass.async_add_executor_job(_write_all, writer, chunks) + await asyncio.sleep(0.1) + writer.abort() + + with pytest.raises(Abort): + await fut