From e92e2065442c1fcd658175da62ded590053f57c6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 23 Aug 2021 23:01:21 -0500 Subject: [PATCH] Fix race that allowed multiple config flows with the same unique id (#55131) - If a config flow set a unique id and then did an await to return control to the event loop, another discovery with the same unique id could start and it would not see the first one because it was still uninitialized. We now check uninitialized flows when setting the unique id --- homeassistant/config_entries.py | 4 +- tests/test_config_entries.py | 68 +++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 67c718a497d..50d279ec8b0 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -1205,7 +1205,7 @@ class ConfigFlow(data_entry_flow.FlowHandler): return None if raise_on_progress: - for progress in self._async_in_progress(): + for progress in self._async_in_progress(include_uninitialized=True): if progress["context"].get("unique_id") == unique_id: raise data_entry_flow.AbortFlow("already_in_progress") @@ -1213,7 +1213,7 @@ class ConfigFlow(data_entry_flow.FlowHandler): # Abort discoveries done using the default discovery unique id if unique_id != DEFAULT_DISCOVERY_UNIQUE_ID: - for progress in self._async_in_progress(): + for progress in self._async_in_progress(include_uninitialized=True): if progress["context"].get("unique_id") == DEFAULT_DISCOVERY_UNIQUE_ID: self.hass.config_entries.flow.async_abort(progress["flow_id"]) diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 4c002ad8228..2ae4ad036d4 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -2585,6 +2585,74 @@ async def test_default_discovery_abort_on_user_flow_complete(hass, manager): assert len(flows) == 0 +async def test_flow_same_device_multiple_sources(hass, manager): + """Test discovery of the same devices from multiple discovery sources.""" + mock_integration( + hass, + MockModule("comp", async_setup_entry=AsyncMock(return_value=True)), + ) + mock_entity_platform(hass, "config_flow.comp", None) + + class TestFlow(config_entries.ConfigFlow): + """Test flow.""" + + VERSION = 1 + + async def async_step_zeroconf(self, discovery_info=None): + """Test zeroconf step.""" + return await self._async_discovery_handler(discovery_info) + + async def async_step_homekit(self, discovery_info=None): + """Test homekit step.""" + return await self._async_discovery_handler(discovery_info) + + async def _async_discovery_handler(self, discovery_info=None): + """Test any discovery handler.""" + await self.async_set_unique_id("thisid") + self._abort_if_unique_id_configured() + await asyncio.sleep(0.1) + return await self.async_step_link() + + async def async_step_link(self, user_input=None): + """Test a link step.""" + if user_input is None: + return self.async_show_form(step_id="link") + return self.async_create_entry(title="title", data={"token": "supersecret"}) + + with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}): + # Create one to be in progress + flow1 = manager.flow.async_init( + "comp", context={"source": config_entries.SOURCE_ZEROCONF} + ) + flow2 = manager.flow.async_init( + "comp", context={"source": config_entries.SOURCE_ZEROCONF} + ) + flow3 = manager.flow.async_init( + "comp", context={"source": config_entries.SOURCE_HOMEKIT} + ) + result1, result2, result3 = await asyncio.gather(flow1, flow2, flow3) + + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + assert flows[0]["context"]["unique_id"] == "thisid" + + # Finish flow + result2 = await manager.flow.async_configure( + flows[0]["flow_id"], user_input={"fake": "data"} + ) + assert result2["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + + assert len(hass.config_entries.flow.async_progress()) == 0 + + entry = hass.config_entries.async_entries("comp")[0] + assert entry.title == "title" + assert entry.source in { + config_entries.SOURCE_ZEROCONF, + config_entries.SOURCE_HOMEKIT, + } + assert entry.unique_id == "thisid" + + async def test_updating_entry_with_and_without_changes(manager): """Test that we can update an entry data.""" entry = MockConfigEntry(