Abort if a flow is removed during a step (#142138)

* Abort if a flow is removed during a step

* Reorganize code

* Only call _set_pending_import_done if an entry is created

* Try a new approach

* Add tests

* Update tests
This commit is contained in:
Erik Montnemery 2025-04-09 19:04:41 +02:00 committed by GitHub
parent 7f4d178781
commit f344314762
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 95 additions and 33 deletions

View File

@ -1503,6 +1503,22 @@ class ConfigEntriesFlowManager(
future.set_result(None)
self._discovery_event_debouncer.async_shutdown()
@callback
def async_flow_removed(
self,
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
) -> None:
"""Handle a removed config flow."""
flow = cast(ConfigFlow, flow)
# Clean up issue if this is a reauth flow
if flow.context["source"] == SOURCE_REAUTH:
if (entry_id := flow.context.get("entry_id")) is not None and (
entry := self.config_entries.async_get_entry(entry_id)
) is not None:
issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}"
ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id)
async def async_finish_flow(
self,
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
@ -1515,20 +1531,6 @@ class ConfigEntriesFlowManager(
"""
flow = cast(ConfigFlow, flow)
# Mark the step as done.
# We do this to avoid a circular dependency where async_finish_flow sets up a
# new entry, which needs the integration to be set up, which is waiting for
# init to be done.
self._set_pending_import_done(flow)
# Clean up issue if this is a reauth flow
if flow.context["source"] == SOURCE_REAUTH:
if (entry_id := flow.context.get("entry_id")) is not None and (
entry := self.config_entries.async_get_entry(entry_id)
) is not None:
issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}"
ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id)
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
# If there's a config entry with a matching unique ID,
# update the discovery key.
@ -1567,6 +1569,12 @@ class ConfigEntriesFlowManager(
)
return result
# Mark the step as done.
# We do this to avoid a circular dependency where async_finish_flow sets up a
# new entry, which needs the integration to be set up, which is waiting for
# init to be done.
self._set_pending_import_done(flow)
# Avoid adding a config entry for a integration
# that only supports a single config entry, but already has an entry
if (

View File

@ -207,6 +207,13 @@ class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
Handler key is the domain of the component that we want to set up.
"""
@callback
def async_flow_removed(
self,
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
) -> None:
"""Handle a removed data entry flow."""
@abc.abstractmethod
async def async_finish_flow(
self,
@ -457,6 +464,7 @@ class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
"""Remove a flow from in progress."""
if (flow := self._progress.pop(flow_id, None)) is None:
raise UnknownFlow
self.async_flow_removed(flow)
self._async_remove_flow_from_index(flow)
flow.async_cancel_progress_task()
try:
@ -485,6 +493,10 @@ class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
description_placeholders=err.description_placeholders,
)
if flow.flow_id not in self._progress:
# The flow was removed during the step
raise UnknownFlow
# Setup the flow handler's preview if needed
if result.get("preview") is not None:
await self._async_setup_preview(flow)

View File

@ -1395,9 +1395,7 @@ async def test_reauth_issue_flow_aborted(
issue = await _test_reauth_issue(hass, manager, issue_registry)
manager.flow.async_abort(issue.data["flow_id"])
# This can be considered a bug, we should make sure the issue is always
# removed when the reauth flow is aborted.
assert len(issue_registry.issues) == 1
assert len(issue_registry.issues) == 0
async def _test_reauth_issue(

View File

@ -243,6 +243,23 @@ async def test_abort_calls_async_remove(manager: MockFlowManager) -> None:
assert len(manager.mock_created_entries) == 0
async def test_abort_calls_async_flow_removed(manager: MockFlowManager) -> None:
"""Test abort calling the async_flow_removed FlowManager method."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
return self.async_abort(reason="reason")
manager.async_flow_removed = Mock()
await manager.async_init("test")
manager.async_flow_removed.assert_called_once()
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 0
async def test_abort_calls_async_remove_with_exception(
manager: MockFlowManager, caplog: pytest.LogCaptureFixture
) -> None:
@ -288,13 +305,7 @@ async def test_create_saves_data(manager: MockFlowManager) -> None:
async def test_create_aborted_flow(manager: MockFlowManager) -> None:
"""Test return create_entry from aborted flow.
Note: The entry is created even if the flow is already aborted, then the
flow raises an UnknownFlow exception. This behavior is not logical, and
we should consider changing it to not create the entry if the flow is
aborted.
"""
"""Test return create_entry from aborted flow."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
@ -308,14 +319,25 @@ async def test_create_aborted_flow(manager: MockFlowManager) -> None:
await manager.async_init("test")
assert len(manager.async_progress()) == 0
# The entry is created even if the flow is aborted
assert len(manager.mock_created_entries) == 1
# No entry should be created if the flow is aborted
assert len(manager.mock_created_entries) == 0
entry = manager.mock_created_entries[0]
assert entry["handler"] == "test"
assert entry["title"] == "Test Title"
assert entry["data"] == "Test Data"
assert entry["source"] is None
async def test_create_calls_async_flow_removed(manager: MockFlowManager) -> None:
"""Test create calling the async_flow_removed FlowManager method."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
return self.async_create_entry(title="Test Title", data="Test Data")
manager.async_flow_removed = Mock()
await manager.async_init("test")
manager.async_flow_removed.assert_called_once()
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 1
async def test_discovery_init_flow(manager: MockFlowManager) -> None:
@ -930,12 +952,34 @@ async def test_configure_raises_unknown_flow_if_not_in_progress(
await manager.async_configure("wrong_flow_id")
async def test_abort_raises_unknown_flow_if_not_in_progress(
async def test_manager_abort_raises_unknown_flow_if_not_in_progress(
manager: MockFlowManager,
) -> None:
"""Test abort raises UnknownFlow if the flow is not in progress."""
with pytest.raises(data_entry_flow.UnknownFlow):
await manager.async_abort("wrong_flow_id")
manager.async_abort("wrong_flow_id")
async def test_manager_abort_calls_async_flow_removed(manager: MockFlowManager) -> None:
"""Test abort calling the async_flow_removed FlowManager method."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
return self.async_show_form(step_id="init")
manager.async_flow_removed = Mock()
result = await manager.async_init("test")
assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["step_id"] == "init"
manager.async_flow_removed.assert_not_called()
manager.async_abort(result["flow_id"])
manager.async_flow_removed.assert_called_once()
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 0
@pytest.mark.parametrize(