Don't raise in ConfigFlow.async_set_unique_id if the other flow is a reauth flow (#140723)

* Don't raise in ConfigFlow.async_set_unique_id if the other flow is a reauth flow

* Improve test
This commit is contained in:
Erik Montnemery 2025-03-17 20:04:30 +01:00 committed by GitHub
parent e16f0e9af3
commit 290dab25bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 70 additions and 7 deletions

View File

@ -2986,8 +2986,11 @@ class ConfigFlow(ConfigEntryBaseFlow):
return None
if raise_on_progress:
if self._async_in_progress(
include_uninitialized=True, match_context={"unique_id": unique_id}
if any(
flow["context"]["source"] != SOURCE_REAUTH
for flow in self._async_in_progress(
include_uninitialized=True, match_context={"unique_id": unique_id}
)
):
raise data_entry_flow.AbortFlow("already_in_progress")

View File

@ -3566,37 +3566,97 @@ async def test_unique_id_not_update_existing_entry(
assert len(async_reload.mock_calls) == 0
ABORT_IN_PROGRESS = {
"type": data_entry_flow.FlowResultType.ABORT,
"reason": "already_in_progress",
}
@pytest.mark.parametrize(
("existing_flow_source", "expected_result"),
# Test all sources except SOURCE_IGNORE
[
(config_entries.SOURCE_BLUETOOTH, ABORT_IN_PROGRESS),
(config_entries.SOURCE_DHCP, ABORT_IN_PROGRESS),
(config_entries.SOURCE_DISCOVERY, ABORT_IN_PROGRESS),
(config_entries.SOURCE_HARDWARE, ABORT_IN_PROGRESS),
(config_entries.SOURCE_HASSIO, ABORT_IN_PROGRESS),
(config_entries.SOURCE_HOMEKIT, ABORT_IN_PROGRESS),
(config_entries.SOURCE_IMPORT, ABORT_IN_PROGRESS),
(config_entries.SOURCE_INTEGRATION_DISCOVERY, ABORT_IN_PROGRESS),
(config_entries.SOURCE_MQTT, ABORT_IN_PROGRESS),
(config_entries.SOURCE_REAUTH, {"type": data_entry_flow.FlowResultType.FORM}),
(config_entries.SOURCE_RECONFIGURE, ABORT_IN_PROGRESS),
(config_entries.SOURCE_SSDP, ABORT_IN_PROGRESS),
(config_entries.SOURCE_SYSTEM, ABORT_IN_PROGRESS),
(config_entries.SOURCE_USB, ABORT_IN_PROGRESS),
(config_entries.SOURCE_USER, ABORT_IN_PROGRESS),
(config_entries.SOURCE_ZEROCONF, ABORT_IN_PROGRESS),
],
)
async def test_unique_id_in_progress(
hass: HomeAssistant, manager: config_entries.ConfigEntries
hass: HomeAssistant,
manager: config_entries.ConfigEntries,
existing_flow_source: str,
expected_result: dict,
) -> None:
"""Test that we abort if there is already a flow in progress with same unique id."""
mock_integration(hass, MockModule("comp"))
mock_platform(hass, "comp.config_flow", None)
entry = MockConfigEntry(domain="comp")
entry.add_to_hass(hass)
class TestFlow(config_entries.ConfigFlow):
"""Test flow."""
VERSION = 1
async def _async_step_discovery_without_unique_id(self):
"""Handle a flow initialized by discovery."""
return await self._async_step()
async def async_step_hardware(self, user_input=None):
"""Test hardware step."""
return await self._async_step()
async def async_step_import(self, user_input=None):
"""Test import step."""
return await self._async_step()
async def async_step_reauth(self, user_input=None):
"""Test reauth step."""
return await self._async_step()
async def async_step_reconfigure(self, user_input=None):
"""Test reconfigure step."""
return await self._async_step()
async def async_step_system(self, user_input=None):
"""Test system step."""
return await self._async_step()
async def async_step_user(self, user_input=None):
"""Test user step."""
return await self._async_step()
async def _async_step(self, user_input=None):
"""Test step."""
await self.async_set_unique_id("mock-unique-id")
return self.async_show_form(step_id="discovery")
with mock_config_flow("comp", TestFlow):
# Create one to be in progress
result = await manager.flow.async_init(
"comp", context={"source": config_entries.SOURCE_USER}
"comp", context={"source": existing_flow_source, "entry_id": entry.entry_id}
)
assert result["type"] == data_entry_flow.FlowResultType.FORM
# Will be canceled
result2 = await manager.flow.async_init(
"comp", context={"source": config_entries.SOURCE_USER}
)
assert result2["type"] == data_entry_flow.FlowResultType.ABORT
assert result2["reason"] == "already_in_progress"
for k, v in expected_result.items():
assert result2[k] == v
async def test_finish_flow_aborts_progress(