Add async_iterator util (#153194)

This commit is contained in:
Erik Montnemery
2025-09-29 18:54:23 +02:00
committed by GitHub
parent 40b9dae608
commit 2b4b46eaf8
5 changed files with 265 additions and 116 deletions

View File

@@ -17,6 +17,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import frame
from homeassistant.util import slugify
from homeassistant.util.async_iterator import AsyncIteratorReader, AsyncIteratorWriter
from . import util
from .agent import BackupAgent
@@ -144,7 +145,7 @@ class DownloadBackupView(HomeAssistantView):
return Response(status=HTTPStatus.NOT_FOUND)
else:
stream = await agent.async_download_backup(backup_id)
reader = cast(IO[bytes], util.AsyncIteratorReader(hass, stream))
reader = cast(IO[bytes], AsyncIteratorReader(hass.loop, stream))
worker_done_event = asyncio.Event()
@@ -152,7 +153,7 @@ class DownloadBackupView(HomeAssistantView):
"""Call by the worker thread when it's done."""
hass.loop.call_soon_threadsafe(worker_done_event.set)
stream = util.AsyncIteratorWriter(hass)
stream = AsyncIteratorWriter(hass.loop)
worker = threading.Thread(
target=util.decrypt_backup,
args=[backup, reader, stream, password, on_done, 0, []],

View File

@@ -38,6 +38,7 @@ from homeassistant.helpers import (
)
from homeassistant.helpers.json import json_bytes
from homeassistant.util import dt as dt_util, json as json_util
from homeassistant.util.async_iterator import AsyncIteratorReader
from . import util as backup_util
from .agent import (
@@ -72,7 +73,6 @@ from .models import (
)
from .store import BackupStore
from .util import (
AsyncIteratorReader,
DecryptedBackupStreamer,
EncryptedBackupStreamer,
make_backup_dir,
@@ -1525,7 +1525,7 @@ class BackupManager:
reader = await self.hass.async_add_executor_job(open, path.as_posix(), "rb")
else:
backup_stream = await agent.async_download_backup(backup_id)
reader = cast(IO[bytes], AsyncIteratorReader(self.hass, backup_stream))
reader = cast(IO[bytes], AsyncIteratorReader(self.hass.loop, backup_stream))
try:
await self.hass.async_add_executor_job(
validate_password_stream, reader, password

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator, Callable, Coroutine
from concurrent.futures import CancelledError, Future
import copy
from dataclasses import dataclass, replace
from io import BytesIO
@@ -14,7 +13,7 @@ from pathlib import Path, PurePath
from queue import SimpleQueue
import tarfile
import threading
from typing import IO, Any, Self, cast
from typing import IO, Any, cast
import aiohttp
from securetar import SecureTarError, SecureTarFile, SecureTarReadError
@@ -23,6 +22,11 @@ from homeassistant.backup_restore import password_to_key
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import dt as dt_util
from homeassistant.util.async_iterator import (
Abort,
AsyncIteratorReader,
AsyncIteratorWriter,
)
from homeassistant.util.json import JsonObjectType, json_loads_object
from .const import BUF_SIZE, LOGGER
@@ -59,12 +63,6 @@ class BackupEmpty(DecryptError):
_message = "No tar files found in the backup."
class AbortCipher(HomeAssistantError):
"""Abort the cipher operation."""
_message = "Abort cipher operation."
def make_backup_dir(path: Path) -> None:
"""Create a backup directory if it does not exist."""
path.mkdir(exist_ok=True)
@@ -166,106 +164,6 @@ def validate_password(path: Path, password: str | None) -> bool:
return False
class AsyncIteratorReader:
"""Wrap an AsyncIterator."""
def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None:
"""Initialize the wrapper."""
self._aborted = False
self._hass = hass
self._stream = stream
self._buffer: bytes | None = None
self._next_future: Future[bytes | None] | None = None
self._pos: int = 0
async def _next(self) -> bytes | None:
"""Get the next chunk from the iterator."""
return await anext(self._stream, None)
def abort(self) -> None:
"""Abort the reader."""
self._aborted = True
if self._next_future is not None:
self._next_future.cancel()
def read(self, n: int = -1, /) -> bytes:
"""Read data from the iterator."""
result = bytearray()
while n < 0 or len(result) < n:
if not self._buffer:
self._next_future = asyncio.run_coroutine_threadsafe(
self._next(), self._hass.loop
)
if self._aborted:
self._next_future.cancel()
raise AbortCipher
try:
self._buffer = self._next_future.result()
except CancelledError as err:
raise AbortCipher from err
self._pos = 0
if not self._buffer:
# The stream is exhausted
break
chunk = self._buffer[self._pos : self._pos + n]
result.extend(chunk)
n -= len(chunk)
self._pos += len(chunk)
if self._pos == len(self._buffer):
self._buffer = None
return bytes(result)
def close(self) -> None:
"""Close the iterator."""
class AsyncIteratorWriter:
"""Wrap an AsyncIterator."""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the wrapper."""
self._aborted = False
self._hass = hass
self._pos: int = 0
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
self._write_future: Future[bytes | None] | None = None
def __aiter__(self) -> Self:
"""Return the iterator."""
return self
async def __anext__(self) -> bytes:
"""Get the next chunk from the iterator."""
if data := await self._queue.get():
return data
raise StopAsyncIteration
def abort(self) -> None:
"""Abort the writer."""
self._aborted = True
if self._write_future is not None:
self._write_future.cancel()
def tell(self) -> int:
"""Return the current position in the iterator."""
return self._pos
def write(self, s: bytes, /) -> int:
"""Write data to the iterator."""
self._write_future = asyncio.run_coroutine_threadsafe(
self._queue.put(s), self._hass.loop
)
if self._aborted:
self._write_future.cancel()
raise AbortCipher
try:
self._write_future.result()
except CancelledError as err:
raise AbortCipher from err
self._pos += len(s)
return len(s)
def validate_password_stream(
input_stream: IO[bytes],
password: str | None,
@@ -342,7 +240,7 @@ def decrypt_backup(
finally:
# Write an empty chunk to signal the end of the stream
output_stream.write(b"")
except AbortCipher:
except Abort:
LOGGER.debug("Cipher operation aborted")
finally:
on_done(error)
@@ -430,7 +328,7 @@ def encrypt_backup(
finally:
# Write an empty chunk to signal the end of the stream
output_stream.write(b"")
except AbortCipher:
except Abort:
LOGGER.debug("Cipher operation aborted")
finally:
on_done(error)
@@ -557,8 +455,8 @@ class _CipherBackupStreamer:
self._hass.loop.call_soon_threadsafe(worker_status.done.set)
stream = await self._open_stream()
reader = AsyncIteratorReader(self._hass, stream)
writer = AsyncIteratorWriter(self._hass)
reader = AsyncIteratorReader(self._hass.loop, stream)
writer = AsyncIteratorWriter(self._hass.loop)
worker = threading.Thread(
target=self._cipher_func,
args=[

View File

@@ -0,0 +1,134 @@
"""Async iterator utilities."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from concurrent.futures import CancelledError, Future
from typing import Self
class Abort(Exception):
"""Raised when abort is requested."""
class AsyncIteratorReader:
"""Allow reading from an AsyncIterator using blocking I/O.
The class implements a blocking read method reading from the async iterator,
and a close method.
In addition, the abort method can be used to abort any ongoing read operation.
"""
def __init__(
self,
loop: asyncio.AbstractEventLoop,
stream: AsyncIterator[bytes],
) -> None:
"""Initialize the wrapper."""
self._aborted = False
self._loop = loop
self._stream = stream
self._buffer: bytes | None = None
self._next_future: Future[bytes | None] | None = None
self._pos: int = 0
async def _next(self) -> bytes | None:
"""Get the next chunk from the iterator."""
return await anext(self._stream, None)
def abort(self) -> None:
"""Abort the reader."""
self._aborted = True
if self._next_future is not None:
self._next_future.cancel()
def read(self, n: int = -1, /) -> bytes:
"""Read up to n bytes of data from the iterator.
The read method returns 0 bytes when the iterator is exhausted.
"""
result = bytearray()
while n < 0 or len(result) < n:
if not self._buffer:
self._next_future = asyncio.run_coroutine_threadsafe(
self._next(), self._loop
)
if self._aborted:
self._next_future.cancel()
raise Abort
try:
self._buffer = self._next_future.result()
except CancelledError as err:
raise Abort from err
self._pos = 0
if not self._buffer:
# The stream is exhausted
break
chunk = self._buffer[self._pos : self._pos + n]
result.extend(chunk)
n -= len(chunk)
self._pos += len(chunk)
if self._pos == len(self._buffer):
self._buffer = None
return bytes(result)
def close(self) -> None:
"""Close the iterator."""
class AsyncIteratorWriter:
"""Allow writing to an AsyncIterator using blocking I/O.
The class implements a blocking write method writing to the async iterator,
as well as a close and tell methods.
In addition, the abort method can be used to abort any ongoing write operation.
"""
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
"""Initialize the wrapper."""
self._aborted = False
self._loop = loop
self._pos: int = 0
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
self._write_future: Future[bytes | None] | None = None
def __aiter__(self) -> Self:
"""Return the iterator."""
return self
async def __anext__(self) -> bytes:
"""Get the next chunk from the iterator."""
if data := await self._queue.get():
return data
raise StopAsyncIteration
def abort(self) -> None:
"""Abort the writer."""
self._aborted = True
if self._write_future is not None:
self._write_future.cancel()
def tell(self) -> int:
"""Return the current position in the iterator."""
return self._pos
def write(self, s: bytes, /) -> int:
"""Write data to the iterator.
To signal the end of the stream, write a zero-length bytes object.
"""
self._write_future = asyncio.run_coroutine_threadsafe(
self._queue.put(s), self._loop
)
if self._aborted:
self._write_future.cancel()
raise Abort
try:
self._write_future.result()
except CancelledError as err:
raise Abort from err
self._pos += len(s)
return len(s)

View File

@@ -0,0 +1,116 @@
"""Tests for async iterator utility functions."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
import pytest
from homeassistant.core import HomeAssistant
from homeassistant.util.async_iterator import (
Abort,
AsyncIteratorReader,
AsyncIteratorWriter,
)
def _read_all(reader: AsyncIteratorReader) -> bytes:
output = b""
while chunk := reader.read(500):
output += chunk
return output
async def test_async_iterator_reader(hass: HomeAssistant) -> None:
"""Test the async iterator reader."""
data = b"hello world" * 1000
async def async_gen() -> AsyncIterator[bytes]:
for _ in range(10):
yield data
reader = AsyncIteratorReader(hass.loop, async_gen())
assert await hass.async_add_executor_job(_read_all, reader) == data * 10
async def test_async_iterator_reader_abort_early(hass: HomeAssistant) -> None:
"""Test abort the async iterator reader."""
evt = asyncio.Event()
async def async_gen() -> AsyncIterator[bytes]:
await evt.wait()
yield b""
reader = AsyncIteratorReader(hass.loop, async_gen())
reader.abort()
fut = hass.async_add_executor_job(_read_all, reader)
with pytest.raises(Abort):
await fut
async def test_async_iterator_reader_abort_late(hass: HomeAssistant) -> None:
"""Test abort the async iterator reader."""
evt = asyncio.Event()
async def async_gen() -> AsyncIterator[bytes]:
await evt.wait()
yield b""
reader = AsyncIteratorReader(hass.loop, async_gen())
fut = hass.async_add_executor_job(_read_all, reader)
await asyncio.sleep(0.1)
reader.abort()
with pytest.raises(Abort):
await fut
def _write_all(writer: AsyncIteratorWriter, data: list[bytes]) -> bytes:
for chunk in data:
assert writer.write(chunk) == len(chunk)
assert writer.write(b"") == 0
async def test_async_iterator_writer(hass: HomeAssistant) -> None:
"""Test the async iterator writer."""
chunk = b"hello world" * 1000
chunks = [chunk] * 10
writer = AsyncIteratorWriter(hass.loop)
fut = hass.async_add_executor_job(_write_all, writer, chunks)
read = b""
async for data in writer:
read += data
await fut
assert read == chunk * 10
assert writer.tell() == len(read)
async def test_async_iterator_writer_abort_early(hass: HomeAssistant) -> None:
"""Test the async iterator writer."""
chunk = b"hello world" * 1000
chunks = [chunk] * 10
writer = AsyncIteratorWriter(hass.loop)
writer.abort()
fut = hass.async_add_executor_job(_write_all, writer, chunks)
with pytest.raises(Abort):
await fut
async def test_async_iterator_writer_abort_late(hass: HomeAssistant) -> None:
"""Test the async iterator writer."""
chunk = b"hello world" * 1000
chunks = [chunk] * 10
writer = AsyncIteratorWriter(hass.loop)
fut = hass.async_add_executor_job(_write_all, writer, chunks)
await asyncio.sleep(0.1)
writer.abort()
with pytest.raises(Abort):
await fut