Add matching on quirk_classes to zha (#87653)

* Add matching on quirk_classes.

* Add and fix tests for matching on quirk_classes.

* Black fix.

* Add a unit test to validate quirk classes.
This commit is contained in:
Guy Martin 2023-03-02 19:43:11 -05:00 committed by GitHub
parent 8968ed1c47
commit 7365522d1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 196 additions and 31 deletions

View File

@ -239,6 +239,11 @@ class ChannelPool:
"""Return device model."""
return self._channels.zha_device.model
@property
def quirk_class(self) -> str:
"""Return device quirk class."""
return self._channels.zha_device.quirk_class
@property
def skip_configuration(self) -> bool:
"""Return True if device does not require channel configuration."""

View File

@ -95,7 +95,11 @@ class ProbeEndpoint:
if component and component in zha_const.PLATFORMS:
channels = channel_pool.unclaimed_channels()
entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity(
component, channel_pool.manufacturer, channel_pool.model, channels
component,
channel_pool.manufacturer,
channel_pool.model,
channels,
channel_pool.quirk_class,
)
if entity_class is None:
return
@ -145,7 +149,11 @@ class ProbeEndpoint:
unique_id = f"{ep_channels.unique_id}-{channel.cluster.cluster_id}"
entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity(
component, ep_channels.manufacturer, ep_channels.model, channel_list
component,
ep_channels.manufacturer,
ep_channels.model,
channel_list,
ep_channels.quirk_class,
)
if entity_class is None:
return
@ -190,12 +198,14 @@ class ProbeEndpoint:
channel_pool.manufacturer,
channel_pool.model,
list(channel_pool.all_channels.values()),
channel_pool.quirk_class,
)
else:
matches, claimed = zha_regs.ZHA_ENTITIES.get_multi_entity(
channel_pool.manufacturer,
channel_pool.model,
channel_pool.unclaimed_channels(),
channel_pool.quirk_class,
)
channel_pool.claim_channels(claimed)
@ -210,8 +220,7 @@ class ProbeEndpoint:
for component, ent_n_chan_list in matches.items():
for entity_and_channel in ent_n_chan_list:
if component == cmpt_by_dev_type:
# for well known device types, like thermostats
# we'll take only 1st class
# for well known device types, like thermostats we'll take only 1st class
channel_pool.async_new_entity(
component,
entity_and_channel.entity_class,

View File

@ -93,9 +93,7 @@ DEVICE_CLASS = {
zigpy.profiles.zha.DeviceType.ON_OFF_PLUG_IN_UNIT: Platform.SWITCH,
zigpy.profiles.zha.DeviceType.SHADE: Platform.COVER,
zigpy.profiles.zha.DeviceType.SMART_PLUG: Platform.SWITCH,
zigpy.profiles.zha.DeviceType.IAS_ANCILLARY_CONTROL: (
Platform.ALARM_CONTROL_PANEL
),
zigpy.profiles.zha.DeviceType.IAS_ANCILLARY_CONTROL: Platform.ALARM_CONTROL_PANEL,
zigpy.profiles.zha.DeviceType.IAS_WARNING_DEVICE: Platform.SIREN,
},
zigpy.profiles.zll.PROFILE_ID: {
@ -146,13 +144,17 @@ class MatchRule:
aux_channels: frozenset[str] | Callable = attr.ib(
factory=_get_empty_frozenset, converter=set_or_callable
)
quirk_classes: frozenset[str] | Callable = attr.ib(
factory=_get_empty_frozenset, converter=set_or_callable
)
@property
def weight(self) -> int:
"""Return the weight of the matching rule.
More specific matches should be preferred over less specific. Model matching
rules have a priority over manufacturer matching rules and rules matching a
More specific matches should be preferred over less specific. Quirk class
matching rules have priority over model matching rules
and have a priority over manufacturer matching rules and rules matching a
single model/manufacturer get a better priority over rules matching multiple
models/manufacturers. And any model or manufacturers matching rules get better
priority over rules matching only channels.
@ -160,6 +162,11 @@ class MatchRule:
multiple channels a better priority over rules matching a single channel.
"""
weight = 0
if self.quirk_classes:
weight += 501 - (
1 if callable(self.quirk_classes) else len(self.quirk_classes)
)
if self.models:
weight += 401 - (1 if callable(self.models) else len(self.models))
@ -187,15 +194,21 @@ class MatchRule:
claimed.extend([ch for ch in channel_pool if ch.name in self.aux_channels])
return claimed
def strict_matched(self, manufacturer: str, model: str, channels: list) -> bool:
def strict_matched(
self, manufacturer: str, model: str, channels: list, quirk_class: str
) -> bool:
"""Return True if this device matches the criteria."""
return all(self._matched(manufacturer, model, channels))
return all(self._matched(manufacturer, model, channels, quirk_class))
def loose_matched(self, manufacturer: str, model: str, channels: list) -> bool:
def loose_matched(
self, manufacturer: str, model: str, channels: list, quirk_class: str
) -> bool:
"""Return True if this device matches the criteria."""
return any(self._matched(manufacturer, model, channels))
return any(self._matched(manufacturer, model, channels, quirk_class))
def _matched(self, manufacturer: str, model: str, channels: list) -> list:
def _matched(
self, manufacturer: str, model: str, channels: list, quirk_class: str
) -> list:
"""Return a list of field matches."""
if not any(attr.asdict(self).values()):
return [False]
@ -221,6 +234,12 @@ class MatchRule:
else:
matches.append(model in self.models)
if self.quirk_classes:
if callable(self.quirk_classes):
matches.append(self.quirk_classes(quirk_class))
else:
matches.append(quirk_class in self.quirk_classes)
return matches
@ -261,12 +280,13 @@ class ZHAEntityRegistry:
manufacturer: str,
model: str,
channels: list[ZigbeeChannel],
quirk_class: str,
default: type[ZhaEntity] | None = None,
) -> tuple[type[ZhaEntity] | None, list[ZigbeeChannel]]:
"""Match a ZHA Channels to a ZHA Entity class."""
matches = self._strict_registry[component]
for match in sorted(matches, key=lambda x: x.weight, reverse=True):
if match.strict_matched(manufacturer, model, channels):
if match.strict_matched(manufacturer, model, channels, quirk_class):
claimed = match.claim_channels(channels)
return self._strict_registry[component][match], claimed
@ -277,6 +297,7 @@ class ZHAEntityRegistry:
manufacturer: str,
model: str,
channels: list[ZigbeeChannel],
quirk_class: str,
) -> tuple[dict[str, list[EntityClassAndChannels]], list[ZigbeeChannel]]:
"""Match ZHA Channels to potentially multiple ZHA Entity classes."""
result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list)
@ -285,7 +306,7 @@ class ZHAEntityRegistry:
for stop_match_grp, matches in stop_match_groups.items():
sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True)
for match in sorted_matches:
if match.strict_matched(manufacturer, model, channels):
if match.strict_matched(manufacturer, model, channels, quirk_class):
claimed = match.claim_channels(channels)
for ent_class in stop_match_groups[stop_match_grp][match]:
ent_n_channels = EntityClassAndChannels(ent_class, claimed)
@ -301,6 +322,7 @@ class ZHAEntityRegistry:
manufacturer: str,
model: str,
channels: list[ZigbeeChannel],
quirk_class: str,
) -> tuple[dict[str, list[EntityClassAndChannels]], list[ZigbeeChannel]]:
"""Match ZHA Channels to potentially multiple ZHA Entity classes."""
result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list)
@ -312,7 +334,7 @@ class ZHAEntityRegistry:
for stop_match_grp, matches in stop_match_groups.items():
sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True)
for match in sorted_matches:
if match.strict_matched(manufacturer, model, channels):
if match.strict_matched(manufacturer, model, channels, quirk_class):
claimed = match.claim_channels(channels)
for ent_class in stop_match_groups[stop_match_grp][match]:
ent_n_channels = EntityClassAndChannels(ent_class, claimed)
@ -335,11 +357,17 @@ class ZHAEntityRegistry:
manufacturers: Callable | set[str] | str | None = None,
models: Callable | set[str] | str | None = None,
aux_channels: Callable | set[str] | str | None = None,
quirk_classes: set[str] | str | None = None,
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
"""Decorate a strict match rule."""
rule = MatchRule(
channel_names, generic_ids, manufacturers, models, aux_channels
channel_names,
generic_ids,
manufacturers,
models,
aux_channels,
quirk_classes,
)
def decorator(zha_ent: _ZhaEntityT) -> _ZhaEntityT:
@ -361,6 +389,7 @@ class ZHAEntityRegistry:
models: Callable | set[str] | str | None = None,
aux_channels: Callable | set[str] | str | None = None,
stop_on_match_group: int | str | None = None,
quirk_classes: set[str] | str | None = None,
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
"""Decorate a loose match rule."""
@ -370,6 +399,7 @@ class ZHAEntityRegistry:
manufacturers,
models,
aux_channels,
quirk_classes,
)
def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT:
@ -394,6 +424,7 @@ class ZHAEntityRegistry:
models: Callable | set[str] | str | None = None,
aux_channels: Callable | set[str] | str | None = None,
stop_on_match_group: int | str | None = None,
quirk_classes: set[str] | str | None = None,
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
"""Decorate a loose match rule."""
@ -403,6 +434,7 @@ class ZHAEntityRegistry:
manufacturers,
models,
aux_channels,
quirk_classes,
)
def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT:

