Improve schema typing (3) (#120521)

This commit is contained in:
Marc Mueller 2024-06-26 11:30:07 +02:00 committed by GitHub
parent afbd24adfe
commit d527113d59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 44 additions and 35 deletions

View File

@ -58,9 +58,9 @@ class InputButtonStorageCollection(collection.DictStorageCollection):
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
async def _process_create_data(self, data: dict) -> vol.Schema: async def _process_create_data(self, data: dict) -> dict[str, str]:
"""Validate the config is valid.""" """Validate the config is valid."""
return self.CREATE_UPDATE_SCHEMA(data) return self.CREATE_UPDATE_SCHEMA(data) # type: ignore[no-any-return]
@callback @callback
def _get_suggested_id(self, info: dict) -> str: def _get_suggested_id(self, info: dict) -> str:

View File

@ -163,9 +163,9 @@ class InputTextStorageCollection(collection.DictStorageCollection):
CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_text)) CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_text))
async def _process_create_data(self, data: dict[str, Any]) -> vol.Schema: async def _process_create_data(self, data: dict[str, Any]) -> dict[str, Any]:
"""Validate the config is valid.""" """Validate the config is valid."""
return self.CREATE_UPDATE_SCHEMA(data) return self.CREATE_UPDATE_SCHEMA(data) # type: ignore[no-any-return]
@callback @callback
def _get_suggested_id(self, info: dict[str, Any]) -> str: def _get_suggested_id(self, info: dict[str, Any]) -> str:

View File

