Improvement typing core (#2624)

* Add package typing

* Add util/location typing

* FIX: lint wrong order of imports

* Fix sometyping and add helpers/entity typing

* Mypy import trick

* Add asteroid to test requiremts to fix pylint issue

* Fix deprecated function isSet for is_set

* Add loader.py typing

* Improve typing bootstrap
This commit is contained in:
Fabian Heredia Montiel 2016-07-27 22:33:49 -05:00 committed by Paulus Schoutsen
parent 8c728d1b4e
commit ae97218582
9 changed files with 99 additions and 56 deletions

View File

@ -6,11 +6,12 @@ import os
import sys import sys
from collections import defaultdict from collections import defaultdict
from threading import RLock from threading import RLock
from types import ModuleType
from typing import Any, Optional, Dict from typing import Any, Optional, Dict
import voluptuous as vol import voluptuous as vol
import homeassistant.components as core_components import homeassistant.components as core_components
from homeassistant.components import group, persistent_notification from homeassistant.components import group, persistent_notification
import homeassistant.config as conf_util import homeassistant.config as conf_util
@ -32,7 +33,8 @@ ATTR_COMPONENT = 'component'
ERROR_LOG_FILENAME = 'home-assistant.log' ERROR_LOG_FILENAME = 'home-assistant.log'
def setup_component(hass, domain, config=None): def setup_component(hass: core.HomeAssistant, domain: str,
config: Optional[Dict]=None) -> bool:
"""Setup a component and all its dependencies.""" """Setup a component and all its dependencies."""
if domain in hass.config.components: if domain in hass.config.components:
return True return True
@ -55,7 +57,8 @@ def setup_component(hass, domain, config=None):
return True return True
def _handle_requirements(hass, component, name): def _handle_requirements(hass: core.HomeAssistant, component,
name: str) -> bool:
"""Install the requirements for a component.""" """Install the requirements for a component."""
if hass.config.skip_pip or not hasattr(component, 'REQUIREMENTS'): if hass.config.skip_pip or not hasattr(component, 'REQUIREMENTS'):
return True return True
@ -69,7 +72,7 @@ def _handle_requirements(hass, component, name):
return True return True
def _setup_component(hass, domain, config): def _setup_component(hass: core.HomeAssistant, domain: str, config) -> bool:
"""Setup a component for Home Assistant.""" """Setup a component for Home Assistant."""
# pylint: disable=too-many-return-statements,too-many-branches # pylint: disable=too-many-return-statements,too-many-branches
# pylint: disable=too-many-statements # pylint: disable=too-many-statements
@ -178,7 +181,8 @@ def _setup_component(hass, domain, config):
return True return True
def prepare_setup_platform(hass, config, domain, platform_name): def prepare_setup_platform(hass: core.HomeAssistant, config, domain: str,
platform_name: str) -> Optional[ModuleType]:
"""Load a platform and makes sure dependencies are setup.""" """Load a platform and makes sure dependencies are setup."""
_ensure_loader_prepared(hass) _ensure_loader_prepared(hass)
@ -309,7 +313,8 @@ def from_config_file(config_path: str,
skip_pip=skip_pip) skip_pip=skip_pip)
def enable_logging(hass, verbose=False, log_rotate_days=None): def enable_logging(hass: core.HomeAssistant, verbose: bool=False,
log_rotate_days=None) -> None:
"""Setup the logging.""" """Setup the logging."""
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
fmt = ("%(log_color)s%(asctime)s %(levelname)s (%(threadName)s) " fmt = ("%(log_color)s%(asctime)s %(levelname)s (%(threadName)s) "
@ -360,12 +365,12 @@ def enable_logging(hass, verbose=False, log_rotate_days=None):
'Unable to setup error log %s (access denied)', err_log_path) 'Unable to setup error log %s (access denied)', err_log_path)
def _ensure_loader_prepared(hass): def _ensure_loader_prepared(hass: core.HomeAssistant) -> None:
"""Ensure Home Assistant loader is prepared.""" """Ensure Home Assistant loader is prepared."""
if not loader.PREPARED: if not loader.PREPARED:
loader.prepare(hass) loader.prepare(hass)
def _mount_local_lib_path(config_dir): def _mount_local_lib_path(config_dir: str) -> None:
"""Add local library to Python Path.""" """Add local library to Python Path."""
sys.path.insert(0, os.path.join(config_dir, 'deps')) sys.path.insert(0, os.path.join(config_dir, 'deps'))

