Improvement typing (#2735)

* Fix: Circular dependencies of internal files

* Change: dt.date for Date and dt.datetime for DateTime

* Use NewType if available

* FIX: Wrong version test

* Remove: Date and DateTime types due to error

* Change to HomeAssistantType

* General Improvement of Typing

* Improve typing config_validation

* Improve typing script

* General Typing Improvements

* Improve NewType check

* Improve typing db_migrator

* Improve util/__init__ typing

* Improve helpers/location typing

* Regroup imports and remove pylint: disable=ungrouped-imports

* General typing improvements
This commit is contained in:
Fabian Heredia Montiel 2016-08-07 18:26:35 -05:00 committed by Paulus Schoutsen
parent a3ca3e878b
commit 0377338a81
16 changed files with 139 additions and 66 deletions

View File

@ -274,7 +274,7 @@ def try_to_restart() -> None:
# thread left (which is us). Nothing we really do with it, but it might be # thread left (which is us). Nothing we really do with it, but it might be
# useful when debugging shutdown/restart issues. # useful when debugging shutdown/restart issues.
try: try:
nthreads = sum(thread.isAlive() and not thread.isDaemon() nthreads = sum(thread.is_alive() and not thread.daemon
for thread in threading.enumerate()) for thread in threading.enumerate())
if nthreads > 1: if nthreads > 1:
sys.stderr.write( sys.stderr.write(

View File

@ -12,6 +12,8 @@ from typing import Any, Optional, Dict
import voluptuous as vol import voluptuous as vol
from homeassistant.helpers.typing import HomeAssistantType
import homeassistant.components as core_components import homeassistant.components as core_components
from homeassistant.components import group, persistent_notification from homeassistant.components import group, persistent_notification
import homeassistant.config as conf_util import homeassistant.config as conf_util
@ -216,7 +218,7 @@ def prepare_setup_platform(hass: core.HomeAssistant, config, domain: str,
# pylint: disable=too-many-branches, too-many-statements, too-many-arguments # pylint: disable=too-many-branches, too-many-statements, too-many-arguments
def from_config_dict(config: Dict[str, Any], def from_config_dict(config: Dict[str, Any],
hass: Optional[core.HomeAssistant]=None, hass: Optional[HomeAssistantType]=None,
config_dir: Optional[str]=None, config_dir: Optional[str]=None,
enable_log: bool=True, enable_log: bool=True,
verbose: bool=False, verbose: bool=False,

View File

@ -4,6 +4,9 @@ import os
import shutil import shutil
from types import MappingProxyType from types import MappingProxyType
# pylint: disable=unused-import
from typing import Any, Tuple # NOQA
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
@ -37,7 +40,7 @@ DEFAULT_CORE_CONFIG = (
CONF_UNIT_SYSTEM_IMPERIAL)), CONF_UNIT_SYSTEM_IMPERIAL)),
(CONF_TIME_ZONE, 'UTC', 'time_zone', 'Pick yours from here: http://en.wiki' (CONF_TIME_ZONE, 'UTC', 'time_zone', 'Pick yours from here: http://en.wiki'
'pedia.org/wiki/List_of_tz_database_time_zones'), 'pedia.org/wiki/List_of_tz_database_time_zones'),
) ) # type: Tuple[Tuple[str, Any, Any, str], ...]
DEFAULT_CONFIG = """ DEFAULT_CONFIG = """
# Show links to resources in log and frontend # Show links to resources in log and frontend
introduction: introduction:

View File

@ -14,9 +14,13 @@ import threading
import time import time
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Callable # pylint: disable=unused-import
from typing import Optional, Any, Callable # NOQA
import voluptuous as vol import voluptuous as vol
from homeassistant.helpers.typing import UnitSystemType # NOQA
import homeassistant.util as util import homeassistant.util as util
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
import homeassistant.util.location as location import homeassistant.util.location as location
@ -713,15 +717,15 @@ class Config(object):
# pylint: disable=too-many-instance-attributes # pylint: disable=too-many-instance-attributes
def __init__(self): def __init__(self):
"""Initialize a new config object.""" """Initialize a new config object."""
self.latitude = None self.latitude = None # type: Optional[float]
self.longitude = None self.longitude = None # type: Optional[float]
self.elevation = None self.elevation = None # type: Optional[int]
self.location_name = None self.location_name = None # type: Optional[str]
self.time_zone = None self.time_zone = None # type: Optional[str]
self.units = METRIC_SYSTEM self.units = METRIC_SYSTEM # type: UnitSystemType
# If True, pip install is skipped for requirements on startup # If True, pip install is skipped for requirements on startup
self.skip_pip = False self.skip_pip = False # type: bool
# List of loaded components # List of loaded components
self.components = [] self.components = []

View File

@ -3,6 +3,8 @@ from datetime import timedelta
import logging import logging
import sys import sys
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
from homeassistant.components import ( from homeassistant.components import (
zone as zone_cmp, sun as sun_cmp) zone as zone_cmp, sun as sun_cmp)
from homeassistant.const import ( from homeassistant.const import (
@ -21,7 +23,7 @@ FROM_CONFIG_FORMAT = '{}_from_config'
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def from_config(config, config_validation=True): def from_config(config: ConfigType, config_validation: bool=True):
"""Turn a condition configuration into a method.""" """Turn a condition configuration into a method."""
factory = getattr( factory = getattr(
sys.modules[__name__], sys.modules[__name__],
@ -34,13 +36,14 @@ def from_config(config, config_validation=True):
return factory(config, config_validation) return factory(config, config_validation)
def and_from_config(config, config_validation=True): def and_from_config(config: ConfigType, config_validation: bool=True):
"""Create multi condition matcher using 'AND'.""" """Create multi condition matcher using 'AND'."""
if config_validation: if config_validation:
config = cv.AND_CONDITION_SCHEMA(config) config = cv.AND_CONDITION_SCHEMA(config)
checks = [from_config(entry) for entry in config['conditions']] checks = [from_config(entry) for entry in config['conditions']]
def if_and_condition(hass, variables=None): def if_and_condition(hass: HomeAssistantType,
variables=None) -> bool:
"""Test and condition.""" """Test and condition."""
for check in checks: for check in checks:
try: try:
@ -55,13 +58,14 @@ def and_from_config(config, config_validation=True):
return if_and_condition return if_and_condition
def or_from_config(config, config_validation=True): def or_from_config(config: ConfigType, config_validation: bool=True):
"""Create multi condition matcher using 'OR'.""" """Create multi condition matcher using 'OR'."""
if config_validation: if config_validation:
config = cv.OR_CONDITION_SCHEMA(config) config = cv.OR_CONDITION_SCHEMA(config)
checks = [from_config(entry) for entry in config['conditions']] checks = [from_config(entry) for entry in config['conditions']]
def if_or_condition(hass, variables=None): def if_or_condition(hass: HomeAssistantType,
variables=None) -> bool:
"""Test and condition.""" """Test and condition."""
for check in checks: for check in checks:
try: try:
@ -76,8 +80,8 @@ def or_from_config(config, config_validation=True):
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def numeric_state(hass, entity, below=None, above=None, value_template=None, def numeric_state(hass: HomeAssistantType, entity, below=None, above=None,
variables=None): value_template=None, variables=None):
"""Test a numeric state condition.""" """Test a numeric state condition."""
if isinstance(entity, str): if isinstance(entity, str):
entity = hass.states.get(entity) entity = hass.states.get(entity)
@ -93,7 +97,7 @@ def numeric_state(hass, entity, below=None, above=None, value_template=None,
try: try:
value = render(hass, value_template, variables) value = render(hass, value_template, variables)
except TemplateError as ex: except TemplateError as ex:
_LOGGER.error(ex) _LOGGER.error("Template error: %s", ex)
return False return False
try: try:

View File

@ -1,6 +1,8 @@
"""Helpers for config validation using voluptuous.""" """Helpers for config validation using voluptuous."""
from datetime import timedelta from datetime import timedelta
from typing import Any, Union, TypeVar, Callable, Sequence, List, Dict
import jinja2 import jinja2
import voluptuous as vol import voluptuous as vol
@ -28,12 +30,15 @@ longitude = vol.All(vol.Coerce(float), vol.Range(min=-180, max=180),
msg='invalid longitude') msg='invalid longitude')
sun_event = vol.All(vol.Lower, vol.Any(SUN_EVENT_SUNSET, SUN_EVENT_SUNRISE)) sun_event = vol.All(vol.Lower, vol.Any(SUN_EVENT_SUNSET, SUN_EVENT_SUNRISE))
# typing typevar
T = TypeVar('T')
# Adapted from: # Adapted from:
# https://github.com/alecthomas/voluptuous/issues/115#issuecomment-144464666 # https://github.com/alecthomas/voluptuous/issues/115#issuecomment-144464666
def has_at_least_one_key(*keys): def has_at_least_one_key(*keys: str) -> Callable:
"""Validator that at least one key exists.""" """Validator that at least one key exists."""
def validate(obj): def validate(obj: Dict) -> Dict:
"""Test keys exist in dict.""" """Test keys exist in dict."""
if not isinstance(obj, dict): if not isinstance(obj, dict):
raise vol.Invalid('expected dictionary') raise vol.Invalid('expected dictionary')
@ -46,7 +51,7 @@ def has_at_least_one_key(*keys):
return validate return validate
def boolean(value): def boolean(value: Any) -> bool:
"""Validate and coerce a boolean value.""" """Validate and coerce a boolean value."""
if isinstance(value, str): if isinstance(value, str):
value = value.lower() value = value.lower()
@ -63,12 +68,12 @@ def isfile(value):
return vol.IsFile('not a file')(value) return vol.IsFile('not a file')(value)
def ensure_list(value): def ensure_list(value: Union[T, Sequence[T]]) -> List[T]:
"""Wrap value in list if it is not one.""" """Wrap value in list if it is not one."""
return value if isinstance(value, list) else [value] return value if isinstance(value, list) else [value]
def entity_id(value): def entity_id(value: Any) -> str:
"""Validate Entity ID.""" """Validate Entity ID."""
value = string(value).lower() value = string(value).lower()
if valid_entity_id(value): if valid_entity_id(value):
@ -76,7 +81,7 @@ def entity_id(value):
raise vol.Invalid('Entity ID {} is an invalid entity id'.format(value)) raise vol.Invalid('Entity ID {} is an invalid entity id'.format(value))
def entity_ids(value): def entity_ids(value: Union[str, Sequence]) -> List[str]:
"""Validate Entity IDs.""" """Validate Entity IDs."""
if value is None: if value is None:
raise vol.Invalid('Entity IDs can not be None') raise vol.Invalid('Entity IDs can not be None')
@ -109,7 +114,7 @@ time_period_dict = vol.All(
lambda value: timedelta(**value)) lambda value: timedelta(**value))
def time_period_str(value): def time_period_str(value: str) -> timedelta:
"""Validate and transform time offset.""" """Validate and transform time offset."""
if isinstance(value, int): if isinstance(value, int):
raise vol.Invalid('Make sure you wrap time values in quotes') raise vol.Invalid('Make sure you wrap time values in quotes')
@ -182,7 +187,7 @@ def platform_validator(domain):
return validator return validator
def positive_timedelta(value): def positive_timedelta(value: timedelta) -> timedelta:
"""Validate timedelta is positive.""" """Validate timedelta is positive."""
if value < timedelta(0): if value < timedelta(0):
raise vol.Invalid('Time period should be positive') raise vol.Invalid('Time period should be positive')
@ -209,14 +214,14 @@ def slug(value):
raise vol.Invalid('invalid slug {} (try {})'.format(value, slg)) raise vol.Invalid('invalid slug {} (try {})'.format(value, slg))
def string(value): def string(value: Any) -> str:
"""Coerce value to string, except for None.""" """Coerce value to string, except for None."""
if value is not None: if value is not None:
return str(value) return str(value)
raise vol.Invalid('string value is None') raise vol.Invalid('string value is None')
def temperature_unit(value): def temperature_unit(value) -> str:
"""Validate and transform temperature unit.""" """Validate and transform temperature unit."""
value = str(value).upper() value = str(value).upper()
if value == 'C': if value == 'C':

View File

@ -12,9 +12,7 @@ from homeassistant.const import (
from homeassistant.exceptions import NoEntitySpecifiedError from homeassistant.exceptions import NoEntitySpecifiedError
from homeassistant.util import ensure_unique_string, slugify from homeassistant.util import ensure_unique_string, slugify
# pylint: disable=using-constant-test,unused-import from homeassistant.helpers.typing import HomeAssistantType
if False:
from homeassistant.core import HomeAssistant # NOQA
# Entity attributes that we will overwrite # Entity attributes that we will overwrite
_OVERWRITE = {} # type: Dict[str, Any] _OVERWRITE = {} # type: Dict[str, Any]
@ -27,7 +25,7 @@ ENTITY_ID_PATTERN = re.compile(r"^(\w+)\.(\w+)$")
def generate_entity_id(entity_id_format: str, name: Optional[str], def generate_entity_id(entity_id_format: str, name: Optional[str],
current_ids: Optional[List[str]]=None, current_ids: Optional[List[str]]=None,
hass: 'Optional[HomeAssistant]'=None) -> str: hass: Optional[HomeAssistantType]=None) -> str:
"""Generate a unique entity ID based on given entity IDs or used IDs.""" """Generate a unique entity ID based on given entity IDs or used IDs."""
name = (name or DEVICE_DEFAULT_NAME).lower() name = (name or DEVICE_DEFAULT_NAME).lower()
if current_ids is None: if current_ids is None:
@ -153,7 +151,7 @@ class Entity(object):
# are used to perform a very specific function. Overwriting these may # are used to perform a very specific function. Overwriting these may
# produce undesirable effects in the entity's operation. # produce undesirable effects in the entity's operation.
hass = None # type: Optional[HomeAssistant] hass = None # type: Optional[HomeAssistantType]
def update_ha_state(self, force_refresh=False): def update_ha_state(self, force_refresh=False):
"""Update Home Assistant with current state of entity. """Update Home Assistant with current state of entity.

View File

@ -1,9 +1,13 @@
"""Event Decorators for custom components.""" """Event Decorators for custom components."""
import functools import functools
# pylint: disable=unused-import
from typing import Optional # NOQA
from homeassistant.helpers.typing import HomeAssistantType # NOQA
from homeassistant.helpers import event from homeassistant.helpers import event
HASS = None HASS = None # type: Optional[HomeAssistantType]
def track_state_change(entity_ids, from_state=None, to_state=None): def track_state_change(entity_ids, from_state=None, to_state=None):

View File

@ -1,18 +1,21 @@
"""Location helpers for Home Assistant.""" """Location helpers for Home Assistant."""
from typing import Sequence
from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE
from homeassistant.core import State from homeassistant.core import State
from homeassistant.util import location as loc_util from homeassistant.util import location as loc_util
def has_location(state): def has_location(state: State) -> bool:
"""Test if state contains a valid location.""" """Test if state contains a valid location."""
return (isinstance(state, State) and return (isinstance(state, State) and
isinstance(state.attributes.get(ATTR_LATITUDE), float) and isinstance(state.attributes.get(ATTR_LATITUDE), float) and
isinstance(state.attributes.get(ATTR_LONGITUDE), float)) isinstance(state.attributes.get(ATTR_LONGITUDE), float))
def closest(latitude, longitude, states): def closest(latitude: float, longitude: float,
states: Sequence[State]) -> State:
"""Return closest state to point.""" """Return closest state to point."""
with_location = [state for state in states if has_location(state)] with_location = [state for state in states if has_location(state)]

View File

@ -3,8 +3,12 @@ import logging
import threading import threading
from itertools import islice from itertools import islice
from typing import Optional, Sequence
import voluptuous as vol import voluptuous as vol
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
import homeassistant.util.dt as date_util import homeassistant.util.dt as date_util
from homeassistant.const import EVENT_TIME_CHANGED, CONF_CONDITION from homeassistant.const import EVENT_TIME_CHANGED, CONF_CONDITION
from homeassistant.helpers.event import track_point_in_utc_time from homeassistant.helpers.event import track_point_in_utc_time
@ -22,7 +26,8 @@ CONF_EVENT_DATA = "event_data"
CONF_DELAY = "delay" CONF_DELAY = "delay"
def call_from_config(hass, config, variables=None): def call_from_config(hass: HomeAssistantType, config: ConfigType,
variables: Optional[Sequence]=None) -> None:
"""Call a script based on a config entry.""" """Call a script based on a config entry."""
Script(hass, config).run(variables) Script(hass, config).run(variables)
@ -31,7 +36,8 @@ class Script():
"""Representation of a script.""" """Representation of a script."""
# pylint: disable=too-many-instance-attributes # pylint: disable=too-many-instance-attributes
def __init__(self, hass, sequence, name=None, change_listener=None): def __init__(self, hass: HomeAssistantType, sequence, name: str=None,
change_listener=None) -> None:
"""Initialize the script.""" """Initialize the script."""
self.hass = hass self.hass = hass
self.sequence = cv.SCRIPT_SCHEMA(sequence) self.sequence = cv.SCRIPT_SCHEMA(sequence)
@ -45,11 +51,11 @@ class Script():
self._delay_listener = None self._delay_listener = None
@property @property
def is_running(self): def is_running(self) -> bool:
"""Return true if script is on.""" """Return true if script is on."""
return self._cur != -1 return self._cur != -1
def run(self, variables=None): def run(self, variables: Optional[Sequence]=None) -> None:
"""Run script.""" """Run script."""
with self._lock: with self._lock:
if self._cur == -1: if self._cur == -1:
@ -101,7 +107,7 @@ class Script():
if self._change_listener: if self._change_listener:
self._change_listener() self._change_listener()
def stop(self): def stop(self) -> None:
"""Stop running script.""" """Stop running script."""
with self._lock: with self._lock:
if self._cur == -1: if self._cur == -1:

View File

@ -2,15 +2,20 @@
import functools import functools
import logging import logging
# pylint: disable=unused-import
from typing import Optional # NOQA
import voluptuous as vol import voluptuous as vol
from homeassistant.helpers.typing import HomeAssistantType # NOQA
from homeassistant.const import ATTR_ENTITY_ID from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.helpers import template from homeassistant.helpers import template
from homeassistant.loader import get_component from homeassistant.loader import get_component
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
HASS = None HASS = None # type: Optional[HomeAssistantType]
CONF_SERVICE = 'service' CONF_SERVICE = 'service'
CONF_SERVICE_TEMPLATE = 'service_template' CONF_SERVICE_TEMPLATE = 'service_template'

View File

@ -1,12 +1,39 @@
"""Typing Helpers for Home-Assistant.""" """Typing Helpers for Home-Assistant."""
from typing import Dict, Any from typing import Dict, Any
import homeassistant.core # NOTE: NewType added to typing in 3.5.2 in June, 2016; Since 3.5.2 includes
# security fixes everyone on 3.5 should upgrade "soon"
try:
from typing import NewType
except ImportError:
NewType = None
# HACK: mypy/pytype will import, other interpreters will not; this is to avoid
# circular dependencies where the type is needed.
# All homeassistant types should be imported this way.
# Documentation
# http://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
# pylint: disable=using-constant-test,unused-import
if False:
from homeassistant.core import HomeAssistant # NOQA
from homeassistant.helpers.unit_system import UnitSystem # NOQA
# ENDHACK
# pylint: disable=invalid-name # pylint: disable=invalid-name
if NewType:
ConfigType = Dict[str, Any] ConfigType = NewType('ConfigType', Dict[str, Any])
HomeAssistantType = homeassistant.core.HomeAssistant HomeAssistantType = NewType('HomeAssistantType', 'HomeAssistant')
UnitSystemType = NewType('UnitSystemType', 'UnitSystem')
# Custom type for recorder Queries # Custom type for recorder Queries
QueryType = Any QueryType = NewType('QueryType', Any)
# Duplicates for 3.5.1
# pylint: disable=invalid-name
else:
ConfigType = Dict[str, Any] # type: ignore
HomeAssistantType = 'HomeAssistant' # type: ignore
UnitSystemType = 'UnitSystemType' # type: ignore
# Custom type for recorder Queries
QueryType = Any # type: ignore

View File

@ -15,6 +15,8 @@ import time
import threading import threading
import urllib.parse import urllib.parse
from typing import Optional
import requests import requests
import homeassistant.bootstrap as bootstrap import homeassistant.bootstrap as bootstrap
@ -42,7 +44,7 @@ class APIStatus(enum.Enum):
CANNOT_CONNECT = "cannot_connect" CANNOT_CONNECT = "cannot_connect"
UNKNOWN = "unknown" UNKNOWN = "unknown"
def __str__(self): def __str__(self) -> str:
"""Return the state.""" """Return the state."""
return self.value return self.value
@ -51,7 +53,8 @@ class API(object):
"""Object to pass around Home Assistant API location and credentials.""" """Object to pass around Home Assistant API location and credentials."""
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
def __init__(self, host, api_password=None, port=None, use_ssl=False): def __init__(self, host: str, api_password: Optional[str]=None,
port: Optional[int]=None, use_ssl: bool=False) -> None:
"""Initalize the API.""" """Initalize the API."""
self.host = host self.host = host
self.port = port or SERVER_PORT self.port = port or SERVER_PORT
@ -68,7 +71,7 @@ class API(object):
if api_password is not None: if api_password is not None:
self._headers[HTTP_HEADER_HA_AUTH] = api_password self._headers[HTTP_HEADER_HA_AUTH] = api_password
def validate_api(self, force_validate=False): def validate_api(self, force_validate: bool=False) -> bool:
"""Test if we can communicate with the API.""" """Test if we can communicate with the API."""
if self.status is None or force_validate: if self.status is None or force_validate:
self.status = validate_api(self) self.status = validate_api(self)
@ -100,7 +103,7 @@ class API(object):
_LOGGER.exception(error) _LOGGER.exception(error)
raise HomeAssistantError(error) raise HomeAssistantError(error)
def __repr__(self): def __repr__(self) -> str:
"""Return the representation of the API.""" """Return the representation of the API."""
return "API({}, {}, {})".format( return "API({}, {}, {})".format(
self.host, self.api_password, self.port) self.host, self.api_password, self.port)

View File

@ -4,6 +4,10 @@ import argparse
import os.path import os.path
import sqlite3 import sqlite3
import sys import sys
from datetime import datetime
from typing import Optional
try: try:
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
@ -16,7 +20,7 @@ import homeassistant.config as config_util
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
def ts_to_dt(timestamp): def ts_to_dt(timestamp: Optional[float]) -> Optional[datetime]:
"""Turn a datetime into an integer for in the DB.""" """Turn a datetime into an integer for in the DB."""
if timestamp is None: if timestamp is None:
return None return None
@ -26,8 +30,8 @@ def ts_to_dt(timestamp):
# Based on code at # Based on code at
# http://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console # http://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def print_progress(iteration, total, prefix='', suffix='', decimals=2, def print_progress(iteration: int, total: int, prefix: str='', suffix: str='',
bar_length=68): decimals: int=2, bar_length: int=68) -> None:
"""Print progress bar. """Print progress bar.
Call in a loop to create terminal progress bar Call in a loop to create terminal progress bar
@ -49,7 +53,7 @@ def print_progress(iteration, total, prefix='', suffix='', decimals=2,
print("\n") print("\n")
def run(args): def run(args) -> int:
"""The actual script body.""" """The actual script body."""
# pylint: disable=too-many-locals,invalid-name,too-many-statements # pylint: disable=too-many-locals,invalid-name,too-many-statements
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -75,7 +79,7 @@ def run(args):
args = parser.parse_args() args = parser.parse_args()
config_dir = os.path.join(os.getcwd(), args.config) config_dir = os.path.join(os.getcwd(), args.config) # type: str
# Test if configuration directory exists # Test if configuration directory exists
if not os.path.isdir(config_dir): if not os.path.isdir(config_dir):

View File

@ -12,21 +12,24 @@ import string
from functools import wraps from functools import wraps
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Sequence from typing import Any, Optional, TypeVar, Callable, Sequence
from .dt import as_local, utcnow from .dt import as_local, utcnow
T = TypeVar('T')
U = TypeVar('U')
RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)') RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)')
RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)') RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)')
RE_SLUGIFY = re.compile(r'[^a-z0-9_]+') RE_SLUGIFY = re.compile(r'[^a-z0-9_]+')
def sanitize_filename(filename): def sanitize_filename(filename: str) -> str:
r"""Sanitize a filename by removing .. / and \\.""" r"""Sanitize a filename by removing .. / and \\."""
return RE_SANITIZE_FILENAME.sub("", filename) return RE_SANITIZE_FILENAME.sub("", filename)
def sanitize_path(path): def sanitize_path(path: str) -> str:
"""Sanitize a path by removing ~ and ..""" """Sanitize a path by removing ~ and .."""
return RE_SANITIZE_PATH.sub("", path) return RE_SANITIZE_PATH.sub("", path)
@ -50,7 +53,8 @@ def repr_helper(inp: Any) -> str:
return str(inp) return str(inp)
def convert(value, to_type, default=None): def convert(value: T, to_type: Callable[[T], U],
default: Optional[U]=None) -> Optional[U]:
"""Convert value to to_type, returns default if fails.""" """Convert value to to_type, returns default if fails."""
try: try:
return default if value is None else to_type(value) return default if value is None else to_type(value)

