mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Abort if a flow is removed during a step (#142138)
* Abort if a flow is removed during a step * Reorganize code * Only call _set_pending_import_done if an entry is created * Try a new approach * Add tests * Update tests
This commit is contained in:
parent
7f4d178781
commit
f344314762
@ -1503,6 +1503,22 @@ class ConfigEntriesFlowManager(
|
||||
future.set_result(None)
|
||||
self._discovery_event_debouncer.async_shutdown()
|
||||
|
||||
@callback
|
||||
def async_flow_removed(
|
||||
self,
|
||||
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
|
||||
) -> None:
|
||||
"""Handle a removed config flow."""
|
||||
flow = cast(ConfigFlow, flow)
|
||||
|
||||
# Clean up issue if this is a reauth flow
|
||||
if flow.context["source"] == SOURCE_REAUTH:
|
||||
if (entry_id := flow.context.get("entry_id")) is not None and (
|
||||
entry := self.config_entries.async_get_entry(entry_id)
|
||||
) is not None:
|
||||
issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}"
|
||||
ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id)
|
||||
|
||||
async def async_finish_flow(
|
||||
self,
|
||||
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
|
||||
@ -1515,20 +1531,6 @@ class ConfigEntriesFlowManager(
|
||||
"""
|
||||
flow = cast(ConfigFlow, flow)
|
||||
|
||||
# Mark the step as done.
|
||||
# We do this to avoid a circular dependency where async_finish_flow sets up a
|
||||
# new entry, which needs the integration to be set up, which is waiting for
|
||||
# init to be done.
|
||||
self._set_pending_import_done(flow)
|
||||
|
||||
# Clean up issue if this is a reauth flow
|
||||
if flow.context["source"] == SOURCE_REAUTH:
|
||||
if (entry_id := flow.context.get("entry_id")) is not None and (
|
||||
entry := self.config_entries.async_get_entry(entry_id)
|
||||
) is not None:
|
||||
issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}"
|
||||
ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id)
|
||||
|
||||
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
|
||||
# If there's a config entry with a matching unique ID,
|
||||
# update the discovery key.
|
||||
@ -1567,6 +1569,12 @@ class ConfigEntriesFlowManager(
|
||||
)
|
||||
return result
|
||||
|
||||
# Mark the step as done.
|
||||
# We do this to avoid a circular dependency where async_finish_flow sets up a
|
||||
# new entry, which needs the integration to be set up, which is waiting for
|
||||
# init to be done.
|
||||
self._set_pending_import_done(flow)
|
||||
|
||||
# Avoid adding a config entry for a integration
|
||||
# that only supports a single config entry, but already has an entry
|
||||
if (
|
||||
|
@ -207,6 +207,13 @@ class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
|
||||
Handler key is the domain of the component that we want to set up.
|
||||
"""
|
||||
|
||||
@callback
|
||||
def async_flow_removed(
|
||||
self,
|
||||
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
|
||||
) -> None:
|
||||
"""Handle a removed data entry flow."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_finish_flow(
|
||||
self,
|
||||
@ -457,6 +464,7 @@ class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
|
||||
"""Remove a flow from in progress."""
|
||||
if (flow := self._progress.pop(flow_id, None)) is None:
|
||||
raise UnknownFlow
|
||||
self.async_flow_removed(flow)
|
||||
self._async_remove_flow_from_index(flow)
|
||||
flow.async_cancel_progress_task()
|
||||
try:
|
||||
@ -485,6 +493,10 @@ class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
|
||||
description_placeholders=err.description_placeholders,
|
||||
)
|
||||
|
||||
if flow.flow_id not in self._progress:
|
||||
# The flow was removed during the step
|
||||
raise UnknownFlow
|
||||
|
||||
# Setup the flow handler's preview if needed
|
||||
if result.get("preview") is not None:
|
||||
await self._async_setup_preview(flow)
|
||||
|
@ -1395,9 +1395,7 @@ async def test_reauth_issue_flow_aborted(
|
||||
issue = await _test_reauth_issue(hass, manager, issue_registry)
|
||||
|
||||
manager.flow.async_abort(issue.data["flow_id"])
|
||||
# This can be considered a bug, we should make sure the issue is always
|
||||
# removed when the reauth flow is aborted.
|
||||
assert len(issue_registry.issues) == 1
|
||||
assert len(issue_registry.issues) == 0
|
||||
|
||||
|
||||
async def _test_reauth_issue(
|
||||
|
@ -243,6 +243,23 @@ async def test_abort_calls_async_remove(manager: MockFlowManager) -> None:
|
||||
assert len(manager.mock_created_entries) == 0
|
||||
|
||||
|
||||
async def test_abort_calls_async_flow_removed(manager: MockFlowManager) -> None:
|
||||
"""Test abort calling the async_flow_removed FlowManager method."""
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
async def async_step_init(self, user_input=None):
|
||||
return self.async_abort(reason="reason")
|
||||
|
||||
manager.async_flow_removed = Mock()
|
||||
await manager.async_init("test")
|
||||
|
||||
manager.async_flow_removed.assert_called_once()
|
||||
|
||||
assert len(manager.async_progress()) == 0
|
||||
assert len(manager.mock_created_entries) == 0
|
||||
|
||||
|
||||
async def test_abort_calls_async_remove_with_exception(
|
||||
manager: MockFlowManager, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
@ -288,13 +305,7 @@ async def test_create_saves_data(manager: MockFlowManager) -> None:
|
||||
|
||||
|
||||
async def test_create_aborted_flow(manager: MockFlowManager) -> None:
|
||||
"""Test return create_entry from aborted flow.
|
||||
|
||||
Note: The entry is created even if the flow is already aborted, then the
|
||||
flow raises an UnknownFlow exception. This behavior is not logical, and
|
||||
we should consider changing it to not create the entry if the flow is
|
||||
aborted.
|
||||
"""
|
||||
"""Test return create_entry from aborted flow."""
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
@ -308,14 +319,25 @@ async def test_create_aborted_flow(manager: MockFlowManager) -> None:
|
||||
await manager.async_init("test")
|
||||
assert len(manager.async_progress()) == 0
|
||||
|
||||
# The entry is created even if the flow is aborted
|
||||
assert len(manager.mock_created_entries) == 1
|
||||
# No entry should be created if the flow is aborted
|
||||
assert len(manager.mock_created_entries) == 0
|
||||
|
||||
entry = manager.mock_created_entries[0]
|
||||
assert entry["handler"] == "test"
|
||||
assert entry["title"] == "Test Title"
|
||||
assert entry["data"] == "Test Data"
|
||||
assert entry["source"] is None
|
||||
|
||||
async def test_create_calls_async_flow_removed(manager: MockFlowManager) -> None:
|
||||
"""Test create calling the async_flow_removed FlowManager method."""
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
async def async_step_init(self, user_input=None):
|
||||
return self.async_create_entry(title="Test Title", data="Test Data")
|
||||
|
||||
manager.async_flow_removed = Mock()
|
||||
await manager.async_init("test")
|
||||
|
||||
manager.async_flow_removed.assert_called_once()
|
||||
|
||||
assert len(manager.async_progress()) == 0
|
||||
assert len(manager.mock_created_entries) == 1
|
||||
|
||||
|
||||
async def test_discovery_init_flow(manager: MockFlowManager) -> None:
|
||||
@ -930,12 +952,34 @@ async def test_configure_raises_unknown_flow_if_not_in_progress(
|
||||
await manager.async_configure("wrong_flow_id")
|
||||
|
||||
|
||||
async def test_abort_raises_unknown_flow_if_not_in_progress(
|
||||
async def test_manager_abort_raises_unknown_flow_if_not_in_progress(
|
||||
manager: MockFlowManager,
|
||||
) -> None:
|
||||
"""Test abort raises UnknownFlow if the flow is not in progress."""
|
||||
with pytest.raises(data_entry_flow.UnknownFlow):
|
||||
await manager.async_abort("wrong_flow_id")
|
||||
manager.async_abort("wrong_flow_id")
|
||||
|
||||
|
||||
async def test_manager_abort_calls_async_flow_removed(manager: MockFlowManager) -> None:
|
||||
"""Test abort calling the async_flow_removed FlowManager method."""
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
async def async_step_init(self, user_input=None):
|
||||
return self.async_show_form(step_id="init")
|
||||
|
||||
manager.async_flow_removed = Mock()
|
||||
result = await manager.async_init("test")
|
||||
assert result["type"] == data_entry_flow.FlowResultType.FORM
|
||||
assert result["step_id"] == "init"
|
||||
|
||||
manager.async_flow_removed.assert_not_called()
|
||||
|
||||
manager.async_abort(result["flow_id"])
|
||||
manager.async_flow_removed.assert_called_once()
|
||||
|
||||
assert len(manager.async_progress()) == 0
|
||||
assert len(manager.mock_created_entries) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
Loading…
x
Reference in New Issue
Block a user