mirror of
https://github.com/home-assistant/core.git
synced 2025-11-13 21:10:25 +00:00
Add async_iterator util (#153194)
This commit is contained in:
134
homeassistant/util/async_iterator.py
Normal file
134
homeassistant/util/async_iterator.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Async iterator utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from concurrent.futures import CancelledError, Future
|
||||
from typing import Self
|
||||
|
||||
|
||||
class Abort(Exception):
|
||||
"""Raised when abort is requested."""
|
||||
|
||||
|
||||
class AsyncIteratorReader:
|
||||
"""Allow reading from an AsyncIterator using blocking I/O.
|
||||
|
||||
The class implements a blocking read method reading from the async iterator,
|
||||
and a close method.
|
||||
|
||||
In addition, the abort method can be used to abort any ongoing read operation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
stream: AsyncIterator[bytes],
|
||||
) -> None:
|
||||
"""Initialize the wrapper."""
|
||||
self._aborted = False
|
||||
self._loop = loop
|
||||
self._stream = stream
|
||||
self._buffer: bytes | None = None
|
||||
self._next_future: Future[bytes | None] | None = None
|
||||
self._pos: int = 0
|
||||
|
||||
async def _next(self) -> bytes | None:
|
||||
"""Get the next chunk from the iterator."""
|
||||
return await anext(self._stream, None)
|
||||
|
||||
def abort(self) -> None:
|
||||
"""Abort the reader."""
|
||||
self._aborted = True
|
||||
if self._next_future is not None:
|
||||
self._next_future.cancel()
|
||||
|
||||
def read(self, n: int = -1, /) -> bytes:
|
||||
"""Read up to n bytes of data from the iterator.
|
||||
|
||||
The read method returns 0 bytes when the iterator is exhausted.
|
||||
"""
|
||||
result = bytearray()
|
||||
while n < 0 or len(result) < n:
|
||||
if not self._buffer:
|
||||
self._next_future = asyncio.run_coroutine_threadsafe(
|
||||
self._next(), self._loop
|
||||
)
|
||||
if self._aborted:
|
||||
self._next_future.cancel()
|
||||
raise Abort
|
||||
try:
|
||||
self._buffer = self._next_future.result()
|
||||
except CancelledError as err:
|
||||
raise Abort from err
|
||||
self._pos = 0
|
||||
if not self._buffer:
|
||||
# The stream is exhausted
|
||||
break
|
||||
chunk = self._buffer[self._pos : self._pos + n]
|
||||
result.extend(chunk)
|
||||
n -= len(chunk)
|
||||
self._pos += len(chunk)
|
||||
if self._pos == len(self._buffer):
|
||||
self._buffer = None
|
||||
return bytes(result)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the iterator."""
|
||||
|
||||
|
||||
class AsyncIteratorWriter:
|
||||
"""Allow writing to an AsyncIterator using blocking I/O.
|
||||
|
||||
The class implements a blocking write method writing to the async iterator,
|
||||
as well as a close and tell methods.
|
||||
|
||||
In addition, the abort method can be used to abort any ongoing write operation.
|
||||
"""
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Initialize the wrapper."""
|
||||
self._aborted = False
|
||||
self._loop = loop
|
||||
self._pos: int = 0
|
||||
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
|
||||
self._write_future: Future[bytes | None] | None = None
|
||||
|
||||
def __aiter__(self) -> Self:
|
||||
"""Return the iterator."""
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> bytes:
|
||||
"""Get the next chunk from the iterator."""
|
||||
if data := await self._queue.get():
|
||||
return data
|
||||
raise StopAsyncIteration
|
||||
|
||||
def abort(self) -> None:
|
||||
"""Abort the writer."""
|
||||
self._aborted = True
|
||||
if self._write_future is not None:
|
||||
self._write_future.cancel()
|
||||
|
||||
def tell(self) -> int:
|
||||
"""Return the current position in the iterator."""
|
||||
return self._pos
|
||||
|
||||
def write(self, s: bytes, /) -> int:
|
||||
"""Write data to the iterator.
|
||||
|
||||
To signal the end of the stream, write a zero-length bytes object.
|
||||
"""
|
||||
self._write_future = asyncio.run_coroutine_threadsafe(
|
||||
self._queue.put(s), self._loop
|
||||
)
|
||||
if self._aborted:
|
||||
self._write_future.cancel()
|
||||
raise Abort
|
||||
try:
|
||||
self._write_future.result()
|
||||
except CancelledError as err:
|
||||
raise Abort from err
|
||||
self._pos += len(s)
|
||||
return len(s)
|
||||
Reference in New Issue
Block a user