mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Abort if a flow is removed during a step (#142138)
* Abort if a flow is removed during a step * Reorganize code * Only call _set_pending_import_done if an entry is created * Try a new approach * Add tests * Update tests
This commit is contained in:
parent
7f4d178781
commit
f344314762
@ -1503,6 +1503,22 @@ class ConfigEntriesFlowManager(
|
|||||||
future.set_result(None)
|
future.set_result(None)
|
||||||
self._discovery_event_debouncer.async_shutdown()
|
self._discovery_event_debouncer.async_shutdown()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_flow_removed(
|
||||||
|
self,
|
||||||
|
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
|
||||||
|
) -> None:
|
||||||
|
"""Handle a removed config flow."""
|
||||||
|
flow = cast(ConfigFlow, flow)
|
||||||
|
|
||||||
|
# Clean up issue if this is a reauth flow
|
||||||
|
if flow.context["source"] == SOURCE_REAUTH:
|
||||||
|
if (entry_id := flow.context.get("entry_id")) is not None and (
|
||||||
|
entry := self.config_entries.async_get_entry(entry_id)
|
||||||
|
) is not None:
|
||||||
|
issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}"
|
||||||
|
ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id)
|
||||||
|
|
||||||
async def async_finish_flow(
|
async def async_finish_flow(
|
||||||
self,
|
self,
|
||||||
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
|
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
|
||||||
@ -1515,20 +1531,6 @@ class ConfigEntriesFlowManager(
|
|||||||
"""
|
"""
|
||||||
flow = cast(ConfigFlow, flow)
|
flow = cast(ConfigFlow, flow)
|
||||||
|
|
||||||
# Mark the step as done.
|
|
||||||
# We do this to avoid a circular dependency where async_finish_flow sets up a
|
|
||||||
# new entry, which needs the integration to be set up, which is waiting for
|
|
||||||
# init to be done.
|
|
||||||
self._set_pending_import_done(flow)
|
|
||||||
|
|
||||||
# Clean up issue if this is a reauth flow
|
|
||||||
if flow.context["source"] == SOURCE_REAUTH:
|
|
||||||
if (entry_id := flow.context.get("entry_id")) is not None and (
|
|
||||||
entry := self.config_entries.async_get_entry(entry_id)
|
|
||||||
) is not None:
|
|
||||||
issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}"
|
|
||||||
ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id)
|
|
||||||
|
|
||||||
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
|
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
|
||||||
# If there's a config entry with a matching unique ID,
|
# If there's a config entry with a matching unique ID,
|
||||||
# update the discovery key.
|
# update the discovery key.
|
||||||
@ -1567,6 +1569,12 @@ class ConfigEntriesFlowManager(
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# Mark the step as done.
|
||||||
|
# We do this to avoid a circular dependency where async_finish_flow sets up a
|
||||||
|
# new entry, which needs the integration to be set up, which is waiting for
|
||||||
|
# init to be done.
|
||||||
|
self._set_pending_import_done(flow)
|
||||||
|
|
||||||
# Avoid adding a config entry for a integration
|
# Avoid adding a config entry for a integration
|
||||||
# that only supports a single config entry, but already has an entry
|
# that only supports a single config entry, but already has an entry
|
||||||
if (
|
if (
|
||||||
|
@ -207,6 +207,13 @@ class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
|
|||||||
Handler key is the domain of the component that we want to set up.
|
Handler key is the domain of the component that we want to set up.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_flow_removed(
|
||||||
|
self,
|
||||||
|
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
|
||||||
|
) -> None:
|
||||||
|
"""Handle a removed data entry flow."""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def async_finish_flow(
|
async def async_finish_flow(
|
||||||
self,
|
self,
|
||||||
@ -457,6 +464,7 @@ class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
|
|||||||
"""Remove a flow from in progress."""
|
"""Remove a flow from in progress."""
|
||||||
if (flow := self._progress.pop(flow_id, None)) is None:
|
if (flow := self._progress.pop(flow_id, None)) is None:
|
||||||
raise UnknownFlow
|
raise UnknownFlow
|
||||||
|
self.async_flow_removed(flow)
|
||||||
self._async_remove_flow_from_index(flow)
|
self._async_remove_flow_from_index(flow)
|
||||||
flow.async_cancel_progress_task()
|
flow.async_cancel_progress_task()
|
||||||
try:
|
try:
|
||||||
@ -485,6 +493,10 @@ class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
|
|||||||
description_placeholders=err.description_placeholders,
|
description_placeholders=err.description_placeholders,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if flow.flow_id not in self._progress:
|
||||||
|
# The flow was removed during the step
|
||||||
|
raise UnknownFlow
|
||||||
|
|
||||||
# Setup the flow handler's preview if needed
|
# Setup the flow handler's preview if needed
|
||||||
if result.get("preview") is not None:
|
if result.get("preview") is not None:
|
||||||
await self._async_setup_preview(flow)
|
await self._async_setup_preview(flow)
|
||||||
|
@ -1395,9 +1395,7 @@ async def test_reauth_issue_flow_aborted(
|
|||||||
issue = await _test_reauth_issue(hass, manager, issue_registry)
|
issue = await _test_reauth_issue(hass, manager, issue_registry)
|
||||||
|
|
||||||
manager.flow.async_abort(issue.data["flow_id"])
|
manager.flow.async_abort(issue.data["flow_id"])
|
||||||
# This can be considered a bug, we should make sure the issue is always
|
assert len(issue_registry.issues) == 0
|
||||||
# removed when the reauth flow is aborted.
|
|
||||||
assert len(issue_registry.issues) == 1
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_reauth_issue(
|
async def _test_reauth_issue(
|
||||||
|
@ -243,6 +243,23 @@ async def test_abort_calls_async_remove(manager: MockFlowManager) -> None:
|
|||||||
assert len(manager.mock_created_entries) == 0
|
assert len(manager.mock_created_entries) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_abort_calls_async_flow_removed(manager: MockFlowManager) -> None:
|
||||||
|
"""Test abort calling the async_flow_removed FlowManager method."""
|
||||||
|
|
||||||
|
@manager.mock_reg_handler("test")
|
||||||
|
class TestFlow(data_entry_flow.FlowHandler):
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
return self.async_abort(reason="reason")
|
||||||
|
|
||||||
|
manager.async_flow_removed = Mock()
|
||||||
|
await manager.async_init("test")
|
||||||
|
|
||||||
|
manager.async_flow_removed.assert_called_once()
|
||||||
|
|
||||||
|
assert len(manager.async_progress()) == 0
|
||||||
|
assert len(manager.mock_created_entries) == 0
|
||||||
|
|
||||||
|
|
||||||
async def test_abort_calls_async_remove_with_exception(
|
async def test_abort_calls_async_remove_with_exception(
|
||||||
manager: MockFlowManager, caplog: pytest.LogCaptureFixture
|
manager: MockFlowManager, caplog: pytest.LogCaptureFixture
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -288,13 +305,7 @@ async def test_create_saves_data(manager: MockFlowManager) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_create_aborted_flow(manager: MockFlowManager) -> None:
|
async def test_create_aborted_flow(manager: MockFlowManager) -> None:
|
||||||
"""Test return create_entry from aborted flow.
|
"""Test return create_entry from aborted flow."""
|
||||||
|
|
||||||
Note: The entry is created even if the flow is already aborted, then the
|
|
||||||
flow raises an UnknownFlow exception. This behavior is not logical, and
|
|
||||||
we should consider changing it to not create the entry if the flow is
|
|
||||||
aborted.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@manager.mock_reg_handler("test")
|
@manager.mock_reg_handler("test")
|
||||||
class TestFlow(data_entry_flow.FlowHandler):
|
class TestFlow(data_entry_flow.FlowHandler):
|
||||||
@ -308,14 +319,25 @@ async def test_create_aborted_flow(manager: MockFlowManager) -> None:
|
|||||||
await manager.async_init("test")
|
await manager.async_init("test")
|
||||||
assert len(manager.async_progress()) == 0
|
assert len(manager.async_progress()) == 0
|
||||||
|
|
||||||
# The entry is created even if the flow is aborted
|
# No entry should be created if the flow is aborted
|
||||||
assert len(manager.mock_created_entries) == 1
|
assert len(manager.mock_created_entries) == 0
|
||||||
|
|
||||||
entry = manager.mock_created_entries[0]
|
|
||||||
assert entry["handler"] == "test"
|
async def test_create_calls_async_flow_removed(manager: MockFlowManager) -> None:
|
||||||
assert entry["title"] == "Test Title"
|
"""Test create calling the async_flow_removed FlowManager method."""
|
||||||
assert entry["data"] == "Test Data"
|
|
||||||
assert entry["source"] is None
|
@manager.mock_reg_handler("test")
|
||||||
|
class TestFlow(data_entry_flow.FlowHandler):
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
return self.async_create_entry(title="Test Title", data="Test Data")
|
||||||
|
|
||||||
|
manager.async_flow_removed = Mock()
|
||||||
|
await manager.async_init("test")
|
||||||
|
|
||||||
|
manager.async_flow_removed.assert_called_once()
|
||||||
|
|
||||||
|
assert len(manager.async_progress()) == 0
|
||||||
|
assert len(manager.mock_created_entries) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_discovery_init_flow(manager: MockFlowManager) -> None:
|
async def test_discovery_init_flow(manager: MockFlowManager) -> None:
|
||||||
@ -930,12 +952,34 @@ async def test_configure_raises_unknown_flow_if_not_in_progress(
|
|||||||
await manager.async_configure("wrong_flow_id")
|
await manager.async_configure("wrong_flow_id")
|
||||||
|
|
||||||
|
|
||||||
async def test_abort_raises_unknown_flow_if_not_in_progress(
|
async def test_manager_abort_raises_unknown_flow_if_not_in_progress(
|
||||||
manager: MockFlowManager,
|
manager: MockFlowManager,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test abort raises UnknownFlow if the flow is not in progress."""
|
"""Test abort raises UnknownFlow if the flow is not in progress."""
|
||||||
with pytest.raises(data_entry_flow.UnknownFlow):
|
with pytest.raises(data_entry_flow.UnknownFlow):
|
||||||
await manager.async_abort("wrong_flow_id")
|
manager.async_abort("wrong_flow_id")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_manager_abort_calls_async_flow_removed(manager: MockFlowManager) -> None:
|
||||||
|
"""Test abort calling the async_flow_removed FlowManager method."""
|
||||||
|
|
||||||
|
@manager.mock_reg_handler("test")
|
||||||
|
class TestFlow(data_entry_flow.FlowHandler):
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
return self.async_show_form(step_id="init")
|
||||||
|
|
||||||
|
manager.async_flow_removed = Mock()
|
||||||
|
result = await manager.async_init("test")
|
||||||
|
assert result["type"] == data_entry_flow.FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "init"
|
||||||
|
|
||||||
|
manager.async_flow_removed.assert_not_called()
|
||||||
|
|
||||||
|
manager.async_abort(result["flow_id"])
|
||||||
|
manager.async_flow_removed.assert_called_once()
|
||||||
|
|
||||||
|
assert len(manager.async_progress()) == 0
|
||||||
|
assert len(manager.mock_created_entries) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user