From bb667abd514c0246ca0bc40a7a7319f05494c566 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 15 Mar 2024 15:45:43 -1000 Subject: [PATCH] Cleanup some circular imports in group (#113554) --- homeassistant/components/group/__init__.py | 545 +----------------- .../components/group/binary_sensor.py | 2 +- homeassistant/components/group/config_flow.py | 6 +- homeassistant/components/group/const.py | 13 + homeassistant/components/group/cover.py | 2 +- homeassistant/components/group/entity.py | 477 +++++++++++++++ homeassistant/components/group/event.py | 2 +- homeassistant/components/group/fan.py | 2 +- homeassistant/components/group/light.py | 2 +- homeassistant/components/group/lock.py | 2 +- homeassistant/components/group/registry.py | 68 +++ homeassistant/components/group/sensor.py | 4 +- homeassistant/components/group/switch.py | 2 +- 13 files changed, 588 insertions(+), 539 deletions(-) create mode 100644 homeassistant/components/group/entity.py create mode 100644 homeassistant/components/group/registry.py diff --git a/homeassistant/components/group/__init__.py b/homeassistant/components/group/__init__.py index 778c4da9c9f..120c2d18290 100644 --- a/homeassistant/components/group/__init__.py +++ b/homeassistant/components/group/__init__.py @@ -2,74 +2,53 @@ from __future__ import annotations -from abc import abstractmethod import asyncio -from collections.abc import Callable, Collection, Mapping -from contextvars import ContextVar +from collections.abc import Collection import logging -from typing import Any, Protocol +from typing import Any import voluptuous as vol from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( - ATTR_ASSUMED_STATE, - ATTR_ENTITY_ID, + ATTR_ENTITY_ID, # noqa: F401 ATTR_ICON, ATTR_NAME, CONF_ENTITIES, CONF_ICON, CONF_NAME, SERVICE_RELOAD, - STATE_OFF, - STATE_ON, Platform, ) -from homeassistant.core import ( - CALLBACK_TYPE, - Event, - HomeAssistant, - ServiceCall, - State, - callback, - split_entity_id, -) -from homeassistant.helpers import config_validation as cv, entity_registry as er, start -from homeassistant.helpers.entity import Entity, async_generate_entity_id +from homeassistant.core import HomeAssistant, ServiceCall +from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.event import ( - EventStateChangedData, - async_track_state_change_event, -) from homeassistant.helpers.group import ( expand_entity_ids as _expand_entity_ids, get_entity_ids as _get_entity_ids, ) -from homeassistant.helpers.integration_platform import ( - async_process_integration_platforms, -) from homeassistant.helpers.reload import async_reload_integration_platforms from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass -from .const import ( +from .const import ( # noqa: F401 + ATTR_ADD_ENTITIES, + ATTR_ALL, + ATTR_AUTO, + ATTR_ENTITIES, + ATTR_OBJECT_ID, + ATTR_ORDER, + ATTR_REMOVE_ENTITIES, CONF_HIDE_MEMBERS, DOMAIN, # noqa: F401 + GROUP_ORDER, + REG_KEY, ) - -GROUP_ORDER = "group_order" - -ENTITY_ID_FORMAT = DOMAIN + ".{}" +from .entity import Group, async_get_component +from .registry import GroupIntegrationRegistry, async_setup as async_setup_registry CONF_ALL = "all" -ATTR_ADD_ENTITIES = "add_entities" -ATTR_REMOVE_ENTITIES = "remove_entities" -ATTR_AUTO = "auto" -ATTR_ENTITIES = "entities" -ATTR_OBJECT_ID = "object_id" -ATTR_ORDER = "order" -ATTR_ALL = "all" SERVICE_SET = "set" SERVICE_REMOVE = "remove" @@ -86,23 +65,8 @@ PLATFORMS = [ Platform.SWITCH, ] -REG_KEY = f"{DOMAIN}_registry" - -ENTITY_PREFIX = f"{DOMAIN}." - _LOGGER = logging.getLogger(__name__) -current_domain: ContextVar[str] = ContextVar("current_domain") - - -class GroupProtocol(Protocol): - """Define the format of group platforms.""" - - def async_describe_on_off_states( - self, hass: HomeAssistant, registry: GroupIntegrationRegistry - ) -> None: - """Describe group on off states.""" - def _conf_preprocess(value: Any) -> dict[str, Any]: """Preprocess alternative configuration formats.""" @@ -129,36 +93,6 @@ CONFIG_SCHEMA = vol.Schema( ) -def _async_get_component(hass: HomeAssistant) -> EntityComponent[Group]: - if (component := hass.data.get(DOMAIN)) is None: - component = hass.data[DOMAIN] = EntityComponent[Group](_LOGGER, DOMAIN, hass) - return component - - -class GroupIntegrationRegistry: - """Class to hold a registry of integrations.""" - - on_off_mapping: dict[str, str] = {STATE_ON: STATE_OFF} - off_on_mapping: dict[str, str] = {STATE_OFF: STATE_ON} - on_states_by_domain: dict[str, set] = {} - exclude_domains: set = set() - - def exclude_domain(self) -> None: - """Exclude the current domain.""" - self.exclude_domains.add(current_domain.get()) - - def on_off_states(self, on_states: set, off_state: str) -> None: - """Register on and off states for the current domain.""" - for on_state in on_states: - if on_state not in self.on_off_mapping: - self.on_off_mapping[on_state] = off_state - - if len(on_states) == 1 and off_state not in self.off_on_mapping: - self.off_on_mapping[off_state] = list(on_states)[0] - - self.on_states_by_domain[current_domain.get()] = set(on_states) - - @bind_hass def is_on(hass: HomeAssistant, entity_id: str) -> bool: """Test if the group state is in its ON-state.""" @@ -241,11 +175,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: component: EntityComponent[Group] = hass.data[DOMAIN] - hass.data[REG_KEY] = GroupIntegrationRegistry() - - await async_process_integration_platforms( - hass, DOMAIN, _process_group_platform, wait_for_platforms=True - ) + await async_setup_registry(hass) await _async_process_config(hass, config) @@ -387,16 +317,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -@callback -def _process_group_platform( - hass: HomeAssistant, domain: str, platform: GroupProtocol -) -> None: - """Process a group platform.""" - current_domain.set(domain) - registry: GroupIntegrationRegistry = hass.data[REG_KEY] - platform.async_describe_on_off_states(hass, registry) - - async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None: """Process group configuration.""" hass.data.setdefault(GROUP_ORDER, 0) @@ -434,431 +354,4 @@ async def _async_process_config(hass: HomeAssistant, config: ConfigType) -> None hass.data[GROUP_ORDER] += 1 # If called before the platform async_setup is called (test cases) - await _async_get_component(hass).async_add_entities(entities) - - -class GroupEntity(Entity): - """Representation of a Group of entities.""" - - _unrecorded_attributes = frozenset({ATTR_ENTITY_ID}) - - _attr_should_poll = False - _entity_ids: list[str] - - @callback - def async_start_preview( - self, - preview_callback: Callable[[str, Mapping[str, Any]], None], - ) -> CALLBACK_TYPE: - """Render a preview.""" - - for entity_id in self._entity_ids: - if (state := self.hass.states.get(entity_id)) is None: - continue - self.async_update_supported_features(entity_id, state) - - @callback - def async_state_changed_listener( - event: Event[EventStateChangedData] | None, - ) -> None: - """Handle child updates.""" - self.async_update_group_state() - if event: - self.async_update_supported_features( - event.data["entity_id"], event.data["new_state"] - ) - calculated_state = self._async_calculate_state() - preview_callback(calculated_state.state, calculated_state.attributes) - - async_state_changed_listener(None) - return async_track_state_change_event( - self.hass, self._entity_ids, async_state_changed_listener - ) - - async def async_added_to_hass(self) -> None: - """Register listeners.""" - for entity_id in self._entity_ids: - if (state := self.hass.states.get(entity_id)) is None: - continue - self.async_update_supported_features(entity_id, state) - - @callback - def async_state_changed_listener( - event: Event[EventStateChangedData], - ) -> None: - """Handle child updates.""" - self.async_set_context(event.context) - self.async_update_supported_features( - event.data["entity_id"], event.data["new_state"] - ) - self.async_defer_or_update_ha_state() - - self.async_on_remove( - async_track_state_change_event( - self.hass, self._entity_ids, async_state_changed_listener - ) - ) - self.async_on_remove(start.async_at_start(self.hass, self._update_at_start)) - - @callback - def _update_at_start(self, _: HomeAssistant) -> None: - """Update the group state at start.""" - self.async_update_group_state() - self.async_write_ha_state() - - @callback - def async_defer_or_update_ha_state(self) -> None: - """Only update once at start.""" - if not self.hass.is_running: - return - - self.async_update_group_state() - self.async_write_ha_state() - - @abstractmethod - @callback - def async_update_group_state(self) -> None: - """Abstract method to update the entity.""" - - @callback - def async_update_supported_features( - self, - entity_id: str, - new_state: State | None, - ) -> None: - """Update dictionaries with supported features.""" - - -class Group(Entity): - """Track a group of entity ids.""" - - _unrecorded_attributes = frozenset({ATTR_ENTITY_ID, ATTR_ORDER, ATTR_AUTO}) - - _attr_should_poll = False - tracking: tuple[str, ...] - trackable: tuple[str, ...] - - def __init__( - self, - hass: HomeAssistant, - name: str, - *, - created_by_service: bool, - entity_ids: Collection[str] | None, - icon: str | None, - mode: bool | None, - order: int | None, - ) -> None: - """Initialize a group. - - This Object has factory function for creation. - """ - self.hass = hass - self._name = name - self._state: str | None = None - self._icon = icon - self._set_tracked(entity_ids) - self._on_off: dict[str, bool] = {} - self._assumed: dict[str, bool] = {} - self._on_states: set[str] = set() - self.created_by_service = created_by_service - self.mode = any - if mode: - self.mode = all - self._order = order - self._assumed_state = False - self._async_unsub_state_changed: CALLBACK_TYPE | None = None - - @staticmethod - @callback - def async_create_group_entity( - hass: HomeAssistant, - name: str, - *, - created_by_service: bool, - entity_ids: Collection[str] | None, - icon: str | None, - mode: bool | None, - object_id: str | None, - order: int | None, - ) -> Group: - """Create a group entity.""" - if order is None: - hass.data.setdefault(GROUP_ORDER, 0) - order = hass.data[GROUP_ORDER] - # Keep track of the group order without iterating - # every state in the state machine every time - # we setup a new group - hass.data[GROUP_ORDER] += 1 - - group = Group( - hass, - name, - created_by_service=created_by_service, - entity_ids=entity_ids, - icon=icon, - mode=mode, - order=order, - ) - - group.entity_id = async_generate_entity_id( - ENTITY_ID_FORMAT, object_id or name, hass=hass - ) - - return group - - @staticmethod - async def async_create_group( - hass: HomeAssistant, - name: str, - *, - created_by_service: bool, - entity_ids: Collection[str] | None, - icon: str | None, - mode: bool | None, - object_id: str | None, - order: int | None, - ) -> Group: - """Initialize a group. - - This method must be run in the event loop. - """ - group = Group.async_create_group_entity( - hass, - name, - created_by_service=created_by_service, - entity_ids=entity_ids, - icon=icon, - mode=mode, - object_id=object_id, - order=order, - ) - - # If called before the platform async_setup is called (test cases) - await _async_get_component(hass).async_add_entities([group]) - return group - - @property - def name(self) -> str: - """Return the name of the group.""" - return self._name - - @name.setter - def name(self, value: str) -> None: - """Set Group name.""" - self._name = value - - @property - def state(self) -> str | None: - """Return the state of the group.""" - return self._state - - @property - def icon(self) -> str | None: - """Return the icon of the group.""" - return self._icon - - @icon.setter - def icon(self, value: str | None) -> None: - """Set Icon for group.""" - self._icon = value - - @property - def extra_state_attributes(self) -> dict[str, Any]: - """Return the state attributes for the group.""" - data = {ATTR_ENTITY_ID: self.tracking, ATTR_ORDER: self._order} - if self.created_by_service: - data[ATTR_AUTO] = True - - return data - - @property - def assumed_state(self) -> bool: - """Test if any member has an assumed state.""" - return self._assumed_state - - def update_tracked_entity_ids(self, entity_ids: Collection[str] | None) -> None: - """Update the member entity IDs.""" - asyncio.run_coroutine_threadsafe( - self.async_update_tracked_entity_ids(entity_ids), self.hass.loop - ).result() - - async def async_update_tracked_entity_ids( - self, entity_ids: Collection[str] | None - ) -> None: - """Update the member entity IDs. - - This method must be run in the event loop. - """ - self._async_stop() - self._set_tracked(entity_ids) - self._reset_tracked_state() - self._async_start() - - def _set_tracked(self, entity_ids: Collection[str] | None) -> None: - """Tuple of entities to be tracked.""" - # tracking are the entities we want to track - # trackable are the entities we actually watch - - if not entity_ids: - self.tracking = () - self.trackable = () - return - - registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] - excluded_domains = registry.exclude_domains - - tracking: list[str] = [] - trackable: list[str] = [] - for ent_id in entity_ids: - ent_id_lower = ent_id.lower() - domain = split_entity_id(ent_id_lower)[0] - tracking.append(ent_id_lower) - if domain not in excluded_domains: - trackable.append(ent_id_lower) - - self.trackable = tuple(trackable) - self.tracking = tuple(tracking) - - @callback - def _async_start(self, _: HomeAssistant | None = None) -> None: - """Start tracking members and write state.""" - self._reset_tracked_state() - self._async_start_tracking() - self.async_write_ha_state() - - @callback - def _async_start_tracking(self) -> None: - """Start tracking members. - - This method must be run in the event loop. - """ - if self.trackable and self._async_unsub_state_changed is None: - self._async_unsub_state_changed = async_track_state_change_event( - self.hass, self.trackable, self._async_state_changed_listener - ) - - self._async_update_group_state() - - @callback - def _async_stop(self) -> None: - """Unregister the group from Home Assistant. - - This method must be run in the event loop. - """ - if self._async_unsub_state_changed: - self._async_unsub_state_changed() - self._async_unsub_state_changed = None - - @callback - def async_update_group_state(self) -> None: - """Query all members and determine current group state.""" - self._state = None - self._async_update_group_state() - - async def async_added_to_hass(self) -> None: - """Handle addition to Home Assistant.""" - self.async_on_remove(start.async_at_start(self.hass, self._async_start)) - - async def async_will_remove_from_hass(self) -> None: - """Handle removal from Home Assistant.""" - self._async_stop() - - async def _async_state_changed_listener( - self, event: Event[EventStateChangedData] - ) -> None: - """Respond to a member state changing. - - This method must be run in the event loop. - """ - # removed - if self._async_unsub_state_changed is None: - return - - self.async_set_context(event.context) - - if (new_state := event.data["new_state"]) is None: - # The state was removed from the state machine - self._reset_tracked_state() - - self._async_update_group_state(new_state) - self.async_write_ha_state() - - def _reset_tracked_state(self) -> None: - """Reset tracked state.""" - self._on_off = {} - self._assumed = {} - self._on_states = set() - - for entity_id in self.trackable: - if (state := self.hass.states.get(entity_id)) is not None: - self._see_state(state) - - def _see_state(self, new_state: State) -> None: - """Keep track of the state.""" - entity_id = new_state.entity_id - domain = new_state.domain - state = new_state.state - registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] - self._assumed[entity_id] = bool(new_state.attributes.get(ATTR_ASSUMED_STATE)) - - if domain not in registry.on_states_by_domain: - # Handle the group of a group case - if state in registry.on_off_mapping: - self._on_states.add(state) - elif state in registry.off_on_mapping: - self._on_states.add(registry.off_on_mapping[state]) - self._on_off[entity_id] = state in registry.on_off_mapping - else: - entity_on_state = registry.on_states_by_domain[domain] - if domain in registry.on_states_by_domain: - self._on_states.update(entity_on_state) - self._on_off[entity_id] = state in entity_on_state - - @callback - def _async_update_group_state(self, tr_state: State | None = None) -> None: - """Update group state. - - Optionally you can provide the only state changed since last update - allowing this method to take shortcuts. - - This method must be run in the event loop. - """ - # To store current states of group entities. Might not be needed. - if tr_state: - self._see_state(tr_state) - - if not self._on_off: - return - - if ( - tr_state is None - or self._assumed_state - and not tr_state.attributes.get(ATTR_ASSUMED_STATE) - ): - self._assumed_state = self.mode(self._assumed.values()) - - elif tr_state.attributes.get(ATTR_ASSUMED_STATE): - self._assumed_state = True - - num_on_states = len(self._on_states) - # If all the entity domains we are tracking - # have the same on state we use this state - # and its hass.data[REG_KEY].on_off_mapping to off - if num_on_states == 1: - on_state = list(self._on_states)[0] - # If we do not have an on state for any domains - # we use None (which will be STATE_UNKNOWN) - elif num_on_states == 0: - self._state = None - return - # If the entity domains have more than one - # on state, we use STATE_ON/STATE_OFF - else: - on_state = STATE_ON - group_is_on = self.mode(self._on_off.values()) - if group_is_on: - self._state = on_state - else: - registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] - self._state = registry.on_off_mapping[on_state] + await async_get_component(hass).async_add_entities(entities) diff --git a/homeassistant/components/group/binary_sensor.py b/homeassistant/components/group/binary_sensor.py index 16665e8970f..3fbadfb156c 100644 --- a/homeassistant/components/group/binary_sensor.py +++ b/homeassistant/components/group/binary_sensor.py @@ -29,7 +29,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import GroupEntity +from .entity import GroupEntity DEFAULT_NAME = "Binary Sensor Group" diff --git a/homeassistant/components/group/config_flow.py b/homeassistant/components/group/config_flow.py index 237eb570417..f3e2405d86a 100644 --- a/homeassistant/components/group/config_flow.py +++ b/homeassistant/components/group/config_flow.py @@ -4,7 +4,7 @@ from __future__ import annotations from collections.abc import Callable, Coroutine, Mapping from functools import partial -from typing import TYPE_CHECKING, Any, cast +from typing import Any, cast import voluptuous as vol @@ -22,12 +22,10 @@ from homeassistant.helpers.schema_config_entry_flow import ( entity_selector_without_own_entities, ) -if TYPE_CHECKING: - from . import GroupEntity - from .binary_sensor import CONF_ALL, async_create_preview_binary_sensor from .const import CONF_HIDE_MEMBERS, CONF_IGNORE_NON_NUMERIC, DOMAIN from .cover import async_create_preview_cover +from .entity import GroupEntity from .event import async_create_preview_event from .fan import async_create_preview_fan from .light import async_create_preview_light diff --git a/homeassistant/components/group/const.py b/homeassistant/components/group/const.py index e64358181ca..0fdd429269f 100644 --- a/homeassistant/components/group/const.py +++ b/homeassistant/components/group/const.py @@ -4,3 +4,16 @@ CONF_HIDE_MEMBERS = "hide_members" CONF_IGNORE_NON_NUMERIC = "ignore_non_numeric" DOMAIN = "group" + +REG_KEY = f"{DOMAIN}_registry" + +GROUP_ORDER = "group_order" + + +ATTR_ADD_ENTITIES = "add_entities" +ATTR_REMOVE_ENTITIES = "remove_entities" +ATTR_AUTO = "auto" +ATTR_ENTITIES = "entities" +ATTR_OBJECT_ID = "object_id" +ATTR_ORDER = "order" +ATTR_ALL = "all" diff --git a/homeassistant/components/group/cover.py b/homeassistant/components/group/cover.py index 8d521314f96..02e5ebbc7cd 100644 --- a/homeassistant/components/group/cover.py +++ b/homeassistant/components/group/cover.py @@ -43,7 +43,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import GroupEntity +from .entity import GroupEntity from .util import reduce_attribute KEY_OPEN_CLOSE = "open_close" diff --git a/homeassistant/components/group/entity.py b/homeassistant/components/group/entity.py new file mode 100644 index 00000000000..3df068f5e23 --- /dev/null +++ b/homeassistant/components/group/entity.py @@ -0,0 +1,477 @@ +"""Provide entity classes for group entities.""" + +from __future__ import annotations + +from abc import abstractmethod +import asyncio +from collections.abc import Callable, Collection, Mapping +import logging +from typing import Any + +from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_ON +from homeassistant.core import ( + CALLBACK_TYPE, + Event, + HomeAssistant, + State, + callback, + split_entity_id, +) +from homeassistant.helpers import start +from homeassistant.helpers.entity import Entity, async_generate_entity_id +from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.helpers.event import ( + EventStateChangedData, + async_track_state_change_event, +) + +from .const import ( + ATTR_AUTO, + ATTR_ORDER, + DOMAIN, # noqa: F401 + GROUP_ORDER, + REG_KEY, +) +from .registry import GroupIntegrationRegistry + +ENTITY_ID_FORMAT = DOMAIN + ".{}" + +_PACKAGE_LOGGER = logging.getLogger(__package__) + +_LOGGER = logging.getLogger(__name__) + + +class GroupEntity(Entity): + """Representation of a Group of entities.""" + + _unrecorded_attributes = frozenset({ATTR_ENTITY_ID}) + + _attr_should_poll = False + _entity_ids: list[str] + + @callback + def async_start_preview( + self, + preview_callback: Callable[[str, Mapping[str, Any]], None], + ) -> CALLBACK_TYPE: + """Render a preview.""" + + for entity_id in self._entity_ids: + if (state := self.hass.states.get(entity_id)) is None: + continue + self.async_update_supported_features(entity_id, state) + + @callback + def async_state_changed_listener( + event: Event[EventStateChangedData] | None, + ) -> None: + """Handle child updates.""" + self.async_update_group_state() + if event: + self.async_update_supported_features( + event.data["entity_id"], event.data["new_state"] + ) + calculated_state = self._async_calculate_state() + preview_callback(calculated_state.state, calculated_state.attributes) + + async_state_changed_listener(None) + return async_track_state_change_event( + self.hass, self._entity_ids, async_state_changed_listener + ) + + async def async_added_to_hass(self) -> None: + """Register listeners.""" + for entity_id in self._entity_ids: + if (state := self.hass.states.get(entity_id)) is None: + continue + self.async_update_supported_features(entity_id, state) + + @callback + def async_state_changed_listener( + event: Event[EventStateChangedData], + ) -> None: + """Handle child updates.""" + self.async_set_context(event.context) + self.async_update_supported_features( + event.data["entity_id"], event.data["new_state"] + ) + self.async_defer_or_update_ha_state() + + self.async_on_remove( + async_track_state_change_event( + self.hass, self._entity_ids, async_state_changed_listener + ) + ) + self.async_on_remove(start.async_at_start(self.hass, self._update_at_start)) + + @callback + def _update_at_start(self, _: HomeAssistant) -> None: + """Update the group state at start.""" + self.async_update_group_state() + self.async_write_ha_state() + + @callback + def async_defer_or_update_ha_state(self) -> None: + """Only update once at start.""" + if not self.hass.is_running: + return + + self.async_update_group_state() + self.async_write_ha_state() + + @abstractmethod + @callback + def async_update_group_state(self) -> None: + """Abstract method to update the entity.""" + + @callback + def async_update_supported_features( + self, + entity_id: str, + new_state: State | None, + ) -> None: + """Update dictionaries with supported features.""" + + +class Group(Entity): + """Track a group of entity ids.""" + + _unrecorded_attributes = frozenset({ATTR_ENTITY_ID, ATTR_ORDER, ATTR_AUTO}) + + _attr_should_poll = False + tracking: tuple[str, ...] + trackable: tuple[str, ...] + + def __init__( + self, + hass: HomeAssistant, + name: str, + *, + created_by_service: bool, + entity_ids: Collection[str] | None, + icon: str | None, + mode: bool | None, + order: int | None, + ) -> None: + """Initialize a group. + + This Object has factory function for creation. + """ + self.hass = hass + self._name = name + self._state: str | None = None + self._icon = icon + self._set_tracked(entity_ids) + self._on_off: dict[str, bool] = {} + self._assumed: dict[str, bool] = {} + self._on_states: set[str] = set() + self.created_by_service = created_by_service + self.mode = any + if mode: + self.mode = all + self._order = order + self._assumed_state = False + self._async_unsub_state_changed: CALLBACK_TYPE | None = None + + @staticmethod + @callback + def async_create_group_entity( + hass: HomeAssistant, + name: str, + *, + created_by_service: bool, + entity_ids: Collection[str] | None, + icon: str | None, + mode: bool | None, + object_id: str | None, + order: int | None, + ) -> Group: + """Create a group entity.""" + if order is None: + hass.data.setdefault(GROUP_ORDER, 0) + order = hass.data[GROUP_ORDER] + # Keep track of the group order without iterating + # every state in the state machine every time + # we setup a new group + hass.data[GROUP_ORDER] += 1 + + group = Group( + hass, + name, + created_by_service=created_by_service, + entity_ids=entity_ids, + icon=icon, + mode=mode, + order=order, + ) + + group.entity_id = async_generate_entity_id( + ENTITY_ID_FORMAT, object_id or name, hass=hass + ) + + return group + + @staticmethod + async def async_create_group( + hass: HomeAssistant, + name: str, + *, + created_by_service: bool, + entity_ids: Collection[str] | None, + icon: str | None, + mode: bool | None, + object_id: str | None, + order: int | None, + ) -> Group: + """Initialize a group. + + This method must be run in the event loop. + """ + group = Group.async_create_group_entity( + hass, + name, + created_by_service=created_by_service, + entity_ids=entity_ids, + icon=icon, + mode=mode, + object_id=object_id, + order=order, + ) + + # If called before the platform async_setup is called (test cases) + await async_get_component(hass).async_add_entities([group]) + return group + + @property + def name(self) -> str: + """Return the name of the group.""" + return self._name + + @name.setter + def name(self, value: str) -> None: + """Set Group name.""" + self._name = value + + @property + def state(self) -> str | None: + """Return the state of the group.""" + return self._state + + @property + def icon(self) -> str | None: + """Return the icon of the group.""" + return self._icon + + @icon.setter + def icon(self, value: str | None) -> None: + """Set Icon for group.""" + self._icon = value + + @property + def extra_state_attributes(self) -> dict[str, Any]: + """Return the state attributes for the group.""" + data = {ATTR_ENTITY_ID: self.tracking, ATTR_ORDER: self._order} + if self.created_by_service: + data[ATTR_AUTO] = True + + return data + + @property + def assumed_state(self) -> bool: + """Test if any member has an assumed state.""" + return self._assumed_state + + def update_tracked_entity_ids(self, entity_ids: Collection[str] | None) -> None: + """Update the member entity IDs.""" + asyncio.run_coroutine_threadsafe( + self.async_update_tracked_entity_ids(entity_ids), self.hass.loop + ).result() + + async def async_update_tracked_entity_ids( + self, entity_ids: Collection[str] | None + ) -> None: + """Update the member entity IDs. + + This method must be run in the event loop. + """ + self._async_stop() + self._set_tracked(entity_ids) + self._reset_tracked_state() + self._async_start() + + def _set_tracked(self, entity_ids: Collection[str] | None) -> None: + """Tuple of entities to be tracked.""" + # tracking are the entities we want to track + # trackable are the entities we actually watch + + if not entity_ids: + self.tracking = () + self.trackable = () + return + + registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] + excluded_domains = registry.exclude_domains + + tracking: list[str] = [] + trackable: list[str] = [] + for ent_id in entity_ids: + ent_id_lower = ent_id.lower() + domain = split_entity_id(ent_id_lower)[0] + tracking.append(ent_id_lower) + if domain not in excluded_domains: + trackable.append(ent_id_lower) + + self.trackable = tuple(trackable) + self.tracking = tuple(tracking) + + @callback + def _async_start(self, _: HomeAssistant | None = None) -> None: + """Start tracking members and write state.""" + self._reset_tracked_state() + self._async_start_tracking() + self.async_write_ha_state() + + @callback + def _async_start_tracking(self) -> None: + """Start tracking members. + + This method must be run in the event loop. + """ + if self.trackable and self._async_unsub_state_changed is None: + self._async_unsub_state_changed = async_track_state_change_event( + self.hass, self.trackable, self._async_state_changed_listener + ) + + self._async_update_group_state() + + @callback + def _async_stop(self) -> None: + """Unregister the group from Home Assistant. + + This method must be run in the event loop. + """ + if self._async_unsub_state_changed: + self._async_unsub_state_changed() + self._async_unsub_state_changed = None + + @callback + def async_update_group_state(self) -> None: + """Query all members and determine current group state.""" + self._state = None + self._async_update_group_state() + + async def async_added_to_hass(self) -> None: + """Handle addition to Home Assistant.""" + self.async_on_remove(start.async_at_start(self.hass, self._async_start)) + + async def async_will_remove_from_hass(self) -> None: + """Handle removal from Home Assistant.""" + self._async_stop() + + async def _async_state_changed_listener( + self, event: Event[EventStateChangedData] + ) -> None: + """Respond to a member state changing. + + This method must be run in the event loop. + """ + # removed + if self._async_unsub_state_changed is None: + return + + self.async_set_context(event.context) + + if (new_state := event.data["new_state"]) is None: + # The state was removed from the state machine + self._reset_tracked_state() + + self._async_update_group_state(new_state) + self.async_write_ha_state() + + def _reset_tracked_state(self) -> None: + """Reset tracked state.""" + self._on_off = {} + self._assumed = {} + self._on_states = set() + + for entity_id in self.trackable: + if (state := self.hass.states.get(entity_id)) is not None: + self._see_state(state) + + def _see_state(self, new_state: State) -> None: + """Keep track of the state.""" + entity_id = new_state.entity_id + domain = new_state.domain + state = new_state.state + registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] + self._assumed[entity_id] = bool(new_state.attributes.get(ATTR_ASSUMED_STATE)) + + if domain not in registry.on_states_by_domain: + # Handle the group of a group case + if state in registry.on_off_mapping: + self._on_states.add(state) + elif state in registry.off_on_mapping: + self._on_states.add(registry.off_on_mapping[state]) + self._on_off[entity_id] = state in registry.on_off_mapping + else: + entity_on_state = registry.on_states_by_domain[domain] + if domain in registry.on_states_by_domain: + self._on_states.update(entity_on_state) + self._on_off[entity_id] = state in entity_on_state + + @callback + def _async_update_group_state(self, tr_state: State | None = None) -> None: + """Update group state. + + Optionally you can provide the only state changed since last update + allowing this method to take shortcuts. + + This method must be run in the event loop. + """ + # To store current states of group entities. Might not be needed. + if tr_state: + self._see_state(tr_state) + + if not self._on_off: + return + + if ( + tr_state is None + or self._assumed_state + and not tr_state.attributes.get(ATTR_ASSUMED_STATE) + ): + self._assumed_state = self.mode(self._assumed.values()) + + elif tr_state.attributes.get(ATTR_ASSUMED_STATE): + self._assumed_state = True + + num_on_states = len(self._on_states) + # If all the entity domains we are tracking + # have the same on state we use this state + # and its hass.data[REG_KEY].on_off_mapping to off + if num_on_states == 1: + on_state = list(self._on_states)[0] + # If we do not have an on state for any domains + # we use None (which will be STATE_UNKNOWN) + elif num_on_states == 0: + self._state = None + return + # If the entity domains have more than one + # on state, we use STATE_ON/STATE_OFF + else: + on_state = STATE_ON + group_is_on = self.mode(self._on_off.values()) + if group_is_on: + self._state = on_state + else: + registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] + self._state = registry.on_off_mapping[on_state] + + +def async_get_component(hass: HomeAssistant) -> EntityComponent[Group]: + """Get the group entity component.""" + if (component := hass.data.get(DOMAIN)) is None: + component = hass.data[DOMAIN] = EntityComponent[Group]( + _PACKAGE_LOGGER, DOMAIN, hass + ) + return component diff --git a/homeassistant/components/group/event.py b/homeassistant/components/group/event.py index 8095a0e89c1..61ddb3e0645 100644 --- a/homeassistant/components/group/event.py +++ b/homeassistant/components/group/event.py @@ -34,7 +34,7 @@ from homeassistant.helpers.event import ( ) from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import GroupEntity +from .entity import GroupEntity DEFAULT_NAME = "Event group" diff --git a/homeassistant/components/group/fan.py b/homeassistant/components/group/fan.py index c8add0b6724..b70a4ff1531 100644 --- a/homeassistant/components/group/fan.py +++ b/homeassistant/components/group/fan.py @@ -40,7 +40,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import GroupEntity +from .entity import GroupEntity from .util import attribute_equal, most_frequent_attribute, reduce_attribute SUPPORTED_FLAGS = { diff --git a/homeassistant/components/group/light.py b/homeassistant/components/group/light.py index d014ca5d618..9adced828c7 100644 --- a/homeassistant/components/group/light.py +++ b/homeassistant/components/group/light.py @@ -51,7 +51,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import GroupEntity +from .entity import GroupEntity from .util import find_state_attributes, mean_tuple, reduce_attribute DEFAULT_NAME = "Light Group" diff --git a/homeassistant/components/group/lock.py b/homeassistant/components/group/lock.py index 08c2c053b0e..b0cf36bd6b1 100644 --- a/homeassistant/components/group/lock.py +++ b/homeassistant/components/group/lock.py @@ -34,7 +34,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import GroupEntity +from .entity import GroupEntity DEFAULT_NAME = "Lock Group" diff --git a/homeassistant/components/group/registry.py b/homeassistant/components/group/registry.py new file mode 100644 index 00000000000..1441d39d331 --- /dev/null +++ b/homeassistant/components/group/registry.py @@ -0,0 +1,68 @@ +"""Provide the functionality to group entities.""" + +from __future__ import annotations + +from contextvars import ContextVar +from typing import Protocol + +from homeassistant.const import STATE_OFF, STATE_ON +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.integration_platform import ( + async_process_integration_platforms, +) + +from .const import DOMAIN, REG_KEY + +current_domain: ContextVar[str] = ContextVar("current_domain") + + +async def async_setup(hass: HomeAssistant) -> None: + """Set up the Group integration registry of integration platforms.""" + hass.data[REG_KEY] = GroupIntegrationRegistry() + + await async_process_integration_platforms( + hass, DOMAIN, _process_group_platform, wait_for_platforms=True + ) + + +class GroupProtocol(Protocol): + """Define the format of group platforms.""" + + def async_describe_on_off_states( + self, hass: HomeAssistant, registry: GroupIntegrationRegistry + ) -> None: + """Describe group on off states.""" + + +@callback +def _process_group_platform( + hass: HomeAssistant, domain: str, platform: GroupProtocol +) -> None: + """Process a group platform.""" + current_domain.set(domain) + registry: GroupIntegrationRegistry = hass.data[REG_KEY] + platform.async_describe_on_off_states(hass, registry) + + +class GroupIntegrationRegistry: + """Class to hold a registry of integrations.""" + + on_off_mapping: dict[str, str] = {STATE_ON: STATE_OFF} + off_on_mapping: dict[str, str] = {STATE_OFF: STATE_ON} + on_states_by_domain: dict[str, set] = {} + exclude_domains: set = set() + + def exclude_domain(self) -> None: + """Exclude the current domain.""" + self.exclude_domains.add(current_domain.get()) + + def on_off_states(self, on_states: set, off_state: str) -> None: + """Register on and off states for the current domain.""" + for on_state in on_states: + if on_state not in self.on_off_mapping: + self.on_off_mapping[on_state] = off_state + + if len(on_states) == 1 and off_state not in self.off_on_mapping: + self.off_on_mapping[off_state] = list(on_states)[0] + + self.on_states_by_domain[current_domain.get()] = set(on_states) diff --git a/homeassistant/components/group/sensor.py b/homeassistant/components/group/sensor.py index 7334831211d..5de668c7bb0 100644 --- a/homeassistant/components/group/sensor.py +++ b/homeassistant/components/group/sensor.py @@ -52,8 +52,8 @@ from homeassistant.helpers.issue_registry import ( ) from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, StateType -from . import DOMAIN as GROUP_DOMAIN, GroupEntity -from .const import CONF_IGNORE_NON_NUMERIC +from .const import CONF_IGNORE_NON_NUMERIC, DOMAIN as GROUP_DOMAIN +from .entity import GroupEntity DEFAULT_NAME = "Sensor Group" diff --git a/homeassistant/components/group/switch.py b/homeassistant/components/group/switch.py index ec70f137b33..7be6b188e72 100644 --- a/homeassistant/components/group/switch.py +++ b/homeassistant/components/group/switch.py @@ -25,7 +25,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import GroupEntity +from .entity import GroupEntity DEFAULT_NAME = "Switch Group" CONF_ALL = "all"