mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 01:38:02 +00:00
Improve schema typing (3) (#120521)
This commit is contained in:
parent
afbd24adfe
commit
d527113d59
@ -58,9 +58,9 @@ class InputButtonStorageCollection(collection.DictStorageCollection):
|
||||
|
||||
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
|
||||
|
||||
async def _process_create_data(self, data: dict) -> vol.Schema:
|
||||
async def _process_create_data(self, data: dict) -> dict[str, str]:
|
||||
"""Validate the config is valid."""
|
||||
return self.CREATE_UPDATE_SCHEMA(data)
|
||||
return self.CREATE_UPDATE_SCHEMA(data) # type: ignore[no-any-return]
|
||||
|
||||
@callback
|
||||
def _get_suggested_id(self, info: dict) -> str:
|
||||
|
@ -163,9 +163,9 @@ class InputTextStorageCollection(collection.DictStorageCollection):
|
||||
|
||||
CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_text))
|
||||
|
||||
async def _process_create_data(self, data: dict[str, Any]) -> vol.Schema:
|
||||
async def _process_create_data(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate the config is valid."""
|
||||
return self.CREATE_UPDATE_SCHEMA(data)
|
||||
return self.CREATE_UPDATE_SCHEMA(data) # type: ignore[no-any-return]
|
||||
|
||||
@callback
|
||||
def _get_suggested_id(self, info: dict[str, Any]) -> str:
|
||||
|
@ -302,7 +302,7 @@ def is_on(hass: HomeAssistant, entity_id: str) -> bool:
|
||||
|
||||
|
||||
def preprocess_turn_on_alternatives(
|
||||
hass: HomeAssistant, params: dict[str, Any]
|
||||
hass: HomeAssistant, params: dict[str, Any] | dict[str | vol.Optional, Any]
|
||||
) -> None:
|
||||
"""Process extra data for turn light on request.
|
||||
|
||||
@ -406,7 +406,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: # noqa:
|
||||
# of the light base platform.
|
||||
hass.async_create_task(profiles.async_initialize(), eager_start=True)
|
||||
|
||||
def preprocess_data(data: dict[str, Any]) -> dict[str | vol.Optional, Any]:
|
||||
def preprocess_data(
|
||||
data: dict[str | vol.Optional, Any],
|
||||
) -> dict[str | vol.Optional, Any]:
|
||||
"""Preprocess the service data."""
|
||||
base: dict[str | vol.Optional, Any] = {
|
||||
entity_field: data.pop(entity_field)
|
||||
|
@ -226,14 +226,16 @@ class MotionEyeOptionsFlow(OptionsFlow):
|
||||
if self.show_advanced_options:
|
||||
# The input URL is not validated as being a URL, to allow for the possibility
|
||||
# the template input won't be a valid URL until after it's rendered
|
||||
stream_kwargs = {}
|
||||
description: dict[str, str] | None = None
|
||||
if CONF_STREAM_URL_TEMPLATE in self._config_entry.options:
|
||||
stream_kwargs["description"] = {
|
||||
description = {
|
||||
"suggested_value": self._config_entry.options[
|
||||
CONF_STREAM_URL_TEMPLATE
|
||||
]
|
||||
}
|
||||
|
||||
schema[vol.Optional(CONF_STREAM_URL_TEMPLATE, **stream_kwargs)] = str
|
||||
schema[vol.Optional(CONF_STREAM_URL_TEMPLATE, description=description)] = (
|
||||
str
|
||||
)
|
||||
|
||||
return self.async_show_form(step_id="init", data_schema=vol.Schema(schema))
|
||||
|
@ -167,8 +167,9 @@ async def async_get_action_capabilities(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> dict[str, vol.Schema]:
|
||||
"""List action capabilities."""
|
||||
|
||||
return {"extra_fields": DEVICE_ACTION_SCHEMAS.get(config[CONF_TYPE], {})}
|
||||
if (fields := DEVICE_ACTION_SCHEMAS.get(config[CONF_TYPE])) is None:
|
||||
return {}
|
||||
return {"extra_fields": fields}
|
||||
|
||||
|
||||
async def _execute_service_based_action(
|
||||
|
@ -80,10 +80,8 @@ def validate_event_data(obj: dict) -> dict:
|
||||
except ValidationError as exc:
|
||||
# Filter out required field errors if keys can be missing, and if there are
|
||||
# still errors, raise an exception
|
||||
if errors := [
|
||||
error for error in exc.errors() if error["type"] != "value_error.missing"
|
||||
]:
|
||||
raise vol.MultipleInvalid(errors) from exc
|
||||
if [error for error in exc.errors() if error["type"] != "value_error.missing"]:
|
||||
raise vol.MultipleInvalid from exc
|
||||
return obj
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Container, Iterable, Mapping
|
||||
from collections.abc import Callable, Container, Hashable, Iterable, Mapping
|
||||
from contextlib import suppress
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
@ -13,7 +13,7 @@ from enum import StrEnum
|
||||
from functools import partial
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Generic, Required, TypedDict
|
||||
from typing import Any, Generic, Required, TypedDict, cast
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
import voluptuous as vol
|
||||
@ -120,7 +120,7 @@ class InvalidData(vol.Invalid): # type: ignore[misc]
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
path: list[str | vol.Marker] | None,
|
||||
path: list[Hashable] | None,
|
||||
error_message: str | None,
|
||||
schema_errors: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
@ -384,6 +384,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
if (
|
||||
data_schema := cur_step.get("data_schema")
|
||||
) is not None and user_input is not None:
|
||||
data_schema = cast(vol.Schema, data_schema)
|
||||
try:
|
||||
user_input = data_schema(user_input) # type: ignore[operator]
|
||||
except vol.Invalid as ex:
|
||||
@ -694,7 +695,7 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
|
||||
):
|
||||
# Copy the marker to not modify the flow schema
|
||||
new_key = copy.copy(key)
|
||||
new_key.description = {"suggested_value": suggested_values[key]}
|
||||
new_key.description = {"suggested_value": suggested_values[key.schema]}
|
||||
schema[new_key] = val
|
||||
return vol.Schema(schema)
|
||||
|
||||
|
@ -981,7 +981,7 @@ def removed(
|
||||
|
||||
def key_value_schemas(
|
||||
key: str,
|
||||
value_schemas: dict[Hashable, VolSchemaType],
|
||||
value_schemas: dict[Hashable, VolSchemaType | Callable[[Any], dict[str, Any]]],
|
||||
default_schema: VolSchemaType | None = None,
|
||||
default_description: str | None = None,
|
||||
) -> Callable[[Any], dict[Hashable, Any]]:
|
||||
@ -1016,12 +1016,12 @@ def key_value_schemas(
|
||||
# Validator helpers
|
||||
|
||||
|
||||
def key_dependency(
|
||||
def key_dependency[_KT: Hashable, _VT](
|
||||
key: Hashable, dependency: Hashable
|
||||
) -> Callable[[dict[Hashable, Any]], dict[Hashable, Any]]:
|
||||
) -> Callable[[dict[_KT, _VT]], dict[_KT, _VT]]:
|
||||
"""Validate that all dependencies exist for key."""
|
||||
|
||||
def validator(value: dict[Hashable, Any]) -> dict[Hashable, Any]:
|
||||
def validator(value: dict[_KT, _VT]) -> dict[_KT, _VT]:
|
||||
"""Test dependencies."""
|
||||
if not isinstance(value, dict):
|
||||
raise vol.Invalid("key dependencies require a dict")
|
||||
@ -1405,13 +1405,13 @@ STATE_CONDITION_ATTRIBUTE_SCHEMA = vol.Schema(
|
||||
)
|
||||
|
||||
|
||||
def STATE_CONDITION_SCHEMA(value: Any) -> dict:
|
||||
def STATE_CONDITION_SCHEMA(value: Any) -> dict[str, Any]:
|
||||
"""Validate a state condition."""
|
||||
if not isinstance(value, dict):
|
||||
raise vol.Invalid("Expected a dictionary")
|
||||
|
||||
if CONF_ATTRIBUTE in value:
|
||||
validated: dict = STATE_CONDITION_ATTRIBUTE_SCHEMA(value)
|
||||
validated: dict[str, Any] = STATE_CONDITION_ATTRIBUTE_SCHEMA(value)
|
||||
else:
|
||||
validated = STATE_CONDITION_STATE_SCHEMA(value)
|
||||
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Collection, Coroutine, Iterable
|
||||
from collections.abc import Callable, Collection, Coroutine, Iterable
|
||||
import dataclasses
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
@ -37,6 +37,9 @@ from .typing import VolSchemaType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
type _SlotsType = dict[str, Any]
|
||||
type _IntentSlotsType = dict[
|
||||
str | tuple[str, str], VolSchemaType | Callable[[Any], Any]
|
||||
]
|
||||
|
||||
INTENT_TURN_OFF = "HassTurnOff"
|
||||
INTENT_TURN_ON = "HassTurnOn"
|
||||
@ -808,8 +811,8 @@ class DynamicServiceIntentHandler(IntentHandler):
|
||||
self,
|
||||
intent_type: str,
|
||||
speech: str | None = None,
|
||||
required_slots: dict[str | tuple[str, str], VolSchemaType] | None = None,
|
||||
optional_slots: dict[str | tuple[str, str], VolSchemaType] | None = None,
|
||||
required_slots: _IntentSlotsType | None = None,
|
||||
optional_slots: _IntentSlotsType | None = None,
|
||||
required_domains: set[str] | None = None,
|
||||
required_features: int | None = None,
|
||||
required_states: set[str] | None = None,
|
||||
@ -825,7 +828,7 @@ class DynamicServiceIntentHandler(IntentHandler):
|
||||
self.description = description
|
||||
self.platforms = platforms
|
||||
|
||||
self.required_slots: dict[tuple[str, str], VolSchemaType] = {}
|
||||
self.required_slots: _IntentSlotsType = {}
|
||||
if required_slots:
|
||||
for key, value_schema in required_slots.items():
|
||||
if isinstance(key, str):
|
||||
@ -834,7 +837,7 @@ class DynamicServiceIntentHandler(IntentHandler):
|
||||
|
||||
self.required_slots[key] = value_schema
|
||||
|
||||
self.optional_slots: dict[tuple[str, str], VolSchemaType] = {}
|
||||
self.optional_slots: _IntentSlotsType = {}
|
||||
if optional_slots:
|
||||
for key, value_schema in optional_slots.items():
|
||||
if isinstance(key, str):
|
||||
@ -1108,8 +1111,8 @@ class ServiceIntentHandler(DynamicServiceIntentHandler):
|
||||
domain: str,
|
||||
service: str,
|
||||
speech: str | None = None,
|
||||
required_slots: dict[str | tuple[str, str], VolSchemaType] | None = None,
|
||||
optional_slots: dict[str | tuple[str, str], VolSchemaType] | None = None,
|
||||
required_slots: _IntentSlotsType | None = None,
|
||||
optional_slots: _IntentSlotsType | None = None,
|
||||
required_domains: set[str] | None = None,
|
||||
required_features: int | None = None,
|
||||
required_states: set[str] | None = None,
|
||||
|
@ -175,7 +175,9 @@ class SchemaCommonFlowHandler:
|
||||
and key.default is not vol.UNDEFINED
|
||||
and key not in self._options
|
||||
):
|
||||
user_input[str(key.schema)] = key.default()
|
||||
user_input[str(key.schema)] = cast(
|
||||
Callable[[], Any], key.default
|
||||
)()
|
||||
|
||||
if user_input is not None and form_step.validate_user_input is not None:
|
||||
# Do extra validation of user input
|
||||
@ -215,7 +217,7 @@ class SchemaCommonFlowHandler:
|
||||
)
|
||||
):
|
||||
# Key not present, delete keys old value (if present) too
|
||||
values.pop(key, None)
|
||||
values.pop(key.schema, None)
|
||||
|
||||
async def _show_next_step_or_create_entry(
|
||||
self, form_step: SchemaFlowFormStep
|
||||
@ -491,7 +493,7 @@ def wrapped_entity_config_entry_title(
|
||||
def entity_selector_without_own_entities(
|
||||
handler: SchemaOptionsFlowHandler,
|
||||
entity_selector_config: selector.EntitySelectorConfig,
|
||||
) -> vol.Schema:
|
||||
) -> selector.EntitySelector:
|
||||
"""Return an entity selector which excludes own entities."""
|
||||
entity_registry = er.async_get(handler.hass)
|
||||
entities = er.async_entries_for_config_entry(
|
||||
|
Loading…
x
Reference in New Issue
Block a user