diff --git a/homeassistant/helpers/schema_config_entry_flow.py b/homeassistant/helpers/schema_config_entry_flow.py index 39fa5164f62..c1936b4bd1e 100644 --- a/homeassistant/helpers/schema_config_entry_flow.py +++ b/homeassistant/helpers/schema_config_entry_flow.py @@ -151,14 +151,19 @@ class SchemaCommonFlowHandler: user_input: dict[str, Any] | None = None, ) -> FlowResult: """Show form for next step.""" - form_step: SchemaFlowFormStep = cast( - SchemaFlowFormStep, self._flow[next_step_id] - ) - options = dict(self._options) if user_input: options.update(user_input) + if isinstance(self._flow[next_step_id], SchemaFlowMenuStep): + menu_step = cast(SchemaFlowMenuStep, self._flow[next_step_id]) + return self._handler.async_show_menu( + step_id=next_step_id, + menu_options=menu_step.options, + ) + + form_step = cast(SchemaFlowFormStep, self._flow[next_step_id]) + if ( data_schema := self._get_schema(form_step, self._options) ) and data_schema.schema: @@ -197,10 +202,10 @@ class SchemaCommonFlowHandler: self, step_id: str, user_input: dict[str, Any] | None = None ) -> FlowResult: """Handle a menu step.""" - form_step: SchemaFlowMenuStep = cast(SchemaFlowMenuStep, self._flow[step_id]) + menu_step: SchemaFlowMenuStep = cast(SchemaFlowMenuStep, self._flow[step_id]) return self._handler.async_show_menu( step_id=step_id, - menu_options=form_step.options, + menu_options=menu_step.options, ) diff --git a/tests/helpers/test_schema_config_entry_flow.py b/tests/helpers/test_schema_config_entry_flow.py new file mode 100644 index 00000000000..e54582db606 --- /dev/null +++ b/tests/helpers/test_schema_config_entry_flow.py @@ -0,0 +1,67 @@ +"""Tests for the schema based data entry flows.""" +from __future__ import annotations + +from unittest.mock import patch + +import voluptuous as vol + +from homeassistant import config_entries +from homeassistant.core import HomeAssistant +from homeassistant.data_entry_flow import FlowResultType +from homeassistant.helpers.schema_config_entry_flow import ( + SchemaConfigFlowHandler, + SchemaFlowFormStep, + SchemaFlowMenuStep, +) + +from tests.common import mock_platform + +TEST_DOMAIN = "test" + + +async def test_menu_step(hass: HomeAssistant) -> None: + """Test menu step.""" + + MENU_1 = ["option1", "option2"] + MENU_2 = ["option3", "option4"] + + CONFIG_FLOW: dict[str, SchemaFlowFormStep | SchemaFlowMenuStep] = { + "user": SchemaFlowMenuStep(MENU_1), + "option1": SchemaFlowFormStep(vol.Schema({}), next_step=lambda _: "menu2"), + "menu2": SchemaFlowMenuStep(MENU_2), + "option3": SchemaFlowFormStep(vol.Schema({})), + } + + class TestConfigFlow(SchemaConfigFlowHandler, domain=TEST_DOMAIN): + """Handle a config or options flow for Derivative.""" + + config_flow = CONFIG_FLOW + + mock_platform(hass, f"{TEST_DOMAIN}.config_flow") + with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: TestConfigFlow}): + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": "user"} + ) + assert result["type"] == FlowResultType.MENU + assert result["step_id"] == "user" + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {"next_step_id": "option1"}, + ) + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "option1" + + result = await hass.config_entries.flow.async_configure(result["flow_id"], {}) + assert result["type"] == FlowResultType.MENU + assert result["step_id"] == "menu2" + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {"next_step_id": "option3"}, + ) + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "option3" + + result = await hass.config_entries.flow.async_configure(result["flow_id"], {}) + assert result["type"] == FlowResultType.CREATE_ENTRY