diff --git a/homeassistant/components/zha/core/const.py b/homeassistant/components/zha/core/const.py index 43b1634cba7..da151f67dbb 100644 --- a/homeassistant/components/zha/core/const.py +++ b/homeassistant/components/zha/core/const.py @@ -208,6 +208,7 @@ SIGNAL_SET_LEVEL = "set_level" SIGNAL_STATE_ATTR = "update_state_attribute" SIGNAL_UPDATE_DEVICE = "{}_zha_update_device" SIGNAL_REMOVE_GROUP = "remove_group" +SIGNAL_GROUP_MEMBERSHIP_CHANGE = "group_membership_change" UNKNOWN = "unknown" UNKNOWN_MANUFACTURER = "unk_manufacturer" diff --git a/homeassistant/components/zha/core/gateway.py b/homeassistant/components/zha/core/gateway.py index bc7ff42d25f..fcc8a52360b 100644 --- a/homeassistant/components/zha/core/gateway.py +++ b/homeassistant/components/zha/core/gateway.py @@ -52,6 +52,7 @@ from .const import ( DEFAULT_DATABASE_NAME, DOMAIN, SIGNAL_ADD_ENTITIES, + SIGNAL_GROUP_MEMBERSHIP_CHANGE, SIGNAL_REMOVE, SIGNAL_REMOVE_GROUP, UNKNOWN_MANUFACTURER, @@ -256,6 +257,9 @@ class ZHAGateway: zha_group = self._async_get_or_create_group(zigpy_group) zha_group.info("group_member_removed - endpoint: %s", endpoint) self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_REMOVED) + async_dispatcher_send( + self._hass, f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_{zigpy_group.group_id}" + ) def group_member_added( self, zigpy_group: ZigpyGroupType, endpoint: ZigpyEndpointType @@ -265,6 +269,9 @@ class ZHAGateway: zha_group = self._async_get_or_create_group(zigpy_group) zha_group.info("group_member_added - endpoint: %s", endpoint) self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_ADDED) + async_dispatcher_send( + self._hass, f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_{zigpy_group.group_id}" + ) def group_added(self, zigpy_group: ZigpyGroupType) -> None: """Handle zigpy group added event.""" diff --git a/homeassistant/components/zha/entity.py b/homeassistant/components/zha/entity.py index 63ed3a6edc7..fda26f54d58 100644 --- a/homeassistant/components/zha/entity.py +++ b/homeassistant/components/zha/entity.py @@ -3,12 +3,13 @@ import asyncio import logging import time -from typing import Any, Awaitable, Dict, List +from typing import Any, Awaitable, Dict, List, Optional -from homeassistant.core import callback +from homeassistant.core import CALLBACK_TYPE, State, callback from homeassistant.helpers import entity from homeassistant.helpers.device_registry import CONNECTION_ZIGBEE from homeassistant.helpers.dispatcher import async_dispatcher_connect +from homeassistant.helpers.event import async_track_state_change from homeassistant.helpers.restore_state import RestoreEntity from .core.const import ( @@ -18,7 +19,9 @@ from .core.const import ( DATA_ZHA, DATA_ZHA_BRIDGE_ID, DOMAIN, + SIGNAL_GROUP_MEMBERSHIP_CHANGE, SIGNAL_REMOVE, + SIGNAL_REMOVE_GROUP, ) from .core.helpers import LogMixin from .core.typing import CALLABLE_T, ChannelsType, ChannelType, ZhaDeviceType @@ -213,3 +216,75 @@ class ZhaEntity(BaseZhaEntity): for channel in self.cluster_channels.values(): if hasattr(channel, "async_update"): await channel.async_update() + + +class ZhaGroupEntity(BaseZhaEntity): + """A base class for ZHA group entities.""" + + def __init__( + self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs + ) -> None: + """Initialize a light group.""" + super().__init__(unique_id, zha_device, **kwargs) + self._name = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}" + self._group_id: int = group_id + self._entity_ids: List[str] = entity_ids + self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None + + async def async_added_to_hass(self) -> None: + """Register callbacks.""" + await super().async_added_to_hass() + await self.async_accept_signal( + None, + f"{SIGNAL_REMOVE_GROUP}_{self._group_id}", + self.async_remove, + signal_override=True, + ) + + await self.async_accept_signal( + None, + f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_{self._group_id}", + self._update_group_entities, + signal_override=True, + ) + + @callback + def async_state_changed_listener( + entity_id: str, old_state: State, new_state: State + ): + """Handle child updates.""" + self.async_schedule_update_ha_state(True) + + self._async_unsub_state_changed = async_track_state_change( + self.hass, self._entity_ids, async_state_changed_listener + ) + await self.async_update() + + def _update_group_entities(self): + """Update tracked entities when membership changes.""" + group = self.zha_device.gateway.get_group(self._group_id) + self._entity_ids = group.get_domain_entity_ids(self.platform.domain) + if self._async_unsub_state_changed is not None: + self._async_unsub_state_changed() + + @callback + def async_state_changed_listener( + entity_id: str, old_state: State, new_state: State + ): + """Handle child updates.""" + self.async_schedule_update_ha_state(True) + + self._async_unsub_state_changed = async_track_state_change( + self.hass, self._entity_ids, async_state_changed_listener + ) + + async def async_will_remove_from_hass(self) -> None: + """Handle removal from Home Assistant.""" + await super().async_will_remove_from_hass() + if self._async_unsub_state_changed is not None: + self._async_unsub_state_changed() + self._async_unsub_state_changed = None + + async def async_update(self) -> None: + """Update the state of the group entity.""" + pass diff --git a/homeassistant/components/zha/fan.py b/homeassistant/components/zha/fan.py index 027d0f8a1ee..c3cd88b0d6d 100644 --- a/homeassistant/components/zha/fan.py +++ b/homeassistant/components/zha/fan.py @@ -1,7 +1,7 @@ """Fans on Zigbee Home Automation networks.""" import functools import logging -from typing import List, Optional +from typing import List from zigpy.exceptions import DeliveryError import zigpy.zcl.clusters.hvac as hvac @@ -16,9 +16,8 @@ from homeassistant.components.fan import ( FanEntity, ) from homeassistant.const import STATE_UNAVAILABLE -from homeassistant.core import CALLBACK_TYPE, State, callback +from homeassistant.core import State, callback from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.event import async_track_state_change from .core import discovery from .core.const import ( @@ -27,10 +26,9 @@ from .core.const import ( DATA_ZHA_DISPATCHERS, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, - SIGNAL_REMOVE_GROUP, ) from .core.registries import ZHA_ENTITIES -from .entity import BaseZhaEntity, ZhaEntity +from .entity import ZhaEntity, ZhaGroupEntity _LOGGER = logging.getLogger(__name__) @@ -73,7 +71,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities): hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) -class BaseFan(BaseZhaEntity, FanEntity): +class BaseFan(FanEntity): """Base representation of a ZHA fan.""" def __init__(self, *args, **kwargs): @@ -120,9 +118,14 @@ class BaseFan(BaseZhaEntity, FanEntity): await self._fan_channel.async_set_speed(SPEED_TO_VALUE[speed]) self.async_set_state(0, "fan_mode", speed) + @callback + def async_set_state(self, attr_id, attr_name, value): + """Handle state update from channel.""" + pass + @STRICT_MATCH(channel_names=CHANNEL_FAN) -class ZhaFan(ZhaEntity, BaseFan): +class ZhaFan(BaseFan, ZhaEntity): """Representation of a ZHA fan.""" def __init__(self, unique_id, zha_device, channels, **kwargs): @@ -158,19 +161,15 @@ class ZhaFan(ZhaEntity, BaseFan): @GROUP_MATCH() -class FanGroup(BaseFan): +class FanGroup(BaseFan, ZhaGroupEntity): """Representation of a fan group.""" def __init__( self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs ) -> None: """Initialize a fan group.""" - super().__init__(unique_id, zha_device, **kwargs) - self._name: str = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}" - self._group_id: int = group_id + super().__init__(entity_ids, unique_id, group_id, zha_device, **kwargs) self._available: bool = False - self._entity_ids: List[str] = entity_ids - self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None group = self.zha_device.gateway.get_group(self._group_id) self._fan_channel = group.endpoint[hvac.Fan.cluster_id] @@ -185,35 +184,6 @@ class FanGroup(BaseFan): self._fan_channel.async_set_speed = async_set_speed - async def async_added_to_hass(self) -> None: - """Register callbacks.""" - await super().async_added_to_hass() - await self.async_accept_signal( - None, - f"{SIGNAL_REMOVE_GROUP}_{self._group_id}", - self.async_remove, - signal_override=True, - ) - - @callback - def async_state_changed_listener( - entity_id: str, old_state: State, new_state: State - ): - """Handle child updates.""" - self.async_schedule_update_ha_state(True) - - self._async_unsub_state_changed = async_track_state_change( - self.hass, self._entity_ids, async_state_changed_listener - ) - await self.async_update() - - async def async_will_remove_from_hass(self) -> None: - """Handle removal from Home Assistant.""" - await super().async_will_remove_from_hass() - if self._async_unsub_state_changed is not None: - self._async_unsub_state_changed() - self._async_unsub_state_changed = None - async def async_update(self): """Attempt to retrieve on off state from the fan.""" all_states = [self.hass.states.get(x) for x in self._entity_ids] diff --git a/homeassistant/components/zha/light.py b/homeassistant/components/zha/light.py index 07cbc6af78c..c6ec5c2ccf9 100644 --- a/homeassistant/components/zha/light.py +++ b/homeassistant/components/zha/light.py @@ -30,12 +30,9 @@ from homeassistant.components.light import ( SUPPORT_WHITE_VALUE, ) from homeassistant.const import ATTR_SUPPORTED_FEATURES, STATE_ON, STATE_UNAVAILABLE -from homeassistant.core import CALLBACK_TYPE, State, callback +from homeassistant.core import State, callback from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.event import ( - async_track_state_change, - async_track_time_interval, -) +from homeassistant.helpers.event import async_track_time_interval import homeassistant.util.color as color_util from .core import discovery, helpers @@ -50,12 +47,12 @@ from .core.const import ( EFFECT_DEFAULT_VARIANT, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, - SIGNAL_REMOVE_GROUP, SIGNAL_SET_LEVEL, ) +from .core.helpers import LogMixin from .core.registries import ZHA_ENTITIES from .core.typing import ZhaDeviceType -from .entity import BaseZhaEntity, ZhaEntity +from .entity import ZhaEntity, ZhaGroupEntity _LOGGER = logging.getLogger(__name__) @@ -100,7 +97,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities): hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) -class BaseLight(BaseZhaEntity, light.Light): +class BaseLight(LogMixin, light.Light): """Operations common to all light entities.""" def __init__(self, *args, **kwargs): @@ -307,7 +304,7 @@ class BaseLight(BaseZhaEntity, light.Light): @STRICT_MATCH(channel_names=CHANNEL_ON_OFF, aux_channels={CHANNEL_COLOR, CHANNEL_LEVEL}) -class Light(ZhaEntity, BaseLight): +class Light(BaseLight, ZhaEntity): """Representation of a ZHA or ZLL light.""" _REFRESH_INTERVAL = (45, 75) @@ -471,52 +468,19 @@ class HueLight(Light): @GROUP_MATCH() -class LightGroup(BaseLight): +class LightGroup(BaseLight, ZhaGroupEntity): """Representation of a light group.""" def __init__( self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs ) -> None: """Initialize a light group.""" - super().__init__(unique_id, zha_device, **kwargs) - self._name = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}" - self._group_id: int = group_id - self._entity_ids: List[str] = entity_ids + super().__init__(entity_ids, unique_id, group_id, zha_device, **kwargs) group = self.zha_device.gateway.get_group(self._group_id) self._on_off_channel = group.endpoint[OnOff.cluster_id] self._level_channel = group.endpoint[LevelControl.cluster_id] self._color_channel = group.endpoint[Color.cluster_id] self._identify_channel = group.endpoint[Identify.cluster_id] - self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None - - async def async_added_to_hass(self) -> None: - """Register callbacks.""" - await super().async_added_to_hass() - await self.async_accept_signal( - None, - f"{SIGNAL_REMOVE_GROUP}_{self._group_id}", - self.async_remove, - signal_override=True, - ) - - @callback - def async_state_changed_listener( - entity_id: str, old_state: State, new_state: State - ): - """Handle child updates.""" - self.async_schedule_update_ha_state(True) - - self._async_unsub_state_changed = async_track_state_change( - self.hass, self._entity_ids, async_state_changed_listener - ) - await self.async_update() - - async def async_will_remove_from_hass(self) -> None: - """Handle removal from Home Assistant.""" - await super().async_will_remove_from_hass() - if self._async_unsub_state_changed is not None: - self._async_unsub_state_changed() - self._async_unsub_state_changed = None async def async_update(self) -> None: """Query all members and determine the light group state.""" diff --git a/homeassistant/components/zha/switch.py b/homeassistant/components/zha/switch.py index 90ec98ce1e3..328d9959ad2 100644 --- a/homeassistant/components/zha/switch.py +++ b/homeassistant/components/zha/switch.py @@ -1,16 +1,15 @@ """Switches on Zigbee Home Automation networks.""" import functools import logging -from typing import Any, List, Optional +from typing import Any, List from zigpy.zcl.clusters.general import OnOff from zigpy.zcl.foundation import Status from homeassistant.components.switch import DOMAIN, SwitchDevice from homeassistant.const import STATE_ON, STATE_UNAVAILABLE -from homeassistant.core import CALLBACK_TYPE, State, callback +from homeassistant.core import State, callback from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.event import async_track_state_change from .core import discovery from .core.const import ( @@ -19,10 +18,9 @@ from .core.const import ( DATA_ZHA_DISPATCHERS, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, - SIGNAL_REMOVE_GROUP, ) from .core.registries import ZHA_ENTITIES -from .entity import BaseZhaEntity, ZhaEntity +from .entity import ZhaEntity, ZhaGroupEntity _LOGGER = logging.getLogger(__name__) STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN) @@ -43,7 +41,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities): hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) -class BaseSwitch(BaseZhaEntity, SwitchDevice): +class BaseSwitch(SwitchDevice): """Common base class for zha switches.""" def __init__(self, *args, **kwargs): @@ -77,7 +75,7 @@ class BaseSwitch(BaseZhaEntity, SwitchDevice): @STRICT_MATCH(channel_names=CHANNEL_ON_OFF) -class Switch(ZhaEntity, BaseSwitch): +class Switch(BaseSwitch, ZhaEntity): """ZHA switch.""" def __init__(self, unique_id, zha_device, channels, **kwargs): @@ -113,50 +111,17 @@ class Switch(ZhaEntity, BaseSwitch): @GROUP_MATCH() -class SwitchGroup(BaseSwitch): +class SwitchGroup(BaseSwitch, ZhaGroupEntity): """Representation of a switch group.""" def __init__( self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs ) -> None: """Initialize a switch group.""" - super().__init__(unique_id, zha_device, **kwargs) - self._name: str = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}" - self._group_id: int = group_id + super().__init__(entity_ids, unique_id, group_id, zha_device, **kwargs) self._available: bool = False - self._entity_ids: List[str] = entity_ids group = self.zha_device.gateway.get_group(self._group_id) self._on_off_channel = group.endpoint[OnOff.cluster_id] - self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None - - async def async_added_to_hass(self) -> None: - """Register callbacks.""" - await super().async_added_to_hass() - await self.async_accept_signal( - None, - f"{SIGNAL_REMOVE_GROUP}_{self._group_id}", - self.async_remove, - signal_override=True, - ) - - @callback - def async_state_changed_listener( - entity_id: str, old_state: State, new_state: State - ): - """Handle child updates.""" - self.async_schedule_update_ha_state(True) - - self._async_unsub_state_changed = async_track_state_change( - self.hass, self._entity_ids, async_state_changed_listener - ) - await self.async_update() - - async def async_will_remove_from_hass(self) -> None: - """Handle removal from Home Assistant.""" - await super().async_will_remove_from_hass() - if self._async_unsub_state_changed is not None: - self._async_unsub_state_changed() - self._async_unsub_state_changed = None async def async_update(self) -> None: """Query all members and determine the light group state.""" diff --git a/tests/components/zha/test_light.py b/tests/components/zha/test_light.py index f832b9e86e0..9bdd4966a4a 100644 --- a/tests/components/zha/test_light.py +++ b/tests/components/zha/test_light.py @@ -30,6 +30,7 @@ ON = 1 OFF = 0 IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8" IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8" +IEEE_GROUPABLE_DEVICE3 = "03:2d:6f:00:0a:90:69:e8" LIGHT_ON_OFF = { 1: { @@ -140,6 +141,31 @@ async def device_light_2(hass, zigpy_device_mock, zha_device_joined): return zha_device +@pytest.fixture +async def device_light_3(hass, zigpy_device_mock, zha_device_joined): + """Test zha light platform.""" + + zigpy_device = zigpy_device_mock( + { + 1: { + "in_clusters": [ + general.OnOff.cluster_id, + general.LevelControl.cluster_id, + lighting.Color.cluster_id, + general.Groups.cluster_id, + general.Identify.cluster_id, + ], + "out_clusters": [], + "device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT, + } + }, + ieee=IEEE_GROUPABLE_DEVICE3, + ) + zha_device = await zha_device_joined(zigpy_device) + zha_device.set_available(True) + return zha_device + + @patch("zigpy.zcl.clusters.general.OnOff.read_attributes", new=MagicMock()) async def test_light_refresh(hass, zigpy_device_mock, zha_device_joined_restored): """Test zha light platform refresh.""" @@ -414,7 +440,7 @@ async def async_test_flash_from_hass(hass, cluster, entity_id, flash): async def async_test_zha_group_light_entity( - hass, device_light_1, device_light_2, coordinator + hass, device_light_1, device_light_2, device_light_3, coordinator ): """Test the light entity for a ZHA group.""" zha_gateway = get_zha_gateway(hass) @@ -445,6 +471,7 @@ async def async_test_zha_group_light_entity( dev1_cluster_on_off = device_light_1.endpoints[1].on_off dev2_cluster_on_off = device_light_2.endpoints[1].on_off + dev3_cluster_on_off = device_light_3.endpoints[1].on_off # test that the lights were created and that they are unavailable assert hass.states.get(entity_id).state == STATE_UNAVAILABLE @@ -503,3 +530,12 @@ async def async_test_zha_group_light_entity( # test that group light is now back on assert hass.states.get(entity_id).state == STATE_ON + + # test that group light is now off + await group_cluster_on_off.off() + assert hass.states.get(entity_id).state == STATE_OFF + + # add a new member and test that his state is also tracked + await zha_group.async_add_members([device_light_3.ieee]) + await dev3_cluster_on_off.on() + assert hass.states.get(entity_id).state == STATE_ON