Ensure the context is passed to group changes (#39221)

This commit is contained in:
J. Nick Koston 2020-08-25 17:22:10 -05:00 committed by GitHub
parent 20398cc0a6
commit 63ebea1706
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 100 additions and 35 deletions

View File

@ -15,6 +15,7 @@ from homeassistant.const import (
CONF_NAME, CONF_NAME,
ENTITY_MATCH_ALL, ENTITY_MATCH_ALL,
ENTITY_MATCH_NONE, ENTITY_MATCH_NONE,
EVENT_HOMEASSISTANT_START,
SERVICE_RELOAD, SERVICE_RELOAD,
STATE_CLOSED, STATE_CLOSED,
STATE_HOME, STATE_HOME,
@ -28,7 +29,7 @@ from homeassistant.const import (
STATE_UNKNOWN, STATE_UNKNOWN,
STATE_UNLOCKED, STATE_UNLOCKED,
) )
from homeassistant.core import callback from homeassistant.core import CoreState, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity, async_generate_entity_id from homeassistant.helpers.entity import Entity, async_generate_entity_id
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
@ -341,6 +342,33 @@ async def _async_process_config(hass, config, component):
) )
class GroupEntity(Entity):
"""Representation of a Group of entities."""
@property
def should_poll(self) -> bool:
"""Disable polling for group."""
return False
async def async_added_to_hass(self) -> None:
"""Register listeners."""
assert self.hass is not None
async def _update_at_start(_):
await self.async_update_ha_state(True)
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, _update_at_start)
async def async_defer_or_update_ha_state(self) -> None:
"""Only update once at start."""
assert self.hass is not None
if self.hass.state != CoreState.running:
return
await self.async_update_ha_state(True)
class Group(Entity): class Group(Entity):
"""Track a group of entity ids.""" """Track a group of entity ids."""
@ -545,6 +573,7 @@ class Group(Entity):
if self._async_unsub_state_changed is None: if self._async_unsub_state_changed is None:
return return
self.async_set_context(event.context)
self._async_update_group_state(event.data.get("new_state")) self._async_update_group_state(event.data.get("new_state"))
self.async_write_ha_state() self.async_write_ha_state()

View File