View File

@ -158,14 +158,14 @@ class HomeAssistant(object):
except AttributeError: except AttributeError:
pass pass
try: try:
while not request_shutdown.isSet(): while not request_shutdown.is_set():
time.sleep(1) time.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
finally: finally:
self.stop() self.stop()
return RESTART_EXIT_CODE if request_restart.isSet() else 0 return RESTART_EXIT_CODE if request_restart.is_set() else 0
def stop(self) -> None: def stop(self) -> None:
"""Stop Home Assistant and shuts down all threads.""" """Stop Home Assistant and shuts down all threads."""
@ -233,7 +233,7 @@ class Event(object):
class EventBus(object): class EventBus(object):
"""Allows firing of and listening for events.""" """Allows firing of and listening for events."""
def __init__(self, pool: util.ThreadPool): def __init__(self, pool: util.ThreadPool) -> None:
"""Initialize a new event bus.""" """Initialize a new event bus."""
self._listeners = {} self._listeners = {}
self._lock = threading.Lock() self._lock = threading.Lock()
@ -792,7 +792,7 @@ def create_timer(hass, interval=TIMER_INTERVAL):
calc_now = dt_util.utcnow calc_now = dt_util.utcnow
while not stop_event.isSet(): while not stop_event.is_set():
now = calc_now() now = calc_now()
# First check checks if we are not on a second matching the # First check checks if we are not on a second matching the
@ -816,7 +816,7 @@ def create_timer(hass, interval=TIMER_INTERVAL):
last_fired_on_second = now.second last_fired_on_second = now.second
# Event might have been set while sleeping # Event might have been set while sleeping
if not stop_event.isSet(): if not stop_event.is_set():
try: try:
hass.bus.fire(EVENT_TIME_CHANGED, {ATTR_NOW: now}) hass.bus.fire(EVENT_TIME_CHANGED, {ATTR_NOW: now})
except HomeAssistantError: except HomeAssistantError:

View File

@ -1,10 +1,20 @@
"""Helper methods for components within Home Assistant.""" """Helper methods for components within Home Assistant."""
import re import re
from typing import Any, Iterable, Tuple, List, Dict
from homeassistant.const import CONF_PLATFORM from homeassistant.const import CONF_PLATFORM
# Typing Imports and TypeAlias
# pylint: disable=using-constant-test,unused-import
if False:
from logging import Logger # NOQA
def validate_config(config, items, logger): # pylint: disable=invalid-name
ConfigType = Dict[str, Any]
def validate_config(config: ConfigType, items: Dict, logger: 'Logger') -> bool:
"""Validate if all items are available in the configuration. """Validate if all items are available in the configuration.
config is the general dictionary with all the configurations. config is the general dictionary with all the configurations.
@ -29,7 +39,8 @@ def validate_config(config, items, logger):
return not errors_found return not errors_found
def config_per_platform(config, domain): def config_per_platform(config: ConfigType,
domain: str) -> Iterable[Tuple[Any, Any]]:
"""Generator to break a component config into different platforms. """Generator to break a component config into different platforms.
For example, will find 'switch', 'switch 2', 'switch 3', .. etc For example, will find 'switch', 'switch 2', 'switch 3', .. etc
@ -48,7 +59,7 @@ def config_per_platform(config, domain):
yield platform, item yield platform, item
def extract_domain_configs(config, domain): def extract_domain_configs(config: ConfigType, domain: str) -> List[str]:
"""Extract keys from config for given domain name.""" """Extract keys from config for given domain name."""
pattern = re.compile(r'^{}(| .+)$'.format(domain)) pattern = re.compile(r'^{}(| .+)$'.format(domain))
return [key for key in config.keys() if pattern.match(key)] return [key for key in config.keys() if pattern.match(key)]

