mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 04:37:06 +00:00
Adjust backup agent platform (#132944)
* Adjust backup agent platform * Adjust according to discussion * Clean up the local agent dict too * Add test * Update kitchen_sink * Apply suggestions from code review Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * Adjust tests * Clean up * Fix kitchen sink reload --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
f2aaf2ac4a
commit
85d4572a17
@ -7,7 +7,9 @@ from collections.abc import AsyncIterator, Callable, Coroutine
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from propcache import cached_property
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from .models import AgentBackup
|
||||
@ -26,8 +28,14 @@ class BackupAgentUnreachableError(BackupAgentError):
|
||||
class BackupAgent(abc.ABC):
|
||||
"""Backup agent interface."""
|
||||
|
||||
domain: str
|
||||
name: str
|
||||
|
||||
@cached_property
|
||||
def agent_id(self) -> str:
|
||||
"""Return the agent_id."""
|
||||
return f"{self.domain}.{self.name}"
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_download_backup(
|
||||
self,
|
||||
@ -98,3 +106,16 @@ class BackupAgentPlatformProtocol(Protocol):
|
||||
**kwargs: Any,
|
||||
) -> list[BackupAgent]:
|
||||
"""Return a list of backup agents."""
|
||||
|
||||
@callback
|
||||
def async_register_backup_agents_listener(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
listener: Callable[[], None],
|
||||
**kwargs: Any,
|
||||
) -> Callable[[], None]:
|
||||
"""Register a listener to be called when agents are added or removed.
|
||||
|
||||
:return: A function to unregister the listener.
|
||||
"""
|
||||
|
@ -12,7 +12,7 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.hassio import is_hassio
|
||||
|
||||
from .agent import BackupAgent, LocalBackupAgent
|
||||
from .const import LOGGER
|
||||
from .const import DOMAIN, LOGGER
|
||||
from .models import AgentBackup
|
||||
from .util import read_backup
|
||||
|
||||
@ -30,6 +30,7 @@ async def async_get_backup_agents(
|
||||
class CoreLocalBackupAgent(LocalBackupAgent):
|
||||
"""Local backup agent for Core and Container installations."""
|
||||
|
||||
domain = DOMAIN
|
||||
name = "local"
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
|
@ -243,6 +243,7 @@ class BackupManager:
|
||||
"""Initialize the backup manager."""
|
||||
self.hass = hass
|
||||
self.platforms: dict[str, BackupPlatformProtocol] = {}
|
||||
self.backup_agent_platforms: dict[str, BackupAgentPlatformProtocol] = {}
|
||||
self.backup_agents: dict[str, BackupAgent] = {}
|
||||
self.local_backup_agents: dict[str, LocalBackupAgent] = {}
|
||||
|
||||
@ -291,22 +292,48 @@ class BackupManager:
|
||||
|
||||
self.platforms[integration_domain] = platform
|
||||
|
||||
async def _async_add_platform_agents(
|
||||
@callback
|
||||
def _async_add_backup_agent_platform(
|
||||
self,
|
||||
integration_domain: str,
|
||||
platform: BackupAgentPlatformProtocol,
|
||||
) -> None:
|
||||
"""Add a platform to the backup manager."""
|
||||
"""Add backup agent platform to the backup manager."""
|
||||
if not hasattr(platform, "async_get_backup_agents"):
|
||||
return
|
||||
|
||||
agents = await platform.async_get_backup_agents(self.hass)
|
||||
self.backup_agents.update(
|
||||
{f"{integration_domain}.{agent.name}": agent for agent in agents}
|
||||
self.backup_agent_platforms[integration_domain] = platform
|
||||
|
||||
@callback
|
||||
def listener() -> None:
|
||||
LOGGER.debug("Loading backup agents for %s", integration_domain)
|
||||
self.hass.async_create_task(
|
||||
self._async_reload_backup_agents(integration_domain)
|
||||
)
|
||||
|
||||
if hasattr(platform, "async_register_backup_agents_listener"):
|
||||
platform.async_register_backup_agents_listener(self.hass, listener=listener)
|
||||
|
||||
listener()
|
||||
|
||||
async def _async_reload_backup_agents(self, domain: str) -> None:
|
||||
"""Add backup agent platform to the backup manager."""
|
||||
platform = self.backup_agent_platforms[domain]
|
||||
|
||||
# Remove all agents for the domain
|
||||
for agent_id in list(self.backup_agents):
|
||||
if self.backup_agents[agent_id].domain == domain:
|
||||
del self.backup_agents[agent_id]
|
||||
for agent_id in list(self.local_backup_agents):
|
||||
if self.local_backup_agents[agent_id].domain == domain:
|
||||
del self.local_backup_agents[agent_id]
|
||||
|
||||
# Add new agents
|
||||
agents = await platform.async_get_backup_agents(self.hass)
|
||||
self.backup_agents.update({agent.agent_id: agent for agent in agents})
|
||||
self.local_backup_agents.update(
|
||||
{
|
||||
f"{integration_domain}.{agent.name}": agent
|
||||
agent.agent_id: agent
|
||||
for agent in agents
|
||||
if isinstance(agent, LocalBackupAgent)
|
||||
}
|
||||
@ -320,7 +347,7 @@ class BackupManager:
|
||||
) -> None:
|
||||
"""Add a backup platform manager."""
|
||||
self._add_platform_pre_post_handler(integration_domain, platform)
|
||||
await self._async_add_platform_agents(integration_domain, platform)
|
||||
self._async_add_backup_agent_platform(integration_domain, platform)
|
||||
LOGGER.debug("Backup platform %s loaded", integration_domain)
|
||||
LOGGER.debug("%s platforms loaded in total", len(self.platforms))
|
||||
LOGGER.debug("%s agents loaded in total", len(self.backup_agents))
|
||||
|
@ -38,7 +38,11 @@ async def async_get_backup_agents(
|
||||
**kwargs: Any,
|
||||
) -> list[BackupAgent]:
|
||||
"""Return the cloud backup agent."""
|
||||
return [CloudBackupAgent(hass=hass, cloud=hass.data[DATA_CLOUD])]
|
||||
cloud = hass.data[DATA_CLOUD]
|
||||
if not cloud.is_logged_in:
|
||||
return []
|
||||
|
||||
return [CloudBackupAgent(hass=hass, cloud=cloud)]
|
||||
|
||||
|
||||
class ChunkAsyncStreamIterator:
|
||||
@ -69,6 +73,7 @@ class ChunkAsyncStreamIterator:
|
||||
class CloudBackupAgent(BackupAgent):
|
||||
"""Cloud backup agent."""
|
||||
|
||||
domain = DOMAIN
|
||||
name = DOMAIN
|
||||
|
||||
def __init__(self, hass: HomeAssistant, cloud: Cloud[CloudClient]) -> None:
|
||||
|
@ -79,6 +79,8 @@ def _backup_details_to_agent_backup(
|
||||
class SupervisorBackupAgent(BackupAgent):
|
||||
"""Backup agent for supervised installations."""
|
||||
|
||||
domain = DOMAIN
|
||||
|
||||
def __init__(self, hass: HomeAssistant, name: str, location: str | None) -> None:
|
||||
"""Initialize the backup agent."""
|
||||
super().__init__()
|
||||
|
@ -26,8 +26,7 @@ from homeassistant.helpers.issue_registry import IssueSeverity, async_create_iss
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
DOMAIN = "kitchen_sink"
|
||||
|
||||
from .const import DATA_BACKUP_AGENT_LISTENERS, DOMAIN
|
||||
|
||||
COMPONENTS_WITH_DEMO_PLATFORM = [
|
||||
Platform.BUTTON,
|
||||
@ -88,9 +87,27 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
|
||||
# Start a reauth flow
|
||||
config_entry.async_start_reauth(hass)
|
||||
|
||||
# Notify backup listeners
|
||||
hass.async_create_task(_notify_backup_listeners(hass), eager_start=False)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload config entry."""
|
||||
# Notify backup listeners
|
||||
hass.async_create_task(_notify_backup_listeners(hass), eager_start=False)
|
||||
|
||||
return await hass.config_entries.async_unload_platforms(
|
||||
entry, COMPONENTS_WITH_DEMO_PLATFORM
|
||||
)
|
||||
|
||||
|
||||
async def _notify_backup_listeners(hass: HomeAssistant) -> None:
|
||||
for listener in hass.data.get(DATA_BACKUP_AGENT_LISTENERS, []):
|
||||
listener()
|
||||
|
||||
|
||||
def _create_issues(hass: HomeAssistant) -> None:
|
||||
"""Create some issue registry issues."""
|
||||
async_create_issue(
|
||||
|
@ -8,7 +8,9 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.backup import AddonInfo, AgentBackup, BackupAgent, Folder
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
||||
from . import DATA_BACKUP_AGENT_LISTENERS, DOMAIN
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -17,12 +19,35 @@ async def async_get_backup_agents(
|
||||
hass: HomeAssistant,
|
||||
) -> list[BackupAgent]:
|
||||
"""Register the backup agents."""
|
||||
if not hass.config_entries.async_loaded_entries(DOMAIN):
|
||||
LOGGER.info("No config entry found or entry is not loaded")
|
||||
return []
|
||||
return [KitchenSinkBackupAgent("syncer")]
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_backup_agents_listener(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
listener: Callable[[], None],
|
||||
**kwargs: Any,
|
||||
) -> Callable[[], None]:
|
||||
"""Register a listener to be called when agents are added or removed."""
|
||||
hass.data.setdefault(DATA_BACKUP_AGENT_LISTENERS, []).append(listener)
|
||||
|
||||
@callback
|
||||
def remove_listener() -> None:
|
||||
"""Remove the listener."""
|
||||
hass.data[DATA_BACKUP_AGENT_LISTENERS].remove(listener)
|
||||
|
||||
return remove_listener
|
||||
|
||||
|
||||
class KitchenSinkBackupAgent(BackupAgent):
|
||||
"""Kitchen sink backup agent."""
|
||||
|
||||
domain = DOMAIN
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialize the kitchen sink backup sync agent."""
|
||||
super().__init__()
|
||||
|
12
homeassistant/components/kitchen_sink/const.py
Normal file
12
homeassistant/components/kitchen_sink/const.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""Constants for the Kitchen Sink integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
DOMAIN = "kitchen_sink"
|
||||
DATA_BACKUP_AGENT_LISTENERS: HassKey[list[Callable[[], None]]] = HassKey(
|
||||
f"{DOMAIN}.backup_agent_listeners"
|
||||
)
|
@ -57,6 +57,8 @@ TEST_DOMAIN = "test"
|
||||
class BackupAgentTest(BackupAgent):
|
||||
"""Test backup agent."""
|
||||
|
||||
domain = "test"
|
||||
|
||||
def __init__(self, name: str, backups: list[AgentBackup] | None = None) -> None:
|
||||
"""Initialize the backup agent."""
|
||||
self.name = name
|
||||
|
@ -6,6 +6,7 @@ import asyncio
|
||||
from collections.abc import Generator
|
||||
from io import StringIO
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, Mock, call, mock_open, patch
|
||||
|
||||
@ -18,6 +19,7 @@ from homeassistant.components.backup import (
|
||||
BackupManager,
|
||||
BackupPlatformProtocol,
|
||||
Folder,
|
||||
LocalBackupAgent,
|
||||
backup as local_backup_platform,
|
||||
)
|
||||
from homeassistant.components.backup.const import DATA_MANAGER
|
||||
@ -534,21 +536,86 @@ async def test_loading_platforms(
|
||||
|
||||
assert not manager.platforms
|
||||
|
||||
get_agents_mock = AsyncMock(return_value=[])
|
||||
|
||||
await _setup_backup_platform(
|
||||
hass,
|
||||
platform=Mock(
|
||||
async_pre_backup=AsyncMock(),
|
||||
async_post_backup=AsyncMock(),
|
||||
async_get_backup_agents=AsyncMock(),
|
||||
async_get_backup_agents=get_agents_mock,
|
||||
),
|
||||
)
|
||||
await manager.load_platforms()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(manager.platforms) == 1
|
||||
|
||||
assert "Loaded 1 platforms" in caplog.text
|
||||
|
||||
get_agents_mock.assert_called_once_with(hass)
|
||||
|
||||
|
||||
class LocalBackupAgentTest(BackupAgentTest, LocalBackupAgent):
|
||||
"""Local backup agent."""
|
||||
|
||||
def get_backup_path(self, backup_id: str) -> Path:
|
||||
"""Return the local path to a backup."""
|
||||
return "test.tar"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("agent_class", "num_local_agents"),
|
||||
[(LocalBackupAgentTest, 2), (BackupAgentTest, 1)],
|
||||
)
|
||||
async def test_loading_platform_with_listener(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
agent_class: type[BackupAgentTest],
|
||||
num_local_agents: int,
|
||||
) -> None:
|
||||
"""Test loading a backup agent platform which can be listened to."""
|
||||
ws_client = await hass_ws_client(hass)
|
||||
assert await async_setup_component(hass, DOMAIN, {})
|
||||
manager = hass.data[DATA_MANAGER]
|
||||
|
||||
get_agents_mock = AsyncMock(return_value=[agent_class("remote1", backups=[])])
|
||||
register_listener_mock = Mock()
|
||||
|
||||
await _setup_backup_platform(
|
||||
hass,
|
||||
domain="test",
|
||||
platform=Mock(
|
||||
async_get_backup_agents=get_agents_mock,
|
||||
async_register_backup_agents_listener=register_listener_mock,
|
||||
),
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await ws_client.send_json_auto_id({"type": "backup/agents/info"})
|
||||
resp = await ws_client.receive_json()
|
||||
assert resp["result"]["agents"] == [
|
||||
{"agent_id": "backup.local"},
|
||||
{"agent_id": "test.remote1"},
|
||||
]
|
||||
assert len(manager.local_backup_agents) == num_local_agents
|
||||
|
||||
get_agents_mock.assert_called_once_with(hass)
|
||||
register_listener_mock.assert_called_once_with(hass, listener=ANY)
|
||||
|
||||
get_agents_mock.reset_mock()
|
||||
get_agents_mock.return_value = [agent_class("remote2", backups=[])]
|
||||
listener = register_listener_mock.call_args[1]["listener"]
|
||||
listener()
|
||||
|
||||
get_agents_mock.assert_called_once_with(hass)
|
||||
await ws_client.send_json_auto_id({"type": "backup/agents/info"})
|
||||
resp = await ws_client.receive_json()
|
||||
assert resp["result"]["agents"] == [
|
||||
{"agent_id": "backup.local"},
|
||||
{"agent_id": "test.remote2"},
|
||||
]
|
||||
assert len(manager.local_backup_agents) == num_local_agents
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"platform_mock",
|
||||
|
@ -26,7 +26,10 @@ from tests.typing import ClientSessionGenerator, MagicMock, WebSocketGenerator
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_integration(
|
||||
hass: HomeAssistant, aioclient_mock: AiohttpClientMocker, cloud: MagicMock
|
||||
hass: HomeAssistant,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
cloud: MagicMock,
|
||||
cloud_logged_in: None,
|
||||
) -> AsyncGenerator[None]:
|
||||
"""Set up cloud integration."""
|
||||
with patch("homeassistant.components.backup.is_hassio", return_value=False):
|
||||
|
@ -57,6 +57,27 @@ async def test_agents_info(
|
||||
"agents": [{"agent_id": "backup.local"}, {"agent_id": "kitchen_sink.syncer"}],
|
||||
}
|
||||
|
||||
config_entry = hass.config_entries.async_entries(DOMAIN)[0]
|
||||
await hass.config_entries.async_unload(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await client.send_json_auto_id({"type": "backup/agents/info"})
|
||||
response = await client.receive_json()
|
||||
|
||||
assert response["success"]
|
||||
assert response["result"] == {"agents": [{"agent_id": "backup.local"}]}
|
||||
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await client.send_json_auto_id({"type": "backup/agents/info"})
|
||||
response = await client.receive_json()
|
||||
|
||||
assert response["success"]
|
||||
assert response["result"] == {
|
||||
"agents": [{"agent_id": "backup.local"}, {"agent_id": "kitchen_sink.syncer"}],
|
||||
}
|
||||
|
||||
|
||||
async def test_agents_list_backups(
|
||||
hass: HomeAssistant,
|
||||
|
Loading…
x
Reference in New Issue
Block a user