Match ZHA Custom ClusterHandler on a Custom Cluster using a unique id for the quirk (#101709)

* initial

* fix tests

* match on specific name and quirk name

* fix tests

* fix tests

* store cluster handlers in only one place

* edit tests

* use correct device for quirk id

* change quirk id

* fix tests

* even if there is a quirk id, it doesn't have to have a specific cluster handler

* add tests

* rename quirk_id

* add tests

* fix tests

* fix tests

* use quirk id from zha_quirks
This commit is contained in:
Caius-Bonus 2023-12-27 17:48:30 +01:00 committed by GitHub
parent 4330452212
commit 37707edc47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 177 additions and 28 deletions

View File

@ -5,8 +5,9 @@ import logging
from typing import TYPE_CHECKING, Any
from zhaquirks.inovelli.types import AllLEDEffectType, SingleLEDEffectType
from zhaquirks.quirk_ids import TUYA_PLUG_MANUFACTURER
from zhaquirks.quirk_ids import TUYA_PLUG_MANUFACTURER, XIAOMI_AQARA_VIBRATION_AQ1
import zigpy.zcl
from zigpy.zcl.clusters.closures import DoorLock
from homeassistant.core import callback
@ -24,6 +25,7 @@ from ..const import (
UNKNOWN,
)
from . import AttrReportConfig, ClientClusterHandler, ClusterHandler
from .general import MultistateInput
if TYPE_CHECKING:
from ..endpoint import Endpoint
@ -403,3 +405,10 @@ class IkeaRemote(ClusterHandler):
"""Ikea Matter remote cluster handler."""
REPORT_CONFIG = ()
@registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.register(
DoorLock.cluster_id, XIAOMI_AQARA_VIBRATION_AQ1
)
class XiaomiVibrationAQ1ClusterHandler(MultistateInput):
"""Xiaomi DoorLock Cluster is in fact a MultiStateInput Cluster."""

View File

@ -21,6 +21,24 @@ class DictRegistry(dict[int | str, _TypeT]):
return decorator
class NestedDictRegistry(dict[int | str, dict[int | str | None, _TypeT]]):
"""Dict Registry of multiple items per key."""
def register(
self, name: int | str, sub_name: int | str | None = None
) -> Callable[[_TypeT], _TypeT]:
"""Return decorator to register item with a specific and a quirk name."""
def decorator(cluster_handler: _TypeT) -> _TypeT:
"""Register decorated cluster handler or item."""
if name not in self:
self[name] = {}
self[name][sub_name] = cluster_handler
return cluster_handler
return decorator
class SetRegistry(set[int | str]):
"""Set Registry of items."""

View File

@ -203,9 +203,20 @@ class ProbeEndpoint:
if platform is None:
continue
cluster_handler_class = zha_regs.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, ClusterHandler
cluster_handler_classes = zha_regs.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, {None: ClusterHandler}
)
quirk_id = (
endpoint.device.quirk_id
if endpoint.device.quirk_id in cluster_handler_classes
else None
)
cluster_handler_class = cluster_handler_classes.get(
quirk_id, ClusterHandler
)
cluster_handler = cluster_handler_class(cluster, endpoint)
self.probe_single_cluster(platform, cluster_handler, endpoint)

View File

@ -6,7 +6,6 @@ from collections.abc import Callable
import logging
from typing import TYPE_CHECKING, Any, Final, TypeVar
import zigpy
from zigpy.typing import EndpointType as ZigpyEndpointType
from homeassistant.const import Platform
@ -15,7 +14,6 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send
from . import const, discovery, registries
from .cluster_handlers import ClusterHandler
from .cluster_handlers.general import MultistateInput
from .helpers import get_zha_data
if TYPE_CHECKING:
@ -116,8 +114,16 @@ class Endpoint:
def add_all_cluster_handlers(self) -> None:
"""Create and add cluster handlers for all input clusters."""
for cluster_id, cluster in self.zigpy_endpoint.in_clusters.items():
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, ClusterHandler
cluster_handler_classes = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, {None: ClusterHandler}
)
quirk_id = (
self.device.quirk_id
if self.device.quirk_id in cluster_handler_classes
else None
)
cluster_handler_class = cluster_handler_classes.get(
quirk_id, ClusterHandler
)
# Allow cluster handler to filter out bad matches
@ -129,15 +135,6 @@ class Endpoint:
cluster_id,
cluster_handler_class,
)
# really ugly hack to deal with xiaomi using the door lock cluster
# incorrectly.
if (
hasattr(cluster, "ep_attribute")
and cluster_id == zigpy.zcl.clusters.closures.DoorLock.cluster_id
and cluster.ep_attribute == "multistate_input"
):
cluster_handler_class = MultistateInput
# end of ugly hack
try:
cluster_handler = cluster_handler_class(cluster, self)