View File

@ -1,13 +1,16 @@
"""Test ZHA registries."""
import inspect
from unittest import mock
import pytest
import zhaquirks
import homeassistant.components.zha.core.registries as registries
from homeassistant.helpers import entity_registry as er
MANUFACTURER = "mock manufacturer"
MODEL = "mock model"
QUIRK_CLASS = "mock.class"
@pytest.fixture
@ -16,6 +19,7 @@ def zha_device():
dev = mock.MagicMock()
dev.manufacturer = MANUFACTURER
dev.model = MODEL
dev.quirk_class = QUIRK_CLASS
return dev
@ -70,6 +74,16 @@ def channels(channel):
(registries.MatchRule(models="no match"), False),
(registries.MatchRule(models=MODEL, aux_channels="aux_channel"), True),
(registries.MatchRule(models="no match", aux_channels="aux_channel"), False),
(registries.MatchRule(quirk_classes=QUIRK_CLASS), True),
(registries.MatchRule(quirk_classes="no match"), False),
(
registries.MatchRule(quirk_classes=QUIRK_CLASS, aux_channels="aux_channel"),
True,
),
(
registries.MatchRule(quirk_classes="no match", aux_channels="aux_channel"),
False,
),
# match everything
(
registries.MatchRule(
@ -77,6 +91,7 @@ def channels(channel):
channel_names={"on_off", "level"},
manufacturers=MANUFACTURER,
models=MODEL,
quirk_classes=QUIRK_CLASS,
),
True,
),
@ -124,11 +139,35 @@ def channels(channel):
registries.MatchRule(channel_names="on_off", models=lambda x: x != MODEL),
False,
),
(
registries.MatchRule(
channel_names="on_off", quirk_classes={"random quirk", QUIRK_CLASS}
),
True,
),
(
registries.MatchRule(
channel_names="on_off", quirk_classes={"random quirk", "another quirk"}
),
False,
),
(
registries.MatchRule(
channel_names="on_off", quirk_classes=lambda x: x == QUIRK_CLASS
),
True,
),
(
registries.MatchRule(
channel_names="on_off", quirk_classes=lambda x: x != QUIRK_CLASS
),
False,
),
],
)
def test_registry_matching(rule, matched, channels) -> None:
"""Test strict rule matching."""
assert rule.strict_matched(MANUFACTURER, MODEL, channels) is matched
assert rule.strict_matched(MANUFACTURER, MODEL, channels, QUIRK_CLASS) is matched
@pytest.mark.parametrize(
@ -197,6 +236,8 @@ def test_registry_matching(rule, matched, channels) -> None:
(registries.MatchRule(manufacturers=MANUFACTURER), True),
(registries.MatchRule(models=MODEL), True),
(registries.MatchRule(models="no match"), False),
(registries.MatchRule(quirk_classes=QUIRK_CLASS), True),
(registries.MatchRule(quirk_classes="no match"), False),
# match everything
(
registries.MatchRule(
@ -204,6 +245,7 @@ def test_registry_matching(rule, matched, channels) -> None:
channel_names={"on_off", "level"},
manufacturers=MANUFACTURER,
models=MODEL,
quirk_classes=QUIRK_CLASS,
),
True,
),
@ -211,7 +253,7 @@ def test_registry_matching(rule, matched, channels) -> None:
)
def test_registry_loose_matching(rule, matched, channels) -> None:
"""Test loose rule matching."""
assert rule.loose_matched(MANUFACTURER, MODEL, channels) is matched
assert rule.loose_matched(MANUFACTURER, MODEL, channels, QUIRK_CLASS) is matched
def test_match_rule_claim_channels_color(channel) -> None:
@ -264,18 +306,24 @@ def entity_registry():
@pytest.mark.parametrize(
("manufacturer", "model", "match_name"),
("manufacturer", "model", "quirk_class", "match_name"),
(
("random manufacturer", "random model", "OnOff"),
("random manufacturer", MODEL, "OnOffModel"),
(MANUFACTURER, "random model", "OnOffManufacturer"),
(MANUFACTURER, MODEL, "OnOffModelManufacturer"),
(MANUFACTURER, "some model", "OnOffMultimodel"),
("random manufacturer", "random model", "random.class", "OnOff"),
("random manufacturer", MODEL, "random.class", "OnOffModel"),
(MANUFACTURER, "random model", "random.class", "OnOffManufacturer"),
("random manufacturer", "random model", QUIRK_CLASS, "OnOffQuirk"),
(MANUFACTURER, MODEL, "random.class", "OnOffModelManufacturer"),
(MANUFACTURER, "some model", "random.class", "OnOffMultimodel"),
),
)
def test_weighted_match(
channel, entity_registry: er.EntityRegistry, manufacturer, model, match_name
) -> None:
channel,
entity_registry: er.EntityRegistry,
manufacturer,
model,
quirk_class,
match_name,
):
"""Test weightedd match."""
s = mock.sentinel
@ -308,11 +356,17 @@ def test_weighted_match(
class OnOffModelManufacturer:
pass
@entity_registry.strict_match(
s.component, channel_names="on_off", quirk_classes=QUIRK_CLASS
)
class OnOffQuirk:
pass
ch_on_off = channel("on_off", 6)
ch_level = channel("level", 8)
match, claimed = entity_registry.get_entity(
s.component, manufacturer, model, [ch_on_off, ch_level]
s.component, manufacturer, model, [ch_on_off, ch_level], quirk_class
)
assert match.__name__ == match_name
@ -335,7 +389,10 @@ def test_multi_sensor_match(channel, entity_registry: er.EntityRegistry) -> None
ch_illuminati = channel("illuminance", 0x0401)
match, claimed = entity_registry.get_multi_entity(
"manufacturer", "model", channels=[ch_se, ch_illuminati]
"manufacturer",
"model",
channels=[ch_se, ch_illuminati],
quirk_class="quirk_class",
)
assert s.binary_sensor in match
@ -360,7 +417,10 @@ def test_multi_sensor_match(channel, entity_registry: er.EntityRegistry) -> None
pass
match, claimed = entity_registry.get_multi_entity(
"manufacturer", "model", channels={ch_se, ch_illuminati}
"manufacturer",
"model",
channels={ch_se, ch_illuminati},
quirk_class="quirk_class",
)
assert s.binary_sensor in match
@ -373,3 +433,62 @@ def test_multi_sensor_match(channel, entity_registry: er.EntityRegistry) -> None
assert {cls.entity_class.__name__ for cls in match[s.component]} == {
SmartEnergySensor1.__name__
}
def test_quirk_classes():
"""Make sure that quirk_classes in components matches are valid."""
def find_quirk_class(base_obj, quirk_mod, quirk_cls):
"""Find a specific quirk class."""
mods = dict(inspect.getmembers(base_obj, inspect.ismodule))
# Check if we have found the right module
if quirk_mod in mods:
# If so, look for the class
clss = dict(inspect.getmembers(mods[quirk_mod], inspect.isclass))
if quirk_cls in clss:
# Quirk class found
return True
else:
# Recurse into other modules
for mod in mods:
if not mods[mod].__name__.startswith("zhaquirks."):
continue
if find_quirk_class(mods[mod], quirk_mod, quirk_cls):
return True
return False
def quirk_class_validator(value):
"""Validate quirk classes during self test."""
if callable(value):
# Callables cannot be tested
return
if isinstance(value, (frozenset, set, list)):
for v in value:
# Unpack the value if needed
quirk_class_validator(v)
return
quirk_tok = value.split(".")
if len(quirk_tok) != 2:
# quirk_class is always __module__.__class__
raise ValueError(f"Invalid quirk class : '{value}'")
if not find_quirk_class(zhaquirks, quirk_tok[0], quirk_tok[1]):
raise ValueError(f"Quirk class '{value}' does not exists.")
for component in registries.ZHA_ENTITIES._strict_registry.items():
for rule in component[1].items():
quirk_class_validator(rule[0].quirk_classes)
for component in registries.ZHA_ENTITIES._multi_entity_registry.items():
for item in component[1].items():
for rule in item[1].items():
quirk_class_validator(rule[0].quirk_classes)
for component in registries.ZHA_ENTITIES._config_diagnostic_entity_registry.items():
for item in component[1].items():
for rule in item[1].items():
quirk_class_validator(rule[0].quirk_classes)