diff --git a/homeassistant/components/zha/core/const.py b/homeassistant/components/zha/core/const.py index c286d0112e9..9874fddc598 100644 --- a/homeassistant/components/zha/core/const.py +++ b/homeassistant/components/zha/core/const.py @@ -48,6 +48,7 @@ ATTR_POWER_SOURCE = "power_source" ATTR_PROFILE_ID = "profile_id" ATTR_QUIRK_APPLIED = "quirk_applied" ATTR_QUIRK_CLASS = "quirk_class" +ATTR_QUIRK_ID = "quirk_id" ATTR_ROUTES = "routes" ATTR_RSSI = "rssi" ATTR_SIGNATURE = "signature" diff --git a/homeassistant/components/zha/core/device.py b/homeassistant/components/zha/core/device.py index 8f5b087f068..44acbb172fc 100644 --- a/homeassistant/components/zha/core/device.py +++ b/homeassistant/components/zha/core/device.py @@ -59,6 +59,7 @@ from .const import ( ATTR_POWER_SOURCE, ATTR_QUIRK_APPLIED, ATTR_QUIRK_CLASS, + ATTR_QUIRK_ID, ATTR_ROUTES, ATTR_RSSI, ATTR_SIGNATURE, @@ -135,6 +136,7 @@ class ZHADevice(LogMixin): f"{self._zigpy_device.__class__.__module__}." f"{self._zigpy_device.__class__.__name__}" ) + self.quirk_id = getattr(self._zigpy_device, ATTR_QUIRK_ID, None) if self.is_mains_powered: self.consider_unavailable_time = async_get_zha_config_value( @@ -537,6 +539,7 @@ class ZHADevice(LogMixin): ATTR_NAME: self.name or ieee, ATTR_QUIRK_APPLIED: self.quirk_applied, ATTR_QUIRK_CLASS: self.quirk_class, + ATTR_QUIRK_ID: self.quirk_id, ATTR_MANUFACTURER_CODE: self.manufacturer_code, ATTR_POWER_SOURCE: self.power_source, ATTR_LQI: self.lqi, diff --git a/homeassistant/components/zha/core/discovery.py b/homeassistant/components/zha/core/discovery.py index a56e7044d3a..90ed68f9b00 100644 --- a/homeassistant/components/zha/core/discovery.py +++ b/homeassistant/components/zha/core/discovery.py @@ -122,7 +122,7 @@ class ProbeEndpoint: endpoint.device.manufacturer, endpoint.device.model, cluster_handlers, - endpoint.device.quirk_class, + endpoint.device.quirk_id, ) if platform_entity_class is None: return @@ -181,7 +181,7 @@ class ProbeEndpoint: endpoint.device.manufacturer, endpoint.device.model, cluster_handler_list, - endpoint.device.quirk_class, + endpoint.device.quirk_id, ) if entity_class is None: return @@ -226,14 +226,14 @@ class ProbeEndpoint: endpoint.device.manufacturer, endpoint.device.model, list(endpoint.all_cluster_handlers.values()), - endpoint.device.quirk_class, + endpoint.device.quirk_id, ) else: matches, claimed = zha_regs.ZHA_ENTITIES.get_multi_entity( endpoint.device.manufacturer, endpoint.device.model, endpoint.unclaimed_cluster_handlers(), - endpoint.device.quirk_class, + endpoint.device.quirk_id, ) endpoint.claim_cluster_handlers(claimed) diff --git a/homeassistant/components/zha/core/registries.py b/homeassistant/components/zha/core/registries.py index 74f724bdc49..4bdedebfff9 100644 --- a/homeassistant/components/zha/core/registries.py +++ b/homeassistant/components/zha/core/registries.py @@ -147,7 +147,7 @@ class MatchRule: aux_cluster_handlers: frozenset[str] | Callable = attr.ib( factory=_get_empty_frozenset, converter=set_or_callable ) - quirk_classes: frozenset[str] | Callable = attr.ib( + quirk_ids: frozenset[str] | Callable = attr.ib( factory=_get_empty_frozenset, converter=set_or_callable ) @@ -165,10 +165,8 @@ class MatchRule: multiple cluster handlers a better priority over rules matching a single cluster handler. """ weight = 0 - if self.quirk_classes: - weight += 501 - ( - 1 if callable(self.quirk_classes) else len(self.quirk_classes) - ) + if self.quirk_ids: + weight += 501 - (1 if callable(self.quirk_ids) else len(self.quirk_ids)) if self.models: weight += 401 - (1 if callable(self.models) else len(self.models)) @@ -204,19 +202,31 @@ class MatchRule: return claimed def strict_matched( - self, manufacturer: str, model: str, cluster_handlers: list, quirk_class: str + self, + manufacturer: str, + model: str, + cluster_handlers: list, + quirk_id: str | None, ) -> bool: """Return True if this device matches the criteria.""" - return all(self._matched(manufacturer, model, cluster_handlers, quirk_class)) + return all(self._matched(manufacturer, model, cluster_handlers, quirk_id)) def loose_matched( - self, manufacturer: str, model: str, cluster_handlers: list, quirk_class: str + self, + manufacturer: str, + model: str, + cluster_handlers: list, + quirk_id: str | None, ) -> bool: """Return True if this device matches the criteria.""" - return any(self._matched(manufacturer, model, cluster_handlers, quirk_class)) + return any(self._matched(manufacturer, model, cluster_handlers, quirk_id)) def _matched( - self, manufacturer: str, model: str, cluster_handlers: list, quirk_class: str + self, + manufacturer: str, + model: str, + cluster_handlers: list, + quirk_id: str | None, ) -> list: """Return a list of field matches.""" if not any(attr.asdict(self).values()): @@ -243,14 +253,11 @@ class MatchRule: else: matches.append(model in self.models) - if self.quirk_classes: - if callable(self.quirk_classes): - matches.append(self.quirk_classes(quirk_class)) + if self.quirk_ids and quirk_id: + if callable(self.quirk_ids): + matches.append(self.quirk_ids(quirk_id)) else: - matches.append( - quirk_class.split(".")[-2:] - in [x.split(".")[-2:] for x in self.quirk_classes] - ) + matches.append(quirk_id in self.quirk_ids) return matches @@ -292,13 +299,13 @@ class ZHAEntityRegistry: manufacturer: str, model: str, cluster_handlers: list[ClusterHandler], - quirk_class: str, + quirk_id: str | None, default: type[ZhaEntity] | None = None, ) -> tuple[type[ZhaEntity] | None, list[ClusterHandler]]: """Match a ZHA ClusterHandler to a ZHA Entity class.""" matches = self._strict_registry[component] for match in sorted(matches, key=WEIGHT_ATTR, reverse=True): - if match.strict_matched(manufacturer, model, cluster_handlers, quirk_class): + if match.strict_matched(manufacturer, model, cluster_handlers, quirk_id): claimed = match.claim_cluster_handlers(cluster_handlers) return self._strict_registry[component][match], claimed @@ -309,7 +316,7 @@ class ZHAEntityRegistry: manufacturer: str, model: str, cluster_handlers: list[ClusterHandler], - quirk_class: str, + quirk_id: str | None, ) -> tuple[ dict[Platform, list[EntityClassAndClusterHandlers]], list[ClusterHandler] ]: @@ -323,7 +330,7 @@ class ZHAEntityRegistry: sorted_matches = sorted(matches, key=WEIGHT_ATTR, reverse=True) for match in sorted_matches: if match.strict_matched( - manufacturer, model, cluster_handlers, quirk_class + manufacturer, model, cluster_handlers, quirk_id ): claimed = match.claim_cluster_handlers(cluster_handlers) for ent_class in stop_match_groups[stop_match_grp][match]: @@ -342,7 +349,7 @@ class ZHAEntityRegistry: manufacturer: str, model: str, cluster_handlers: list[ClusterHandler], - quirk_class: str, + quirk_id: str | None, ) -> tuple[ dict[Platform, list[EntityClassAndClusterHandlers]], list[ClusterHandler] ]: @@ -359,7 +366,7 @@ class ZHAEntityRegistry: sorted_matches = sorted(matches, key=WEIGHT_ATTR, reverse=True) for match in sorted_matches: if match.strict_matched( - manufacturer, model, cluster_handlers, quirk_class + manufacturer, model, cluster_handlers, quirk_id ): claimed = match.claim_cluster_handlers(cluster_handlers) for ent_class in stop_match_groups[stop_match_grp][match]: @@ -385,7 +392,7 @@ class ZHAEntityRegistry: manufacturers: Callable | set[str] | str | None = None, models: Callable | set[str] | str | None = None, aux_cluster_handlers: Callable | set[str] | str | None = None, - quirk_classes: set[str] | str | None = None, + quirk_ids: set[str] | str | None = None, ) -> Callable[[_ZhaEntityT], _ZhaEntityT]: """Decorate a strict match rule.""" @@ -395,7 +402,7 @@ class ZHAEntityRegistry: manufacturers, models, aux_cluster_handlers, - quirk_classes, + quirk_ids, ) def decorator(zha_ent: _ZhaEntityT) -> _ZhaEntityT: @@ -417,7 +424,7 @@ class ZHAEntityRegistry: models: Callable | set[str] | str | None = None, aux_cluster_handlers: Callable | set[str] | str | None = None, stop_on_match_group: int | str | None = None, - quirk_classes: set[str] | str | None = None, + quirk_ids: set[str] | str | None = None, ) -> Callable[[_ZhaEntityT], _ZhaEntityT]: """Decorate a loose match rule.""" @@ -427,7 +434,7 @@ class ZHAEntityRegistry: manufacturers, models, aux_cluster_handlers, - quirk_classes, + quirk_ids, ) def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT: @@ -452,7 +459,7 @@ class ZHAEntityRegistry: models: Callable | set[str] | str | None = None, aux_cluster_handlers: Callable | set[str] | str | None = None, stop_on_match_group: int | str | None = None, - quirk_classes: set[str] | str | None = None, + quirk_ids: set[str] | str | None = None, ) -> Callable[[_ZhaEntityT], _ZhaEntityT]: """Decorate a loose match rule.""" @@ -462,7 +469,7 @@ class ZHAEntityRegistry: manufacturers, models, aux_cluster_handlers, - quirk_classes, + quirk_ids, ) def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT: diff --git a/tests/components/zha/test_registries.py b/tests/components/zha/test_registries.py index 2eb61402a95..68ff116adea 100644 --- a/tests/components/zha/test_registries.py +++ b/tests/components/zha/test_registries.py @@ -1,15 +1,14 @@ """Test ZHA registries.""" from __future__ import annotations -import importlib -import inspect import typing from unittest import mock import pytest -import zhaquirks +import zigpy.quirks as zigpy_quirks from homeassistant.components.zha.binary_sensor import IASZone +from homeassistant.components.zha.core.const import ATTR_QUIRK_ID import homeassistant.components.zha.core.registries as registries from homeassistant.helpers import entity_registry as er @@ -19,7 +18,7 @@ if typing.TYPE_CHECKING: MANUFACTURER = "mock manufacturer" MODEL = "mock model" QUIRK_CLASS = "mock.test.quirk.class" -QUIRK_CLASS_SHORT = "quirk.class" +QUIRK_ID = "quirk_id" @pytest.fixture @@ -29,6 +28,7 @@ def zha_device(): dev.manufacturer = MANUFACTURER dev.model = MODEL dev.quirk_class = QUIRK_CLASS + dev.quirk_id = QUIRK_ID return dev @@ -107,17 +107,17 @@ def cluster_handlers(cluster_handler): ), False, ), - (registries.MatchRule(quirk_classes=QUIRK_CLASS), True), - (registries.MatchRule(quirk_classes="no match"), False), + (registries.MatchRule(quirk_ids=QUIRK_ID), True), + (registries.MatchRule(quirk_ids="no match"), False), ( registries.MatchRule( - quirk_classes=QUIRK_CLASS, aux_cluster_handlers="aux_cluster_handler" + quirk_ids=QUIRK_ID, aux_cluster_handlers="aux_cluster_handler" ), True, ), ( registries.MatchRule( - quirk_classes="no match", aux_cluster_handlers="aux_cluster_handler" + quirk_ids="no match", aux_cluster_handlers="aux_cluster_handler" ), False, ), @@ -128,7 +128,7 @@ def cluster_handlers(cluster_handler): cluster_handler_names={"on_off", "level"}, manufacturers=MANUFACTURER, models=MODEL, - quirk_classes=QUIRK_CLASS, + quirk_ids=QUIRK_ID, ), True, ), @@ -187,33 +187,31 @@ def cluster_handlers(cluster_handler): ( registries.MatchRule( cluster_handler_names="on_off", - quirk_classes={"random quirk", QUIRK_CLASS}, + quirk_ids={"random quirk", QUIRK_ID}, ), True, ), ( registries.MatchRule( cluster_handler_names="on_off", - quirk_classes={"random quirk", "another quirk"}, + quirk_ids={"random quirk", "another quirk"}, ), False, ), ( registries.MatchRule( - cluster_handler_names="on_off", quirk_classes=lambda x: x == QUIRK_CLASS + cluster_handler_names="on_off", quirk_ids=lambda x: x == QUIRK_ID ), True, ), ( registries.MatchRule( - cluster_handler_names="on_off", quirk_classes=lambda x: x != QUIRK_CLASS + cluster_handler_names="on_off", quirk_ids=lambda x: x != QUIRK_ID ), False, ), ( - registries.MatchRule( - cluster_handler_names="on_off", quirk_classes=QUIRK_CLASS_SHORT - ), + registries.MatchRule(cluster_handler_names="on_off", quirk_ids=QUIRK_ID), True, ), ], @@ -221,8 +219,7 @@ def cluster_handlers(cluster_handler): def test_registry_matching(rule, matched, cluster_handlers) -> None: """Test strict rule matching.""" assert ( - rule.strict_matched(MANUFACTURER, MODEL, cluster_handlers, QUIRK_CLASS) - is matched + rule.strict_matched(MANUFACTURER, MODEL, cluster_handlers, QUIRK_ID) is matched ) @@ -314,8 +311,8 @@ def test_registry_matching(rule, matched, cluster_handlers) -> None: (registries.MatchRule(manufacturers=MANUFACTURER), True), (registries.MatchRule(models=MODEL), True), (registries.MatchRule(models="no match"), False), - (registries.MatchRule(quirk_classes=QUIRK_CLASS), True), - (registries.MatchRule(quirk_classes="no match"), False), + (registries.MatchRule(quirk_ids=QUIRK_ID), True), + (registries.MatchRule(quirk_ids="no match"), False), # match everything ( registries.MatchRule( @@ -323,7 +320,7 @@ def test_registry_matching(rule, matched, cluster_handlers) -> None: cluster_handler_names={"on_off", "level"}, manufacturers=MANUFACTURER, models=MODEL, - quirk_classes=QUIRK_CLASS, + quirk_ids=QUIRK_ID, ), True, ), @@ -332,8 +329,7 @@ def test_registry_matching(rule, matched, cluster_handlers) -> None: def test_registry_loose_matching(rule, matched, cluster_handlers) -> None: """Test loose rule matching.""" assert ( - rule.loose_matched(MANUFACTURER, MODEL, cluster_handlers, QUIRK_CLASS) - is matched + rule.loose_matched(MANUFACTURER, MODEL, cluster_handlers, QUIRK_ID) is matched ) @@ -397,12 +393,12 @@ def entity_registry(): @pytest.mark.parametrize( - ("manufacturer", "model", "quirk_class", "match_name"), + ("manufacturer", "model", "quirk_id", "match_name"), ( ("random manufacturer", "random model", "random.class", "OnOff"), ("random manufacturer", MODEL, "random.class", "OnOffModel"), (MANUFACTURER, "random model", "random.class", "OnOffManufacturer"), - ("random manufacturer", "random model", QUIRK_CLASS, "OnOffQuirk"), + ("random manufacturer", "random model", QUIRK_ID, "OnOffQuirk"), (MANUFACTURER, MODEL, "random.class", "OnOffModelManufacturer"), (MANUFACTURER, "some model", "random.class", "OnOffMultimodel"), ), @@ -412,7 +408,7 @@ def test_weighted_match( entity_registry: er.EntityRegistry, manufacturer, model, - quirk_class, + quirk_id, match_name, ) -> None: """Test weightedd match.""" @@ -453,7 +449,7 @@ def test_weighted_match( pass @entity_registry.strict_match( - s.component, cluster_handler_names="on_off", quirk_classes=QUIRK_CLASS + s.component, cluster_handler_names="on_off", quirk_ids=QUIRK_ID ) class OnOffQuirk: pass @@ -462,7 +458,7 @@ def test_weighted_match( ch_level = cluster_handler("level", 8) match, claimed = entity_registry.get_entity( - s.component, manufacturer, model, [ch_on_off, ch_level], quirk_class + s.component, manufacturer, model, [ch_on_off, ch_level], quirk_id ) assert match.__name__ == match_name @@ -490,7 +486,7 @@ def test_multi_sensor_match( "manufacturer", "model", cluster_handlers=[ch_se, ch_illuminati], - quirk_class="quirk_class", + quirk_id="quirk_id", ) assert s.binary_sensor in match @@ -520,7 +516,7 @@ def test_multi_sensor_match( "manufacturer", "model", cluster_handlers={ch_se, ch_illuminati}, - quirk_class="quirk_class", + quirk_id="quirk_id", ) assert s.binary_sensor in match @@ -554,18 +550,10 @@ def iter_all_rules() -> typing.Iterable[registries.MatchRule, list[type[ZhaEntit def test_quirk_classes() -> None: - """Make sure that quirk_classes in components matches are valid.""" - - def find_quirk_class(base_obj, quirk_mod, quirk_cls): - """Find a specific quirk class.""" - - module = importlib.import_module(quirk_mod) - clss = dict(inspect.getmembers(module, inspect.isclass)) - # Check quirk_cls in module classes - return quirk_cls in clss + """Make sure that all quirk IDs in components matches exist.""" def quirk_class_validator(value): - """Validate quirk classes during self test.""" + """Validate quirk IDs during self test.""" if callable(value): # Callables cannot be tested return @@ -576,16 +564,22 @@ def test_quirk_classes() -> None: quirk_class_validator(v) return - quirk_tok = value.rsplit(".", 1) - if len(quirk_tok) != 2: - # quirk_class is at least __module__.__class__ - raise ValueError(f"Invalid quirk class : '{value}'") + if value not in all_quirk_ids: + raise ValueError(f"Quirk ID '{value}' does not exist.") - if not find_quirk_class(zhaquirks, quirk_tok[0], quirk_tok[1]): - raise ValueError(f"Quirk class '{value}' does not exists.") + # get all quirk ID from zigpy quirks registry + all_quirk_ids = [] + for manufacturer in zigpy_quirks._DEVICE_REGISTRY._registry.values(): + for model_quirk_list in manufacturer.values(): + for quirk in model_quirk_list: + quirk_id = getattr(quirk, ATTR_QUIRK_ID, None) + if quirk_id is not None and quirk_id not in all_quirk_ids: + all_quirk_ids.append(quirk_id) + del quirk, model_quirk_list, manufacturer + # validate all quirk IDs used in component match rules for rule, _ in iter_all_rules(): - quirk_class_validator(rule.quirk_classes) + quirk_class_validator(rule.quirk_ids) def test_entity_names() -> None: