From 6b519499a72bef68a1ec8f1d251cf055698819c2 Mon Sep 17 00:00:00 2001 From: Jc2k Date: Fri, 3 Jan 2020 16:28:05 +0000 Subject: [PATCH] Don't expose flows that aren't initialised. (#30432) * Don't expose flows that aren't initialised. If a flow init does not return immediately then there is a window where our behaviour is screwy: * Can try to configure a flow that isn't ready * Can show notifications for discoveries that might yet return an abort This moves the flow discovery events and notifications to after the flow is initialised and hides flows that don't have a cur_step from async_progress * Fix tradfri test * Black. * Lint fixes --- homeassistant/config_entries.py | 22 +++++--- homeassistant/data_entry_flow.py | 14 ++++- tests/components/tradfri/conftest.py | 10 ++++ tests/components/tradfri/test_config_flow.py | 8 --- tests/components/tradfri/test_init.py | 16 +++--- tests/test_config_entries.py | 57 ++++++++++++++++++++ 6 files changed, 106 insertions(+), 21 deletions(-) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index d1b5c927a2b..6fb5595dac4 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -61,7 +61,13 @@ ENTRY_STATE_FAILED_UNLOAD = "failed_unload" UNRECOVERABLE_STATES = (ENTRY_STATE_MIGRATION_ERROR, ENTRY_STATE_FAILED_UNLOAD) DISCOVERY_NOTIFICATION_ID = "config_entry_discovery" -DISCOVERY_SOURCES = (SOURCE_SSDP, SOURCE_ZEROCONF, SOURCE_DISCOVERY, SOURCE_IMPORT) +DISCOVERY_SOURCES = ( + SOURCE_SSDP, + SOURCE_ZEROCONF, + SOURCE_DISCOVERY, + SOURCE_IMPORT, + SOURCE_UNIGNORE, +) EVENT_FLOW_DISCOVERED = "config_entry_discovered" @@ -511,7 +517,15 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): if not context or "source" not in context: raise KeyError("Context not set or doesn't have a source set") - source = context["source"] + flow = cast(ConfigFlow, handler()) + flow.init_step = context["source"] + return flow + + async def async_post_init( + self, flow: data_entry_flow.FlowHandler, result: dict + ) -> None: + """After a flow is initialised trigger new flow notifications.""" + source = flow.context["source"] # Create notification. if source in DISCOVERY_SOURCES: @@ -525,10 +539,6 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): notification_id=DISCOVERY_NOTIFICATION_ID, ) - flow = cast(ConfigFlow, handler()) - flow.init_step = source - return flow - class ConfigEntries: """Manage the configuration entries. diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 5e72dea9273..4dd1c7acf50 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -76,12 +76,19 @@ class FlowManager(abc.ABC): """Finish a config flow and add an entry.""" pass + async def async_post_init( + self, flow: "FlowHandler", result: Dict[str, Any] + ) -> None: + """Entry has finished executing its first step asynchronously.""" + pass + @callback def async_progress(self) -> List[Dict]: """Return the flows in progress.""" return [ {"flow_id": flow.flow_id, "handler": flow.handler, "context": flow.context} for flow in self._progress.values() + if flow.cur_step is not None ] async def async_init( @@ -99,7 +106,12 @@ class FlowManager(abc.ABC): flow.context = context self._progress[flow.flow_id] = flow - return await self._async_handle_step(flow, flow.init_step, data) + result = await self._async_handle_step(flow, flow.init_step, data) + + if result["type"] != RESULT_TYPE_ABORT: + await self.async_post_init(flow, result) + + return result async def async_configure( self, flow_id: str, user_input: Optional[Dict] = None diff --git a/tests/components/tradfri/conftest.py b/tests/components/tradfri/conftest.py index 1c6e572b81f..d835c06f256 100644 --- a/tests/components/tradfri/conftest.py +++ b/tests/components/tradfri/conftest.py @@ -3,6 +3,8 @@ from unittest.mock import patch import pytest +from tests.common import mock_coro + @pytest.fixture def mock_gateway_info(): @@ -11,3 +13,11 @@ def mock_gateway_info(): "homeassistant.components.tradfri.config_flow.get_gateway_info" ) as mock_gateway: yield mock_gateway + + +@pytest.fixture +def mock_entry_setup(): + """Mock entry setup.""" + with patch("homeassistant.components.tradfri.async_setup_entry") as mock_setup: + mock_setup.return_value = mock_coro(True) + yield mock_setup diff --git a/tests/components/tradfri/test_config_flow.py b/tests/components/tradfri/test_config_flow.py index ad7386c530f..18fb55eda2f 100644 --- a/tests/components/tradfri/test_config_flow.py +++ b/tests/components/tradfri/test_config_flow.py @@ -18,14 +18,6 @@ def mock_auth(): yield mock_auth -@pytest.fixture -def mock_entry_setup(): - """Mock entry setup.""" - with patch("homeassistant.components.tradfri.async_setup_entry") as mock_setup: - mock_setup.return_value = mock_coro(True) - yield mock_setup - - async def test_user_connection_successful(hass, mock_auth, mock_entry_setup): """Test a successful connection.""" mock_auth.side_effect = lambda hass, host, code: mock_coro( diff --git a/tests/components/tradfri/test_init.py b/tests/components/tradfri/test_init.py index 67ecb8d054a..cf9034df8d6 100644 --- a/tests/components/tradfri/test_init.py +++ b/tests/components/tradfri/test_init.py @@ -3,7 +3,7 @@ from unittest.mock import patch from homeassistant.setup import async_setup_component -from tests.common import MockConfigEntry +from tests.common import MockConfigEntry, mock_coro async def test_config_yaml_host_not_imported(hass): @@ -49,8 +49,12 @@ async def test_config_json_host_not_imported(hass): assert len(mock_init.mock_calls) == 0 -async def test_config_json_host_imported(hass, mock_gateway_info): +async def test_config_json_host_imported(hass, mock_gateway_info, mock_entry_setup): """Test that we import a configured host.""" + mock_gateway_info.side_effect = lambda hass, host, identity, key: mock_coro( + {"host": host, "identity": identity, "key": key, "gateway_id": "mock-gateway"} + ) + with patch( "homeassistant.components.tradfri.load_json", return_value={"mock-host": {"key": "some-info"}}, @@ -58,7 +62,7 @@ async def test_config_json_host_imported(hass, mock_gateway_info): assert await async_setup_component(hass, "tradfri", {"tradfri": {}}) await hass.async_block_till_done() - progress = hass.config_entries.flow.async_progress() - assert len(progress) == 1 - assert progress[0]["handler"] == "tradfri" - assert progress[0]["context"] == {"source": "import"} + config_entry = mock_entry_setup.mock_calls[0][1][1] + assert config_entry.domain == "tradfri" + assert config_entry.source == "import" + assert config_entry.title == "mock-host" diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index c3a87bcf3a0..46410de1999 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -1303,3 +1303,60 @@ async def test_unignore_default_impl(hass, manager): assert len(hass.config_entries.async_entries("comp")) == 0 assert len(hass.config_entries.flow.async_progress()) == 0 + + +async def test_partial_flows_hidden(hass, manager): + """Test that flows that don't have a cur_step and haven't finished initing are hidden.""" + async_setup_entry = MagicMock(return_value=mock_coro(True)) + mock_integration(hass, MockModule("comp", async_setup_entry=async_setup_entry)) + mock_entity_platform(hass, "config_flow.comp", None) + await async_setup_component(hass, "persistent_notification", {}) + + # A flag to test our assertion that `async_step_discovery` was called and is in its blocked state + # This simulates if the step was e.g. doing network i/o + discovery_started = asyncio.Event() + + # A flag to allow `async_step_discovery` to resume after we have verified the uninited flow is not + # visible and has not triggered a discovery alert. This lets us control when the mocked network + # i/o is complete. + pause_discovery = asyncio.Event() + + class TestFlow(config_entries.ConfigFlow): + + VERSION = 1 + + async def async_step_discovery(self, user_input): + discovery_started.set() + await pause_discovery.wait() + return self.async_show_form(step_id="someform") + + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): + # Start a config entry flow and wait for it to be blocked + init_task = asyncio.ensure_future( + manager.flow.async_init( + "comp", + context={"source": config_entries.SOURCE_DISCOVERY}, + data={"unique_id": "mock-unique-id"}, + ) + ) + await discovery_started.wait() + + # While it's blocked it shouldn't be visible or trigger discovery notifications + assert len(hass.config_entries.flow.async_progress()) == 0 + + await hass.async_block_till_done() + state = hass.states.get("persistent_notification.config_entry_discovery") + assert state is None + + # Let the flow init complete + pause_discovery.set() + + # When it's complete it should now be visible in async_progress and have triggered + # discovery notifications + result = await init_task + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert len(hass.config_entries.flow.async_progress()) == 1 + + await hass.async_block_till_done() + state = hass.states.get("persistent_notification.config_entry_discovery") + assert state is not None