diff --git a/homeassistant/components/zha/core/channels/general.py b/homeassistant/components/zha/core/channels/general.py index d51c03b33c9..783188248f3 100644 --- a/homeassistant/components/zha/core/channels/general.py +++ b/homeassistant/components/zha/core/channels/general.py @@ -1,6 +1,9 @@ """General channels module for Zigbee Home Automation.""" +import asyncio import logging +from typing import Any, List, Optional +import zigpy.exceptions import zigpy.zcl.clusters.general as general from homeassistant.core import callback @@ -332,11 +335,41 @@ class Partition(ZigbeeChannel): pass +@registries.CHANNEL_ONLY_CLUSTERS.register(general.PollControl.cluster_id) @registries.ZIGBEE_CHANNEL_REGISTRY.register(general.PollControl.cluster_id) class PollControl(ZigbeeChannel): """Poll Control channel.""" - pass + CHECKIN_INTERVAL = 55 * 60 * 4 # 55min + CHECKIN_FAST_POLL_TIMEOUT = 2 * 4 # 2s + LONG_POLL = 6 * 4 # 6s + + async def async_configure(self) -> None: + """Configure channel: set check-in interval.""" + try: + res = await self.cluster.write_attributes( + {"checkin_interval": self.CHECKIN_INTERVAL} + ) + self.debug("%ss check-in interval set: %s", self.CHECKIN_INTERVAL / 4, res) + except (asyncio.TimeoutError, zigpy.exceptions.ZigbeeException) as ex: + self.debug("Couldn't set check-in interval: %s", ex) + await super().async_configure() + + @callback + def cluster_command( + self, tsn: int, command_id: int, args: Optional[List[Any]] + ) -> None: + """Handle commands received to this cluster.""" + cmd_name = self.cluster.client_commands.get(command_id, [command_id])[0] + self.debug("Received %s tsn command '%s': %s", tsn, cmd_name, args) + self.zha_send_event(cmd_name, args) + if cmd_name == "checkin": + self.cluster.create_catching_task(self.check_in_response(tsn)) + + async def check_in_response(self, tsn: int) -> None: + """Respond to checkin command.""" + await self.checkin_response(True, self.CHECKIN_FAST_POLL_TIMEOUT, tsn=tsn) + await self.set_long_poll_interval(self.LONG_POLL) @registries.DEVICE_TRACKER_CLUSTERS.register(general.PowerConfiguration.cluster_id) diff --git a/tests/components/zha/common.py b/tests/components/zha/common.py index c21d05aa364..3eb6f407f32 100644 --- a/tests/components/zha/common.py +++ b/tests/components/zha/common.py @@ -127,13 +127,15 @@ async def async_enable_traffic(hass, zha_devices): await hass.async_block_till_done() -def make_zcl_header(command_id: int, global_command: bool = True) -> zcl_f.ZCLHeader: +def make_zcl_header( + command_id: int, global_command: bool = True, tsn: int = 1 +) -> zcl_f.ZCLHeader: """Cluster.handle_message() ZCL Header helper.""" if global_command: frc = zcl_f.FrameControl(zcl_f.FrameType.GLOBAL_COMMAND) else: frc = zcl_f.FrameControl(zcl_f.FrameType.CLUSTER_COMMAND) - return zcl_f.ZCLHeader(frc, tsn=1, command_id=command_id) + return zcl_f.ZCLHeader(frc, tsn=tsn, command_id=command_id) def reset_clusters(clusters): diff --git a/tests/components/zha/test_channels.py b/tests/components/zha/test_channels.py index 9eac267273b..ec9c172430c 100644 --- a/tests/components/zha/test_channels.py +++ b/tests/components/zha/test_channels.py @@ -5,13 +5,14 @@ from unittest import mock import asynctest import pytest import zigpy.types as t +import zigpy.zcl.clusters import homeassistant.components.zha.core.channels as zha_channels import homeassistant.components.zha.core.channels.base as base_channels import homeassistant.components.zha.core.const as zha_const import homeassistant.components.zha.core.registries as registries -from .common import get_zha_gateway +from .common import get_zha_gateway, make_zcl_header @pytest.fixture @@ -42,6 +43,37 @@ def channel_pool(): return ch_pool_mock +@pytest.fixture +def poll_control_ch(channel_pool, zigpy_device_mock): + """Poll control channel fixture.""" + cluster_id = zigpy.zcl.clusters.general.PollControl.cluster_id + zigpy_dev = zigpy_device_mock( + {1: {"in_clusters": [cluster_id], "out_clusters": [], "device_type": 0x1234}}, + "00:11:22:33:44:55:66:77", + "test manufacturer", + "test model", + ) + + cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id] + channel_class = registries.ZIGBEE_CHANNEL_REGISTRY.get(cluster_id) + return channel_class(cluster, channel_pool) + + +@pytest.fixture +async def poll_control_device(zha_device_restored, zigpy_device_mock): + """Poll control device fixture.""" + cluster_id = zigpy.zcl.clusters.general.PollControl.cluster_id + zigpy_dev = zigpy_device_mock( + {1: {"in_clusters": [cluster_id], "out_clusters": [], "device_type": 0x1234}}, + "00:11:22:33:44:55:66:77", + "test manufacturer", + "test model", + ) + + zha_device = await zha_device_restored(zigpy_dev) + return zha_device + + @pytest.mark.parametrize( "cluster_id, bind_count, attrs", [ @@ -371,3 +403,65 @@ async def test_ep_channels_configure(channel): assert ch_3.warning.call_count == 2 assert ch_5.warning.call_count == 2 + + +async def test_poll_control_configure(poll_control_ch): + """Test poll control channel configuration.""" + await poll_control_ch.async_configure() + assert poll_control_ch.cluster.write_attributes.call_count == 1 + assert poll_control_ch.cluster.write_attributes.call_args[0][0] == { + "checkin_interval": poll_control_ch.CHECKIN_INTERVAL + } + + +async def test_poll_control_checkin_response(poll_control_ch): + """Test poll control channel checkin response.""" + rsp_mock = asynctest.CoroutineMock() + set_interval_mock = asynctest.CoroutineMock() + cluster = poll_control_ch.cluster + patch_1 = mock.patch.object(cluster, "checkin_response", rsp_mock) + patch_2 = mock.patch.object(cluster, "set_long_poll_interval", set_interval_mock) + + with patch_1, patch_2: + await poll_control_ch.check_in_response(33) + + assert rsp_mock.call_count == 1 + assert set_interval_mock.call_count == 1 + + await poll_control_ch.check_in_response(33) + assert cluster.endpoint.request.call_count == 2 + assert cluster.endpoint.request.await_count == 2 + assert cluster.endpoint.request.call_args_list[0][0][1] == 33 + assert cluster.endpoint.request.call_args_list[0][0][0] == 0x0020 + assert cluster.endpoint.request.call_args_list[1][0][0] == 0x0020 + + +async def test_poll_control_cluster_command(hass, poll_control_device): + """Test poll control channel response to cluster command.""" + checkin_mock = asynctest.CoroutineMock() + poll_control_ch = poll_control_device.channels.pools[0].all_channels["1:0x0020"] + cluster = poll_control_ch.cluster + + events = [] + hass.bus.async_listen("zha_event", lambda x: events.append(x)) + await hass.async_block_till_done() + + with mock.patch.object(poll_control_ch, "check_in_response", checkin_mock): + tsn = 22 + hdr = make_zcl_header(0, global_command=False, tsn=tsn) + assert not events + cluster.handle_message( + hdr, [mock.sentinel.args, mock.sentinel.args2, mock.sentinel.args3] + ) + await hass.async_block_till_done() + + assert checkin_mock.call_count == 1 + assert checkin_mock.await_count == 1 + assert checkin_mock.await_args[0][0] == tsn + assert len(events) == 1 + data = events[0].data + assert data["command"] == "checkin" + assert data["args"][0] is mock.sentinel.args + assert data["args"][1] is mock.sentinel.args2 + assert data["args"][2] is mock.sentinel.args3 + assert data["unique_id"] == "00:11:22:33:44:55:66:77:1:0x0020"