mirror of
https://github.com/home-assistant/core.git
synced 2025-04-26 02:07:54 +00:00
Add upload capability to the backup integration (#128546)
* Add upload capability to the backup integration * Limit context switch * rename * coverage for http * Test receiving a backup file * Update test_manager.py Co-authored-by: Martin Hjelmare <marhje52@gmail.com> --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
cb9cc0f801
commit
ac0c75a598
@ -2,23 +2,26 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from aiohttp import BodyPartReader
|
||||||
from aiohttp.hdrs import CONTENT_DISPOSITION
|
from aiohttp.hdrs import CONTENT_DISPOSITION
|
||||||
from aiohttp.web import FileResponse, Request, Response
|
from aiohttp.web import FileResponse, Request, Response
|
||||||
|
|
||||||
from homeassistant.components.http import KEY_HASS, HomeAssistantView
|
from homeassistant.components.http import KEY_HASS, HomeAssistantView, require_admin
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.util import slugify
|
from homeassistant.util import slugify
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DATA_MANAGER
|
||||||
from .manager import BaseBackupManager
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_register_http_views(hass: HomeAssistant) -> None:
|
def async_register_http_views(hass: HomeAssistant) -> None:
|
||||||
"""Register the http views."""
|
"""Register the http views."""
|
||||||
hass.http.register_view(DownloadBackupView)
|
hass.http.register_view(DownloadBackupView)
|
||||||
|
hass.http.register_view(UploadBackupView)
|
||||||
|
|
||||||
|
|
||||||
class DownloadBackupView(HomeAssistantView):
|
class DownloadBackupView(HomeAssistantView):
|
||||||
@ -36,7 +39,7 @@ class DownloadBackupView(HomeAssistantView):
|
|||||||
if not request["hass_user"].is_admin:
|
if not request["hass_user"].is_admin:
|
||||||
return Response(status=HTTPStatus.UNAUTHORIZED)
|
return Response(status=HTTPStatus.UNAUTHORIZED)
|
||||||
|
|
||||||
manager: BaseBackupManager = request.app[KEY_HASS].data[DOMAIN]
|
manager = request.app[KEY_HASS].data[DATA_MANAGER]
|
||||||
backup = await manager.async_get_backup(slug=slug)
|
backup = await manager.async_get_backup(slug=slug)
|
||||||
|
|
||||||
if backup is None or not backup.path.exists():
|
if backup is None or not backup.path.exists():
|
||||||
@ -48,3 +51,29 @@ class DownloadBackupView(HomeAssistantView):
|
|||||||
CONTENT_DISPOSITION: f"attachment; filename={slugify(backup.name)}.tar"
|
CONTENT_DISPOSITION: f"attachment; filename={slugify(backup.name)}.tar"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UploadBackupView(HomeAssistantView):
|
||||||
|
"""Generate backup view."""
|
||||||
|
|
||||||
|
url = "/api/backup/upload"
|
||||||
|
name = "api:backup:upload"
|
||||||
|
|
||||||
|
@require_admin
|
||||||
|
async def post(self, request: Request) -> Response:
|
||||||
|
"""Upload a backup file."""
|
||||||
|
manager = request.app[KEY_HASS].data[DATA_MANAGER]
|
||||||
|
reader = await request.multipart()
|
||||||
|
contents = cast(BodyPartReader, await reader.next())
|
||||||
|
|
||||||
|
try:
|
||||||
|
await manager.async_receive_backup(contents=contents)
|
||||||
|
except OSError as err:
|
||||||
|
return Response(
|
||||||
|
body=f"Can't write backup file {err}",
|
||||||
|
status=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return Response(status=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
return Response(status=HTTPStatus.CREATED)
|
||||||
|
@ -9,11 +9,15 @@ import hashlib
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from queue import SimpleQueue
|
||||||
|
import shutil
|
||||||
import tarfile
|
import tarfile
|
||||||
from tarfile import TarError
|
from tarfile import TarError
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
import time
|
import time
|
||||||
from typing import Any, Protocol, cast
|
from typing import Any, Protocol, cast
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
from securetar import SecureTarFile, atomic_contents_add
|
from securetar import SecureTarFile, atomic_contents_add
|
||||||
|
|
||||||
from homeassistant.backup_restore import RESTORE_BACKUP_FILE
|
from homeassistant.backup_restore import RESTORE_BACKUP_FILE
|
||||||
@ -147,6 +151,15 @@ class BaseBackupManager(abc.ABC):
|
|||||||
async def async_remove_backup(self, *, slug: str, **kwargs: Any) -> None:
|
async def async_remove_backup(self, *, slug: str, **kwargs: Any) -> None:
|
||||||
"""Remove a backup."""
|
"""Remove a backup."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def async_receive_backup(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
contents: aiohttp.BodyPartReader,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Receive and store a backup file from upload."""
|
||||||
|
|
||||||
|
|
||||||
class BackupManager(BaseBackupManager):
|
class BackupManager(BaseBackupManager):
|
||||||
"""Backup manager for the Backup integration."""
|
"""Backup manager for the Backup integration."""
|
||||||
@ -222,6 +235,63 @@ class BackupManager(BaseBackupManager):
|
|||||||
LOGGER.debug("Removed backup located at %s", backup.path)
|
LOGGER.debug("Removed backup located at %s", backup.path)
|
||||||
self.backups.pop(slug)
|
self.backups.pop(slug)
|
||||||
|
|
||||||
|
async def async_receive_backup(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
contents: aiohttp.BodyPartReader,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Receive and store a backup file from upload."""
|
||||||
|
queue: SimpleQueue[tuple[bytes, asyncio.Future[None] | None] | None] = (
|
||||||
|
SimpleQueue()
|
||||||
|
)
|
||||||
|
temp_dir_handler = await self.hass.async_add_executor_job(TemporaryDirectory)
|
||||||
|
target_temp_file = Path(
|
||||||
|
temp_dir_handler.name, contents.filename or "backup.tar"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sync_queue_consumer() -> None:
|
||||||
|
with target_temp_file.open("wb") as file_handle:
|
||||||
|
while True:
|
||||||
|
if (_chunk_future := queue.get()) is None:
|
||||||
|
break
|
||||||
|
_chunk, _future = _chunk_future
|
||||||
|
if _future is not None:
|
||||||
|
self.hass.loop.call_soon_threadsafe(_future.set_result, None)
|
||||||
|
file_handle.write(_chunk)
|
||||||
|
|
||||||
|
fut: asyncio.Future[None] | None = None
|
||||||
|
try:
|
||||||
|
fut = self.hass.async_add_executor_job(_sync_queue_consumer)
|
||||||
|
megabytes_sending = 0
|
||||||
|
while chunk := await contents.read_chunk(BUF_SIZE):
|
||||||
|
megabytes_sending += 1
|
||||||
|
if megabytes_sending % 5 != 0:
|
||||||
|
queue.put_nowait((chunk, None))
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_future = self.hass.loop.create_future()
|
||||||
|
queue.put_nowait((chunk, chunk_future))
|
||||||
|
await asyncio.wait(
|
||||||
|
(fut, chunk_future),
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
if fut.done():
|
||||||
|
# The executor job failed
|
||||||
|
break
|
||||||
|
|
||||||
|
queue.put_nowait(None) # terminate queue consumer
|
||||||
|
finally:
|
||||||
|
if fut is not None:
|
||||||
|
await fut
|
||||||
|
|
||||||
|
def _move_and_cleanup() -> None:
|
||||||
|
shutil.move(target_temp_file, self.backup_dir / target_temp_file.name)
|
||||||
|
temp_dir_handler.cleanup()
|
||||||
|
|
||||||
|
await self.hass.async_add_executor_job(_move_and_cleanup)
|
||||||
|
await self.load_backups()
|
||||||
|
|
||||||
async def async_create_backup(self, **kwargs: Any) -> Backup:
|
async def async_create_backup(self, **kwargs: Any) -> Backup:
|
||||||
"""Generate a backup."""
|
"""Generate a backup."""
|
||||||
if self.backing_up:
|
if self.backing_up:
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
"""Tests for the Backup integration."""
|
"""Tests for the Backup integration."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from io import StringIO
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
import pytest
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
@ -49,12 +52,12 @@ async def test_downloading_backup_not_found(
|
|||||||
assert resp.status == 404
|
assert resp.status == 404
|
||||||
|
|
||||||
|
|
||||||
async def test_non_admin(
|
async def test_downloading_as_non_admin(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
hass_admin_user: MockUser,
|
hass_admin_user: MockUser,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test downloading a backup file that does not exist."""
|
"""Test downloading a backup file when you are not an admin."""
|
||||||
hass_admin_user.groups = []
|
hass_admin_user.groups = []
|
||||||
await setup_backup_integration(hass)
|
await setup_backup_integration(hass)
|
||||||
|
|
||||||
@ -62,3 +65,53 @@ async def test_non_admin(
|
|||||||
|
|
||||||
resp = await client.get("/api/backup/download/abc123")
|
resp = await client.get("/api/backup/download/abc123")
|
||||||
assert resp.status == 401
|
assert resp.status == 401
|
||||||
|
|
||||||
|
|
||||||
|
async def test_uploading_a_backup_file(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test uploading a backup file."""
|
||||||
|
await setup_backup_integration(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.backup.manager.BackupManager.async_receive_backup",
|
||||||
|
) as async_receive_backup_mock:
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/backup/upload",
|
||||||
|
data={"file": StringIO("test")},
|
||||||
|
)
|
||||||
|
assert resp.status == 201
|
||||||
|
assert async_receive_backup_mock.called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("error", "message"),
|
||||||
|
[
|
||||||
|
(OSError("Boom!"), "Can't write backup file Boom!"),
|
||||||
|
(asyncio.CancelledError("Boom!"), ""),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_error_handling_uploading_a_backup_file(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
error: Exception,
|
||||||
|
message: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test error handling when uploading a backup file."""
|
||||||
|
await setup_backup_integration(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.backup.manager.BackupManager.async_receive_backup",
|
||||||
|
side_effect=error,
|
||||||
|
):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/backup/upload",
|
||||||
|
data={"file": StringIO("test")},
|
||||||
|
)
|
||||||
|
assert resp.status == 500
|
||||||
|
assert await resp.text() == message
|
||||||
|
@ -3,8 +3,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
from unittest.mock import AsyncMock, MagicMock, Mock, mock_open, patch
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from multidict import CIMultiDict, CIMultiDictProxy
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.backup import BackupManager
|
from homeassistant.components.backup import BackupManager
|
||||||
@ -335,6 +337,40 @@ async def test_loading_platforms_when_running_async_post_backup_actions(
|
|||||||
assert "Loaded 1 platforms" in caplog.text
|
assert "Loaded 1 platforms" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_receive_backup(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test receiving a backup file."""
|
||||||
|
manager = BackupManager(hass)
|
||||||
|
|
||||||
|
size = 2 * 2**16
|
||||||
|
protocol = Mock(_reading_paused=False)
|
||||||
|
stream = aiohttp.StreamReader(protocol, 2**16)
|
||||||
|
stream.feed_data(b"0" * size + b"\r\n--:--")
|
||||||
|
stream.feed_eof()
|
||||||
|
|
||||||
|
open_mock = mock_open()
|
||||||
|
|
||||||
|
with patch("pathlib.Path.open", open_mock), patch("shutil.move") as mover_mock:
|
||||||
|
await manager.async_receive_backup(
|
||||||
|
contents=aiohttp.BodyPartReader(
|
||||||
|
b"--:",
|
||||||
|
CIMultiDictProxy(
|
||||||
|
CIMultiDict(
|
||||||
|
{
|
||||||
|
aiohttp.hdrs.CONTENT_DISPOSITION: "attachment; filename=abc123.tar"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
),
|
||||||
|
stream,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert open_mock.call_count == 1
|
||||||
|
assert mover_mock.call_count == 1
|
||||||
|
assert mover_mock.mock_calls[0].args[1].name == "abc123.tar"
|
||||||
|
|
||||||
|
|
||||||
async def test_async_trigger_restore(
|
async def test_async_trigger_restore(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
caplog: pytest.LogCaptureFixture,
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user