Improve schema typing (2) (#120475)

This commit is contained in:
Marc Mueller 2024-06-26 02:25:30 +02:00 committed by GitHub
parent 2380696fcd
commit 49df0c4366
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 23 additions and 15 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from collections.abc import Mapping
import logging
from typing import Any
from typing import TYPE_CHECKING, Any
from aurorapy.client import AuroraError, AuroraSerialClient
import serial.tools.list_ports
@ -78,7 +78,7 @@ class AuroraABBConfigFlow(ConfigFlow, domain=DOMAIN):
def __init__(self):
"""Initialise the config flow."""
self.config = None
self._com_ports_list = None
self._com_ports_list: list[str] | None = None
self._default_com_port = None
async def async_step_user(
@ -92,6 +92,8 @@ class AuroraABBConfigFlow(ConfigFlow, domain=DOMAIN):
self._com_ports_list, self._default_com_port = result
if self._default_com_port is None:
return self.async_abort(reason="no_serial_ports")
if TYPE_CHECKING:
assert isinstance(self._com_ports_list, list)
# Handle the initial step.
if user_input is not None:

View File

@ -29,7 +29,7 @@ from homeassistant.helpers import (
device_registry as dr,
entity_registry as er,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.typing import ConfigType, VolSchemaType
from homeassistant.loader import IntegrationNotFound
from homeassistant.requirements import (
RequirementsNotFound,
@ -340,7 +340,7 @@ def async_get_entity_registry_entry_or_raise(
@callback
def async_validate_entity_schema(
hass: HomeAssistant, config: ConfigType, schema: vol.Schema
hass: HomeAssistant, config: ConfigType, schema: VolSchemaType
) -> ConfigType:
"""Validate schema and resolve entity registry entry id to entity_id."""
config = schema(config)

View File

@ -51,6 +51,7 @@ async def basic_group_options_schema(
domain: str | list[str], handler: SchemaCommonFlowHandler | None
) -> vol.Schema:
"""Generate options schema."""
entity_selector: selector.Selector[Any] | vol.Schema
if handler is None:
entity_selector = selector.selector(
{"entity": {"domain": domain, "multiple": True}}

View File

@ -13,7 +13,7 @@ from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.typing import ConfigType, VolDictType
from .const import DOMAIN
from .schema import ga_validator
@ -32,7 +32,7 @@ CONF_KNX_INCOMING: Final = "incoming"
CONF_KNX_OUTGOING: Final = "outgoing"
TELEGRAM_TRIGGER_SCHEMA: Final = {
TELEGRAM_TRIGGER_SCHEMA: VolDictType = {
vol.Optional(CONF_KNX_DESTINATION): vol.All(cv.ensure_list, [ga_validator]),
vol.Optional(CONF_KNX_GROUP_VALUE_WRITE, default=True): cv.boolean,
vol.Optional(CONF_KNX_GROUP_VALUE_RESPONSE, default=True): cv.boolean,

View File

@ -6,6 +6,7 @@ from collections import defaultdict
from collections.abc import Callable
from enum import IntEnum
import logging
from typing import cast
from mysensors import BaseAsyncGateway, Message
from mysensors.sensor import ChildSensor
@ -151,7 +152,7 @@ def get_child_schema(
) -> vol.Schema:
"""Return a child schema."""
set_req = gateway.const.SetReq
child_schema = child.get_schema(gateway.protocol_version)
child_schema = cast(vol.Schema, child.get_schema(gateway.protocol_version))
return child_schema.extend(
{
vol.Required(

View File

@ -43,7 +43,7 @@ class SmaConfigFlow(ConfigFlow, domain=DOMAIN):
def __init__(self) -> None:
"""Initialize."""
self._data = {
self._data: dict[str, Any] = {
CONF_HOST: vol.UNDEFINED,
CONF_SSL: False,
CONF_VERIFY_SSL: True,

View File

@ -40,7 +40,7 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.group import expand_entity_ids
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.typing import ConfigType, VolSchemaType
from .const import (
ATTR_COMMAND_CLASS,
@ -479,7 +479,9 @@ def copy_available_params(
)
def get_value_state_schema(value: ZwaveValue) -> vol.Schema | None:
def get_value_state_schema(
value: ZwaveValue,
) -> VolSchemaType | vol.Coerce | vol.In | None:
"""Return device automation schema for a config entry."""
if isinstance(value, ConfigurationValue):
min_ = value.metadata.min

View File

@ -108,6 +108,7 @@ from homeassistant.util.yaml.objects import NodeStrClass
from . import script_variables as script_variables_helper, template as template_helper
from .frame import get_integration_logger
from .typing import VolDictType, VolSchemaType
TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM', 'HH:MM:SS' or 'HH:MM:SS.F'"
@ -980,8 +981,8 @@ def removed(
def key_value_schemas(
key: str,
value_schemas: dict[Hashable, vol.Schema],
default_schema: vol.Schema | None = None,
value_schemas: dict[Hashable, VolSchemaType],
default_schema: VolSchemaType | None = None,
default_description: str | None = None,
) -> Callable[[Any], dict[Hashable, Any]]:
"""Create a validator that validates based on a value for specific key.
@ -1355,7 +1356,7 @@ NUMERIC_STATE_THRESHOLD_SCHEMA = vol.Any(
vol.All(str, entity_domain(["input_number", "number", "sensor", "zone"])),
)
CONDITION_BASE_SCHEMA = {
CONDITION_BASE_SCHEMA: VolDictType = {
vol.Optional(CONF_ALIAS): string,
vol.Optional(CONF_ENABLED): vol.Any(boolean, template),
}

View File

@ -985,7 +985,7 @@ class EntityPlatform:
def async_register_entity_service(
self,
name: str,
schema: VolDictType | VolSchemaType,
schema: VolDictType | VolSchemaType | None,
func: str | Callable[..., Any],
required_features: Iterable[int] | None = None,
supports_response: SupportsResponse = SupportsResponse.NONE,

View File

@ -660,6 +660,7 @@ class ScriptTool(Tool):
description = config.get("description")
if not description:
description = config.get("name")
key: vol.Marker
if config.get("required"):
key = vol.Required(field, description=description)
else:

View File

@ -1182,7 +1182,7 @@ class SelectSelector(Selector[SelectSelectorConfig]):
for option in cast(Sequence[SelectOptionDict], config_options)
]
parent_schema = vol.In(options)
parent_schema: vol.In | vol.Any = vol.In(options)
if self.config["custom_value"]:
parent_schema = vol.Any(parent_schema, str)