@ -39,10 +39,12 @@ from homeassistant.const import (
STATE_OPEN, STATE_OPEN,
STATE_OPENING, STATE_OPENING,
) )
from homeassistant.core import State, callback from homeassistant.core import State
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.event import async_track_state_change_event
from . import GroupEntity
# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs # mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs # mypy: no-check-untyped-defs
@ -68,7 +70,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
async_add_entities([CoverGroup(config[CONF_NAME], config[CONF_ENTITIES])]) async_add_entities([CoverGroup(config[CONF_NAME], config[CONF_ENTITIES])])
class CoverGroup(CoverEntity): class CoverGroup(GroupEntity, CoverEntity):
"""Representation of a CoverGroup.""" """Representation of a CoverGroup."""
def __init__(self, name, entities): def __init__(self, name, entities):
@ -94,14 +96,13 @@ class CoverGroup(CoverEntity):
KEY_POSITION: set(), KEY_POSITION: set(),
} }
@callback async def _update_supported_features_event(self, event):
def _update_supported_features_event(self, event): self.async_set_context(event.context)
self.update_supported_features( await self.async_update_supported_features(
event.data.get("entity_id"), event.data.get("new_state") event.data.get("entity_id"), event.data.get("new_state")
) )
@callback async def async_update_supported_features(
def update_supported_features(
self, entity_id: str, new_state: Optional[State], update_state: bool = True, self, entity_id: str, new_state: Optional[State], update_state: bool = True,
) -> None: ) -> None:
"""Update dictionaries with supported features.""" """Update dictionaries with supported features."""
@ -111,7 +112,7 @@ class CoverGroup(CoverEntity):
for values in self._tilts.values(): for values in self._tilts.values():
values.discard(entity_id) values.discard(entity_id)
if update_state: if update_state:
self.async_schedule_update_ha_state(True) await self.async_defer_or_update_ha_state()
return return
features = new_state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) features = new_state.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
@ -143,17 +144,22 @@ class CoverGroup(CoverEntity):
self._tilts[KEY_POSITION].discard(entity_id) self._tilts[KEY_POSITION].discard(entity_id)
if update_state: if update_state:
self.async_schedule_update_ha_state(True) await self.async_defer_or_update_ha_state()
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Register listeners.""" """Register listeners."""
for entity_id in self._entities: for entity_id in self._entities:
new_state = self.hass.states.get(entity_id) new_state = self.hass.states.get(entity_id)
self.update_supported_features(entity_id, new_state, update_state=False) await self.async_update_supported_features(
async_track_state_change_event( entity_id, new_state, update_state=False
self.hass, self._entities, self._update_supported_features_event )
assert self.hass is not None
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entities, self._update_supported_features_event
)
) )
await self.async_update() await super().async_added_to_hass()
@property @property
def name(self): def name(self):
@ -165,11 +171,6 @@ class CoverGroup(CoverEntity):
"""Enable buttons even if at end position.""" """Enable buttons even if at end position."""
return self._assumed_state return self._assumed_state
@property
def should_poll(self):
"""Disable polling for cover group."""
return False
@property @property
def supported_features(self): def supported_features(self):
"""Flag supported features for the cover.""" """Flag supported features for the cover."""

View File

@ -36,12 +36,14 @@ from homeassistant.const import (
STATE_ON, STATE_ON,
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
) )
from homeassistant.core import CALLBACK_TYPE, State, callback from homeassistant.core import State
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.event import async_track_state_change_event
from homeassistant.helpers.typing import ConfigType, HomeAssistantType from homeassistant.helpers.typing import ConfigType, HomeAssistantType
from homeassistant.util import color as color_util from homeassistant.util import color as color_util
from . import GroupEntity
# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs # mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs # mypy: no-check-untyped-defs
@ -76,7 +78,7 @@ async def async_setup_platform(
) )
class LightGroup(light.LightEntity): class LightGroup(GroupEntity, light.LightEntity):
"""Representation of a light group.""" """Representation of a light group."""
def __init__(self, name: str, entity_ids: List[str]) -> None: def __init__(self, name: str, entity_ids: List[str]) -> None:
@ -94,27 +96,22 @@ class LightGroup(light.LightEntity):
self._effect_list: Optional[List[str]] = None self._effect_list: Optional[List[str]] = None
self._effect: Optional[str] = None self._effect: Optional[str] = None
self._supported_features: int = 0 self._supported_features: int = 0
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Register callbacks.""" """Register callbacks."""
@callback async def async_state_changed_listener(event):
def async_state_changed_listener(*_):
"""Handle child updates.""" """Handle child updates."""
self.async_schedule_update_ha_state(True) self.async_set_context(event.context)
await self.async_defer_or_update_ha_state()
assert self.hass is not None assert self.hass
self._async_unsub_state_changed = async_track_state_change_event( self.async_on_remove(
self.hass, self._entity_ids, async_state_changed_listener async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
) )
await self.async_update() await super().async_added_to_hass()
async def async_will_remove_from_hass(self):
"""Handle removal from Home Assistant."""
if self._async_unsub_state_changed is not None:
self._async_unsub_state_changed()
self._async_unsub_state_changed = None
@property @property
def name(self) -> str: def name(self) -> str:

View File

@ -78,6 +78,8 @@ async def setup_comp(hass, config_count):
with assert_setup_component(count, DOMAIN): with assert_setup_component(count, DOMAIN):
await async_setup_component(hass, DOMAIN, config) await async_setup_component(hass, DOMAIN, config)
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
@pytest.mark.parametrize("config_count", [(CONFIG_ATTRIBUTES, 1)]) @pytest.mark.parametrize("config_count", [(CONFIG_ATTRIBUTES, 1)])

View File

