From d4f78e8552ae3e4d800435e8f25e0c73feb71c85 Mon Sep 17 00:00:00 2001 From: Fabian Heredia Montiel Date: Sat, 23 Jul 2016 13:07:08 -0500 Subject: [PATCH] Type Hints - Core/Utils/Helpers Part 1 (#2592) * Fix deprecated(moved) import * Add util/dt typing * Green on mypy util/dt * Fix some errors * First part of yping util/yaml * Add more typing to util/yaml --- homeassistant/util/__init__.py | 8 +++-- homeassistant/util/color.py | 44 +++++++++++++------------ homeassistant/util/dt.py | 53 +++++++++++++++++-------------- homeassistant/util/temperature.py | 8 ++--- homeassistant/util/yaml.py | 49 +++++++++++++++++----------- tests/util/test_dt.py | 5 ++- 6 files changed, 96 insertions(+), 71 deletions(-) diff --git a/homeassistant/util/__init__.py b/homeassistant/util/__init__.py index 67719c18208..1f2584d655a 100644 --- a/homeassistant/util/__init__.py +++ b/homeassistant/util/__init__.py @@ -1,5 +1,5 @@ """Helper methods for various modules.""" -import collections +from collections.abc import MutableSet from itertools import chain import threading import queue @@ -12,6 +12,8 @@ import string from functools import wraps from types import MappingProxyType +from typing import Any + from .dt import as_local, utcnow RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)') @@ -36,7 +38,7 @@ def slugify(text): return RE_SLUGIFY.sub("", text) -def repr_helper(inp): +def repr_helper(inp: Any) -> str: """Help creating a more readable string representation of objects.""" if isinstance(inp, (dict, MappingProxyType)): return ", ".join( @@ -128,7 +130,7 @@ class OrderedEnum(enum.Enum): return NotImplemented -class OrderedSet(collections.MutableSet): +class OrderedSet(MutableSet): """Ordered set taken from http://code.activestate.com/recipes/576694/.""" def __init__(self, iterable=None): diff --git a/homeassistant/util/color.py b/homeassistant/util/color.py index dd504b57065..e9671c77328 100644 --- a/homeassistant/util/color.py +++ b/homeassistant/util/color.py @@ -1,7 +1,8 @@ """Color util methods.""" import logging import math -# pylint: disable=unused-import + +from typing import Tuple _LOGGER = logging.getLogger(__name__) @@ -36,14 +37,14 @@ def color_name_to_rgb(color_name): # http://www.developers.meethue.com/documentation/color-conversions-rgb-xy # License: Code is given as is. Use at your own risk and discretion. # pylint: disable=invalid-name -def color_RGB_to_xy(R, G, B): +def color_RGB_to_xy(iR: int, iG: int, iB: int) -> Tuple[float, float, int]: """Convert from RGB color to XY color.""" - if R + G + B == 0: - return 0, 0, 0 + if iR + iG + iB == 0: + return 0.0, 0.0, 0 - R = R / 255 - B = B / 255 - G = G / 255 + R = iR / 255 + B = iB / 255 + G = iG / 255 # Gamma correction R = pow((R + 0.055) / (1.0 + 0.055), @@ -72,9 +73,10 @@ def color_RGB_to_xy(R, G, B): # taken from # https://github.com/benknight/hue-python-rgb-converter/blob/master/rgb_cie.py # Copyright (c) 2014 Benjamin Knight / MIT License. -def color_xy_brightness_to_RGB(vX, vY, brightness): +def color_xy_brightness_to_RGB(vX: float, vY: float, + ibrightness: int) -> Tuple[int, int, int]: """Convert from XYZ to RGB.""" - brightness /= 255. + brightness = ibrightness / 255. if brightness == 0: return (0, 0, 0) @@ -106,17 +108,18 @@ def color_xy_brightness_to_RGB(vX, vY, brightness): if max_component > 1: r, g, b = map(lambda x: x / max_component, [r, g, b]) - r, g, b = map(lambda x: int(x * 255), [r, g, b]) + ir, ig, ib = map(lambda x: int(x * 255), [r, g, b]) - return (r, g, b) + return (ir, ig, ib) -def _match_max_scale(input_colors, output_colors): +def _match_max_scale(input_colors: Tuple[int, ...], + output_colors: Tuple[int, ...]) -> Tuple[int, ...]: """Match the maximum value of the output to the input.""" max_in = max(input_colors) max_out = max(output_colors) if max_out == 0: - factor = 0 + factor = 0.0 else: factor = max_in / max_out return tuple(int(round(i * factor)) for i in output_colors) @@ -176,7 +179,8 @@ def color_temperature_to_rgb(color_temperature_kelvin): return (red, green, blue) -def _bound(color_component, minimum=0, maximum=255): +def _bound(color_component: float, minimum: float=0, + maximum: float=255) -> float: """ Bound the given color component value between the given min and max values. @@ -188,7 +192,7 @@ def _bound(color_component, minimum=0, maximum=255): return min(color_component_out, maximum) -def _get_red(temperature): +def _get_red(temperature: float) -> float: """Get the red component of the temperature in RGB space.""" if temperature <= 66: return 255 @@ -196,7 +200,7 @@ def _get_red(temperature): return _bound(tmp_red) -def _get_green(temperature): +def _get_green(temperature: float) -> float: """Get the green component of the given color temp in RGB space.""" if temperature <= 66: green = 99.4708025861 * math.log(temperature) - 161.1195681661 @@ -205,13 +209,13 @@ def _get_green(temperature): return _bound(green) -def _get_blue(tmp_internal): +def _get_blue(temperature: float) -> float: """Get the blue component of the given color temperature in RGB space.""" - if tmp_internal >= 66: + if temperature >= 66: return 255 - if tmp_internal <= 19: + if temperature <= 19: return 0 - blue = 138.5177312231 * math.log(tmp_internal - 10) - 305.0447927307 + blue = 138.5177312231 * math.log(temperature - 10) - 305.0447927307 return _bound(blue) diff --git a/homeassistant/util/dt.py b/homeassistant/util/dt.py index b8b7a691859..a5724ee90e1 100644 --- a/homeassistant/util/dt.py +++ b/homeassistant/util/dt.py @@ -2,10 +2,13 @@ import datetime as dt import re +# pylint: disable=unused-import +from typing import Any, Union, Optional, Tuple # NOQA + import pytz DATE_STR_FORMAT = "%Y-%m-%d" -UTC = DEFAULT_TIME_ZONE = pytz.utc +UTC = DEFAULT_TIME_ZONE = pytz.utc # type: pytz.UTC # Copyright (c) Django Software Foundation and individual contributors. @@ -19,16 +22,17 @@ DATETIME_RE = re.compile( ) -def set_default_time_zone(time_zone): +def set_default_time_zone(time_zone: dt.tzinfo) -> None: """Set a default time zone to be used when none is specified.""" global DEFAULT_TIME_ZONE # pylint: disable=global-statement + # NOTE: Remove in the future in favour of typing assert isinstance(time_zone, dt.tzinfo) DEFAULT_TIME_ZONE = time_zone -def get_time_zone(time_zone_str): +def get_time_zone(time_zone_str: str) -> Optional[dt.tzinfo]: """Get time zone from string. Return None if unable to determine.""" try: return pytz.timezone(time_zone_str) @@ -36,17 +40,17 @@ def get_time_zone(time_zone_str): return None -def utcnow(): +def utcnow() -> dt.datetime: """Get now in UTC time.""" return dt.datetime.now(UTC) -def now(time_zone=None): +def now(time_zone: dt.tzinfo=None) -> dt.datetime: """Get now in specified time zone.""" return dt.datetime.now(time_zone or DEFAULT_TIME_ZONE) -def as_utc(dattim): +def as_utc(dattim: dt.datetime) -> dt.datetime: """Return a datetime as UTC time. Assumes datetime without tzinfo to be in the DEFAULT_TIME_ZONE. @@ -70,7 +74,7 @@ def as_timestamp(dt_value): return parsed_dt.timestamp() -def as_local(dattim): +def as_local(dattim: dt.datetime) -> dt.datetime: """Convert a UTC datetime object to local time zone.""" if dattim.tzinfo == DEFAULT_TIME_ZONE: return dattim @@ -80,12 +84,13 @@ def as_local(dattim): return dattim.astimezone(DEFAULT_TIME_ZONE) -def utc_from_timestamp(timestamp): +def utc_from_timestamp(timestamp: float) -> dt.datetime: """Return a UTC time from a timestamp.""" return dt.datetime.utcfromtimestamp(timestamp).replace(tzinfo=UTC) -def start_of_local_day(dt_or_d=None): +def start_of_local_day(dt_or_d: + Union[dt.date, dt.datetime]=None) -> dt.datetime: """Return local datetime object of start of day from date or datetime.""" if dt_or_d is None: dt_or_d = now().date() @@ -98,7 +103,7 @@ def start_of_local_day(dt_or_d=None): # Copyright (c) Django Software Foundation and individual contributors. # All rights reserved. # https://github.com/django/django/blob/master/LICENSE -def parse_datetime(dt_str): +def parse_datetime(dt_str: str) -> dt.datetime: """Parse a string and return a datetime.datetime. This function supports time zone offsets. When the input contains one, @@ -109,25 +114,27 @@ def parse_datetime(dt_str): match = DATETIME_RE.match(dt_str) if not match: return None - kws = match.groupdict() + kws = match.groupdict() # type: Dict[str, Any] if kws['microsecond']: kws['microsecond'] = kws['microsecond'].ljust(6, '0') - tzinfo = kws.pop('tzinfo') - if tzinfo == 'Z': + tzinfo_str = kws.pop('tzinfo') + if tzinfo_str == 'Z': tzinfo = UTC - elif tzinfo is not None: - offset_mins = int(tzinfo[-2:]) if len(tzinfo) > 3 else 0 - offset_hours = int(tzinfo[1:3]) + elif tzinfo_str is not None: + offset_mins = int(tzinfo_str[-2:]) if len(tzinfo_str) > 3 else 0 + offset_hours = int(tzinfo_str[1:3]) offset = dt.timedelta(hours=offset_hours, minutes=offset_mins) - if tzinfo[0] == '-': + if tzinfo_str[0] == '-': offset = -offset tzinfo = dt.timezone(offset) + else: + tzinfo = None kws = {k: int(v) for k, v in kws.items() if v is not None} kws['tzinfo'] = tzinfo return dt.datetime(**kws) -def parse_date(dt_str): +def parse_date(dt_str: str) -> dt.date: """Convert a date string to a date object.""" try: return dt.datetime.strptime(dt_str, DATE_STR_FORMAT).date() @@ -154,7 +161,7 @@ def parse_time(time_str): # Found in this gist: https://gist.github.com/zhangsen/1199964 -def get_age(date): +def get_age(date: dt.datetime) -> str: # pylint: disable=too-many-return-statements """ Take a datetime and return its "age" as a string. @@ -164,14 +171,14 @@ def get_age(date): be returned. Make sure date is not in the future, or else it won't work. """ - def formatn(number, unit): + def formatn(number: int, unit: str) -> str: """Add "unit" if it's plural.""" if number == 1: return "1 %s" % unit elif number > 1: return "%d %ss" % (number, unit) - def q_n_r(first, second): + def q_n_r(first: int, second: int) -> Tuple[int, int]: """Return quotient and remaining.""" return first // second, first % second @@ -196,7 +203,5 @@ def get_age(date): minute, second = q_n_r(second, 60) if minute > 0: return formatn(minute, 'minute') - if second > 0: - return formatn(second, 'second') - return "0 second" + return formatn(second, 'second') if second > 0 else "0 seconds" diff --git a/homeassistant/util/temperature.py b/homeassistant/util/temperature.py index 293ddcf44cf..59112a709ca 100644 --- a/homeassistant/util/temperature.py +++ b/homeassistant/util/temperature.py @@ -3,7 +3,7 @@ import logging -def fahrenheit_to_celcius(fahrenheit): +def fahrenheit_to_celcius(fahrenheit: float) -> float: """**DEPRECATED** Convert a Fahrenheit temperature to Celsius.""" logging.getLogger(__name__).warning( 'fahrenheit_to_celcius is now fahrenheit_to_celsius ' @@ -11,12 +11,12 @@ def fahrenheit_to_celcius(fahrenheit): return fahrenheit_to_celsius(fahrenheit) -def fahrenheit_to_celsius(fahrenheit): +def fahrenheit_to_celsius(fahrenheit: float) -> float: """Convert a Fahrenheit temperature to Celsius.""" return (fahrenheit - 32.0) / 1.8 -def celcius_to_fahrenheit(celcius): +def celcius_to_fahrenheit(celcius: float) -> float: """**DEPRECATED** Convert a Celsius temperature to Fahrenheit.""" logging.getLogger(__name__).warning( 'celcius_to_fahrenheit is now celsius_to_fahrenheit correcting ' @@ -24,6 +24,6 @@ def celcius_to_fahrenheit(celcius): return celsius_to_fahrenheit(celcius) -def celsius_to_fahrenheit(celsius): +def celsius_to_fahrenheit(celsius: float) -> float: """Convert a Celsius temperature to Fahrenheit.""" return celsius * 1.8 + 32.0 diff --git a/homeassistant/util/yaml.py b/homeassistant/util/yaml.py index feafdc2c6ff..8b2521e3e9b 100644 --- a/homeassistant/util/yaml.py +++ b/homeassistant/util/yaml.py @@ -2,6 +2,7 @@ import logging import os from collections import OrderedDict +from typing import Union, List, Dict import glob import yaml @@ -21,15 +22,16 @@ _SECRET_YAML = 'secrets.yaml' class SafeLineLoader(yaml.SafeLoader): """Loader class that keeps track of line numbers.""" - def compose_node(self, parent, index): + def compose_node(self, parent: yaml.nodes.Node, index) -> yaml.nodes.Node: """Annotate a node with the first line it was seen.""" - last_line = self.line - node = super(SafeLineLoader, self).compose_node(parent, index) + last_line = self.line # type: int + node = super(SafeLineLoader, + self).compose_node(parent, index) # type: yaml.nodes.Node node.__line__ = last_line + 1 return node -def load_yaml(fname): +def load_yaml(fname: str) -> Union[List, Dict]: """Load a YAML file.""" try: with open(fname, encoding='utf-8') as conf_file: @@ -41,7 +43,8 @@ def load_yaml(fname): raise HomeAssistantError(exc) -def _include_yaml(loader, node): +def _include_yaml(loader: SafeLineLoader, + node: yaml.nodes.Node) -> Union[List, Dict]: """Load another YAML file and embeds it using the !include tag. Example: @@ -51,9 +54,10 @@ def _include_yaml(loader, node): return load_yaml(fname) -def _include_dir_named_yaml(loader, node): +def _include_dir_named_yaml(loader: SafeLineLoader, + node: yaml.nodes.Node): """Load multiple files from directory as a dictionary.""" - mapping = OrderedDict() + mapping = OrderedDict() # type: OrderedDict files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') for fname in glob.glob(files): filename = os.path.splitext(os.path.basename(fname))[0] @@ -61,9 +65,10 @@ def _include_dir_named_yaml(loader, node): return mapping -def _include_dir_merge_named_yaml(loader, node): +def _include_dir_merge_named_yaml(loader: SafeLineLoader, + node: yaml.nodes.Node): """Load multiple files from directory as a merged dictionary.""" - mapping = OrderedDict() + mapping = OrderedDict() # type: OrderedDict files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') for fname in glob.glob(files): if os.path.basename(fname) == _SECRET_YAML: @@ -74,17 +79,20 @@ def _include_dir_merge_named_yaml(loader, node): return mapping -def _include_dir_list_yaml(loader, node): +def _include_dir_list_yaml(loader: SafeLineLoader, + node: yaml.nodes.Node): """Load multiple files from directory as a list.""" files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') return [load_yaml(f) for f in glob.glob(files) if os.path.basename(f) != _SECRET_YAML] -def _include_dir_merge_list_yaml(loader, node): +def _include_dir_merge_list_yaml(loader: SafeLineLoader, + node: yaml.nodes.Node): """Load multiple files from directory as a merged list.""" - files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') - merged_list = [] + files = os.path.join(os.path.dirname(loader.name), + node.value, '*.yaml') # type: str + merged_list = [] # type: List for fname in glob.glob(files): if os.path.basename(fname) == _SECRET_YAML: continue @@ -94,12 +102,13 @@ def _include_dir_merge_list_yaml(loader, node): return merged_list -def _ordered_dict(loader, node): +def _ordered_dict(loader: SafeLineLoader, + node: yaml.nodes.MappingNode) -> OrderedDict: """Load YAML mappings into an ordered dictionary to preserve key order.""" loader.flatten_mapping(node) nodes = loader.construct_pairs(node) - seen = {} + seen = {} # type: Dict min_line = None for (key, _), (node, _) in zip(nodes, node.value): line = getattr(node, '__line__', 'unknown') @@ -116,12 +125,13 @@ def _ordered_dict(loader, node): seen[key] = line processed = OrderedDict(nodes) - processed.__config_file__ = loader.name - processed.__line__ = min_line + setattr(processed, '__config_file__', loader.name) + setattr(processed, '__line__', min_line) return processed -def _env_var_yaml(loader, node): +def _env_var_yaml(loader: SafeLineLoader, + node: yaml.nodes.Node): """Load environment variables and embed it into the configuration YAML.""" if node.value in os.environ: return os.environ[node.value] @@ -131,7 +141,8 @@ def _env_var_yaml(loader, node): # pylint: disable=protected-access -def _secret_yaml(loader, node): +def _secret_yaml(loader: SafeLineLoader, + node: yaml.nodes.Node): """Load secrets and embed it into the configuration YAML.""" # Create secret cache on loader and load secrets.yaml if not hasattr(loader, '_SECRET_CACHE'): diff --git a/tests/util/test_dt.py b/tests/util/test_dt.py index bf5284a0b04..e8114e93e24 100644 --- a/tests/util/test_dt.py +++ b/tests/util/test_dt.py @@ -137,7 +137,10 @@ class TestDateUtil(unittest.TestCase): def test_get_age(self): """Test get_age.""" diff = dt_util.now() - timedelta(seconds=0) - self.assertEqual(dt_util.get_age(diff), "0 second") + self.assertEqual(dt_util.get_age(diff), "0 seconds") + + diff = dt_util.now() - timedelta(seconds=1) + self.assertEqual(dt_util.get_age(diff), "1 second") diff = dt_util.now() - timedelta(seconds=30) self.assertEqual(dt_util.get_age(diff), "30 seconds")