Add type hints to helpers.condition (#20266)

This commit is contained in:
Ville Skyttä 2019-01-21 01:03:12 +02:00 committed by Fabian Affolter
parent 5b8cb10ad7
commit 58bb6f2e99
6 changed files with 101 additions and 55 deletions

View File

@ -56,7 +56,7 @@ def async_active_zone(hass, latitude, longitude, radius=0):
return closest return closest
def in_zone(zone, latitude, longitude, radius=0): def in_zone(zone, latitude, longitude, radius=0) -> bool:
"""Test if given latitude, longitude is in given zone. """Test if given latitude, longitude is in given zone.
Async friendly. Async friendly.

View File

@ -678,7 +678,7 @@ class State:
"State max length is 255 characters.").format(entity_id)) "State max length is 255 characters.").format(entity_id))
self.entity_id = entity_id.lower() self.entity_id = entity_id.lower()
self.state = state self.state = state # type: str
self.attributes = MappingProxyType(attributes or {}) self.attributes = MappingProxyType(attributes or {})
self.last_updated = last_updated or dt_util.utcnow() self.last_updated = last_updated or dt_util.utcnow()
self.last_changed = last_changed or self.last_updated self.last_changed = last_changed or self.last_updated

View File

@ -1,12 +1,14 @@
"""Offer reusable conditions.""" """Offer reusable conditions."""
from datetime import timedelta from datetime import datetime, timedelta
import functools as ft import functools as ft
import logging import logging
import sys import sys
from typing import Callable, Container, Optional, Union, cast
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, State
from homeassistant.components import zone as zone_cmp from homeassistant.components import zone as zone_cmp
from homeassistant.const import ( from homeassistant.const import (
ATTR_GPS_ACCURACY, ATTR_LATITUDE, ATTR_LONGITUDE, ATTR_GPS_ACCURACY, ATTR_LATITUDE, ATTR_LONGITUDE,
@ -29,25 +31,30 @@ _LOGGER = logging.getLogger(__name__)
# pylint: disable=invalid-name # pylint: disable=invalid-name
def _threaded_factory(async_factory): def _threaded_factory(async_factory:
Callable[[ConfigType, bool], Callable[..., bool]]) \
-> Callable[[ConfigType, bool], Callable[..., bool]]:
"""Create threaded versions of async factories.""" """Create threaded versions of async factories."""
@ft.wraps(async_factory) @ft.wraps(async_factory)
def factory(config, config_validation=True): def factory(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Threaded factory.""" """Threaded factory."""
async_check = async_factory(config, config_validation) async_check = async_factory(config, config_validation)
def condition_if(hass, variables=None): def condition_if(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Validate condition.""" """Validate condition."""
return run_callback_threadsafe( return cast(bool, run_callback_threadsafe(
hass.loop, async_check, hass, variables, hass.loop, async_check, hass, variables,
).result() ).result())
return condition_if return condition_if
return factory return factory
def async_from_config(config: ConfigType, config_validation: bool = True): def async_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Turn a condition configuration into a method. """Turn a condition configuration into a method.
Should be run on the event loop. Should be run on the event loop.
@ -64,20 +71,22 @@ def async_from_config(config: ConfigType, config_validation: bool = True):
raise HomeAssistantError('Invalid condition "{}" specified {}'.format( raise HomeAssistantError('Invalid condition "{}" specified {}'.format(
config.get(CONF_CONDITION), config)) config.get(CONF_CONDITION), config))
return factory(config, config_validation) return cast(Callable[..., bool], factory(config, config_validation))
from_config = _threaded_factory(async_from_config) from_config = _threaded_factory(async_from_config)
def async_and_from_config(config: ConfigType, config_validation: bool = True): def async_and_from_config(config: ConfigType,
config_validation: bool = True) \
-> Callable[..., bool]:
"""Create multi condition matcher using 'AND'.""" """Create multi condition matcher using 'AND'."""
if config_validation: if config_validation:
config = cv.AND_CONDITION_SCHEMA(config) config = cv.AND_CONDITION_SCHEMA(config)
checks = None checks = None
def if_and_condition(hass: HomeAssistant, def if_and_condition(hass: HomeAssistant,
variables=None) -> bool: variables: TemplateVarsType = None) -> bool:
"""Test and condition.""" """Test and condition."""
nonlocal checks nonlocal checks
@ -101,14 +110,16 @@ def async_and_from_config(config: ConfigType, config_validation: bool = True):
and_from_config = _threaded_factory(async_and_from_config) and_from_config = _threaded_factory(async_and_from_config)
def async_or_from_config(config: ConfigType, config_validation: bool = True): def async_or_from_config(config: ConfigType,
config_validation: bool = True) \
-> Callable[..., bool]:
"""Create multi condition matcher using 'OR'.""" """Create multi condition matcher using 'OR'."""
if config_validation: if config_validation:
config = cv.OR_CONDITION_SCHEMA(config) config = cv.OR_CONDITION_SCHEMA(config)
checks = None checks = None
def if_or_condition(hass: HomeAssistant, def if_or_condition(hass: HomeAssistant,
variables=None) -> bool: variables: TemplateVarsType = None) -> bool:
"""Test and condition.""" """Test and condition."""
nonlocal checks nonlocal checks
@ -131,17 +142,22 @@ def async_or_from_config(config: ConfigType, config_validation: bool = True):
or_from_config = _threaded_factory(async_or_from_config) or_from_config = _threaded_factory(async_or_from_config)
def numeric_state(hass: HomeAssistant, entity, below=None, above=None, def numeric_state(hass: HomeAssistant, entity: Union[None, str, State],
value_template=None, variables=None): below: Optional[float] = None, above: Optional[float] = None,
value_template: Optional[Template] = None,
variables: TemplateVarsType = None) -> bool:
"""Test a numeric state condition.""" """Test a numeric state condition."""
return run_callback_threadsafe( return cast(bool, run_callback_threadsafe(
hass.loop, async_numeric_state, hass, entity, below, above, hass.loop, async_numeric_state, hass, entity, below, above,
value_template, variables, value_template, variables,
).result() ).result())
def async_numeric_state(hass: HomeAssistant, entity, below=None, above=None, def async_numeric_state(hass: HomeAssistant, entity: Union[None, str, State],
value_template=None, variables=None): below: Optional[float] = None,
above: Optional[float] = None,
value_template: Optional[Template] = None,
variables: TemplateVarsType = None) -> bool:
"""Test a numeric state condition.""" """Test a numeric state condition."""
if isinstance(entity, str): if isinstance(entity, str):
entity = hass.states.get(entity) entity = hass.states.get(entity)
@ -164,22 +180,24 @@ def async_numeric_state(hass: HomeAssistant, entity, below=None, above=None,
return False return False
try: try:
value = float(value) fvalue = float(value)
except ValueError: except ValueError:
_LOGGER.warning("Value cannot be processed as a number: %s " _LOGGER.warning("Value cannot be processed as a number: %s "
"(Offending entity: %s)", entity, value) "(Offending entity: %s)", entity, value)
return False return False
if below is not None and value >= below: if below is not None and fvalue >= below:
return False return False
if above is not None and value <= above: if above is not None and fvalue <= above:
return False return False
return True return True
def async_numeric_state_from_config(config, config_validation=True): def async_numeric_state_from_config(config: ConfigType,
config_validation: bool = True) \
-> Callable[..., bool]:
"""Wrap action method with state based condition.""" """Wrap action method with state based condition."""
if config_validation: if config_validation:
config = cv.NUMERIC_STATE_CONDITION_SCHEMA(config) config = cv.NUMERIC_STATE_CONDITION_SCHEMA(config)
@ -188,7 +206,8 @@ def async_numeric_state_from_config(config, config_validation=True):
above = config.get(CONF_ABOVE) above = config.get(CONF_ABOVE)
value_template = config.get(CONF_VALUE_TEMPLATE) value_template = config.get(CONF_VALUE_TEMPLATE)
def if_numeric_state(hass, variables=None): def if_numeric_state(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Test numeric state condition.""" """Test numeric state condition."""
if value_template is not None: if value_template is not None:
value_template.hass = hass value_template.hass = hass
@ -202,7 +221,8 @@ def async_numeric_state_from_config(config, config_validation=True):
numeric_state_from_config = _threaded_factory(async_numeric_state_from_config) numeric_state_from_config = _threaded_factory(async_numeric_state_from_config)
def state(hass, entity, req_state, for_period=None): def state(hass: HomeAssistant, entity: Union[None, str, State], req_state: str,
for_period: Optional[timedelta] = None) -> bool:
"""Test if state matches requirements. """Test if state matches requirements.
Async friendly. Async friendly.
@ -212,6 +232,7 @@ def state(hass, entity, req_state, for_period=None):
if entity is None: if entity is None:
return False return False
assert isinstance(entity, State)
is_state = entity.state == req_state is_state = entity.state == req_state
@ -221,22 +242,26 @@ def state(hass, entity, req_state, for_period=None):
return dt_util.utcnow() - for_period > entity.last_changed return dt_util.utcnow() - for_period > entity.last_changed
def state_from_config(config, config_validation=True): def state_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Wrap action method with state based condition.""" """Wrap action method with state based condition."""
if config_validation: if config_validation:
config = cv.STATE_CONDITION_SCHEMA(config) config = cv.STATE_CONDITION_SCHEMA(config)
entity_id = config.get(CONF_ENTITY_ID) entity_id = config.get(CONF_ENTITY_ID)
req_state = config.get(CONF_STATE) req_state = cast(str, config.get(CONF_STATE))
for_period = config.get('for') for_period = config.get('for')
def if_state(hass, variables=None): def if_state(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Test if condition.""" """Test if condition."""
return state(hass, entity_id, req_state, for_period) return state(hass, entity_id, req_state, for_period)
return if_state return if_state
def sun(hass, before=None, after=None, before_offset=None, after_offset=None): def sun(hass: HomeAssistant, before: Optional[str] = None,
after: Optional[str] = None, before_offset: Optional[timedelta] = None,
after_offset: Optional[timedelta] = None) -> bool:
"""Test if current time matches sun requirements.""" """Test if current time matches sun requirements."""
utcnow = dt_util.utcnow() utcnow = dt_util.utcnow()
today = dt_util.as_local(utcnow).date() today = dt_util.as_local(utcnow).date()
@ -254,22 +279,27 @@ def sun(hass, before=None, after=None, before_offset=None, after_offset=None):
# There is no sunset today # There is no sunset today
return False return False
if before == SUN_EVENT_SUNRISE and utcnow > sunrise + before_offset: if before == SUN_EVENT_SUNRISE and \
utcnow > cast(datetime, sunrise) + before_offset:
return False return False
if before == SUN_EVENT_SUNSET and utcnow > sunset + before_offset: if before == SUN_EVENT_SUNSET and \
utcnow > cast(datetime, sunset) + before_offset:
return False return False
if after == SUN_EVENT_SUNRISE and utcnow < sunrise + after_offset: if after == SUN_EVENT_SUNRISE and \
utcnow < cast(datetime, sunrise) + after_offset:
return False return False
if after == SUN_EVENT_SUNSET and utcnow < sunset + after_offset: if after == SUN_EVENT_SUNSET and \
utcnow < cast(datetime, sunset) + after_offset:
return False return False
return True return True
def sun_from_config(config, config_validation=True): def sun_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Wrap action method with sun based condition.""" """Wrap action method with sun based condition."""
if config_validation: if config_validation:
config = cv.SUN_CONDITION_SCHEMA(config) config = cv.SUN_CONDITION_SCHEMA(config)
@ -278,21 +308,24 @@ def sun_from_config(config, config_validation=True):
before_offset = config.get('before_offset') before_offset = config.get('before_offset')
after_offset = config.get('after_offset') after_offset = config.get('after_offset')
def time_if(hass, variables=None): def time_if(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Validate time based if-condition.""" """Validate time based if-condition."""
return sun(hass, before, after, before_offset, after_offset) return sun(hass, before, after, before_offset, after_offset)
return time_if return time_if
def template(hass, value_template, variables=None): def template(hass: HomeAssistant, value_template: Template,
variables: TemplateVarsType = None) -> bool:
"""Test if template condition matches.""" """Test if template condition matches."""
return run_callback_threadsafe( return cast(bool, run_callback_threadsafe(
hass.loop, async_template, hass, value_template, variables, hass.loop, async_template, hass, value_template, variables,
).result() ).result())
def async_template(hass, value_template, variables=None): def async_template(hass: HomeAssistant, value_template: Template,
variables: TemplateVarsType = None) -> bool:
"""Test if template condition matches.""" """Test if template condition matches."""
try: try:
value = value_template.async_render(variables) value = value_template.async_render(variables)
@ -303,13 +336,16 @@ def async_template(hass, value_template, variables=None):
return value.lower() == 'true' return value.lower() == 'true'
def async_template_from_config(config, config_validation=True): def async_template_from_config(config: ConfigType,
config_validation: bool = True) \
-> Callable[..., bool]:
"""Wrap action method with state based condition.""" """Wrap action method with state based condition."""
if config_validation: if config_validation:
config = cv.TEMPLATE_CONDITION_SCHEMA(config) config = cv.TEMPLATE_CONDITION_SCHEMA(config)
value_template = config.get(CONF_VALUE_TEMPLATE) value_template = cast(Template, config.get(CONF_VALUE_TEMPLATE))
def template_if(hass, variables=None): def template_if(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Validate template based if-condition.""" """Validate template based if-condition."""
value_template.hass = hass value_template.hass = hass
@ -321,7 +357,9 @@ def async_template_from_config(config, config_validation=True):
template_from_config = _threaded_factory(async_template_from_config) template_from_config = _threaded_factory(async_template_from_config)
def time(before=None, after=None, weekday=None): def time(before: Optional[dt_util.dt.time] = None,
after: Optional[dt_util.dt.time] = None,
weekday: Union[None, str, Container[str]] = None) -> bool:
"""Test if local time condition matches. """Test if local time condition matches.
Handle the fact that time is continuous and we may be testing for Handle the fact that time is continuous and we may be testing for
@ -354,7 +392,8 @@ def time(before=None, after=None, weekday=None):
return True return True
def time_from_config(config, config_validation=True): def time_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Wrap action method with time based condition.""" """Wrap action method with time based condition."""
if config_validation: if config_validation:
config = cv.TIME_CONDITION_SCHEMA(config) config = cv.TIME_CONDITION_SCHEMA(config)
@ -362,14 +401,16 @@ def time_from_config(config, config_validation=True):
after = config.get(CONF_AFTER) after = config.get(CONF_AFTER)
weekday = config.get(CONF_WEEKDAY) weekday = config.get(CONF_WEEKDAY)
def time_if(hass, variables=None): def time_if(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Validate time based if-condition.""" """Validate time based if-condition."""
return time(before, after, weekday) return time(before, after, weekday)
return time_if return time_if
def zone(hass, zone_ent, entity): def zone(hass: HomeAssistant, zone_ent: Union[None, str, State],
entity: Union[None, str, State]) -> bool:
"""Test if zone-condition matches. """Test if zone-condition matches.
Async friendly. Async friendly.
@ -396,14 +437,16 @@ def zone(hass, zone_ent, entity):
entity.attributes.get(ATTR_GPS_ACCURACY, 0)) entity.attributes.get(ATTR_GPS_ACCURACY, 0))
def zone_from_config(config, config_validation=True): def zone_from_config(config: ConfigType,
config_validation: bool = True) -> Callable[..., bool]:
"""Wrap action method with zone based condition.""" """Wrap action method with zone based condition."""
if config_validation: if config_validation:
config = cv.ZONE_CONDITION_SCHEMA(config) config = cv.ZONE_CONDITION_SCHEMA(config)
entity_id = config.get(CONF_ENTITY_ID) entity_id = config.get(CONF_ENTITY_ID)
zone_entity_id = config.get(CONF_ZONE) zone_entity_id = config.get(CONF_ZONE)
def if_in_zone(hass, variables=None): def if_in_zone(hass: HomeAssistant,
variables: TemplateVarsType = None) -> bool:
"""Test if condition.""" """Test if condition."""
return zone(hass, zone_entity_id, entity_id) return zone(hass, zone_entity_id, entity_id)

View File

@ -18,6 +18,7 @@ from homeassistant.const import (
from homeassistant.core import State, valid_entity_id from homeassistant.core import State, valid_entity_id
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.helpers import location as loc_helper from homeassistant.helpers import location as loc_helper
from homeassistant.helpers.typing import TemplateVarsType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import convert from homeassistant.util import convert
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -115,7 +116,7 @@ class Template:
"""Extract all entities for state_changed listener.""" """Extract all entities for state_changed listener."""
return extract_entities(self.template, variables) return extract_entities(self.template, variables)
def render(self, variables=None, **kwargs): def render(self, variables: TemplateVarsType = None, **kwargs):
"""Render given template.""" """Render given template."""
if variables is not None: if variables is not None:
kwargs.update(variables) kwargs.update(variables)
@ -123,7 +124,8 @@ class Template:
return run_callback_threadsafe( return run_callback_threadsafe(
self.hass.loop, self.async_render, kwargs).result() self.hass.loop, self.async_render, kwargs).result()
def async_render(self, variables=None, **kwargs): def async_render(self, variables: TemplateVarsType = None,
**kwargs) -> str:
"""Render given template. """Render given template.
This method must be run in the event loop. This method must be run in the event loop.

View File

@ -1,5 +1,5 @@
"""Typing Helpers for Home Assistant.""" """Typing Helpers for Home Assistant."""
from typing import Dict, Any, Tuple from typing import Dict, Any, Tuple, Optional
import homeassistant.core import homeassistant.core
@ -9,6 +9,7 @@ GPSType = Tuple[float, float]
ConfigType = Dict[str, Any] ConfigType = Dict[str, Any]
HomeAssistantType = homeassistant.core.HomeAssistant HomeAssistantType = homeassistant.core.HomeAssistant
ServiceDataType = Dict[str, Any] ServiceDataType = Dict[str, Any]
TemplateVarsType = Optional[Dict[str, Any]]
# Custom type for recorder Queries # Custom type for recorder Queries
QueryType = Any QueryType = Any

View File

@ -60,4 +60,4 @@ whitelist_externals=/bin/bash
deps = deps =
-r{toxinidir}/requirements_test.txt -r{toxinidir}/requirements_test.txt
commands = commands =
/bin/bash -c 'mypy homeassistant/*.py homeassistant/{auth,util}/ homeassistant/helpers/{__init__,deprecation,dispatcher,entity_values,entityfilter,icon,intent,json,location,signal,state,sun,temperature,translation,typing}.py' /bin/bash -c 'mypy homeassistant/*.py homeassistant/{auth,util}/ homeassistant/helpers/{__init__,condition,deprecation,dispatcher,entity_values,entityfilter,icon,intent,json,location,signal,state,sun,temperature,translation,typing}.py'