View File

@ -15,7 +15,7 @@ from zigpy.types.named import EUI64
from homeassistant.const import Platform
from .decorators import DictRegistry, SetRegistry
from .decorators import DictRegistry, NestedDictRegistry, SetRegistry
if TYPE_CHECKING:
from ..entity import ZhaEntity, ZhaGroupEntity
@ -110,7 +110,9 @@ CLUSTER_HANDLER_ONLY_CLUSTERS = SetRegistry()
CLIENT_CLUSTER_HANDLER_REGISTRY: DictRegistry[
type[ClientClusterHandler]
] = DictRegistry()
ZIGBEE_CLUSTER_HANDLER_REGISTRY: DictRegistry[type[ClusterHandler]] = DictRegistry()
ZIGBEE_CLUSTER_HANDLER_REGISTRY: NestedDictRegistry[
type[ClusterHandler]
] = NestedDictRegistry()
WEIGHT_ATTR = attrgetter("weight")

View File

@ -3,6 +3,7 @@ import asyncio
from collections.abc import Callable
import logging
import math
from types import NoneType
from unittest import mock
from unittest.mock import AsyncMock, patch
@ -11,12 +12,17 @@ import zigpy.device
import zigpy.endpoint
from zigpy.endpoint import Endpoint as ZigpyEndpoint
import zigpy.profiles.zha
import zigpy.quirks as zigpy_quirks
import zigpy.types as t
from zigpy.zcl import foundation
import zigpy.zcl.clusters
from zigpy.zcl.clusters import CLUSTERS_BY_ID
import zigpy.zdo.types as zdo_t
import homeassistant.components.zha.core.cluster_handlers as cluster_handlers
from homeassistant.components.zha.core.cluster_handlers.lighting import (
ColorClusterHandler,
)
import homeassistant.components.zha.core.const as zha_const
from homeassistant.components.zha.core.device import ZHADevice
from homeassistant.components.zha.core.endpoint import Endpoint
@ -97,7 +103,9 @@ def poll_control_ch(endpoint, zigpy_device_mock):
)
cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id]
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(cluster_id)
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id
).get(None)
return cluster_handler_class(cluster, endpoint)
@ -258,8 +266,8 @@ async def test_in_cluster_handler_config(
cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id]
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, cluster_handlers.ClusterHandler
)
cluster_id, {None, cluster_handlers.ClusterHandler}
).get(None)
cluster_handler = cluster_handler_class(cluster, endpoint)
await cluster_handler.async_configure()
@ -322,8 +330,8 @@ async def test_out_cluster_handler_config(
cluster = zigpy_dev.endpoints[1].out_clusters[cluster_id]
cluster.bind_only = True
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, cluster_handlers.ClusterHandler
)
cluster_id, {None: cluster_handlers.ClusterHandler}
).get(None)
cluster_handler = cluster_handler_class(cluster, endpoint)
await cluster_handler.async_configure()
@ -334,13 +342,46 @@ async def test_out_cluster_handler_config(
def test_cluster_handler_registry() -> None:
"""Test ZIGBEE cluster handler Registry."""
# get all quirk ID from zigpy quirks registry
all_quirk_ids = {}
for cluster_id in CLUSTERS_BY_ID:
all_quirk_ids[cluster_id] = {None}
for manufacturer in zigpy_quirks._DEVICE_REGISTRY.registry.values():
for model_quirk_list in manufacturer.values():
for quirk in model_quirk_list:
quirk_id = getattr(quirk, zha_const.ATTR_QUIRK_ID, None)
device_description = getattr(quirk, "replacement", None) or getattr(
quirk, "signature", None
)
for endpoint in device_description["endpoints"].values():
cluster_ids = set()
if "input_clusters" in endpoint:
cluster_ids.update(endpoint["input_clusters"])
if "output_clusters" in endpoint:
cluster_ids.update(endpoint["output_clusters"])
for cluster_id in cluster_ids:
if not isinstance(cluster_id, int):
cluster_id = cluster_id.cluster_id
if cluster_id not in all_quirk_ids:
all_quirk_ids[cluster_id] = {None}
all_quirk_ids[cluster_id].add(quirk_id)
del quirk, model_quirk_list, manufacturer
for (
cluster_id,
cluster_handler,
cluster_handler_classes,
) in registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.items():
assert isinstance(cluster_id, int)
assert 0 <= cluster_id <= 0xFFFF
assert issubclass(cluster_handler, cluster_handlers.ClusterHandler)
assert cluster_id in all_quirk_ids
assert isinstance(cluster_handler_classes, dict)
for quirk_id, cluster_handler in cluster_handler_classes.items():
assert isinstance(quirk_id, NoneType) or isinstance(quirk_id, str)
assert issubclass(cluster_handler, cluster_handlers.ClusterHandler)
assert quirk_id in all_quirk_ids[cluster_id]
def test_epch_unclaimed_cluster_handlers(cluster_handler) -> None:
@ -818,7 +859,8 @@ async def test_invalid_cluster_handler(hass: HomeAssistant, caplog) -> None:
],
)
mock_zha_device = mock.AsyncMock(spec_set=ZHADevice)
mock_zha_device = mock.AsyncMock(spec=ZHADevice)
mock_zha_device.quirk_id = None
zha_endpoint = Endpoint(zigpy_ep, mock_zha_device)
# The cluster handler throws an error when matching this cluster
@ -827,14 +869,84 @@ async def test_invalid_cluster_handler(hass: HomeAssistant, caplog) -> None:
# And one is also logged at runtime
with patch.dict(
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY,
{cluster.cluster_id: TestZigbeeClusterHandler},
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id],
{None: TestZigbeeClusterHandler},
), caplog.at_level(logging.WARNING):
zha_endpoint.add_all_cluster_handlers()
assert "missing_attr" in caplog.text
async def test_standard_cluster_handler(hass: HomeAssistant, caplog) -> None:
"""Test setting up a cluster handler that matches a standard cluster."""
class TestZigbeeClusterHandler(ColorClusterHandler):
pass
mock_device = mock.AsyncMock(spec_set=zigpy.device.Device)
zigpy_ep = zigpy.endpoint.Endpoint(mock_device, endpoint_id=1)
cluster = zigpy_ep.add_input_cluster(zigpy.zcl.clusters.lighting.Color.cluster_id)
cluster.configure_reporting_multiple = AsyncMock(
spec_set=cluster.configure_reporting_multiple,
return_value=[
foundation.ConfigureReportingResponseRecord(
status=foundation.Status.SUCCESS
)
],
)
mock_zha_device = mock.AsyncMock(spec=ZHADevice)
mock_zha_device.quirk_id = None
zha_endpoint = Endpoint(zigpy_ep, mock_zha_device)
with patch.dict(
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id],
{"__test_quirk_id": TestZigbeeClusterHandler},
):
zha_endpoint.add_all_cluster_handlers()
assert len(zha_endpoint.all_cluster_handlers) == 1
assert isinstance(
list(zha_endpoint.all_cluster_handlers.values())[0], ColorClusterHandler
)
async def test_quirk_id_cluster_handler(hass: HomeAssistant, caplog) -> None:
"""Test setting up a cluster handler that matches a standard cluster."""
class TestZigbeeClusterHandler(ColorClusterHandler):
pass
mock_device = mock.AsyncMock(spec_set=zigpy.device.Device)
zigpy_ep = zigpy.endpoint.Endpoint(mock_device, endpoint_id=1)
cluster = zigpy_ep.add_input_cluster(zigpy.zcl.clusters.lighting.Color.cluster_id)
cluster.configure_reporting_multiple = AsyncMock(
spec_set=cluster.configure_reporting_multiple,
return_value=[
foundation.ConfigureReportingResponseRecord(
status=foundation.Status.SUCCESS
)
],
)
mock_zha_device = mock.AsyncMock(spec=ZHADevice)
mock_zha_device.quirk_id = "__test_quirk_id"
zha_endpoint = Endpoint(zigpy_ep, mock_zha_device)
with patch.dict(
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id],
{"__test_quirk_id": TestZigbeeClusterHandler},
):
zha_endpoint.add_all_cluster_handlers()
assert len(zha_endpoint.all_cluster_handlers) == 1
assert isinstance(
list(zha_endpoint.all_cluster_handlers.values())[0], TestZigbeeClusterHandler
)
# parametrize side effects:
@pytest.mark.parametrize(
("side_effect", "expected_error"),