From eee6e8a2c3fb503b457908926a65a4659d17e321 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Thu, 10 Apr 2025 16:58:46 +0200 Subject: [PATCH] Add WS command config_entries/flow/subscribe (#142459) --- .../components/config/config_entries.py | 65 ++++- homeassistant/components/onboarding/views.py | 32 ++- homeassistant/config_entries.py | 31 +++ .../components/config/test_config_entries.py | 251 ++++++++++++++++++ tests/components/onboarding/test_views.py | 83 +++++- 5 files changed, 458 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index 74c9b5a9d0c..6e2d4a5da49 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -58,7 +58,8 @@ def async_setup(hass: HomeAssistant) -> bool: websocket_api.async_register_command(hass, config_entry_get_single) websocket_api.async_register_command(hass, config_entry_update) websocket_api.async_register_command(hass, config_entries_subscribe) - websocket_api.async_register_command(hass, config_entries_progress) + websocket_api.async_register_command(hass, config_entries_flow_progress) + websocket_api.async_register_command(hass, config_entries_flow_subscribe) websocket_api.async_register_command(hass, ignore_config_flow) websocket_api.async_register_command(hass, config_subentry_delete) @@ -357,7 +358,7 @@ class SubentryManagerFlowResourceView( @websocket_api.require_admin @websocket_api.websocket_command({"type": "config_entries/flow/progress"}) -def config_entries_progress( +def config_entries_flow_progress( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any], @@ -378,6 +379,66 @@ def config_entries_progress( ) +@websocket_api.require_admin +@websocket_api.websocket_command({"type": "config_entries/flow/subscribe"}) +def config_entries_flow_subscribe( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Subscribe to non user created flows being initiated or removed. + + When initiating the subscription, the current flows are sent to the client. + + Example of a non-user initiated flow is a discovered Hue hub that + requires user interaction to finish setup. + """ + + @callback + def async_on_flow_init_remove(change_type: str, flow_id: str) -> None: + """Forward config entry state events to websocket.""" + if change_type == "removed": + connection.send_message( + websocket_api.event_message( + msg["id"], + [{"type": change_type, "flow_id": flow_id}], + ) + ) + return + # change_type == "added" + connection.send_message( + websocket_api.event_message( + msg["id"], + [ + { + "type": change_type, + "flow_id": flow_id, + "flow": hass.config_entries.flow.async_get(flow_id), + } + ], + ) + ) + + connection.subscriptions[msg["id"]] = hass.config_entries.flow.async_subscribe_flow( + async_on_flow_init_remove + ) + connection.send_message( + websocket_api.event_message( + msg["id"], + [ + {"type": None, "flow_id": flw["flow_id"], "flow": flw} + for flw in hass.config_entries.flow.async_progress() + if flw["context"]["source"] + not in ( + config_entries.SOURCE_RECONFIGURE, + config_entries.SOURCE_USER, + ) + ], + ) + ) + connection.send_result(msg["id"]) + + def send_entry_not_found( connection: websocket_api.ActiveConnection, msg_id: int ) -> None: diff --git a/homeassistant/components/onboarding/views.py b/homeassistant/components/onboarding/views.py index 978e16963d9..47d9b1cb98b 100644 --- a/homeassistant/components/onboarding/views.py +++ b/homeassistant/components/onboarding/views.py @@ -31,7 +31,12 @@ from homeassistant.helpers import area_registry as ar from homeassistant.helpers.backup import async_get_manager as async_get_backup_manager from homeassistant.helpers.system_info import async_get_system_info from homeassistant.helpers.translation import async_get_translations -from homeassistant.setup import SetupPhases, async_pause_setup, async_setup_component +from homeassistant.setup import ( + SetupPhases, + async_pause_setup, + async_setup_component, + async_wait_component, +) if TYPE_CHECKING: from . import OnboardingData, OnboardingStorage, OnboardingStoreData @@ -60,6 +65,7 @@ async def async_setup( hass.http.register_view(BackupInfoView(data)) hass.http.register_view(RestoreBackupView(data)) hass.http.register_view(UploadBackupView(data)) + hass.http.register_view(WaitIntegrationOnboardingView(data)) await setup_cloud_views(hass, data) @@ -298,6 +304,30 @@ class IntegrationOnboardingView(_BaseOnboardingStepView): return self.json({"auth_code": auth_code}) +class WaitIntegrationOnboardingView(_NoAuthBaseOnboardingView): + """Get backup info view.""" + + url = "/api/onboarding/integration/wait" + name = "api:onboarding:integration:wait" + + @RequestDataValidator( + vol.Schema( + { + vol.Required("domain"): str, + } + ) + ) + async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response: + """Handle wait for integration command.""" + hass = request.app[KEY_HASS] + domain = data["domain"] + return self.json( + { + "integration_loaded": await async_wait_component(hass, domain), + } + ) + + class AnalyticsOnboardingView(_BaseOnboardingStepView): """View to finish analytics onboarding step.""" diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 705cc01061b..30bd075ed95 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -1375,6 +1375,7 @@ class ConfigEntriesFlowManager( function=self._async_fire_discovery_event, background=True, ) + self._flow_subscriptions: list[Callable[[str, str], None]] = [] async def async_wait_import_flow_initialized(self, handler: str) -> None: """Wait till all import flows in progress are initialized.""" @@ -1461,6 +1462,13 @@ class ConfigEntriesFlowManager( # Fire discovery event await self._discovery_event_debouncer.async_call() + if result["type"] != data_entry_flow.FlowResultType.ABORT and source in ( + DISCOVERY_SOURCES | {SOURCE_REAUTH} + ): + # Notify listeners that a flow is created + for subscription in self._flow_subscriptions: + subscription("added", flow.flow_id) + return result async def _async_init( @@ -1739,6 +1747,29 @@ class ConfigEntriesFlowManager( return True return False + @callback + def async_subscribe_flow( + self, listener: Callable[[str, str], None] + ) -> CALLBACK_TYPE: + """Subscribe to non user initiated flow init or remove.""" + self._flow_subscriptions.append(listener) + return lambda: self._flow_subscriptions.remove(listener) + + @callback + def _async_remove_flow_progress(self, flow_id: str) -> None: + """Remove a flow from in progress.""" + flow = self._progress.get(flow_id) + super()._async_remove_flow_progress(flow_id) + # Fire remove event for initialized non user initiated flows + if ( + not flow + or flow.cur_step is None + or flow.source not in (DISCOVERY_SOURCES | {SOURCE_REAUTH}) + ): + return + for listeners in self._flow_subscriptions: + listeners("removed", flow_id) + class ConfigEntryItems(UserDict[str, ConfigEntry]): """Container for config items, maps config_entry_id -> entry. diff --git a/tests/components/config/test_config_entries.py b/tests/components/config/test_config_entries.py index c6e65c312bb..6784866ea4b 100644 --- a/tests/components/config/test_config_entries.py +++ b/tests/components/config/test_config_entries.py @@ -8,6 +8,7 @@ from unittest.mock import ANY, AsyncMock, patch from aiohttp.test_utils import TestClient from freezegun.api import FrozenDateTimeFactory import pytest +from pytest_unordered import unordered import voluptuous as vol from homeassistant import config_entries as core_ce, data_entry_flow, loader @@ -882,6 +883,256 @@ async def test_get_progress_flow_unauth( assert resp2.status == HTTPStatus.UNAUTHORIZED +async def test_get_progress_subscribe( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator +) -> None: + """Test querying for the flows that are in progress.""" + assert await async_setup_component(hass, "config", {}) + mock_platform(hass, "test.config_flow", None) + ws_client = await hass_ws_client(hass) + + mock_integration( + hass, MockModule("test", async_setup_entry=AsyncMock(return_value=True)) + ) + + entry = MockConfigEntry(domain="test", title="Test", entry_id="1234") + entry.add_to_hass(hass) + + class TestFlow(core_ce.ConfigFlow): + VERSION = 5 + + async def async_step_bluetooth( + self, discovery_info: HassioServiceInfo + ) -> ConfigFlowResult: + """Handle a bluetooth discovery.""" + return self.async_abort(reason="already_configured") + + async def async_step_hassio( + self, discovery_info: HassioServiceInfo + ) -> ConfigFlowResult: + """Handle a Hass.io discovery.""" + return await self.async_step_account() + + async def async_step_account(self, user_input: dict[str, Any] | None = None): + """Show a form to the user.""" + return self.async_show_form(step_id="account") + + async def async_step_user(self, user_input: dict[str, Any] | None = None): + """Handle a config flow initialized by the user.""" + return await self.async_step_account() + + async def async_step_reauth(self, user_input: dict[str, Any] | None = None): + """Handle a reauthentication flow.""" + nonlocal entry + assert self._get_reauth_entry() is entry + return await self.async_step_account() + + async def async_step_reconfigure( + self, user_input: dict[str, Any] | None = None + ): + """Handle a reconfiguration flow initialized by the user.""" + nonlocal entry + assert self._get_reconfigure_entry() is entry + return await self.async_step_account() + + await ws_client.send_json({"id": 1, "type": "config_entries/flow/subscribe"}) + response = await ws_client.receive_json() + assert response == {"id": 1, "event": [], "type": "event"} + response = await ws_client.receive_json() + assert response == {"id": 1, "result": None, "success": True, "type": "result"} + + flow_context = { + "bluetooth": {"source": core_ce.SOURCE_BLUETOOTH}, + "hassio": {"source": core_ce.SOURCE_HASSIO}, + "user": {"source": core_ce.SOURCE_USER}, + "reauth": {"source": core_ce.SOURCE_REAUTH, "entry_id": "1234"}, + "reconfigure": {"source": core_ce.SOURCE_RECONFIGURE, "entry_id": "1234"}, + } + forms = {} + + with mock_config_flow("test", TestFlow): + for key, context in flow_context.items(): + forms[key] = await hass.config_entries.flow.async_init( + "test", context=context + ) + + assert forms["bluetooth"]["type"] == data_entry_flow.FlowResultType.ABORT + for key in ("hassio", "user", "reauth", "reconfigure"): + assert forms[key]["type"] == data_entry_flow.FlowResultType.FORM + assert forms[key]["step_id"] == "account" + + for key in ("hassio", "user", "reauth", "reconfigure"): + hass.config_entries.flow.async_abort(forms[key]["flow_id"]) + + # Uninitialized flows and flows with SOURCE_USER and SOURCE_RECONFIGURE + # should be filtered out + for key in ("hassio", "reauth"): + response = await ws_client.receive_json() + assert response == { + "event": [ + { + "flow": { + "flow_id": forms[key]["flow_id"], + "handler": "test", + "step_id": "account", + "context": flow_context[key], + }, + "flow_id": forms[key]["flow_id"], + "type": "added", + } + ], + "id": 1, + "type": "event", + } + for key in ("hassio", "reauth"): + response = await ws_client.receive_json() + assert response == { + "event": [ + { + "flow_id": forms[key]["flow_id"], + "type": "removed", + } + ], + "id": 1, + "type": "event", + } + + +async def test_get_progress_subscribe_in_progress( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator +) -> None: + """Test querying for the flows that are in progress.""" + assert await async_setup_component(hass, "config", {}) + mock_platform(hass, "test.config_flow", None) + ws_client = await hass_ws_client(hass) + + mock_integration( + hass, MockModule("test", async_setup_entry=AsyncMock(return_value=True)) + ) + + entry = MockConfigEntry(domain="test", title="Test", entry_id="1234") + entry.add_to_hass(hass) + + class TestFlow(core_ce.ConfigFlow): + VERSION = 5 + + async def async_step_bluetooth( + self, discovery_info: HassioServiceInfo + ) -> ConfigFlowResult: + """Handle a bluetooth discovery.""" + return self.async_abort(reason="already_configured") + + async def async_step_hassio( + self, discovery_info: HassioServiceInfo + ) -> ConfigFlowResult: + """Handle a Hass.io discovery.""" + return await self.async_step_account() + + async def async_step_account(self, user_input: dict[str, Any] | None = None): + """Show a form to the user.""" + return self.async_show_form(step_id="account") + + async def async_step_user(self, user_input: dict[str, Any] | None = None): + """Handle a config flow initialized by the user.""" + return await self.async_step_account() + + async def async_step_reauth(self, user_input: dict[str, Any] | None = None): + """Handle a reauthentication flow.""" + nonlocal entry + assert self._get_reauth_entry() is entry + return await self.async_step_account() + + async def async_step_reconfigure( + self, user_input: dict[str, Any] | None = None + ): + """Handle a reconfiguration flow initialized by the user.""" + nonlocal entry + assert self._get_reconfigure_entry() is entry + return await self.async_step_account() + + flow_context = { + "bluetooth": {"source": core_ce.SOURCE_BLUETOOTH}, + "hassio": {"source": core_ce.SOURCE_HASSIO}, + "user": {"source": core_ce.SOURCE_USER}, + "reauth": {"source": core_ce.SOURCE_REAUTH, "entry_id": "1234"}, + "reconfigure": {"source": core_ce.SOURCE_RECONFIGURE, "entry_id": "1234"}, + } + forms = {} + + with mock_config_flow("test", TestFlow): + for key, context in flow_context.items(): + forms[key] = await hass.config_entries.flow.async_init( + "test", context=context + ) + + assert forms["bluetooth"]["type"] == data_entry_flow.FlowResultType.ABORT + for key in ("hassio", "user", "reauth", "reconfigure"): + assert forms[key]["type"] == data_entry_flow.FlowResultType.FORM + assert forms[key]["step_id"] == "account" + + await ws_client.send_json({"id": 1, "type": "config_entries/flow/subscribe"}) + + # Uninitialized flows and flows with SOURCE_USER and SOURCE_RECONFIGURE + # should be filtered out + responses = [] + responses.append(await ws_client.receive_json()) + assert responses == [ + { + "event": unordered( + [ + { + "flow": { + "flow_id": forms[key]["flow_id"], + "handler": "test", + "step_id": "account", + "context": flow_context[key], + }, + "flow_id": forms[key]["flow_id"], + "type": None, + } + for key in ("hassio", "reauth") + ] + ), + "id": 1, + "type": "event", + } + ] + + response = await ws_client.receive_json() + assert response == {"id": ANY, "result": None, "success": True, "type": "result"} + + for key in ("hassio", "user", "reauth", "reconfigure"): + hass.config_entries.flow.async_abort(forms[key]["flow_id"]) + + for key in ("hassio", "reauth"): + response = await ws_client.receive_json() + assert response == { + "event": [ + { + "flow_id": forms[key]["flow_id"], + "type": "removed", + } + ], + "id": 1, + "type": "event", + } + + +async def test_get_progress_subscribe_unauth( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator, hass_admin_user: MockUser +) -> None: + """Test we can't subscribe to flows.""" + assert await async_setup_component(hass, "config", {}) + hass_admin_user.groups = [] + ws_client = await hass_ws_client(hass) + + await ws_client.send_json({"id": 5, "type": "config_entries/flow/subscribe"}) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"]["code"] == "unauthorized" + + async def test_options_flow(hass: HomeAssistant, client: TestClient) -> None: """Test we can change options.""" diff --git a/tests/components/onboarding/test_views.py b/tests/components/onboarding/test_views.py index 9c5e93e49fe..6a6be1da470 100644 --- a/tests/components/onboarding/test_views.py +++ b/tests/components/onboarding/test_views.py @@ -21,14 +21,16 @@ from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import area_registry as ar from homeassistant.helpers.backup import async_initialize_backup -from homeassistant.setup import async_setup_component +from homeassistant.setup import async_set_domains_to_be_loaded, async_setup_component from . import mock_storage from tests.common import ( CLIENT_ID, CLIENT_REDIRECT_URI, + MockModule, MockUser, + mock_integration, register_auth_provider, ) from tests.test_util.aiohttp import AiohttpClientMocker @@ -1205,3 +1207,82 @@ async def test_onboarding_cloud_status( assert req.status == HTTPStatus.OK data = await req.json() assert data == {"logged_in": False} + + +@pytest.mark.parametrize( + ("domain", "expected_result"), + [ + ("onboarding", {"integration_loaded": True}), + ("non_existing_domain", {"integration_loaded": False}), + ], +) +async def test_wait_integration( + hass: HomeAssistant, + hass_storage: dict[str, Any], + hass_client: ClientSessionGenerator, + domain: str, + expected_result: dict[str, Any], +) -> None: + """Test we can get wait for an integration to load.""" + mock_storage(hass_storage, {"done": []}) + + assert await async_setup_component(hass, "onboarding", {}) + await hass.async_block_till_done() + + client = await hass_client() + req = await client.post("/api/onboarding/integration/wait", json={"domain": domain}) + + assert req.status == HTTPStatus.OK + data = await req.json() + assert data == expected_result + + +async def test_wait_integration_startup( + hass: HomeAssistant, + hass_storage: dict[str, Any], + hass_client: ClientSessionGenerator, +) -> None: + """Test we can get wait for an integration to load during startup.""" + mock_storage(hass_storage, {"done": []}) + + assert await async_setup_component(hass, "onboarding", {}) + await hass.async_block_till_done() + client = await hass_client() + + setup_stall = asyncio.Event() + setup_started = asyncio.Event() + + async def mock_setup(hass: HomeAssistant, _) -> bool: + setup_started.set() + await setup_stall.wait() + return True + + mock_integration(hass, MockModule("test", async_setup=mock_setup)) + + # The integration is not loaded, and is also not scheduled to load + req = await client.post("/api/onboarding/integration/wait", json={"domain": "test"}) + assert req.status == HTTPStatus.OK + data = await req.json() + assert data == {"integration_loaded": False} + + # Mark the component as scheduled to be loaded + async_set_domains_to_be_loaded(hass, {"test"}) + + # Start loading the component, including its config entries + hass.async_create_task(async_setup_component(hass, "test", {})) + await setup_started.wait() + + # The component is not yet loaded + assert "test" not in hass.config.components + + # Allow setup to proceed + setup_stall.set() + + # The component is scheduled to load, this will block until the config entry is loaded + req = await client.post("/api/onboarding/integration/wait", json={"domain": "test"}) + assert req.status == HTTPStatus.OK + data = await req.json() + assert data == {"integration_loaded": True} + + # The component has been loaded + assert "test" in hass.config.components