diff --git a/homeassistant/components/asuswrt/config_flow.py b/homeassistant/components/asuswrt/config_flow.py index 94843a4c07c..414dbc65d8b 100644 --- a/homeassistant/components/asuswrt/config_flow.py +++ b/homeassistant/components/asuswrt/config_flow.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging import os import socket -from typing import Any +from typing import Any, cast import voluptuous as vol @@ -13,7 +13,7 @@ from homeassistant.components.device_tracker import ( CONF_CONSIDER_HOME, DEFAULT_CONSIDER_HOME, ) -from homeassistant.config_entries import ConfigEntry, ConfigFlow, OptionsFlow +from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.const import ( CONF_HOST, CONF_MODE, @@ -26,6 +26,11 @@ from homeassistant.core import callback from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers import config_validation as cv from homeassistant.helpers.device_registry import format_mac +from homeassistant.helpers.schema_config_entry_flow import ( + SchemaCommonFlowHandler, + SchemaFlowFormStep, + SchemaOptionsFlowHandler, +) from .const import ( CONF_DNSMASQ, @@ -52,6 +57,35 @@ RESULT_UNKNOWN = "unknown" _LOGGER = logging.getLogger(__name__) +OPTIONS_SCHEMA = vol.Schema( + { + vol.Optional( + CONF_CONSIDER_HOME, default=DEFAULT_CONSIDER_HOME.total_seconds() + ): vol.All(vol.Coerce(int), vol.Clamp(min=0, max=900)), + vol.Optional(CONF_TRACK_UNKNOWN, default=DEFAULT_TRACK_UNKNOWN): bool, + vol.Required(CONF_INTERFACE, default=DEFAULT_INTERFACE): str, + vol.Required(CONF_DNSMASQ, default=DEFAULT_DNSMASQ): str, + } +) + + +def get_options_schema(handler: SchemaCommonFlowHandler) -> vol.Schema: + """Get options schema.""" + options_flow: SchemaOptionsFlowHandler + options_flow = cast(SchemaOptionsFlowHandler, handler.parent_handler) + if options_flow.config_entry.data[CONF_MODE] == MODE_AP: + return OPTIONS_SCHEMA.extend( + { + vol.Optional(CONF_REQUIRE_IP, default=True): bool, + } + ) + return OPTIONS_SCHEMA + + +OPTIONS_FLOW = { + "init": SchemaFlowFormStep(get_options_schema), +} + def _is_file(value: str) -> bool: """Validate that the value is an existing file.""" @@ -203,62 +237,8 @@ class AsusWrtFlowHandler(ConfigFlow, domain=DOMAIN): @staticmethod @callback - def async_get_options_flow(config_entry: ConfigEntry) -> OptionsFlow: - """Get the options flow for this handler.""" - return OptionsFlowHandler(config_entry) - - -class OptionsFlowHandler(OptionsFlow): - """Handle a option flow for AsusWrt.""" - - def __init__(self, config_entry: ConfigEntry) -> None: - """Initialize options flow.""" - self.config_entry = config_entry - - async def async_step_init( - self, user_input: dict[str, Any] | None = None - ) -> FlowResult: - """Handle options flow.""" - if user_input is not None: - return self.async_create_entry(title="", data=user_input) - - data_schema = vol.Schema( - { - vol.Optional( - CONF_CONSIDER_HOME, - default=self.config_entry.options.get( - CONF_CONSIDER_HOME, DEFAULT_CONSIDER_HOME.total_seconds() - ), - ): vol.All(vol.Coerce(int), vol.Clamp(min=0, max=900)), - vol.Optional( - CONF_TRACK_UNKNOWN, - default=self.config_entry.options.get( - CONF_TRACK_UNKNOWN, DEFAULT_TRACK_UNKNOWN - ), - ): bool, - vol.Required( - CONF_INTERFACE, - default=self.config_entry.options.get( - CONF_INTERFACE, DEFAULT_INTERFACE - ), - ): str, - vol.Required( - CONF_DNSMASQ, - default=self.config_entry.options.get( - CONF_DNSMASQ, DEFAULT_DNSMASQ - ), - ): str, - } - ) - - if self.config_entry.data[CONF_MODE] == MODE_AP: - data_schema = data_schema.extend( - { - vol.Optional( - CONF_REQUIRE_IP, - default=self.config_entry.options.get(CONF_REQUIRE_IP, True), - ): bool, - } - ) - - return self.async_show_form(step_id="init", data_schema=data_schema) + def async_get_options_flow( + config_entry: ConfigEntry, + ) -> SchemaOptionsFlowHandler: + """Get options flow for this handler.""" + return SchemaOptionsFlowHandler(config_entry, OPTIONS_FLOW) diff --git a/tests/components/asuswrt/test_config_flow.py b/tests/components/asuswrt/test_config_flow.py index 22a780fc12e..f9af800166a 100644 --- a/tests/components/asuswrt/test_config_flow.py +++ b/tests/components/asuswrt/test_config_flow.py @@ -23,6 +23,7 @@ from homeassistant.const import ( CONF_PROTOCOL, CONF_USERNAME, ) +from homeassistant.core import HomeAssistant from tests.common import MockConfigEntry @@ -248,8 +249,8 @@ async def test_on_connect_failed(hass, side_effect, error): assert result["errors"] == {"base": error} -async def test_options_flow(hass): - """Test config flow options.""" +async def test_options_flow_ap(hass: HomeAssistant) -> None: + """Test config flow options for ap mode.""" config_entry = MockConfigEntry( domain=DOMAIN, data=CONFIG_DATA, @@ -264,6 +265,7 @@ async def test_options_flow(hass): assert result["type"] == data_entry_flow.FlowResultType.FORM assert result["step_id"] == "init" + assert CONF_REQUIRE_IP in result["data_schema"].schema result = await hass.config_entries.options.async_configure( result["flow_id"], @@ -282,3 +284,37 @@ async def test_options_flow(hass): assert config_entry.options[CONF_INTERFACE] == "aaa" assert config_entry.options[CONF_DNSMASQ] == "bbb" assert config_entry.options[CONF_REQUIRE_IP] is False + + +async def test_options_flow_router(hass: HomeAssistant) -> None: + """Test config flow options for router mode.""" + config_entry = MockConfigEntry( + domain=DOMAIN, + data={**CONFIG_DATA, CONF_MODE: "router"}, + ) + config_entry.add_to_hass(hass) + + with PATCH_SETUP_ENTRY: + await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + result = await hass.config_entries.options.async_init(config_entry.entry_id) + + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["step_id"] == "init" + assert CONF_REQUIRE_IP not in result["data_schema"].schema + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={ + CONF_CONSIDER_HOME: 20, + CONF_TRACK_UNKNOWN: True, + CONF_INTERFACE: "aaa", + CONF_DNSMASQ: "bbb", + }, + ) + + assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY + assert config_entry.options[CONF_CONSIDER_HOME] == 20 + assert config_entry.options[CONF_TRACK_UNKNOWN] is True + assert config_entry.options[CONF_INTERFACE] == "aaa" + assert config_entry.options[CONF_DNSMASQ] == "bbb"