Use SchemaOptionsFlowHandler in asuswrt (#82806)

This commit is contained in:
epenet 2022-11-28 09:56:08 +01:00 committed by GitHub
parent 67e4f2c202
commit 8a20a90324
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 79 additions and 63 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import logging import logging
import os import os
import socket import socket
from typing import Any from typing import Any, cast
import voluptuous as vol import voluptuous as vol
@ -13,7 +13,7 @@ from homeassistant.components.device_tracker import (
CONF_CONSIDER_HOME, CONF_CONSIDER_HOME,
DEFAULT_CONSIDER_HOME, DEFAULT_CONSIDER_HOME,
) )
from homeassistant.config_entries import ConfigEntry, ConfigFlow, OptionsFlow from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.const import ( from homeassistant.const import (
CONF_HOST, CONF_HOST,
CONF_MODE, CONF_MODE,
@ -26,6 +26,11 @@ from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.device_registry import format_mac from homeassistant.helpers.device_registry import format_mac
from homeassistant.helpers.schema_config_entry_flow import (
SchemaCommonFlowHandler,
SchemaFlowFormStep,
SchemaOptionsFlowHandler,
)
from .const import ( from .const import (
CONF_DNSMASQ, CONF_DNSMASQ,
@ -52,6 +57,35 @@ RESULT_UNKNOWN = "unknown"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
OPTIONS_SCHEMA = vol.Schema(
{
vol.Optional(
CONF_CONSIDER_HOME, default=DEFAULT_CONSIDER_HOME.total_seconds()
): vol.All(vol.Coerce(int), vol.Clamp(min=0, max=900)),
vol.Optional(CONF_TRACK_UNKNOWN, default=DEFAULT_TRACK_UNKNOWN): bool,
vol.Required(CONF_INTERFACE, default=DEFAULT_INTERFACE): str,
vol.Required(CONF_DNSMASQ, default=DEFAULT_DNSMASQ): str,
}
)
def get_options_schema(handler: SchemaCommonFlowHandler) -> vol.Schema:
"""Get options schema."""
options_flow: SchemaOptionsFlowHandler
options_flow = cast(SchemaOptionsFlowHandler, handler.parent_handler)
if options_flow.config_entry.data[CONF_MODE] == MODE_AP:
return OPTIONS_SCHEMA.extend(
{
vol.Optional(CONF_REQUIRE_IP, default=True): bool,
}
)
return OPTIONS_SCHEMA
OPTIONS_FLOW = {
"init": SchemaFlowFormStep(get_options_schema),
}
def _is_file(value: str) -> bool: def _is_file(value: str) -> bool:
"""Validate that the value is an existing file.""" """Validate that the value is an existing file."""
@ -203,62 +237,8 @@ class AsusWrtFlowHandler(ConfigFlow, domain=DOMAIN):
@staticmethod @staticmethod
@callback @callback
def async_get_options_flow(config_entry: ConfigEntry) -> OptionsFlow: def async_get_options_flow(
"""Get the options flow for this handler.""" config_entry: ConfigEntry,
return OptionsFlowHandler(config_entry) ) -> SchemaOptionsFlowHandler:
"""Get options flow for this handler."""
return SchemaOptionsFlowHandler(config_entry, OPTIONS_FLOW)
class OptionsFlowHandler(OptionsFlow):
"""Handle a option flow for AsusWrt."""
def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize options flow."""
self.config_entry = config_entry
async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle options flow."""
if user_input is not None:
return self.async_create_entry(title="", data=user_input)
data_schema = vol.Schema(
{
vol.Optional(
CONF_CONSIDER_HOME,
default=self.config_entry.options.get(
CONF_CONSIDER_HOME, DEFAULT_CONSIDER_HOME.total_seconds()
),
): vol.All(vol.Coerce(int), vol.Clamp(min=0, max=900)),
vol.Optional(
CONF_TRACK_UNKNOWN,
default=self.config_entry.options.get(
CONF_TRACK_UNKNOWN, DEFAULT_TRACK_UNKNOWN
),
): bool,
vol.Required(
CONF_INTERFACE,
default=self.config_entry.options.get(
CONF_INTERFACE, DEFAULT_INTERFACE
),
): str,
vol.Required(
CONF_DNSMASQ,
default=self.config_entry.options.get(
CONF_DNSMASQ, DEFAULT_DNSMASQ
),
): str,
}
)
if self.config_entry.data[CONF_MODE] == MODE_AP:
data_schema = data_schema.extend(
{
vol.Optional(
CONF_REQUIRE_IP,
default=self.config_entry.options.get(CONF_REQUIRE_IP, True),
): bool,
}
)
return self.async_show_form(step_id="init", data_schema=data_schema)

View File

@ -23,6 +23,7 @@ from homeassistant.const import (
CONF_PROTOCOL, CONF_PROTOCOL,
CONF_USERNAME, CONF_USERNAME,
) )
from homeassistant.core import HomeAssistant
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -248,8 +249,8 @@ async def test_on_connect_failed(hass, side_effect, error):
assert result["errors"] == {"base": error} assert result["errors"] == {"base": error}
async def test_options_flow(hass): async def test_options_flow_ap(hass: HomeAssistant) -> None:
"""Test config flow options.""" """Test config flow options for ap mode."""
config_entry = MockConfigEntry( config_entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
data=CONFIG_DATA, data=CONFIG_DATA,
@ -264,6 +265,7 @@ async def test_options_flow(hass):
assert result["type"] == data_entry_flow.FlowResultType.FORM assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["step_id"] == "init" assert result["step_id"] == "init"
assert CONF_REQUIRE_IP in result["data_schema"].schema
result = await hass.config_entries.options.async_configure( result = await hass.config_entries.options.async_configure(
result["flow_id"], result["flow_id"],
@ -282,3 +284,37 @@ async def test_options_flow(hass):
assert config_entry.options[CONF_INTERFACE] == "aaa" assert config_entry.options[CONF_INTERFACE] == "aaa"
assert config_entry.options[CONF_DNSMASQ] == "bbb" assert config_entry.options[CONF_DNSMASQ] == "bbb"
assert config_entry.options[CONF_REQUIRE_IP] is False assert config_entry.options[CONF_REQUIRE_IP] is False
async def test_options_flow_router(hass: HomeAssistant) -> None:
"""Test config flow options for router mode."""
config_entry = MockConfigEntry(
domain=DOMAIN,
data={**CONFIG_DATA, CONF_MODE: "router"},
)
config_entry.add_to_hass(hass)
with PATCH_SETUP_ENTRY:
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["step_id"] == "init"
assert CONF_REQUIRE_IP not in result["data_schema"].schema
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={
CONF_CONSIDER_HOME: 20,
CONF_TRACK_UNKNOWN: True,
CONF_INTERFACE: "aaa",
CONF_DNSMASQ: "bbb",
},
)
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
assert config_entry.options[CONF_CONSIDER_HOME] == 20
assert config_entry.options[CONF_TRACK_UNKNOWN] is True
assert config_entry.options[CONF_INTERFACE] == "aaa"
assert config_entry.options[CONF_DNSMASQ] == "bbb"