Adjust backup agent platform (#132944)

* Adjust backup agent platform

* Adjust according to discussion

* Clean up the local agent dict too

* Add test

* Update kitchen_sink

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Adjust tests

* Clean up

* Fix kitchen sink reload

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Erik Montnemery 2024-12-12 13:41:56 +01:00 committed by GitHub
parent f2aaf2ac4a
commit 85d4572a17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 235 additions and 32 deletions

View File

@ -7,7 +7,9 @@ from collections.abc import AsyncIterator, Callable, Coroutine
from pathlib import Path
from typing import Any, Protocol
from homeassistant.core import HomeAssistant
from propcache import cached_property
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from .models import AgentBackup
@ -26,8 +28,14 @@ class BackupAgentUnreachableError(BackupAgentError):
class BackupAgent(abc.ABC):
"""Backup agent interface."""
domain: str
name: str
@cached_property
def agent_id(self) -> str:
"""Return the agent_id."""
return f"{self.domain}.{self.name}"
@abc.abstractmethod
async def async_download_backup(
self,
@ -98,3 +106,16 @@ class BackupAgentPlatformProtocol(Protocol):
**kwargs: Any,
) -> list[BackupAgent]:
"""Return a list of backup agents."""
@callback
def async_register_backup_agents_listener(
self,
hass: HomeAssistant,
*,
listener: Callable[[], None],
**kwargs: Any,
) -> Callable[[], None]:
"""Register a listener to be called when agents are added or removed.
:return: A function to unregister the listener.
"""

View File

@ -12,7 +12,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.hassio import is_hassio
from .agent import BackupAgent, LocalBackupAgent
from .const import LOGGER
from .const import DOMAIN, LOGGER
from .models import AgentBackup
from .util import read_backup
@ -30,6 +30,7 @@ async def async_get_backup_agents(
class CoreLocalBackupAgent(LocalBackupAgent):
"""Local backup agent for Core and Container installations."""
domain = DOMAIN
name = "local"
def __init__(self, hass: HomeAssistant) -> None:

View File

@ -243,6 +243,7 @@ class BackupManager:
"""Initialize the backup manager."""
self.hass = hass
self.platforms: dict[str, BackupPlatformProtocol] = {}
self.backup_agent_platforms: dict[str, BackupAgentPlatformProtocol] = {}
self.backup_agents: dict[str, BackupAgent] = {}
self.local_backup_agents: dict[str, LocalBackupAgent] = {}
@ -291,22 +292,48 @@ class BackupManager:
self.platforms[integration_domain] = platform
async def _async_add_platform_agents(
@callback
def _async_add_backup_agent_platform(
self,
integration_domain: str,
platform: BackupAgentPlatformProtocol,
) -> None:
"""Add a platform to the backup manager."""
"""Add backup agent platform to the backup manager."""
if not hasattr(platform, "async_get_backup_agents"):
return
agents = await platform.async_get_backup_agents(self.hass)
self.backup_agents.update(
{f"{integration_domain}.{agent.name}": agent for agent in agents}
self.backup_agent_platforms[integration_domain] = platform
@callback
def listener() -> None:
LOGGER.debug("Loading backup agents for %s", integration_domain)
self.hass.async_create_task(
self._async_reload_backup_agents(integration_domain)
)
if hasattr(platform, "async_register_backup_agents_listener"):
platform.async_register_backup_agents_listener(self.hass, listener=listener)
listener()
async def _async_reload_backup_agents(self, domain: str) -> None:
"""Add backup agent platform to the backup manager."""
platform = self.backup_agent_platforms[domain]
# Remove all agents for the domain
for agent_id in list(self.backup_agents):
if self.backup_agents[agent_id].domain == domain:
del self.backup_agents[agent_id]
for agent_id in list(self.local_backup_agents):
if self.local_backup_agents[agent_id].domain == domain:
del self.local_backup_agents[agent_id]
# Add new agents
agents = await platform.async_get_backup_agents(self.hass)
self.backup_agents.update({agent.agent_id: agent for agent in agents})
self.local_backup_agents.update(
{
f"{integration_domain}.{agent.name}": agent
agent.agent_id: agent
for agent in agents
if isinstance(agent, LocalBackupAgent)
}
@ -320,7 +347,7 @@ class BackupManager:
) -> None:
"""Add a backup platform manager."""
self._add_platform_pre_post_handler(integration_domain, platform)
await self._async_add_platform_agents(integration_domain, platform)
self._async_add_backup_agent_platform(integration_domain, platform)
LOGGER.debug("Backup platform %s loaded", integration_domain)
LOGGER.debug("%s platforms loaded in total", len(self.platforms))
LOGGER.debug("%s agents loaded in total", len(self.backup_agents))

View File

@ -38,7 +38,11 @@ async def async_get_backup_agents(
**kwargs: Any,
) -> list[BackupAgent]:
"""Return the cloud backup agent."""
return [CloudBackupAgent(hass=hass, cloud=hass.data[DATA_CLOUD])]
cloud = hass.data[DATA_CLOUD]
if not cloud.is_logged_in:
return []
return [CloudBackupAgent(hass=hass, cloud=cloud)]
class ChunkAsyncStreamIterator:
@ -69,6 +73,7 @@ class ChunkAsyncStreamIterator:
class CloudBackupAgent(BackupAgent):
"""Cloud backup agent."""
domain = DOMAIN
name = DOMAIN
def __init__(self, hass: HomeAssistant, cloud: Cloud[CloudClient]) -> None:

View File

@ -79,6 +79,8 @@ def _backup_details_to_agent_backup(
class SupervisorBackupAgent(BackupAgent):
"""Backup agent for supervised installations."""
domain = DOMAIN
def __init__(self, hass: HomeAssistant, name: str, location: str | None) -> None:
"""Initialize the backup agent."""
super().__init__()

View File

@ -26,8 +26,7 @@ from homeassistant.helpers.issue_registry import IssueSeverity, async_create_iss
from homeassistant.helpers.typing import ConfigType
import homeassistant.util.dt as dt_util
DOMAIN = "kitchen_sink"
from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN
COMPONENTS_WITH_DEMO_PLATFORM = [
Platform.BUTTON,
@ -88,9 +87,27 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
# Start a reauth flow
config_entry.async_start_reauth(hass)
# Notify backup listeners
hass.async_create_task(_notify_backup_listeners(hass), eager_start=False)
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload config entry."""
# Notify backup listeners
hass.async_create_task(_notify_backup_listeners(hass), eager_start=False)
return await hass.config_entries.async_unload_platforms(
entry, COMPONENTS_WITH_DEMO_PLATFORM
)
async def _notify_backup_listeners(hass: HomeAssistant) -> None:
for listener in hass.data.get(DATA_BACKUP_AGENT_LISTENERS, []):
listener()
def _create_issues(hass: HomeAssistant) -> None:
"""Create some issue registry issues."""
async_create_issue(

View File

@ -8,7 +8,9 @@ import logging
from typing import Any
from homeassistant.components.backup import AddonInfo, AgentBackup, BackupAgent, Folder
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from . import DATA_BACKUP_AGENT_LISTENERS, DOMAIN
LOGGER = logging.getLogger(__name__)
@ -17,12 +19,35 @@ async def async_get_backup_agents(
hass: HomeAssistant,
) -> list[BackupAgent]:
"""Register the backup agents."""
if not hass.config_entries.async_loaded_entries(DOMAIN):
LOGGER.info("No config entry found or entry is not loaded")
return []
return [KitchenSinkBackupAgent("syncer")]
@callback
def async_register_backup_agents_listener(
hass: HomeAssistant,
*,
listener: Callable[[], None],
**kwargs: Any,
) -> Callable[[], None]:
"""Register a listener to be called when agents are added or removed."""
hass.data.setdefault(DATA_BACKUP_AGENT_LISTENERS, []).append(listener)
@callback
def remove_listener() -> None:
"""Remove the listener."""
hass.data[DATA_BACKUP_AGENT_LISTENERS].remove(listener)
return remove_listener
class KitchenSinkBackupAgent(BackupAgent):
"""Kitchen sink backup agent."""
domain = DOMAIN
def __init__(self, name: str) -> None:
"""Initialize the kitchen sink backup sync agent."""
super().__init__()

View File

@ -0,0 +1,12 @@
"""Constants for the Kitchen Sink integration."""
from __future__ import annotations
from collections.abc import Callable
from homeassistant.util.hass_dict import HassKey
DOMAIN = "kitchen_sink"
DATA_BACKUP_AGENT_LISTENERS: HassKey[list[Callable[[], None]]] = HassKey(
f"{DOMAIN}.backup_agent_listeners"
)

View File

@ -57,6 +57,8 @@ TEST_DOMAIN = "test"
class BackupAgentTest(BackupAgent):
"""Test backup agent."""
domain = "test"
def __init__(self, name: str, backups: list[AgentBackup] | None = None) -> None:
"""Initialize the backup agent."""
self.name = name

View File

@ -6,6 +6,7 @@ import asyncio
from collections.abc import Generator
from io import StringIO
import json
from pathlib import Path
from typing import Any
from unittest.mock import ANY, AsyncMock, MagicMock, Mock, call, mock_open, patch
@ -18,6 +19,7 @@ from homeassistant.components.backup import (
BackupManager,
BackupPlatformProtocol,
Folder,
LocalBackupAgent,
backup as local_backup_platform,
)
from homeassistant.components.backup.const import DATA_MANAGER
@ -534,21 +536,86 @@ async def test_loading_platforms(
assert not manager.platforms
get_agents_mock = AsyncMock(return_value=[])
await _setup_backup_platform(
hass,
platform=Mock(
async_pre_backup=AsyncMock(),
async_post_backup=AsyncMock(),
async_get_backup_agents=AsyncMock(),
async_get_backup_agents=get_agents_mock,
),
)
await manager.load_platforms()
await hass.async_block_till_done()
assert len(manager.platforms) == 1
assert "Loaded 1 platforms" in caplog.text
get_agents_mock.assert_called_once_with(hass)
class LocalBackupAgentTest(BackupAgentTest, LocalBackupAgent):
"""Local backup agent."""
def get_backup_path(self, backup_id: str) -> Path:
"""Return the local path to a backup."""
return "test.tar"
@pytest.mark.parametrize(
("agent_class", "num_local_agents"),
[(LocalBackupAgentTest, 2), (BackupAgentTest, 1)],
)
async def test_loading_platform_with_listener(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
agent_class: type[BackupAgentTest],
num_local_agents: int,
) -> None:
"""Test loading a backup agent platform which can be listened to."""
ws_client = await hass_ws_client(hass)
assert await async_setup_component(hass, DOMAIN, {})
manager = hass.data[DATA_MANAGER]
get_agents_mock = AsyncMock(return_value=[agent_class("remote1", backups=[])])
register_listener_mock = Mock()
await _setup_backup_platform(
hass,
domain="test",
platform=Mock(
async_get_backup_agents=get_agents_mock,
async_register_backup_agents_listener=register_listener_mock,
),
)
await hass.async_block_till_done()
await ws_client.send_json_auto_id({"type": "backup/agents/info"})
resp = await ws_client.receive_json()
assert resp["result"]["agents"] == [
{"agent_id": "backup.local"},
{"agent_id": "test.remote1"},
]
assert len(manager.local_backup_agents) == num_local_agents
get_agents_mock.assert_called_once_with(hass)
register_listener_mock.assert_called_once_with(hass, listener=ANY)
get_agents_mock.reset_mock()
get_agents_mock.return_value = [agent_class("remote2", backups=[])]
listener = register_listener_mock.call_args[1]["listener"]
listener()
get_agents_mock.assert_called_once_with(hass)
await ws_client.send_json_auto_id({"type": "backup/agents/info"})
resp = await ws_client.receive_json()
assert resp["result"]["agents"] == [
{"agent_id": "backup.local"},
{"agent_id": "test.remote2"},
]
assert len(manager.local_backup_agents) == num_local_agents
@pytest.mark.parametrize(
"platform_mock",

View File

@ -26,7 +26,10 @@ from tests.typing import ClientSessionGenerator, MagicMock, WebSocketGenerator
@pytest.fixture(autouse=True)
async def setup_integration(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker, cloud: MagicMock
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
cloud: MagicMock,
cloud_logged_in: None,
) -> AsyncGenerator[None]:
"""Set up cloud integration."""
with patch("homeassistant.components.backup.is_hassio", return_value=False):

View File

@ -57,6 +57,27 @@ async def test_agents_info(
"agents": [{"agent_id": "backup.local"}, {"agent_id": "kitchen_sink.syncer"}],
}
config_entry = hass.config_entries.async_entries(DOMAIN)[0]
await hass.config_entries.async_unload(config_entry.entry_id)
await hass.async_block_till_done()
await client.send_json_auto_id({"type": "backup/agents/info"})
response = await client.receive_json()
assert response["success"]
assert response["result"] == {"agents": [{"agent_id": "backup.local"}]}
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
await client.send_json_auto_id({"type": "backup/agents/info"})
response = await client.receive_json()
assert response["success"]
assert response["result"] == {
"agents": [{"agent_id": "backup.local"}, {"agent_id": "kitchen_sink.syncer"}],
}
async def test_agents_list_backups(
hass: HomeAssistant,