Files
core/homeassistant/util/async_iterator.py
2025-09-29 18:54:23 +02:00

135 lines
4.1 KiB
Python

"""Async iterator utilities."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from concurrent.futures import CancelledError, Future
from typing import Self
class Abort(Exception):
"""Raised when abort is requested."""
class AsyncIteratorReader:
"""Allow reading from an AsyncIterator using blocking I/O.
The class implements a blocking read method reading from the async iterator,
and a close method.
In addition, the abort method can be used to abort any ongoing read operation.
"""
def __init__(
self,
loop: asyncio.AbstractEventLoop,
stream: AsyncIterator[bytes],
) -> None:
"""Initialize the wrapper."""
self._aborted = False
self._loop = loop
self._stream = stream
self._buffer: bytes | None = None
self._next_future: Future[bytes | None] | None = None
self._pos: int = 0
async def _next(self) -> bytes | None:
"""Get the next chunk from the iterator."""
return await anext(self._stream, None)
def abort(self) -> None:
"""Abort the reader."""
self._aborted = True
if self._next_future is not None:
self._next_future.cancel()
def read(self, n: int = -1, /) -> bytes:
"""Read up to n bytes of data from the iterator.
The read method returns 0 bytes when the iterator is exhausted.
"""
result = bytearray()
while n < 0 or len(result) < n:
if not self._buffer:
self._next_future = asyncio.run_coroutine_threadsafe(
self._next(), self._loop
)
if self._aborted:
self._next_future.cancel()
raise Abort
try:
self._buffer = self._next_future.result()
except CancelledError as err:
raise Abort from err
self._pos = 0
if not self._buffer:
# The stream is exhausted
break
chunk = self._buffer[self._pos : self._pos + n]
result.extend(chunk)
n -= len(chunk)
self._pos += len(chunk)
if self._pos == len(self._buffer):
self._buffer = None
return bytes(result)
def close(self) -> None:
"""Close the iterator."""
class AsyncIteratorWriter:
"""Allow writing to an AsyncIterator using blocking I/O.
The class implements a blocking write method writing to the async iterator,
as well as a close and tell methods.
In addition, the abort method can be used to abort any ongoing write operation.
"""
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
"""Initialize the wrapper."""
self._aborted = False
self._loop = loop
self._pos: int = 0
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
self._write_future: Future[bytes | None] | None = None
def __aiter__(self) -> Self:
"""Return the iterator."""
return self
async def __anext__(self) -> bytes:
"""Get the next chunk from the iterator."""
if data := await self._queue.get():
return data
raise StopAsyncIteration
def abort(self) -> None:
"""Abort the writer."""
self._aborted = True
if self._write_future is not None:
self._write_future.cancel()
def tell(self) -> int:
"""Return the current position in the iterator."""
return self._pos
def write(self, s: bytes, /) -> int:
"""Write data to the iterator.
To signal the end of the stream, write a zero-length bytes object.
"""
self._write_future = asyncio.run_coroutine_threadsafe(
self._queue.put(s), self._loop
)
if self._aborted:
self._write_future.cancel()
raise Abort
try:
self._write_future.result()
except CancelledError as err:
raise Abort from err
self._pos += len(s)
return len(s)