Don't expose flows that aren't initialised. (#30432)

* Don't expose flows that aren't initialised.

If a flow init does not return immediately then there is a window where our
behaviour is screwy:

 * Can try to configure a flow that isn't ready
 * Can show notifications for discoveries that might yet return an abort

This moves the flow discovery events and notifications to after the flow is
initialised and hides flows that don't have a cur_step from async_progress

* Fix tradfri test

* Black.

* Lint fixes
This commit is contained in:
Jc2k 2020-01-03 16:28:05 +00:00 committed by Paulus Schoutsen
parent df6c7b97f5
commit 6b519499a7
6 changed files with 106 additions and 21 deletions

View File

@ -61,7 +61,13 @@ ENTRY_STATE_FAILED_UNLOAD = "failed_unload"
UNRECOVERABLE_STATES = (ENTRY_STATE_MIGRATION_ERROR, ENTRY_STATE_FAILED_UNLOAD) UNRECOVERABLE_STATES = (ENTRY_STATE_MIGRATION_ERROR, ENTRY_STATE_FAILED_UNLOAD)
DISCOVERY_NOTIFICATION_ID = "config_entry_discovery" DISCOVERY_NOTIFICATION_ID = "config_entry_discovery"
DISCOVERY_SOURCES = (SOURCE_SSDP, SOURCE_ZEROCONF, SOURCE_DISCOVERY, SOURCE_IMPORT) DISCOVERY_SOURCES = (
SOURCE_SSDP,
SOURCE_ZEROCONF,
SOURCE_DISCOVERY,
SOURCE_IMPORT,
SOURCE_UNIGNORE,
)
EVENT_FLOW_DISCOVERED = "config_entry_discovered" EVENT_FLOW_DISCOVERED = "config_entry_discovered"
@ -511,7 +517,15 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
if not context or "source" not in context: if not context or "source" not in context:
raise KeyError("Context not set or doesn't have a source set") raise KeyError("Context not set or doesn't have a source set")
source = context["source"] flow = cast(ConfigFlow, handler())
flow.init_step = context["source"]
return flow
async def async_post_init(
self, flow: data_entry_flow.FlowHandler, result: dict
) -> None:
"""After a flow is initialised trigger new flow notifications."""
source = flow.context["source"]
# Create notification. # Create notification.
if source in DISCOVERY_SOURCES: if source in DISCOVERY_SOURCES:
@ -525,10 +539,6 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
notification_id=DISCOVERY_NOTIFICATION_ID, notification_id=DISCOVERY_NOTIFICATION_ID,
) )
flow = cast(ConfigFlow, handler())
flow.init_step = source
return flow
class ConfigEntries: class ConfigEntries:
"""Manage the configuration entries. """Manage the configuration entries.

View File

