More helpers type improvements (#30145)

This commit is contained in:
Ville Skyttä 2019-12-22 20:51:39 +02:00 committed by Paulus Schoutsen
parent 70f8bfbd4f
commit 868eb3c735
5 changed files with 124 additions and 90 deletions

View File

@ -1,6 +1,6 @@
"""Helper to check the configuration file.""" """Helper to check the configuration file."""
from collections import OrderedDict, namedtuple from collections import OrderedDict
from typing import List from typing import List, NamedTuple, Optional
import attr import attr
import voluptuous as vol import voluptuous as vol
@ -19,15 +19,20 @@ from homeassistant.config import (
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType
from homeassistant.requirements import ( from homeassistant.requirements import (
RequirementsNotFound, RequirementsNotFound,
async_get_integration_with_requirements, async_get_integration_with_requirements,
) )
import homeassistant.util.yaml.loader as yaml_loader import homeassistant.util.yaml.loader as yaml_loader
# mypy: allow-untyped-calls, allow-untyped-defs, no-warn-return-any
CheckConfigError = namedtuple("CheckConfigError", "message domain config") class CheckConfigError(NamedTuple):
"""Configuration check error."""
message: str
domain: Optional[str]
config: Optional[ConfigType]
@attr.s @attr.s
@ -36,7 +41,12 @@ class HomeAssistantConfig(OrderedDict):
errors: List[CheckConfigError] = attr.ib(default=attr.Factory(list)) errors: List[CheckConfigError] = attr.ib(default=attr.Factory(list))
def add_error(self, message, domain=None, config=None): def add_error(
self,
message: str,
domain: Optional[str] = None,
config: Optional[ConfigType] = None,
) -> "HomeAssistantConfig":
"""Add a single error.""" """Add a single error."""
self.errors.append(CheckConfigError(str(message), domain, config)) self.errors.append(CheckConfigError(str(message), domain, config))
return self return self
@ -55,7 +65,9 @@ async def async_check_ha_config_file(hass: HomeAssistant) -> HomeAssistantConfig
config_dir = hass.config.config_dir config_dir = hass.config.config_dir
result = HomeAssistantConfig() result = HomeAssistantConfig()
def _pack_error(package, component, config, message): def _pack_error(
package: str, component: str, config: ConfigType, message: str
) -> None:
"""Handle errors from packages: _log_pkg_error.""" """Handle errors from packages: _log_pkg_error."""
message = "Package {} setup failed. Component {} {}".format( message = "Package {} setup failed. Component {} {}".format(
package, component, message package, component, message
@ -64,7 +76,7 @@ async def async_check_ha_config_file(hass: HomeAssistant) -> HomeAssistantConfig
pack_config = core_config[CONF_PACKAGES].get(package, config) pack_config = core_config[CONF_PACKAGES].get(package, config)
result.add_error(message, domain, pack_config) result.add_error(message, domain, pack_config)
def _comp_error(ex, domain, config): def _comp_error(ex: Exception, domain: str, config: ConfigType) -> None:
"""Handle errors from components: async_log_exception.""" """Handle errors from components: async_log_exception."""
result.add_error(_format_config_error(ex, domain, config), domain, config) result.add_error(_format_config_error(ex, domain, config), domain, config)

View File

@ -5,13 +5,26 @@ from datetime import (
time as time_sys, time as time_sys,
timedelta, timedelta,
) )
from enum import Enum
import inspect import inspect
import logging import logging
from numbers import Number from numbers import Number
import os import os
import re import re
from socket import _GLOBAL_DEFAULT_TIMEOUT # type: ignore # private, not in typeshed from socket import _GLOBAL_DEFAULT_TIMEOUT # type: ignore # private, not in typeshed
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union from typing import (
Any,
Callable,
Dict,
Hashable,
List,
Optional,
Pattern,
Type,
TypeVar,
Union,
cast,
)
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import UUID from uuid import UUID
@ -48,12 +61,11 @@ from homeassistant.const import (
) )
from homeassistant.core import split_entity_id, valid_entity_id from homeassistant.core import split_entity_id, valid_entity_id
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.helpers import template as template_helper
from homeassistant.helpers.logging import KeywordStyleAdapter from homeassistant.helpers.logging import KeywordStyleAdapter
from homeassistant.util import slugify as util_slugify from homeassistant.util import slugify as util_slugify
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
# pylint: disable=invalid-name # pylint: disable=invalid-name
TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM' or 'HH:MM:SS'" TIME_PERIOD_ERROR = "offset {} should be format 'HH:MM' or 'HH:MM:SS'"
@ -126,7 +138,7 @@ def boolean(value: Any) -> bool:
raise vol.Invalid("invalid boolean value {}".format(value)) raise vol.Invalid("invalid boolean value {}".format(value))
def isdevice(value): def isdevice(value: Any) -> str:
"""Validate that value is a real device.""" """Validate that value is a real device."""
try: try:
os.stat(value) os.stat(value)
@ -135,19 +147,19 @@ def isdevice(value):
raise vol.Invalid("No device at {} found".format(value)) raise vol.Invalid("No device at {} found".format(value))
def matches_regex(regex): def matches_regex(regex: str) -> Callable[[Any], str]:
"""Validate that the value is a string that matches a regex.""" """Validate that the value is a string that matches a regex."""
regex = re.compile(regex) compiled = re.compile(regex)
def validator(value: Any) -> str: def validator(value: Any) -> str:
"""Validate that value matches the given regex.""" """Validate that value matches the given regex."""
if not isinstance(value, str): if not isinstance(value, str):
raise vol.Invalid("not a string value: {}".format(value)) raise vol.Invalid("not a string value: {}".format(value))
if not regex.match(value): if not compiled.match(value):
raise vol.Invalid( raise vol.Invalid(
"value {} does not match regular expression {}".format( "value {} does not match regular expression {}".format(
value, regex.pattern value, compiled.pattern
) )
) )
@ -156,14 +168,14 @@ def matches_regex(regex):
return validator return validator
def is_regex(value): def is_regex(value: Any) -> Pattern[Any]:
"""Validate that a string is a valid regular expression.""" """Validate that a string is a valid regular expression."""
try: try:
r = re.compile(value) r = re.compile(value)
return r return r
except TypeError: except TypeError:
raise vol.Invalid( raise vol.Invalid(
"value {} is of the wrong type for a regular " "expression".format(value) "value {} is of the wrong type for a regular expression".format(value)
) )
except re.error: except re.error:
raise vol.Invalid("value {} is not a valid regular expression".format(value)) raise vol.Invalid("value {} is not a valid regular expression".format(value))
@ -204,9 +216,9 @@ def ensure_list(value: Union[T, List[T], None]) -> List[T]:
def entity_id(value: Any) -> str: def entity_id(value: Any) -> str:
"""Validate Entity ID.""" """Validate Entity ID."""
value = string(value).lower() str_value = string(value).lower()
if valid_entity_id(value): if valid_entity_id(str_value):
return value return str_value
raise vol.Invalid("Entity ID {} is an invalid entity id".format(value)) raise vol.Invalid("Entity ID {} is an invalid entity id".format(value))
@ -253,17 +265,17 @@ def entities_domain(domain: str) -> Callable[[Union[str, List]], List[str]]:
return validate return validate
def enum(enumClass): def enum(enumClass: Type[Enum]) -> vol.All:
"""Create validator for specified enum.""" """Create validator for specified enum."""
return vol.All(vol.In(enumClass.__members__), enumClass.__getitem__) return vol.All(vol.In(enumClass.__members__), enumClass.__getitem__)
def icon(value): def icon(value: Any) -> str:
"""Validate icon.""" """Validate icon."""
value = str(value) str_value = str(value)
if ":" in value: if ":" in str_value:
return value return str_value
raise vol.Invalid('Icons should be specified in the form "prefix:name"') raise vol.Invalid('Icons should be specified in the form "prefix:name"')
@ -362,7 +374,7 @@ def time_period_seconds(value: Union[int, str]) -> timedelta:
time_period = vol.Any(time_period_str, time_period_seconds, timedelta, time_period_dict) time_period = vol.Any(time_period_str, time_period_seconds, timedelta, time_period_dict)
def match_all(value): def match_all(value: T) -> T:
"""Validate that matches all values.""" """Validate that matches all values."""
return value return value
@ -382,12 +394,12 @@ def remove_falsy(value: List[T]) -> List[T]:
return [v for v in value if v] return [v for v in value if v]
def service(value): def service(value: Any) -> str:
"""Validate service.""" """Validate service."""
# Services use same format as entities so we can use same helper. # Services use same format as entities so we can use same helper.
value = string(value).lower() str_value = string(value).lower()
if valid_entity_id(value): if valid_entity_id(str_value):
return value return str_value
raise vol.Invalid("Service {} does not match format <domain>.<name>".format(value)) raise vol.Invalid("Service {} does not match format <domain>.<name>".format(value))
@ -407,7 +419,7 @@ def schema_with_slug_keys(value_schema: Union[T, Callable]) -> Callable:
for key in value.keys(): for key in value.keys():
slug(key) slug(key)
return schema(value) return cast(Dict, schema(value))
return verify return verify
@ -416,10 +428,10 @@ def slug(value: Any) -> str:
"""Validate value is a valid slug.""" """Validate value is a valid slug."""
if value is None: if value is None:
raise vol.Invalid("Slug should not be None") raise vol.Invalid("Slug should not be None")
value = str(value) str_value = str(value)
slg = util_slugify(value) slg = util_slugify(str_value)
if value == slg: if str_value == slg:
return value return str_value
raise vol.Invalid("invalid slug {} (try {})".format(value, slg)) raise vol.Invalid("invalid slug {} (try {})".format(value, slg))
@ -458,42 +470,41 @@ unit_system = vol.All(
) )
def template(value): def template(value: Optional[Any]) -> template_helper.Template:
"""Validate a jinja2 template.""" """Validate a jinja2 template."""
from homeassistant.helpers import template as template_helper
if value is None: if value is None:
raise vol.Invalid("template value is None") raise vol.Invalid("template value is None")
if isinstance(value, (list, dict, template_helper.Template)): if isinstance(value, (list, dict, template_helper.Template)):
raise vol.Invalid("template value should be a string") raise vol.Invalid("template value should be a string")
value = template_helper.Template(str(value)) template_value = template_helper.Template(str(value)) # type: ignore
try: try:
value.ensure_valid() template_value.ensure_valid()
return value return cast(template_helper.Template, template_value)
except TemplateError as ex: except TemplateError as ex:
raise vol.Invalid("invalid template ({})".format(ex)) raise vol.Invalid("invalid template ({})".format(ex))
def template_complex(value): def template_complex(value: Any) -> Any:
"""Validate a complex jinja2 template.""" """Validate a complex jinja2 template."""
if isinstance(value, list): if isinstance(value, list):
return_value = value.copy() return_list = value.copy()
for idx, element in enumerate(return_value): for idx, element in enumerate(return_list):
return_value[idx] = template_complex(element) return_list[idx] = template_complex(element)
return return_value return return_list
if isinstance(value, dict): if isinstance(value, dict):
return_value = value.copy() return_dict = value.copy()
for key, element in return_value.items(): for key, element in return_dict.items():
return_value[key] = template_complex(element) return_dict[key] = template_complex(element)
return return_value return return_dict
if isinstance(value, str): if isinstance(value, str):
return template(value) return template(value)
return value return value
def datetime(value): def datetime(value: Any) -> datetime_sys:
"""Validate datetime.""" """Validate datetime."""
if isinstance(value, datetime_sys): if isinstance(value, datetime_sys):
return value return value
@ -509,7 +520,7 @@ def datetime(value):
return date_val return date_val
def time_zone(value): def time_zone(value: str) -> str:
"""Validate timezone.""" """Validate timezone."""
if dt_util.get_time_zone(value) is not None: if dt_util.get_time_zone(value) is not None:
return value return value
@ -522,7 +533,7 @@ def time_zone(value):
weekdays = vol.All(ensure_list, [vol.In(WEEKDAYS)]) weekdays = vol.All(ensure_list, [vol.In(WEEKDAYS)])
def socket_timeout(value): def socket_timeout(value: Optional[Any]) -> object:
"""Validate timeout float > 0.0. """Validate timeout float > 0.0.
None coerced to socket._GLOBAL_DEFAULT_TIMEOUT bare object. None coerced to socket._GLOBAL_DEFAULT_TIMEOUT bare object.
@ -544,12 +555,12 @@ def url(value: Any) -> str:
url_in = str(value) url_in = str(value)
if urlparse(url_in).scheme in ["http", "https"]: if urlparse(url_in).scheme in ["http", "https"]:
return vol.Schema(vol.Url())(url_in) return cast(str, vol.Schema(vol.Url())(url_in))
raise vol.Invalid("invalid url") raise vol.Invalid("invalid url")
def x10_address(value): def x10_address(value: str) -> str:
"""Validate an x10 address.""" """Validate an x10 address."""
regex = re.compile(r"([A-Pa-p]{1})(?:[2-9]|1[0-6]?)$") regex = re.compile(r"([A-Pa-p]{1})(?:[2-9]|1[0-6]?)$")
if not regex.match(value): if not regex.match(value):
@ -557,7 +568,7 @@ def x10_address(value):
return str(value).lower() return str(value).lower()
def uuid4_hex(value): def uuid4_hex(value: Any) -> str:
"""Validate a v4 UUID in hex format.""" """Validate a v4 UUID in hex format."""
try: try:
result = UUID(value, version=4) result = UUID(value, version=4)
@ -678,10 +689,12 @@ def deprecated(
# Validator helpers # Validator helpers
def key_dependency(key, dependency): def key_dependency(
key: Hashable, dependency: Hashable
) -> Callable[[Dict[Hashable, Any]], Dict[Hashable, Any]]:
"""Validate that all dependencies exist for key.""" """Validate that all dependencies exist for key."""
def validator(value): def validator(value: Dict[Hashable, Any]) -> Dict[Hashable, Any]:
"""Test dependencies.""" """Test dependencies."""
if not isinstance(value, dict): if not isinstance(value, dict):
raise vol.Invalid("key dependencies require a dict") raise vol.Invalid("key dependencies require a dict")
@ -696,7 +709,7 @@ def key_dependency(key, dependency):
return validator return validator
def custom_serializer(schema): def custom_serializer(schema: Any) -> Any:
"""Serialize additional types for voluptuous_serialize.""" """Serialize additional types for voluptuous_serialize."""
if schema is positive_time_period_dict: if schema is positive_time_period_dict:
return {"type": "positive_time_period_dict"} return {"type": "positive_time_period_dict"}

View File

@ -12,8 +12,7 @@ from homeassistant.loader import bind_hass
from .typing import HomeAssistantType from .typing import HomeAssistantType
# mypy: allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_UNDEF = object() _UNDEF = object()
@ -71,10 +70,11 @@ def format_mac(mac: str) -> str:
class DeviceRegistry: class DeviceRegistry:
"""Class to hold a registry of devices.""" """Class to hold a registry of devices."""
def __init__(self, hass): devices: Dict[str, DeviceEntry]
def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize the device registry.""" """Initialize the device registry."""
self.hass = hass self.hass = hass
self.devices = None
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
@callback @callback

View File

@ -11,7 +11,7 @@ import asyncio
from collections import OrderedDict from collections import OrderedDict
from itertools import chain from itertools import chain
import logging import logging
from typing import Any, Dict, Iterable, List, Optional, cast from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, cast
import attr import attr
@ -23,6 +23,9 @@ from homeassistant.util.yaml import load_yaml
from .typing import HomeAssistantType from .typing import HomeAssistantType
if TYPE_CHECKING:
from homeassistant.config_entries import ConfigEntry # noqa: F401
# mypy: allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-defs, no-check-untyped-defs
PATH_REGISTRY = "entity_registry.yaml" PATH_REGISTRY = "entity_registry.yaml"
@ -48,7 +51,7 @@ class RegistryEntry:
unique_id = attr.ib(type=str) unique_id = attr.ib(type=str)
platform = attr.ib(type=str) platform = attr.ib(type=str)
name = attr.ib(type=str, default=None) name = attr.ib(type=str, default=None)
device_id = attr.ib(type=str, default=None) device_id: Optional[str] = attr.ib(default=None)
config_entry_id: Optional[str] = attr.ib(default=None) config_entry_id: Optional[str] = attr.ib(default=None)
disabled_by = attr.ib( disabled_by = attr.ib(
type=Optional[str], type=Optional[str],
@ -135,16 +138,16 @@ class EntityRegistry:
@callback @callback
def async_get_or_create( def async_get_or_create(
self, self,
domain, domain: str,
platform, platform: str,
unique_id, unique_id: str,
*, *,
suggested_object_id=None, suggested_object_id: Optional[str] = None,
config_entry=None, config_entry: Optional["ConfigEntry"] = None,
device_id=None, device_id: Optional[str] = None,
known_object_ids=None, known_object_ids: Optional[Iterable[str]] = None,
disabled_by=None, disabled_by: Optional[str] = None,
): ) -> RegistryEntry:
"""Get entity. Create if it doesn't exist.""" """Get entity. Create if it doesn't exist."""
config_entry_id = None config_entry_id = None
if config_entry: if config_entry:
@ -153,7 +156,7 @@ class EntityRegistry:
entity_id = self.async_get_entity_id(domain, platform, unique_id) entity_id = self.async_get_entity_id(domain, platform, unique_id)
if entity_id: if entity_id:
return self._async_update_entity( return self._async_update_entity( # type: ignore
entity_id, entity_id,
config_entry_id=config_entry_id or _UNDEF, config_entry_id=config_entry_id or _UNDEF,
device_id=device_id or _UNDEF, device_id=device_id or _UNDEF,
@ -228,12 +231,15 @@ class EntityRegistry:
disabled_by=_UNDEF, disabled_by=_UNDEF,
): ):
"""Update properties of an entity.""" """Update properties of an entity."""
return self._async_update_entity( return cast( # cast until we have _async_update_entity type hinted
entity_id, RegistryEntry,
name=name, self._async_update_entity(
new_entity_id=new_entity_id, entity_id,
new_unique_id=new_unique_id, name=name,
disabled_by=disabled_by, new_entity_id=new_entity_id,
new_unique_id=new_unique_id,
disabled_by=disabled_by,
),
) )
@callback @callback

View File

@ -1,8 +1,7 @@
"""Helpers for logging allowing more advanced logging styles to be used.""" """Helpers for logging allowing more advanced logging styles to be used."""
import inspect import inspect
import logging import logging
from typing import Any, Mapping, MutableMapping, Optional, Tuple
# mypy: allow-untyped-defs, no-check-untyped-defs
class KeywordMessage: class KeywordMessage:
@ -12,13 +11,13 @@ class KeywordMessage:
Adapted from: https://stackoverflow.com/a/24683360/2267718 Adapted from: https://stackoverflow.com/a/24683360/2267718
""" """
def __init__(self, fmt, args, kwargs): def __init__(self, fmt: Any, args: Any, kwargs: Mapping[str, Any]) -> None:
"""Initialize a new BraceMessage object.""" """Initialize a new KeywordMessage object."""
self._fmt = fmt self._fmt = fmt
self._args = args self._args = args
self._kwargs = kwargs self._kwargs = kwargs
def __str__(self): def __str__(self) -> str:
"""Convert the object to a string for logging.""" """Convert the object to a string for logging."""
return str(self._fmt).format(*self._args, **self._kwargs) return str(self._fmt).format(*self._args, **self._kwargs)
@ -26,26 +25,30 @@ class KeywordMessage:
class KeywordStyleAdapter(logging.LoggerAdapter): class KeywordStyleAdapter(logging.LoggerAdapter):
"""Represents an adapter wrapping the logger allowing KeywordMessages.""" """Represents an adapter wrapping the logger allowing KeywordMessages."""
def __init__(self, logger, extra=None): def __init__(
self, logger: logging.Logger, extra: Optional[Mapping[str, Any]] = None
) -> None:
"""Initialize a new StyleAdapter for the provided logger.""" """Initialize a new StyleAdapter for the provided logger."""
super().__init__(logger, extra or {}) super().__init__(logger, extra or {})
def log(self, level, msg, *args, **kwargs): def log(self, level: int, msg: Any, *args: Any, **kwargs: Any) -> None:
"""Log the message provided at the appropriate level.""" """Log the message provided at the appropriate level."""
if self.isEnabledFor(level): if self.isEnabledFor(level):
msg, log_kwargs = self.process(msg, kwargs) msg, log_kwargs = self.process(msg, kwargs)
self.logger._log( # pylint: disable=protected-access self.logger._log( # type: ignore # pylint: disable=protected-access
level, KeywordMessage(msg, args, kwargs), (), **log_kwargs level, KeywordMessage(msg, args, kwargs), (), **log_kwargs
) )
def process(self, msg, kwargs): def process(
self, msg: Any, kwargs: MutableMapping[str, Any]
) -> Tuple[Any, MutableMapping[str, Any]]:
"""Process the keyward args in preparation for logging.""" """Process the keyward args in preparation for logging."""
return ( return (
msg, msg,
{ {
k: kwargs[k] k: kwargs[k]
for k in inspect.getfullargspec( for k in inspect.getfullargspec(
self.logger._log # pylint: disable=protected-access self.logger._log # type: ignore # pylint: disable=protected-access
).args[1:] ).args[1:]
if k in kwargs if k in kwargs
}, },