Add WS command config_entries/flow/subscribe (#142459)

This commit is contained in:
Erik Montnemery 2025-04-10 16:58:46 +02:00 committed by GitHub
parent a26cdef427
commit eee6e8a2c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 458 additions and 4 deletions

View File

@ -58,7 +58,8 @@ def async_setup(hass: HomeAssistant) -> bool:
websocket_api.async_register_command(hass, config_entry_get_single)
websocket_api.async_register_command(hass, config_entry_update)
websocket_api.async_register_command(hass, config_entries_subscribe)
websocket_api.async_register_command(hass, config_entries_progress)
websocket_api.async_register_command(hass, config_entries_flow_progress)
websocket_api.async_register_command(hass, config_entries_flow_subscribe)
websocket_api.async_register_command(hass, ignore_config_flow)
websocket_api.async_register_command(hass, config_subentry_delete)
@ -357,7 +358,7 @@ class SubentryManagerFlowResourceView(
@websocket_api.require_admin
@websocket_api.websocket_command({"type": "config_entries/flow/progress"})
def config_entries_progress(
def config_entries_flow_progress(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
@ -378,6 +379,66 @@ def config_entries_progress(
)
@websocket_api.require_admin
@websocket_api.websocket_command({"type": "config_entries/flow/subscribe"})
def config_entries_flow_subscribe(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Subscribe to non user created flows being initiated or removed.
When initiating the subscription, the current flows are sent to the client.
Example of a non-user initiated flow is a discovered Hue hub that
requires user interaction to finish setup.
"""
@callback
def async_on_flow_init_remove(change_type: str, flow_id: str) -> None:
"""Forward config entry state events to websocket."""
if change_type == "removed":
connection.send_message(
websocket_api.event_message(
msg["id"],
[{"type": change_type, "flow_id": flow_id}],
)
)
return
# change_type == "added"
connection.send_message(
websocket_api.event_message(
msg["id"],
[
{
"type": change_type,
"flow_id": flow_id,
"flow": hass.config_entries.flow.async_get(flow_id),
}
],
)
)
connection.subscriptions[msg["id"]] = hass.config_entries.flow.async_subscribe_flow(
async_on_flow_init_remove
)
connection.send_message(
websocket_api.event_message(
msg["id"],
[
{"type": None, "flow_id": flw["flow_id"], "flow": flw}
for flw in hass.config_entries.flow.async_progress()
if flw["context"]["source"]
not in (
config_entries.SOURCE_RECONFIGURE,
config_entries.SOURCE_USER,
)
],
)
)
connection.send_result(msg["id"])
def send_entry_not_found(
connection: websocket_api.ActiveConnection, msg_id: int
) -> None:

View File

@ -31,7 +31,12 @@ from homeassistant.helpers import area_registry as ar
from homeassistant.helpers.backup import async_get_manager as async_get_backup_manager
from homeassistant.helpers.system_info import async_get_system_info
from homeassistant.helpers.translation import async_get_translations
from homeassistant.setup import SetupPhases, async_pause_setup, async_setup_component
from homeassistant.setup import (
SetupPhases,
async_pause_setup,
async_setup_component,
async_wait_component,
)
if TYPE_CHECKING:
from . import OnboardingData, OnboardingStorage, OnboardingStoreData
@ -60,6 +65,7 @@ async def async_setup(
hass.http.register_view(BackupInfoView(data))
hass.http.register_view(RestoreBackupView(data))
hass.http.register_view(UploadBackupView(data))
hass.http.register_view(WaitIntegrationOnboardingView(data))
await setup_cloud_views(hass, data)
@ -298,6 +304,30 @@ class IntegrationOnboardingView(_BaseOnboardingStepView):
return self.json({"auth_code": auth_code})
class WaitIntegrationOnboardingView(_NoAuthBaseOnboardingView):
"""Get backup info view."""
url = "/api/onboarding/integration/wait"
name = "api:onboarding:integration:wait"
@RequestDataValidator(
vol.Schema(
{
vol.Required("domain"): str,
}
)
)
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
"""Handle wait for integration command."""
hass = request.app[KEY_HASS]
domain = data["domain"]
return self.json(
{
"integration_loaded": await async_wait_component(hass, domain),
}
)
class AnalyticsOnboardingView(_BaseOnboardingStepView):
"""View to finish analytics onboarding step."""

View File

@ -1375,6 +1375,7 @@ class ConfigEntriesFlowManager(
function=self._async_fire_discovery_event,
background=True,
)
self._flow_subscriptions: list[Callable[[str, str], None]] = []
async def async_wait_import_flow_initialized(self, handler: str) -> None:
"""Wait till all import flows in progress are initialized."""
@ -1461,6 +1462,13 @@ class ConfigEntriesFlowManager(
# Fire discovery event
await self._discovery_event_debouncer.async_call()
if result["type"] != data_entry_flow.FlowResultType.ABORT and source in (
DISCOVERY_SOURCES | {SOURCE_REAUTH}
):
# Notify listeners that a flow is created
for subscription in self._flow_subscriptions:
subscription("added", flow.flow_id)
return result
async def _async_init(
@ -1739,6 +1747,29 @@ class ConfigEntriesFlowManager(
return True
return False
@callback
def async_subscribe_flow(
self, listener: Callable[[str, str], None]
) -> CALLBACK_TYPE:
"""Subscribe to non user initiated flow init or remove."""
self._flow_subscriptions.append(listener)
return lambda: self._flow_subscriptions.remove(listener)
@callback
def _async_remove_flow_progress(self, flow_id: str) -> None:
"""Remove a flow from in progress."""
flow = self._progress.get(flow_id)
super()._async_remove_flow_progress(flow_id)
# Fire remove event for initialized non user initiated flows
if (
not flow
or flow.cur_step is None
or flow.source not in (DISCOVERY_SOURCES | {SOURCE_REAUTH})
):
return
for listeners in self._flow_subscriptions:
listeners("removed", flow_id)
class ConfigEntryItems(UserDict[str, ConfigEntry]):
"""Container for config items, maps config_entry_id -> entry.

View File

@ -8,6 +8,7 @@ from unittest.mock import ANY, AsyncMock, patch
from aiohttp.test_utils import TestClient
from freezegun.api import FrozenDateTimeFactory
import pytest
from pytest_unordered import unordered
import voluptuous as vol
from homeassistant import config_entries as core_ce, data_entry_flow, loader
@ -882,6 +883,256 @@ async def test_get_progress_flow_unauth(
assert resp2.status == HTTPStatus.UNAUTHORIZED
async def test_get_progress_subscribe(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
) -> None:
"""Test querying for the flows that are in progress."""
assert await async_setup_component(hass, "config", {})
mock_platform(hass, "test.config_flow", None)
ws_client = await hass_ws_client(hass)
mock_integration(
hass, MockModule("test", async_setup_entry=AsyncMock(return_value=True))
)
entry = MockConfigEntry(domain="test", title="Test", entry_id="1234")
entry.add_to_hass(hass)
class TestFlow(core_ce.ConfigFlow):
VERSION = 5
async def async_step_bluetooth(
self, discovery_info: HassioServiceInfo
) -> ConfigFlowResult:
"""Handle a bluetooth discovery."""
return self.async_abort(reason="already_configured")
async def async_step_hassio(
self, discovery_info: HassioServiceInfo
) -> ConfigFlowResult:
"""Handle a Hass.io discovery."""
return await self.async_step_account()
async def async_step_account(self, user_input: dict[str, Any] | None = None):
"""Show a form to the user."""
return self.async_show_form(step_id="account")
async def async_step_user(self, user_input: dict[str, Any] | None = None):
"""Handle a config flow initialized by the user."""
return await self.async_step_account()
async def async_step_reauth(self, user_input: dict[str, Any] | None = None):
"""Handle a reauthentication flow."""
nonlocal entry
assert self._get_reauth_entry() is entry
return await self.async_step_account()
async def async_step_reconfigure(
self, user_input: dict[str, Any] | None = None
):
"""Handle a reconfiguration flow initialized by the user."""
nonlocal entry
assert self._get_reconfigure_entry() is entry
return await self.async_step_account()
await ws_client.send_json({"id": 1, "type": "config_entries/flow/subscribe"})
response = await ws_client.receive_json()
assert response == {"id": 1, "event": [], "type": "event"}
response = await ws_client.receive_json()
assert response == {"id": 1, "result": None, "success": True, "type": "result"}
flow_context = {
"bluetooth": {"source": core_ce.SOURCE_BLUETOOTH},
"hassio": {"source": core_ce.SOURCE_HASSIO},
"user": {"source": core_ce.SOURCE_USER},
"reauth": {"source": core_ce.SOURCE_REAUTH, "entry_id": "1234"},
"reconfigure": {"source": core_ce.SOURCE_RECONFIGURE, "entry_id": "1234"},
}
forms = {}
with mock_config_flow("test", TestFlow):
for key, context in flow_context.items():
forms[key] = await hass.config_entries.flow.async_init(
"test", context=context
)
assert forms["bluetooth"]["type"] == data_entry_flow.FlowResultType.ABORT
for key in ("hassio", "user", "reauth", "reconfigure"):
assert forms[key]["type"] == data_entry_flow.FlowResultType.FORM
assert forms[key]["step_id"] == "account"
for key in ("hassio", "user", "reauth", "reconfigure"):
hass.config_entries.flow.async_abort(forms[key]["flow_id"])
# Uninitialized flows and flows with SOURCE_USER and SOURCE_RECONFIGURE
# should be filtered out
for key in ("hassio", "reauth"):
response = await ws_client.receive_json()
assert response == {
"event": [
{
"flow": {
"flow_id": forms[key]["flow_id"],
"handler": "test",
"step_id": "account",
"context": flow_context[key],
},
"flow_id": forms[key]["flow_id"],
"type": "added",
}
],
"id": 1,
"type": "event",
}
for key in ("hassio", "reauth"):
response = await ws_client.receive_json()
assert response == {
"event": [
{
"flow_id": forms[key]["flow_id"],
"type": "removed",
}
],
"id": 1,
"type": "event",
}
async def test_get_progress_subscribe_in_progress(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
) -> None:
"""Test querying for the flows that are in progress."""
assert await async_setup_component(hass, "config", {})
mock_platform(hass, "test.config_flow", None)
ws_client = await hass_ws_client(hass)
mock_integration(
hass, MockModule("test", async_setup_entry=AsyncMock(return_value=True))
)
entry = MockConfigEntry(domain="test", title="Test", entry_id="1234")
entry.add_to_hass(hass)
class TestFlow(core_ce.ConfigFlow):
VERSION = 5
async def async_step_bluetooth(
self, discovery_info: HassioServiceInfo
) -> ConfigFlowResult:
"""Handle a bluetooth discovery."""
return self.async_abort(reason="already_configured")
async def async_step_hassio(
self, discovery_info: HassioServiceInfo
) -> ConfigFlowResult:
"""Handle a Hass.io discovery."""
return await self.async_step_account()
async def async_step_account(self, user_input: dict[str, Any] | None = None):
"""Show a form to the user."""
return self.async_show_form(step_id="account")
async def async_step_user(self, user_input: dict[str, Any] | None = None):
"""Handle a config flow initialized by the user."""
return await self.async_step_account()
async def async_step_reauth(self, user_input: dict[str, Any] | None = None):
"""Handle a reauthentication flow."""
nonlocal entry
assert self._get_reauth_entry() is entry
return await self.async_step_account()
async def async_step_reconfigure(
self, user_input: dict[str, Any] | None = None
):
"""Handle a reconfiguration flow initialized by the user."""
nonlocal entry
assert self._get_reconfigure_entry() is entry
return await self.async_step_account()
flow_context = {
"bluetooth": {"source": core_ce.SOURCE_BLUETOOTH},
"hassio": {"source": core_ce.SOURCE_HASSIO},
"user": {"source": core_ce.SOURCE_USER},
"reauth": {"source": core_ce.SOURCE_REAUTH, "entry_id": "1234"},
"reconfigure": {"source": core_ce.SOURCE_RECONFIGURE, "entry_id": "1234"},
}
forms = {}
with mock_config_flow("test", TestFlow):
for key, context in flow_context.items():
forms[key] = await hass.config_entries.flow.async_init(
"test", context=context
)
assert forms["bluetooth"]["type"] == data_entry_flow.FlowResultType.ABORT
for key in ("hassio", "user", "reauth", "reconfigure"):
assert forms[key]["type"] == data_entry_flow.FlowResultType.FORM
assert forms[key]["step_id"] == "account"
await ws_client.send_json({"id": 1, "type": "config_entries/flow/subscribe"})
# Uninitialized flows and flows with SOURCE_USER and SOURCE_RECONFIGURE
# should be filtered out
responses = []
responses.append(await ws_client.receive_json())
assert responses == [
{
"event": unordered(
[
{
"flow": {
"flow_id": forms[key]["flow_id"],
"handler": "test",
"step_id": "account",
"context": flow_context[key],
},
"flow_id": forms[key]["flow_id"],
"type": None,
}
for key in ("hassio", "reauth")
]
),
"id": 1,
"type": "event",
}
]
response = await ws_client.receive_json()
assert response == {"id": ANY, "result": None, "success": True, "type": "result"}
for key in ("hassio", "user", "reauth", "reconfigure"):
hass.config_entries.flow.async_abort(forms[key]["flow_id"])
for key in ("hassio", "reauth"):
response = await ws_client.receive_json()
assert response == {
"event": [
{
"flow_id": forms[key]["flow_id"],
"type": "removed",
}
],
"id": 1,
"type": "event",
}
async def test_get_progress_subscribe_unauth(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, hass_admin_user: MockUser
) -> None:
"""Test we can't subscribe to flows."""
assert await async_setup_component(hass, "config", {})
hass_admin_user.groups = []
ws_client = await hass_ws_client(hass)
await ws_client.send_json({"id": 5, "type": "config_entries/flow/subscribe"})
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "unauthorized"
async def test_options_flow(hass: HomeAssistant, client: TestClient) -> None:
"""Test we can change options."""

View File

@ -21,14 +21,16 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import area_registry as ar
from homeassistant.helpers.backup import async_initialize_backup
from homeassistant.setup import async_setup_component
from homeassistant.setup import async_set_domains_to_be_loaded, async_setup_component
from . import mock_storage
from tests.common import (
CLIENT_ID,
CLIENT_REDIRECT_URI,
MockModule,
MockUser,
mock_integration,
register_auth_provider,
)
from tests.test_util.aiohttp import AiohttpClientMocker
@ -1205,3 +1207,82 @@ async def test_onboarding_cloud_status(
assert req.status == HTTPStatus.OK
data = await req.json()
assert data == {"logged_in": False}
@pytest.mark.parametrize(
("domain", "expected_result"),
[
("onboarding", {"integration_loaded": True}),
("non_existing_domain", {"integration_loaded": False}),
],
)
async def test_wait_integration(
hass: HomeAssistant,
hass_storage: dict[str, Any],
hass_client: ClientSessionGenerator,
domain: str,
expected_result: dict[str, Any],
) -> None:
"""Test we can get wait for an integration to load."""
mock_storage(hass_storage, {"done": []})
assert await async_setup_component(hass, "onboarding", {})
await hass.async_block_till_done()
client = await hass_client()
req = await client.post("/api/onboarding/integration/wait", json={"domain": domain})
assert req.status == HTTPStatus.OK
data = await req.json()
assert data == expected_result
async def test_wait_integration_startup(
hass: HomeAssistant,
hass_storage: dict[str, Any],
hass_client: ClientSessionGenerator,
) -> None:
"""Test we can get wait for an integration to load during startup."""
mock_storage(hass_storage, {"done": []})
assert await async_setup_component(hass, "onboarding", {})
await hass.async_block_till_done()
client = await hass_client()
setup_stall = asyncio.Event()
setup_started = asyncio.Event()
async def mock_setup(hass: HomeAssistant, _) -> bool:
setup_started.set()
await setup_stall.wait()
return True
mock_integration(hass, MockModule("test", async_setup=mock_setup))
# The integration is not loaded, and is also not scheduled to load
req = await client.post("/api/onboarding/integration/wait", json={"domain": "test"})
assert req.status == HTTPStatus.OK
data = await req.json()
assert data == {"integration_loaded": False}
# Mark the component as scheduled to be loaded
async_set_domains_to_be_loaded(hass, {"test"})
# Start loading the component, including its config entries
hass.async_create_task(async_setup_component(hass, "test", {}))
await setup_started.wait()
# The component is not yet loaded
assert "test" not in hass.config.components
# Allow setup to proceed
setup_stall.set()
# The component is scheduled to load, this will block until the config entry is loaded
req = await client.post("/api/onboarding/integration/wait", json={"domain": "test"})
assert req.status == HTTPStatus.OK
data = await req.json()
assert data == {"integration_loaded": True}
# The component has been loaded
assert "test" in hass.config.components