mirror of
https://github.com/home-assistant/core.git
synced 2025-11-09 19:09:32 +00:00
Add async_iterator util (#153194)
This commit is contained in:
@@ -17,6 +17,7 @@ from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import frame
|
||||
from homeassistant.util import slugify
|
||||
from homeassistant.util.async_iterator import AsyncIteratorReader, AsyncIteratorWriter
|
||||
|
||||
from . import util
|
||||
from .agent import BackupAgent
|
||||
@@ -144,7 +145,7 @@ class DownloadBackupView(HomeAssistantView):
|
||||
return Response(status=HTTPStatus.NOT_FOUND)
|
||||
else:
|
||||
stream = await agent.async_download_backup(backup_id)
|
||||
reader = cast(IO[bytes], util.AsyncIteratorReader(hass, stream))
|
||||
reader = cast(IO[bytes], AsyncIteratorReader(hass.loop, stream))
|
||||
|
||||
worker_done_event = asyncio.Event()
|
||||
|
||||
@@ -152,7 +153,7 @@ class DownloadBackupView(HomeAssistantView):
|
||||
"""Call by the worker thread when it's done."""
|
||||
hass.loop.call_soon_threadsafe(worker_done_event.set)
|
||||
|
||||
stream = util.AsyncIteratorWriter(hass)
|
||||
stream = AsyncIteratorWriter(hass.loop)
|
||||
worker = threading.Thread(
|
||||
target=util.decrypt_backup,
|
||||
args=[backup, reader, stream, password, on_done, 0, []],
|
||||
|
||||
@@ -38,6 +38,7 @@ from homeassistant.helpers import (
|
||||
)
|
||||
from homeassistant.helpers.json import json_bytes
|
||||
from homeassistant.util import dt as dt_util, json as json_util
|
||||
from homeassistant.util.async_iterator import AsyncIteratorReader
|
||||
|
||||
from . import util as backup_util
|
||||
from .agent import (
|
||||
@@ -72,7 +73,6 @@ from .models import (
|
||||
)
|
||||
from .store import BackupStore
|
||||
from .util import (
|
||||
AsyncIteratorReader,
|
||||
DecryptedBackupStreamer,
|
||||
EncryptedBackupStreamer,
|
||||
make_backup_dir,
|
||||
@@ -1525,7 +1525,7 @@ class BackupManager:
|
||||
reader = await self.hass.async_add_executor_job(open, path.as_posix(), "rb")
|
||||
else:
|
||||
backup_stream = await agent.async_download_backup(backup_id)
|
||||
reader = cast(IO[bytes], AsyncIteratorReader(self.hass, backup_stream))
|
||||
reader = cast(IO[bytes], AsyncIteratorReader(self.hass.loop, backup_stream))
|
||||
try:
|
||||
await self.hass.async_add_executor_job(
|
||||
validate_password_stream, reader, password
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, Callable, Coroutine
|
||||
from concurrent.futures import CancelledError, Future
|
||||
import copy
|
||||
from dataclasses import dataclass, replace
|
||||
from io import BytesIO
|
||||
@@ -14,7 +13,7 @@ from pathlib import Path, PurePath
|
||||
from queue import SimpleQueue
|
||||
import tarfile
|
||||
import threading
|
||||
from typing import IO, Any, Self, cast
|
||||
from typing import IO, Any, cast
|
||||
|
||||
import aiohttp
|
||||
from securetar import SecureTarError, SecureTarFile, SecureTarReadError
|
||||
@@ -23,6 +22,11 @@ from homeassistant.backup_restore import password_to_key
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.async_iterator import (
|
||||
Abort,
|
||||
AsyncIteratorReader,
|
||||
AsyncIteratorWriter,
|
||||
)
|
||||
from homeassistant.util.json import JsonObjectType, json_loads_object
|
||||
|
||||
from .const import BUF_SIZE, LOGGER
|
||||
@@ -59,12 +63,6 @@ class BackupEmpty(DecryptError):
|
||||
_message = "No tar files found in the backup."
|
||||
|
||||
|
||||
class AbortCipher(HomeAssistantError):
|
||||
"""Abort the cipher operation."""
|
||||
|
||||
_message = "Abort cipher operation."
|
||||
|
||||
|
||||
def make_backup_dir(path: Path) -> None:
|
||||
"""Create a backup directory if it does not exist."""
|
||||
path.mkdir(exist_ok=True)
|
||||
@@ -166,106 +164,6 @@ def validate_password(path: Path, password: str | None) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class AsyncIteratorReader:
|
||||
"""Wrap an AsyncIterator."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None:
|
||||
"""Initialize the wrapper."""
|
||||
self._aborted = False
|
||||
self._hass = hass
|
||||
self._stream = stream
|
||||
self._buffer: bytes | None = None
|
||||
self._next_future: Future[bytes | None] | None = None
|
||||
self._pos: int = 0
|
||||
|
||||
async def _next(self) -> bytes | None:
|
||||
"""Get the next chunk from the iterator."""
|
||||
return await anext(self._stream, None)
|
||||
|
||||
def abort(self) -> None:
|
||||
"""Abort the reader."""
|
||||
self._aborted = True
|
||||
if self._next_future is not None:
|
||||
self._next_future.cancel()
|
||||
|
||||
def read(self, n: int = -1, /) -> bytes:
|
||||
"""Read data from the iterator."""
|
||||
result = bytearray()
|
||||
while n < 0 or len(result) < n:
|
||||
if not self._buffer:
|
||||
self._next_future = asyncio.run_coroutine_threadsafe(
|
||||
self._next(), self._hass.loop
|
||||
)
|
||||
if self._aborted:
|
||||
self._next_future.cancel()
|
||||
raise AbortCipher
|
||||
try:
|
||||
self._buffer = self._next_future.result()
|
||||
except CancelledError as err:
|
||||
raise AbortCipher from err
|
||||
self._pos = 0
|
||||
if not self._buffer:
|
||||
# The stream is exhausted
|
||||
break
|
||||
chunk = self._buffer[self._pos : self._pos + n]
|
||||
result.extend(chunk)
|
||||
n -= len(chunk)
|
||||
self._pos += len(chunk)
|
||||
if self._pos == len(self._buffer):
|
||||
self._buffer = None
|
||||
return bytes(result)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the iterator."""
|
||||
|
||||
|
||||
class AsyncIteratorWriter:
|
||||
"""Wrap an AsyncIterator."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the wrapper."""
|
||||
self._aborted = False
|
||||
self._hass = hass
|
||||
self._pos: int = 0
|
||||
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
|
||||
self._write_future: Future[bytes | None] | None = None
|
||||
|
||||
def __aiter__(self) -> Self:
|
||||
"""Return the iterator."""
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> bytes:
|
||||
"""Get the next chunk from the iterator."""
|
||||
if data := await self._queue.get():
|
||||
return data
|
||||
raise StopAsyncIteration
|
||||
|
||||
def abort(self) -> None:
|
||||
"""Abort the writer."""
|
||||
self._aborted = True
|
||||
if self._write_future is not None:
|
||||
self._write_future.cancel()
|
||||
|
||||
def tell(self) -> int:
|
||||
"""Return the current position in the iterator."""
|
||||
return self._pos
|
||||
|
||||
def write(self, s: bytes, /) -> int:
|
||||
"""Write data to the iterator."""
|
||||
self._write_future = asyncio.run_coroutine_threadsafe(
|
||||
self._queue.put(s), self._hass.loop
|
||||
)
|
||||
if self._aborted:
|
||||
self._write_future.cancel()
|
||||
raise AbortCipher
|
||||
try:
|
||||
self._write_future.result()
|
||||
except CancelledError as err:
|
||||
raise AbortCipher from err
|
||||
self._pos += len(s)
|
||||
return len(s)
|
||||
|
||||
|
||||
def validate_password_stream(
|
||||
input_stream: IO[bytes],
|
||||
password: str | None,
|
||||
@@ -342,7 +240,7 @@ def decrypt_backup(
|
||||
finally:
|
||||
# Write an empty chunk to signal the end of the stream
|
||||
output_stream.write(b"")
|
||||
except AbortCipher:
|
||||
except Abort:
|
||||
LOGGER.debug("Cipher operation aborted")
|
||||
finally:
|
||||
on_done(error)
|
||||
@@ -430,7 +328,7 @@ def encrypt_backup(
|
||||
finally:
|
||||
# Write an empty chunk to signal the end of the stream
|
||||
output_stream.write(b"")
|
||||
except AbortCipher:
|
||||
except Abort:
|
||||
LOGGER.debug("Cipher operation aborted")
|
||||
finally:
|
||||
on_done(error)
|
||||
@@ -557,8 +455,8 @@ class _CipherBackupStreamer:
|
||||
self._hass.loop.call_soon_threadsafe(worker_status.done.set)
|
||||
|
||||
stream = await self._open_stream()
|
||||
reader = AsyncIteratorReader(self._hass, stream)
|
||||
writer = AsyncIteratorWriter(self._hass)
|
||||
reader = AsyncIteratorReader(self._hass.loop, stream)
|
||||
writer = AsyncIteratorWriter(self._hass.loop)
|
||||
worker = threading.Thread(
|
||||
target=self._cipher_func,
|
||||
args=[
|
||||
|
||||
134
homeassistant/util/async_iterator.py
Normal file
134
homeassistant/util/async_iterator.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Async iterator utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from concurrent.futures import CancelledError, Future
|
||||
from typing import Self
|
||||
|
||||
|
||||
class Abort(Exception):
|
||||
"""Raised when abort is requested."""
|
||||
|
||||
|
||||
class AsyncIteratorReader:
|
||||
"""Allow reading from an AsyncIterator using blocking I/O.
|
||||
|
||||
The class implements a blocking read method reading from the async iterator,
|
||||
and a close method.
|
||||
|
||||
In addition, the abort method can be used to abort any ongoing read operation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
stream: AsyncIterator[bytes],
|
||||
) -> None:
|
||||
"""Initialize the wrapper."""
|
||||
self._aborted = False
|
||||
self._loop = loop
|
||||
self._stream = stream
|
||||
self._buffer: bytes | None = None
|
||||
self._next_future: Future[bytes | None] | None = None
|
||||
self._pos: int = 0
|
||||
|
||||
async def _next(self) -> bytes | None:
|
||||
"""Get the next chunk from the iterator."""
|
||||
return await anext(self._stream, None)
|
||||
|
||||
def abort(self) -> None:
|
||||
"""Abort the reader."""
|
||||
self._aborted = True
|
||||
if self._next_future is not None:
|
||||
self._next_future.cancel()
|
||||
|
||||
def read(self, n: int = -1, /) -> bytes:
|
||||
"""Read up to n bytes of data from the iterator.
|
||||
|
||||
The read method returns 0 bytes when the iterator is exhausted.
|
||||
"""
|
||||
result = bytearray()
|
||||
while n < 0 or len(result) < n:
|
||||
if not self._buffer:
|
||||
self._next_future = asyncio.run_coroutine_threadsafe(
|
||||
self._next(), self._loop
|
||||
)
|
||||
if self._aborted:
|
||||
self._next_future.cancel()
|
||||
raise Abort
|
||||
try:
|
||||
self._buffer = self._next_future.result()
|
||||
except CancelledError as err:
|
||||
raise Abort from err
|
||||
self._pos = 0
|
||||
if not self._buffer:
|
||||
# The stream is exhausted
|
||||
break
|
||||
chunk = self._buffer[self._pos : self._pos + n]
|
||||
result.extend(chunk)
|
||||
n -= len(chunk)
|
||||
self._pos += len(chunk)
|
||||
if self._pos == len(self._buffer):
|
||||
self._buffer = None
|
||||
return bytes(result)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the iterator."""
|
||||
|
||||
|
||||
class AsyncIteratorWriter:
|
||||
"""Allow writing to an AsyncIterator using blocking I/O.
|
||||
|
||||
The class implements a blocking write method writing to the async iterator,
|
||||
as well as a close and tell methods.
|
||||
|
||||
In addition, the abort method can be used to abort any ongoing write operation.
|
||||
"""
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Initialize the wrapper."""
|
||||
self._aborted = False
|
||||
self._loop = loop
|
||||
self._pos: int = 0
|
||||
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
|
||||
self._write_future: Future[bytes | None] | None = None
|
||||
|
||||
def __aiter__(self) -> Self:
|
||||
"""Return the iterator."""
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> bytes:
|
||||
"""Get the next chunk from the iterator."""
|
||||
if data := await self._queue.get():
|
||||
return data
|
||||
raise StopAsyncIteration
|
||||
|
||||
def abort(self) -> None:
|
||||
"""Abort the writer."""
|
||||
self._aborted = True
|
||||
if self._write_future is not None:
|
||||
self._write_future.cancel()
|
||||
|
||||
def tell(self) -> int:
|
||||
"""Return the current position in the iterator."""
|
||||
return self._pos
|
||||
|
||||
def write(self, s: bytes, /) -> int:
|
||||
"""Write data to the iterator.
|
||||
|
||||
To signal the end of the stream, write a zero-length bytes object.
|
||||
"""
|
||||
self._write_future = asyncio.run_coroutine_threadsafe(
|
||||
self._queue.put(s), self._loop
|
||||
)
|
||||
if self._aborted:
|
||||
self._write_future.cancel()
|
||||
raise Abort
|
||||
try:
|
||||
self._write_future.result()
|
||||
except CancelledError as err:
|
||||
raise Abort from err
|
||||
self._pos += len(s)
|
||||
return len(s)
|
||||
116
tests/util/test_async_iterator.py
Normal file
116
tests/util/test_async_iterator.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Tests for async iterator utility functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.util.async_iterator import (
|
||||
Abort,
|
||||
AsyncIteratorReader,
|
||||
AsyncIteratorWriter,
|
||||
)
|
||||
|
||||
|
||||
def _read_all(reader: AsyncIteratorReader) -> bytes:
|
||||
output = b""
|
||||
while chunk := reader.read(500):
|
||||
output += chunk
|
||||
return output
|
||||
|
||||
|
||||
async def test_async_iterator_reader(hass: HomeAssistant) -> None:
|
||||
"""Test the async iterator reader."""
|
||||
data = b"hello world" * 1000
|
||||
|
||||
async def async_gen() -> AsyncIterator[bytes]:
|
||||
for _ in range(10):
|
||||
yield data
|
||||
|
||||
reader = AsyncIteratorReader(hass.loop, async_gen())
|
||||
assert await hass.async_add_executor_job(_read_all, reader) == data * 10
|
||||
|
||||
|
||||
async def test_async_iterator_reader_abort_early(hass: HomeAssistant) -> None:
|
||||
"""Test abort the async iterator reader."""
|
||||
evt = asyncio.Event()
|
||||
|
||||
async def async_gen() -> AsyncIterator[bytes]:
|
||||
await evt.wait()
|
||||
yield b""
|
||||
|
||||
reader = AsyncIteratorReader(hass.loop, async_gen())
|
||||
reader.abort()
|
||||
fut = hass.async_add_executor_job(_read_all, reader)
|
||||
with pytest.raises(Abort):
|
||||
await fut
|
||||
|
||||
|
||||
async def test_async_iterator_reader_abort_late(hass: HomeAssistant) -> None:
|
||||
"""Test abort the async iterator reader."""
|
||||
evt = asyncio.Event()
|
||||
|
||||
async def async_gen() -> AsyncIterator[bytes]:
|
||||
await evt.wait()
|
||||
yield b""
|
||||
|
||||
reader = AsyncIteratorReader(hass.loop, async_gen())
|
||||
fut = hass.async_add_executor_job(_read_all, reader)
|
||||
await asyncio.sleep(0.1)
|
||||
reader.abort()
|
||||
with pytest.raises(Abort):
|
||||
await fut
|
||||
|
||||
|
||||
def _write_all(writer: AsyncIteratorWriter, data: list[bytes]) -> bytes:
|
||||
for chunk in data:
|
||||
assert writer.write(chunk) == len(chunk)
|
||||
assert writer.write(b"") == 0
|
||||
|
||||
|
||||
async def test_async_iterator_writer(hass: HomeAssistant) -> None:
|
||||
"""Test the async iterator writer."""
|
||||
chunk = b"hello world" * 1000
|
||||
chunks = [chunk] * 10
|
||||
writer = AsyncIteratorWriter(hass.loop)
|
||||
|
||||
fut = hass.async_add_executor_job(_write_all, writer, chunks)
|
||||
|
||||
read = b""
|
||||
async for data in writer:
|
||||
read += data
|
||||
|
||||
await fut
|
||||
|
||||
assert read == chunk * 10
|
||||
assert writer.tell() == len(read)
|
||||
|
||||
|
||||
async def test_async_iterator_writer_abort_early(hass: HomeAssistant) -> None:
|
||||
"""Test the async iterator writer."""
|
||||
chunk = b"hello world" * 1000
|
||||
chunks = [chunk] * 10
|
||||
writer = AsyncIteratorWriter(hass.loop)
|
||||
writer.abort()
|
||||
|
||||
fut = hass.async_add_executor_job(_write_all, writer, chunks)
|
||||
|
||||
with pytest.raises(Abort):
|
||||
await fut
|
||||
|
||||
|
||||
async def test_async_iterator_writer_abort_late(hass: HomeAssistant) -> None:
|
||||
"""Test the async iterator writer."""
|
||||
chunk = b"hello world" * 1000
|
||||
chunks = [chunk] * 10
|
||||
writer = AsyncIteratorWriter(hass.loop)
|
||||
|
||||
fut = hass.async_add_executor_job(_write_all, writer, chunks)
|
||||
await asyncio.sleep(0.1)
|
||||
writer.abort()
|
||||
|
||||
with pytest.raises(Abort):
|
||||
await fut
|
||||
Reference in New Issue
Block a user