mirror of
https://github.com/home-assistant/core.git
synced 2025-08-02 10:08:23 +00:00
Add BufferedAsyncIteratorToSyncStream util
This commit is contained in:
parent
65fce5f056
commit
d7738ee732
@ -18,6 +18,7 @@ from homeassistant.core import HomeAssistant, callback
|
|||||||
|
|
||||||
from . import BackblazeConfigEntry
|
from . import BackblazeConfigEntry
|
||||||
from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN, SEPARATOR
|
from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN, SEPARATOR
|
||||||
|
from .util import BufferedAsyncIteratorToSyncStream
|
||||||
|
|
||||||
|
|
||||||
async def async_get_backup_agents(
|
async def async_get_backup_agents(
|
||||||
@ -98,6 +99,8 @@ class BackblazeBackupAgent(BackupAgent):
|
|||||||
"""Upload a backup."""
|
"""Upload a backup."""
|
||||||
|
|
||||||
# Prepare file info metadata to store with the backup in Backblaze
|
# Prepare file info metadata to store with the backup in Backblaze
|
||||||
|
# Backblaze can only store a mapping of strings to strings, so we need
|
||||||
|
# to serialize the metadata into a string format.
|
||||||
file_info = {
|
file_info = {
|
||||||
"backup_id": backup.backup_id,
|
"backup_id": backup.backup_id,
|
||||||
"database_included": str(backup.database_included).lower(),
|
"database_included": str(backup.database_included).lower(),
|
||||||
@ -119,12 +122,15 @@ class BackblazeBackupAgent(BackupAgent):
|
|||||||
if backup.folders:
|
if backup.folders:
|
||||||
file_info["folders"] = ",".join(folder.value for folder in backup.folders)
|
file_info["folders"] = ",".join(folder.value for folder in backup.folders)
|
||||||
|
|
||||||
stream: AsyncIterator[bytes] = await open_stream()
|
iterator = await open_stream()
|
||||||
|
stream = BufferedAsyncIteratorToSyncStream(
|
||||||
|
iterator,
|
||||||
|
buffer_size=8 * 1024 * 1024, # Buffer up to 8MB
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
await self._hass.async_add_executor_job(
|
await self._hass.async_add_executor_job(
|
||||||
self._bucket.upload_bytes,
|
self._bucket.upload_unbound_stream,
|
||||||
b"".join([chunk async for chunk in stream]),
|
stream,
|
||||||
f"{backup.backup_id}.tar",
|
f"{backup.backup_id}.tar",
|
||||||
"application/octet-stream",
|
"application/octet-stream",
|
||||||
file_info,
|
file_info,
|
||||||
|
87
homeassistant/components/backblaze/util.py
Normal file
87
homeassistant/components/backblaze/util.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
"""Utilities for the Backblaze B2 integration."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from concurrent.futures import Future
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
|
class BufferedAsyncIteratorToSyncStream(io.RawIOBase):
|
||||||
|
"""An wrapper to make an AsyncIterator[bytes] a buffered synchronous readable stream."""
|
||||||
|
|
||||||
|
_done: bool = False
|
||||||
|
_read_future: Future[bytes] | None = None
|
||||||
|
|
||||||
|
def __init__(self, iterator: AsyncIterator[bytes], buffer_size: int = 1024) -> None:
|
||||||
|
"""Initialize the stream."""
|
||||||
|
self._buffer = bytearray()
|
||||||
|
self._buffer_size = buffer_size
|
||||||
|
self._iterator = iterator
|
||||||
|
self._loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
def readable(self) -> bool:
|
||||||
|
"""Mark the stream as readable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _load_next_chunk(self) -> None:
|
||||||
|
"""Load the next chunk into the buffer."""
|
||||||
|
if self._done:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self._read_future:
|
||||||
|
# Fetch a larger chunk asynchronously
|
||||||
|
self._read_future = asyncio.run_coroutine_threadsafe(
|
||||||
|
self._fetch_next_chunk(), self._loop
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._read_future.done():
|
||||||
|
try:
|
||||||
|
data = self._read_future.result()
|
||||||
|
if data:
|
||||||
|
self._buffer.extend(data)
|
||||||
|
else:
|
||||||
|
self._done = True
|
||||||
|
except StopAsyncIteration:
|
||||||
|
self._done = True
|
||||||
|
except Exception as err: # noqa: BLE001
|
||||||
|
raise io.BlockingIOError(f"Failed to load chunk: {err}") from err
|
||||||
|
finally:
|
||||||
|
self._read_future = None
|
||||||
|
|
||||||
|
async def _fetch_next_chunk(self) -> bytes:
|
||||||
|
"""Fetch multiple chunks until buffer size is filled."""
|
||||||
|
chunks = []
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Fill the buffer up to the specified size
|
||||||
|
while total_size < self._buffer_size:
|
||||||
|
chunk = await anext(self._iterator)
|
||||||
|
chunks.append(chunk)
|
||||||
|
total_size += len(chunk)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
pass # The end, return what we have
|
||||||
|
|
||||||
|
return b"".join(chunks)
|
||||||
|
|
||||||
|
def read(self, size: int = -1) -> bytes:
|
||||||
|
"""Read bytes."""
|
||||||
|
if size == -1:
|
||||||
|
# Read all remaining data
|
||||||
|
while not self._done:
|
||||||
|
self._load_next_chunk()
|
||||||
|
size = len(self._buffer)
|
||||||
|
|
||||||
|
# Ensure enough data in the buffer
|
||||||
|
while len(self._buffer) < size and not self._done:
|
||||||
|
self._load_next_chunk()
|
||||||
|
|
||||||
|
# Return requested data
|
||||||
|
data = self._buffer[:size]
|
||||||
|
self._buffer = self._buffer[size:]
|
||||||
|
return bytes(data)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the stream."""
|
||||||
|
self._done = True
|
||||||
|
super().close()
|
Loading…
x
Reference in New Issue
Block a user