View File

@ -8,7 +8,7 @@ from typing import Any, Union, Optional, Tuple # NOQA
import pytz import pytz
DATE_STR_FORMAT = "%Y-%m-%d" DATE_STR_FORMAT = "%Y-%m-%d"
UTC = DEFAULT_TIME_ZONE = pytz.utc # type: pytz.UTC UTC = DEFAULT_TIME_ZONE = pytz.utc # type: dt.tzinfo
# Copyright (c) Django Software Foundation and individual contributors. # Copyright (c) Django Software Foundation and individual contributors.
@ -93,11 +93,10 @@ def start_of_local_day(dt_or_d:
Union[dt.date, dt.datetime]=None) -> dt.datetime: Union[dt.date, dt.datetime]=None) -> dt.datetime:
"""Return local datetime object of start of day from date or datetime.""" """Return local datetime object of start of day from date or datetime."""
if dt_or_d is None: if dt_or_d is None:
dt_or_d = now().date() date = now().date() # type: dt.date
elif isinstance(dt_or_d, dt.datetime): elif isinstance(dt_or_d, dt.datetime):
dt_or_d = dt_or_d.date() date = dt_or_d.date()
return DEFAULT_TIME_ZONE.localize(dt.datetime.combine(date, dt.time()))
return DEFAULT_TIME_ZONE.localize(dt.datetime.combine(dt_or_d, dt.time()))
# Copyright (c) Django Software Foundation and individual contributors. # Copyright (c) Django Software Foundation and individual contributors.
@ -118,6 +117,8 @@ def parse_datetime(dt_str: str) -> dt.datetime:
if kws['microsecond']: if kws['microsecond']:
kws['microsecond'] = kws['microsecond'].ljust(6, '0') kws['microsecond'] = kws['microsecond'].ljust(6, '0')
tzinfo_str = kws.pop('tzinfo') tzinfo_str = kws.pop('tzinfo')
tzinfo = None # type: Optional[dt.tzinfo]
if tzinfo_str == 'Z': if tzinfo_str == 'Z':
tzinfo = UTC tzinfo = UTC
elif tzinfo_str is not None: elif tzinfo_str is not None: