mirror of
https://github.com/home-assistant/core.git
synced 2025-07-20 03:37:07 +00:00
Add WS command config_entries/flow/subscribe (#142459)
This commit is contained in:
parent
a26cdef427
commit
eee6e8a2c3
@ -58,7 +58,8 @@ def async_setup(hass: HomeAssistant) -> bool:
|
||||
websocket_api.async_register_command(hass, config_entry_get_single)
|
||||
websocket_api.async_register_command(hass, config_entry_update)
|
||||
websocket_api.async_register_command(hass, config_entries_subscribe)
|
||||
websocket_api.async_register_command(hass, config_entries_progress)
|
||||
websocket_api.async_register_command(hass, config_entries_flow_progress)
|
||||
websocket_api.async_register_command(hass, config_entries_flow_subscribe)
|
||||
websocket_api.async_register_command(hass, ignore_config_flow)
|
||||
|
||||
websocket_api.async_register_command(hass, config_subentry_delete)
|
||||
@ -357,7 +358,7 @@ class SubentryManagerFlowResourceView(
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command({"type": "config_entries/flow/progress"})
|
||||
def config_entries_progress(
|
||||
def config_entries_flow_progress(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
@ -378,6 +379,66 @@ def config_entries_progress(
|
||||
)
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command({"type": "config_entries/flow/subscribe"})
|
||||
def config_entries_flow_subscribe(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Subscribe to non user created flows being initiated or removed.
|
||||
|
||||
When initiating the subscription, the current flows are sent to the client.
|
||||
|
||||
Example of a non-user initiated flow is a discovered Hue hub that
|
||||
requires user interaction to finish setup.
|
||||
"""
|
||||
|
||||
@callback
|
||||
def async_on_flow_init_remove(change_type: str, flow_id: str) -> None:
|
||||
"""Forward config entry state events to websocket."""
|
||||
if change_type == "removed":
|
||||
connection.send_message(
|
||||
websocket_api.event_message(
|
||||
msg["id"],
|
||||
[{"type": change_type, "flow_id": flow_id}],
|
||||
)
|
||||
)
|
||||
return
|
||||
# change_type == "added"
|
||||
connection.send_message(
|
||||
websocket_api.event_message(
|
||||
msg["id"],
|
||||
[
|
||||
{
|
||||
"type": change_type,
|
||||
"flow_id": flow_id,
|
||||
"flow": hass.config_entries.flow.async_get(flow_id),
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
connection.subscriptions[msg["id"]] = hass.config_entries.flow.async_subscribe_flow(
|
||||
async_on_flow_init_remove
|
||||
)
|
||||
connection.send_message(
|
||||
websocket_api.event_message(
|
||||
msg["id"],
|
||||
[
|
||||
{"type": None, "flow_id": flw["flow_id"], "flow": flw}
|
||||
for flw in hass.config_entries.flow.async_progress()
|
||||
if flw["context"]["source"]
|
||||
not in (
|
||||
config_entries.SOURCE_RECONFIGURE,
|
||||
config_entries.SOURCE_USER,
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
def send_entry_not_found(
|
||||
connection: websocket_api.ActiveConnection, msg_id: int
|
||||
) -> None:
|
||||
|
@ -31,7 +31,12 @@ from homeassistant.helpers import area_registry as ar
|
||||
from homeassistant.helpers.backup import async_get_manager as async_get_backup_manager
|
||||
from homeassistant.helpers.system_info import async_get_system_info
|
||||
from homeassistant.helpers.translation import async_get_translations
|
||||
from homeassistant.setup import SetupPhases, async_pause_setup, async_setup_component
|
||||
from homeassistant.setup import (
|
||||
SetupPhases,
|
||||
async_pause_setup,
|
||||
async_setup_component,
|
||||
async_wait_component,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import OnboardingData, OnboardingStorage, OnboardingStoreData
|
||||
@ -60,6 +65,7 @@ async def async_setup(
|
||||
hass.http.register_view(BackupInfoView(data))
|
||||
hass.http.register_view(RestoreBackupView(data))
|
||||
hass.http.register_view(UploadBackupView(data))
|
||||
hass.http.register_view(WaitIntegrationOnboardingView(data))
|
||||
await setup_cloud_views(hass, data)
|
||||
|
||||
|
||||
@ -298,6 +304,30 @@ class IntegrationOnboardingView(_BaseOnboardingStepView):
|
||||
return self.json({"auth_code": auth_code})
|
||||
|
||||
|
||||
class WaitIntegrationOnboardingView(_NoAuthBaseOnboardingView):
|
||||
"""Get backup info view."""
|
||||
|
||||
url = "/api/onboarding/integration/wait"
|
||||
name = "api:onboarding:integration:wait"
|
||||
|
||||
@RequestDataValidator(
|
||||
vol.Schema(
|
||||
{
|
||||
vol.Required("domain"): str,
|
||||
}
|
||||
)
|
||||
)
|
||||
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
|
||||
"""Handle wait for integration command."""
|
||||
hass = request.app[KEY_HASS]
|
||||
domain = data["domain"]
|
||||
return self.json(
|
||||
{
|
||||
"integration_loaded": await async_wait_component(hass, domain),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class AnalyticsOnboardingView(_BaseOnboardingStepView):
|
||||
"""View to finish analytics onboarding step."""
|
||||
|
||||
|
@ -1375,6 +1375,7 @@ class ConfigEntriesFlowManager(
|
||||
function=self._async_fire_discovery_event,
|
||||
background=True,
|
||||
)
|
||||
self._flow_subscriptions: list[Callable[[str, str], None]] = []
|
||||
|
||||
async def async_wait_import_flow_initialized(self, handler: str) -> None:
|
||||
"""Wait till all import flows in progress are initialized."""
|
||||
@ -1461,6 +1462,13 @@ class ConfigEntriesFlowManager(
|
||||
# Fire discovery event
|
||||
await self._discovery_event_debouncer.async_call()
|
||||
|
||||
if result["type"] != data_entry_flow.FlowResultType.ABORT and source in (
|
||||
DISCOVERY_SOURCES | {SOURCE_REAUTH}
|
||||
):
|
||||
# Notify listeners that a flow is created
|
||||
for subscription in self._flow_subscriptions:
|
||||
subscription("added", flow.flow_id)
|
||||
|
||||
return result
|
||||
|
||||
async def _async_init(
|
||||
@ -1739,6 +1747,29 @@ class ConfigEntriesFlowManager(
|
||||
return True
|
||||
return False
|
||||
|
||||
@callback
|
||||
def async_subscribe_flow(
|
||||
self, listener: Callable[[str, str], None]
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Subscribe to non user initiated flow init or remove."""
|
||||
self._flow_subscriptions.append(listener)
|
||||
return lambda: self._flow_subscriptions.remove(listener)
|
||||
|
||||
@callback
|
||||
def _async_remove_flow_progress(self, flow_id: str) -> None:
|
||||
"""Remove a flow from in progress."""
|
||||
flow = self._progress.get(flow_id)
|
||||
super()._async_remove_flow_progress(flow_id)
|
||||
# Fire remove event for initialized non user initiated flows
|
||||
if (
|
||||
not flow
|
||||
or flow.cur_step is None
|
||||
or flow.source not in (DISCOVERY_SOURCES | {SOURCE_REAUTH})
|
||||
):
|
||||
return
|
||||
for listeners in self._flow_subscriptions:
|
||||
listeners("removed", flow_id)
|
||||
|
||||
|
||||
class ConfigEntryItems(UserDict[str, ConfigEntry]):
|
||||
"""Container for config items, maps config_entry_id -> entry.
|
||||
|
@ -8,6 +8,7 @@ from unittest.mock import ANY, AsyncMock, patch
|
||||
from aiohttp.test_utils import TestClient
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
import pytest
|
||||
from pytest_unordered import unordered
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries as core_ce, data_entry_flow, loader
|
||||
@ -882,6 +883,256 @@ async def test_get_progress_flow_unauth(
|
||||
assert resp2.status == HTTPStatus.UNAUTHORIZED
|
||||
|
||||
|
||||
async def test_get_progress_subscribe(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
|
||||
) -> None:
|
||||
"""Test querying for the flows that are in progress."""
|
||||
assert await async_setup_component(hass, "config", {})
|
||||
mock_platform(hass, "test.config_flow", None)
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
mock_integration(
|
||||
hass, MockModule("test", async_setup_entry=AsyncMock(return_value=True))
|
||||
)
|
||||
|
||||
entry = MockConfigEntry(domain="test", title="Test", entry_id="1234")
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
class TestFlow(core_ce.ConfigFlow):
|
||||
VERSION = 5
|
||||
|
||||
async def async_step_bluetooth(
|
||||
self, discovery_info: HassioServiceInfo
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle a bluetooth discovery."""
|
||||
return self.async_abort(reason="already_configured")
|
||||
|
||||
async def async_step_hassio(
|
||||
self, discovery_info: HassioServiceInfo
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle a Hass.io discovery."""
|
||||
return await self.async_step_account()
|
||||
|
||||
async def async_step_account(self, user_input: dict[str, Any] | None = None):
|
||||
"""Show a form to the user."""
|
||||
return self.async_show_form(step_id="account")
|
||||
|
||||
async def async_step_user(self, user_input: dict[str, Any] | None = None):
|
||||
"""Handle a config flow initialized by the user."""
|
||||
return await self.async_step_account()
|
||||
|
||||
async def async_step_reauth(self, user_input: dict[str, Any] | None = None):
|
||||
"""Handle a reauthentication flow."""
|
||||
nonlocal entry
|
||||
assert self._get_reauth_entry() is entry
|
||||
return await self.async_step_account()
|
||||
|
||||
async def async_step_reconfigure(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
):
|
||||
"""Handle a reconfiguration flow initialized by the user."""
|
||||
nonlocal entry
|
||||
assert self._get_reconfigure_entry() is entry
|
||||
return await self.async_step_account()
|
||||
|
||||
await ws_client.send_json({"id": 1, "type": "config_entries/flow/subscribe"})
|
||||
response = await ws_client.receive_json()
|
||||
assert response == {"id": 1, "event": [], "type": "event"}
|
||||
response = await ws_client.receive_json()
|
||||
assert response == {"id": 1, "result": None, "success": True, "type": "result"}
|
||||
|
||||
flow_context = {
|
||||
"bluetooth": {"source": core_ce.SOURCE_BLUETOOTH},
|
||||
"hassio": {"source": core_ce.SOURCE_HASSIO},
|
||||
"user": {"source": core_ce.SOURCE_USER},
|
||||
"reauth": {"source": core_ce.SOURCE_REAUTH, "entry_id": "1234"},
|
||||
"reconfigure": {"source": core_ce.SOURCE_RECONFIGURE, "entry_id": "1234"},
|
||||
}
|
||||
forms = {}
|
||||
|
||||
with mock_config_flow("test", TestFlow):
|
||||
for key, context in flow_context.items():
|
||||
forms[key] = await hass.config_entries.flow.async_init(
|
||||
"test", context=context
|
||||
)
|
||||
|
||||
assert forms["bluetooth"]["type"] == data_entry_flow.FlowResultType.ABORT
|
||||
for key in ("hassio", "user", "reauth", "reconfigure"):
|
||||
assert forms[key]["type"] == data_entry_flow.FlowResultType.FORM
|
||||
assert forms[key]["step_id"] == "account"
|
||||
|
||||
for key in ("hassio", "user", "reauth", "reconfigure"):
|
||||
hass.config_entries.flow.async_abort(forms[key]["flow_id"])
|
||||
|
||||
# Uninitialized flows and flows with SOURCE_USER and SOURCE_RECONFIGURE
|
||||
# should be filtered out
|
||||
for key in ("hassio", "reauth"):
|
||||
response = await ws_client.receive_json()
|
||||
assert response == {
|
||||
"event": [
|
||||
{
|
||||
"flow": {
|
||||
"flow_id": forms[key]["flow_id"],
|
||||
"handler": "test",
|
||||
"step_id": "account",
|
||||
"context": flow_context[key],
|
||||
},
|
||||
"flow_id": forms[key]["flow_id"],
|
||||
"type": "added",
|
||||
}
|
||||
],
|
||||
"id": 1,
|
||||
"type": "event",
|
||||
}
|
||||
for key in ("hassio", "reauth"):
|
||||
response = await ws_client.receive_json()
|
||||
assert response == {
|
||||
"event": [
|
||||
{
|
||||
"flow_id": forms[key]["flow_id"],
|
||||
"type": "removed",
|
||||
}
|
||||
],
|
||||
"id": 1,
|
||||
"type": "event",
|
||||
}
|
||||
|
||||
|
||||
async def test_get_progress_subscribe_in_progress(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
|
||||
) -> None:
|
||||
"""Test querying for the flows that are in progress."""
|
||||
assert await async_setup_component(hass, "config", {})
|
||||
mock_platform(hass, "test.config_flow", None)
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
mock_integration(
|
||||
hass, MockModule("test", async_setup_entry=AsyncMock(return_value=True))
|
||||
)
|
||||
|
||||
entry = MockConfigEntry(domain="test", title="Test", entry_id="1234")
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
class TestFlow(core_ce.ConfigFlow):
|
||||
VERSION = 5
|
||||
|
||||
async def async_step_bluetooth(
|
||||
self, discovery_info: HassioServiceInfo
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle a bluetooth discovery."""
|
||||
return self.async_abort(reason="already_configured")
|
||||
|
||||
async def async_step_hassio(
|
||||
self, discovery_info: HassioServiceInfo
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle a Hass.io discovery."""
|
||||
return await self.async_step_account()
|
||||
|
||||
async def async_step_account(self, user_input: dict[str, Any] | None = None):
|
||||
"""Show a form to the user."""
|
||||
return self.async_show_form(step_id="account")
|
||||
|
||||
async def async_step_user(self, user_input: dict[str, Any] | None = None):
|
||||
"""Handle a config flow initialized by the user."""
|
||||
return await self.async_step_account()
|
||||
|
||||
async def async_step_reauth(self, user_input: dict[str, Any] | None = None):
|
||||
"""Handle a reauthentication flow."""
|
||||
nonlocal entry
|
||||
assert self._get_reauth_entry() is entry
|
||||
return await self.async_step_account()
|
||||
|
||||
async def async_step_reconfigure(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
):
|
||||
"""Handle a reconfiguration flow initialized by the user."""
|
||||
nonlocal entry
|
||||
assert self._get_reconfigure_entry() is entry
|
||||
return await self.async_step_account()
|
||||
|
||||
flow_context = {
|
||||
"bluetooth": {"source": core_ce.SOURCE_BLUETOOTH},
|
||||
"hassio": {"source": core_ce.SOURCE_HASSIO},
|
||||
"user": {"source": core_ce.SOURCE_USER},
|
||||
"reauth": {"source": core_ce.SOURCE_REAUTH, "entry_id": "1234"},
|
||||
"reconfigure": {"source": core_ce.SOURCE_RECONFIGURE, "entry_id": "1234"},
|
||||
}
|
||||
forms = {}
|
||||
|
||||
with mock_config_flow("test", TestFlow):
|
||||
for key, context in flow_context.items():
|
||||
forms[key] = await hass.config_entries.flow.async_init(
|
||||
"test", context=context
|
||||
)
|
||||
|
||||
assert forms["bluetooth"]["type"] == data_entry_flow.FlowResultType.ABORT
|
||||
for key in ("hassio", "user", "reauth", "reconfigure"):
|
||||
assert forms[key]["type"] == data_entry_flow.FlowResultType.FORM
|
||||
assert forms[key]["step_id"] == "account"
|
||||
|
||||
await ws_client.send_json({"id": 1, "type": "config_entries/flow/subscribe"})
|
||||
|
||||
# Uninitialized flows and flows with SOURCE_USER and SOURCE_RECONFIGURE
|
||||
# should be filtered out
|
||||
responses = []
|
||||
responses.append(await ws_client.receive_json())
|
||||
assert responses == [
|
||||
{
|
||||
"event": unordered(
|
||||
[
|
||||
{
|
||||
"flow": {
|
||||
"flow_id": forms[key]["flow_id"],
|
||||
"handler": "test",
|
||||
"step_id": "account",
|
||||
"context": flow_context[key],
|
||||
},
|
||||
"flow_id": forms[key]["flow_id"],
|
||||
"type": None,
|
||||
}
|
||||
for key in ("hassio", "reauth")
|
||||
]
|
||||
),
|
||||
"id": 1,
|
||||
"type": "event",
|
||||
}
|
||||
]
|
||||
|
||||
response = await ws_client.receive_json()
|
||||
assert response == {"id": ANY, "result": None, "success": True, "type": "result"}
|
||||
|
||||
for key in ("hassio", "user", "reauth", "reconfigure"):
|
||||
hass.config_entries.flow.async_abort(forms[key]["flow_id"])
|
||||
|
||||
for key in ("hassio", "reauth"):
|
||||
response = await ws_client.receive_json()
|
||||
assert response == {
|
||||
"event": [
|
||||
{
|
||||
"flow_id": forms[key]["flow_id"],
|
||||
"type": "removed",
|
||||
}
|
||||
],
|
||||
"id": 1,
|
||||
"type": "event",
|
||||
}
|
||||
|
||||
|
||||
async def test_get_progress_subscribe_unauth(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, hass_admin_user: MockUser
|
||||
) -> None:
|
||||
"""Test we can't subscribe to flows."""
|
||||
assert await async_setup_component(hass, "config", {})
|
||||
hass_admin_user.groups = []
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json({"id": 5, "type": "config_entries/flow/subscribe"})
|
||||
response = await ws_client.receive_json()
|
||||
|
||||
assert not response["success"]
|
||||
assert response["error"]["code"] == "unauthorized"
|
||||
|
||||
|
||||
async def test_options_flow(hass: HomeAssistant, client: TestClient) -> None:
|
||||
"""Test we can change options."""
|
||||
|
||||
|
@ -21,14 +21,16 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import area_registry as ar
|
||||
from homeassistant.helpers.backup import async_initialize_backup
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.setup import async_set_domains_to_be_loaded, async_setup_component
|
||||
|
||||
from . import mock_storage
|
||||
|
||||
from tests.common import (
|
||||
CLIENT_ID,
|
||||
CLIENT_REDIRECT_URI,
|
||||
MockModule,
|
||||
MockUser,
|
||||
mock_integration,
|
||||
register_auth_provider,
|
||||
)
|
||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||
@ -1205,3 +1207,82 @@ async def test_onboarding_cloud_status(
|
||||
assert req.status == HTTPStatus.OK
|
||||
data = await req.json()
|
||||
assert data == {"logged_in": False}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("domain", "expected_result"),
|
||||
[
|
||||
("onboarding", {"integration_loaded": True}),
|
||||
("non_existing_domain", {"integration_loaded": False}),
|
||||
],
|
||||
)
|
||||
async def test_wait_integration(
|
||||
hass: HomeAssistant,
|
||||
hass_storage: dict[str, Any],
|
||||
hass_client: ClientSessionGenerator,
|
||||
domain: str,
|
||||
expected_result: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test we can get wait for an integration to load."""
|
||||
mock_storage(hass_storage, {"done": []})
|
||||
|
||||
assert await async_setup_component(hass, "onboarding", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
client = await hass_client()
|
||||
req = await client.post("/api/onboarding/integration/wait", json={"domain": domain})
|
||||
|
||||
assert req.status == HTTPStatus.OK
|
||||
data = await req.json()
|
||||
assert data == expected_result
|
||||
|
||||
|
||||
async def test_wait_integration_startup(
|
||||
hass: HomeAssistant,
|
||||
hass_storage: dict[str, Any],
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test we can get wait for an integration to load during startup."""
|
||||
mock_storage(hass_storage, {"done": []})
|
||||
|
||||
assert await async_setup_component(hass, "onboarding", {})
|
||||
await hass.async_block_till_done()
|
||||
client = await hass_client()
|
||||
|
||||
setup_stall = asyncio.Event()
|
||||
setup_started = asyncio.Event()
|
||||
|
||||
async def mock_setup(hass: HomeAssistant, _) -> bool:
|
||||
setup_started.set()
|
||||
await setup_stall.wait()
|
||||
return True
|
||||
|
||||
mock_integration(hass, MockModule("test", async_setup=mock_setup))
|
||||
|
||||
# The integration is not loaded, and is also not scheduled to load
|
||||
req = await client.post("/api/onboarding/integration/wait", json={"domain": "test"})
|
||||
assert req.status == HTTPStatus.OK
|
||||
data = await req.json()
|
||||
assert data == {"integration_loaded": False}
|
||||
|
||||
# Mark the component as scheduled to be loaded
|
||||
async_set_domains_to_be_loaded(hass, {"test"})
|
||||
|
||||
# Start loading the component, including its config entries
|
||||
hass.async_create_task(async_setup_component(hass, "test", {}))
|
||||
await setup_started.wait()
|
||||
|
||||
# The component is not yet loaded
|
||||
assert "test" not in hass.config.components
|
||||
|
||||
# Allow setup to proceed
|
||||
setup_stall.set()
|
||||
|
||||
# The component is scheduled to load, this will block until the config entry is loaded
|
||||
req = await client.post("/api/onboarding/integration/wait", json={"domain": "test"})
|
||||
assert req.status == HTTPStatus.OK
|
||||
data = await req.json()
|
||||
assert data == {"integration_loaded": True}
|
||||
|
||||
# The component has been loaded
|
||||
assert "test" in hass.config.components
|
||||
|
Loading…
x
Reference in New Issue
Block a user