Add handler to validate_user_input (#82681)

* Add handler to validate_user_input

* Adjust group config flow
This commit is contained in:
epenet 2022-11-25 09:29:54 +01:00 committed by GitHub
parent f3b3193f7a
commit a4dbb9a24e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 13 deletions

View File

@ -11,6 +11,7 @@ from homeassistant.const import CONF_ENTITIES
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_registry as er, selector from homeassistant.helpers import entity_registry as er, selector
from homeassistant.helpers.schema_config_entry_flow import ( from homeassistant.helpers.schema_config_entry_flow import (
SchemaCommonFlowHandler,
SchemaConfigFlowHandler, SchemaConfigFlowHandler,
SchemaFlowFormStep, SchemaFlowFormStep,
SchemaFlowMenuStep, SchemaFlowMenuStep,
@ -104,11 +105,15 @@ def choose_options_step(options: dict[str, Any]) -> str:
return cast(str, options["group_type"]) return cast(str, options["group_type"])
def set_group_type(group_type: str) -> Callable[[dict[str, Any]], dict[str, Any]]: def set_group_type(
group_type: str,
) -> Callable[[SchemaCommonFlowHandler, dict[str, Any]], dict[str, Any]]:
"""Set group type.""" """Set group type."""
@callback @callback
def _set_group_type(user_input: dict[str, Any]) -> dict[str, Any]: def _set_group_type(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]:
"""Add group type to user input.""" """Add group type to user input."""
return {"group_type": group_type, **user_input} return {"group_type": group_type, **user_input}

View File

@ -36,6 +36,7 @@ from homeassistant.const import (
) )
from homeassistant.core import async_get_hass from homeassistant.core import async_get_hass
from homeassistant.helpers.schema_config_entry_flow import ( from homeassistant.helpers.schema_config_entry_flow import (
SchemaCommonFlowHandler,
SchemaConfigFlowHandler, SchemaConfigFlowHandler,
SchemaFlowError, SchemaFlowError,
SchemaFlowFormStep, SchemaFlowFormStep,
@ -113,7 +114,9 @@ SENSOR_SETUP = {
} }
def validate_rest_setup(user_input: dict[str, Any]) -> dict[str, Any]: def validate_rest_setup(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]:
"""Validate rest setup.""" """Validate rest setup."""
hass = async_get_hass() hass = async_get_hass()
rest_config: dict[str, Any] = COMBINED_SCHEMA(user_input) rest_config: dict[str, Any] = COMBINED_SCHEMA(user_input)
@ -124,7 +127,9 @@ def validate_rest_setup(user_input: dict[str, Any]) -> dict[str, Any]:
return user_input return user_input
def validate_sensor_setup(user_input: dict[str, Any]) -> dict[str, Any]: def validate_sensor_setup(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]:
"""Validate sensor setup.""" """Validate sensor setup."""
return { return {
"sensor": [ "sensor": [

View File

@ -10,6 +10,7 @@ from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
from homeassistant.const import CONF_ENTITY_ID, CONF_NAME from homeassistant.const import CONF_ENTITY_ID, CONF_NAME
from homeassistant.helpers import selector from homeassistant.helpers import selector
from homeassistant.helpers.schema_config_entry_flow import ( from homeassistant.helpers.schema_config_entry_flow import (
SchemaCommonFlowHandler,
SchemaConfigFlowHandler, SchemaConfigFlowHandler,
SchemaFlowError, SchemaFlowError,
SchemaFlowFormStep, SchemaFlowFormStep,
@ -18,11 +19,13 @@ from homeassistant.helpers.schema_config_entry_flow import (
from .const import CONF_HYSTERESIS, CONF_LOWER, CONF_UPPER, DEFAULT_HYSTERESIS, DOMAIN from .const import CONF_HYSTERESIS, CONF_LOWER, CONF_UPPER, DEFAULT_HYSTERESIS, DOMAIN
def _validate_mode(data: Any) -> Any: def _validate_mode(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]:
"""Validate the threshold mode, and set limits to None if not set.""" """Validate the threshold mode, and set limits to None if not set."""
if CONF_LOWER not in data and CONF_UPPER not in data: if CONF_LOWER not in user_input and CONF_UPPER not in user_input:
raise SchemaFlowError("need_lower_upper") raise SchemaFlowError("need_lower_upper")
return {CONF_LOWER: None, CONF_UPPER: None, **data} return {CONF_LOWER: None, CONF_UPPER: None, **user_input}
OPTIONS_SCHEMA = vol.Schema( OPTIONS_SCHEMA = vol.Schema(

View File

@ -10,6 +10,7 @@ from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
from homeassistant.const import CONF_NAME from homeassistant.const import CONF_NAME
from homeassistant.helpers import selector from homeassistant.helpers import selector
from homeassistant.helpers.schema_config_entry_flow import ( from homeassistant.helpers.schema_config_entry_flow import (
SchemaCommonFlowHandler,
SchemaConfigFlowHandler, SchemaConfigFlowHandler,
SchemaFlowError, SchemaFlowError,
SchemaFlowFormStep, SchemaFlowFormStep,
@ -46,14 +47,16 @@ METER_TYPES = [
] ]
def _validate_config(data: Any) -> Any: def _validate_config(
handler: SchemaCommonFlowHandler, user_input: dict[str, Any]
) -> dict[str, Any]:
"""Validate config.""" """Validate config."""
try: try:
vol.Unique()(data[CONF_TARIFFS]) vol.Unique()(user_input[CONF_TARIFFS])
except vol.Invalid as exc: except vol.Invalid as exc:
raise SchemaFlowError("tariffs_not_unique") from exc raise SchemaFlowError("tariffs_not_unique") from exc
return data return user_input
OPTIONS_SCHEMA = vol.Schema( OPTIONS_SCHEMA = vol.Schema(

View File

@ -44,7 +44,9 @@ class SchemaFlowFormStep(SchemaFlowStep):
user input is requested. user input is requested.
""" """
validate_user_input: Callable[[dict[str, Any]], dict[str, Any]] = lambda x: x validate_user_input: Callable[
[SchemaCommonFlowHandler, dict[str, Any]], dict[str, Any]
] | None = None
"""Optional function to validate user input. """Optional function to validate user input.
- The `validate_user_input` function is called if the schema validates successfully. - The `validate_user_input` function is called if the schema validates successfully.
@ -124,10 +126,10 @@ class SchemaCommonFlowHandler:
): ):
user_input[str(key.schema)] = key.default() user_input[str(key.schema)] = key.default()
if user_input is not None and form_step.schema is not None: if user_input is not None and form_step.validate_user_input is not None:
# Do extra validation of user input # Do extra validation of user input
try: try:
user_input = form_step.validate_user_input(user_input) user_input = form_step.validate_user_input(self, user_input)
except SchemaFlowError as exc: except SchemaFlowError as exc:
return self._show_next_step(step_id, exc, user_input) return self._show_next_step(step_id, exc, user_input)