From 342e6a503acfa71e2ab2116a3c491fb29a2b90d9 Mon Sep 17 00:00:00 2001 From: "David F. Mulcahey" Date: Sat, 13 Jul 2024 21:25:15 -0400 Subject: [PATCH] Fix group operations in ZHA websocket API (#121881) --- homeassistant/components/zha/helpers.py | 40 ++++---- homeassistant/components/zha/websocket_api.py | 24 +++-- tests/components/zha/test_websocket_api.py | 98 ++++++++++++++++++- 3 files changed, 131 insertions(+), 31 deletions(-) diff --git a/homeassistant/components/zha/helpers.py b/homeassistant/components/zha/helpers.py index 4f60e8b32b2..0691e2429d1 100644 --- a/homeassistant/components/zha/helpers.py +++ b/homeassistant/components/zha/helpers.py @@ -73,7 +73,7 @@ from zha.exceptions import ZHAException from zha.mixins import LogMixin from zha.zigbee.cluster_handlers import ClusterBindEvent, ClusterConfigureReportingEvent from zha.zigbee.device import ClusterHandlerConfigurationComplete, Device, ZHAEvent -from zha.zigbee.group import Group, GroupMember +from zha.zigbee.group import Group, GroupInfo, GroupMember from zigpy.config import ( CONF_DATABASE, CONF_DEVICE, @@ -290,7 +290,11 @@ class ZHAGroupProxy(LogMixin): def log(self, level: int, msg: str, *args: Any, **kwargs) -> None: """Log a message.""" msg = f"[%s](%s): {msg}" - args = (f"0x{self.group.group_id:04x}", self.group.endpoint.id, *args) + args = ( + f"0x{self.group.group_id:04x}", + self.group.endpoint.endpoint_id, + *args, + ) _LOGGER.log(level, msg, *args, **kwargs) @@ -673,8 +677,8 @@ class ZHAGatewayProxy(EventBase): @callback def handle_group_removed(self, event: GroupEvent) -> None: """Handle a group removed event.""" - self._send_group_gateway_message(event.group_info, ZHA_GW_MSG_GROUP_REMOVED) zha_group_proxy = self.group_proxies.pop(event.group_info.group_id) + self._send_group_gateway_message(zha_group_proxy, ZHA_GW_MSG_GROUP_REMOVED) zha_group_proxy.info("group_removed") self._cleanup_group_entity_registry_entries(zha_group_proxy) @@ -760,12 +764,14 @@ class ZHAGatewayProxy(EventBase): zha_device_proxy.device_id = device_registry_device.id return zha_device_proxy - def _async_get_or_create_group_proxy(self, zha_group: Group) -> ZHAGroupProxy: + def _async_get_or_create_group_proxy(self, group_info: GroupInfo) -> ZHAGroupProxy: """Get or create a ZHA group.""" - zha_group_proxy = self.group_proxies.get(zha_group.group_id) + zha_group_proxy = self.group_proxies.get(group_info.group_id) if zha_group_proxy is None: - zha_group_proxy = ZHAGroupProxy(zha_group, self) - self.group_proxies[zha_group.group_id] = zha_group_proxy + zha_group_proxy = ZHAGroupProxy( + self.gateway.groups[group_info.group_id], self + ) + self.group_proxies[group_info.group_id] = zha_group_proxy return zha_group_proxy def _create_entity_metadata( @@ -840,19 +846,17 @@ class ZHAGatewayProxy(EventBase): async_dispatcher_send(self.hass, SIGNAL_ADD_ENTITIES) def _send_group_gateway_message( - self, zigpy_group: zigpy.group.Group, gateway_message_type: str + self, zha_group_proxy: ZHAGroupProxy, gateway_message_type: str ) -> None: """Send the gateway event for a zigpy group event.""" - zha_group = self.group_proxies.get(zigpy_group.group_id) - if zha_group is not None: - async_dispatcher_send( - self.hass, - ZHA_GW_MSG, - { - ATTR_TYPE: gateway_message_type, - ZHA_GW_MSG_GROUP_INFO: zha_group.group_info, - }, - ) + async_dispatcher_send( + self.hass, + ZHA_GW_MSG, + { + ATTR_TYPE: gateway_message_type, + ZHA_GW_MSG_GROUP_INFO: zha_group_proxy.group_info, + }, + ) async def _async_remove_device( self, device: ZHADeviceProxy, entity_refs: list[EntityReference] | None diff --git a/homeassistant/components/zha/websocket_api.py b/homeassistant/components/zha/websocket_api.py index 053a941de8d..97c625a27ed 100644 --- a/homeassistant/components/zha/websocket_api.py +++ b/homeassistant/components/zha/websocket_api.py @@ -47,7 +47,7 @@ from zha.application.helpers import ( ) from zha.zigbee.cluster_handlers.const import CLUSTER_HANDLER_IAS_WD from zha.zigbee.device import Device -from zha.zigbee.group import GroupMember +from zha.zigbee.group import GroupMemberReference import zigpy.backups from zigpy.config import CONF_DEVICE from zigpy.config.validators import cv_boolean @@ -259,9 +259,9 @@ class ClusterBinding(NamedTuple): endpoint_id: int -def _cv_group_member(value: dict[str, Any]) -> GroupMember: +def _cv_group_member(value: dict[str, Any]) -> GroupMemberReference: """Transform a group member.""" - return GroupMember( + return GroupMemberReference( ieee=value[ATTR_IEEE], endpoint_id=value[ATTR_ENDPOINT_ID], ) @@ -519,7 +519,7 @@ async def websocket_add_group( zha_gateway = get_zha_gateway_proxy(hass) group_name: str = msg[GROUP_NAME] group_id: int | None = msg.get(GROUP_ID) - members: list[GroupMember] | None = msg.get(ATTR_MEMBERS) + members: list[GroupMemberReference] | None = msg.get(ATTR_MEMBERS) group = await zha_gateway.gateway.async_create_zigpy_group( group_name, members, group_id ) @@ -570,8 +570,9 @@ async def websocket_add_group_members( ) -> None: """Add members to a ZHA group.""" zha_gateway = get_zha_gateway(hass) + zha_gateway_proxy = get_zha_gateway_proxy(hass) group_id: int = msg[GROUP_ID] - members: list[GroupMember] = msg[ATTR_MEMBERS] + members: list[GroupMemberReference] = msg[ATTR_MEMBERS] if not (zha_group := zha_gateway.groups.get(group_id)): connection.send_message( @@ -582,8 +583,9 @@ async def websocket_add_group_members( return await zha_group.async_add_members(members) - ret_group = zha_group.group_info - connection.send_result(msg[ID], ret_group) + ret_group = zha_gateway_proxy.get_group_proxy(group_id) + assert ret_group + connection.send_result(msg[ID], ret_group.group_info) @websocket_api.require_admin @@ -600,8 +602,9 @@ async def websocket_remove_group_members( ) -> None: """Remove members from a ZHA group.""" zha_gateway = get_zha_gateway(hass) + zha_gateway_proxy = get_zha_gateway_proxy(hass) group_id: int = msg[GROUP_ID] - members: list[GroupMember] = msg[ATTR_MEMBERS] + members: list[GroupMemberReference] = msg[ATTR_MEMBERS] if not (zha_group := zha_gateway.groups.get(group_id)): connection.send_message( @@ -612,8 +615,9 @@ async def websocket_remove_group_members( return await zha_group.async_remove_members(members) - ret_group = zha_group.group_info - connection.send_result(msg[ID], ret_group) + ret_group = zha_gateway_proxy.get_group_proxy(group_id) + assert ret_group + connection.send_result(msg[ID], ret_group.group_info) @websocket_api.require_admin diff --git a/tests/components/zha/test_websocket_api.py b/tests/components/zha/test_websocket_api.py index ea8ea39aed9..f6afee9eb83 100644 --- a/tests/components/zha/test_websocket_api.py +++ b/tests/components/zha/test_websocket_api.py @@ -440,9 +440,16 @@ async def test_list_groupable_devices( assert len(device_endpoints) == 0 -async def test_add_group(zha_client) -> None: +async def test_add_group(hass: HomeAssistant, zha_client) -> None: """Test adding and getting a new ZHA zigbee group.""" - await zha_client.send_json({ID: 12, TYPE: "zha/group/add", GROUP_NAME: "new_group"}) + await zha_client.send_json( + { + ID: 12, + TYPE: "zha/group/add", + GROUP_NAME: "new_group", + "members": [{"ieee": IEEE_GROUPABLE_DEVICE, "endpoint_id": 1}], + } + ) msg = await zha_client.receive_json() assert msg["id"] == 12 @@ -450,8 +457,17 @@ async def test_add_group(zha_client) -> None: added_group = msg["result"] + groupable_device = get_zha_gateway_proxy(hass).device_proxies[ + EUI64.convert(IEEE_GROUPABLE_DEVICE) + ] + assert added_group["name"] == "new_group" - assert added_group["members"] == [] + assert len(added_group["members"]) == 1 + assert added_group["members"][0]["device"]["ieee"] == IEEE_GROUPABLE_DEVICE + assert ( + added_group["members"][0]["device"]["device_reg_id"] + == groupable_device.device_id + ) await zha_client.send_json({ID: 13, TYPE: "zha/groups"}) @@ -499,6 +515,82 @@ async def test_remove_group(zha_client) -> None: assert len(groups) == 0 +async def test_add_group_member(hass: HomeAssistant, zha_client) -> None: + """Test adding a ZHA zigbee group member.""" + await zha_client.send_json( + { + ID: 12, + TYPE: "zha/group/add", + GROUP_NAME: "new_group", + } + ) + + msg = await zha_client.receive_json() + assert msg["id"] == 12 + assert msg["type"] == TYPE_RESULT + + added_group = msg["result"] + + assert len(added_group["members"]) == 0 + + await zha_client.send_json( + { + ID: 13, + TYPE: "zha/group/members/add", + GROUP_ID: added_group["group_id"], + "members": [{"ieee": IEEE_GROUPABLE_DEVICE, "endpoint_id": 1}], + } + ) + + msg = await zha_client.receive_json() + assert msg["id"] == 13 + assert msg["type"] == TYPE_RESULT + + added_group = msg["result"] + + assert len(added_group["members"]) == 1 + assert added_group["name"] == "new_group" + assert added_group["members"][0]["device"]["ieee"] == IEEE_GROUPABLE_DEVICE + + +async def test_remove_group_member(hass: HomeAssistant, zha_client) -> None: + """Test removing a ZHA zigbee group member.""" + await zha_client.send_json( + { + ID: 12, + TYPE: "zha/group/add", + GROUP_NAME: "new_group", + "members": [{"ieee": IEEE_GROUPABLE_DEVICE, "endpoint_id": 1}], + } + ) + + msg = await zha_client.receive_json() + assert msg["id"] == 12 + assert msg["type"] == TYPE_RESULT + + added_group = msg["result"] + + assert added_group["name"] == "new_group" + assert len(added_group["members"]) == 1 + assert added_group["members"][0]["device"]["ieee"] == IEEE_GROUPABLE_DEVICE + + await zha_client.send_json( + { + ID: 13, + TYPE: "zha/group/members/remove", + GROUP_ID: added_group["group_id"], + "members": [{"ieee": IEEE_GROUPABLE_DEVICE, "endpoint_id": 1}], + } + ) + + msg = await zha_client.receive_json() + assert msg["id"] == 13 + assert msg["type"] == TYPE_RESULT + + added_group = msg["result"] + assert len(added_group["members"]) == 0 + + @pytest.fixture async def app_controller( hass: HomeAssistant, setup_zha, zigpy_app_controller: ControllerApplication