diff --git a/homeassistant/components/tplink/config_flow.py b/homeassistant/components/tplink/config_flow.py index 291a7e78c62..0914c4191cf 100644 --- a/homeassistant/components/tplink/config_flow.py +++ b/homeassistant/components/tplink/config_flow.py @@ -567,7 +567,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): ) async def _async_reload_requires_auth_entries(self) -> None: - """Reload any in progress config flow that now have credentials.""" + """Reload all config entries after auth update.""" _config_entries = self.hass.config_entries if self.source == SOURCE_REAUTH: @@ -579,11 +579,9 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN): context = flow["context"] if context.get("source") != SOURCE_REAUTH: continue - entry_id: str = context["entry_id"] + entry_id = context["entry_id"] if entry := _config_entries.async_get_entry(entry_id): await _config_entries.async_reload(entry.entry_id) - if entry.state is ConfigEntryState.LOADED: - _config_entries.flow.async_abort(flow["flow_id"]) @callback def _async_create_or_update_entry_from_device( diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index f5f73842042..3064fdd54bb 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -1521,10 +1521,9 @@ class ConfigEntriesFlowManager( # Clean up issue if this is a reauth flow if flow.context["source"] == SOURCE_REAUTH: - if (entry_id := flow.context.get("entry_id")) is not None and ( - entry := self.config_entries.async_get_entry(entry_id) - ) is not None: - issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}" + if (entry_id := flow.context.get("entry_id")) is not None: + # The config entry's domain is flow.handler + issue_id = f"config_entry_reauth_{flow.handler}_{entry_id}" ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id) async def async_finish_flow( @@ -2128,13 +2127,7 @@ class ConfigEntries: # If the configuration entry is removed during reauth, it should # abort any reauth flow that is active for the removed entry and # linked issues. - for progress_flow in self.hass.config_entries.flow.async_progress_by_handler( - entry.domain, match_context={"entry_id": entry_id, "source": SOURCE_REAUTH} - ): - if "flow_id" in progress_flow: - self.hass.config_entries.flow.async_abort(progress_flow["flow_id"]) - issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}" - ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id) + _abort_reauth_flows(self.hass, entry.domain, entry_id) self._async_dispatch(ConfigEntryChange.REMOVED, entry) @@ -2266,6 +2259,9 @@ class ConfigEntries: # attempts. entry.async_cancel_retry_setup() + # Abort any in-progress reauth flow and linked issues + _abort_reauth_flows(self.hass, entry.domain, entry_id) + if entry.domain not in self.hass.config.components: # If the component is not loaded, just load it as # the config entry will be loaded as well. We need @@ -3786,3 +3782,13 @@ async def _async_get_flow_handler( return handler raise data_entry_flow.UnknownHandler + + +@callback +def _abort_reauth_flows(hass: HomeAssistant, domain: str, entry_id: str) -> None: + """Abort reauth flows for an entry.""" + for progress_flow in hass.config_entries.flow.async_progress_by_handler( + domain, match_context={"entry_id": entry_id, "source": SOURCE_REAUTH} + ): + if "flow_id" in progress_flow: + hass.config_entries.flow.async_abort(progress_flow["flow_id"]) diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 6a288380cd0..e2e31ffce29 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -494,8 +494,11 @@ class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]): ) if flow.flow_id not in self._progress: - # The flow was removed during the step - raise UnknownFlow + # The flow was removed during the step, raise UnknownFlow + # unless the result is an abort + if result["type"] != FlowResultType.ABORT: + raise UnknownFlow + return result # Setup the flow handler's preview if needed if result.get("preview") is not None: @@ -547,7 +550,7 @@ class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]): flow.cur_step = result return result - # Abort and Success results both finish the flow + # Abort and Success results both finish the flow. self._async_remove_flow_progress(flow.flow_id) return result diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 2527a6a151d..13ecd855624 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -695,7 +695,7 @@ async def test_remove_entry_cancels_reauth( manager: config_entries.ConfigEntries, issue_registry: ir.IssueRegistry, ) -> None: - """Tests that removing a config entry, also aborts existing reauth flows.""" + """Tests that removing a config entry also aborts existing reauth flows.""" entry = MockConfigEntry(title="test_title", domain="test") mock_setup_entry = AsyncMock(side_effect=ConfigEntryAuthFailed()) @@ -722,6 +722,40 @@ async def test_remove_entry_cancels_reauth( assert not issue_registry.async_get_issue(HOMEASSISTANT_DOMAIN, issue_id) +async def test_reload_entry_cancels_reauth( + hass: HomeAssistant, + manager: config_entries.ConfigEntries, + issue_registry: ir.IssueRegistry, +) -> None: + """Tests that reloading a config entry also aborts existing reauth flows.""" + entry = MockConfigEntry(title="test_title", domain="test") + + mock_setup_entry = AsyncMock(side_effect=ConfigEntryAuthFailed()) + mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry)) + mock_platform(hass, "test.config_flow", None) + + entry.add_to_hass(hass) + await manager.async_setup(entry.entry_id) + await hass.async_block_till_done() + + flows = hass.config_entries.flow.async_progress_by_handler("test") + assert len(flows) == 1 + assert flows[0]["context"]["entry_id"] == entry.entry_id + assert flows[0]["context"]["source"] == config_entries.SOURCE_REAUTH + assert entry.state is config_entries.ConfigEntryState.SETUP_ERROR + + issue_id = f"config_entry_reauth_test_{entry.entry_id}" + assert issue_registry.async_get_issue(HOMEASSISTANT_DOMAIN, issue_id) + + mock_setup_entry.return_value = True + mock_setup_entry.side_effect = None + await manager.async_reload(entry.entry_id) + + flows = hass.config_entries.flow.async_progress_by_handler("test") + assert len(flows) == 0 + assert not issue_registry.async_get_issue(HOMEASSISTANT_DOMAIN, issue_id) + + async def test_remove_entry_handles_callback_error( hass: HomeAssistant, manager: config_entries.ConfigEntries ) -> None: diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index bcc40251bad..804b1fea405 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -219,8 +219,8 @@ async def test_abort_aborted_flow(manager: MockFlowManager) -> None: manager.async_abort(self.flow_id) return self.async_abort(reason="blah") - with pytest.raises(data_entry_flow.UnknownFlow): - await manager.async_init("test") + form = await manager.async_init("test") + assert form["reason"] == "blah" assert len(manager.async_progress()) == 0 assert len(manager.mock_created_entries) == 0