mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 17:57:55 +00:00
Add upload capability to the backup integration (#128546)
* Add upload capability to the backup integration * Limit context switch * rename * coverage for http * Test receiving a backup file * Update test_manager.py Co-authored-by: Martin Hjelmare <marhje52@gmail.com> --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
cb9cc0f801
commit
ac0c75a598
@ -2,23 +2,26 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
from aiohttp import BodyPartReader
|
||||
from aiohttp.hdrs import CONTENT_DISPOSITION
|
||||
from aiohttp.web import FileResponse, Request, Response
|
||||
|
||||
from homeassistant.components.http import KEY_HASS, HomeAssistantView
|
||||
from homeassistant.components.http import KEY_HASS, HomeAssistantView, require_admin
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.util import slugify
|
||||
|
||||
from .const import DOMAIN
|
||||
from .manager import BaseBackupManager
|
||||
from .const import DATA_MANAGER
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_http_views(hass: HomeAssistant) -> None:
|
||||
"""Register the http views."""
|
||||
hass.http.register_view(DownloadBackupView)
|
||||
hass.http.register_view(UploadBackupView)
|
||||
|
||||
|
||||
class DownloadBackupView(HomeAssistantView):
|
||||
@ -36,7 +39,7 @@ class DownloadBackupView(HomeAssistantView):
|
||||
if not request["hass_user"].is_admin:
|
||||
return Response(status=HTTPStatus.UNAUTHORIZED)
|
||||
|
||||
manager: BaseBackupManager = request.app[KEY_HASS].data[DOMAIN]
|
||||
manager = request.app[KEY_HASS].data[DATA_MANAGER]
|
||||
backup = await manager.async_get_backup(slug=slug)
|
||||
|
||||
if backup is None or not backup.path.exists():
|
||||
@ -48,3 +51,29 @@ class DownloadBackupView(HomeAssistantView):
|
||||
CONTENT_DISPOSITION: f"attachment; filename={slugify(backup.name)}.tar"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class UploadBackupView(HomeAssistantView):
|
||||
"""Generate backup view."""
|
||||
|
||||
url = "/api/backup/upload"
|
||||
name = "api:backup:upload"
|
||||
|
||||
@require_admin
|
||||
async def post(self, request: Request) -> Response:
|
||||
"""Upload a backup file."""
|
||||
manager = request.app[KEY_HASS].data[DATA_MANAGER]
|
||||
reader = await request.multipart()
|
||||
contents = cast(BodyPartReader, await reader.next())
|
||||
|
||||
try:
|
||||
await manager.async_receive_backup(contents=contents)
|
||||
except OSError as err:
|
||||
return Response(
|
||||
body=f"Can't write backup file {err}",
|
||||
status=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return Response(status=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
return Response(status=HTTPStatus.CREATED)
|
||||
|
@ -9,11 +9,15 @@ import hashlib
|
||||
import io
|
||||
import json
|
||||
from pathlib import Path
|
||||
from queue import SimpleQueue
|
||||
import shutil
|
||||
import tarfile
|
||||
from tarfile import TarError
|
||||
from tempfile import TemporaryDirectory
|
||||
import time
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import aiohttp
|
||||
from securetar import SecureTarFile, atomic_contents_add
|
||||
|
||||
from homeassistant.backup_restore import RESTORE_BACKUP_FILE
|
||||
@ -147,6 +151,15 @@ class BaseBackupManager(abc.ABC):
|
||||
async def async_remove_backup(self, *, slug: str, **kwargs: Any) -> None:
|
||||
"""Remove a backup."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_receive_backup(
|
||||
self,
|
||||
*,
|
||||
contents: aiohttp.BodyPartReader,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Receive and store a backup file from upload."""
|
||||
|
||||
|
||||
class BackupManager(BaseBackupManager):
|
||||
"""Backup manager for the Backup integration."""
|
||||
@ -222,6 +235,63 @@ class BackupManager(BaseBackupManager):
|
||||
LOGGER.debug("Removed backup located at %s", backup.path)
|
||||
self.backups.pop(slug)
|
||||
|
||||
async def async_receive_backup(
|
||||
self,
|
||||
*,
|
||||
contents: aiohttp.BodyPartReader,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Receive and store a backup file from upload."""
|
||||
queue: SimpleQueue[tuple[bytes, asyncio.Future[None] | None] | None] = (
|
||||
SimpleQueue()
|
||||
)
|
||||
temp_dir_handler = await self.hass.async_add_executor_job(TemporaryDirectory)
|
||||
target_temp_file = Path(
|
||||
temp_dir_handler.name, contents.filename or "backup.tar"
|
||||
)
|
||||
|
||||
def _sync_queue_consumer() -> None:
|
||||
with target_temp_file.open("wb") as file_handle:
|
||||
while True:
|
||||
if (_chunk_future := queue.get()) is None:
|
||||
break
|
||||
_chunk, _future = _chunk_future
|
||||
if _future is not None:
|
||||
self.hass.loop.call_soon_threadsafe(_future.set_result, None)
|
||||
file_handle.write(_chunk)
|
||||
|
||||
fut: asyncio.Future[None] | None = None
|
||||
try:
|
||||
fut = self.hass.async_add_executor_job(_sync_queue_consumer)
|
||||
megabytes_sending = 0
|
||||
while chunk := await contents.read_chunk(BUF_SIZE):
|
||||
megabytes_sending += 1
|
||||
if megabytes_sending % 5 != 0:
|
||||
queue.put_nowait((chunk, None))
|
||||
continue
|
||||
|
||||
chunk_future = self.hass.loop.create_future()
|
||||
queue.put_nowait((chunk, chunk_future))
|
||||
await asyncio.wait(
|
||||
(fut, chunk_future),
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
if fut.done():
|
||||
# The executor job failed
|
||||
break
|
||||
|
||||
queue.put_nowait(None) # terminate queue consumer
|
||||
finally:
|
||||
if fut is not None:
|
||||
await fut
|
||||
|
||||
def _move_and_cleanup() -> None:
|
||||
shutil.move(target_temp_file, self.backup_dir / target_temp_file.name)
|
||||
temp_dir_handler.cleanup()
|
||||
|
||||
await self.hass.async_add_executor_job(_move_and_cleanup)
|
||||
await self.load_backups()
|
||||
|
||||
async def async_create_backup(self, **kwargs: Any) -> Backup:
|
||||
"""Generate a backup."""
|
||||
if self.backing_up:
|
||||
|
@ -1,8 +1,11 @@
|
||||
"""Tests for the Backup integration."""
|
||||
|
||||
import asyncio
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
from aiohttp import web
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
@ -49,12 +52,12 @@ async def test_downloading_backup_not_found(
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
async def test_non_admin(
|
||||
async def test_downloading_as_non_admin(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
hass_admin_user: MockUser,
|
||||
) -> None:
|
||||
"""Test downloading a backup file that does not exist."""
|
||||
"""Test downloading a backup file when you are not an admin."""
|
||||
hass_admin_user.groups = []
|
||||
await setup_backup_integration(hass)
|
||||
|
||||
@ -62,3 +65,53 @@ async def test_non_admin(
|
||||
|
||||
resp = await client.get("/api/backup/download/abc123")
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
async def test_uploading_a_backup_file(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test uploading a backup file."""
|
||||
await setup_backup_integration(hass)
|
||||
|
||||
client = await hass_client()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.backup.manager.BackupManager.async_receive_backup",
|
||||
) as async_receive_backup_mock:
|
||||
resp = await client.post(
|
||||
"/api/backup/upload",
|
||||
data={"file": StringIO("test")},
|
||||
)
|
||||
assert resp.status == 201
|
||||
assert async_receive_backup_mock.called
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("error", "message"),
|
||||
[
|
||||
(OSError("Boom!"), "Can't write backup file Boom!"),
|
||||
(asyncio.CancelledError("Boom!"), ""),
|
||||
],
|
||||
)
|
||||
async def test_error_handling_uploading_a_backup_file(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
error: Exception,
|
||||
message: str,
|
||||
) -> None:
|
||||
"""Test error handling when uploading a backup file."""
|
||||
await setup_backup_integration(hass)
|
||||
|
||||
client = await hass_client()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.backup.manager.BackupManager.async_receive_backup",
|
||||
side_effect=error,
|
||||
):
|
||||
resp = await client.post(
|
||||
"/api/backup/upload",
|
||||
data={"file": StringIO("test")},
|
||||
)
|
||||
assert resp.status == 500
|
||||
assert await resp.text() == message
|
||||
|
@ -3,8 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, mock_open, patch
|
||||
|
||||
import aiohttp
|
||||
from multidict import CIMultiDict, CIMultiDictProxy
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.backup import BackupManager
|
||||
@ -335,6 +337,40 @@ async def test_loading_platforms_when_running_async_post_backup_actions(
|
||||
assert "Loaded 1 platforms" in caplog.text
|
||||
|
||||
|
||||
async def test_async_receive_backup(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test receiving a backup file."""
|
||||
manager = BackupManager(hass)
|
||||
|
||||
size = 2 * 2**16
|
||||
protocol = Mock(_reading_paused=False)
|
||||
stream = aiohttp.StreamReader(protocol, 2**16)
|
||||
stream.feed_data(b"0" * size + b"\r\n--:--")
|
||||
stream.feed_eof()
|
||||
|
||||
open_mock = mock_open()
|
||||
|
||||
with patch("pathlib.Path.open", open_mock), patch("shutil.move") as mover_mock:
|
||||
await manager.async_receive_backup(
|
||||
contents=aiohttp.BodyPartReader(
|
||||
b"--:",
|
||||
CIMultiDictProxy(
|
||||
CIMultiDict(
|
||||
{
|
||||
aiohttp.hdrs.CONTENT_DISPOSITION: "attachment; filename=abc123.tar"
|
||||
}
|
||||
)
|
||||
),
|
||||
stream,
|
||||
)
|
||||
)
|
||||
assert open_mock.call_count == 1
|
||||
assert mover_mock.call_count == 1
|
||||
assert mover_mock.mock_calls[0].args[1].name == "abc123.tar"
|
||||
|
||||
|
||||
async def test_async_trigger_restore(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
|
Loading…
x
Reference in New Issue
Block a user