@ -76,12 +76,19 @@ class FlowManager(abc.ABC):
"""Finish a config flow and add an entry.""" """Finish a config flow and add an entry."""
pass pass
async def async_post_init(
self, flow: "FlowHandler", result: Dict[str, Any]
) -> None:
"""Entry has finished executing its first step asynchronously."""
pass
@callback @callback
def async_progress(self) -> List[Dict]: def async_progress(self) -> List[Dict]:
"""Return the flows in progress.""" """Return the flows in progress."""
return [ return [
{"flow_id": flow.flow_id, "handler": flow.handler, "context": flow.context} {"flow_id": flow.flow_id, "handler": flow.handler, "context": flow.context}
for flow in self._progress.values() for flow in self._progress.values()
if flow.cur_step is not None
] ]
async def async_init( async def async_init(
@ -99,7 +106,12 @@ class FlowManager(abc.ABC):
flow.context = context flow.context = context
self._progress[flow.flow_id] = flow self._progress[flow.flow_id] = flow
return await self._async_handle_step(flow, flow.init_step, data) result = await self._async_handle_step(flow, flow.init_step, data)
if result["type"] != RESULT_TYPE_ABORT:
await self.async_post_init(flow, result)
return result
async def async_configure( async def async_configure(
self, flow_id: str, user_input: Optional[Dict] = None self, flow_id: str, user_input: Optional[Dict] = None

View File

@ -3,6 +3,8 @@ from unittest.mock import patch
import pytest import pytest
from tests.common import mock_coro
@pytest.fixture @pytest.fixture
def mock_gateway_info(): def mock_gateway_info():
@ -11,3 +13,11 @@ def mock_gateway_info():
"homeassistant.components.tradfri.config_flow.get_gateway_info" "homeassistant.components.tradfri.config_flow.get_gateway_info"
) as mock_gateway: ) as mock_gateway:
yield mock_gateway yield mock_gateway
@pytest.fixture
def mock_entry_setup():
"""Mock entry setup."""
with patch("homeassistant.components.tradfri.async_setup_entry") as mock_setup:
mock_setup.return_value = mock_coro(True)
yield mock_setup

View File

@ -18,14 +18,6 @@ def mock_auth():
yield mock_auth yield mock_auth
@pytest.fixture
def mock_entry_setup():
"""Mock entry setup."""
with patch("homeassistant.components.tradfri.async_setup_entry") as mock_setup:
mock_setup.return_value = mock_coro(True)
yield mock_setup
async def test_user_connection_successful(hass, mock_auth, mock_entry_setup): async def test_user_connection_successful(hass, mock_auth, mock_entry_setup):
"""Test a successful connection.""" """Test a successful connection."""
mock_auth.side_effect = lambda hass, host, code: mock_coro( mock_auth.side_effect = lambda hass, host, code: mock_coro(

View File

@ -3,7 +3,7 @@ from unittest.mock import patch
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry, mock_coro
async def test_config_yaml_host_not_imported(hass): async def test_config_yaml_host_not_imported(hass):
@ -49,8 +49,12 @@ async def test_config_json_host_not_imported(hass):
assert len(mock_init.mock_calls) == 0 assert len(mock_init.mock_calls) == 0
async def test_config_json_host_imported(hass, mock_gateway_info): async def test_config_json_host_imported(hass, mock_gateway_info, mock_entry_setup):
"""Test that we import a configured host.""" """Test that we import a configured host."""
mock_gateway_info.side_effect = lambda hass, host, identity, key: mock_coro(
{"host": host, "identity": identity, "key": key, "gateway_id": "mock-gateway"}
)
with patch( with patch(
"homeassistant.components.tradfri.load_json", "homeassistant.components.tradfri.load_json",
return_value={"mock-host": {"key": "some-info"}}, return_value={"mock-host": {"key": "some-info"}},
@ -58,7 +62,7 @@ async def test_config_json_host_imported(hass, mock_gateway_info):
assert await async_setup_component(hass, "tradfri", {"tradfri": {}}) assert await async_setup_component(hass, "tradfri", {"tradfri": {}})
await hass.async_block_till_done() await hass.async_block_till_done()
progress = hass.config_entries.flow.async_progress() config_entry = mock_entry_setup.mock_calls[0][1][1]
assert len(progress) == 1 assert config_entry.domain == "tradfri"
assert progress[0]["handler"] == "tradfri" assert config_entry.source == "import"
assert progress[0]["context"] == {"source": "import"} assert config_entry.title == "mock-host"

View File

@ -1303,3 +1303,60 @@ async def test_unignore_default_impl(hass, manager):
assert len(hass.config_entries.async_entries("comp")) == 0 assert len(hass.config_entries.async_entries("comp")) == 0
assert len(hass.config_entries.flow.async_progress()) == 0 assert len(hass.config_entries.flow.async_progress()) == 0
async def test_partial_flows_hidden(hass, manager):
"""Test that flows that don't have a cur_step and haven't finished initing are hidden."""
async_setup_entry = MagicMock(return_value=mock_coro(True))
mock_integration(hass, MockModule("comp", async_setup_entry=async_setup_entry))
mock_entity_platform(hass, "config_flow.comp", None)
await async_setup_component(hass, "persistent_notification", {})
# A flag to test our assertion that `async_step_discovery` was called and is in its blocked state
# This simulates if the step was e.g. doing network i/o
discovery_started = asyncio.Event()
# A flag to allow `async_step_discovery` to resume after we have verified the uninited flow is not
# visible and has not triggered a discovery alert. This lets us control when the mocked network
# i/o is complete.
pause_discovery = asyncio.Event()
class TestFlow(config_entries.ConfigFlow):
VERSION = 1
async def async_step_discovery(self, user_input):
discovery_started.set()
await pause_discovery.wait()
return self.async_show_form(step_id="someform")
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
# Start a config entry flow and wait for it to be blocked
init_task = asyncio.ensure_future(
manager.flow.async_init(
"comp",
context={"source": config_entries.SOURCE_DISCOVERY},
data={"unique_id": "mock-unique-id"},
)
)
await discovery_started.wait()
# While it's blocked it shouldn't be visible or trigger discovery notifications
assert len(hass.config_entries.flow.async_progress()) == 0
await hass.async_block_till_done()
state = hass.states.get("persistent_notification.config_entry_discovery")
assert state is None
# Let the flow init complete
pause_discovery.set()
# When it's complete it should now be visible in async_progress and have triggered
# discovery notifications
result = await init_task
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert len(hass.config_entries.flow.async_progress()) == 1
await hass.async_block_till_done()
state = hass.states.get("persistent_notification.config_entry_discovery")
assert state is not None