@ -47,6 +47,8 @@ async def test_default_state(hass):
}, },
) )
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
state = hass.states.get("light.bedroom_group") state = hass.states.get("light.bedroom_group")
assert state is not None assert state is not None
@ -73,6 +75,9 @@ async def test_state_reporting(hass):
} }
}, },
) )
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
hass.states.async_set("light.test1", STATE_ON) hass.states.async_set("light.test1", STATE_ON)
hass.states.async_set("light.test2", STATE_UNAVAILABLE) hass.states.async_set("light.test2", STATE_UNAVAILABLE)
@ -107,6 +112,9 @@ async def test_brightness(hass):
} }
}, },
) )
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
hass.states.async_set( hass.states.async_set(
"light.test1", STATE_ON, {ATTR_BRIGHTNESS: 255, ATTR_SUPPORTED_FEATURES: 1} "light.test1", STATE_ON, {ATTR_BRIGHTNESS: 255, ATTR_SUPPORTED_FEATURES: 1}
@ -147,6 +155,9 @@ async def test_color(hass):
} }
}, },
) )
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
hass.states.async_set( hass.states.async_set(
"light.test1", STATE_ON, {ATTR_HS_COLOR: (0, 100), ATTR_SUPPORTED_FEATURES: 16} "light.test1", STATE_ON, {ATTR_HS_COLOR: (0, 100), ATTR_SUPPORTED_FEATURES: 16}
@ -184,6 +195,9 @@ async def test_white_value(hass):
} }
}, },
) )
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
hass.states.async_set( hass.states.async_set(
"light.test1", STATE_ON, {ATTR_WHITE_VALUE: 255, ATTR_SUPPORTED_FEATURES: 128} "light.test1", STATE_ON, {ATTR_WHITE_VALUE: 255, ATTR_SUPPORTED_FEATURES: 128}
@ -219,6 +233,9 @@ async def test_color_temp(hass):
} }
}, },
) )
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
hass.states.async_set( hass.states.async_set(
"light.test1", STATE_ON, {"color_temp": 2, ATTR_SUPPORTED_FEATURES: 2} "light.test1", STATE_ON, {"color_temp": 2, ATTR_SUPPORTED_FEATURES: 2}
@ -262,6 +279,8 @@ async def test_emulated_color_temp_group(hass):
}, },
) )
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
hass.states.async_set("light.bed_light", STATE_ON, {ATTR_SUPPORTED_FEATURES: 2}) hass.states.async_set("light.bed_light", STATE_ON, {ATTR_SUPPORTED_FEATURES: 2})
hass.states.async_set( hass.states.async_set(
@ -306,6 +325,9 @@ async def test_min_max_mireds(hass):
} }
}, },
) )
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
hass.states.async_set( hass.states.async_set(
"light.test1", "light.test1",
@ -350,6 +372,9 @@ async def test_effect_list(hass):
} }
}, },
) )
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
hass.states.async_set( hass.states.async_set(
"light.test1", "light.test1",
@ -402,6 +427,9 @@ async def test_effect(hass):
} }
}, },
) )
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
hass.states.async_set( hass.states.async_set(
"light.test1", STATE_ON, {ATTR_EFFECT: "None", ATTR_SUPPORTED_FEATURES: 6} "light.test1", STATE_ON, {ATTR_EFFECT: "None", ATTR_SUPPORTED_FEATURES: 6}
@ -447,6 +475,9 @@ async def test_supported_features(hass):
} }
}, },
) )
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
hass.states.async_set("light.test1", STATE_ON, {ATTR_SUPPORTED_FEATURES: 0}) hass.states.async_set("light.test1", STATE_ON, {ATTR_SUPPORTED_FEATURES: 0})
await hass.async_block_till_done() await hass.async_block_till_done()
@ -489,6 +520,8 @@ async def test_service_calls(hass):
}, },
) )
await hass.async_block_till_done() await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
assert hass.states.get("light.light_group").state == STATE_ON assert hass.states.get("light.light_group").state == STATE_ON
await hass.services.async_call( await hass.services.async_call(
@ -559,6 +592,9 @@ async def test_invalid_service_calls(hass):
await group.async_setup_platform( await group.async_setup_platform(
hass, {"entities": ["light.test1", "light.test2"]}, add_entities hass, {"entities": ["light.test1", "light.test2"]}, add_entities
) )
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
assert add_entities.call_count == 1 assert add_entities.call_count == 1
grouped_light = add_entities.call_args[0][0][0] grouped_light = add_entities.call_args[0][0][0]