Match ZHA Custom ClusterHandler on a Custom Cluster using a unique id for the quirk (#101709)

* initial

* fix tests

* match on specific name and quirk name

* fix tests

* fix tests

* store cluster handlers in only one place

* edit tests

* use correct device for quirk id

* change quirk id

* fix tests

* even if there is a quirk id, it doesn't have to have a specific cluster handler

* add tests

* rename quirk_id

* add tests

* fix tests

* fix tests

* use quirk id from zha_quirks
This commit is contained in:
Caius-Bonus 2023-12-27 17:48:30 +01:00 committed by GitHub
parent 4330452212
commit 37707edc47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 177 additions and 28 deletions

View File

@ -5,8 +5,9 @@ import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from zhaquirks.inovelli.types import AllLEDEffectType, SingleLEDEffectType from zhaquirks.inovelli.types import AllLEDEffectType, SingleLEDEffectType
from zhaquirks.quirk_ids import TUYA_PLUG_MANUFACTURER from zhaquirks.quirk_ids import TUYA_PLUG_MANUFACTURER, XIAOMI_AQARA_VIBRATION_AQ1
import zigpy.zcl import zigpy.zcl
from zigpy.zcl.clusters.closures import DoorLock
from homeassistant.core import callback from homeassistant.core import callback
@ -24,6 +25,7 @@ from ..const import (
UNKNOWN, UNKNOWN,
) )
from . import AttrReportConfig, ClientClusterHandler, ClusterHandler from . import AttrReportConfig, ClientClusterHandler, ClusterHandler
from .general import MultistateInput
if TYPE_CHECKING: if TYPE_CHECKING:
from ..endpoint import Endpoint from ..endpoint import Endpoint
@ -403,3 +405,10 @@ class IkeaRemote(ClusterHandler):
"""Ikea Matter remote cluster handler.""" """Ikea Matter remote cluster handler."""
REPORT_CONFIG = () REPORT_CONFIG = ()
@registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.register(
DoorLock.cluster_id, XIAOMI_AQARA_VIBRATION_AQ1
)
class XiaomiVibrationAQ1ClusterHandler(MultistateInput):
"""Xiaomi DoorLock Cluster is in fact a MultiStateInput Cluster."""

View File

@ -21,6 +21,24 @@ class DictRegistry(dict[int | str, _TypeT]):
return decorator return decorator
class NestedDictRegistry(dict[int | str, dict[int | str | None, _TypeT]]):
"""Dict Registry of multiple items per key."""
def register(
self, name: int | str, sub_name: int | str | None = None
) -> Callable[[_TypeT], _TypeT]:
"""Return decorator to register item with a specific and a quirk name."""
def decorator(cluster_handler: _TypeT) -> _TypeT:
"""Register decorated cluster handler or item."""
if name not in self:
self[name] = {}
self[name][sub_name] = cluster_handler
return cluster_handler
return decorator
class SetRegistry(set[int | str]): class SetRegistry(set[int | str]):
"""Set Registry of items.""" """Set Registry of items."""

View File

@ -203,9 +203,20 @@ class ProbeEndpoint:
if platform is None: if platform is None:
continue continue
cluster_handler_class = zha_regs.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( cluster_handler_classes = zha_regs.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, ClusterHandler cluster_id, {None: ClusterHandler}
) )
quirk_id = (
endpoint.device.quirk_id
if endpoint.device.quirk_id in cluster_handler_classes
else None
)
cluster_handler_class = cluster_handler_classes.get(
quirk_id, ClusterHandler
)
cluster_handler = cluster_handler_class(cluster, endpoint) cluster_handler = cluster_handler_class(cluster, endpoint)
self.probe_single_cluster(platform, cluster_handler, endpoint) self.probe_single_cluster(platform, cluster_handler, endpoint)

View File

@ -6,7 +6,6 @@ from collections.abc import Callable
import logging import logging
from typing import TYPE_CHECKING, Any, Final, TypeVar from typing import TYPE_CHECKING, Any, Final, TypeVar
import zigpy
from zigpy.typing import EndpointType as ZigpyEndpointType from zigpy.typing import EndpointType as ZigpyEndpointType
from homeassistant.const import Platform from homeassistant.const import Platform
@ -15,7 +14,6 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send
from . import const, discovery, registries from . import const, discovery, registries
from .cluster_handlers import ClusterHandler from .cluster_handlers import ClusterHandler
from .cluster_handlers.general import MultistateInput
from .helpers import get_zha_data from .helpers import get_zha_data
if TYPE_CHECKING: if TYPE_CHECKING:
@ -116,8 +114,16 @@ class Endpoint:
def add_all_cluster_handlers(self) -> None: def add_all_cluster_handlers(self) -> None:
"""Create and add cluster handlers for all input clusters.""" """Create and add cluster handlers for all input clusters."""
for cluster_id, cluster in self.zigpy_endpoint.in_clusters.items(): for cluster_id, cluster in self.zigpy_endpoint.in_clusters.items():
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( cluster_handler_classes = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, ClusterHandler cluster_id, {None: ClusterHandler}
)
quirk_id = (
self.device.quirk_id
if self.device.quirk_id in cluster_handler_classes
else None
)
cluster_handler_class = cluster_handler_classes.get(
quirk_id, ClusterHandler
) )
# Allow cluster handler to filter out bad matches # Allow cluster handler to filter out bad matches
@ -129,15 +135,6 @@ class Endpoint:
cluster_id, cluster_id,
cluster_handler_class, cluster_handler_class,
) )
# really ugly hack to deal with xiaomi using the door lock cluster
# incorrectly.
if (
hasattr(cluster, "ep_attribute")
and cluster_id == zigpy.zcl.clusters.closures.DoorLock.cluster_id
and cluster.ep_attribute == "multistate_input"
):
cluster_handler_class = MultistateInput
# end of ugly hack
try: try:
cluster_handler = cluster_handler_class(cluster, self) cluster_handler = cluster_handler_class(cluster, self)

View File

@ -15,7 +15,7 @@ from zigpy.types.named import EUI64
from homeassistant.const import Platform from homeassistant.const import Platform
from .decorators import DictRegistry, SetRegistry from .decorators import DictRegistry, NestedDictRegistry, SetRegistry
if TYPE_CHECKING: if TYPE_CHECKING:
from ..entity import ZhaEntity, ZhaGroupEntity from ..entity import ZhaEntity, ZhaGroupEntity
@ -110,7 +110,9 @@ CLUSTER_HANDLER_ONLY_CLUSTERS = SetRegistry()
CLIENT_CLUSTER_HANDLER_REGISTRY: DictRegistry[ CLIENT_CLUSTER_HANDLER_REGISTRY: DictRegistry[
type[ClientClusterHandler] type[ClientClusterHandler]
] = DictRegistry() ] = DictRegistry()
ZIGBEE_CLUSTER_HANDLER_REGISTRY: DictRegistry[type[ClusterHandler]] = DictRegistry() ZIGBEE_CLUSTER_HANDLER_REGISTRY: NestedDictRegistry[
type[ClusterHandler]
] = NestedDictRegistry()
WEIGHT_ATTR = attrgetter("weight") WEIGHT_ATTR = attrgetter("weight")

View File

@ -3,6 +3,7 @@ import asyncio
from collections.abc import Callable from collections.abc import Callable
import logging import logging
import math import math
from types import NoneType
from unittest import mock from unittest import mock
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
@ -11,12 +12,17 @@ import zigpy.device
import zigpy.endpoint import zigpy.endpoint
from zigpy.endpoint import Endpoint as ZigpyEndpoint from zigpy.endpoint import Endpoint as ZigpyEndpoint
import zigpy.profiles.zha import zigpy.profiles.zha
import zigpy.quirks as zigpy_quirks
import zigpy.types as t import zigpy.types as t
from zigpy.zcl import foundation from zigpy.zcl import foundation
import zigpy.zcl.clusters import zigpy.zcl.clusters
from zigpy.zcl.clusters import CLUSTERS_BY_ID
import zigpy.zdo.types as zdo_t import zigpy.zdo.types as zdo_t
import homeassistant.components.zha.core.cluster_handlers as cluster_handlers import homeassistant.components.zha.core.cluster_handlers as cluster_handlers
from homeassistant.components.zha.core.cluster_handlers.lighting import (
ColorClusterHandler,
)
import homeassistant.components.zha.core.const as zha_const import homeassistant.components.zha.core.const as zha_const
from homeassistant.components.zha.core.device import ZHADevice from homeassistant.components.zha.core.device import ZHADevice
from homeassistant.components.zha.core.endpoint import Endpoint from homeassistant.components.zha.core.endpoint import Endpoint
@ -97,7 +103,9 @@ def poll_control_ch(endpoint, zigpy_device_mock):
) )
cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id] cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id]
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(cluster_id) cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id
).get(None)
return cluster_handler_class(cluster, endpoint) return cluster_handler_class(cluster, endpoint)
@ -258,8 +266,8 @@ async def test_in_cluster_handler_config(
cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id] cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id]
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, cluster_handlers.ClusterHandler cluster_id, {None, cluster_handlers.ClusterHandler}
) ).get(None)
cluster_handler = cluster_handler_class(cluster, endpoint) cluster_handler = cluster_handler_class(cluster, endpoint)
await cluster_handler.async_configure() await cluster_handler.async_configure()
@ -322,8 +330,8 @@ async def test_out_cluster_handler_config(
cluster = zigpy_dev.endpoints[1].out_clusters[cluster_id] cluster = zigpy_dev.endpoints[1].out_clusters[cluster_id]
cluster.bind_only = True cluster.bind_only = True
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, cluster_handlers.ClusterHandler cluster_id, {None: cluster_handlers.ClusterHandler}
) ).get(None)
cluster_handler = cluster_handler_class(cluster, endpoint) cluster_handler = cluster_handler_class(cluster, endpoint)
await cluster_handler.async_configure() await cluster_handler.async_configure()
@ -334,13 +342,46 @@ async def test_out_cluster_handler_config(
def test_cluster_handler_registry() -> None: def test_cluster_handler_registry() -> None:
"""Test ZIGBEE cluster handler Registry.""" """Test ZIGBEE cluster handler Registry."""
# get all quirk ID from zigpy quirks registry
all_quirk_ids = {}
for cluster_id in CLUSTERS_BY_ID:
all_quirk_ids[cluster_id] = {None}
for manufacturer in zigpy_quirks._DEVICE_REGISTRY.registry.values():
for model_quirk_list in manufacturer.values():
for quirk in model_quirk_list:
quirk_id = getattr(quirk, zha_const.ATTR_QUIRK_ID, None)
device_description = getattr(quirk, "replacement", None) or getattr(
quirk, "signature", None
)
for endpoint in device_description["endpoints"].values():
cluster_ids = set()
if "input_clusters" in endpoint:
cluster_ids.update(endpoint["input_clusters"])
if "output_clusters" in endpoint:
cluster_ids.update(endpoint["output_clusters"])
for cluster_id in cluster_ids:
if not isinstance(cluster_id, int):
cluster_id = cluster_id.cluster_id
if cluster_id not in all_quirk_ids:
all_quirk_ids[cluster_id] = {None}
all_quirk_ids[cluster_id].add(quirk_id)
del quirk, model_quirk_list, manufacturer
for ( for (
cluster_id, cluster_id,
cluster_handler, cluster_handler_classes,
) in registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.items(): ) in registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.items():
assert isinstance(cluster_id, int) assert isinstance(cluster_id, int)
assert 0 <= cluster_id <= 0xFFFF assert 0 <= cluster_id <= 0xFFFF
assert issubclass(cluster_handler, cluster_handlers.ClusterHandler) assert cluster_id in all_quirk_ids
assert isinstance(cluster_handler_classes, dict)
for quirk_id, cluster_handler in cluster_handler_classes.items():
assert isinstance(quirk_id, NoneType) or isinstance(quirk_id, str)
assert issubclass(cluster_handler, cluster_handlers.ClusterHandler)
assert quirk_id in all_quirk_ids[cluster_id]
def test_epch_unclaimed_cluster_handlers(cluster_handler) -> None: def test_epch_unclaimed_cluster_handlers(cluster_handler) -> None:
@ -818,7 +859,8 @@ async def test_invalid_cluster_handler(hass: HomeAssistant, caplog) -> None:
], ],
) )
mock_zha_device = mock.AsyncMock(spec_set=ZHADevice) mock_zha_device = mock.AsyncMock(spec=ZHADevice)
mock_zha_device.quirk_id = None
zha_endpoint = Endpoint(zigpy_ep, mock_zha_device) zha_endpoint = Endpoint(zigpy_ep, mock_zha_device)
# The cluster handler throws an error when matching this cluster # The cluster handler throws an error when matching this cluster
@ -827,14 +869,84 @@ async def test_invalid_cluster_handler(hass: HomeAssistant, caplog) -> None:
# And one is also logged at runtime # And one is also logged at runtime
with patch.dict( with patch.dict(
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY, registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id],
{cluster.cluster_id: TestZigbeeClusterHandler}, {None: TestZigbeeClusterHandler},
), caplog.at_level(logging.WARNING): ), caplog.at_level(logging.WARNING):
zha_endpoint.add_all_cluster_handlers() zha_endpoint.add_all_cluster_handlers()
assert "missing_attr" in caplog.text assert "missing_attr" in caplog.text
async def test_standard_cluster_handler(hass: HomeAssistant, caplog) -> None:
"""Test setting up a cluster handler that matches a standard cluster."""
class TestZigbeeClusterHandler(ColorClusterHandler):
pass
mock_device = mock.AsyncMock(spec_set=zigpy.device.Device)
zigpy_ep = zigpy.endpoint.Endpoint(mock_device, endpoint_id=1)
cluster = zigpy_ep.add_input_cluster(zigpy.zcl.clusters.lighting.Color.cluster_id)
cluster.configure_reporting_multiple = AsyncMock(
spec_set=cluster.configure_reporting_multiple,
return_value=[
foundation.ConfigureReportingResponseRecord(
status=foundation.Status.SUCCESS
)
],
)
mock_zha_device = mock.AsyncMock(spec=ZHADevice)
mock_zha_device.quirk_id = None
zha_endpoint = Endpoint(zigpy_ep, mock_zha_device)
with patch.dict(
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id],
{"__test_quirk_id": TestZigbeeClusterHandler},
):
zha_endpoint.add_all_cluster_handlers()
assert len(zha_endpoint.all_cluster_handlers) == 1
assert isinstance(
list(zha_endpoint.all_cluster_handlers.values())[0], ColorClusterHandler
)
async def test_quirk_id_cluster_handler(hass: HomeAssistant, caplog) -> None:
"""Test setting up a cluster handler that matches a standard cluster."""
class TestZigbeeClusterHandler(ColorClusterHandler):
pass
mock_device = mock.AsyncMock(spec_set=zigpy.device.Device)
zigpy_ep = zigpy.endpoint.Endpoint(mock_device, endpoint_id=1)
cluster = zigpy_ep.add_input_cluster(zigpy.zcl.clusters.lighting.Color.cluster_id)
cluster.configure_reporting_multiple = AsyncMock(
spec_set=cluster.configure_reporting_multiple,
return_value=[
foundation.ConfigureReportingResponseRecord(
status=foundation.Status.SUCCESS
)
],
)
mock_zha_device = mock.AsyncMock(spec=ZHADevice)
mock_zha_device.quirk_id = "__test_quirk_id"
zha_endpoint = Endpoint(zigpy_ep, mock_zha_device)
with patch.dict(
registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY[cluster.cluster_id],
{"__test_quirk_id": TestZigbeeClusterHandler},
):
zha_endpoint.add_all_cluster_handlers()
assert len(zha_endpoint.all_cluster_handlers) == 1
assert isinstance(
list(zha_endpoint.all_cluster_handlers.values())[0], TestZigbeeClusterHandler
)
# parametrize side effects: # parametrize side effects:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "expected_error"), ("side_effect", "expected_error"),