Helpers type hint improvements (#30106)

This commit is contained in:
Ville Skyttä 2019-12-21 09:23:48 +02:00 committed by GitHub
parent ecdc1adf90
commit 6604680793
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 35 additions and 23 deletions

View File

@ -2,7 +2,7 @@
from asyncio import Event from asyncio import Event
from collections import OrderedDict from collections import OrderedDict
import logging import logging
from typing import List, Optional, cast from typing import Any, Dict, List, Optional, cast
import uuid import uuid
import attr import attr
@ -48,7 +48,7 @@ class DeviceEntry:
is_new = attr.ib(type=bool, default=False) is_new = attr.ib(type=bool, default=False)
def format_mac(mac): def format_mac(mac: str) -> str:
"""Format the mac address string for entry into dev reg.""" """Format the mac address string for entry into dev reg."""
to_test = mac to_test = mac
@ -260,7 +260,7 @@ class DeviceRegistry:
return new return new
def async_remove_device(self, device_id): def async_remove_device(self, device_id: str) -> None:
"""Remove a device from the device registry.""" """Remove a device from the device registry."""
del self.devices[device_id] del self.devices[device_id]
self.hass.bus.async_fire( self.hass.bus.async_fire(
@ -298,12 +298,12 @@ class DeviceRegistry:
self.devices = devices self.devices = devices
@callback @callback
def async_schedule_save(self): def async_schedule_save(self) -> None:
"""Schedule saving the device registry.""" """Schedule saving the device registry."""
self._store.async_delay_save(self._data_to_save, SAVE_DELAY) self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback @callback
def _data_to_save(self): def _data_to_save(self) -> Dict[str, List[Dict[str, Any]]]:
"""Return data of device registry to store in a file.""" """Return data of device registry to store in a file."""
data = {} data = {}
@ -327,7 +327,7 @@ class DeviceRegistry:
return data return data
@callback @callback
def async_clear_config_entry(self, config_entry_id): def async_clear_config_entry(self, config_entry_id: str) -> None:
"""Clear config entry from registry entries.""" """Clear config entry from registry entries."""
remove = [] remove = []
for dev_id, device in self.devices.items(): for dev_id, device in self.devices.items():

View File

@ -5,6 +5,8 @@ There are two different types of discoveries that can be fired/listened for.
- listen_platform/discover_platform is for platforms. These are used by - listen_platform/discover_platform is for platforms. These are used by
components to allow discovery of their platforms. components to allow discovery of their platforms.
""" """
from typing import Callable, Collection, Union
from homeassistant import core, setup from homeassistant import core, setup
from homeassistant.const import ATTR_DISCOVERED, ATTR_SERVICE, EVENT_PLATFORM_DISCOVERED from homeassistant.const import ATTR_DISCOVERED, ATTR_SERVICE, EVENT_PLATFORM_DISCOVERED
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -18,7 +20,9 @@ ATTR_PLATFORM = "platform"
@bind_hass @bind_hass
def listen(hass, service, callback): def listen(
hass: core.HomeAssistant, service: Union[str, Collection[str]], callback: Callable
) -> None:
"""Set up listener for discovery of specific service. """Set up listener for discovery of specific service.
Service can be a string or a list/tuple. Service can be a string or a list/tuple.
@ -28,7 +32,9 @@ def listen(hass, service, callback):
@core.callback @core.callback
@bind_hass @bind_hass
def async_listen(hass, service, callback): def async_listen(
hass: core.HomeAssistant, service: Union[str, Collection[str]], callback: Callable
) -> None:
"""Set up listener for discovery of specific service. """Set up listener for discovery of specific service.
Service can be a string or a list/tuple. Service can be a string or a list/tuple.
@ -39,7 +45,7 @@ def async_listen(hass, service, callback):
service = tuple(service) service = tuple(service)
@core.callback @core.callback
def discovery_event_listener(event): def discovery_event_listener(event: core.Event) -> None:
"""Listen for discovery events.""" """Listen for discovery events."""
if ATTR_SERVICE in event.data and event.data[ATTR_SERVICE] in service: if ATTR_SERVICE in event.data and event.data[ATTR_SERVICE] in service:
hass.async_add_job( hass.async_add_job(
@ -73,7 +79,9 @@ async def async_discover(hass, service, discovered, component, hass_config):
@bind_hass @bind_hass
def listen_platform(hass, component, callback): def listen_platform(
hass: core.HomeAssistant, component: str, callback: Callable
) -> None:
"""Register a platform loader listener.""" """Register a platform loader listener."""
run_callback_threadsafe( run_callback_threadsafe(
hass.loop, async_listen_platform, hass, component, callback hass.loop, async_listen_platform, hass, component, callback
@ -81,7 +89,9 @@ def listen_platform(hass, component, callback):
@bind_hass @bind_hass
def async_listen_platform(hass, component, callback): def async_listen_platform(
hass: core.HomeAssistant, component: str, callback: Callable
) -> None:
"""Register a platform loader listener. """Register a platform loader listener.
This method must be run in the event loop. This method must be run in the event loop.
@ -89,7 +99,7 @@ def async_listen_platform(hass, component, callback):
service = EVENT_LOAD_PLATFORM.format(component) service = EVENT_LOAD_PLATFORM.format(component)
@core.callback @core.callback
def discovery_platform_listener(event): def discovery_platform_listener(event: core.Event) -> None:
"""Listen for platform discovery events.""" """Listen for platform discovery events."""
if event.data.get(ATTR_SERVICE) != service: if event.data.get(ATTR_SERVICE) != service:
return return

View File

@ -5,13 +5,14 @@ from itertools import chain
import logging import logging
from homeassistant import config as conf_util from homeassistant import config as conf_util
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
CONF_ENTITY_NAMESPACE, CONF_ENTITY_NAMESPACE,
CONF_SCAN_INTERVAL, CONF_SCAN_INTERVAL,
ENTITY_MATCH_ALL, ENTITY_MATCH_ALL,
) )
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform, discovery from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.config_validation import make_entity_service_schema from homeassistant.helpers.config_validation import make_entity_service_schema
@ -29,7 +30,7 @@ DATA_INSTANCES = "entity_components"
@bind_hass @bind_hass
async def async_update_entity(hass, entity_id): async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None:
"""Trigger an update for an entity.""" """Trigger an update for an entity."""
domain = entity_id.split(".", 1)[0] domain = entity_id.split(".", 1)[0]
entity_comp = hass.data.get(DATA_INSTANCES, {}).get(domain) entity_comp = hass.data.get(DATA_INSTANCES, {}).get(domain)
@ -158,7 +159,7 @@ class EntityComponent:
return await self._platforms[key].async_setup_entry(config_entry) return await self._platforms[key].async_setup_entry(config_entry)
async def async_unload_entry(self, config_entry): async def async_unload_entry(self, config_entry: ConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
key = config_entry.entry_id key = config_entry.entry_id
@ -237,7 +238,7 @@ class EntityComponent:
await self._platforms[key].async_setup(platform_config, discovery_info) await self._platforms[key].async_setup(platform_config, discovery_info)
@callback @callback
def _async_update_group(self): def _async_update_group(self) -> None:
"""Set up and/or update component group. """Set up and/or update component group.
This method must be run in the event loop. This method must be run in the event loop.
@ -265,7 +266,7 @@ class EntityComponent:
) )
) )
async def _async_reset(self): async def _async_reset(self) -> None:
"""Remove entities and reset the entity component to initial values. """Remove entities and reset the entity component to initial values.
This method must be run in the event loop. This method must be run in the event loop.
@ -283,7 +284,7 @@ class EntityComponent:
"group", "remove", dict(object_id=slugify(self.group_name)) "group", "remove", dict(object_id=slugify(self.group_name))
) )
async def async_remove_entity(self, entity_id): async def async_remove_entity(self, entity_id: str) -> None:
"""Remove an entity managed by one of the platforms.""" """Remove an entity managed by one of the platforms."""
for platform in self._platforms.values(): for platform in self._platforms.values():
if entity_id in platform.entities: if entity_id in platform.entities:

