Force updates for ZHA light group entity members (#37961)

* Force updates for ZHA light group entity members

* add a 3 second debouncer to the forced refresh

* lint
This commit is contained in:
David F. Mulcahey 2020-07-18 14:47:32 -04:00 committed by GitHub
parent f173805c2f
commit 2354d0117b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,5 @@
"""Lights on Zigbee Home Automation networks.""" """Lights on Zigbee Home Automation networks."""
import asyncio
from collections import Counter from collections import Counter
from datetime import timedelta from datetime import timedelta
import functools import functools
@ -31,6 +32,7 @@ from homeassistant.components.light import (
) )
from homeassistant.const import ATTR_SUPPORTED_FEATURES, STATE_ON, STATE_UNAVAILABLE from homeassistant.const import ATTR_SUPPORTED_FEATURES, STATE_ON, STATE_UNAVAILABLE
from homeassistant.core import State, callback from homeassistant.core import State, callback
from homeassistant.helpers.debounce import Debouncer
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import async_track_time_interval
import homeassistant.util.color as color_util import homeassistant.util.color as color_util
@ -406,7 +408,7 @@ class Light(BaseLight, ZhaEntity):
async def async_get_state(self, from_cache=True): async def async_get_state(self, from_cache=True):
"""Attempt to retrieve on off state from the light.""" """Attempt to retrieve on off state from the light."""
self.debug("polling current state") self.debug("polling current state - from cache: %s", from_cache)
if self._on_off_channel: if self._on_off_channel:
state = await self._on_off_channel.get_attribute_value( state = await self._on_off_channel.get_attribute_value(
"on_off", from_cache=from_cache "on_off", from_cache=from_cache
@ -494,6 +496,30 @@ class LightGroup(BaseLight, ZhaGroupEntity):
self._level_channel = group.endpoint[LevelControl.cluster_id] self._level_channel = group.endpoint[LevelControl.cluster_id]
self._color_channel = group.endpoint[Color.cluster_id] self._color_channel = group.endpoint[Color.cluster_id]
self._identify_channel = group.endpoint[Identify.cluster_id] self._identify_channel = group.endpoint[Identify.cluster_id]
self._debounced_member_refresh = None
async def async_added_to_hass(self):
"""Run when about to be added to hass."""
await super().async_added_to_hass()
if self._debounced_member_refresh is None:
force_refresh_debouncer = Debouncer(
self.hass,
_LOGGER,
cooldown=3,
immediate=True,
function=self._force_member_updates,
)
self._debounced_member_refresh = force_refresh_debouncer
async def async_turn_on(self, **kwargs):
"""Turn the entity on."""
await super().async_turn_on(**kwargs)
await self._debounced_member_refresh.async_call()
async def async_turn_off(self, **kwargs):
"""Turn the entity off."""
await super().async_turn_off(**kwargs)
await self._debounced_member_refresh.async_call()
async def async_update(self) -> None: async def async_update(self) -> None:
"""Query all members and determine the light group state.""" """Query all members and determine the light group state."""
@ -541,3 +567,11 @@ class LightGroup(BaseLight, ZhaGroupEntity):
# Bitwise-and the supported features with the GroupedLight's features # Bitwise-and the supported features with the GroupedLight's features
# so that we don't break in the future when a new feature is added. # so that we don't break in the future when a new feature is added.
self._supported_features &= SUPPORT_GROUP_LIGHT self._supported_features &= SUPPORT_GROUP_LIGHT
async def _force_member_updates(self):
"""Force the update of member entities to ensure the states are correct for bulbs that don't report their state."""
component = self.hass.data[light.DOMAIN]
entities = [component.get_entity(entity_id) for entity_id in self._entity_ids]
tasks = [entity.async_get_state(from_cache=False) for entity in entities]
if tasks:
await asyncio.gather(*tasks)