diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index d1b5c927a2b..6fb5595dac4 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -61,7 +61,13 @@ ENTRY_STATE_FAILED_UNLOAD = "failed_unload" UNRECOVERABLE_STATES = (ENTRY_STATE_MIGRATION_ERROR, ENTRY_STATE_FAILED_UNLOAD) DISCOVERY_NOTIFICATION_ID = "config_entry_discovery" -DISCOVERY_SOURCES = (SOURCE_SSDP, SOURCE_ZEROCONF, SOURCE_DISCOVERY, SOURCE_IMPORT) +DISCOVERY_SOURCES = ( + SOURCE_SSDP, + SOURCE_ZEROCONF, + SOURCE_DISCOVERY, + SOURCE_IMPORT, + SOURCE_UNIGNORE, +) EVENT_FLOW_DISCOVERED = "config_entry_discovered" @@ -511,7 +517,15 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): if not context or "source" not in context: raise KeyError("Context not set or doesn't have a source set") - source = context["source"] + flow = cast(ConfigFlow, handler()) + flow.init_step = context["source"] + return flow + + async def async_post_init( + self, flow: data_entry_flow.FlowHandler, result: dict + ) -> None: + """After a flow is initialised trigger new flow notifications.""" + source = flow.context["source"] # Create notification. if source in DISCOVERY_SOURCES: @@ -525,10 +539,6 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): notification_id=DISCOVERY_NOTIFICATION_ID, ) - flow = cast(ConfigFlow, handler()) - flow.init_step = source - return flow - class ConfigEntries: """Manage the configuration entries. diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 5e72dea9273..4dd1c7acf50 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -76,12 +76,19 @@ class FlowManager(abc.ABC): """Finish a config flow and add an entry.""" pass + async def async_post_init( + self, flow: "FlowHandler", result: Dict[str, Any] + ) -> None: + """Entry has finished executing its first step asynchronously.""" + pass + @callback def async_progress(self) -> List[Dict]: """Return the flows in progress.""" return [ {"flow_id": flow.flow_id, "handler": flow.handler, "context": flow.context} for flow in self._progress.values() + if flow.cur_step is not None ] async def async_init( @@ -99,7 +106,12 @@ class FlowManager(abc.ABC): flow.context = context self._progress[flow.flow_id] = flow - return await self._async_handle_step(flow, flow.init_step, data) + result = await self._async_handle_step(flow, flow.init_step, data) + + if result["type"] != RESULT_TYPE_ABORT: + await self.async_post_init(flow, result) + + return result async def async_configure( self, flow_id: str, user_input: Optional[Dict] = None diff --git a/tests/components/tradfri/conftest.py b/tests/components/tradfri/conftest.py index 1c6e572b81f..d835c06f256 100644 --- a/tests/components/tradfri/conftest.py +++ b/tests/components/tradfri/conftest.py @@ -3,6 +3,8 @@ from unittest.mock import patch import pytest +from tests.common import mock_coro + @pytest.fixture def mock_gateway_info(): @@ -11,3 +13,11 @@ def mock_gateway_info(): "homeassistant.components.tradfri.config_flow.get_gateway_info" ) as mock_gateway: yield mock_gateway + + +@pytest.fixture +def mock_entry_setup(): + """Mock entry setup.""" + with patch("homeassistant.components.tradfri.async_setup_entry") as mock_setup: + mock_setup.return_value = mock_coro(True) + yield mock_setup diff --git a/tests/components/tradfri/test_config_flow.py b/tests/components/tradfri/test_config_flow.py index ad7386c530f..18fb55eda2f 100644 --- a/tests/components/tradfri/test_config_flow.py +++ b/tests/components/tradfri/test_config_flow.py @@ -18,14 +18,6 @@ def mock_auth(): yield mock_auth -@pytest.fixture -def mock_entry_setup(): - """Mock entry setup.""" - with patch("homeassistant.components.tradfri.async_setup_entry") as mock_setup: - mock_setup.return_value = mock_coro(True) - yield mock_setup - - async def test_user_connection_successful(hass, mock_auth, mock_entry_setup): """Test a successful connection.""" mock_auth.side_effect = lambda hass, host, code: mock_coro( diff --git a/tests/components/tradfri/test_init.py b/tests/components/tradfri/test_init.py index 67ecb8d054a..cf9034df8d6 100644 --- a/tests/components/tradfri/test_init.py +++ b/tests/components/tradfri/test_init.py @@ -3,7 +3,7 @@ from unittest.mock import patch from homeassistant.setup import async_setup_component -from tests.common import MockConfigEntry +from tests.common import MockConfigEntry, mock_coro async def test_config_yaml_host_not_imported(hass): @@ -49,8 +49,12 @@ async def test_config_json_host_not_imported(hass): assert len(mock_init.mock_calls) == 0 -async def test_config_json_host_imported(hass, mock_gateway_info): +async def test_config_json_host_imported(hass, mock_gateway_info, mock_entry_setup): """Test that we import a configured host.""" + mock_gateway_info.side_effect = lambda hass, host, identity, key: mock_coro( + {"host": host, "identity": identity, "key": key, "gateway_id": "mock-gateway"} + ) + with patch( "homeassistant.components.tradfri.load_json", return_value={"mock-host": {"key": "some-info"}}, @@ -58,7 +62,7 @@ async def test_config_json_host_imported(hass, mock_gateway_info): assert await async_setup_component(hass, "tradfri", {"tradfri": {}}) await hass.async_block_till_done() - progress = hass.config_entries.flow.async_progress() - assert len(progress) == 1 - assert progress[0]["handler"] == "tradfri" - assert progress[0]["context"] == {"source": "import"} + config_entry = mock_entry_setup.mock_calls[0][1][1] + assert config_entry.domain == "tradfri" + assert config_entry.source == "import" + assert config_entry.title == "mock-host" diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index c3a87bcf3a0..46410de1999 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -1303,3 +1303,60 @@ async def test_unignore_default_impl(hass, manager): assert len(hass.config_entries.async_entries("comp")) == 0 assert len(hass.config_entries.flow.async_progress()) == 0 + + +async def test_partial_flows_hidden(hass, manager): + """Test that flows that don't have a cur_step and haven't finished initing are hidden.""" + async_setup_entry = MagicMock(return_value=mock_coro(True)) + mock_integration(hass, MockModule("comp", async_setup_entry=async_setup_entry)) + mock_entity_platform(hass, "config_flow.comp", None) + await async_setup_component(hass, "persistent_notification", {}) + + # A flag to test our assertion that `async_step_discovery` was called and is in its blocked state + # This simulates if the step was e.g. doing network i/o + discovery_started = asyncio.Event() + + # A flag to allow `async_step_discovery` to resume after we have verified the uninited flow is not + # visible and has not triggered a discovery alert. This lets us control when the mocked network + # i/o is complete. + pause_discovery = asyncio.Event() + + class TestFlow(config_entries.ConfigFlow): + + VERSION = 1 + + async def async_step_discovery(self, user_input): + discovery_started.set() + await pause_discovery.wait() + return self.async_show_form(step_id="someform") + + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): + # Start a config entry flow and wait for it to be blocked + init_task = asyncio.ensure_future( + manager.flow.async_init( + "comp", + context={"source": config_entries.SOURCE_DISCOVERY}, + data={"unique_id": "mock-unique-id"}, + ) + ) + await discovery_started.wait() + + # While it's blocked it shouldn't be visible or trigger discovery notifications + assert len(hass.config_entries.flow.async_progress()) == 0 + + await hass.async_block_till_done() + state = hass.states.get("persistent_notification.config_entry_discovery") + assert state is None + + # Let the flow init complete + pause_discovery.set() + + # When it's complete it should now be visible in async_progress and have triggered + # discovery notifications + result = await init_task + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert len(hass.config_entries.flow.async_progress()) == 1 + + await hass.async_block_till_done() + state = hass.states.get("persistent_notification.config_entry_discovery") + assert state is not None