View File

@ -2,6 +2,8 @@
import logging import logging
import re import re
from typing import Any, Optional, List, Dict
from homeassistant.const import ( from homeassistant.const import (
ATTR_ASSUMED_STATE, ATTR_FRIENDLY_NAME, ATTR_HIDDEN, ATTR_ICON, ATTR_ASSUMED_STATE, ATTR_FRIENDLY_NAME, ATTR_HIDDEN, ATTR_ICON,
ATTR_UNIT_OF_MEASUREMENT, DEVICE_DEFAULT_NAME, STATE_OFF, STATE_ON, ATTR_UNIT_OF_MEASUREMENT, DEVICE_DEFAULT_NAME, STATE_OFF, STATE_ON,
@ -10,8 +12,12 @@ from homeassistant.const import (
from homeassistant.exceptions import NoEntitySpecifiedError from homeassistant.exceptions import NoEntitySpecifiedError
from homeassistant.util import ensure_unique_string, slugify from homeassistant.util import ensure_unique_string, slugify
# pylint: disable=using-constant-test,unused-import
if False:
from homeassistant.core import HomeAssistant # NOQA
# Entity attributes that we will overwrite # Entity attributes that we will overwrite
_OVERWRITE = {} _OVERWRITE = {} # type: Dict[str, Any]
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -19,7 +25,9 @@ _LOGGER = logging.getLogger(__name__)
ENTITY_ID_PATTERN = re.compile(r"^(\w+)\.(\w+)$") ENTITY_ID_PATTERN = re.compile(r"^(\w+)\.(\w+)$")
def generate_entity_id(entity_id_format, name, current_ids=None, hass=None): def generate_entity_id(entity_id_format: str, name: Optional[str],
current_ids: Optional[List[str]]=None,
hass: 'Optional[HomeAssistant]'=None) -> str:
"""Generate a unique entity ID based on given entity IDs or used IDs.""" """Generate a unique entity ID based on given entity IDs or used IDs."""
name = (name or DEVICE_DEFAULT_NAME).lower() name = (name or DEVICE_DEFAULT_NAME).lower()
if current_ids is None: if current_ids is None:
@ -32,19 +40,19 @@ def generate_entity_id(entity_id_format, name, current_ids=None, hass=None):
entity_id_format.format(slugify(name)), current_ids) entity_id_format.format(slugify(name)), current_ids)
def set_customize(customize): def set_customize(customize: Dict[str, Any]) -> None:
"""Overwrite all current customize settings.""" """Overwrite all current customize settings."""
global _OVERWRITE global _OVERWRITE
_OVERWRITE = {key.lower(): val for key, val in customize.items()} _OVERWRITE = {key.lower(): val for key, val in customize.items()}
def split_entity_id(entity_id): def split_entity_id(entity_id: str) -> List[str]:
"""Split a state entity_id into domain, object_id.""" """Split a state entity_id into domain, object_id."""
return entity_id.split(".", 1) return entity_id.split(".", 1)
def valid_entity_id(entity_id): def valid_entity_id(entity_id: str) -> bool:
"""Test if an entity ID is a valid format.""" """Test if an entity ID is a valid format."""
return ENTITY_ID_PATTERN.match(entity_id) is not None return ENTITY_ID_PATTERN.match(entity_id) is not None
@ -57,7 +65,7 @@ class Entity(object):
# The properties and methods here are safe to overwrite when inheriting # The properties and methods here are safe to overwrite when inheriting
# this class. These may be used to customize the behavior of the entity. # this class. These may be used to customize the behavior of the entity.
@property @property
def should_poll(self): def should_poll(self) -> bool:
"""Return True if entity has to be polled for state. """Return True if entity has to be polled for state.
False if entity pushes its state to HA. False if entity pushes its state to HA.
@ -65,17 +73,17 @@ class Entity(object):
return True return True
@property @property
def unique_id(self): def unique_id(self) -> str:
"""Return an unique ID.""" """Return an unique ID."""
return "{}.{}".format(self.__class__, id(self)) return "{}.{}".format(self.__class__, id(self))
@property @property
def name(self): def name(self) -> Optional[str]:
"""Return the name of the entity.""" """Return the name of the entity."""
return None return None
@property @property
def state(self): def state(self) -> str:
"""Return the state of the entity.""" """Return the state of the entity."""
return STATE_UNKNOWN return STATE_UNKNOWN
@ -111,22 +119,22 @@ class Entity(object):
return None return None
@property @property
def hidden(self): def hidden(self) -> bool:
"""Return True if the entity should be hidden from UIs.""" """Return True if the entity should be hidden from UIs."""
return False return False
@property @property
def available(self): def available(self) -> bool:
"""Return True if entity is available.""" """Return True if entity is available."""
return True return True
@property @property
def assumed_state(self): def assumed_state(self) -> bool:
"""Return True if unable to access real state of the entity.""" """Return True if unable to access real state of the entity."""
return False return False
@property @property
def force_update(self): def force_update(self) -> bool:
"""Return True if state updates should be forced. """Return True if state updates should be forced.
If True, a state change will be triggered anytime the state property is If True, a state change will be triggered anytime the state property is
@ -138,14 +146,14 @@ class Entity(object):
"""Retrieve latest state.""" """Retrieve latest state."""
pass pass
entity_id = None entity_id = None # type: str
# DO NOT OVERWRITE # DO NOT OVERWRITE
# These properties and methods are either managed by Home Assistant or they # These properties and methods are either managed by Home Assistant or they
# are used to perform a very specific function. Overwriting these may # are used to perform a very specific function. Overwriting these may
# produce undesirable effects in the entity's operation. # produce undesirable effects in the entity's operation.
hass = None hass = None # type: Optional[HomeAssistant]
def update_ha_state(self, force_refresh=False): def update_ha_state(self, force_refresh=False):
"""Update Home Assistant with current state of entity. """Update Home Assistant with current state of entity.
@ -232,24 +240,24 @@ class ToggleEntity(Entity):
# pylint: disable=no-self-use # pylint: disable=no-self-use
@property @property
def state(self): def state(self) -> str:
"""Return the state.""" """Return the state."""
return STATE_ON if self.is_on else STATE_OFF return STATE_ON if self.is_on else STATE_OFF
@property @property
def is_on(self): def is_on(self) -> bool:
"""Return True if entity is on.""" """Return True if entity is on."""
raise NotImplementedError() raise NotImplementedError()
def turn_on(self, **kwargs): def turn_on(self, **kwargs) -> None:
"""Turn the entity on.""" """Turn the entity on."""
raise NotImplementedError() raise NotImplementedError()
def turn_off(self, **kwargs): def turn_off(self, **kwargs) -> None:
"""Turn the entity off.""" """Turn the entity off."""
raise NotImplementedError() raise NotImplementedError()
def toggle(self, **kwargs): def toggle(self, **kwargs) -> None:
"""Toggle the entity off.""" """Toggle the entity off."""
if self.is_on: if self.is_on:
self.turn_off(**kwargs) self.turn_off(**kwargs)

View File

@ -16,21 +16,30 @@ import os
import pkgutil import pkgutil
import sys import sys
from types import ModuleType
# pylint: disable=unused-import
from typing import Optional, Sequence, Set, Dict # NOQA
from homeassistant.const import PLATFORM_FORMAT from homeassistant.const import PLATFORM_FORMAT
from homeassistant.util import OrderedSet from homeassistant.util import OrderedSet
# Typing imports
# pylint: disable=using-constant-test,unused-import
if False:
from homeassistant.core import HomeAssistant # NOQA
PREPARED = False PREPARED = False
# List of available components # List of available components
AVAILABLE_COMPONENTS = [] AVAILABLE_COMPONENTS = [] # type: List[str]
# Dict of loaded components mapped name => module # Dict of loaded components mapped name => module
_COMPONENT_CACHE = {} _COMPONENT_CACHE = {} # type: Dict[str, ModuleType]
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def prepare(hass): def prepare(hass: 'HomeAssistant'):
"""Prepare the loading of components.""" """Prepare the loading of components."""
global PREPARED # pylint: disable=global-statement global PREPARED # pylint: disable=global-statement
@ -71,19 +80,19 @@ def prepare(hass):
PREPARED = True PREPARED = True
def set_component(comp_name, component): def set_component(comp_name: str, component: ModuleType) -> None:
"""Set a component in the cache.""" """Set a component in the cache."""
_check_prepared() _check_prepared()
_COMPONENT_CACHE[comp_name] = component _COMPONENT_CACHE[comp_name] = component
def get_platform(domain, platform): def get_platform(domain: str, platform: str) -> Optional[ModuleType]:
"""Try to load specified platform.""" """Try to load specified platform."""
return get_component(PLATFORM_FORMAT.format(domain, platform)) return get_component(PLATFORM_FORMAT.format(domain, platform))
def get_component(comp_name): def get_component(comp_name) -> Optional[ModuleType]:
"""Try to load specified component. """Try to load specified component.
Looks in config dir first, then built-in components. Looks in config dir first, then built-in components.
@ -148,7 +157,7 @@ def get_component(comp_name):
return None return None
def load_order_components(components): def load_order_components(components: Sequence[str]) -> OrderedSet:
"""Take in a list of components we want to load. """Take in a list of components we want to load.
- filters out components we cannot load - filters out components we cannot load
@ -178,7 +187,7 @@ def load_order_components(components):
return load_order return load_order
def load_order_component(comp_name): def load_order_component(comp_name: str) -> OrderedSet:
"""Return an OrderedSet of components in the correct order of loading. """Return an OrderedSet of components in the correct order of loading.
Raises HomeAssistantError if a circular dependency is detected. Raises HomeAssistantError if a circular dependency is detected.
@ -187,7 +196,8 @@ def load_order_component(comp_name):
return _load_order_component(comp_name, OrderedSet(), set()) return _load_order_component(comp_name, OrderedSet(), set())
def _load_order_component(comp_name, load_order, loading): def _load_order_component(comp_name: str, load_order: OrderedSet,
loading: Set) -> OrderedSet:
"""Recursive function to get load order of components.""" """Recursive function to get load order of components."""
component = get_component(comp_name) component = get_component(comp_name)
@ -224,7 +234,7 @@ def _load_order_component(comp_name, load_order, loading):
return load_order return load_order
def _check_prepared(): def _check_prepared() -> None:
"""Issue a warning if loader.prepare() has never been called.""" """Issue a warning if loader.prepare() has never been called."""
if not PREPARED: if not PREPARED:
_LOGGER.warning(( _LOGGER.warning((

View File

@ -12,7 +12,7 @@ import string
from functools import wraps from functools import wraps
from types import MappingProxyType from types import MappingProxyType
from typing import Any from typing import Any, Sequence
from .dt import as_local, utcnow from .dt import as_local, utcnow
@ -31,7 +31,7 @@ def sanitize_path(path):
return RE_SANITIZE_PATH.sub("", path) return RE_SANITIZE_PATH.sub("", path)
def slugify(text): def slugify(text: str) -> str:
"""Slugify a given text.""" """Slugify a given text."""
text = text.lower().replace(" ", "_") text = text.lower().replace(" ", "_")
@ -59,17 +59,18 @@ def convert(value, to_type, default=None):
return default return default
def ensure_unique_string(preferred_string, current_strings): def ensure_unique_string(preferred_string: str,
current_strings: Sequence[str]) -> str:
"""Return a string that is not present in current_strings. """Return a string that is not present in current_strings.
If preferred string exists will append _2, _3, .. If preferred string exists will append _2, _3, ..
""" """
test_string = preferred_string test_string = preferred_string
current_strings = set(current_strings) current_strings_set = set(current_strings)
tries = 1 tries = 1
while test_string in current_strings: while test_string in current_strings_set:
tries += 1 tries += 1
test_string = "{}_{}".format(preferred_string, tries) test_string = "{}_{}".format(preferred_string, tries)

View File

@ -5,8 +5,11 @@ detect_location_info and elevation are mocked by default during tests.
""" """
import collections import collections
import math import math
from typing import Any, Optional, Tuple, Dict
import requests import requests
ELEVATION_URL = 'http://maps.googleapis.com/maps/api/elevation/json' ELEVATION_URL = 'http://maps.googleapis.com/maps/api/elevation/json'
FREEGEO_API = 'https://freegeoip.io/json/' FREEGEO_API = 'https://freegeoip.io/json/'
IP_API = 'http://ip-api.com/json' IP_API = 'http://ip-api.com/json'
@ -81,7 +84,8 @@ def elevation(latitude, longitude):
# Source: https://github.com/maurycyp/vincenty # Source: https://github.com/maurycyp/vincenty
# License: https://github.com/maurycyp/vincenty/blob/master/LICENSE # License: https://github.com/maurycyp/vincenty/blob/master/LICENSE
# pylint: disable=too-many-locals, invalid-name, unused-variable # pylint: disable=too-many-locals, invalid-name, unused-variable
def vincenty(point1, point2, miles=False): def vincenty(point1: Tuple[float, float], point2: Tuple[float, float],
miles: bool=False) -> Optional[float]:
""" """
Vincenty formula (inverse method) to calculate the distance. Vincenty formula (inverse method) to calculate the distance.
@ -148,7 +152,7 @@ def vincenty(point1, point2, miles=False):
return round(s, 6) return round(s, 6)
def _get_freegeoip(): def _get_freegeoip() -> Optional[Dict[str, Any]]:
"""Query freegeoip.io for location data.""" """Query freegeoip.io for location data."""
try: try:
raw_info = requests.get(FREEGEO_API, timeout=5).json() raw_info = requests.get(FREEGEO_API, timeout=5).json()
@ -169,7 +173,7 @@ def _get_freegeoip():
} }
def _get_ip_api(): def _get_ip_api() -> Optional[Dict[str, Any]]:
"""Query ip-api.com for location data.""" """Query ip-api.com for location data."""
try: try:
raw_info = requests.get(IP_API, timeout=5).json() raw_info = requests.get(IP_API, timeout=5).json()

View File

@ -6,13 +6,16 @@ import sys
import threading import threading
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Optional
import pkg_resources import pkg_resources
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
INSTALL_LOCK = threading.Lock() INSTALL_LOCK = threading.Lock()
def install_package(package, upgrade=True, target=None): def install_package(package: str, upgrade: bool=True,
target: Optional[str]=None) -> bool:
"""Install a package on PyPi. Accepts pip compatible package strings. """Install a package on PyPi. Accepts pip compatible package strings.
Return boolean if install successful. Return boolean if install successful.
@ -36,7 +39,7 @@ def install_package(package, upgrade=True, target=None):
return False return False
def check_package_exists(package, lib_dir): def check_package_exists(package: str, lib_dir: str) -> bool:
"""Check if a package is installed globally or in lib_dir. """Check if a package is installed globally or in lib_dir.
Returns True when the requirement is met. Returns True when the requirement is met.

View File

@ -1,5 +1,6 @@
flake8>=2.6.0 flake8>=2.6.0
pylint>=1.5.6 pylint>=1.5.6
astroid>=1.4.8
coveralls>=1.1 coveralls>=1.1
pytest>=2.9.2 pytest>=2.9.2
pytest-cov>=2.2.1 pytest-cov>=2.2.1