@ -302,7 +302,7 @@ def is_on(hass: HomeAssistant, entity_id: str) -> bool:
def preprocess_turn_on_alternatives( def preprocess_turn_on_alternatives(
hass: HomeAssistant, params: dict[str, Any] hass: HomeAssistant, params: dict[str, Any] | dict[str | vol.Optional, Any]
) -> None: ) -> None:
"""Process extra data for turn light on request. """Process extra data for turn light on request.
@ -406,7 +406,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: # noqa:
# of the light base platform. # of the light base platform.
hass.async_create_task(profiles.async_initialize(), eager_start=True) hass.async_create_task(profiles.async_initialize(), eager_start=True)
def preprocess_data(data: dict[str, Any]) -> dict[str | vol.Optional, Any]: def preprocess_data(
data: dict[str | vol.Optional, Any],
) -> dict[str | vol.Optional, Any]:
"""Preprocess the service data.""" """Preprocess the service data."""
base: dict[str | vol.Optional, Any] = { base: dict[str | vol.Optional, Any] = {
entity_field: data.pop(entity_field) entity_field: data.pop(entity_field)

View File

@ -226,14 +226,16 @@ class MotionEyeOptionsFlow(OptionsFlow):
if self.show_advanced_options: if self.show_advanced_options:
# The input URL is not validated as being a URL, to allow for the possibility # The input URL is not validated as being a URL, to allow for the possibility
# the template input won't be a valid URL until after it's rendered # the template input won't be a valid URL until after it's rendered
stream_kwargs = {} description: dict[str, str] | None = None
if CONF_STREAM_URL_TEMPLATE in self._config_entry.options: if CONF_STREAM_URL_TEMPLATE in self._config_entry.options:
stream_kwargs["description"] = { description = {
"suggested_value": self._config_entry.options[ "suggested_value": self._config_entry.options[
CONF_STREAM_URL_TEMPLATE CONF_STREAM_URL_TEMPLATE
] ]
} }
schema[vol.Optional(CONF_STREAM_URL_TEMPLATE, **stream_kwargs)] = str schema[vol.Optional(CONF_STREAM_URL_TEMPLATE, description=description)] = (
str
)
return self.async_show_form(step_id="init", data_schema=vol.Schema(schema)) return self.async_show_form(step_id="init", data_schema=vol.Schema(schema))

View File

@ -167,8 +167,9 @@ async def async_get_action_capabilities(
hass: HomeAssistant, config: ConfigType hass: HomeAssistant, config: ConfigType
) -> dict[str, vol.Schema]: ) -> dict[str, vol.Schema]:
"""List action capabilities.""" """List action capabilities."""
if (fields := DEVICE_ACTION_SCHEMAS.get(config[CONF_TYPE])) is None:
return {"extra_fields": DEVICE_ACTION_SCHEMAS.get(config[CONF_TYPE], {})} return {}
return {"extra_fields": fields}
async def _execute_service_based_action( async def _execute_service_based_action(

View File

@ -80,10 +80,8 @@ def validate_event_data(obj: dict) -> dict:
except ValidationError as exc: except ValidationError as exc:
# Filter out required field errors if keys can be missing, and if there are # Filter out required field errors if keys can be missing, and if there are
# still errors, raise an exception # still errors, raise an exception
if errors := [ if [error for error in exc.errors() if error["type"] != "value_error.missing"]:
error for error in exc.errors() if error["type"] != "value_error.missing" raise vol.MultipleInvalid from exc
]:
raise vol.MultipleInvalid(errors) from exc
return obj return obj

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import abc import abc
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Container, Iterable, Mapping from collections.abc import Callable, Container, Hashable, Iterable, Mapping
from contextlib import suppress from contextlib import suppress
import copy import copy
from dataclasses import dataclass from dataclasses import dataclass
@ -13,7 +13,7 @@ from enum import StrEnum
from functools import partial from functools import partial
import logging import logging
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Generic, Required, TypedDict from typing import Any, Generic, Required, TypedDict, cast
from typing_extensions import TypeVar from typing_extensions import TypeVar
import voluptuous as vol import voluptuous as vol
@ -120,7 +120,7 @@ class InvalidData(vol.Invalid): # type: ignore[misc]
def __init__( def __init__(
self, self,
message: str, message: str,
path: list[str | vol.Marker] | None, path: list[Hashable] | None,
error_message: str | None, error_message: str | None,
schema_errors: dict[str, Any], schema_errors: dict[str, Any],
**kwargs: Any, **kwargs: Any,
@ -384,6 +384,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
if ( if (
data_schema := cur_step.get("data_schema") data_schema := cur_step.get("data_schema")
) is not None and user_input is not None: ) is not None and user_input is not None:
data_schema = cast(vol.Schema, data_schema)
try: try:
user_input = data_schema(user_input) # type: ignore[operator] user_input = data_schema(user_input) # type: ignore[operator]
except vol.Invalid as ex: except vol.Invalid as ex:
@ -694,7 +695,7 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
): ):
# Copy the marker to not modify the flow schema # Copy the marker to not modify the flow schema
new_key = copy.copy(key) new_key = copy.copy(key)
new_key.description = {"suggested_value": suggested_values[key]} new_key.description = {"suggested_value": suggested_values[key.schema]}
schema[new_key] = val schema[new_key] = val
return vol.Schema(schema) return vol.Schema(schema)

View File

@ -981,7 +981,7 @@ def removed(
def key_value_schemas( def key_value_schemas(
key: str, key: str,
value_schemas: dict[Hashable, VolSchemaType], value_schemas: dict[Hashable, VolSchemaType | Callable[[Any], dict[str, Any]]],
default_schema: VolSchemaType | None = None, default_schema: VolSchemaType | None = None,
default_description: str | None = None, default_description: str | None = None,
) -> Callable[[Any], dict[Hashable, Any]]: ) -> Callable[[Any], dict[Hashable, Any]]:
@ -1016,12 +1016,12 @@ def key_value_schemas(
# Validator helpers # Validator helpers
def key_dependency( def key_dependency[_KT: Hashable, _VT](
key: Hashable, dependency: Hashable key: Hashable, dependency: Hashable
) -> Callable[[dict[Hashable, Any]], dict[Hashable, Any]]: ) -> Callable[[dict[_KT, _VT]], dict[_KT, _VT]]:
"""Validate that all dependencies exist for key.""" """Validate that all dependencies exist for key."""
def validator(value: dict[Hashable, Any]) -> dict[Hashable, Any]: def validator(value: dict[_KT, _VT]) -> dict[_KT, _VT]:
"""Test dependencies.""" """Test dependencies."""
if not isinstance(value, dict): if not isinstance(value, dict):
raise vol.Invalid("key dependencies require a dict") raise vol.Invalid("key dependencies require a dict")
@ -1405,13 +1405,13 @@ STATE_CONDITION_ATTRIBUTE_SCHEMA = vol.Schema(
) )
def STATE_CONDITION_SCHEMA(value: Any) -> dict: def STATE_CONDITION_SCHEMA(value: Any) -> dict[str, Any]:
"""Validate a state condition.""" """Validate a state condition."""
if not isinstance(value, dict): if not isinstance(value, dict):
raise vol.Invalid("Expected a dictionary") raise vol.Invalid("Expected a dictionary")
if CONF_ATTRIBUTE in value: if CONF_ATTRIBUTE in value:
validated: dict = STATE_CONDITION_ATTRIBUTE_SCHEMA(value) validated: dict[str, Any] = STATE_CONDITION_ATTRIBUTE_SCHEMA(value)
else: else:
validated = STATE_CONDITION_STATE_SCHEMA(value) validated = STATE_CONDITION_STATE_SCHEMA(value)

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
import asyncio import asyncio
from collections.abc import Collection, Coroutine, Iterable from collections.abc import Callable, Collection, Coroutine, Iterable
import dataclasses import dataclasses
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum, auto from enum import Enum, auto
@ -37,6 +37,9 @@ from .typing import VolSchemaType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
type _SlotsType = dict[str, Any] type _SlotsType = dict[str, Any]
type _IntentSlotsType = dict[
str | tuple[str, str], VolSchemaType | Callable[[Any], Any]
]
INTENT_TURN_OFF = "HassTurnOff" INTENT_TURN_OFF = "HassTurnOff"
INTENT_TURN_ON = "HassTurnOn" INTENT_TURN_ON = "HassTurnOn"
@ -808,8 +811,8 @@ class DynamicServiceIntentHandler(IntentHandler):
self, self,
intent_type: str, intent_type: str,
speech: str | None = None, speech: str | None = None,
required_slots: dict[str | tuple[str, str], VolSchemaType] | None = None, required_slots: _IntentSlotsType | None = None,
optional_slots: dict[str | tuple[str, str], VolSchemaType] | None = None, optional_slots: _IntentSlotsType | None = None,
required_domains: set[str] | None = None, required_domains: set[str] | None = None,
required_features: int | None = None, required_features: int | None = None,
required_states: set[str] | None = None, required_states: set[str] | None = None,
@ -825,7 +828,7 @@ class DynamicServiceIntentHandler(IntentHandler):
self.description = description self.description = description
self.platforms = platforms self.platforms = platforms
self.required_slots: dict[tuple[str, str], VolSchemaType] = {} self.required_slots: _IntentSlotsType = {}
if required_slots: if required_slots:
for key, value_schema in required_slots.items(): for key, value_schema in required_slots.items():
if isinstance(key, str): if isinstance(key, str):
@ -834,7 +837,7 @@ class DynamicServiceIntentHandler(IntentHandler):
self.required_slots[key] = value_schema self.required_slots[key] = value_schema
self.optional_slots: dict[tuple[str, str], VolSchemaType] = {} self.optional_slots: _IntentSlotsType = {}
if optional_slots: if optional_slots:
for key, value_schema in optional_slots.items(): for key, value_schema in optional_slots.items():
if isinstance(key, str): if isinstance(key, str):
@ -1108,8 +1111,8 @@ class ServiceIntentHandler(DynamicServiceIntentHandler):
domain: str, domain: str,
service: str, service: str,
speech: str | None = None, speech: str | None = None,
required_slots: dict[str | tuple[str, str], VolSchemaType] | None = None, required_slots: _IntentSlotsType | None = None,
optional_slots: dict[str | tuple[str, str], VolSchemaType] | None = None, optional_slots: _IntentSlotsType | None = None,
required_domains: set[str] | None = None, required_domains: set[str] | None = None,
required_features: int | None = None, required_features: int | None = None,
required_states: set[str] | None = None, required_states: set[str] | None = None,

View File

@ -175,7 +175,9 @@ class SchemaCommonFlowHandler:
and key.default is not vol.UNDEFINED and key.default is not vol.UNDEFINED
and key not in self._options and key not in self._options
): ):
user_input[str(key.schema)] = key.default() user_input[str(key.schema)] = cast(
Callable[[], Any], key.default
)()
if user_input is not None and form_step.validate_user_input is not None: if user_input is not None and form_step.validate_user_input is not None:
# Do extra validation of user input # Do extra validation of user input
@ -215,7 +217,7 @@ class SchemaCommonFlowHandler:
) )
): ):
# Key not present, delete keys old value (if present) too # Key not present, delete keys old value (if present) too
values.pop(key, None) values.pop(key.schema, None)
async def _show_next_step_or_create_entry( async def _show_next_step_or_create_entry(
self, form_step: SchemaFlowFormStep self, form_step: SchemaFlowFormStep
@ -491,7 +493,7 @@ def wrapped_entity_config_entry_title(
def entity_selector_without_own_entities( def entity_selector_without_own_entities(
handler: SchemaOptionsFlowHandler, handler: SchemaOptionsFlowHandler,
entity_selector_config: selector.EntitySelectorConfig, entity_selector_config: selector.EntitySelectorConfig,
) -> vol.Schema: ) -> selector.EntitySelector:
"""Return an entity selector which excludes own entities.""" """Return an entity selector which excludes own entities."""
entity_registry = er.async_get(handler.hass) entity_registry = er.async_get(handler.hass)
entities = er.async_entries_for_config_entry( entities = er.async_entries_for_config_entry(