Add BufferedAsyncIteratorToSyncStream util

This commit is contained in:
Franck Nijhof 2024-12-26 11:57:23 +00:00
parent 4c9014a8a0
commit d1d1910445
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
2 changed files with 97 additions and 4 deletions

View File

@ -18,6 +18,7 @@ from homeassistant.core import HomeAssistant, callback
from . import BackblazeConfigEntry
from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN, SEPARATOR
from .util import BufferedAsyncIteratorToSyncStream
async def async_get_backup_agents(
@ -98,6 +99,8 @@ class BackblazeBackupAgent(BackupAgent):
"""Upload a backup."""
# Prepare file info metadata to store with the backup in Backblaze
# Backblaze can only store a mapping of strings to strings, so we need
# to serialize the metadata into a string format.
file_info = {
"backup_id": backup.backup_id,
"database_included": str(backup.database_included).lower(),
@ -119,12 +122,15 @@ class BackblazeBackupAgent(BackupAgent):
if backup.folders:
file_info["folders"] = ",".join(folder.value for folder in backup.folders)
stream: AsyncIterator[bytes] = await open_stream()
iterator = await open_stream()
stream = BufferedAsyncIteratorToSyncStream(
iterator,
buffer_size=8 * 1024 * 1024, # Buffer up to 8MB
)
try:
await self._hass.async_add_executor_job(
self._bucket.upload_bytes,
b"".join([chunk async for chunk in stream]),
self._bucket.upload_unbound_stream,
stream,
f"{backup.backup_id}.tar",
"application/octet-stream",
file_info,

View File

@ -0,0 +1,87 @@
"""Utilities for the Backblaze B2 integration."""
import asyncio
from collections.abc import AsyncIterator
from concurrent.futures import Future
import io
class BufferedAsyncIteratorToSyncStream(io.RawIOBase):
"""An wrapper to make an AsyncIterator[bytes] a buffered synchronous readable stream."""
_done: bool = False
_read_future: Future[bytes] | None = None
def __init__(self, iterator: AsyncIterator[bytes], buffer_size: int = 1024) -> None:
"""Initialize the stream."""
self._buffer = bytearray()
self._buffer_size = buffer_size
self._iterator = iterator
self._loop = asyncio.get_running_loop()
def readable(self) -> bool:
"""Mark the stream as readable."""
return True
def _load_next_chunk(self) -> None:
"""Load the next chunk into the buffer."""
if self._done:
return
if not self._read_future:
# Fetch a larger chunk asynchronously
self._read_future = asyncio.run_coroutine_threadsafe(
self._fetch_next_chunk(), self._loop
)
if self._read_future.done():
try:
data = self._read_future.result()
if data:
self._buffer.extend(data)
else:
self._done = True
except StopAsyncIteration:
self._done = True
except Exception as err: # noqa: BLE001
raise io.BlockingIOError(f"Failed to load chunk: {err}") from err
finally:
self._read_future = None
async def _fetch_next_chunk(self) -> bytes:
"""Fetch multiple chunks until buffer size is filled."""
chunks = []
total_size = 0
try:
# Fill the buffer up to the specified size
while total_size < self._buffer_size:
chunk = await anext(self._iterator)
chunks.append(chunk)
total_size += len(chunk)
except StopAsyncIteration:
pass # The end, return what we have
return b"".join(chunks)
def read(self, size: int = -1) -> bytes:
"""Read bytes."""
if size == -1:
# Read all remaining data
while not self._done:
self._load_next_chunk()
size = len(self._buffer)
# Ensure enough data in the buffer
while len(self._buffer) < size and not self._done:
self._load_next_chunk()
# Return requested data
data = self._buffer[:size]
self._buffer = self._buffer[size:]
return bytes(data)
def close(self) -> None:
"""Close the stream."""
self._done = True
super().close()