Improve type hints in group (#78350)

This commit is contained in:
epenet 2022-09-14 11:36:28 +02:00 committed by GitHub
parent 03a24e3a05
commit 5cccb24830
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 77 deletions

View File

@ -3,10 +3,10 @@ from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
import asyncio import asyncio
from collections.abc import Iterable from collections.abc import Collection, Iterable
from contextvars import ContextVar from contextvars import ContextVar
import logging import logging
from typing import Any, Union, cast from typing import Any, Protocol, Union, cast
import voluptuous as vol import voluptuous as vol
@ -27,7 +27,15 @@ from homeassistant.const import (
STATE_ON, STATE_ON,
Platform, Platform,
) )
from homeassistant.core import HomeAssistant, ServiceCall, callback, split_entity_id from homeassistant.core import (
CALLBACK_TYPE,
Event,
HomeAssistant,
ServiceCall,
State,
callback,
split_entity_id,
)
from homeassistant.helpers import config_validation as cv, entity_registry as er, start from homeassistant.helpers import config_validation as cv, entity_registry as er, start
from homeassistant.helpers.entity import Entity, async_generate_entity_id from homeassistant.helpers.entity import Entity, async_generate_entity_id
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
@ -42,8 +50,6 @@ from homeassistant.loader import bind_hass
from .const import CONF_HIDE_MEMBERS from .const import CONF_HIDE_MEMBERS
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
DOMAIN = "group" DOMAIN = "group"
GROUP_ORDER = "group_order" GROUP_ORDER = "group_order"
@ -79,10 +85,19 @@ _LOGGER = logging.getLogger(__name__)
current_domain: ContextVar[str] = ContextVar("current_domain") current_domain: ContextVar[str] = ContextVar("current_domain")
def _conf_preprocess(value): class GroupProtocol(Protocol):
"""Define the format of group platforms."""
def async_describe_on_off_states(
self, hass: HomeAssistant, registry: GroupIntegrationRegistry
) -> None:
"""Describe group on off states."""
def _conf_preprocess(value: Any) -> dict[str, Any]:
"""Preprocess alternative configuration formats.""" """Preprocess alternative configuration formats."""
if not isinstance(value, dict): if not isinstance(value, dict):
value = {CONF_ENTITIES: value} return {CONF_ENTITIES: value}
return value return value
@ -135,14 +150,15 @@ class GroupIntegrationRegistry:
@bind_hass @bind_hass
def is_on(hass, entity_id): def is_on(hass: HomeAssistant, entity_id: str) -> bool:
"""Test if the group state is in its ON-state.""" """Test if the group state is in its ON-state."""
if REG_KEY not in hass.data: if REG_KEY not in hass.data:
# Integration not setup yet, it cannot be on # Integration not setup yet, it cannot be on
return False return False
if (state := hass.states.get(entity_id)) is not None: if (state := hass.states.get(entity_id)) is not None:
return state.state in hass.data[REG_KEY].on_off_mapping registry: GroupIntegrationRegistry = hass.data[REG_KEY]
return state.state in registry.on_off_mapping
return False return False
@ -408,10 +424,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
async def _process_group_platform(hass, domain, platform): async def _process_group_platform(
hass: HomeAssistant, domain: str, platform: GroupProtocol
) -> None:
"""Process a group platform.""" """Process a group platform."""
current_domain.set(domain) current_domain.set(domain)
platform.async_describe_on_off_states(hass, hass.data[REG_KEY]) registry: GroupIntegrationRegistry = hass.data[REG_KEY]
platform.async_describe_on_off_states(hass, registry)
async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None: async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None:
@ -423,7 +442,7 @@ async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None
for object_id, conf in domain_config.items(): for object_id, conf in domain_config.items():
name: str = conf.get(CONF_NAME, object_id) name: str = conf.get(CONF_NAME, object_id)
entity_ids: Iterable[str] = conf.get(CONF_ENTITIES) or [] entity_ids: Collection[str] = conf.get(CONF_ENTITIES) or []
icon: str | None = conf.get(CONF_ICON) icon: str | None = conf.get(CONF_ICON)
mode = bool(conf.get(CONF_ALL)) mode = bool(conf.get(CONF_ALL))
order: int = hass.data[GROUP_ORDER] order: int = hass.data[GROUP_ORDER]
@ -456,15 +475,12 @@ async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None
class GroupEntity(Entity): class GroupEntity(Entity):
"""Representation of a Group of entities.""" """Representation of a Group of entities."""
@property _attr_should_poll = False
def should_poll(self) -> bool:
"""Disable polling for group."""
return False
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Register listeners.""" """Register listeners."""
async def _update_at_start(_): async def _update_at_start(_: HomeAssistant) -> None:
self.async_update_group_state() self.async_update_group_state()
self.async_write_ha_state() self.async_write_ha_state()
@ -487,6 +503,10 @@ class GroupEntity(Entity):
class Group(Entity): class Group(Entity):
"""Track a group of entity ids.""" """Track a group of entity ids."""
_attr_should_poll = False
tracking: tuple[str, ...]
trackable: tuple[str, ...]
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
@ -494,7 +514,7 @@ class Group(Entity):
order: int | None = None, order: int | None = None,
icon: str | None = None, icon: str | None = None,
user_defined: bool = True, user_defined: bool = True,
entity_ids: Iterable[str] | None = None, entity_ids: Collection[str] | None = None,
mode: bool | None = None, mode: bool | None = None,
) -> None: ) -> None:
"""Initialize a group. """Initialize a group.
@ -503,25 +523,25 @@ class Group(Entity):
""" """
self.hass = hass self.hass = hass
self._name = name self._name = name
self._state = None self._state: str | None = None
self._icon = icon self._icon = icon
self._set_tracked(entity_ids) self._set_tracked(entity_ids)
self._on_off = None self._on_off: dict[str, bool] = {}
self._assumed = None self._assumed: dict[str, bool] = {}
self._on_states = None self._on_states: set[str] = set()
self.user_defined = user_defined self.user_defined = user_defined
self.mode = any self.mode = any
if mode: if mode:
self.mode = all self.mode = all
self._order = order self._order = order
self._assumed_state = False self._assumed_state = False
self._async_unsub_state_changed = None self._async_unsub_state_changed: CALLBACK_TYPE | None = None
@staticmethod @staticmethod
def create_group( def create_group(
hass: HomeAssistant, hass: HomeAssistant,
name: str, name: str,
entity_ids: Iterable[str] | None = None, entity_ids: Collection[str] | None = None,
user_defined: bool = True, user_defined: bool = True,
icon: str | None = None, icon: str | None = None,
object_id: str | None = None, object_id: str | None = None,
@ -541,7 +561,7 @@ class Group(Entity):
def async_create_group_entity( def async_create_group_entity(
hass: HomeAssistant, hass: HomeAssistant,
name: str, name: str,
entity_ids: Iterable[str] | None = None, entity_ids: Collection[str] | None = None,
user_defined: bool = True, user_defined: bool = True,
icon: str | None = None, icon: str | None = None,
object_id: str | None = None, object_id: str | None = None,
@ -577,7 +597,7 @@ class Group(Entity):
async def async_create_group( async def async_create_group(
hass: HomeAssistant, hass: HomeAssistant,
name: str, name: str,
entity_ids: Iterable[str] | None = None, entity_ids: Collection[str] | None = None,
user_defined: bool = True, user_defined: bool = True,
icon: str | None = None, icon: str | None = None,
object_id: str | None = None, object_id: str | None = None,
@ -597,37 +617,32 @@ class Group(Entity):
return group return group
@property @property
def should_poll(self): def name(self) -> str:
"""No need to poll because groups will update themselves."""
return False
@property
def name(self):
"""Return the name of the group.""" """Return the name of the group."""
return self._name return self._name
@name.setter @name.setter
def name(self, value): def name(self, value: str) -> None:
"""Set Group name.""" """Set Group name."""
self._name = value self._name = value
@property @property
def state(self): def state(self) -> str | None:
"""Return the state of the group.""" """Return the state of the group."""
return self._state return self._state
@property @property
def icon(self): def icon(self) -> str | None:
"""Return the icon of the group.""" """Return the icon of the group."""
return self._icon return self._icon
@icon.setter @icon.setter
def icon(self, value): def icon(self, value: str | None) -> None:
"""Set Icon for group.""" """Set Icon for group."""
self._icon = value self._icon = value
@property @property
def extra_state_attributes(self): def extra_state_attributes(self) -> dict[str, Any]:
"""Return the state attributes for the group.""" """Return the state attributes for the group."""
data = {ATTR_ENTITY_ID: self.tracking, ATTR_ORDER: self._order} data = {ATTR_ENTITY_ID: self.tracking, ATTR_ORDER: self._order}
if not self.user_defined: if not self.user_defined:
@ -636,17 +651,19 @@ class Group(Entity):
return data return data
@property @property
def assumed_state(self): def assumed_state(self) -> bool:
"""Test if any member has an assumed state.""" """Test if any member has an assumed state."""
return self._assumed_state return self._assumed_state
def update_tracked_entity_ids(self, entity_ids): def update_tracked_entity_ids(self, entity_ids: Collection[str] | None) -> None:
"""Update the member entity IDs.""" """Update the member entity IDs."""
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
self.async_update_tracked_entity_ids(entity_ids), self.hass.loop self.async_update_tracked_entity_ids(entity_ids), self.hass.loop
).result() ).result()
async def async_update_tracked_entity_ids(self, entity_ids): async def async_update_tracked_entity_ids(
self, entity_ids: Collection[str] | None
) -> None:
"""Update the member entity IDs. """Update the member entity IDs.
This method must be run in the event loop. This method must be run in the event loop.
@ -656,7 +673,7 @@ class Group(Entity):
self._reset_tracked_state() self._reset_tracked_state()
self._async_start() self._async_start()
def _set_tracked(self, entity_ids): def _set_tracked(self, entity_ids: Collection[str] | None) -> None:
"""Tuple of entities to be tracked.""" """Tuple of entities to be tracked."""
# tracking are the entities we want to track # tracking are the entities we want to track
# trackable are the entities we actually watch # trackable are the entities we actually watch
@ -666,10 +683,11 @@ class Group(Entity):
self.trackable = () self.trackable = ()
return return
excluded_domains = self.hass.data[REG_KEY].exclude_domains registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
excluded_domains = registry.exclude_domains
tracking = [] tracking: list[str] = []
trackable = [] trackable: list[str] = []
for ent_id in entity_ids: for ent_id in entity_ids:
ent_id_lower = ent_id.lower() ent_id_lower = ent_id.lower()
domain = split_entity_id(ent_id_lower)[0] domain = split_entity_id(ent_id_lower)[0]
@ -681,14 +699,14 @@ class Group(Entity):
self.tracking = tuple(tracking) self.tracking = tuple(tracking)
@callback @callback
def _async_start(self, *_): def _async_start(self, _: HomeAssistant | None = None) -> None:
"""Start tracking members and write state.""" """Start tracking members and write state."""
self._reset_tracked_state() self._reset_tracked_state()
self._async_start_tracking() self._async_start_tracking()
self.async_write_ha_state() self.async_write_ha_state()
@callback @callback
def _async_start_tracking(self): def _async_start_tracking(self) -> None:
"""Start tracking members. """Start tracking members.
This method must be run in the event loop. This method must be run in the event loop.
@ -701,7 +719,7 @@ class Group(Entity):
self._async_update_group_state() self._async_update_group_state()
@callback @callback
def _async_stop(self): def _async_stop(self) -> None:
"""Unregister the group from Home Assistant. """Unregister the group from Home Assistant.
This method must be run in the event loop. This method must be run in the event loop.
@ -711,20 +729,20 @@ class Group(Entity):
self._async_unsub_state_changed = None self._async_unsub_state_changed = None
@callback @callback
def async_update_group_state(self): def async_update_group_state(self) -> None:
"""Query all members and determine current group state.""" """Query all members and determine current group state."""
self._state = None self._state = None
self._async_update_group_state() self._async_update_group_state()
async def async_added_to_hass(self): async def async_added_to_hass(self) -> None:
"""Handle addition to Home Assistant.""" """Handle addition to Home Assistant."""
self.async_on_remove(start.async_at_start(self.hass, self._async_start)) self.async_on_remove(start.async_at_start(self.hass, self._async_start))
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self) -> None:
"""Handle removal from Home Assistant.""" """Handle removal from Home Assistant."""
self._async_stop() self._async_stop()
async def _async_state_changed_listener(self, event): async def _async_state_changed_listener(self, event: Event) -> None:
"""Respond to a member state changing. """Respond to a member state changing.
This method must be run in the event loop. This method must be run in the event loop.
@ -742,7 +760,7 @@ class Group(Entity):
self._async_update_group_state(new_state) self._async_update_group_state(new_state)
self.async_write_ha_state() self.async_write_ha_state()
def _reset_tracked_state(self): def _reset_tracked_state(self) -> None:
"""Reset tracked state.""" """Reset tracked state."""
self._on_off = {} self._on_off = {}
self._assumed = {} self._assumed = {}
@ -752,13 +770,13 @@ class Group(Entity):
if (state := self.hass.states.get(entity_id)) is not None: if (state := self.hass.states.get(entity_id)) is not None:
self._see_state(state) self._see_state(state)
def _see_state(self, new_state): def _see_state(self, new_state: State) -> None:
"""Keep track of the the state.""" """Keep track of the the state."""
entity_id = new_state.entity_id entity_id = new_state.entity_id
domain = new_state.domain domain = new_state.domain
state = new_state.state state = new_state.state
registry = self.hass.data[REG_KEY] registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
self._assumed[entity_id] = new_state.attributes.get(ATTR_ASSUMED_STATE) self._assumed[entity_id] = bool(new_state.attributes.get(ATTR_ASSUMED_STATE))
if domain not in registry.on_states_by_domain: if domain not in registry.on_states_by_domain:
# Handle the group of a group case # Handle the group of a group case
@ -769,12 +787,12 @@ class Group(Entity):
self._on_off[entity_id] = state in registry.on_off_mapping self._on_off[entity_id] = state in registry.on_off_mapping
else: else:
entity_on_state = registry.on_states_by_domain[domain] entity_on_state = registry.on_states_by_domain[domain]
if domain in self.hass.data[REG_KEY].on_states_by_domain: if domain in registry.on_states_by_domain:
self._on_states.update(entity_on_state) self._on_states.update(entity_on_state)
self._on_off[entity_id] = state in entity_on_state self._on_off[entity_id] = state in entity_on_state
@callback @callback
def _async_update_group_state(self, tr_state=None): def _async_update_group_state(self, tr_state: State | None = None) -> None:
"""Update group state. """Update group state.
Optionally you can provide the only state changed since last update Optionally you can provide the only state changed since last update
@ -818,4 +836,5 @@ class Group(Entity):
if group_is_on: if group_is_on:
self._state = on_state self._state = on_state
else: else:
self._state = self.hass.data[REG_KEY].on_off_mapping[on_state] registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
self._state = registry.on_off_mapping[on_state]

View File

@ -103,6 +103,7 @@ class MediaPlayerGroup(MediaPlayerEntity):
"""Representation of a Media Group.""" """Representation of a Media Group."""
_attr_available: bool = False _attr_available: bool = False
_attr_should_poll = False
def __init__(self, unique_id: str | None, name: str, entities: list[str]) -> None: def __init__(self, unique_id: str | None, name: str, entities: list[str]) -> None:
"""Initialize a Media Group entity.""" """Initialize a Media Group entity."""
@ -216,11 +217,6 @@ class MediaPlayerGroup(MediaPlayerEntity):
"""Flag supported features.""" """Flag supported features."""
return self._supported_features return self._supported_features
@property
def should_poll(self) -> bool:
"""No polling needed for a media group."""
return False
@property @property
def extra_state_attributes(self) -> dict: def extra_state_attributes(self) -> dict:
"""Return the state attributes for the media group.""" """Return the state attributes for the media group."""

View File

@ -1,7 +1,10 @@
"""Group platform for notify component.""" """Group platform for notify component."""
from __future__ import annotations
import asyncio import asyncio
from collections.abc import Mapping from collections.abc import Coroutine, Mapping
from copy import deepcopy from copy import deepcopy
from typing import Any
import voluptuous as vol import voluptuous as vol
@ -13,9 +16,9 @@ from homeassistant.components.notify import (
BaseNotificationService, BaseNotificationService,
) )
from homeassistant.const import ATTR_SERVICE from homeassistant.const import ATTR_SERVICE
from homeassistant.core import HomeAssistant
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
CONF_SERVICES = "services" CONF_SERVICES = "services"
@ -29,46 +32,50 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
) )
def update(input_dict, update_source): def update(input_dict: dict[str, Any], update_source: dict[str, Any]) -> dict[str, Any]:
"""Deep update a dictionary. """Deep update a dictionary.
Async friendly. Async friendly.
""" """
for key, val in update_source.items(): for key, val in update_source.items():
if isinstance(val, Mapping): if isinstance(val, Mapping):
recurse = update(input_dict.get(key, {}), val) recurse = update(input_dict.get(key, {}), val) # type: ignore[arg-type]
input_dict[key] = recurse input_dict[key] = recurse
else: else:
input_dict[key] = update_source[key] input_dict[key] = update_source[key]
return input_dict return input_dict
async def async_get_service(hass, config, discovery_info=None): async def async_get_service(
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> GroupNotifyPlatform:
"""Get the Group notification service.""" """Get the Group notification service."""
return GroupNotifyPlatform(hass, config.get(CONF_SERVICES)) return GroupNotifyPlatform(hass, config[CONF_SERVICES])
class GroupNotifyPlatform(BaseNotificationService): class GroupNotifyPlatform(BaseNotificationService):
"""Implement the notification service for the group notify platform.""" """Implement the notification service for the group notify platform."""
def __init__(self, hass, entities): def __init__(self, hass: HomeAssistant, entities: list[dict[str, Any]]) -> None:
"""Initialize the service.""" """Initialize the service."""
self.hass = hass self.hass = hass
self.entities = entities self.entities = entities
async def async_send_message(self, message="", **kwargs): async def async_send_message(self, message: str = "", **kwargs: Any) -> None:
"""Send message to all entities in the group.""" """Send message to all entities in the group."""
payload = {ATTR_MESSAGE: message} payload: dict[str, Any] = {ATTR_MESSAGE: message}
payload.update({key: val for key, val in kwargs.items() if val}) payload.update({key: val for key, val in kwargs.items() if val})
tasks = [] tasks: list[Coroutine[Any, Any, bool | None]] = []
for entity in self.entities: for entity in self.entities:
sending_payload = deepcopy(payload.copy()) sending_payload = deepcopy(payload.copy())
if entity.get(ATTR_DATA) is not None: if (data := entity.get(ATTR_DATA)) is not None:
update(sending_payload, entity.get(ATTR_DATA)) update(sending_payload, data)
tasks.append( tasks.append(
self.hass.services.async_call( self.hass.services.async_call(
DOMAIN, entity.get(ATTR_SERVICE), sending_payload DOMAIN, entity[ATTR_SERVICE], sending_payload
) )
) )