From e7fba46a0636b16919144ea07dc8253f619eba22 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 9 Mar 2022 13:18:19 +0100 Subject: [PATCH] Refactor helper_config_entry_flow (#67895) --- homeassistant/components/group/config_flow.py | 93 ++++---- .../components/switch/config_flow.py | 27 ++- .../helpers/helper_config_entry_flow.py | 210 ++++++++---------- tests/components/group/test_config_flow.py | 2 +- 4 files changed, 161 insertions(+), 171 deletions(-) diff --git a/homeassistant/components/group/config_flow.py b/homeassistant/components/group/config_flow.py index 82de2056da5..0547578b131 100644 --- a/homeassistant/components/group/config_flow.py +++ b/homeassistant/components/group/config_flow.py @@ -1,13 +1,18 @@ """Config flow for Group integration.""" from __future__ import annotations +from collections.abc import Mapping from typing import Any, cast import voluptuous as vol -from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_ENTITIES -from homeassistant.helpers import helper_config_entry_flow, selector +from homeassistant.core import callback +from homeassistant.helpers import selector +from homeassistant.helpers.helper_config_entry_flow import ( + HelperConfigFlowHandler, + HelperFlowStep, +) from . import DOMAIN @@ -30,52 +35,54 @@ def basic_group_config_schema(domain: str) -> vol.Schema: ) -STEPS = { - "init": vol.Schema( - { - vol.Required("group_type"): selector.selector( - { - "select": { - "options": [ - "cover", - "fan", - "light", - "media_player", - ] - } +INITIAL_STEP_SCHEMA = vol.Schema( + { + vol.Required("group_type"): selector.selector( + { + "select": { + "options": [ + "cover", + "fan", + "light", + "media_player", + ] } - ) - } - ), - "cover": basic_group_config_schema("cover"), - "fan": basic_group_config_schema("fan"), - "light": basic_group_config_schema("light"), - "media_player": basic_group_config_schema("media_player"), - "cover_options": basic_group_options_schema("cover"), - "fan_options": basic_group_options_schema("fan"), - "light_options": basic_group_options_schema("light"), - "media_player_options": basic_group_options_schema("media_player"), + } + ) + } +) + + +@callback +def choose_config_step(options: dict[str, Any]) -> str: + """Return next step_id when group_type is selected.""" + return cast(str, options["group_type"]) + + +CONFIG_FLOW = { + "user": HelperFlowStep(INITIAL_STEP_SCHEMA, next_step=choose_config_step), + "cover": HelperFlowStep(basic_group_config_schema("cover")), + "fan": HelperFlowStep(basic_group_config_schema("fan")), + "light": HelperFlowStep(basic_group_config_schema("light")), + "media_player": HelperFlowStep(basic_group_config_schema("media_player")), } -class GroupConfigFlowHandler( - helper_config_entry_flow.HelperConfigFlowHandler, domain=DOMAIN -): +OPTIONS_FLOW = { + "init": HelperFlowStep(None, next_step=choose_config_step), + "cover": HelperFlowStep(basic_group_options_schema("cover")), + "fan": HelperFlowStep(basic_group_options_schema("fan")), + "light": HelperFlowStep(basic_group_options_schema("light")), + "media_player": HelperFlowStep(basic_group_options_schema("media_player")), +} + + +class GroupConfigFlowHandler(HelperConfigFlowHandler, domain=DOMAIN): """Handle a config or options flow for Switch Light.""" - steps = STEPS + config_flow = CONFIG_FLOW + options_flow = OPTIONS_FLOW - def async_config_entry_title(self, user_input: dict[str, Any]) -> str: + def async_config_entry_title(self, options: Mapping[str, Any]) -> str: """Return config entry title.""" - return cast(str, user_input["name"]) if "name" in user_input else "" - - @staticmethod - def async_initial_options_step(config_entry: ConfigEntry) -> str: - """Return initial options step.""" - return f"{config_entry.options['group_type']}_options" - - def async_next_step(self, step_id: str, user_input: dict[str, Any]) -> str | None: - """Return next step_id.""" - if step_id == "init": - return cast(str, user_input["group_type"]) - return None + return cast(str, options["name"]) if "name" in options else "" diff --git a/homeassistant/components/switch/config_flow.py b/homeassistant/components/switch/config_flow.py index 1adc4ec0aee..efb3baf363f 100644 --- a/homeassistant/components/switch/config_flow.py +++ b/homeassistant/components/switch/config_flow.py @@ -1,6 +1,7 @@ """Config flow for Switch integration.""" from __future__ import annotations +from collections.abc import Mapping from typing import Any import voluptuous as vol @@ -14,13 +15,15 @@ from homeassistant.helpers import ( from .const import DOMAIN -STEPS = { - "init": vol.Schema( - { - vol.Required("entity_id"): selector.selector( - {"entity": {"domain": "switch"}} - ), - } +CONFIG_FLOW = { + "user": helper_config_entry_flow.HelperFlowStep( + vol.Schema( + { + vol.Required("entity_id"): selector.selector( + {"entity": {"domain": "switch"}} + ), + } + ) ) } @@ -30,16 +33,16 @@ class SwitchLightConfigFlowHandler( ): """Handle a config or options flow for Switch Light.""" - steps = STEPS + config_flow = CONFIG_FLOW - def async_config_entry_title(self, user_input: dict[str, Any]) -> str: + def async_config_entry_title(self, options: Mapping[str, Any]) -> str: """Return config entry title.""" registry = er.async_get(self.hass) - object_id = split_entity_id(user_input["entity_id"])[1] - entry = registry.async_get(user_input["entity_id"]) + object_id = split_entity_id(options["entity_id"])[1] + entry = registry.async_get(options["entity_id"]) if entry: return entry.name or entry.original_name or object_id - state = self.hass.states.get(user_input["entity_id"]) + state = self.hass.states.get(options["entity_id"]) if state: return state.name or object_id return object_id diff --git a/homeassistant/helpers/helper_config_entry_flow.py b/homeassistant/helpers/helper_config_entry_flow.py index c632ad60eae..7ef69b7f360 100644 --- a/homeassistant/helpers/helper_config_entry_flow.py +++ b/homeassistant/helpers/helper_config_entry_flow.py @@ -2,20 +2,32 @@ from __future__ import annotations from abc import abstractmethod -from collections.abc import Awaitable, Callable +from collections.abc import Callable, Mapping import copy -import types +from dataclasses import dataclass from typing import Any import voluptuous as vol from homeassistant import config_entries -from homeassistant.core import HomeAssistant, callback -from homeassistant.data_entry_flow import ( - RESULT_TYPE_CREATE_ENTRY, - FlowResult, - UnknownHandler, -) +from homeassistant.core import callback +from homeassistant.data_entry_flow import FlowResult, UnknownHandler + + +@dataclass +class HelperFlowStep: + """Define a helper config or options flow step.""" + + # Optional schema for requesting and validating user input. If schema validation + # fails, the step will be retried. If the schema is None, no user input is requested. + schema: vol.Schema | None + + # Optional function to identify next step. + # The next_step function is called if the schema validates successfully or if no + # schema is defined. The next_step function is passed the union of config entry + # options and user input from previous steps. + # If next_step returns None, the flow is ended with RESULT_TYPE_CREATE_ENTRY. + next_step: Callable[[dict[str, Any]], str | None] = lambda _: None class HelperCommonFlowHandler: @@ -24,61 +36,64 @@ class HelperCommonFlowHandler: def __init__( self, handler: HelperConfigFlowHandler | HelperOptionsFlowHandler, + flow: dict[str, HelperFlowStep], config_entry: config_entries.ConfigEntry | None, ) -> None: """Initialize a common handler.""" + self._flow = flow self._handler = handler self._options = dict(config_entry.options) if config_entry is not None else {} async def async_step( - self, step_id: str, _user_input: dict[str, Any] | None = None + self, step_id: str, user_input: dict[str, Any] | None = None ) -> FlowResult: """Handle a step.""" - errors = None - if _user_input is not None: - errors = {} - try: - user_input = await self._handler.async_validate_input( - self._handler.hass, step_id, _user_input - ) - except vol.Invalid as exc: - errors["base"] = str(exc) - else: - self._options.update(user_input) - if ( - next_step_id := self._handler.async_next_step(step_id, user_input) - ) is None: - title = self._handler.async_config_entry_title(user_input) - return self._handler.async_create_entry( - title=title, data=self._options - ) - return self._handler.async_show_form( - step_id=next_step_id, data_schema=self._handler.steps[next_step_id] - ) + next_step_id: str = step_id - schema = dict(self._handler.steps[step_id].schema) - for key in list(schema): - if key in self._options and isinstance(key, vol.Marker): - new_key = copy.copy(key) - new_key.description = {"suggested_value": self._options[key]} - val = schema.pop(key) - schema[new_key] = val + if user_input is not None: + # User input was validated successfully, update options + self._options.update(user_input) + if self._flow[next_step_id].next_step and ( + user_input is not None or self._flow[next_step_id].schema is None + ): + # Get next step + next_step_id_or_end_flow = self._flow[next_step_id].next_step(self._options) + if next_step_id_or_end_flow is None: + # Flow done, create entry or update config entry options + return self._handler.async_create_entry(data=self._options) + + next_step_id = next_step_id_or_end_flow + + if (data_schema := self._flow[next_step_id].schema) and data_schema.schema: + # Copy the schema, then set suggested field values to saved options + schema = dict(data_schema.schema) + for key in list(schema): + if key in self._options and isinstance(key, vol.Marker): + # Copy the marker to not modify the flow schema + new_key = copy.copy(key) + new_key.description = {"suggested_value": self._options[key]} + val = schema.pop(key) + schema[new_key] = val + data_schema = vol.Schema(schema) + + # Show form for next step return self._handler.async_show_form( - step_id=step_id, data_schema=vol.Schema(schema), errors=errors + step_id=next_step_id, data_schema=data_schema ) class HelperConfigFlowHandler(config_entries.ConfigFlow): """Handle a config flow for helper integrations.""" - steps: dict[str, vol.Schema] + config_flow: dict[str, HelperFlowStep] + options_flow: dict[str, HelperFlowStep] | None = None VERSION = 1 # pylint: disable-next=arguments-differ def __init_subclass__(cls, **kwargs: Any) -> None: - """Initialize a subclass, register if possible.""" + """Initialize a subclass.""" super().__init_subclass__(**kwargs) @callback @@ -86,30 +101,21 @@ class HelperConfigFlowHandler(config_entries.ConfigFlow): config_entry: config_entries.ConfigEntry, ) -> config_entries.OptionsFlow: """Get the options flow for this handler.""" - if ( - cls.async_initial_options_step - is HelperConfigFlowHandler.async_initial_options_step - ): + if cls.options_flow is None: raise UnknownHandler - return HelperOptionsFlowHandler( - config_entry, - cls.steps, - cls.async_config_entry_title, - cls.async_initial_options_step, - cls.async_next_step, - cls.async_validate_input, - ) + return HelperOptionsFlowHandler(config_entry, cls.options_flow) # Create an async_get_options_flow method cls.async_get_options_flow = _async_get_options_flow # type: ignore[assignment] + # Create flow step methods for each step defined in the flow schema - for step in cls.steps: - setattr(cls, f"async_step_{step}", cls.async_step) + for step in cls.config_flow: + setattr(cls, f"async_step_{step}", cls._async_step) def __init__(self) -> None: """Initialize config flow.""" - self._common_handler = HelperCommonFlowHandler(self, None) + self._common_handler = HelperCommonFlowHandler(self, self.config_flow, None) @classmethod @callback @@ -117,50 +123,34 @@ class HelperConfigFlowHandler(config_entries.ConfigFlow): cls, config_entry: config_entries.ConfigEntry ) -> bool: """Return options flow support for this handler.""" - return ( - cls.async_initial_options_step - is not HelperConfigFlowHandler.async_initial_options_step - ) + return cls.options_flow is not None - async def async_step_user( - self, user_input: dict[str, Any] | None = None - ) -> FlowResult: - """Handle the initial step.""" - return await self.async_step() - - async def async_step(self, user_input: dict[str, Any] | None = None) -> FlowResult: - """Handle a step.""" - step_id = self.cur_step["step_id"] if self.cur_step else "init" + async def _async_step(self, user_input: dict[str, Any] | None = None) -> FlowResult: + """Handle a config flow step.""" + step_id = self.cur_step["step_id"] if self.cur_step else "user" result = await self._common_handler.async_step(step_id, user_input) - if result["type"] == RESULT_TYPE_CREATE_ENTRY: - result["options"] = result["data"] - result["data"] = {} + return result # pylint: disable-next=no-self-use @abstractmethod - def async_config_entry_title(self, user_input: dict[str, Any]) -> str: - """Return config entry title.""" + def async_config_entry_title(self, options: Mapping[str, Any]) -> str: + """Return config entry title. - # pylint: disable-next=no-self-use - def async_next_step(self, step_id: str, user_input: dict[str, Any]) -> str | None: - """Return next step_id, or None to finish the flow.""" - return None + The options parameter contains config entry options, which is the union of user + input from the config flow steps. + """ - @staticmethod @callback - def async_initial_options_step( - config_entry: config_entries.ConfigEntry, - ) -> str: - """Return initial step_id of options flow.""" - raise UnknownHandler - - # pylint: disable-next=no-self-use - async def async_validate_input( - self, hass: HomeAssistant, step_id: str, user_input: dict[str, Any] - ) -> dict[str, Any]: - """Validate user input.""" - return user_input + def async_create_entry( # pylint: disable=arguments-differ + self, + data: Mapping[str, Any], + **kwargs: Any, + ) -> FlowResult: + """Finish config flow and create a config entry.""" + return super().async_create_entry( + data={}, options=data, title=self.async_config_entry_title(data), **kwargs + ) class HelperOptionsFlowHandler(config_entries.OptionsFlow): @@ -169,35 +159,25 @@ class HelperOptionsFlowHandler(config_entries.OptionsFlow): def __init__( self, config_entry: config_entries.ConfigEntry, - steps: dict[str, vol.Schema], - config_entry_title: Callable[[Any, dict[str, Any]], str], - initial_step: Callable[[config_entries.ConfigEntry], str], - next_step: Callable[[Any, str, dict[str, Any]], str | None], - validate: Callable[ - [Any, HomeAssistant, str, dict[str, Any]], Awaitable[dict[str, Any]] - ], + options_flow: dict[str, vol.Schema], ) -> None: """Initialize options flow.""" - self._common_handler = HelperCommonFlowHandler(self, config_entry) + self._common_handler = HelperCommonFlowHandler(self, options_flow, config_entry) self._config_entry = config_entry - self._initial_step = initial_step(config_entry) - self.async_config_entry_title = types.MethodType(config_entry_title, self) - self.async_next_step = types.MethodType(next_step, self) - self.async_validate_input = types.MethodType(validate, self) - self.steps = steps - for step in self.steps: - if step == "init": - continue - setattr(self, f"async_step_{step}", self.async_step) - async def async_step_init( - self, user_input: dict[str, Any] | None = None - ) -> FlowResult: - """Handle the initial step.""" - return await self.async_step(user_input) + for step in options_flow: + setattr(self, f"async_step_{step}", self._async_step) - async def async_step(self, user_input: dict[str, Any] | None = None) -> FlowResult: - """Handle a step.""" + async def _async_step(self, user_input: dict[str, Any] | None = None) -> FlowResult: + """Handle an options flow step.""" # pylint: disable-next=unsubscriptable-object # self.cur_step is a dict - step_id = self.cur_step["step_id"] if self.cur_step else self._initial_step + step_id = self.cur_step["step_id"] if self.cur_step else "init" return await self._common_handler.async_step(step_id, user_input) + + @callback + def async_create_entry( # pylint: disable=arguments-differ + self, + **kwargs: Any, + ) -> FlowResult: + """Finish config flow and create a config entry.""" + return super().async_create_entry(title="", **kwargs) diff --git a/tests/components/group/test_config_flow.py b/tests/components/group/test_config_flow.py index cc97ff8c95f..121ecc717e6 100644 --- a/tests/components/group/test_config_flow.py +++ b/tests/components/group/test_config_flow.py @@ -142,7 +142,7 @@ async def test_options(hass: HomeAssistant, group_type, member_state) -> None: result = await hass.config_entries.options.async_init(config_entry.entry_id) assert result["type"] == RESULT_TYPE_FORM - assert result["step_id"] == f"{group_type}_options" + assert result["step_id"] == group_type assert get_suggested(result["data_schema"].schema, "entities") == members1 assert "name" not in result["data_schema"].schema