From 37707edc478eff9bd123f669effbeff9f317268a Mon Sep 17 00:00:00 2001 From: Caius-Bonus <123886836+Caius-Bonus@users.noreply.github.com> Date: Wed, 27 Dec 2023 17:48:30 +0100 Subject: [PATCH] Match ZHA Custom ClusterHandler on a Custom Cluster using a unique id for the quirk (#101709) * initial * fix tests * match on specific name and quirk name * fix tests * fix tests * store cluster handlers in only one place * edit tests * use correct device for quirk id * change quirk id * fix tests * even if there is a quirk id, it doesn't have to have a specific cluster handler * add tests * rename quirk_id * add tests * fix tests * fix tests * use quirk id from zha_quirks --- .../cluster_handlers/manufacturerspecific.py | 11 +- .../components/zha/core/decorators.py | 18 +++ .../components/zha/core/discovery.py | 15 +- homeassistant/components/zha/core/endpoint.py | 23 ++- .../components/zha/core/registries.py | 6 +- tests/components/zha/test_cluster_handlers.py | 132 ++++++++++++++++-- 6 files changed, 177 insertions(+), 28 deletions(-) diff --git a/homeassistant/components/zha/core/cluster_handlers/manufacturerspecific.py b/homeassistant/components/zha/core/cluster_handlers/manufacturerspecific.py index 99c1e954a0e..556eb907605 100644 --- a/homeassistant/components/zha/core/cluster_handlers/manufacturerspecific.py +++ b/homeassistant/components/zha/core/cluster_handlers/manufacturerspecific.py @@ -5,8 +5,9 @@ import logging from typing import TYPE_CHECKING, Any from zhaquirks.inovelli.types import AllLEDEffectType, SingleLEDEffectType -from zhaquirks.quirk_ids import TUYA_PLUG_MANUFACTURER +from zhaquirks.quirk_ids import TUYA_PLUG_MANUFACTURER, XIAOMI_AQARA_VIBRATION_AQ1 import zigpy.zcl +from zigpy.zcl.clusters.closures import DoorLock from homeassistant.core import callback @@ -24,6 +25,7 @@ from ..const import ( UNKNOWN, ) from . import AttrReportConfig, ClientClusterHandler, ClusterHandler +from .general import MultistateInput if TYPE_CHECKING: from ..endpoint import Endpoint @@ -403,3 +405,10 @@ class IkeaRemote(ClusterHandler): """Ikea Matter remote cluster handler.""" REPORT_CONFIG = () + + +@registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.register( + DoorLock.cluster_id, XIAOMI_AQARA_VIBRATION_AQ1 +) +class XiaomiVibrationAQ1ClusterHandler(MultistateInput): + """Xiaomi DoorLock Cluster is in fact a MultiStateInput Cluster.""" diff --git a/homeassistant/components/zha/core/decorators.py b/homeassistant/components/zha/core/decorators.py index 71bfd510bea..192f6848989 100644 --- a/homeassistant/components/zha/core/decorators.py +++ b/homeassistant/components/zha/core/decorators.py @@ -21,6 +21,24 @@ class DictRegistry(dict[int | str, _TypeT]): return decorator +class NestedDictRegistry(dict[int | str, dict[int | str | None, _TypeT]]): + """Dict Registry of multiple items per key.""" + + def register( + self, name: int | str, sub_name: int | str | None = None + ) -> Callable[[_TypeT], _TypeT]: + """Return decorator to register item with a specific and a quirk name.""" + + def decorator(cluster_handler: _TypeT) -> _TypeT: + """Register decorated cluster handler or item.""" + if name not in self: + self[name] = {} + self[name][sub_name] = cluster_handler + return cluster_handler + + return decorator + + class SetRegistry(set[int | str]): """Set Registry of items.""" diff --git a/homeassistant/components/zha/core/discovery.py b/homeassistant/components/zha/core/discovery.py index 90ed68f9b00..1944f632e9a 100644 --- a/homeassistant/components/zha/core/discovery.py +++ b/homeassistant/components/zha/core/discovery.py @@ -203,9 +203,20 @@ class ProbeEndpoint: if platform is None: continue - cluster_handler_class = zha_regs.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( - cluster_id, ClusterHandler + cluster_handler_classes = zha_regs.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( + cluster_id, {None: ClusterHandler} ) + + quirk_id = ( + endpoint.device.quirk_id + if endpoint.device.quirk_id in cluster_handler_classes + else None + ) + + cluster_handler_class = cluster_handler_classes.get( + quirk_id, ClusterHandler + ) + cluster_handler = cluster_handler_class(cluster, endpoint) self.probe_single_cluster(platform, cluster_handler, endpoint) diff --git a/homeassistant/components/zha/core/endpoint.py b/homeassistant/components/zha/core/endpoint.py index c87ee60d6b3..04c253128ee 100644 --- a/homeassistant/components/zha/core/endpoint.py +++ b/homeassistant/components/zha/core/endpoint.py @@ -6,7 +6,6 @@ from collections.abc import Callable import logging from typing import TYPE_CHECKING, Any, Final, TypeVar -import zigpy from zigpy.typing import EndpointType as ZigpyEndpointType from homeassistant.const import Platform @@ -15,7 +14,6 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send from . import const, discovery, registries from .cluster_handlers import ClusterHandler -from .cluster_handlers.general import MultistateInput from .helpers import get_zha_data if TYPE_CHECKING: @@ -116,8 +114,16 @@ class Endpoint: def add_all_cluster_handlers(self) -> None: """Create and add cluster handlers for all input clusters.""" for cluster_id, cluster in self.zigpy_endpoint.in_clusters.items(): - cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( - cluster_id, ClusterHandler + cluster_handler_classes = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( + cluster_id, {None: ClusterHandler} + ) + quirk_id = ( + self.device.quirk_id + if self.device.quirk_id in cluster_handler_classes + else None + ) + cluster_handler_class = cluster_handler_classes.get( + quirk_id, ClusterHandler ) # Allow cluster handler to filter out bad matches @@ -129,15 +135,6 @@ class Endpoint: cluster_id, cluster_handler_class, ) - # really ugly hack to deal with xiaomi using the door lock cluster - # incorrectly. - if ( - hasattr(cluster, "ep_attribute") - and cluster_id == zigpy.zcl.clusters.closures.DoorLock.cluster_id - and cluster.ep_attribute == "multistate_input" - ): - cluster_handler_class = MultistateInput - # end of ugly hack try: cluster_handler = cluster_handler_class(cluster, self) diff --git a/homeassistant/components/zha/core/registries.py b/homeassistant/components/zha/core/registries.py index 87f59f31e9b..b302116694d 100644 --- a/homeassistant/components/zha/core/registries.py +++ b/homeassistant/components/zha/core/registries.py @@ -15,7 +15,7 @@ from zigpy.types.named import EUI64 from homeassistant.const import Platform -from .decorators import DictRegistry, SetRegistry +from .decorators import DictRegistry, NestedDictRegistry, SetRegistry if TYPE_CHECKING: from ..entity import ZhaEntity, ZhaGroupEntity @@ -110,7 +110,9 @@ CLUSTER_HANDLER_ONLY_CLUSTERS = SetRegistry() CLIENT_CLUSTER_HANDLER_REGISTRY: DictRegistry[ type[ClientClusterHandler] ] = DictRegistry() -ZIGBEE_CLUSTER_HANDLER_REGISTRY: DictRegistry[type[ClusterHandler]] = DictRegistry() +ZIGBEE_CLUSTER_HANDLER_REGISTRY: NestedDictRegistry[ + type[ClusterHandler] +] = NestedDictRegistry() WEIGHT_ATTR = attrgetter("weight") diff --git a/tests/components/zha/test_cluster_handlers.py b/tests/components/zha/test_cluster_handlers.py index e3d5741acd8..39f201e668e 100644 --- a/tests/components/zha/test_cluster_handlers.py +++ b/tests/components/zha/test_cluster_handlers.py @@ -3,6 +3,7 @@ import asyncio from collections.abc import Callable import logging import math +from types import NoneType from unittest import mock from unittest.mock import AsyncMock, patch @@ -11,12 +12,17 @@ import zigpy.device import zigpy.endpoint from zigpy.endpoint import Endpoint as ZigpyEndpoint import zigpy.profiles.zha +import zigpy.quirks as zigpy_quirks import zigpy.types as t from zigpy.zcl import foundation import zigpy.zcl.clusters +from zigpy.zcl.clusters import CLUSTERS_BY_ID import zigpy.zdo.types as zdo_t import homeassistant.components.zha.core.cluster_handlers as cluster_handlers +from homeassistant.components.zha.core.cluster_handlers.lighting import ( + ColorClusterHandler, +) import homeassistant.components.zha.core.const as zha_const from homeassistant.components.zha.core.device import ZHADevice from homeassistant.components.zha.core.endpoint import Endpoint @@ -97,7 +103,9 @@ def poll_control_ch(endpoint, zigpy_device_mock): ) cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id] - cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(cluster_id) + cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( + cluster_id + ).get(None) return cluster_handler_class(cluster, endpoint) @@ -258,8 +266,8 @@ async def test_in_cluster_handler_config( cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id] cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( - cluster_id, cluster_handlers.ClusterHandler - ) + cluster_id, {None, cluster_handlers.ClusterHandler} + ).get(None) cluster_handler = cluster_handler_class(cluster, endpoint) await cluster_handler.async_configure() @@ -322,8 +330,8 @@ async def test_out_cluster_handler_config( cluster = zigpy_dev.endpoints[1].out_clusters[cluster_id] cluster.bind_only = True cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( - cluster_id, cluster_handlers.ClusterHandler - ) + cluster_id, {None: cluster_handlers.ClusterHandler} + ).get(None) cluster_handler = cluster_handler_class(cluster, endpoint) await cluster_handler.async_configure() @@ -334,13 +342,46 @@ async def test_out_cluster_handler_config( def test_cluster_handler_registry() -> None: """Test ZIGBEE cluster handler Registry.""" + + # get all quirk ID from zigpy quirks registry + all_quirk_ids = {} + for cluster_id in CLUSTERS_BY_ID: + all_quirk_ids[cluster_id] = {None} + for manufacturer in zigpy_quirks._DEVICE_REGISTRY.registry.values(): + for model_quirk_list in manufacturer.values(): + for quirk in model_quirk_list: + quirk_id = getattr(quirk, zha_const.ATTR_QUIRK_ID, None) + device_description = getattr(quirk, "replacement", None) or getattr( + quirk, "signature", None + ) + + for endpoint in device_description["endpoints"].values(): + cluster_ids = set() + if "input_clusters" in endpoint: + cluster_ids.update(endpoint["input_clusters"]) + if "output_clusters" in endpoint: + cluster_ids.update(endpoint["output_clusters"]) + for cluster_id in cluster_ids: + if not isinstance(cluster_id, int): + cluster_id = cluster_id.cluster_id + if cluster_id not in all_quirk_ids: + all_quirk_ids[cluster_id] = {None} + all_quirk_ids[cluster_id].add(quirk_id) + + del quirk, model_quirk_list, manufacturer + for ( cluster_id, - cluster_handler, + cluster_handler_classes, ) in registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.items(): assert isinstance(cluster_id, int) assert 0 <= cluster_id <= 0xFFFF - assert issubclass(cluster_handler, cluster_handlers.ClusterHandler) + assert cluster_id in all_quirk_ids + assert isinstance(cluster_handler_classes, dict) + for quirk_id, cluster_handler in cluster_handler_classes.items(): + assert isinstance(quirk_id, NoneType) or isinstance(quirk_id, str) + assert issubclass(cluster_handler, cluster_handlers.ClusterHandler) + assert quirk_id in all_quirk_ids[cluster_id] def test_epch_unclaimed_cluster_handlers(cluster_handler) -> None: @@ -818,7 +859,8 @@ async def test_invalid_cluster_handler(hass: HomeAssistant, caplog) -> None: ], ) - mock_zha_device = mock.AsyncMock(spec_set=ZHADevice) + mock_zha_device = mock.AsyncMock(spec=ZHADevice) + mock_zha_device.quirk_id = None zha_endpoint = Endpoint(zigpy_ep, mock_zha_device) # The cluster handler throws an error when matching this cluster @@ -827,14 +869,84 @@ async def test_invalid_cluster_handler(hass: HomeAssistant, caplog) -> None: # And one is also logged at runtime with patch.dict( - registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY, - {cluster.cluster_id: TestZigbeeClusterHandler}, + registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id], + {None: TestZigbeeClusterHandler}, ), caplog.at_level(logging.WARNING): zha_endpoint.add_all_cluster_handlers() assert "missing_attr" in caplog.text +async def test_standard_cluster_handler(hass: HomeAssistant, caplog) -> None: + """Test setting up a cluster handler that matches a standard cluster.""" + + class TestZigbeeClusterHandler(ColorClusterHandler): + pass + + mock_device = mock.AsyncMock(spec_set=zigpy.device.Device) + zigpy_ep = zigpy.endpoint.Endpoint(mock_device, endpoint_id=1) + + cluster = zigpy_ep.add_input_cluster(zigpy.zcl.clusters.lighting.Color.cluster_id) + cluster.configure_reporting_multiple = AsyncMock( + spec_set=cluster.configure_reporting_multiple, + return_value=[ + foundation.ConfigureReportingResponseRecord( + status=foundation.Status.SUCCESS + ) + ], + ) + + mock_zha_device = mock.AsyncMock(spec=ZHADevice) + mock_zha_device.quirk_id = None + zha_endpoint = Endpoint(zigpy_ep, mock_zha_device) + + with patch.dict( + registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id], + {"__test_quirk_id": TestZigbeeClusterHandler}, + ): + zha_endpoint.add_all_cluster_handlers() + + assert len(zha_endpoint.all_cluster_handlers) == 1 + assert isinstance( + list(zha_endpoint.all_cluster_handlers.values())[0], ColorClusterHandler + ) + + +async def test_quirk_id_cluster_handler(hass: HomeAssistant, caplog) -> None: + """Test setting up a cluster handler that matches a standard cluster.""" + + class TestZigbeeClusterHandler(ColorClusterHandler): + pass + + mock_device = mock.AsyncMock(spec_set=zigpy.device.Device) + zigpy_ep = zigpy.endpoint.Endpoint(mock_device, endpoint_id=1) + + cluster = zigpy_ep.add_input_cluster(zigpy.zcl.clusters.lighting.Color.cluster_id) + cluster.configure_reporting_multiple = AsyncMock( + spec_set=cluster.configure_reporting_multiple, + return_value=[ + foundation.ConfigureReportingResponseRecord( + status=foundation.Status.SUCCESS + ) + ], + ) + + mock_zha_device = mock.AsyncMock(spec=ZHADevice) + mock_zha_device.quirk_id = "__test_quirk_id" + zha_endpoint = Endpoint(zigpy_ep, mock_zha_device) + + with patch.dict( + registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id], + {"__test_quirk_id": TestZigbeeClusterHandler}, + ): + zha_endpoint.add_all_cluster_handlers() + + assert len(zha_endpoint.all_cluster_handlers) == 1 + assert isinstance( + list(zha_endpoint.all_cluster_handlers.values())[0], TestZigbeeClusterHandler + ) + + # parametrize side effects: @pytest.mark.parametrize( ("side_effect", "expected_error"),