Add async_register_backup_agents_listener to cloud/backup (#133584)

* Add async_register_backup_agents_listener to cloud/backup

* Coverage

* more coverage
This commit is contained in:
Joakim Sørensen 2024-12-20 08:55:00 +01:00 committed by GitHub
parent ad34bc8910
commit 10191e7a23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 126 additions and 2 deletions

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import base64 import base64
from collections.abc import AsyncIterator, Callable, Coroutine from collections.abc import AsyncIterator, Callable, Coroutine, Mapping
import hashlib import hashlib
from typing import Any, Self from typing import Any, Self
@ -18,9 +18,10 @@ from hass_nabucasa.cloud_api import (
from homeassistant.components.backup import AgentBackup, BackupAgent, BackupAgentError from homeassistant.components.backup import AgentBackup, BackupAgent, BackupAgentError
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from .client import CloudClient from .client import CloudClient
from .const import DATA_CLOUD, DOMAIN from .const import DATA_CLOUD, DOMAIN, EVENT_CLOUD_EVENT
_STORAGE_BACKUP = "backup" _STORAGE_BACKUP = "backup"
@ -45,6 +46,31 @@ async def async_get_backup_agents(
return [CloudBackupAgent(hass=hass, cloud=cloud)] return [CloudBackupAgent(hass=hass, cloud=cloud)]
@callback
def async_register_backup_agents_listener(
hass: HomeAssistant,
*,
listener: Callable[[], None],
**kwargs: Any,
) -> Callable[[], None]:
"""Register a listener to be called when agents are added or removed."""
@callback
def unsub() -> None:
"""Unsubscribe from events."""
unsub_signal()
@callback
def handle_event(data: Mapping[str, Any]) -> None:
"""Handle event."""
if data["type"] not in ("login", "logout"):
return
listener()
unsub_signal = async_dispatcher_connect(hass, EVENT_CLOUD_EVENT, handle_event)
return unsub
class ChunkAsyncStreamIterator: class ChunkAsyncStreamIterator:
"""Async iterator for chunked streams. """Async iterator for chunked streams.

View File

@ -18,6 +18,8 @@ DATA_CLOUD: HassKey[Cloud[CloudClient]] = HassKey(DOMAIN)
DATA_PLATFORMS_SETUP: HassKey[dict[str, asyncio.Event]] = HassKey( DATA_PLATFORMS_SETUP: HassKey[dict[str, asyncio.Event]] = HassKey(
"cloud_platforms_setup" "cloud_platforms_setup"
) )
EVENT_CLOUD_EVENT = "cloud_event"
REQUEST_TIMEOUT = 10 REQUEST_TIMEOUT = 10
PREF_ENABLE_ALEXA = "alexa_enabled" PREF_ENABLE_ALEXA = "alexa_enabled"

View File

@ -34,6 +34,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.util.location import async_detect_location_info from homeassistant.util.location import async_detect_location_info
from .alexa_config import entity_supported as entity_supported_by_alexa from .alexa_config import entity_supported as entity_supported_by_alexa
@ -41,6 +42,7 @@ from .assist_pipeline import async_create_cloud_pipeline
from .client import CloudClient from .client import CloudClient
from .const import ( from .const import (
DATA_CLOUD, DATA_CLOUD,
EVENT_CLOUD_EVENT,
LOGIN_MFA_TIMEOUT, LOGIN_MFA_TIMEOUT,
PREF_ALEXA_REPORT_STATE, PREF_ALEXA_REPORT_STATE,
PREF_DISABLE_2FA, PREF_DISABLE_2FA,
@ -278,6 +280,8 @@ class CloudLoginView(HomeAssistantView):
new_cloud_pipeline_id = await async_create_cloud_pipeline(hass) new_cloud_pipeline_id = await async_create_cloud_pipeline(hass)
else: else:
new_cloud_pipeline_id = None new_cloud_pipeline_id = None
async_dispatcher_send(hass, EVENT_CLOUD_EVENT, {"type": "login"})
return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id}) return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id})
@ -297,6 +301,7 @@ class CloudLogoutView(HomeAssistantView):
async with asyncio.timeout(REQUEST_TIMEOUT): async with asyncio.timeout(REQUEST_TIMEOUT):
await cloud.logout() await cloud.logout()
async_dispatcher_send(hass, EVENT_CLOUD_EVENT, {"type": "logout"})
return self.json_message("ok") return self.json_message("ok")

View File

@ -17,7 +17,10 @@ from homeassistant.components.backup import (
Folder, Folder,
) )
from homeassistant.components.cloud import DOMAIN from homeassistant.components.cloud import DOMAIN
from homeassistant.components.cloud.backup import async_register_backup_agents_listener
from homeassistant.components.cloud.const import EVENT_CLOUD_EVENT
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.test_util.aiohttp import AiohttpClientMocker from tests.test_util.aiohttp import AiohttpClientMocker
@ -576,3 +579,49 @@ async def test_agents_delete_not_found(
assert response["success"] assert response["success"]
assert response["result"] == {"agent_errors": {}} assert response["result"] == {"agent_errors": {}}
@pytest.mark.parametrize("event_type", ["login", "logout"])
async def test_calling_listener_on_login_logout(
hass: HomeAssistant,
event_type: str,
) -> None:
"""Test calling listener for login and logout events."""
listener = MagicMock()
async_register_backup_agents_listener(hass, listener=listener)
assert listener.call_count == 0
async_dispatcher_send(hass, EVENT_CLOUD_EVENT, {"type": event_type})
await hass.async_block_till_done()
assert listener.call_count == 1
async def test_not_calling_listener_after_unsub(hass: HomeAssistant) -> None:
"""Test only calling listener until unsub."""
listener = MagicMock()
unsub = async_register_backup_agents_listener(hass, listener=listener)
assert listener.call_count == 0
async_dispatcher_send(hass, EVENT_CLOUD_EVENT, {"type": "login"})
await hass.async_block_till_done()
assert listener.call_count == 1
unsub()
async_dispatcher_send(hass, EVENT_CLOUD_EVENT, {"type": "login"})
await hass.async_block_till_done()
assert listener.call_count == 1
async def test_not_calling_listener_with_unknown_event_type(
hass: HomeAssistant,
) -> None:
"""Test not calling listener if we did not get the expected event type."""
listener = MagicMock()
async_register_backup_agents_listener(hass, listener=listener)
assert listener.call_count == 0
async_dispatcher_send(hass, EVENT_CLOUD_EVENT, {"type": "unknown"})
await hass.async_block_till_done()
assert listener.call_count == 0

View File

@ -1819,3 +1819,45 @@ async def test_api_calls_require_admin(
resp = await client.post(endpoint, json=data) resp = await client.post(endpoint, json=data)
assert resp.status == HTTPStatus.UNAUTHORIZED assert resp.status == HTTPStatus.UNAUTHORIZED
async def test_login_view_dispatch_event(
hass: HomeAssistant,
cloud: MagicMock,
hass_client: ClientSessionGenerator,
) -> None:
"""Test dispatching event while logging in."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
await hass.async_block_till_done()
cloud_client = await hass_client()
with patch(
"homeassistant.components.cloud.http_api.async_dispatcher_send"
) as async_dispatcher_send_mock:
await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
)
assert async_dispatcher_send_mock.call_count == 1
assert async_dispatcher_send_mock.mock_calls[0][1][1] == "cloud_event"
assert async_dispatcher_send_mock.mock_calls[0][1][2] == {"type": "login"}
async def test_logout_view_dispatch_event(
cloud: MagicMock,
setup_cloud: None,
hass_client: ClientSessionGenerator,
) -> None:
"""Test dispatching event while logging out."""
cloud_client = await hass_client()
with patch(
"homeassistant.components.cloud.http_api.async_dispatcher_send"
) as async_dispatcher_send_mock:
await cloud_client.post("/api/cloud/logout")
assert async_dispatcher_send_mock.call_count == 1
assert async_dispatcher_send_mock.mock_calls[0][1][1] == "cloud_event"
assert async_dispatcher_send_mock.mock_calls[0][1][2] == {"type": "logout"}