mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Improve shutdown of _CipherBackupStreamer (#137257)
* Improve shutdown of _CipherBackupStreamer * Catch the right exception
This commit is contained in:
parent
e18062bce4
commit
ca53d97a6d
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncIterator, Callable, Coroutine
|
from collections.abc import AsyncIterator, Callable, Coroutine
|
||||||
|
from concurrent.futures import CancelledError, Future
|
||||||
import copy
|
import copy
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass, replace
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -12,6 +13,7 @@ import os
|
|||||||
from pathlib import Path, PurePath
|
from pathlib import Path, PurePath
|
||||||
from queue import SimpleQueue
|
from queue import SimpleQueue
|
||||||
import tarfile
|
import tarfile
|
||||||
|
import threading
|
||||||
from typing import IO, Any, Self, cast
|
from typing import IO, Any, Self, cast
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -22,7 +24,6 @@ from homeassistant.core import HomeAssistant
|
|||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
from homeassistant.util.json import JsonObjectType, json_loads_object
|
from homeassistant.util.json import JsonObjectType, json_loads_object
|
||||||
from homeassistant.util.thread import ThreadWithException
|
|
||||||
|
|
||||||
from .const import BUF_SIZE, LOGGER
|
from .const import BUF_SIZE, LOGGER
|
||||||
from .models import AddonInfo, AgentBackup, Folder
|
from .models import AddonInfo, AgentBackup, Folder
|
||||||
@ -167,23 +168,38 @@ class AsyncIteratorReader:
|
|||||||
|
|
||||||
def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None:
|
def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None:
|
||||||
"""Initialize the wrapper."""
|
"""Initialize the wrapper."""
|
||||||
|
self._aborted = False
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
self._stream = stream
|
self._stream = stream
|
||||||
self._buffer: bytes | None = None
|
self._buffer: bytes | None = None
|
||||||
|
self._next_future: Future[bytes | None] | None = None
|
||||||
self._pos: int = 0
|
self._pos: int = 0
|
||||||
|
|
||||||
async def _next(self) -> bytes | None:
|
async def _next(self) -> bytes | None:
|
||||||
"""Get the next chunk from the iterator."""
|
"""Get the next chunk from the iterator."""
|
||||||
return await anext(self._stream, None)
|
return await anext(self._stream, None)
|
||||||
|
|
||||||
|
def abort(self) -> None:
|
||||||
|
"""Abort the reader."""
|
||||||
|
self._aborted = True
|
||||||
|
if self._next_future is not None:
|
||||||
|
self._next_future.cancel()
|
||||||
|
|
||||||
def read(self, n: int = -1, /) -> bytes:
|
def read(self, n: int = -1, /) -> bytes:
|
||||||
"""Read data from the iterator."""
|
"""Read data from the iterator."""
|
||||||
result = bytearray()
|
result = bytearray()
|
||||||
while n < 0 or len(result) < n:
|
while n < 0 or len(result) < n:
|
||||||
if not self._buffer:
|
if not self._buffer:
|
||||||
self._buffer = asyncio.run_coroutine_threadsafe(
|
self._next_future = asyncio.run_coroutine_threadsafe(
|
||||||
self._next(), self._hass.loop
|
self._next(), self._hass.loop
|
||||||
).result()
|
)
|
||||||
|
if self._aborted:
|
||||||
|
self._next_future.cancel()
|
||||||
|
raise AbortCipher
|
||||||
|
try:
|
||||||
|
self._buffer = self._next_future.result()
|
||||||
|
except CancelledError as err:
|
||||||
|
raise AbortCipher from err
|
||||||
self._pos = 0
|
self._pos = 0
|
||||||
if not self._buffer:
|
if not self._buffer:
|
||||||
# The stream is exhausted
|
# The stream is exhausted
|
||||||
@ -205,9 +221,11 @@ class AsyncIteratorWriter:
|
|||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize the wrapper."""
|
"""Initialize the wrapper."""
|
||||||
|
self._aborted = False
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
self._pos: int = 0
|
self._pos: int = 0
|
||||||
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
|
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
|
||||||
|
self._write_future: Future[bytes | None] | None = None
|
||||||
|
|
||||||
def __aiter__(self) -> Self:
|
def __aiter__(self) -> Self:
|
||||||
"""Return the iterator."""
|
"""Return the iterator."""
|
||||||
@ -219,13 +237,28 @@ class AsyncIteratorWriter:
|
|||||||
return data
|
return data
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
def abort(self) -> None:
|
||||||
|
"""Abort the writer."""
|
||||||
|
self._aborted = True
|
||||||
|
if self._write_future is not None:
|
||||||
|
self._write_future.cancel()
|
||||||
|
|
||||||
def tell(self) -> int:
|
def tell(self) -> int:
|
||||||
"""Return the current position in the iterator."""
|
"""Return the current position in the iterator."""
|
||||||
return self._pos
|
return self._pos
|
||||||
|
|
||||||
def write(self, s: bytes, /) -> int:
|
def write(self, s: bytes, /) -> int:
|
||||||
"""Write data to the iterator."""
|
"""Write data to the iterator."""
|
||||||
asyncio.run_coroutine_threadsafe(self._queue.put(s), self._hass.loop).result()
|
self._write_future = asyncio.run_coroutine_threadsafe(
|
||||||
|
self._queue.put(s), self._hass.loop
|
||||||
|
)
|
||||||
|
if self._aborted:
|
||||||
|
self._write_future.cancel()
|
||||||
|
raise AbortCipher
|
||||||
|
try:
|
||||||
|
self._write_future.result()
|
||||||
|
except CancelledError as err:
|
||||||
|
raise AbortCipher from err
|
||||||
self._pos += len(s)
|
self._pos += len(s)
|
||||||
return len(s)
|
return len(s)
|
||||||
|
|
||||||
@ -415,7 +448,9 @@ def _encrypt_backup(
|
|||||||
class _CipherWorkerStatus:
|
class _CipherWorkerStatus:
|
||||||
done: asyncio.Event
|
done: asyncio.Event
|
||||||
error: Exception | None = None
|
error: Exception | None = None
|
||||||
thread: ThreadWithException
|
reader: AsyncIteratorReader
|
||||||
|
thread: threading.Thread
|
||||||
|
writer: AsyncIteratorWriter
|
||||||
|
|
||||||
|
|
||||||
class _CipherBackupStreamer:
|
class _CipherBackupStreamer:
|
||||||
@ -468,11 +503,13 @@ class _CipherBackupStreamer:
|
|||||||
stream = await self._open_stream()
|
stream = await self._open_stream()
|
||||||
reader = AsyncIteratorReader(self._hass, stream)
|
reader = AsyncIteratorReader(self._hass, stream)
|
||||||
writer = AsyncIteratorWriter(self._hass)
|
writer = AsyncIteratorWriter(self._hass)
|
||||||
worker = ThreadWithException(
|
worker = threading.Thread(
|
||||||
target=self._cipher_func,
|
target=self._cipher_func,
|
||||||
args=[reader, writer, self._password, on_done, self.size(), self._nonces],
|
args=[reader, writer, self._password, on_done, self.size(), self._nonces],
|
||||||
)
|
)
|
||||||
worker_status = _CipherWorkerStatus(done=asyncio.Event(), thread=worker)
|
worker_status = _CipherWorkerStatus(
|
||||||
|
done=asyncio.Event(), reader=reader, thread=worker, writer=writer
|
||||||
|
)
|
||||||
self._workers.append(worker_status)
|
self._workers.append(worker_status)
|
||||||
worker.start()
|
worker.start()
|
||||||
return writer
|
return writer
|
||||||
@ -480,9 +517,8 @@ class _CipherBackupStreamer:
|
|||||||
async def wait(self) -> None:
|
async def wait(self) -> None:
|
||||||
"""Wait for the worker threads to finish."""
|
"""Wait for the worker threads to finish."""
|
||||||
for worker in self._workers:
|
for worker in self._workers:
|
||||||
if not worker.thread.is_alive():
|
worker.reader.abort()
|
||||||
continue
|
worker.writer.abort()
|
||||||
worker.thread.raise_exc(AbortCipher)
|
|
||||||
await asyncio.gather(*(worker.done.wait() for worker in self._workers))
|
await asyncio.gather(*(worker.done.wait() for worker in self._workers))
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import tarfile
|
import tarfile
|
||||||
@ -189,6 +190,73 @@ async def test_decrypted_backup_streamer(hass: HomeAssistant) -> None:
|
|||||||
assert decrypted_output == decrypted_backup_data + expected_padding
|
assert decrypted_output == decrypted_backup_data + expected_padding
|
||||||
|
|
||||||
|
|
||||||
|
async def test_decrypted_backup_streamer_interrupt_stuck_reader(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test the decrypted backup streamer."""
|
||||||
|
encrypted_backup_path = get_fixture_path("test_backups/c0cb53bd.tar", DOMAIN)
|
||||||
|
backup = AgentBackup(
|
||||||
|
addons=["addon_1", "addon_2"],
|
||||||
|
backup_id="1234",
|
||||||
|
date="2024-12-02T07:23:58.261875-05:00",
|
||||||
|
database_included=False,
|
||||||
|
extra_metadata={},
|
||||||
|
folders=[],
|
||||||
|
homeassistant_included=True,
|
||||||
|
homeassistant_version="2024.12.0.dev0",
|
||||||
|
name="test",
|
||||||
|
protected=True,
|
||||||
|
size=encrypted_backup_path.stat().st_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
stuck = asyncio.Event()
|
||||||
|
|
||||||
|
async def send_backup() -> AsyncIterator[bytes]:
|
||||||
|
f = encrypted_backup_path.open("rb")
|
||||||
|
while chunk := f.read(1024):
|
||||||
|
await stuck.wait()
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def open_backup() -> AsyncIterator[bytes]:
|
||||||
|
return send_backup()
|
||||||
|
|
||||||
|
decryptor = DecryptedBackupStreamer(hass, backup, open_backup, "hunter2")
|
||||||
|
await decryptor.open_stream()
|
||||||
|
await decryptor.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_decrypted_backup_streamer_interrupt_stuck_writer(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test the decrypted backup streamer."""
|
||||||
|
encrypted_backup_path = get_fixture_path("test_backups/c0cb53bd.tar", DOMAIN)
|
||||||
|
backup = AgentBackup(
|
||||||
|
addons=["addon_1", "addon_2"],
|
||||||
|
backup_id="1234",
|
||||||
|
date="2024-12-02T07:23:58.261875-05:00",
|
||||||
|
database_included=False,
|
||||||
|
extra_metadata={},
|
||||||
|
folders=[],
|
||||||
|
homeassistant_included=True,
|
||||||
|
homeassistant_version="2024.12.0.dev0",
|
||||||
|
name="test",
|
||||||
|
protected=True,
|
||||||
|
size=encrypted_backup_path.stat().st_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_backup() -> AsyncIterator[bytes]:
|
||||||
|
f = encrypted_backup_path.open("rb")
|
||||||
|
while chunk := f.read(1024):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def open_backup() -> AsyncIterator[bytes]:
|
||||||
|
return send_backup()
|
||||||
|
|
||||||
|
decryptor = DecryptedBackupStreamer(hass, backup, open_backup, "hunter2")
|
||||||
|
await decryptor.open_stream()
|
||||||
|
await decryptor.wait()
|
||||||
|
|
||||||
|
|
||||||
async def test_decrypted_backup_streamer_wrong_password(hass: HomeAssistant) -> None:
|
async def test_decrypted_backup_streamer_wrong_password(hass: HomeAssistant) -> None:
|
||||||
"""Test the decrypted backup streamer with wrong password."""
|
"""Test the decrypted backup streamer with wrong password."""
|
||||||
encrypted_backup_path = get_fixture_path("test_backups/c0cb53bd.tar", DOMAIN)
|
encrypted_backup_path = get_fixture_path("test_backups/c0cb53bd.tar", DOMAIN)
|
||||||
@ -279,6 +347,77 @@ async def test_encrypted_backup_streamer(hass: HomeAssistant) -> None:
|
|||||||
assert encrypted_output == encrypted_backup_data + expected_padding
|
assert encrypted_output == encrypted_backup_data + expected_padding
|
||||||
|
|
||||||
|
|
||||||
|
async def test_encrypted_backup_streamer_interrupt_stuck_reader(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test the encrypted backup streamer."""
|
||||||
|
decrypted_backup_path = get_fixture_path(
|
||||||
|
"test_backups/c0cb53bd.tar.decrypted", DOMAIN
|
||||||
|
)
|
||||||
|
backup = AgentBackup(
|
||||||
|
addons=["addon_1", "addon_2"],
|
||||||
|
backup_id="1234",
|
||||||
|
date="2024-12-02T07:23:58.261875-05:00",
|
||||||
|
database_included=False,
|
||||||
|
extra_metadata={},
|
||||||
|
folders=[],
|
||||||
|
homeassistant_included=True,
|
||||||
|
homeassistant_version="2024.12.0.dev0",
|
||||||
|
name="test",
|
||||||
|
protected=False,
|
||||||
|
size=decrypted_backup_path.stat().st_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
stuck = asyncio.Event()
|
||||||
|
|
||||||
|
async def send_backup() -> AsyncIterator[bytes]:
|
||||||
|
f = decrypted_backup_path.open("rb")
|
||||||
|
while chunk := f.read(1024):
|
||||||
|
await stuck.wait()
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def open_backup() -> AsyncIterator[bytes]:
|
||||||
|
return send_backup()
|
||||||
|
|
||||||
|
decryptor = EncryptedBackupStreamer(hass, backup, open_backup, "hunter2")
|
||||||
|
await decryptor.open_stream()
|
||||||
|
await decryptor.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_encrypted_backup_streamer_interrupt_stuck_writer(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test the encrypted backup streamer."""
|
||||||
|
decrypted_backup_path = get_fixture_path(
|
||||||
|
"test_backups/c0cb53bd.tar.decrypted", DOMAIN
|
||||||
|
)
|
||||||
|
backup = AgentBackup(
|
||||||
|
addons=["addon_1", "addon_2"],
|
||||||
|
backup_id="1234",
|
||||||
|
date="2024-12-02T07:23:58.261875-05:00",
|
||||||
|
database_included=False,
|
||||||
|
extra_metadata={},
|
||||||
|
folders=[],
|
||||||
|
homeassistant_included=True,
|
||||||
|
homeassistant_version="2024.12.0.dev0",
|
||||||
|
name="test",
|
||||||
|
protected=True,
|
||||||
|
size=decrypted_backup_path.stat().st_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_backup() -> AsyncIterator[bytes]:
|
||||||
|
f = decrypted_backup_path.open("rb")
|
||||||
|
while chunk := f.read(1024):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def open_backup() -> AsyncIterator[bytes]:
|
||||||
|
return send_backup()
|
||||||
|
|
||||||
|
decryptor = EncryptedBackupStreamer(hass, backup, open_backup, "hunter2")
|
||||||
|
await decryptor.open_stream()
|
||||||
|
await decryptor.wait()
|
||||||
|
|
||||||
|
|
||||||
async def test_encrypted_backup_streamer_random_nonce(hass: HomeAssistant) -> None:
|
async def test_encrypted_backup_streamer_random_nonce(hass: HomeAssistant) -> None:
|
||||||
"""Test the encrypted backup streamer."""
|
"""Test the encrypted backup streamer."""
|
||||||
decrypted_backup_path = get_fixture_path(
|
decrypted_backup_path = get_fixture_path(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user