diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 2ff444da89f..4818de83cb9 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -2,7 +2,7 @@ from asyncio import Event from collections import OrderedDict import logging -from typing import List, Optional, cast +from typing import Any, Dict, List, Optional, cast import uuid import attr @@ -48,7 +48,7 @@ class DeviceEntry: is_new = attr.ib(type=bool, default=False) -def format_mac(mac): +def format_mac(mac: str) -> str: """Format the mac address string for entry into dev reg.""" to_test = mac @@ -260,7 +260,7 @@ class DeviceRegistry: return new - def async_remove_device(self, device_id): + def async_remove_device(self, device_id: str) -> None: """Remove a device from the device registry.""" del self.devices[device_id] self.hass.bus.async_fire( @@ -298,12 +298,12 @@ class DeviceRegistry: self.devices = devices @callback - def async_schedule_save(self): + def async_schedule_save(self) -> None: """Schedule saving the device registry.""" self._store.async_delay_save(self._data_to_save, SAVE_DELAY) @callback - def _data_to_save(self): + def _data_to_save(self) -> Dict[str, List[Dict[str, Any]]]: """Return data of device registry to store in a file.""" data = {} @@ -327,7 +327,7 @@ class DeviceRegistry: return data @callback - def async_clear_config_entry(self, config_entry_id): + def async_clear_config_entry(self, config_entry_id: str) -> None: """Clear config entry from registry entries.""" remove = [] for dev_id, device in self.devices.items(): diff --git a/homeassistant/helpers/discovery.py b/homeassistant/helpers/discovery.py index 8e4def77440..a6162dbde55 100644 --- a/homeassistant/helpers/discovery.py +++ b/homeassistant/helpers/discovery.py @@ -5,6 +5,8 @@ There are two different types of discoveries that can be fired/listened for. - listen_platform/discover_platform is for platforms. These are used by components to allow discovery of their platforms. """ +from typing import Callable, Collection, Union + from homeassistant import core, setup from homeassistant.const import ATTR_DISCOVERED, ATTR_SERVICE, EVENT_PLATFORM_DISCOVERED from homeassistant.exceptions import HomeAssistantError @@ -18,7 +20,9 @@ ATTR_PLATFORM = "platform" @bind_hass -def listen(hass, service, callback): +def listen( + hass: core.HomeAssistant, service: Union[str, Collection[str]], callback: Callable +) -> None: """Set up listener for discovery of specific service. Service can be a string or a list/tuple. @@ -28,7 +32,9 @@ def listen(hass, service, callback): @core.callback @bind_hass -def async_listen(hass, service, callback): +def async_listen( + hass: core.HomeAssistant, service: Union[str, Collection[str]], callback: Callable +) -> None: """Set up listener for discovery of specific service. Service can be a string or a list/tuple. @@ -39,7 +45,7 @@ def async_listen(hass, service, callback): service = tuple(service) @core.callback - def discovery_event_listener(event): + def discovery_event_listener(event: core.Event) -> None: """Listen for discovery events.""" if ATTR_SERVICE in event.data and event.data[ATTR_SERVICE] in service: hass.async_add_job( @@ -73,7 +79,9 @@ async def async_discover(hass, service, discovered, component, hass_config): @bind_hass -def listen_platform(hass, component, callback): +def listen_platform( + hass: core.HomeAssistant, component: str, callback: Callable +) -> None: """Register a platform loader listener.""" run_callback_threadsafe( hass.loop, async_listen_platform, hass, component, callback @@ -81,7 +89,9 @@ def listen_platform(hass, component, callback): @bind_hass -def async_listen_platform(hass, component, callback): +def async_listen_platform( + hass: core.HomeAssistant, component: str, callback: Callable +) -> None: """Register a platform loader listener. This method must be run in the event loop. @@ -89,7 +99,7 @@ def async_listen_platform(hass, component, callback): service = EVENT_LOAD_PLATFORM.format(component) @core.callback - def discovery_platform_listener(event): + def discovery_platform_listener(event: core.Event) -> None: """Listen for platform discovery events.""" if event.data.get(ATTR_SERVICE) != service: return diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index becd96bf5f3..84aa8becafd 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -5,13 +5,14 @@ from itertools import chain import logging from homeassistant import config as conf_util +from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( ATTR_ENTITY_ID, CONF_ENTITY_NAMESPACE, CONF_SCAN_INTERVAL, ENTITY_MATCH_ALL, ) -from homeassistant.core import callback +from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_per_platform, discovery from homeassistant.helpers.config_validation import make_entity_service_schema @@ -29,7 +30,7 @@ DATA_INSTANCES = "entity_components" @bind_hass -async def async_update_entity(hass, entity_id): +async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None: """Trigger an update for an entity.""" domain = entity_id.split(".", 1)[0] entity_comp = hass.data.get(DATA_INSTANCES, {}).get(domain) @@ -158,7 +159,7 @@ class EntityComponent: return await self._platforms[key].async_setup_entry(config_entry) - async def async_unload_entry(self, config_entry): + async def async_unload_entry(self, config_entry: ConfigEntry) -> bool: """Unload a config entry.""" key = config_entry.entry_id @@ -237,7 +238,7 @@ class EntityComponent: await self._platforms[key].async_setup(platform_config, discovery_info) @callback - def _async_update_group(self): + def _async_update_group(self) -> None: """Set up and/or update component group. This method must be run in the event loop. @@ -265,7 +266,7 @@ class EntityComponent: ) ) - async def _async_reset(self): + async def _async_reset(self) -> None: """Remove entities and reset the entity component to initial values. This method must be run in the event loop. @@ -283,7 +284,7 @@ class EntityComponent: "group", "remove", dict(object_id=slugify(self.group_name)) ) - async def async_remove_entity(self, entity_id): + async def async_remove_entity(self, entity_id: str) -> None: """Remove an entity managed by one of the platforms.""" for platform in self._platforms.values(): if entity_id in platform.entities: diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 133d1a5841f..e171a4cade8 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -1,6 +1,7 @@ """Class to manage the entities for a single platform.""" import asyncio from contextvars import ContextVar +from datetime import datetime from typing import Optional from homeassistant.const import DEVICE_DEFAULT_NAME @@ -64,14 +65,14 @@ class EntityPlatform: # which powers entity_component.add_entities if platform is None: self.parallel_updates = None - self.parallel_updates_semaphore = None + self.parallel_updates_semaphore: Optional[asyncio.Semaphore] = None return self.parallel_updates = getattr(platform, "PARALLEL_UPDATES", None) # semaphore will be created on demand self.parallel_updates_semaphore = None - def _get_parallel_updates_semaphore(self): + def _get_parallel_updates_semaphore(self) -> asyncio.Semaphore: """Get or create a semaphore for parallel updates.""" if self.parallel_updates_semaphore is None: self.parallel_updates_semaphore = asyncio.Semaphore( @@ -406,7 +407,7 @@ class EntityPlatform: await entity.async_update_ha_state() - async def async_reset(self): + async def async_reset(self) -> None: """Remove all entities and reset data. This method must be run in the event loop. @@ -426,7 +427,7 @@ class EntityPlatform: self._async_unsub_polling() self._async_unsub_polling = None - async def async_remove_entity(self, entity_id): + async def async_remove_entity(self, entity_id: str) -> None: """Remove entity id from platform.""" await self.entities[entity_id].async_remove() @@ -437,7 +438,7 @@ class EntityPlatform: self._async_unsub_polling() self._async_unsub_polling = None - async def _update_entity_states(self, now): + async def _update_entity_states(self, now: datetime) -> None: """Update the states of all the polling entities. To protect from flooding the executor, we will update async entities