Move cloud backup upload/download handlers to lib (#137416)

* Move cloud backup upload/download handlers to lib

* Update backup.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Joakim Sørensen 2025-02-06 07:32:46 +01:00 committed by GitHub
parent 3b871afcc4
commit 283b0908c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 39 additions and 206 deletions

View File

@ -8,16 +8,11 @@ from collections.abc import AsyncIterator, Callable, Coroutine, Mapping
import hashlib import hashlib
import logging import logging
import random import random
from typing import Any from typing import Any, Literal
from aiohttp import ClientError, ClientTimeout from aiohttp import ClientError
from hass_nabucasa import Cloud, CloudError from hass_nabucasa import Cloud, CloudError
from hass_nabucasa.cloud_api import ( from hass_nabucasa.cloud_api import async_files_delete_file, async_files_list
async_files_delete_file,
async_files_download_details,
async_files_list,
async_files_upload_details,
)
from homeassistant.components.backup import AgentBackup, BackupAgent, BackupAgentError from homeassistant.components.backup import AgentBackup, BackupAgent, BackupAgentError
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
@ -28,7 +23,7 @@ from .client import CloudClient
from .const import DATA_CLOUD, DOMAIN, EVENT_CLOUD_EVENT from .const import DATA_CLOUD, DOMAIN, EVENT_CLOUD_EVENT
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_STORAGE_BACKUP = "backup" _STORAGE_BACKUP: Literal["backup"] = "backup"
_RETRY_LIMIT = 5 _RETRY_LIMIT = 5
_RETRY_SECONDS_MIN = 60 _RETRY_SECONDS_MIN = 60
_RETRY_SECONDS_MAX = 600 _RETRY_SECONDS_MAX = 600
@ -109,63 +104,14 @@ class CloudBackupAgent(BackupAgent):
raise BackupAgentError("Backup not found") raise BackupAgentError("Backup not found")
try: try:
details = await async_files_download_details( content = await self._cloud.files.download(
self._cloud,
storage_type=_STORAGE_BACKUP, storage_type=_STORAGE_BACKUP,
filename=self._get_backup_filename(), filename=self._get_backup_filename(),
) )
except (ClientError, CloudError) as err: except CloudError as err:
raise BackupAgentError("Failed to get download details") from err raise BackupAgentError(f"Failed to download backup: {err}") from err
try: return ChunkAsyncStreamIterator(content)
resp = await self._cloud.websession.get(
details["url"],
timeout=ClientTimeout(connect=10.0, total=43200.0), # 43200s == 12h
)
resp.raise_for_status()
except ClientError as err:
raise BackupAgentError("Failed to download backup") from err
return ChunkAsyncStreamIterator(resp.content)
async def _async_do_upload_backup(
self,
*,
open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]],
filename: str,
base64md5hash: str,
metadata: dict[str, Any],
size: int,
) -> None:
"""Upload a backup."""
try:
details = await async_files_upload_details(
self._cloud,
storage_type=_STORAGE_BACKUP,
filename=filename,
metadata=metadata,
size=size,
base64md5hash=base64md5hash,
)
except (ClientError, CloudError) as err:
raise BackupAgentError("Failed to get upload details") from err
try:
upload_status = await self._cloud.websession.put(
details["url"],
data=await open_stream(),
headers=details["headers"] | {"content-length": str(size)},
timeout=ClientTimeout(connect=10.0, total=43200.0), # 43200s == 12h
)
_LOGGER.log(
logging.DEBUG if upload_status.status < 400 else logging.WARNING,
"Backup upload status: %s",
upload_status.status,
)
upload_status.raise_for_status()
except (TimeoutError, ClientError) as err:
raise BackupAgentError("Failed to upload backup") from err
async def async_upload_backup( async def async_upload_backup(
self, self,
@ -190,7 +136,8 @@ class CloudBackupAgent(BackupAgent):
tries = 1 tries = 1
while tries <= _RETRY_LIMIT: while tries <= _RETRY_LIMIT:
try: try:
await self._async_do_upload_backup( await self._cloud.files.upload(
storage_type=_STORAGE_BACKUP,
open_stream=open_stream, open_stream=open_stream,
filename=filename, filename=filename,
base64md5hash=base64md5hash, base64md5hash=base64md5hash,
@ -198,9 +145,9 @@ class CloudBackupAgent(BackupAgent):
size=size, size=size,
) )
break break
except BackupAgentError as err: except CloudError as err:
if tries == _RETRY_LIMIT: if tries == _RETRY_LIMIT:
raise raise BackupAgentError("Failed to upload backup") from err
tries += 1 tries += 1
retry_timer = random.randint(_RETRY_SECONDS_MIN, _RETRY_SECONDS_MAX) retry_timer = random.randint(_RETRY_SECONDS_MIN, _RETRY_SECONDS_MAX)
_LOGGER.info( _LOGGER.info(

View File

@ -9,6 +9,7 @@ from hass_nabucasa import Cloud
from hass_nabucasa.auth import CognitoAuth from hass_nabucasa.auth import CognitoAuth
from hass_nabucasa.cloudhooks import Cloudhooks from hass_nabucasa.cloudhooks import Cloudhooks
from hass_nabucasa.const import DEFAULT_SERVERS, DEFAULT_VALUES, STATE_CONNECTED from hass_nabucasa.const import DEFAULT_SERVERS, DEFAULT_VALUES, STATE_CONNECTED
from hass_nabucasa.files import Files
from hass_nabucasa.google_report_state import GoogleReportState from hass_nabucasa.google_report_state import GoogleReportState
from hass_nabucasa.ice_servers import IceServers from hass_nabucasa.ice_servers import IceServers
from hass_nabucasa.iot import CloudIoT from hass_nabucasa.iot import CloudIoT
@ -68,6 +69,7 @@ async def cloud_fixture() -> AsyncGenerator[MagicMock]:
spec=CloudIoT, last_disconnect_reason=None, state=STATE_CONNECTED spec=CloudIoT, last_disconnect_reason=None, state=STATE_CONNECTED
) )
mock_cloud.voice = MagicMock(spec=Voice) mock_cloud.voice = MagicMock(spec=Voice)
mock_cloud.files = MagicMock(spec=Files)
mock_cloud.started = None mock_cloud.started = None
mock_cloud.ice_servers = MagicMock( mock_cloud.ice_servers = MagicMock(
spec=IceServers, spec=IceServers,

View File

@ -1,14 +1,14 @@
"""Test the cloud backup platform.""" """Test the cloud backup platform."""
from collections.abc import AsyncGenerator, AsyncIterator, Generator from collections.abc import AsyncGenerator, Generator
from io import StringIO from io import StringIO
from typing import Any from typing import Any
from unittest.mock import Mock, PropertyMock, patch from unittest.mock import Mock, PropertyMock, patch
from aiohttp import ClientError from aiohttp import ClientError
from hass_nabucasa import CloudError from hass_nabucasa import CloudError
from hass_nabucasa.files import FilesError
import pytest import pytest
from yarl import URL
from homeassistant.components.backup import ( from homeassistant.components.backup import (
DOMAIN as BACKUP_DOMAIN, DOMAIN as BACKUP_DOMAIN,
@ -22,11 +22,20 @@ from homeassistant.components.cloud.const import EVENT_CLOUD_EVENT
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util.aiohttp import MockStreamReader
from tests.test_util.aiohttp import AiohttpClientMocker from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator, MagicMock, WebSocketGenerator from tests.typing import ClientSessionGenerator, MagicMock, WebSocketGenerator
class MockStreamReaderChunked(MockStreamReader):
"""Mock a stream reader with simulated chunked data."""
async def readchunk(self) -> tuple[bytes, bool]:
"""Read bytes."""
return (self._content.read(), False)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
async def setup_integration( async def setup_integration(
hass: HomeAssistant, hass: HomeAssistant,
@ -55,49 +64,6 @@ def mock_delete_file() -> Generator[MagicMock]:
yield delete_file yield delete_file
@pytest.fixture
def mock_get_download_details() -> Generator[MagicMock]:
"""Mock list files."""
with patch(
"homeassistant.components.cloud.backup.async_files_download_details",
spec_set=True,
) as download_details:
download_details.return_value = {
"url": (
"https://blabla.cloudflarestorage.com/blabla/backup/"
"462e16810d6841228828d9dd2f9e341e.tar?X-Amz-Algorithm=blah"
),
}
yield download_details
@pytest.fixture
def mock_get_upload_details() -> Generator[MagicMock]:
"""Mock list files."""
with patch(
"homeassistant.components.cloud.backup.async_files_upload_details",
spec_set=True,
) as download_details:
download_details.return_value = {
"url": (
"https://blabla.cloudflarestorage.com/blabla/backup/"
"ea5c969e492c49df89d432a1483b8dc3.tar?X-Amz-Algorithm=blah"
),
"headers": {
"content-md5": "HOhSM3WZkpHRYGiz4YRGIQ==",
"x-amz-meta-storage-type": "backup",
"x-amz-meta-b64json": (
"eyJhZGRvbnMiOltdLCJiYWNrdXBfaWQiOiJjNDNiNWU2MCIsImRhdGUiOiIyMDI0LT"
"EyLTAzVDA0OjI1OjUwLjMyMDcwMy0wNTowMCIsImRhdGFiYXNlX2luY2x1ZGVkIjpm"
"YWxzZSwiZm9sZGVycyI6W10sImhvbWVhc3Npc3RhbnRfaW5jbHVkZWQiOnRydWUsIm"
"hvbWVhc3Npc3RhbnRfdmVyc2lvbiI6IjIwMjQuMTIuMC5kZXYwIiwibmFtZSI6ImVy"
"aWsiLCJwcm90ZWN0ZWQiOnRydWUsInNpemUiOjM1NjI0OTYwfQ=="
),
},
}
yield download_details
@pytest.fixture @pytest.fixture
def mock_list_files() -> Generator[MagicMock]: def mock_list_files() -> Generator[MagicMock]:
"""Mock list files.""" """Mock list files."""
@ -264,52 +230,30 @@ async def test_agents_download(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker, aioclient_mock: AiohttpClientMocker,
mock_get_download_details: Mock, cloud: Mock,
) -> None: ) -> None:
"""Test agent download backup.""" """Test agent download backup."""
client = await hass_client() client = await hass_client()
backup_id = "23e64aec" backup_id = "23e64aec"
aioclient_mock.get( cloud.files.download.return_value = MockStreamReaderChunked(b"backup data")
mock_get_download_details.return_value["url"], content=b"backup data"
)
resp = await client.get(f"/api/backup/download/{backup_id}?agent_id=cloud.cloud") resp = await client.get(f"/api/backup/download/{backup_id}?agent_id=cloud.cloud")
assert resp.status == 200 assert resp.status == 200
assert await resp.content.read() == b"backup data" assert await resp.content.read() == b"backup data"
@pytest.mark.parametrize("side_effect", [ClientError, CloudError])
@pytest.mark.usefixtures("cloud_logged_in", "mock_list_files")
async def test_agents_download_fail_cloud(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_get_download_details: Mock,
side_effect: Exception,
) -> None:
"""Test agent download backup, when cloud user is logged in."""
client = await hass_client()
backup_id = "23e64aec"
mock_get_download_details.side_effect = side_effect
resp = await client.get(f"/api/backup/download/{backup_id}?agent_id=cloud.cloud")
assert resp.status == 500
content = await resp.content.read()
assert "Failed to get download details" in content.decode()
@pytest.mark.usefixtures("cloud_logged_in", "mock_list_files") @pytest.mark.usefixtures("cloud_logged_in", "mock_list_files")
async def test_agents_download_fail_get( async def test_agents_download_fail_get(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker, cloud: Mock,
mock_get_download_details: Mock,
) -> None: ) -> None:
"""Test agent download backup, when cloud user is logged in.""" """Test agent download backup, when cloud user is logged in."""
client = await hass_client() client = await hass_client()
backup_id = "23e64aec" backup_id = "23e64aec"
aioclient_mock.get(mock_get_download_details.return_value["url"], status=500) cloud.files.download.side_effect = FilesError("Oh no :(")
resp = await client.get(f"/api/backup/download/{backup_id}?agent_id=cloud.cloud") resp = await client.get(f"/api/backup/download/{backup_id}?agent_id=cloud.cloud")
assert resp.status == 500 assert resp.status == 500
@ -336,8 +280,7 @@ async def test_agents_upload(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
aioclient_mock: AiohttpClientMocker, cloud: Mock,
mock_get_upload_details: Mock,
) -> None: ) -> None:
"""Test agent upload backup.""" """Test agent upload backup."""
client = await hass_client() client = await hass_client()
@ -355,8 +298,6 @@ async def test_agents_upload(
protected=True, protected=True,
size=0, size=0,
) )
aioclient_mock.put(mock_get_upload_details.return_value["url"])
with ( with (
patch( patch(
"homeassistant.components.backup.manager.BackupManager.async_get_backup", "homeassistant.components.backup.manager.BackupManager.async_get_backup",
@ -374,26 +315,22 @@ async def test_agents_upload(
data={"file": StringIO("test")}, data={"file": StringIO("test")},
) )
assert len(aioclient_mock.mock_calls) == 1 assert len(cloud.files.upload.mock_calls) == 1
assert aioclient_mock.mock_calls[-1][0] == "PUT" metadata = cloud.files.upload.mock_calls[-1].kwargs["metadata"]
assert aioclient_mock.mock_calls[-1][1] == URL( assert metadata["backup_id"] == backup_id
mock_get_upload_details.return_value["url"]
)
assert isinstance(aioclient_mock.mock_calls[-1][2], AsyncIterator)
assert resp.status == 201 assert resp.status == 201
assert f"Uploading backup {backup_id}" in caplog.text assert f"Uploading backup {backup_id}" in caplog.text
@pytest.mark.parametrize("put_mock_kwargs", [{"status": 500}, {"exc": TimeoutError}]) @pytest.mark.parametrize("side_effect", [FilesError("Boom!"), CloudError("Boom!")])
@pytest.mark.usefixtures("cloud_logged_in", "mock_list_files") @pytest.mark.usefixtures("cloud_logged_in", "mock_list_files")
async def test_agents_upload_fail_put( async def test_agents_upload_fail(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
hass_storage: dict[str, Any], hass_storage: dict[str, Any],
aioclient_mock: AiohttpClientMocker, side_effect: Exception,
mock_get_upload_details: Mock, cloud: Mock,
put_mock_kwargs: dict[str, Any],
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test agent upload backup fails.""" """Test agent upload backup fails."""
@ -412,7 +349,8 @@ async def test_agents_upload_fail_put(
protected=True, protected=True,
size=0, size=0,
) )
aioclient_mock.put(mock_get_upload_details.return_value["url"], **put_mock_kwargs)
cloud.files.upload.side_effect = side_effect
with ( with (
patch( patch(
@ -435,7 +373,6 @@ async def test_agents_upload_fail_put(
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(aioclient_mock.mock_calls) == 2
assert "Failed to upload backup, retrying (2/2) in 60s" in caplog.text assert "Failed to upload backup, retrying (2/2) in 60s" in caplog.text
assert resp.status == 201 assert resp.status == 201
store_backups = hass_storage[BACKUP_DOMAIN]["data"]["backups"] store_backups = hass_storage[BACKUP_DOMAIN]["data"]["backups"]
@ -445,59 +382,6 @@ async def test_agents_upload_fail_put(
assert stored_backup["failed_agent_ids"] == ["cloud.cloud"] assert stored_backup["failed_agent_ids"] == ["cloud.cloud"]
@pytest.mark.parametrize("side_effect", [ClientError, CloudError])
@pytest.mark.usefixtures("cloud_logged_in")
async def test_agents_upload_fail_cloud(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
hass_storage: dict[str, Any],
mock_get_upload_details: Mock,
side_effect: Exception,
) -> None:
"""Test agent upload backup, when cloud user is logged in."""
client = await hass_client()
backup_id = "test-backup"
mock_get_upload_details.side_effect = side_effect
test_backup = AgentBackup(
addons=[AddonInfo(name="Test", slug="test", version="1.0.0")],
backup_id=backup_id,
database_included=True,
date="1970-01-01T00:00:00.000Z",
extra_metadata={},
folders=[Folder.MEDIA, Folder.SHARE],
homeassistant_included=True,
homeassistant_version="2024.12.0",
name="Test",
protected=True,
size=0,
)
with (
patch(
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
) as fetch_backup,
patch(
"homeassistant.components.backup.manager.read_backup",
return_value=test_backup,
),
patch("pathlib.Path.open") as mocked_open,
patch("homeassistant.components.cloud.backup.asyncio.sleep"),
):
mocked_open.return_value.read = Mock(side_effect=[b"test", b""])
fetch_backup.return_value = test_backup
resp = await client.post(
"/api/backup/upload?agent_id=cloud.cloud",
data={"file": StringIO("test")},
)
await hass.async_block_till_done()
assert resp.status == 201
store_backups = hass_storage[BACKUP_DOMAIN]["data"]["backups"]
assert len(store_backups) == 1
stored_backup = store_backups[0]
assert stored_backup["backup_id"] == backup_id
assert stored_backup["failed_agent_ids"] == ["cloud.cloud"]
async def test_agents_upload_not_protected( async def test_agents_upload_not_protected(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,