Add upload capability to the backup integration (#128546)

* Add upload capability to the backup integration

* Limit context switch

* rename

* coverage for http

* Test receiving a backup file

* Update test_manager.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Joakim Sørensen 2024-11-12 15:27:53 +01:00 committed by GitHub
parent cb9cc0f801
commit ac0c75a598
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 195 additions and 7 deletions

View File

@ -2,23 +2,26 @@
from __future__ import annotations
import asyncio
from http import HTTPStatus
from typing import cast
from aiohttp import BodyPartReader
from aiohttp.hdrs import CONTENT_DISPOSITION
from aiohttp.web import FileResponse, Request, Response
from homeassistant.components.http import KEY_HASS, HomeAssistantView
from homeassistant.components.http import KEY_HASS, HomeAssistantView, require_admin
from homeassistant.core import HomeAssistant, callback
from homeassistant.util import slugify
from .const import DOMAIN
from .manager import BaseBackupManager
from .const import DATA_MANAGER
@callback
def async_register_http_views(hass: HomeAssistant) -> None:
"""Register the http views."""
hass.http.register_view(DownloadBackupView)
hass.http.register_view(UploadBackupView)
class DownloadBackupView(HomeAssistantView):
@ -36,7 +39,7 @@ class DownloadBackupView(HomeAssistantView):
if not request["hass_user"].is_admin:
return Response(status=HTTPStatus.UNAUTHORIZED)
manager: BaseBackupManager = request.app[KEY_HASS].data[DOMAIN]
manager = request.app[KEY_HASS].data[DATA_MANAGER]
backup = await manager.async_get_backup(slug=slug)
if backup is None or not backup.path.exists():
@ -48,3 +51,29 @@ class DownloadBackupView(HomeAssistantView):
CONTENT_DISPOSITION: f"attachment; filename={slugify(backup.name)}.tar"
},
)
class UploadBackupView(HomeAssistantView):
"""Generate backup view."""
url = "/api/backup/upload"
name = "api:backup:upload"
@require_admin
async def post(self, request: Request) -> Response:
"""Upload a backup file."""
manager = request.app[KEY_HASS].data[DATA_MANAGER]
reader = await request.multipart()
contents = cast(BodyPartReader, await reader.next())
try:
await manager.async_receive_backup(contents=contents)
except OSError as err:
return Response(
body=f"Can't write backup file {err}",
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
except asyncio.CancelledError:
return Response(status=HTTPStatus.INTERNAL_SERVER_ERROR)
return Response(status=HTTPStatus.CREATED)

View File

@ -9,11 +9,15 @@ import hashlib
import io
import json
from pathlib import Path
from queue import SimpleQueue
import shutil
import tarfile
from tarfile import TarError
from tempfile import TemporaryDirectory
import time
from typing import Any, Protocol, cast
import aiohttp
from securetar import SecureTarFile, atomic_contents_add
from homeassistant.backup_restore import RESTORE_BACKUP_FILE
@ -147,6 +151,15 @@ class BaseBackupManager(abc.ABC):
async def async_remove_backup(self, *, slug: str, **kwargs: Any) -> None:
"""Remove a backup."""
@abc.abstractmethod
async def async_receive_backup(
self,
*,
contents: aiohttp.BodyPartReader,
**kwargs: Any,
) -> None:
"""Receive and store a backup file from upload."""
class BackupManager(BaseBackupManager):
"""Backup manager for the Backup integration."""
@ -222,6 +235,63 @@ class BackupManager(BaseBackupManager):
LOGGER.debug("Removed backup located at %s", backup.path)
self.backups.pop(slug)
async def async_receive_backup(
self,
*,
contents: aiohttp.BodyPartReader,
**kwargs: Any,
) -> None:
"""Receive and store a backup file from upload."""
queue: SimpleQueue[tuple[bytes, asyncio.Future[None] | None] | None] = (
SimpleQueue()
)
temp_dir_handler = await self.hass.async_add_executor_job(TemporaryDirectory)
target_temp_file = Path(
temp_dir_handler.name, contents.filename or "backup.tar"
)
def _sync_queue_consumer() -> None:
with target_temp_file.open("wb") as file_handle:
while True:
if (_chunk_future := queue.get()) is None:
break
_chunk, _future = _chunk_future
if _future is not None:
self.hass.loop.call_soon_threadsafe(_future.set_result, None)
file_handle.write(_chunk)
fut: asyncio.Future[None] | None = None
try:
fut = self.hass.async_add_executor_job(_sync_queue_consumer)
megabytes_sending = 0
while chunk := await contents.read_chunk(BUF_SIZE):
megabytes_sending += 1
if megabytes_sending % 5 != 0:
queue.put_nowait((chunk, None))
continue
chunk_future = self.hass.loop.create_future()
queue.put_nowait((chunk, chunk_future))
await asyncio.wait(
(fut, chunk_future),
return_when=asyncio.FIRST_COMPLETED,
)
if fut.done():
# The executor job failed
break
queue.put_nowait(None) # terminate queue consumer
finally:
if fut is not None:
await fut
def _move_and_cleanup() -> None:
shutil.move(target_temp_file, self.backup_dir / target_temp_file.name)
temp_dir_handler.cleanup()
await self.hass.async_add_executor_job(_move_and_cleanup)
await self.load_backups()
async def async_create_backup(self, **kwargs: Any) -> Backup:
"""Generate a backup."""
if self.backing_up:

View File

@ -1,8 +1,11 @@
"""Tests for the Backup integration."""
import asyncio
from io import StringIO
from unittest.mock import patch
from aiohttp import web
import pytest
from homeassistant.core import HomeAssistant
@ -49,12 +52,12 @@ async def test_downloading_backup_not_found(
assert resp.status == 404
async def test_non_admin(
async def test_downloading_as_non_admin(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
hass_admin_user: MockUser,
) -> None:
"""Test downloading a backup file that does not exist."""
"""Test downloading a backup file when you are not an admin."""
hass_admin_user.groups = []
await setup_backup_integration(hass)
@ -62,3 +65,53 @@ async def test_non_admin(
resp = await client.get("/api/backup/download/abc123")
assert resp.status == 401
async def test_uploading_a_backup_file(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
) -> None:
"""Test uploading a backup file."""
await setup_backup_integration(hass)
client = await hass_client()
with patch(
"homeassistant.components.backup.manager.BackupManager.async_receive_backup",
) as async_receive_backup_mock:
resp = await client.post(
"/api/backup/upload",
data={"file": StringIO("test")},
)
assert resp.status == 201
assert async_receive_backup_mock.called
@pytest.mark.parametrize(
("error", "message"),
[
(OSError("Boom!"), "Can't write backup file Boom!"),
(asyncio.CancelledError("Boom!"), ""),
],
)
async def test_error_handling_uploading_a_backup_file(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
error: Exception,
message: str,
) -> None:
"""Test error handling when uploading a backup file."""
await setup_backup_integration(hass)
client = await hass_client()
with patch(
"homeassistant.components.backup.manager.BackupManager.async_receive_backup",
side_effect=error,
):
resp = await client.post(
"/api/backup/upload",
data={"file": StringIO("test")},
)
assert resp.status == 500
assert await resp.text() == message

View File

@ -3,8 +3,10 @@
from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, mock_open, patch
import aiohttp
from multidict import CIMultiDict, CIMultiDictProxy
import pytest
from homeassistant.components.backup import BackupManager
@ -335,6 +337,40 @@ async def test_loading_platforms_when_running_async_post_backup_actions(
assert "Loaded 1 platforms" in caplog.text
async def test_async_receive_backup(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test receiving a backup file."""
manager = BackupManager(hass)
size = 2 * 2**16
protocol = Mock(_reading_paused=False)
stream = aiohttp.StreamReader(protocol, 2**16)
stream.feed_data(b"0" * size + b"\r\n--:--")
stream.feed_eof()
open_mock = mock_open()
with patch("pathlib.Path.open", open_mock), patch("shutil.move") as mover_mock:
await manager.async_receive_backup(
contents=aiohttp.BodyPartReader(
b"--:",
CIMultiDictProxy(
CIMultiDict(
{
aiohttp.hdrs.CONTENT_DISPOSITION: "attachment; filename=abc123.tar"
}
)
),
stream,
)
)
assert open_mock.call_count == 1
assert mover_mock.call_count == 1
assert mover_mock.mock_calls[0].args[1].name == "abc123.tar"
async def test_async_trigger_restore(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,