Refactor group state logic (#116318)

* Refactor group state logic

* Fix

* Add helper and tests for groups with entity platforms multiple ON states

* Adress comments

* Do not store object and avoid linear search

* User dataclass, cleanup multiline ternary

* Add test cases for grouped groups

* Remove dead code

* typo in comment

* Update metjod and module docstr
This commit is contained in:
Jan Bouwhuis 2024-05-02 21:55:46 +02:00 committed by GitHub
parent 8e7026d643
commit 41b688645a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 385 additions and 16 deletions

View File

@ -8,7 +8,7 @@ from collections.abc import Callable, Collection, Mapping
import logging
from typing import Any
from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_ON
from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_OFF, STATE_ON
from homeassistant.core import (
CALLBACK_TYPE,
Event,
@ -24,7 +24,7 @@ from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_track_state_change_event
from .const import ATTR_AUTO, ATTR_ORDER, DOMAIN, GROUP_ORDER, REG_KEY
from .registry import GroupIntegrationRegistry
from .registry import GroupIntegrationRegistry, SingleStateType
ENTITY_ID_FORMAT = DOMAIN + ".{}"
@ -133,6 +133,7 @@ class Group(Entity):
_attr_should_poll = False
tracking: tuple[str, ...]
trackable: tuple[str, ...]
single_state_type_key: SingleStateType | None
def __init__(
self,
@ -153,7 +154,7 @@ class Group(Entity):
self._attr_name = name
self._state: str | None = None
self._attr_icon = icon
self._set_tracked(entity_ids)
self._entity_ids = entity_ids
self._on_off: dict[str, bool] = {}
self._assumed: dict[str, bool] = {}
self._on_states: set[str] = set()
@ -287,6 +288,7 @@ class Group(Entity):
if not entity_ids:
self.tracking = ()
self.trackable = ()
self.single_state_type_key = None
return
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
@ -294,16 +296,42 @@ class Group(Entity):
tracking: list[str] = []
trackable: list[str] = []
single_state_type_set: set[SingleStateType] = set()
for ent_id in entity_ids:
ent_id_lower = ent_id.lower()
domain = split_entity_id(ent_id_lower)[0]
tracking.append(ent_id_lower)
if domain not in excluded_domains:
trackable.append(ent_id_lower)
if domain in registry.state_group_mapping:
single_state_type_set.add(registry.state_group_mapping[domain])
elif domain == DOMAIN:
# If a group contains another group we check if that group
# has a specific single state type
if ent_id in registry.state_group_mapping:
single_state_type_set.add(registry.state_group_mapping[ent_id])
else:
single_state_type_set.add(SingleStateType(STATE_ON, STATE_OFF))
if len(single_state_type_set) == 1:
self.single_state_type_key = next(iter(single_state_type_set))
# To support groups with nested groups we store the state type
# per group entity_id if there is a single state type
registry.state_group_mapping[self.entity_id] = self.single_state_type_key
else:
self.single_state_type_key = None
self.async_on_remove(self._async_deregister)
self.trackable = tuple(trackable)
self.tracking = tuple(tracking)
@callback
def _async_deregister(self) -> None:
"""Deregister group entity from the registry."""
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
if self.entity_id in registry.state_group_mapping:
registry.state_group_mapping.pop(self.entity_id)
@callback
def _async_start(self, _: HomeAssistant | None = None) -> None:
"""Start tracking members and write state."""
@ -342,6 +370,7 @@ class Group(Entity):
async def async_added_to_hass(self) -> None:
"""Handle addition to Home Assistant."""
self._set_tracked(self._entity_ids)
self.async_on_remove(start.async_at_start(self.hass, self._async_start))
async def async_will_remove_from_hass(self) -> None:
@ -430,12 +459,14 @@ class Group(Entity):
# have the same on state we use this state
# and its hass.data[REG_KEY].on_off_mapping to off
if num_on_states == 1:
on_state = list(self._on_states)[0]
on_state = next(iter(self._on_states))
# If we do not have an on state for any domains
# we use None (which will be STATE_UNKNOWN)
elif num_on_states == 0:
self._state = None
return
if self.single_state_type_key:
on_state = self.single_state_type_key.on_state
# If the entity domains have more than one
# on state, we use STATE_ON/STATE_OFF
else:
@ -443,9 +474,10 @@ class Group(Entity):
group_is_on = self.mode(self._on_off.values())
if group_is_on:
self._state = on_state
elif self.single_state_type_key:
self._state = self.single_state_type_key.off_state
else:
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
self._state = registry.on_off_mapping[on_state]
self._state = STATE_OFF
def async_get_component(hass: HomeAssistant) -> EntityComponent[Group]:

View File

@ -1,8 +1,12 @@
"""Provide the functionality to group entities."""
"""Provide the functionality to group entities.
Legacy group support will not be extended for new domains.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from dataclasses import dataclass
from typing import Protocol
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant, callback
@ -12,9 +16,6 @@ from homeassistant.helpers.integration_platform import (
from .const import DOMAIN, REG_KEY
if TYPE_CHECKING:
from .entity import Group
async def async_setup(hass: HomeAssistant) -> None:
"""Set up the Group integration registry of integration platforms."""
@ -43,6 +44,14 @@ def _process_group_platform(
platform.async_describe_on_off_states(hass, registry)
@dataclass(frozen=True, slots=True)
class SingleStateType:
"""Dataclass to store a single state type."""
on_state: str
off_state: str
class GroupIntegrationRegistry:
"""Class to hold a registry of integrations."""
@ -53,8 +62,7 @@ class GroupIntegrationRegistry:
self.off_on_mapping: dict[str, str] = {STATE_OFF: STATE_ON}
self.on_states_by_domain: dict[str, set[str]] = {}
self.exclude_domains: set[str] = set()
self.state_group_mapping: dict[str, tuple[str, str]] = {}
self.group_entities: set[Group] = set()
self.state_group_mapping: dict[str, SingleStateType] = {}
@callback
def exclude_domain(self, domain: str) -> None:
@ -65,12 +73,16 @@ class GroupIntegrationRegistry:
def on_off_states(
self, domain: str, on_states: set[str], default_on_state: str, off_state: str
) -> None:
"""Register on and off states for the current domain."""
"""Register on and off states for the current domain.
Legacy group support will not be extended for new domains.
"""
for on_state in on_states:
if on_state not in self.on_off_mapping:
self.on_off_mapping[on_state] = off_state
if len(on_states) == 1 and off_state not in self.off_on_mapping:
if off_state not in self.off_on_mapping:
self.off_on_mapping[off_state] = default_on_state
self.state_group_mapping[domain] = SingleStateType(default_on_state, off_state)
self.on_states_by_domain[domain] = on_states

View File

@ -10,6 +10,7 @@ from unittest.mock import patch
import pytest
from homeassistant.components import group
from homeassistant.components.group.registry import GroupIntegrationRegistry
from homeassistant.const import (
ATTR_ASSUMED_STATE,
ATTR_FRIENDLY_NAME,
@ -33,7 +34,116 @@ from homeassistant.setup import async_setup_component
from . import common
from tests.common import MockConfigEntry, assert_setup_component
from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
assert_setup_component,
mock_integration,
mock_platform,
)
async def help_test_mixed_entity_platforms_on_off_state_test(
hass: HomeAssistant,
on_off_states1: tuple[set[str], str, str],
on_off_states2: tuple[set[str], str, str],
entity_and_state1_state_2: tuple[str, str | None, str | None],
group_state1: str,
group_state2: str,
grouped_groups: bool = False,
) -> None:
"""Help test on_off_states on mixed entity platforms."""
class MockGroupPlatform1(MockPlatform):
"""Mock a group platform module for test1 integration."""
def async_describe_on_off_states(
self, hass: HomeAssistant, registry: GroupIntegrationRegistry
) -> None:
"""Describe group on off states."""
registry.on_off_states("test1", *on_off_states1)
class MockGroupPlatform2(MockPlatform):
"""Mock a group platform module for test2 integration."""
def async_describe_on_off_states(
self, hass: HomeAssistant, registry: GroupIntegrationRegistry
) -> None:
"""Describe group on off states."""
registry.on_off_states("test2", *on_off_states2)
mock_integration(hass, MockModule(domain="test1"))
mock_platform(hass, "test1.group", MockGroupPlatform1())
assert await async_setup_component(hass, "test1", {"test1": {}})
mock_integration(hass, MockModule(domain="test2"))
mock_platform(hass, "test2.group", MockGroupPlatform2())
assert await async_setup_component(hass, "test2", {"test2": {}})
if grouped_groups:
assert await async_setup_component(
hass,
"group",
{
"group": {
"test1": {
"entities": [
item[0]
for item in entity_and_state1_state_2
if item[0].startswith("test1.")
]
},
"test2": {
"entities": [
item[0]
for item in entity_and_state1_state_2
if item[0].startswith("test2.")
]
},
"test": {"entities": ["group.test1", "group.test2"]},
}
},
)
else:
assert await async_setup_component(
hass,
"group",
{
"group": {
"test": {
"entities": [item[0] for item in entity_and_state1_state_2]
},
}
},
)
await hass.async_block_till_done()
await hass.async_block_till_done()
state = hass.states.get("group.test")
assert state is not None
# Set first state
for entity_id, state1, _ in entity_and_state1_state_2:
hass.states.async_set(entity_id, state1)
await hass.async_block_till_done()
await hass.async_block_till_done()
state = hass.states.get("group.test")
assert state is not None
assert state.state == group_state1
# Set second state
for entity_id, _, state2 in entity_and_state1_state_2:
hass.states.async_set(entity_id, state2)
await hass.async_block_till_done()
await hass.async_block_till_done()
state = hass.states.get("group.test")
assert state is not None
assert state.state == group_state2
async def test_setup_group_with_mixed_groupable_states(hass: HomeAssistant) -> None:
@ -1560,6 +1670,7 @@ async def test_group_that_references_a_group_of_covers(hass: HomeAssistant) -> N
for entity_id in entity_ids:
hass.states.async_set(entity_id, "closed")
await hass.async_block_till_done()
assert await async_setup_component(hass, "cover", {})
assert await async_setup_component(
hass,
@ -1643,6 +1754,7 @@ async def test_group_that_references_two_types_of_groups(hass: HomeAssistant) ->
hass.states.async_set(entity_id, "home")
await hass.async_block_till_done()
assert await async_setup_component(hass, "cover", {})
assert await async_setup_component(hass, "device_tracker", {})
assert await async_setup_component(
hass,
@ -1884,3 +1996,216 @@ async def test_unhide_members_on_remove(
# Check the group members are unhidden
assert entity_registry.async_get(f"{group_type}.one").hidden_by == hidden_by
assert entity_registry.async_get(f"{group_type}.three").hidden_by == hidden_by
@pytest.mark.parametrize("grouped_groups", [False, True])
@pytest.mark.parametrize(
("on_off_states1", "on_off_states2"),
[
(
(
{
"on_beer",
"on_milk",
},
"on_beer", # default ON state test1
"off_water", # default OFF state test1
),
(
{
"on_beer",
"on_milk",
},
"on_milk", # default ON state test2
"off_wine", # default OFF state test2
),
),
],
)
@pytest.mark.parametrize(
("entity_and_state1_state_2", "group_state1", "group_state2"),
[
# All OFF states, no change, so group stays OFF
(
[
("test1.ent1", "off_water", "off_water"),
("test1.ent2", "off_water", "off_water"),
("test2.ent1", "off_wine", "off_wine"),
("test2.ent2", "off_wine", "off_wine"),
],
STATE_OFF,
STATE_OFF,
),
# All entities have state on_milk, but the state groups
# are different so the group status defaults to ON / OFF
(
[
("test1.ent1", "off_water", "on_milk"),
("test1.ent2", "off_water", "on_milk"),
("test2.ent1", "off_wine", "on_milk"),
("test2.ent2", "off_wine", "on_milk"),
],
STATE_OFF,
STATE_ON,
),
# Only test1 entities in group, all at ON state
# group returns the default ON state `on_beer`
(
[
("test1.ent1", "off_water", "on_milk"),
("test1.ent2", "off_water", "on_beer"),
],
"off_water",
"on_beer",
),
# Only test1 entities in group, all at ON state
# group returns the default ON state `on_beer`
(
[
("test1.ent1", "off_water", "on_milk"),
("test1.ent2", "off_water", "on_milk"),
],
"off_water",
"on_beer",
),
# Only test2 entities in group, all at ON state
# group returns the default ON state `on_milk`
(
[
("test2.ent1", "off_wine", "on_milk"),
("test2.ent2", "off_wine", "on_milk"),
],
"off_wine",
"on_milk",
),
],
)
async def test_entity_platforms_with_multiple_on_states_no_state_match(
hass: HomeAssistant,
on_off_states1: tuple[set[str], str, str],
on_off_states2: tuple[set[str], str, str],
entity_and_state1_state_2: tuple[str, str | None, str | None],
group_state1: str,
group_state2: str,
grouped_groups: bool,
) -> None:
"""Test custom entity platforms with multiple ON states without state match.
The test group 1 an 2 non matching (default_state_on, state_off) pairs.
"""
await help_test_mixed_entity_platforms_on_off_state_test(
hass,
on_off_states1,
on_off_states2,
entity_and_state1_state_2,
group_state1,
group_state2,
grouped_groups,
)
@pytest.mark.parametrize("grouped_groups", [False, True])
@pytest.mark.parametrize(
("on_off_states1", "on_off_states2"),
[
(
(
{
"on_beer",
"on_milk",
},
"on_beer", # default ON state test1
"off_water", # default OFF state test1
),
(
{
"on_beer",
"on_wine",
},
"on_beer", # default ON state test2
"off_water", # default OFF state test2
),
),
],
)
@pytest.mark.parametrize(
("entity_and_state1_state_2", "group_state1", "group_state2"),
[
# All OFF states, no change, so group stays OFF
(
[
("test1.ent1", "off_water", "off_water"),
("test1.ent2", "off_water", "off_water"),
("test2.ent1", "off_water", "off_water"),
("test2.ent2", "off_water", "off_water"),
],
"off_water",
"off_water",
),
# All entities have ON state `on_milk`
# but the group state will default to on_beer
# which is the default ON state for both integrations.
(
[
("test1.ent1", "off_water", "on_milk"),
("test1.ent2", "off_water", "on_milk"),
("test2.ent1", "off_water", "on_milk"),
("test2.ent2", "off_water", "on_milk"),
],
"off_water",
"on_beer",
),
# Only test1 entities in group, all at ON state
# group returns the default ON state `on_beer`
(
[
("test1.ent1", "off_water", "on_milk"),
("test1.ent2", "off_water", "on_beer"),
],
"off_water",
"on_beer",
),
# Only test1 entities in group, all at ON state
# group returns the default ON state `on_beer`
(
[
("test1.ent1", "off_water", "on_milk"),
("test1.ent2", "off_water", "on_milk"),
],
"off_water",
"on_beer",
),
# Only test2 entities in group, all at ON state
# group returns the default ON state `on_milk`
(
[
("test2.ent1", "off_water", "on_wine"),
("test2.ent2", "off_water", "on_wine"),
],
"off_water",
"on_beer",
),
],
)
async def test_entity_platforms_with_multiple_on_states_with_state_match(
hass: HomeAssistant,
on_off_states1: tuple[set[str], str, str],
on_off_states2: tuple[set[str], str, str],
entity_and_state1_state_2: tuple[str, str | None, str | None],
group_state1: str,
group_state2: str,
grouped_groups: bool,
) -> None:
"""Test custom entity platforms with multiple ON states with a state match.
The integrations test1 and test2 have matching (default_state_on, state_off) pairs.
"""
await help_test_mixed_entity_platforms_on_off_state_test(
hass,
on_off_states1,
on_off_states2,
entity_and_state1_state_2,
group_state1,
group_state2,
grouped_groups,
)