Adjust backup agent platform (#132944)

* Adjust backup agent platform

* Adjust according to discussion

* Clean up the local agent dict too

* Add test

* Update kitchen_sink

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Adjust tests

* Clean up

* Fix kitchen sink reload

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Erik Montnemery 2024-12-12 13:41:56 +01:00 committed by GitHub
parent f2aaf2ac4a
commit 85d4572a17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 235 additions and 32 deletions

View File

@ -7,7 +7,9 @@ from collections.abc import AsyncIterator, Callable, Coroutine
from pathlib import Path from pathlib import Path
from typing import Any, Protocol from typing import Any, Protocol
from homeassistant.core import HomeAssistant from propcache import cached_property
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from .models import AgentBackup from .models import AgentBackup
@ -26,8 +28,14 @@ class BackupAgentUnreachableError(BackupAgentError):
class BackupAgent(abc.ABC): class BackupAgent(abc.ABC):
"""Backup agent interface.""" """Backup agent interface."""
domain: str
name: str name: str
@cached_property
def agent_id(self) -> str:
"""Return the agent_id."""
return f"{self.domain}.{self.name}"
@abc.abstractmethod @abc.abstractmethod
async def async_download_backup( async def async_download_backup(
self, self,
@ -98,3 +106,16 @@ class BackupAgentPlatformProtocol(Protocol):
**kwargs: Any, **kwargs: Any,
) -> list[BackupAgent]: ) -> list[BackupAgent]:
"""Return a list of backup agents.""" """Return a list of backup agents."""
@callback
def async_register_backup_agents_listener(
self,
hass: HomeAssistant,
*,
listener: Callable[[], None],
**kwargs: Any,
) -> Callable[[], None]:
"""Register a listener to be called when agents are added or removed.
:return: A function to unregister the listener.
"""

View File

@ -12,7 +12,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.hassio import is_hassio from homeassistant.helpers.hassio import is_hassio
from .agent import BackupAgent, LocalBackupAgent from .agent import BackupAgent, LocalBackupAgent
from .const import LOGGER from .const import DOMAIN, LOGGER
from .models import AgentBackup from .models import AgentBackup
from .util import read_backup from .util import read_backup
@ -30,6 +30,7 @@ async def async_get_backup_agents(
class CoreLocalBackupAgent(LocalBackupAgent): class CoreLocalBackupAgent(LocalBackupAgent):
"""Local backup agent for Core and Container installations.""" """Local backup agent for Core and Container installations."""
domain = DOMAIN
name = "local" name = "local"
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:

View File

@ -243,6 +243,7 @@ class BackupManager:
"""Initialize the backup manager.""" """Initialize the backup manager."""
self.hass = hass self.hass = hass
self.platforms: dict[str, BackupPlatformProtocol] = {} self.platforms: dict[str, BackupPlatformProtocol] = {}
self.backup_agent_platforms: dict[str, BackupAgentPlatformProtocol] = {}
self.backup_agents: dict[str, BackupAgent] = {} self.backup_agents: dict[str, BackupAgent] = {}
self.local_backup_agents: dict[str, LocalBackupAgent] = {} self.local_backup_agents: dict[str, LocalBackupAgent] = {}
@ -291,22 +292,48 @@ class BackupManager:
self.platforms[integration_domain] = platform self.platforms[integration_domain] = platform
async def _async_add_platform_agents( @callback
def _async_add_backup_agent_platform(
self, self,
integration_domain: str, integration_domain: str,
platform: BackupAgentPlatformProtocol, platform: BackupAgentPlatformProtocol,
) -> None: ) -> None:
"""Add a platform to the backup manager.""" """Add backup agent platform to the backup manager."""
if not hasattr(platform, "async_get_backup_agents"): if not hasattr(platform, "async_get_backup_agents"):
return return
agents = await platform.async_get_backup_agents(self.hass) self.backup_agent_platforms[integration_domain] = platform
self.backup_agents.update(
{f"{integration_domain}.{agent.name}": agent for agent in agents} @callback
def listener() -> None:
LOGGER.debug("Loading backup agents for %s", integration_domain)
self.hass.async_create_task(
self._async_reload_backup_agents(integration_domain)
) )
if hasattr(platform, "async_register_backup_agents_listener"):
platform.async_register_backup_agents_listener(self.hass, listener=listener)
listener()
async def _async_reload_backup_agents(self, domain: str) -> None:
"""Add backup agent platform to the backup manager."""
platform = self.backup_agent_platforms[domain]
# Remove all agents for the domain
for agent_id in list(self.backup_agents):
if self.backup_agents[agent_id].domain == domain:
del self.backup_agents[agent_id]
for agent_id in list(self.local_backup_agents):
if self.local_backup_agents[agent_id].domain == domain:
del self.local_backup_agents[agent_id]
# Add new agents
agents = await platform.async_get_backup_agents(self.hass)
self.backup_agents.update({agent.agent_id: agent for agent in agents})
self.local_backup_agents.update( self.local_backup_agents.update(
{ {
f"{integration_domain}.{agent.name}": agent agent.agent_id: agent
for agent in agents for agent in agents
if isinstance(agent, LocalBackupAgent) if isinstance(agent, LocalBackupAgent)
} }
@ -320,7 +347,7 @@ class BackupManager:
) -> None: ) -> None:
"""Add a backup platform manager.""" """Add a backup platform manager."""
self._add_platform_pre_post_handler(integration_domain, platform) self._add_platform_pre_post_handler(integration_domain, platform)
await self._async_add_platform_agents(integration_domain, platform) self._async_add_backup_agent_platform(integration_domain, platform)
LOGGER.debug("Backup platform %s loaded", integration_domain) LOGGER.debug("Backup platform %s loaded", integration_domain)
LOGGER.debug("%s platforms loaded in total", len(self.platforms)) LOGGER.debug("%s platforms loaded in total", len(self.platforms))
LOGGER.debug("%s agents loaded in total", len(self.backup_agents)) LOGGER.debug("%s agents loaded in total", len(self.backup_agents))

View File

@ -38,7 +38,11 @@ async def async_get_backup_agents(
**kwargs: Any, **kwargs: Any,
) -> list[BackupAgent]: ) -> list[BackupAgent]:
"""Return the cloud backup agent.""" """Return the cloud backup agent."""
return [CloudBackupAgent(hass=hass, cloud=hass.data[DATA_CLOUD])] cloud = hass.data[DATA_CLOUD]
if not cloud.is_logged_in:
return []
return [CloudBackupAgent(hass=hass, cloud=cloud)]
class ChunkAsyncStreamIterator: class ChunkAsyncStreamIterator:
@ -69,6 +73,7 @@ class ChunkAsyncStreamIterator:
class CloudBackupAgent(BackupAgent): class CloudBackupAgent(BackupAgent):
"""Cloud backup agent.""" """Cloud backup agent."""
domain = DOMAIN
name = DOMAIN name = DOMAIN
def __init__(self, hass: HomeAssistant, cloud: Cloud[CloudClient]) -> None: def __init__(self, hass: HomeAssistant, cloud: Cloud[CloudClient]) -> None:

View File

@ -79,6 +79,8 @@ def _backup_details_to_agent_backup(
class SupervisorBackupAgent(BackupAgent): class SupervisorBackupAgent(BackupAgent):
"""Backup agent for supervised installations.""" """Backup agent for supervised installations."""
domain = DOMAIN
def __init__(self, hass: HomeAssistant, name: str, location: str | None) -> None: def __init__(self, hass: HomeAssistant, name: str, location: str | None) -> None:
"""Initialize the backup agent.""" """Initialize the backup agent."""
super().__init__() super().__init__()

View File

@ -26,8 +26,7 @@ from homeassistant.helpers.issue_registry import IssueSeverity, async_create_iss
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
DOMAIN = "kitchen_sink" from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN
COMPONENTS_WITH_DEMO_PLATFORM = [ COMPONENTS_WITH_DEMO_PLATFORM = [
Platform.BUTTON, Platform.BUTTON,
@ -88,9 +87,27 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
# Start a reauth flow # Start a reauth flow
config_entry.async_start_reauth(hass) config_entry.async_start_reauth(hass)
# Notify backup listeners
hass.async_create_task(_notify_backup_listeners(hass), eager_start=False)
return True return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload config entry."""
# Notify backup listeners
hass.async_create_task(_notify_backup_listeners(hass), eager_start=False)
return await hass.config_entries.async_unload_platforms(
entry, COMPONENTS_WITH_DEMO_PLATFORM
)
async def _notify_backup_listeners(hass: HomeAssistant) -> None:
for listener in hass.data.get(DATA_BACKUP_AGENT_LISTENERS, []):
listener()
def _create_issues(hass: HomeAssistant) -> None: def _create_issues(hass: HomeAssistant) -> None:
"""Create some issue registry issues.""" """Create some issue registry issues."""
async_create_issue( async_create_issue(

View File

@ -8,7 +8,9 @@ import logging
from typing import Any from typing import Any
from homeassistant.components.backup import AddonInfo, AgentBackup, BackupAgent, Folder from homeassistant.components.backup import AddonInfo, AgentBackup, BackupAgent, Folder
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from . import DATA_BACKUP_AGENT_LISTENERS, DOMAIN
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
@ -17,12 +19,35 @@ async def async_get_backup_agents(
hass: HomeAssistant, hass: HomeAssistant,
) -> list[BackupAgent]: ) -> list[BackupAgent]:
"""Register the backup agents.""" """Register the backup agents."""
if not hass.config_entries.async_loaded_entries(DOMAIN):
LOGGER.info("No config entry found or entry is not loaded")
return []
return [KitchenSinkBackupAgent("syncer")] return [KitchenSinkBackupAgent("syncer")]
@callback
def async_register_backup_agents_listener(
hass: HomeAssistant,
*,
listener: Callable[[], None],
**kwargs: Any,
) -> Callable[[], None]:
"""Register a listener to be called when agents are added or removed."""
hass.data.setdefault(DATA_BACKUP_AGENT_LISTENERS, []).append(listener)
@callback
def remove_listener() -> None:
"""Remove the listener."""
hass.data[DATA_BACKUP_AGENT_LISTENERS].remove(listener)
return remove_listener
class KitchenSinkBackupAgent(BackupAgent): class KitchenSinkBackupAgent(BackupAgent):
"""Kitchen sink backup agent.""" """Kitchen sink backup agent."""
domain = DOMAIN
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
"""Initialize the kitchen sink backup sync agent.""" """Initialize the kitchen sink backup sync agent."""
super().__init__() super().__init__()

View File

@ -0,0 +1,12 @@
"""Constants for the Kitchen Sink integration."""
from __future__ import annotations
from collections.abc import Callable
from homeassistant.util.hass_dict import HassKey
DOMAIN = "kitchen_sink"
DATA_BACKUP_AGENT_LISTENERS: HassKey[list[Callable[[], None]]] = HassKey(
f"{DOMAIN}.backup_agent_listeners"
)

View File

@ -57,6 +57,8 @@ TEST_DOMAIN = "test"
class BackupAgentTest(BackupAgent): class BackupAgentTest(BackupAgent):
"""Test backup agent.""" """Test backup agent."""
domain = "test"
def __init__(self, name: str, backups: list[AgentBackup] | None = None) -> None: def __init__(self, name: str, backups: list[AgentBackup] | None = None) -> None:
"""Initialize the backup agent.""" """Initialize the backup agent."""
self.name = name self.name = name

View File

@ -6,6 +6,7 @@ import asyncio
from collections.abc import Generator from collections.abc import Generator
from io import StringIO from io import StringIO
import json import json
from pathlib import Path
from typing import Any from typing import Any
from unittest.mock import ANY, AsyncMock, MagicMock, Mock, call, mock_open, patch from unittest.mock import ANY, AsyncMock, MagicMock, Mock, call, mock_open, patch
@ -18,6 +19,7 @@ from homeassistant.components.backup import (
BackupManager, BackupManager,
BackupPlatformProtocol, BackupPlatformProtocol,
Folder, Folder,
LocalBackupAgent,
backup as local_backup_platform, backup as local_backup_platform,
) )
from homeassistant.components.backup.const import DATA_MANAGER from homeassistant.components.backup.const import DATA_MANAGER
@ -534,21 +536,86 @@ async def test_loading_platforms(
assert not manager.platforms assert not manager.platforms
get_agents_mock = AsyncMock(return_value=[])
await _setup_backup_platform( await _setup_backup_platform(
hass, hass,
platform=Mock( platform=Mock(
async_pre_backup=AsyncMock(), async_pre_backup=AsyncMock(),
async_post_backup=AsyncMock(), async_post_backup=AsyncMock(),
async_get_backup_agents=AsyncMock(), async_get_backup_agents=get_agents_mock,
), ),
) )
await manager.load_platforms() await manager.load_platforms()
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(manager.platforms) == 1 assert len(manager.platforms) == 1
assert "Loaded 1 platforms" in caplog.text assert "Loaded 1 platforms" in caplog.text
get_agents_mock.assert_called_once_with(hass)
class LocalBackupAgentTest(BackupAgentTest, LocalBackupAgent):
"""Local backup agent."""
def get_backup_path(self, backup_id: str) -> Path:
"""Return the local path to a backup."""
return "test.tar"
@pytest.mark.parametrize(
("agent_class", "num_local_agents"),
[(LocalBackupAgentTest, 2), (BackupAgentTest, 1)],
)
async def test_loading_platform_with_listener(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
agent_class: type[BackupAgentTest],
num_local_agents: int,
) -> None:
"""Test loading a backup agent platform which can be listened to."""
ws_client = await hass_ws_client(hass)
assert await async_setup_component(hass, DOMAIN, {})
manager = hass.data[DATA_MANAGER]
get_agents_mock = AsyncMock(return_value=[agent_class("remote1", backups=[])])
register_listener_mock = Mock()
await _setup_backup_platform(
hass,
domain="test",
platform=Mock(
async_get_backup_agents=get_agents_mock,
async_register_backup_agents_listener=register_listener_mock,
),
)
await hass.async_block_till_done()
await ws_client.send_json_auto_id({"type": "backup/agents/info"})
resp = await ws_client.receive_json()
assert resp["result"]["agents"] == [
{"agent_id": "backup.local"},
{"agent_id": "test.remote1"},
]
assert len(manager.local_backup_agents) == num_local_agents
get_agents_mock.assert_called_once_with(hass)
register_listener_mock.assert_called_once_with(hass, listener=ANY)
get_agents_mock.reset_mock()
get_agents_mock.return_value = [agent_class("remote2", backups=[])]
listener = register_listener_mock.call_args[1]["listener"]
listener()
get_agents_mock.assert_called_once_with(hass)
await ws_client.send_json_auto_id({"type": "backup/agents/info"})
resp = await ws_client.receive_json()
assert resp["result"]["agents"] == [
{"agent_id": "backup.local"},
{"agent_id": "test.remote2"},
]
assert len(manager.local_backup_agents) == num_local_agents
@pytest.mark.parametrize( @pytest.mark.parametrize(
"platform_mock", "platform_mock",

View File

@ -26,7 +26,10 @@ from tests.typing import ClientSessionGenerator, MagicMock, WebSocketGenerator
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
async def setup_integration( async def setup_integration(
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker, cloud: MagicMock hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
cloud: MagicMock,
cloud_logged_in: None,
) -> AsyncGenerator[None]: ) -> AsyncGenerator[None]:
"""Set up cloud integration.""" """Set up cloud integration."""
with patch("homeassistant.components.backup.is_hassio", return_value=False): with patch("homeassistant.components.backup.is_hassio", return_value=False):

View File

@ -57,6 +57,27 @@ async def test_agents_info(
"agents": [{"agent_id": "backup.local"}, {"agent_id": "kitchen_sink.syncer"}], "agents": [{"agent_id": "backup.local"}, {"agent_id": "kitchen_sink.syncer"}],
} }
config_entry = hass.config_entries.async_entries(DOMAIN)[0]
await hass.config_entries.async_unload(config_entry.entry_id)
await hass.async_block_till_done()
await client.send_json_auto_id({"type": "backup/agents/info"})
response = await client.receive_json()
assert response["success"]
assert response["result"] == {"agents": [{"agent_id": "backup.local"}]}
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
await client.send_json_auto_id({"type": "backup/agents/info"})
response = await client.receive_json()
assert response["success"]
assert response["result"] == {
"agents": [{"agent_id": "backup.local"}, {"agent_id": "kitchen_sink.syncer"}],
}
async def test_agents_list_backups( async def test_agents_list_backups(
hass: HomeAssistant, hass: HomeAssistant,