diff --git a/homeassistant/components/group/__init__.py b/homeassistant/components/group/__init__.py index 11c78a9b271..67f8096134e 100644 --- a/homeassistant/components/group/__init__.py +++ b/homeassistant/components/group/__init__.py @@ -15,6 +15,7 @@ from homeassistant.const import ( CONF_NAME, ENTITY_MATCH_ALL, ENTITY_MATCH_NONE, + EVENT_HOMEASSISTANT_START, SERVICE_RELOAD, STATE_CLOSED, STATE_HOME, @@ -28,7 +29,7 @@ from homeassistant.const import ( STATE_UNKNOWN, STATE_UNLOCKED, ) -from homeassistant.core import callback +from homeassistant.core import CoreState, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity import Entity, async_generate_entity_id from homeassistant.helpers.entity_component import EntityComponent @@ -341,6 +342,33 @@ async def _async_process_config(hass, config, component): ) +class GroupEntity(Entity): + """Representation of a Group of entities.""" + + @property + def should_poll(self) -> bool: + """Disable polling for group.""" + return False + + async def async_added_to_hass(self) -> None: + """Register listeners.""" + assert self.hass is not None + + async def _update_at_start(_): + await self.async_update_ha_state(True) + + self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, _update_at_start) + + async def async_defer_or_update_ha_state(self) -> None: + """Only update once at start.""" + assert self.hass is not None + + if self.hass.state != CoreState.running: + return + + await self.async_update_ha_state(True) + + class Group(Entity): """Track a group of entity ids.""" @@ -545,6 +573,7 @@ class Group(Entity): if self._async_unsub_state_changed is None: return + self.async_set_context(event.context) self._async_update_group_state(event.data.get("new_state")) self.async_write_ha_state() diff --git a/homeassistant/components/group/cover.py b/homeassistant/components/group/cover.py index 02de871cb7a..6a2111747d6 100644 --- a/homeassistant/components/group/cover.py +++ b/homeassistant/components/group/cover.py @@ -39,10 +39,12 @@ from homeassistant.const import ( STATE_OPEN, STATE_OPENING, ) -from homeassistant.core import State, callback +from homeassistant.core import State import homeassistant.helpers.config_validation as cv from homeassistant.helpers.event import async_track_state_change_event +from . import GroupEntity + # mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs # mypy: no-check-untyped-defs @@ -68,7 +70,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= async_add_entities([CoverGroup(config[CONF_NAME], config[CONF_ENTITIES])]) -class CoverGroup(CoverEntity): +class CoverGroup(GroupEntity, CoverEntity): """Representation of a CoverGroup.""" def __init__(self, name, entities): @@ -94,14 +96,13 @@ class CoverGroup(CoverEntity): KEY_POSITION: set(), } - @callback - def _update_supported_features_event(self, event): - self.update_supported_features( + async def _update_supported_features_event(self, event): + self.async_set_context(event.context) + await self.async_update_supported_features( event.data.get("entity_id"), event.data.get("new_state") ) - @callback - def update_supported_features( + async def async_update_supported_features( self, entity_id: str, new_state: Optional[State], update_state: bool = True, ) -> None: """Update dictionaries with supported features.""" @@ -111,7 +112,7 @@ class CoverGroup(CoverEntity): for values in self._tilts.values(): values.discard(entity_id) if update_state: - self.async_schedule_update_ha_state(True) + await self.async_defer_or_update_ha_state() return features = new_state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) @@ -143,17 +144,22 @@ class CoverGroup(CoverEntity): self._tilts[KEY_POSITION].discard(entity_id) if update_state: - self.async_schedule_update_ha_state(True) + await self.async_defer_or_update_ha_state() async def async_added_to_hass(self): """Register listeners.""" for entity_id in self._entities: new_state = self.hass.states.get(entity_id) - self.update_supported_features(entity_id, new_state, update_state=False) - async_track_state_change_event( - self.hass, self._entities, self._update_supported_features_event + await self.async_update_supported_features( + entity_id, new_state, update_state=False + ) + assert self.hass is not None + self.async_on_remove( + async_track_state_change_event( + self.hass, self._entities, self._update_supported_features_event + ) ) - await self.async_update() + await super().async_added_to_hass() @property def name(self): @@ -165,11 +171,6 @@ class CoverGroup(CoverEntity): """Enable buttons even if at end position.""" return self._assumed_state - @property - def should_poll(self): - """Disable polling for cover group.""" - return False - @property def supported_features(self): """Flag supported features for the cover.""" diff --git a/homeassistant/components/group/light.py b/homeassistant/components/group/light.py index 1b33a0a6e88..289bb8df3f0 100644 --- a/homeassistant/components/group/light.py +++ b/homeassistant/components/group/light.py @@ -36,12 +36,14 @@ from homeassistant.const import ( STATE_ON, STATE_UNAVAILABLE, ) -from homeassistant.core import CALLBACK_TYPE, State, callback +from homeassistant.core import State import homeassistant.helpers.config_validation as cv from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.typing import ConfigType, HomeAssistantType from homeassistant.util import color as color_util +from . import GroupEntity + # mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs # mypy: no-check-untyped-defs @@ -76,7 +78,7 @@ async def async_setup_platform( ) -class LightGroup(light.LightEntity): +class LightGroup(GroupEntity, light.LightEntity): """Representation of a light group.""" def __init__(self, name: str, entity_ids: List[str]) -> None: @@ -94,27 +96,22 @@ class LightGroup(light.LightEntity): self._effect_list: Optional[List[str]] = None self._effect: Optional[str] = None self._supported_features: int = 0 - self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None async def async_added_to_hass(self) -> None: """Register callbacks.""" - @callback - def async_state_changed_listener(*_): + async def async_state_changed_listener(event): """Handle child updates.""" - self.async_schedule_update_ha_state(True) + self.async_set_context(event.context) + await self.async_defer_or_update_ha_state() - assert self.hass is not None - self._async_unsub_state_changed = async_track_state_change_event( - self.hass, self._entity_ids, async_state_changed_listener + assert self.hass + self.async_on_remove( + async_track_state_change_event( + self.hass, self._entity_ids, async_state_changed_listener + ) ) - await self.async_update() - - async def async_will_remove_from_hass(self): - """Handle removal from Home Assistant.""" - if self._async_unsub_state_changed is not None: - self._async_unsub_state_changed() - self._async_unsub_state_changed = None + await super().async_added_to_hass() @property def name(self) -> str: diff --git a/tests/components/group/test_cover.py b/tests/components/group/test_cover.py index 98460762389..efdbc40ee46 100644 --- a/tests/components/group/test_cover.py +++ b/tests/components/group/test_cover.py @@ -78,6 +78,8 @@ async def setup_comp(hass, config_count): with assert_setup_component(count, DOMAIN): await async_setup_component(hass, DOMAIN, config) await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() @pytest.mark.parametrize("config_count", [(CONFIG_ATTRIBUTES, 1)]) diff --git a/tests/components/group/test_light.py b/tests/components/group/test_light.py index 685db475b8c..8c659a5ebf6 100644 --- a/tests/components/group/test_light.py +++ b/tests/components/group/test_light.py @@ -47,6 +47,8 @@ async def test_default_state(hass): }, ) await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() state = hass.states.get("light.bedroom_group") assert state is not None @@ -73,6 +75,9 @@ async def test_state_reporting(hass): } }, ) + await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() hass.states.async_set("light.test1", STATE_ON) hass.states.async_set("light.test2", STATE_UNAVAILABLE) @@ -107,6 +112,9 @@ async def test_brightness(hass): } }, ) + await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() hass.states.async_set( "light.test1", STATE_ON, {ATTR_BRIGHTNESS: 255, ATTR_SUPPORTED_FEATURES: 1} @@ -147,6 +155,9 @@ async def test_color(hass): } }, ) + await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() hass.states.async_set( "light.test1", STATE_ON, {ATTR_HS_COLOR: (0, 100), ATTR_SUPPORTED_FEATURES: 16} @@ -184,6 +195,9 @@ async def test_white_value(hass): } }, ) + await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() hass.states.async_set( "light.test1", STATE_ON, {ATTR_WHITE_VALUE: 255, ATTR_SUPPORTED_FEATURES: 128} @@ -219,6 +233,9 @@ async def test_color_temp(hass): } }, ) + await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() hass.states.async_set( "light.test1", STATE_ON, {"color_temp": 2, ATTR_SUPPORTED_FEATURES: 2} @@ -262,6 +279,8 @@ async def test_emulated_color_temp_group(hass): }, ) await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() hass.states.async_set("light.bed_light", STATE_ON, {ATTR_SUPPORTED_FEATURES: 2}) hass.states.async_set( @@ -306,6 +325,9 @@ async def test_min_max_mireds(hass): } }, ) + await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() hass.states.async_set( "light.test1", @@ -350,6 +372,9 @@ async def test_effect_list(hass): } }, ) + await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() hass.states.async_set( "light.test1", @@ -402,6 +427,9 @@ async def test_effect(hass): } }, ) + await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() hass.states.async_set( "light.test1", STATE_ON, {ATTR_EFFECT: "None", ATTR_SUPPORTED_FEATURES: 6} @@ -447,6 +475,9 @@ async def test_supported_features(hass): } }, ) + await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() hass.states.async_set("light.test1", STATE_ON, {ATTR_SUPPORTED_FEATURES: 0}) await hass.async_block_till_done() @@ -489,6 +520,8 @@ async def test_service_calls(hass): }, ) await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() assert hass.states.get("light.light_group").state == STATE_ON await hass.services.async_call( @@ -559,6 +592,9 @@ async def test_invalid_service_calls(hass): await group.async_setup_platform( hass, {"entities": ["light.test1", "light.test2"]}, add_entities ) + await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() assert add_entities.call_count == 1 grouped_light = add_entities.call_args[0][0][0]