View File

@ -1,6 +1,7 @@
"""Class to manage the entities for a single platform.""" """Class to manage the entities for a single platform."""
import asyncio import asyncio
from contextvars import ContextVar from contextvars import ContextVar
from datetime import datetime
from typing import Optional from typing import Optional
from homeassistant.const import DEVICE_DEFAULT_NAME from homeassistant.const import DEVICE_DEFAULT_NAME
@ -64,14 +65,14 @@ class EntityPlatform:
# which powers entity_component.add_entities # which powers entity_component.add_entities
if platform is None: if platform is None:
self.parallel_updates = None self.parallel_updates = None
self.parallel_updates_semaphore = None self.parallel_updates_semaphore: Optional[asyncio.Semaphore] = None
return return
self.parallel_updates = getattr(platform, "PARALLEL_UPDATES", None) self.parallel_updates = getattr(platform, "PARALLEL_UPDATES", None)
# semaphore will be created on demand # semaphore will be created on demand
self.parallel_updates_semaphore = None self.parallel_updates_semaphore = None
def _get_parallel_updates_semaphore(self): def _get_parallel_updates_semaphore(self) -> asyncio.Semaphore:
"""Get or create a semaphore for parallel updates.""" """Get or create a semaphore for parallel updates."""
if self.parallel_updates_semaphore is None: if self.parallel_updates_semaphore is None:
self.parallel_updates_semaphore = asyncio.Semaphore( self.parallel_updates_semaphore = asyncio.Semaphore(
@ -406,7 +407,7 @@ class EntityPlatform:
await entity.async_update_ha_state() await entity.async_update_ha_state()
async def async_reset(self): async def async_reset(self) -> None:
"""Remove all entities and reset data. """Remove all entities and reset data.
This method must be run in the event loop. This method must be run in the event loop.
@ -426,7 +427,7 @@ class EntityPlatform:
self._async_unsub_polling() self._async_unsub_polling()
self._async_unsub_polling = None self._async_unsub_polling = None
async def async_remove_entity(self, entity_id): async def async_remove_entity(self, entity_id: str) -> None:
"""Remove entity id from platform.""" """Remove entity id from platform."""
await self.entities[entity_id].async_remove() await self.entities[entity_id].async_remove()
@ -437,7 +438,7 @@ class EntityPlatform:
self._async_unsub_polling() self._async_unsub_polling()
self._async_unsub_polling = None self._async_unsub_polling = None
async def _update_entity_states(self, now): async def _update_entity_states(self, now: datetime) -> None:
"""Update the states of all the polling entities. """Update the states of all the polling entities.
To protect from flooding the executor, we will update async entities To protect from flooding the executor, we will update async entities