Abort if a flow is removed during a step (#142138)

* Abort if a flow is removed during a step

* Reorganize code

* Only call _set_pending_import_done if an entry is created

* Try a new approach

* Add tests

* Update tests
This commit is contained in:
Erik Montnemery 2025-04-09 19:04:41 +02:00 committed by GitHub
parent 7f4d178781
commit f344314762
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 95 additions and 33 deletions

View File

@ -1503,6 +1503,22 @@ class ConfigEntriesFlowManager(
future.set_result(None) future.set_result(None)
self._discovery_event_debouncer.async_shutdown() self._discovery_event_debouncer.async_shutdown()
@callback
def async_flow_removed(
self,
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
) -> None:
"""Handle a removed config flow."""
flow = cast(ConfigFlow, flow)
# Clean up issue if this is a reauth flow
if flow.context["source"] == SOURCE_REAUTH:
if (entry_id := flow.context.get("entry_id")) is not None and (
entry := self.config_entries.async_get_entry(entry_id)
) is not None:
issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}"
ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id)
async def async_finish_flow( async def async_finish_flow(
self, self,
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult], flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
@ -1515,20 +1531,6 @@ class ConfigEntriesFlowManager(
""" """
flow = cast(ConfigFlow, flow) flow = cast(ConfigFlow, flow)
# Mark the step as done.
# We do this to avoid a circular dependency where async_finish_flow sets up a
# new entry, which needs the integration to be set up, which is waiting for
# init to be done.
self._set_pending_import_done(flow)
# Clean up issue if this is a reauth flow
if flow.context["source"] == SOURCE_REAUTH:
if (entry_id := flow.context.get("entry_id")) is not None and (
entry := self.config_entries.async_get_entry(entry_id)
) is not None:
issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}"
ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id)
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY: if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
# If there's a config entry with a matching unique ID, # If there's a config entry with a matching unique ID,
# update the discovery key. # update the discovery key.
@ -1567,6 +1569,12 @@ class ConfigEntriesFlowManager(
) )
return result return result
# Mark the step as done.
# We do this to avoid a circular dependency where async_finish_flow sets up a
# new entry, which needs the integration to be set up, which is waiting for
# init to be done.
self._set_pending_import_done(flow)
# Avoid adding a config entry for a integration # Avoid adding a config entry for a integration
# that only supports a single config entry, but already has an entry # that only supports a single config entry, but already has an entry
if ( if (

View File

@ -207,6 +207,13 @@ class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
Handler key is the domain of the component that we want to set up. Handler key is the domain of the component that we want to set up.
""" """
@callback
def async_flow_removed(
self,
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
) -> None:
"""Handle a removed data entry flow."""
@abc.abstractmethod @abc.abstractmethod
async def async_finish_flow( async def async_finish_flow(
self, self,
@ -457,6 +464,7 @@ class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
"""Remove a flow from in progress.""" """Remove a flow from in progress."""
if (flow := self._progress.pop(flow_id, None)) is None: if (flow := self._progress.pop(flow_id, None)) is None:
raise UnknownFlow raise UnknownFlow
self.async_flow_removed(flow)
self._async_remove_flow_from_index(flow) self._async_remove_flow_from_index(flow)
flow.async_cancel_progress_task() flow.async_cancel_progress_task()
try: try:
@ -485,6 +493,10 @@ class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
description_placeholders=err.description_placeholders, description_placeholders=err.description_placeholders,
) )
if flow.flow_id not in self._progress:
# The flow was removed during the step
raise UnknownFlow
# Setup the flow handler's preview if needed # Setup the flow handler's preview if needed
if result.get("preview") is not None: if result.get("preview") is not None:
await self._async_setup_preview(flow) await self._async_setup_preview(flow)

View File

@ -1395,9 +1395,7 @@ async def test_reauth_issue_flow_aborted(
issue = await _test_reauth_issue(hass, manager, issue_registry) issue = await _test_reauth_issue(hass, manager, issue_registry)
manager.flow.async_abort(issue.data["flow_id"]) manager.flow.async_abort(issue.data["flow_id"])
# This can be considered a bug, we should make sure the issue is always assert len(issue_registry.issues) == 0
# removed when the reauth flow is aborted.
assert len(issue_registry.issues) == 1
async def _test_reauth_issue( async def _test_reauth_issue(

View File

@ -243,6 +243,23 @@ async def test_abort_calls_async_remove(manager: MockFlowManager) -> None:
assert len(manager.mock_created_entries) == 0 assert len(manager.mock_created_entries) == 0
async def test_abort_calls_async_flow_removed(manager: MockFlowManager) -> None:
"""Test abort calling the async_flow_removed FlowManager method."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
return self.async_abort(reason="reason")
manager.async_flow_removed = Mock()
await manager.async_init("test")
manager.async_flow_removed.assert_called_once()
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 0
async def test_abort_calls_async_remove_with_exception( async def test_abort_calls_async_remove_with_exception(
manager: MockFlowManager, caplog: pytest.LogCaptureFixture manager: MockFlowManager, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
@ -288,13 +305,7 @@ async def test_create_saves_data(manager: MockFlowManager) -> None:
async def test_create_aborted_flow(manager: MockFlowManager) -> None: async def test_create_aborted_flow(manager: MockFlowManager) -> None:
"""Test return create_entry from aborted flow. """Test return create_entry from aborted flow."""
Note: The entry is created even if the flow is already aborted, then the
flow raises an UnknownFlow exception. This behavior is not logical, and
we should consider changing it to not create the entry if the flow is
aborted.
"""
@manager.mock_reg_handler("test") @manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler): class TestFlow(data_entry_flow.FlowHandler):
@ -308,14 +319,25 @@ async def test_create_aborted_flow(manager: MockFlowManager) -> None:
await manager.async_init("test") await manager.async_init("test")
assert len(manager.async_progress()) == 0 assert len(manager.async_progress()) == 0
# The entry is created even if the flow is aborted # No entry should be created if the flow is aborted
assert len(manager.mock_created_entries) == 1 assert len(manager.mock_created_entries) == 0
entry = manager.mock_created_entries[0]
assert entry["handler"] == "test" async def test_create_calls_async_flow_removed(manager: MockFlowManager) -> None:
assert entry["title"] == "Test Title" """Test create calling the async_flow_removed FlowManager method."""
assert entry["data"] == "Test Data"
assert entry["source"] is None @manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
return self.async_create_entry(title="Test Title", data="Test Data")
manager.async_flow_removed = Mock()
await manager.async_init("test")
manager.async_flow_removed.assert_called_once()
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 1
async def test_discovery_init_flow(manager: MockFlowManager) -> None: async def test_discovery_init_flow(manager: MockFlowManager) -> None:
@ -930,12 +952,34 @@ async def test_configure_raises_unknown_flow_if_not_in_progress(
await manager.async_configure("wrong_flow_id") await manager.async_configure("wrong_flow_id")
async def test_abort_raises_unknown_flow_if_not_in_progress( async def test_manager_abort_raises_unknown_flow_if_not_in_progress(
manager: MockFlowManager, manager: MockFlowManager,
) -> None: ) -> None:
"""Test abort raises UnknownFlow if the flow is not in progress.""" """Test abort raises UnknownFlow if the flow is not in progress."""
with pytest.raises(data_entry_flow.UnknownFlow): with pytest.raises(data_entry_flow.UnknownFlow):
await manager.async_abort("wrong_flow_id") manager.async_abort("wrong_flow_id")
async def test_manager_abort_calls_async_flow_removed(manager: MockFlowManager) -> None:
"""Test abort calling the async_flow_removed FlowManager method."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
return self.async_show_form(step_id="init")
manager.async_flow_removed = Mock()
result = await manager.async_init("test")
assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["step_id"] == "init"
manager.async_flow_removed.assert_not_called()
manager.async_abort(result["flow_id"])
manager.async_flow_removed.assert_called_once()
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 0
@pytest.mark.parametrize( @pytest.mark.parametrize(