diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 64750b2ff50..cdc4023f32c 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -5,6 +5,7 @@ import abc import asyncio from collections.abc import Iterable, Mapping from dataclasses import dataclass +import logging from types import MappingProxyType from typing import Any, TypedDict @@ -16,6 +17,8 @@ from .exceptions import HomeAssistantError from .helpers.frame import report from .util import uuid as uuid_util +_LOGGER = logging.getLogger(__name__) + class FlowResultType(StrEnum): """Result type for a data entry flow.""" @@ -337,6 +340,11 @@ class FlowManager(abc.ABC): if not self._handler_progress_index[handler]: del self._handler_progress_index[handler] + try: + flow.async_remove() + except Exception as err: # pylint: disable=broad-except + _LOGGER.exception("Error removing %s config flow: %s", flow.handler, err) + async def _async_handle_step( self, flow: Any, @@ -568,6 +576,10 @@ class FlowHandler: description_placeholders=description_placeholders, ) + @callback + def async_remove(self) -> None: + """Notification that the config flow has been removed.""" + @callback def _create_abort_data( diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index 136c97808d3..1d60e20a3f0 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -1,6 +1,7 @@ """Test the flow classes.""" import asyncio -from unittest.mock import patch +import logging +from unittest.mock import Mock, patch import pytest import voluptuous as vol @@ -149,6 +150,45 @@ async def test_abort_removes_instance(manager): assert len(manager.mock_created_entries) == 0 +async def test_abort_calls_async_remove(manager): + """Test abort calling the async_remove FlowHandler method.""" + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + async def async_step_init(self, user_input=None): + return self.async_abort(reason="reason") + + async_remove = Mock() + + await manager.async_init("test") + + TestFlow.async_remove.assert_called_once() + + assert len(manager.async_progress()) == 0 + assert len(manager.mock_created_entries) == 0 + + +async def test_abort_calls_async_remove_with_exception(manager, caplog): + """Test abort calling the async_remove FlowHandler method, with an exception.""" + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + async def async_step_init(self, user_input=None): + return self.async_abort(reason="reason") + + async_remove = Mock(side_effect=[RuntimeError("error")]) + + with caplog.at_level(logging.ERROR): + await manager.async_init("test") + + assert "Error removing test config flow: error" in caplog.text + + TestFlow.async_remove.assert_called_once() + + assert len(manager.async_progress()) == 0 + assert len(manager.mock_created_entries) == 0 + + async def test_create_saves_data(manager): """Test creating a config entry."""