diff --git a/homeassistant/components/group/entity.py b/homeassistant/components/group/entity.py index a8fd9027984..5ac913dde8d 100644 --- a/homeassistant/components/group/entity.py +++ b/homeassistant/components/group/entity.py @@ -8,7 +8,7 @@ from collections.abc import Callable, Collection, Mapping import logging from typing import Any -from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_ON +from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_OFF, STATE_ON from homeassistant.core import ( CALLBACK_TYPE, Event, @@ -131,6 +131,9 @@ class Group(Entity): _unrecorded_attributes = frozenset({ATTR_ENTITY_ID, ATTR_ORDER, ATTR_AUTO}) _attr_should_poll = False + # In case there is only one active domain we use specific ON or OFF + # values, if all ON or OFF states are equal + single_active_domain: str | None tracking: tuple[str, ...] trackable: tuple[str, ...] @@ -287,6 +290,7 @@ class Group(Entity): if not entity_ids: self.tracking = () self.trackable = () + self.single_active_domain = None return registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] @@ -294,12 +298,22 @@ class Group(Entity): tracking: list[str] = [] trackable: list[str] = [] + self.single_active_domain = None + multiple_domains: bool = False for ent_id in entity_ids: ent_id_lower = ent_id.lower() domain = split_entity_id(ent_id_lower)[0] tracking.append(ent_id_lower) - if domain not in excluded_domains: - trackable.append(ent_id_lower) + if domain in excluded_domains: + continue + + trackable.append(ent_id_lower) + + if not multiple_domains and self.single_active_domain is None: + self.single_active_domain = domain + if self.single_active_domain != domain: + multiple_domains = True + self.single_active_domain = None self.trackable = tuple(trackable) self.tracking = tuple(tracking) @@ -395,10 +409,36 @@ class Group(Entity): self._on_off[entity_id] = state in registry.on_off_mapping else: entity_on_state = registry.on_states_by_domain[domain] - if domain in registry.on_states_by_domain: - self._on_states.update(entity_on_state) + self._on_states.update(entity_on_state) self._on_off[entity_id] = state in entity_on_state + def _detect_specific_on_off_state(self, group_is_on: bool) -> set[str]: + """Check if a specific ON or OFF state is possible.""" + # In case the group contains entities of the same domain with the same ON + # or an OFF state (one or more domains), we want to use that specific state. + # If we have more then one ON or OFF state we default to STATE_ON or STATE_OFF. + registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] + active_on_states: set[str] = set() + active_off_states: set[str] = set() + for entity_id in self.trackable: + if (state := self.hass.states.get(entity_id)) is None: + continue + current_state = state.state + if ( + group_is_on + and (domain_on_states := registry.on_states_by_domain.get(state.domain)) + and current_state in domain_on_states + ): + active_on_states.add(current_state) + # If we have more than one on state, the group state + # will result in STATE_ON and we can stop checking + if len(active_on_states) > 1: + break + elif current_state in registry.off_on_mapping: + active_off_states.add(current_state) + + return active_on_states if group_is_on else active_off_states + @callback def _async_update_group_state(self, tr_state: State | None = None) -> None: """Update group state. @@ -425,27 +465,48 @@ class Group(Entity): elif tr_state.attributes.get(ATTR_ASSUMED_STATE): self._assumed_state = True - num_on_states = len(self._on_states) + # If we do not have an on state for any domains + # we use None (which will be STATE_UNKNOWN) + if (num_on_states := len(self._on_states)) == 0: + self._state = None + return + + group_is_on = self.mode(self._on_off.values()) + # If all the entity domains we are tracking # have the same on state we use this state # and its hass.data[REG_KEY].on_off_mapping to off if num_on_states == 1: - on_state = list(self._on_states)[0] - # If we do not have an on state for any domains - # we use None (which will be STATE_UNKNOWN) - elif num_on_states == 0: - self._state = None - return + on_state = next(iter(self._on_states)) # If the entity domains have more than one - # on state, we use STATE_ON/STATE_OFF - else: + # on state, we use STATE_ON/STATE_OFF, unless there is + # only one specific `on` state in use for one specific domain + elif self.single_active_domain and num_on_states: + active_on_states = self._detect_specific_on_off_state(True) + on_state = ( + list(active_on_states)[0] if len(active_on_states) == 1 else STATE_ON + ) + elif group_is_on: on_state = STATE_ON - group_is_on = self.mode(self._on_off.values()) if group_is_on: self._state = on_state + return + + registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] + if ( + active_domain := self.single_active_domain + ) and active_domain in registry.off_state_by_domain: + # If there is only one domain used, + # then we return the off state for that domain.s + self._state = registry.off_state_by_domain[active_domain] else: - registry: GroupIntegrationRegistry = self.hass.data[REG_KEY] - self._state = registry.on_off_mapping[on_state] + active_off_states = self._detect_specific_on_off_state(False) + # If there is one off state in use then we return that specific state, + # also if there a multiple domains involved, e.g. + # person and device_tracker, with a shared state. + self._state = ( + list(active_off_states)[0] if len(active_off_states) == 1 else STATE_OFF + ) def async_get_component(hass: HomeAssistant) -> EntityComponent[Group]: diff --git a/homeassistant/components/group/registry.py b/homeassistant/components/group/registry.py index 6cdb929d60c..474448db68a 100644 --- a/homeassistant/components/group/registry.py +++ b/homeassistant/components/group/registry.py @@ -49,9 +49,12 @@ class GroupIntegrationRegistry: def __init__(self) -> None: """Imitialize registry.""" - self.on_off_mapping: dict[str, str] = {STATE_ON: STATE_OFF} + self.on_off_mapping: dict[str, dict[str | None, str]] = { + STATE_ON: {None: STATE_OFF} + } self.off_on_mapping: dict[str, str] = {STATE_OFF: STATE_ON} self.on_states_by_domain: dict[str, set[str]] = {} + self.off_state_by_domain: dict[str, str] = {} self.exclude_domains: set[str] = set() def exclude_domain(self) -> None: @@ -60,11 +63,14 @@ class GroupIntegrationRegistry: def on_off_states(self, on_states: set, off_state: str) -> None: """Register on and off states for the current domain.""" + domain = current_domain.get() for on_state in on_states: if on_state not in self.on_off_mapping: - self.on_off_mapping[on_state] = off_state - + self.on_off_mapping[on_state] = {domain: off_state} + else: + self.on_off_mapping[on_state][domain] = off_state if len(on_states) == 1 and off_state not in self.off_on_mapping: self.off_on_mapping[off_state] = list(on_states)[0] - self.on_states_by_domain[current_domain.get()] = set(on_states) + self.on_states_by_domain[domain] = set(on_states) + self.off_state_by_domain[domain] = off_state diff --git a/tests/components/group/test_init.py b/tests/components/group/test_init.py index d3f2747933e..b9cdfcb1590 100644 --- a/tests/components/group/test_init.py +++ b/tests/components/group/test_init.py @@ -9,7 +9,7 @@ from unittest.mock import patch import pytest -from homeassistant.components import group +from homeassistant.components import group, vacuum from homeassistant.const import ( ATTR_ASSUMED_STATE, ATTR_FRIENDLY_NAME, @@ -659,6 +659,24 @@ async def test_is_on(hass: HomeAssistant) -> None: (STATE_ON, True), (STATE_OFF, False), ), + ( + ("vacuum", "vacuum"), + # Cleaning is the only on state + (vacuum.STATE_DOCKED, vacuum.STATE_CLEANING), + # Returning is the only on state + (vacuum.STATE_RETURNING, vacuum.STATE_PAUSED), + (vacuum.STATE_CLEANING, True), + (vacuum.STATE_RETURNING, True), + ), + ( + ("vacuum", "vacuum"), + # Multiple on states, so group state will be STATE_ON + (vacuum.STATE_RETURNING, vacuum.STATE_CLEANING), + # Only off states, so group state will be off + (vacuum.STATE_PAUSED, vacuum.STATE_IDLE), + (STATE_ON, True), + (STATE_OFF, False), + ), ], ) async def test_is_on_and_state_mixed_domains( @@ -1220,7 +1238,7 @@ async def test_group_climate_all_cool(hass: HomeAssistant) -> None: ) await hass.async_block_till_done() - assert hass.states.get("group.group_zero").state == STATE_ON + assert hass.states.get("group.group_zero").state == "cool" async def test_group_climate_all_off(hass: HomeAssistant) -> None: @@ -1334,7 +1352,7 @@ async def test_group_vacuum_on(hass: HomeAssistant) -> None: ) await hass.async_block_till_done() - assert hass.states.get("group.group_zero").state == STATE_ON + assert hass.states.get("group.group_zero").state == "cleaning" async def test_device_tracker_not_home(hass: HomeAssistant) -> None: