Replace ZHA quirk class matching with quirk ID matching (#102482)

* Use fixed quirk IDs for matching instead of quirk class

* Change tests for quirk id (WIP)

* Do not default `quirk_id` to `quirk_class`

* Implement test for checking if quirk ID exists

* Change `quirk_id` for test slightly (underscore instead of dot)
This commit is contained in:
TheJulianJES 2023-10-24 23:18:10 +02:00 committed by GitHub
parent 5ee14f7f7d
commit fd8fdba7e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 85 additions and 80 deletions

View File

@ -48,6 +48,7 @@ ATTR_POWER_SOURCE = "power_source"
ATTR_PROFILE_ID = "profile_id"
ATTR_QUIRK_APPLIED = "quirk_applied"
ATTR_QUIRK_CLASS = "quirk_class"
ATTR_QUIRK_ID = "quirk_id"
ATTR_ROUTES = "routes"
ATTR_RSSI = "rssi"
ATTR_SIGNATURE = "signature"

View File

@ -59,6 +59,7 @@ from .const import (
ATTR_POWER_SOURCE,
ATTR_QUIRK_APPLIED,
ATTR_QUIRK_CLASS,
ATTR_QUIRK_ID,
ATTR_ROUTES,
ATTR_RSSI,
ATTR_SIGNATURE,
@ -135,6 +136,7 @@ class ZHADevice(LogMixin):
f"{self._zigpy_device.__class__.__module__}."
f"{self._zigpy_device.__class__.__name__}"
)
self.quirk_id = getattr(self._zigpy_device, ATTR_QUIRK_ID, None)
if self.is_mains_powered:
self.consider_unavailable_time = async_get_zha_config_value(
@ -537,6 +539,7 @@ class ZHADevice(LogMixin):
ATTR_NAME: self.name or ieee,
ATTR_QUIRK_APPLIED: self.quirk_applied,
ATTR_QUIRK_CLASS: self.quirk_class,
ATTR_QUIRK_ID: self.quirk_id,
ATTR_MANUFACTURER_CODE: self.manufacturer_code,
ATTR_POWER_SOURCE: self.power_source,
ATTR_LQI: self.lqi,

View File

@ -122,7 +122,7 @@ class ProbeEndpoint:
endpoint.device.manufacturer,
endpoint.device.model,
cluster_handlers,
endpoint.device.quirk_class,
endpoint.device.quirk_id,
)
if platform_entity_class is None:
return
@ -181,7 +181,7 @@ class ProbeEndpoint:
endpoint.device.manufacturer,
endpoint.device.model,
cluster_handler_list,
endpoint.device.quirk_class,
endpoint.device.quirk_id,
)
if entity_class is None:
return
@ -226,14 +226,14 @@ class ProbeEndpoint:
endpoint.device.manufacturer,
endpoint.device.model,
list(endpoint.all_cluster_handlers.values()),
endpoint.device.quirk_class,
endpoint.device.quirk_id,
)
else:
matches, claimed = zha_regs.ZHA_ENTITIES.get_multi_entity(
endpoint.device.manufacturer,
endpoint.device.model,
endpoint.unclaimed_cluster_handlers(),
endpoint.device.quirk_class,
endpoint.device.quirk_id,
)
endpoint.claim_cluster_handlers(claimed)

View File

@ -147,7 +147,7 @@ class MatchRule:
aux_cluster_handlers: frozenset[str] | Callable = attr.ib(
factory=_get_empty_frozenset, converter=set_or_callable
)
quirk_classes: frozenset[str] | Callable = attr.ib(
quirk_ids: frozenset[str] | Callable = attr.ib(
factory=_get_empty_frozenset, converter=set_or_callable
)
@ -165,10 +165,8 @@ class MatchRule:
multiple cluster handlers a better priority over rules matching a single cluster handler.
"""
weight = 0
if self.quirk_classes:
weight += 501 - (
1 if callable(self.quirk_classes) else len(self.quirk_classes)
)
if self.quirk_ids:
weight += 501 - (1 if callable(self.quirk_ids) else len(self.quirk_ids))
if self.models:
weight += 401 - (1 if callable(self.models) else len(self.models))
@ -204,19 +202,31 @@ class MatchRule:
return claimed
def strict_matched(
self, manufacturer: str, model: str, cluster_handlers: list, quirk_class: str
self,
manufacturer: str,
model: str,
cluster_handlers: list,
quirk_id: str | None,
) -> bool:
"""Return True if this device matches the criteria."""
return all(self._matched(manufacturer, model, cluster_handlers, quirk_class))
return all(self._matched(manufacturer, model, cluster_handlers, quirk_id))
def loose_matched(
self, manufacturer: str, model: str, cluster_handlers: list, quirk_class: str
self,
manufacturer: str,
model: str,
cluster_handlers: list,
quirk_id: str | None,
) -> bool:
"""Return True if this device matches the criteria."""
return any(self._matched(manufacturer, model, cluster_handlers, quirk_class))
return any(self._matched(manufacturer, model, cluster_handlers, quirk_id))
def _matched(
self, manufacturer: str, model: str, cluster_handlers: list, quirk_class: str
self,
manufacturer: str,
model: str,
cluster_handlers: list,
quirk_id: str | None,
) -> list:
"""Return a list of field matches."""
if not any(attr.asdict(self).values()):
@ -243,14 +253,11 @@ class MatchRule:
else:
matches.append(model in self.models)
if self.quirk_classes:
if callable(self.quirk_classes):
matches.append(self.quirk_classes(quirk_class))
if self.quirk_ids and quirk_id:
if callable(self.quirk_ids):
matches.append(self.quirk_ids(quirk_id))
else:
matches.append(
quirk_class.split(".")[-2:]
in [x.split(".")[-2:] for x in self.quirk_classes]
)
matches.append(quirk_id in self.quirk_ids)
return matches
@ -292,13 +299,13 @@ class ZHAEntityRegistry:
manufacturer: str,
model: str,
cluster_handlers: list[ClusterHandler],
quirk_class: str,
quirk_id: str | None,
default: type[ZhaEntity] | None = None,
) -> tuple[type[ZhaEntity] | None, list[ClusterHandler]]:
"""Match a ZHA ClusterHandler to a ZHA Entity class."""
matches = self._strict_registry[component]
for match in sorted(matches, key=WEIGHT_ATTR, reverse=True):
if match.strict_matched(manufacturer, model, cluster_handlers, quirk_class):
if match.strict_matched(manufacturer, model, cluster_handlers, quirk_id):
claimed = match.claim_cluster_handlers(cluster_handlers)
return self._strict_registry[component][match], claimed
@ -309,7 +316,7 @@ class ZHAEntityRegistry:
manufacturer: str,
model: str,
cluster_handlers: list[ClusterHandler],
quirk_class: str,
quirk_id: str | None,
) -> tuple[
dict[Platform, list[EntityClassAndClusterHandlers]], list[ClusterHandler]
]:
@ -323,7 +330,7 @@ class ZHAEntityRegistry:
sorted_matches = sorted(matches, key=WEIGHT_ATTR, reverse=True)
for match in sorted_matches:
if match.strict_matched(
manufacturer, model, cluster_handlers, quirk_class
manufacturer, model, cluster_handlers, quirk_id
):
claimed = match.claim_cluster_handlers(cluster_handlers)
for ent_class in stop_match_groups[stop_match_grp][match]:
@ -342,7 +349,7 @@ class ZHAEntityRegistry:
manufacturer: str,
model: str,
cluster_handlers: list[ClusterHandler],
quirk_class: str,
quirk_id: str | None,
) -> tuple[
dict[Platform, list[EntityClassAndClusterHandlers]], list[ClusterHandler]
]:
@ -359,7 +366,7 @@ class ZHAEntityRegistry:
sorted_matches = sorted(matches, key=WEIGHT_ATTR, reverse=True)
for match in sorted_matches:
if match.strict_matched(
manufacturer, model, cluster_handlers, quirk_class
manufacturer, model, cluster_handlers, quirk_id
):
claimed = match.claim_cluster_handlers(cluster_handlers)
for ent_class in stop_match_groups[stop_match_grp][match]:
@ -385,7 +392,7 @@ class ZHAEntityRegistry:
manufacturers: Callable | set[str] | str | None = None,
models: Callable | set[str] | str | None = None,
aux_cluster_handlers: Callable | set[str] | str | None = None,
quirk_classes: set[str] | str | None = None,
quirk_ids: set[str] | str | None = None,
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
"""Decorate a strict match rule."""
@ -395,7 +402,7 @@ class ZHAEntityRegistry:
manufacturers,
models,
aux_cluster_handlers,
quirk_classes,
quirk_ids,
)
def decorator(zha_ent: _ZhaEntityT) -> _ZhaEntityT:
@ -417,7 +424,7 @@ class ZHAEntityRegistry:
models: Callable | set[str] | str | None = None,
aux_cluster_handlers: Callable | set[str] | str | None = None,
stop_on_match_group: int | str | None = None,
quirk_classes: set[str] | str | None = None,
quirk_ids: set[str] | str | None = None,
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
"""Decorate a loose match rule."""
@ -427,7 +434,7 @@ class ZHAEntityRegistry:
manufacturers,
models,
aux_cluster_handlers,
quirk_classes,
quirk_ids,
)
def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT:
@ -452,7 +459,7 @@ class ZHAEntityRegistry:
models: Callable | set[str] | str | None = None,
aux_cluster_handlers: Callable | set[str] | str | None = None,
stop_on_match_group: int | str | None = None,
quirk_classes: set[str] | str | None = None,
quirk_ids: set[str] | str | None = None,
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
"""Decorate a loose match rule."""
@ -462,7 +469,7 @@ class ZHAEntityRegistry:
manufacturers,
models,
aux_cluster_handlers,
quirk_classes,
quirk_ids,
)
def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT:

View File

@ -1,15 +1,14 @@
"""Test ZHA registries."""
from __future__ import annotations
import importlib
import inspect
import typing
from unittest import mock
import pytest
import zhaquirks
import zigpy.quirks as zigpy_quirks
from homeassistant.components.zha.binary_sensor import IASZone
from homeassistant.components.zha.core.const import ATTR_QUIRK_ID
import homeassistant.components.zha.core.registries as registries
from homeassistant.helpers import entity_registry as er
@ -19,7 +18,7 @@ if typing.TYPE_CHECKING:
MANUFACTURER = "mock manufacturer"
MODEL = "mock model"
QUIRK_CLASS = "mock.test.quirk.class"
QUIRK_CLASS_SHORT = "quirk.class"
QUIRK_ID = "quirk_id"
@pytest.fixture
@ -29,6 +28,7 @@ def zha_device():
dev.manufacturer = MANUFACTURER
dev.model = MODEL
dev.quirk_class = QUIRK_CLASS
dev.quirk_id = QUIRK_ID
return dev
@ -107,17 +107,17 @@ def cluster_handlers(cluster_handler):
),
False,
),
(registries.MatchRule(quirk_classes=QUIRK_CLASS), True),
(registries.MatchRule(quirk_classes="no match"), False),
(registries.MatchRule(quirk_ids=QUIRK_ID), True),
(registries.MatchRule(quirk_ids="no match"), False),
(
registries.MatchRule(
quirk_classes=QUIRK_CLASS, aux_cluster_handlers="aux_cluster_handler"
quirk_ids=QUIRK_ID, aux_cluster_handlers="aux_cluster_handler"
),
True,
),
(
registries.MatchRule(
quirk_classes="no match", aux_cluster_handlers="aux_cluster_handler"
quirk_ids="no match", aux_cluster_handlers="aux_cluster_handler"
),
False,
),
@ -128,7 +128,7 @@ def cluster_handlers(cluster_handler):
cluster_handler_names={"on_off", "level"},
manufacturers=MANUFACTURER,
models=MODEL,
quirk_classes=QUIRK_CLASS,
quirk_ids=QUIRK_ID,
),
True,
),
@ -187,33 +187,31 @@ def cluster_handlers(cluster_handler):
(
registries.MatchRule(
cluster_handler_names="on_off",
quirk_classes={"random quirk", QUIRK_CLASS},
quirk_ids={"random quirk", QUIRK_ID},
),
True,
),
(
registries.MatchRule(
cluster_handler_names="on_off",
quirk_classes={"random quirk", "another quirk"},
quirk_ids={"random quirk", "another quirk"},
),
False,
),
(
registries.MatchRule(
cluster_handler_names="on_off", quirk_classes=lambda x: x == QUIRK_CLASS
cluster_handler_names="on_off", quirk_ids=lambda x: x == QUIRK_ID
),
True,
),
(
registries.MatchRule(
cluster_handler_names="on_off", quirk_classes=lambda x: x != QUIRK_CLASS
cluster_handler_names="on_off", quirk_ids=lambda x: x != QUIRK_ID
),
False,
),
(
registries.MatchRule(
cluster_handler_names="on_off", quirk_classes=QUIRK_CLASS_SHORT
),
registries.MatchRule(cluster_handler_names="on_off", quirk_ids=QUIRK_ID),
True,
),
],
@ -221,8 +219,7 @@ def cluster_handlers(cluster_handler):
def test_registry_matching(rule, matched, cluster_handlers) -> None:
"""Test strict rule matching."""
assert (
rule.strict_matched(MANUFACTURER, MODEL, cluster_handlers, QUIRK_CLASS)
is matched
rule.strict_matched(MANUFACTURER, MODEL, cluster_handlers, QUIRK_ID) is matched
)
@ -314,8 +311,8 @@ def test_registry_matching(rule, matched, cluster_handlers) -> None:
(registries.MatchRule(manufacturers=MANUFACTURER), True),
(registries.MatchRule(models=MODEL), True),
(registries.MatchRule(models="no match"), False),
(registries.MatchRule(quirk_classes=QUIRK_CLASS), True),
(registries.MatchRule(quirk_classes="no match"), False),
(registries.MatchRule(quirk_ids=QUIRK_ID), True),
(registries.MatchRule(quirk_ids="no match"), False),
# match everything
(
registries.MatchRule(
@ -323,7 +320,7 @@ def test_registry_matching(rule, matched, cluster_handlers) -> None:
cluster_handler_names={"on_off", "level"},
manufacturers=MANUFACTURER,
models=MODEL,
quirk_classes=QUIRK_CLASS,
quirk_ids=QUIRK_ID,
),
True,
),
@ -332,8 +329,7 @@ def test_registry_matching(rule, matched, cluster_handlers) -> None:
def test_registry_loose_matching(rule, matched, cluster_handlers) -> None:
"""Test loose rule matching."""
assert (
rule.loose_matched(MANUFACTURER, MODEL, cluster_handlers, QUIRK_CLASS)
is matched
rule.loose_matched(MANUFACTURER, MODEL, cluster_handlers, QUIRK_ID) is matched
)
@ -397,12 +393,12 @@ def entity_registry():
@pytest.mark.parametrize(
("manufacturer", "model", "quirk_class", "match_name"),
("manufacturer", "model", "quirk_id", "match_name"),
(
("random manufacturer", "random model", "random.class", "OnOff"),
("random manufacturer", MODEL, "random.class", "OnOffModel"),
(MANUFACTURER, "random model", "random.class", "OnOffManufacturer"),
("random manufacturer", "random model", QUIRK_CLASS, "OnOffQuirk"),
("random manufacturer", "random model", QUIRK_ID, "OnOffQuirk"),
(MANUFACTURER, MODEL, "random.class", "OnOffModelManufacturer"),
(MANUFACTURER, "some model", "random.class", "OnOffMultimodel"),
),
@ -412,7 +408,7 @@ def test_weighted_match(
entity_registry: er.EntityRegistry,
manufacturer,
model,
quirk_class,
quirk_id,
match_name,
) -> None:
"""Test weightedd match."""
@ -453,7 +449,7 @@ def test_weighted_match(
pass
@entity_registry.strict_match(
s.component, cluster_handler_names="on_off", quirk_classes=QUIRK_CLASS
s.component, cluster_handler_names="on_off", quirk_ids=QUIRK_ID
)
class OnOffQuirk:
pass
@ -462,7 +458,7 @@ def test_weighted_match(
ch_level = cluster_handler("level", 8)
match, claimed = entity_registry.get_entity(
s.component, manufacturer, model, [ch_on_off, ch_level], quirk_class
s.component, manufacturer, model, [ch_on_off, ch_level], quirk_id
)
assert match.__name__ == match_name
@ -490,7 +486,7 @@ def test_multi_sensor_match(
"manufacturer",
"model",
cluster_handlers=[ch_se, ch_illuminati],
quirk_class="quirk_class",
quirk_id="quirk_id",
)
assert s.binary_sensor in match
@ -520,7 +516,7 @@ def test_multi_sensor_match(
"manufacturer",
"model",
cluster_handlers={ch_se, ch_illuminati},
quirk_class="quirk_class",
quirk_id="quirk_id",
)
assert s.binary_sensor in match
@ -554,18 +550,10 @@ def iter_all_rules() -> typing.Iterable[registries.MatchRule, list[type[ZhaEntit
def test_quirk_classes() -> None:
"""Make sure that quirk_classes in components matches are valid."""
def find_quirk_class(base_obj, quirk_mod, quirk_cls):
"""Find a specific quirk class."""
module = importlib.import_module(quirk_mod)
clss = dict(inspect.getmembers(module, inspect.isclass))
# Check quirk_cls in module classes
return quirk_cls in clss
"""Make sure that all quirk IDs in components matches exist."""
def quirk_class_validator(value):
"""Validate quirk classes during self test."""
"""Validate quirk IDs during self test."""
if callable(value):
# Callables cannot be tested
return
@ -576,16 +564,22 @@ def test_quirk_classes() -> None:
quirk_class_validator(v)
return
quirk_tok = value.rsplit(".", 1)
if len(quirk_tok) != 2:
# quirk_class is at least __module__.__class__
raise ValueError(f"Invalid quirk class : '{value}'")
if value not in all_quirk_ids:
raise ValueError(f"Quirk ID '{value}' does not exist.")
if not find_quirk_class(zhaquirks, quirk_tok[0], quirk_tok[1]):
raise ValueError(f"Quirk class '{value}' does not exists.")
# get all quirk ID from zigpy quirks registry
all_quirk_ids = []
for manufacturer in zigpy_quirks._DEVICE_REGISTRY._registry.values():
for model_quirk_list in manufacturer.values():
for quirk in model_quirk_list:
quirk_id = getattr(quirk, ATTR_QUIRK_ID, None)
if quirk_id is not None and quirk_id not in all_quirk_ids:
all_quirk_ids.append(quirk_id)
del quirk, model_quirk_list, manufacturer
# validate all quirk IDs used in component match rules
for rule, _ in iter_all_rules():
quirk_class_validator(rule.quirk_classes)
quirk_class_validator(rule.quirk_ids)
def test_entity_names() -> None: