diff --git a/homeassistant/components/zha/core/channels/__init__.py b/homeassistant/components/zha/core/channels/__init__.py index 149b733be39..a708e65a07a 100644 --- a/homeassistant/components/zha/core/channels/__init__.py +++ b/homeassistant/components/zha/core/channels/__init__.py @@ -239,6 +239,11 @@ class ChannelPool: """Return device model.""" return self._channels.zha_device.model + @property + def quirk_class(self) -> str: + """Return device quirk class.""" + return self._channels.zha_device.quirk_class + @property def skip_configuration(self) -> bool: """Return True if device does not require channel configuration.""" diff --git a/homeassistant/components/zha/core/discovery.py b/homeassistant/components/zha/core/discovery.py index eb7dd81e381..d256b98cfb1 100644 --- a/homeassistant/components/zha/core/discovery.py +++ b/homeassistant/components/zha/core/discovery.py @@ -95,7 +95,11 @@ class ProbeEndpoint: if component and component in zha_const.PLATFORMS: channels = channel_pool.unclaimed_channels() entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity( - component, channel_pool.manufacturer, channel_pool.model, channels + component, + channel_pool.manufacturer, + channel_pool.model, + channels, + channel_pool.quirk_class, ) if entity_class is None: return @@ -145,7 +149,11 @@ class ProbeEndpoint: unique_id = f"{ep_channels.unique_id}-{channel.cluster.cluster_id}" entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity( - component, ep_channels.manufacturer, ep_channels.model, channel_list + component, + ep_channels.manufacturer, + ep_channels.model, + channel_list, + ep_channels.quirk_class, ) if entity_class is None: return @@ -190,12 +198,14 @@ class ProbeEndpoint: channel_pool.manufacturer, channel_pool.model, list(channel_pool.all_channels.values()), + channel_pool.quirk_class, ) else: matches, claimed = zha_regs.ZHA_ENTITIES.get_multi_entity( channel_pool.manufacturer, channel_pool.model, channel_pool.unclaimed_channels(), + channel_pool.quirk_class, ) channel_pool.claim_channels(claimed) @@ -210,8 +220,7 @@ class ProbeEndpoint: for component, ent_n_chan_list in matches.items(): for entity_and_channel in ent_n_chan_list: if component == cmpt_by_dev_type: - # for well known device types, like thermostats - # we'll take only 1st class + # for well known device types, like thermostats we'll take only 1st class channel_pool.async_new_entity( component, entity_and_channel.entity_class, diff --git a/homeassistant/components/zha/core/registries.py b/homeassistant/components/zha/core/registries.py index 6b99d412688..a7504ae7a96 100644 --- a/homeassistant/components/zha/core/registries.py +++ b/homeassistant/components/zha/core/registries.py @@ -93,9 +93,7 @@ DEVICE_CLASS = { zigpy.profiles.zha.DeviceType.ON_OFF_PLUG_IN_UNIT: Platform.SWITCH, zigpy.profiles.zha.DeviceType.SHADE: Platform.COVER, zigpy.profiles.zha.DeviceType.SMART_PLUG: Platform.SWITCH, - zigpy.profiles.zha.DeviceType.IAS_ANCILLARY_CONTROL: ( - Platform.ALARM_CONTROL_PANEL - ), + zigpy.profiles.zha.DeviceType.IAS_ANCILLARY_CONTROL: Platform.ALARM_CONTROL_PANEL, zigpy.profiles.zha.DeviceType.IAS_WARNING_DEVICE: Platform.SIREN, }, zigpy.profiles.zll.PROFILE_ID: { @@ -146,13 +144,17 @@ class MatchRule: aux_channels: frozenset[str] | Callable = attr.ib( factory=_get_empty_frozenset, converter=set_or_callable ) + quirk_classes: frozenset[str] | Callable = attr.ib( + factory=_get_empty_frozenset, converter=set_or_callable + ) @property def weight(self) -> int: """Return the weight of the matching rule. - More specific matches should be preferred over less specific. Model matching - rules have a priority over manufacturer matching rules and rules matching a + More specific matches should be preferred over less specific. Quirk class + matching rules have priority over model matching rules + and have a priority over manufacturer matching rules and rules matching a single model/manufacturer get a better priority over rules matching multiple models/manufacturers. And any model or manufacturers matching rules get better priority over rules matching only channels. @@ -160,6 +162,11 @@ class MatchRule: multiple channels a better priority over rules matching a single channel. """ weight = 0 + if self.quirk_classes: + weight += 501 - ( + 1 if callable(self.quirk_classes) else len(self.quirk_classes) + ) + if self.models: weight += 401 - (1 if callable(self.models) else len(self.models)) @@ -187,15 +194,21 @@ class MatchRule: claimed.extend([ch for ch in channel_pool if ch.name in self.aux_channels]) return claimed - def strict_matched(self, manufacturer: str, model: str, channels: list) -> bool: + def strict_matched( + self, manufacturer: str, model: str, channels: list, quirk_class: str + ) -> bool: """Return True if this device matches the criteria.""" - return all(self._matched(manufacturer, model, channels)) + return all(self._matched(manufacturer, model, channels, quirk_class)) - def loose_matched(self, manufacturer: str, model: str, channels: list) -> bool: + def loose_matched( + self, manufacturer: str, model: str, channels: list, quirk_class: str + ) -> bool: """Return True if this device matches the criteria.""" - return any(self._matched(manufacturer, model, channels)) + return any(self._matched(manufacturer, model, channels, quirk_class)) - def _matched(self, manufacturer: str, model: str, channels: list) -> list: + def _matched( + self, manufacturer: str, model: str, channels: list, quirk_class: str + ) -> list: """Return a list of field matches.""" if not any(attr.asdict(self).values()): return [False] @@ -221,6 +234,12 @@ class MatchRule: else: matches.append(model in self.models) + if self.quirk_classes: + if callable(self.quirk_classes): + matches.append(self.quirk_classes(quirk_class)) + else: + matches.append(quirk_class in self.quirk_classes) + return matches @@ -261,12 +280,13 @@ class ZHAEntityRegistry: manufacturer: str, model: str, channels: list[ZigbeeChannel], + quirk_class: str, default: type[ZhaEntity] | None = None, ) -> tuple[type[ZhaEntity] | None, list[ZigbeeChannel]]: """Match a ZHA Channels to a ZHA Entity class.""" matches = self._strict_registry[component] for match in sorted(matches, key=lambda x: x.weight, reverse=True): - if match.strict_matched(manufacturer, model, channels): + if match.strict_matched(manufacturer, model, channels, quirk_class): claimed = match.claim_channels(channels) return self._strict_registry[component][match], claimed @@ -277,6 +297,7 @@ class ZHAEntityRegistry: manufacturer: str, model: str, channels: list[ZigbeeChannel], + quirk_class: str, ) -> tuple[dict[str, list[EntityClassAndChannels]], list[ZigbeeChannel]]: """Match ZHA Channels to potentially multiple ZHA Entity classes.""" result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list) @@ -285,7 +306,7 @@ class ZHAEntityRegistry: for stop_match_grp, matches in stop_match_groups.items(): sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True) for match in sorted_matches: - if match.strict_matched(manufacturer, model, channels): + if match.strict_matched(manufacturer, model, channels, quirk_class): claimed = match.claim_channels(channels) for ent_class in stop_match_groups[stop_match_grp][match]: ent_n_channels = EntityClassAndChannels(ent_class, claimed) @@ -301,6 +322,7 @@ class ZHAEntityRegistry: manufacturer: str, model: str, channels: list[ZigbeeChannel], + quirk_class: str, ) -> tuple[dict[str, list[EntityClassAndChannels]], list[ZigbeeChannel]]: """Match ZHA Channels to potentially multiple ZHA Entity classes.""" result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list) @@ -312,7 +334,7 @@ class ZHAEntityRegistry: for stop_match_grp, matches in stop_match_groups.items(): sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True) for match in sorted_matches: - if match.strict_matched(manufacturer, model, channels): + if match.strict_matched(manufacturer, model, channels, quirk_class): claimed = match.claim_channels(channels) for ent_class in stop_match_groups[stop_match_grp][match]: ent_n_channels = EntityClassAndChannels(ent_class, claimed) @@ -335,11 +357,17 @@ class ZHAEntityRegistry: manufacturers: Callable | set[str] | str | None = None, models: Callable | set[str] | str | None = None, aux_channels: Callable | set[str] | str | None = None, + quirk_classes: set[str] | str | None = None, ) -> Callable[[_ZhaEntityT], _ZhaEntityT]: """Decorate a strict match rule.""" rule = MatchRule( - channel_names, generic_ids, manufacturers, models, aux_channels + channel_names, + generic_ids, + manufacturers, + models, + aux_channels, + quirk_classes, ) def decorator(zha_ent: _ZhaEntityT) -> _ZhaEntityT: @@ -361,6 +389,7 @@ class ZHAEntityRegistry: models: Callable | set[str] | str | None = None, aux_channels: Callable | set[str] | str | None = None, stop_on_match_group: int | str | None = None, + quirk_classes: set[str] | str | None = None, ) -> Callable[[_ZhaEntityT], _ZhaEntityT]: """Decorate a loose match rule.""" @@ -370,6 +399,7 @@ class ZHAEntityRegistry: manufacturers, models, aux_channels, + quirk_classes, ) def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT: @@ -394,6 +424,7 @@ class ZHAEntityRegistry: models: Callable | set[str] | str | None = None, aux_channels: Callable | set[str] | str | None = None, stop_on_match_group: int | str | None = None, + quirk_classes: set[str] | str | None = None, ) -> Callable[[_ZhaEntityT], _ZhaEntityT]: """Decorate a loose match rule.""" @@ -403,6 +434,7 @@ class ZHAEntityRegistry: manufacturers, models, aux_channels, + quirk_classes, ) def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT: diff --git a/tests/components/zha/test_registries.py b/tests/components/zha/test_registries.py index db7aa2791cf..24cd7a5785f 100644 --- a/tests/components/zha/test_registries.py +++ b/tests/components/zha/test_registries.py @@ -1,13 +1,16 @@ """Test ZHA registries.""" +import inspect from unittest import mock import pytest +import zhaquirks import homeassistant.components.zha.core.registries as registries from homeassistant.helpers import entity_registry as er MANUFACTURER = "mock manufacturer" MODEL = "mock model" +QUIRK_CLASS = "mock.class" @pytest.fixture @@ -16,6 +19,7 @@ def zha_device(): dev = mock.MagicMock() dev.manufacturer = MANUFACTURER dev.model = MODEL + dev.quirk_class = QUIRK_CLASS return dev @@ -70,6 +74,16 @@ def channels(channel): (registries.MatchRule(models="no match"), False), (registries.MatchRule(models=MODEL, aux_channels="aux_channel"), True), (registries.MatchRule(models="no match", aux_channels="aux_channel"), False), + (registries.MatchRule(quirk_classes=QUIRK_CLASS), True), + (registries.MatchRule(quirk_classes="no match"), False), + ( + registries.MatchRule(quirk_classes=QUIRK_CLASS, aux_channels="aux_channel"), + True, + ), + ( + registries.MatchRule(quirk_classes="no match", aux_channels="aux_channel"), + False, + ), # match everything ( registries.MatchRule( @@ -77,6 +91,7 @@ def channels(channel): channel_names={"on_off", "level"}, manufacturers=MANUFACTURER, models=MODEL, + quirk_classes=QUIRK_CLASS, ), True, ), @@ -124,11 +139,35 @@ def channels(channel): registries.MatchRule(channel_names="on_off", models=lambda x: x != MODEL), False, ), + ( + registries.MatchRule( + channel_names="on_off", quirk_classes={"random quirk", QUIRK_CLASS} + ), + True, + ), + ( + registries.MatchRule( + channel_names="on_off", quirk_classes={"random quirk", "another quirk"} + ), + False, + ), + ( + registries.MatchRule( + channel_names="on_off", quirk_classes=lambda x: x == QUIRK_CLASS + ), + True, + ), + ( + registries.MatchRule( + channel_names="on_off", quirk_classes=lambda x: x != QUIRK_CLASS + ), + False, + ), ], ) def test_registry_matching(rule, matched, channels) -> None: """Test strict rule matching.""" - assert rule.strict_matched(MANUFACTURER, MODEL, channels) is matched + assert rule.strict_matched(MANUFACTURER, MODEL, channels, QUIRK_CLASS) is matched @pytest.mark.parametrize( @@ -197,6 +236,8 @@ def test_registry_matching(rule, matched, channels) -> None: (registries.MatchRule(manufacturers=MANUFACTURER), True), (registries.MatchRule(models=MODEL), True), (registries.MatchRule(models="no match"), False), + (registries.MatchRule(quirk_classes=QUIRK_CLASS), True), + (registries.MatchRule(quirk_classes="no match"), False), # match everything ( registries.MatchRule( @@ -204,6 +245,7 @@ def test_registry_matching(rule, matched, channels) -> None: channel_names={"on_off", "level"}, manufacturers=MANUFACTURER, models=MODEL, + quirk_classes=QUIRK_CLASS, ), True, ), @@ -211,7 +253,7 @@ def test_registry_matching(rule, matched, channels) -> None: ) def test_registry_loose_matching(rule, matched, channels) -> None: """Test loose rule matching.""" - assert rule.loose_matched(MANUFACTURER, MODEL, channels) is matched + assert rule.loose_matched(MANUFACTURER, MODEL, channels, QUIRK_CLASS) is matched def test_match_rule_claim_channels_color(channel) -> None: @@ -264,18 +306,24 @@ def entity_registry(): @pytest.mark.parametrize( - ("manufacturer", "model", "match_name"), + ("manufacturer", "model", "quirk_class", "match_name"), ( - ("random manufacturer", "random model", "OnOff"), - ("random manufacturer", MODEL, "OnOffModel"), - (MANUFACTURER, "random model", "OnOffManufacturer"), - (MANUFACTURER, MODEL, "OnOffModelManufacturer"), - (MANUFACTURER, "some model", "OnOffMultimodel"), + ("random manufacturer", "random model", "random.class", "OnOff"), + ("random manufacturer", MODEL, "random.class", "OnOffModel"), + (MANUFACTURER, "random model", "random.class", "OnOffManufacturer"), + ("random manufacturer", "random model", QUIRK_CLASS, "OnOffQuirk"), + (MANUFACTURER, MODEL, "random.class", "OnOffModelManufacturer"), + (MANUFACTURER, "some model", "random.class", "OnOffMultimodel"), ), ) def test_weighted_match( - channel, entity_registry: er.EntityRegistry, manufacturer, model, match_name -) -> None: + channel, + entity_registry: er.EntityRegistry, + manufacturer, + model, + quirk_class, + match_name, +): """Test weightedd match.""" s = mock.sentinel @@ -308,11 +356,17 @@ def test_weighted_match( class OnOffModelManufacturer: pass + @entity_registry.strict_match( + s.component, channel_names="on_off", quirk_classes=QUIRK_CLASS + ) + class OnOffQuirk: + pass + ch_on_off = channel("on_off", 6) ch_level = channel("level", 8) match, claimed = entity_registry.get_entity( - s.component, manufacturer, model, [ch_on_off, ch_level] + s.component, manufacturer, model, [ch_on_off, ch_level], quirk_class ) assert match.__name__ == match_name @@ -335,7 +389,10 @@ def test_multi_sensor_match(channel, entity_registry: er.EntityRegistry) -> None ch_illuminati = channel("illuminance", 0x0401) match, claimed = entity_registry.get_multi_entity( - "manufacturer", "model", channels=[ch_se, ch_illuminati] + "manufacturer", + "model", + channels=[ch_se, ch_illuminati], + quirk_class="quirk_class", ) assert s.binary_sensor in match @@ -360,7 +417,10 @@ def test_multi_sensor_match(channel, entity_registry: er.EntityRegistry) -> None pass match, claimed = entity_registry.get_multi_entity( - "manufacturer", "model", channels={ch_se, ch_illuminati} + "manufacturer", + "model", + channels={ch_se, ch_illuminati}, + quirk_class="quirk_class", ) assert s.binary_sensor in match @@ -373,3 +433,62 @@ def test_multi_sensor_match(channel, entity_registry: er.EntityRegistry) -> None assert {cls.entity_class.__name__ for cls in match[s.component]} == { SmartEnergySensor1.__name__ } + + +def test_quirk_classes(): + """Make sure that quirk_classes in components matches are valid.""" + + def find_quirk_class(base_obj, quirk_mod, quirk_cls): + """Find a specific quirk class.""" + mods = dict(inspect.getmembers(base_obj, inspect.ismodule)) + + # Check if we have found the right module + if quirk_mod in mods: + # If so, look for the class + clss = dict(inspect.getmembers(mods[quirk_mod], inspect.isclass)) + if quirk_cls in clss: + # Quirk class found + return True + + else: + # Recurse into other modules + for mod in mods: + if not mods[mod].__name__.startswith("zhaquirks."): + continue + if find_quirk_class(mods[mod], quirk_mod, quirk_cls): + return True + return False + + def quirk_class_validator(value): + """Validate quirk classes during self test.""" + if callable(value): + # Callables cannot be tested + return + + if isinstance(value, (frozenset, set, list)): + for v in value: + # Unpack the value if needed + quirk_class_validator(v) + return + + quirk_tok = value.split(".") + if len(quirk_tok) != 2: + # quirk_class is always __module__.__class__ + raise ValueError(f"Invalid quirk class : '{value}'") + + if not find_quirk_class(zhaquirks, quirk_tok[0], quirk_tok[1]): + raise ValueError(f"Quirk class '{value}' does not exists.") + + for component in registries.ZHA_ENTITIES._strict_registry.items(): + for rule in component[1].items(): + quirk_class_validator(rule[0].quirk_classes) + + for component in registries.ZHA_ENTITIES._multi_entity_registry.items(): + for item in component[1].items(): + for rule in item[1].items(): + quirk_class_validator(rule[0].quirk_classes) + + for component in registries.ZHA_ENTITIES._config_diagnostic_entity_registry.items(): + for item in component[1].items(): + for rule in item[1].items(): + quirk_class_validator(rule[0].quirk_classes)