mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Add matching on quirk_classes to zha (#87653)
* Add matching on quirk_classes. * Add and fix tests for matching on quirk_classes. * Black fix. * Add a unit test to validate quirk classes.
This commit is contained in:
parent
8968ed1c47
commit
7365522d1f
@ -239,6 +239,11 @@ class ChannelPool:
|
||||
"""Return device model."""
|
||||
return self._channels.zha_device.model
|
||||
|
||||
@property
|
||||
def quirk_class(self) -> str:
|
||||
"""Return device quirk class."""
|
||||
return self._channels.zha_device.quirk_class
|
||||
|
||||
@property
|
||||
def skip_configuration(self) -> bool:
|
||||
"""Return True if device does not require channel configuration."""
|
||||
|
@ -95,7 +95,11 @@ class ProbeEndpoint:
|
||||
if component and component in zha_const.PLATFORMS:
|
||||
channels = channel_pool.unclaimed_channels()
|
||||
entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity(
|
||||
component, channel_pool.manufacturer, channel_pool.model, channels
|
||||
component,
|
||||
channel_pool.manufacturer,
|
||||
channel_pool.model,
|
||||
channels,
|
||||
channel_pool.quirk_class,
|
||||
)
|
||||
if entity_class is None:
|
||||
return
|
||||
@ -145,7 +149,11 @@ class ProbeEndpoint:
|
||||
unique_id = f"{ep_channels.unique_id}-{channel.cluster.cluster_id}"
|
||||
|
||||
entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity(
|
||||
component, ep_channels.manufacturer, ep_channels.model, channel_list
|
||||
component,
|
||||
ep_channels.manufacturer,
|
||||
ep_channels.model,
|
||||
channel_list,
|
||||
ep_channels.quirk_class,
|
||||
)
|
||||
if entity_class is None:
|
||||
return
|
||||
@ -190,12 +198,14 @@ class ProbeEndpoint:
|
||||
channel_pool.manufacturer,
|
||||
channel_pool.model,
|
||||
list(channel_pool.all_channels.values()),
|
||||
channel_pool.quirk_class,
|
||||
)
|
||||
else:
|
||||
matches, claimed = zha_regs.ZHA_ENTITIES.get_multi_entity(
|
||||
channel_pool.manufacturer,
|
||||
channel_pool.model,
|
||||
channel_pool.unclaimed_channels(),
|
||||
channel_pool.quirk_class,
|
||||
)
|
||||
|
||||
channel_pool.claim_channels(claimed)
|
||||
@ -210,8 +220,7 @@ class ProbeEndpoint:
|
||||
for component, ent_n_chan_list in matches.items():
|
||||
for entity_and_channel in ent_n_chan_list:
|
||||
if component == cmpt_by_dev_type:
|
||||
# for well known device types, like thermostats
|
||||
# we'll take only 1st class
|
||||
# for well known device types, like thermostats we'll take only 1st class
|
||||
channel_pool.async_new_entity(
|
||||
component,
|
||||
entity_and_channel.entity_class,
|
||||
|
@ -93,9 +93,7 @@ DEVICE_CLASS = {
|
||||
zigpy.profiles.zha.DeviceType.ON_OFF_PLUG_IN_UNIT: Platform.SWITCH,
|
||||
zigpy.profiles.zha.DeviceType.SHADE: Platform.COVER,
|
||||
zigpy.profiles.zha.DeviceType.SMART_PLUG: Platform.SWITCH,
|
||||
zigpy.profiles.zha.DeviceType.IAS_ANCILLARY_CONTROL: (
|
||||
Platform.ALARM_CONTROL_PANEL
|
||||
),
|
||||
zigpy.profiles.zha.DeviceType.IAS_ANCILLARY_CONTROL: Platform.ALARM_CONTROL_PANEL,
|
||||
zigpy.profiles.zha.DeviceType.IAS_WARNING_DEVICE: Platform.SIREN,
|
||||
},
|
||||
zigpy.profiles.zll.PROFILE_ID: {
|
||||
@ -146,13 +144,17 @@ class MatchRule:
|
||||
aux_channels: frozenset[str] | Callable = attr.ib(
|
||||
factory=_get_empty_frozenset, converter=set_or_callable
|
||||
)
|
||||
quirk_classes: frozenset[str] | Callable = attr.ib(
|
||||
factory=_get_empty_frozenset, converter=set_or_callable
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self) -> int:
|
||||
"""Return the weight of the matching rule.
|
||||
|
||||
More specific matches should be preferred over less specific. Model matching
|
||||
rules have a priority over manufacturer matching rules and rules matching a
|
||||
More specific matches should be preferred over less specific. Quirk class
|
||||
matching rules have priority over model matching rules
|
||||
and have a priority over manufacturer matching rules and rules matching a
|
||||
single model/manufacturer get a better priority over rules matching multiple
|
||||
models/manufacturers. And any model or manufacturers matching rules get better
|
||||
priority over rules matching only channels.
|
||||
@ -160,6 +162,11 @@ class MatchRule:
|
||||
multiple channels a better priority over rules matching a single channel.
|
||||
"""
|
||||
weight = 0
|
||||
if self.quirk_classes:
|
||||
weight += 501 - (
|
||||
1 if callable(self.quirk_classes) else len(self.quirk_classes)
|
||||
)
|
||||
|
||||
if self.models:
|
||||
weight += 401 - (1 if callable(self.models) else len(self.models))
|
||||
|
||||
@ -187,15 +194,21 @@ class MatchRule:
|
||||
claimed.extend([ch for ch in channel_pool if ch.name in self.aux_channels])
|
||||
return claimed
|
||||
|
||||
def strict_matched(self, manufacturer: str, model: str, channels: list) -> bool:
|
||||
def strict_matched(
|
||||
self, manufacturer: str, model: str, channels: list, quirk_class: str
|
||||
) -> bool:
|
||||
"""Return True if this device matches the criteria."""
|
||||
return all(self._matched(manufacturer, model, channels))
|
||||
return all(self._matched(manufacturer, model, channels, quirk_class))
|
||||
|
||||
def loose_matched(self, manufacturer: str, model: str, channels: list) -> bool:
|
||||
def loose_matched(
|
||||
self, manufacturer: str, model: str, channels: list, quirk_class: str
|
||||
) -> bool:
|
||||
"""Return True if this device matches the criteria."""
|
||||
return any(self._matched(manufacturer, model, channels))
|
||||
return any(self._matched(manufacturer, model, channels, quirk_class))
|
||||
|
||||
def _matched(self, manufacturer: str, model: str, channels: list) -> list:
|
||||
def _matched(
|
||||
self, manufacturer: str, model: str, channels: list, quirk_class: str
|
||||
) -> list:
|
||||
"""Return a list of field matches."""
|
||||
if not any(attr.asdict(self).values()):
|
||||
return [False]
|
||||
@ -221,6 +234,12 @@ class MatchRule:
|
||||
else:
|
||||
matches.append(model in self.models)
|
||||
|
||||
if self.quirk_classes:
|
||||
if callable(self.quirk_classes):
|
||||
matches.append(self.quirk_classes(quirk_class))
|
||||
else:
|
||||
matches.append(quirk_class in self.quirk_classes)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
@ -261,12 +280,13 @@ class ZHAEntityRegistry:
|
||||
manufacturer: str,
|
||||
model: str,
|
||||
channels: list[ZigbeeChannel],
|
||||
quirk_class: str,
|
||||
default: type[ZhaEntity] | None = None,
|
||||
) -> tuple[type[ZhaEntity] | None, list[ZigbeeChannel]]:
|
||||
"""Match a ZHA Channels to a ZHA Entity class."""
|
||||
matches = self._strict_registry[component]
|
||||
for match in sorted(matches, key=lambda x: x.weight, reverse=True):
|
||||
if match.strict_matched(manufacturer, model, channels):
|
||||
if match.strict_matched(manufacturer, model, channels, quirk_class):
|
||||
claimed = match.claim_channels(channels)
|
||||
return self._strict_registry[component][match], claimed
|
||||
|
||||
@ -277,6 +297,7 @@ class ZHAEntityRegistry:
|
||||
manufacturer: str,
|
||||
model: str,
|
||||
channels: list[ZigbeeChannel],
|
||||
quirk_class: str,
|
||||
) -> tuple[dict[str, list[EntityClassAndChannels]], list[ZigbeeChannel]]:
|
||||
"""Match ZHA Channels to potentially multiple ZHA Entity classes."""
|
||||
result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list)
|
||||
@ -285,7 +306,7 @@ class ZHAEntityRegistry:
|
||||
for stop_match_grp, matches in stop_match_groups.items():
|
||||
sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True)
|
||||
for match in sorted_matches:
|
||||
if match.strict_matched(manufacturer, model, channels):
|
||||
if match.strict_matched(manufacturer, model, channels, quirk_class):
|
||||
claimed = match.claim_channels(channels)
|
||||
for ent_class in stop_match_groups[stop_match_grp][match]:
|
||||
ent_n_channels = EntityClassAndChannels(ent_class, claimed)
|
||||
@ -301,6 +322,7 @@ class ZHAEntityRegistry:
|
||||
manufacturer: str,
|
||||
model: str,
|
||||
channels: list[ZigbeeChannel],
|
||||
quirk_class: str,
|
||||
) -> tuple[dict[str, list[EntityClassAndChannels]], list[ZigbeeChannel]]:
|
||||
"""Match ZHA Channels to potentially multiple ZHA Entity classes."""
|
||||
result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list)
|
||||
@ -312,7 +334,7 @@ class ZHAEntityRegistry:
|
||||
for stop_match_grp, matches in stop_match_groups.items():
|
||||
sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True)
|
||||
for match in sorted_matches:
|
||||
if match.strict_matched(manufacturer, model, channels):
|
||||
if match.strict_matched(manufacturer, model, channels, quirk_class):
|
||||
claimed = match.claim_channels(channels)
|
||||
for ent_class in stop_match_groups[stop_match_grp][match]:
|
||||
ent_n_channels = EntityClassAndChannels(ent_class, claimed)
|
||||
@ -335,11 +357,17 @@ class ZHAEntityRegistry:
|
||||
manufacturers: Callable | set[str] | str | None = None,
|
||||
models: Callable | set[str] | str | None = None,
|
||||
aux_channels: Callable | set[str] | str | None = None,
|
||||
quirk_classes: set[str] | str | None = None,
|
||||
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
|
||||
"""Decorate a strict match rule."""
|
||||
|
||||
rule = MatchRule(
|
||||
channel_names, generic_ids, manufacturers, models, aux_channels
|
||||
channel_names,
|
||||
generic_ids,
|
||||
manufacturers,
|
||||
models,
|
||||
aux_channels,
|
||||
quirk_classes,
|
||||
)
|
||||
|
||||
def decorator(zha_ent: _ZhaEntityT) -> _ZhaEntityT:
|
||||
@ -361,6 +389,7 @@ class ZHAEntityRegistry:
|
||||
models: Callable | set[str] | str | None = None,
|
||||
aux_channels: Callable | set[str] | str | None = None,
|
||||
stop_on_match_group: int | str | None = None,
|
||||
quirk_classes: set[str] | str | None = None,
|
||||
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
|
||||
"""Decorate a loose match rule."""
|
||||
|
||||
@ -370,6 +399,7 @@ class ZHAEntityRegistry:
|
||||
manufacturers,
|
||||
models,
|
||||
aux_channels,
|
||||
quirk_classes,
|
||||
)
|
||||
|
||||
def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT:
|
||||
@ -394,6 +424,7 @@ class ZHAEntityRegistry:
|
||||
models: Callable | set[str] | str | None = None,
|
||||
aux_channels: Callable | set[str] | str | None = None,
|
||||
stop_on_match_group: int | str | None = None,
|
||||
quirk_classes: set[str] | str | None = None,
|
||||
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
|
||||
"""Decorate a loose match rule."""
|
||||
|
||||
@ -403,6 +434,7 @@ class ZHAEntityRegistry:
|
||||
manufacturers,
|
||||
models,
|
||||
aux_channels,
|
||||
quirk_classes,
|
||||
)
|
||||
|
||||
def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT:
|
||||
|
@ -1,13 +1,16 @@
|
||||
"""Test ZHA registries."""
|
||||
import inspect
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import zhaquirks
|
||||
|
||||
import homeassistant.components.zha.core.registries as registries
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
|
||||
MANUFACTURER = "mock manufacturer"
|
||||
MODEL = "mock model"
|
||||
QUIRK_CLASS = "mock.class"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -16,6 +19,7 @@ def zha_device():
|
||||
dev = mock.MagicMock()
|
||||
dev.manufacturer = MANUFACTURER
|
||||
dev.model = MODEL
|
||||
dev.quirk_class = QUIRK_CLASS
|
||||
return dev
|
||||
|
||||
|
||||
@ -70,6 +74,16 @@ def channels(channel):
|
||||
(registries.MatchRule(models="no match"), False),
|
||||
(registries.MatchRule(models=MODEL, aux_channels="aux_channel"), True),
|
||||
(registries.MatchRule(models="no match", aux_channels="aux_channel"), False),
|
||||
(registries.MatchRule(quirk_classes=QUIRK_CLASS), True),
|
||||
(registries.MatchRule(quirk_classes="no match"), False),
|
||||
(
|
||||
registries.MatchRule(quirk_classes=QUIRK_CLASS, aux_channels="aux_channel"),
|
||||
True,
|
||||
),
|
||||
(
|
||||
registries.MatchRule(quirk_classes="no match", aux_channels="aux_channel"),
|
||||
False,
|
||||
),
|
||||
# match everything
|
||||
(
|
||||
registries.MatchRule(
|
||||
@ -77,6 +91,7 @@ def channels(channel):
|
||||
channel_names={"on_off", "level"},
|
||||
manufacturers=MANUFACTURER,
|
||||
models=MODEL,
|
||||
quirk_classes=QUIRK_CLASS,
|
||||
),
|
||||
True,
|
||||
),
|
||||
@ -124,11 +139,35 @@ def channels(channel):
|
||||
registries.MatchRule(channel_names="on_off", models=lambda x: x != MODEL),
|
||||
False,
|
||||
),
|
||||
(
|
||||
registries.MatchRule(
|
||||
channel_names="on_off", quirk_classes={"random quirk", QUIRK_CLASS}
|
||||
),
|
||||
True,
|
||||
),
|
||||
(
|
||||
registries.MatchRule(
|
||||
channel_names="on_off", quirk_classes={"random quirk", "another quirk"}
|
||||
),
|
||||
False,
|
||||
),
|
||||
(
|
||||
registries.MatchRule(
|
||||
channel_names="on_off", quirk_classes=lambda x: x == QUIRK_CLASS
|
||||
),
|
||||
True,
|
||||
),
|
||||
(
|
||||
registries.MatchRule(
|
||||
channel_names="on_off", quirk_classes=lambda x: x != QUIRK_CLASS
|
||||
),
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_registry_matching(rule, matched, channels) -> None:
|
||||
"""Test strict rule matching."""
|
||||
assert rule.strict_matched(MANUFACTURER, MODEL, channels) is matched
|
||||
assert rule.strict_matched(MANUFACTURER, MODEL, channels, QUIRK_CLASS) is matched
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -197,6 +236,8 @@ def test_registry_matching(rule, matched, channels) -> None:
|
||||
(registries.MatchRule(manufacturers=MANUFACTURER), True),
|
||||
(registries.MatchRule(models=MODEL), True),
|
||||
(registries.MatchRule(models="no match"), False),
|
||||
(registries.MatchRule(quirk_classes=QUIRK_CLASS), True),
|
||||
(registries.MatchRule(quirk_classes="no match"), False),
|
||||
# match everything
|
||||
(
|
||||
registries.MatchRule(
|
||||
@ -204,6 +245,7 @@ def test_registry_matching(rule, matched, channels) -> None:
|
||||
channel_names={"on_off", "level"},
|
||||
manufacturers=MANUFACTURER,
|
||||
models=MODEL,
|
||||
quirk_classes=QUIRK_CLASS,
|
||||
),
|
||||
True,
|
||||
),
|
||||
@ -211,7 +253,7 @@ def test_registry_matching(rule, matched, channels) -> None:
|
||||
)
|
||||
def test_registry_loose_matching(rule, matched, channels) -> None:
|
||||
"""Test loose rule matching."""
|
||||
assert rule.loose_matched(MANUFACTURER, MODEL, channels) is matched
|
||||
assert rule.loose_matched(MANUFACTURER, MODEL, channels, QUIRK_CLASS) is matched
|
||||
|
||||
|
||||
def test_match_rule_claim_channels_color(channel) -> None:
|
||||
@ -264,18 +306,24 @@ def entity_registry():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("manufacturer", "model", "match_name"),
|
||||
("manufacturer", "model", "quirk_class", "match_name"),
|
||||
(
|
||||
("random manufacturer", "random model", "OnOff"),
|
||||
("random manufacturer", MODEL, "OnOffModel"),
|
||||
(MANUFACTURER, "random model", "OnOffManufacturer"),
|
||||
(MANUFACTURER, MODEL, "OnOffModelManufacturer"),
|
||||
(MANUFACTURER, "some model", "OnOffMultimodel"),
|
||||
("random manufacturer", "random model", "random.class", "OnOff"),
|
||||
("random manufacturer", MODEL, "random.class", "OnOffModel"),
|
||||
(MANUFACTURER, "random model", "random.class", "OnOffManufacturer"),
|
||||
("random manufacturer", "random model", QUIRK_CLASS, "OnOffQuirk"),
|
||||
(MANUFACTURER, MODEL, "random.class", "OnOffModelManufacturer"),
|
||||
(MANUFACTURER, "some model", "random.class", "OnOffMultimodel"),
|
||||
),
|
||||
)
|
||||
def test_weighted_match(
|
||||
channel, entity_registry: er.EntityRegistry, manufacturer, model, match_name
|
||||
) -> None:
|
||||
channel,
|
||||
entity_registry: er.EntityRegistry,
|
||||
manufacturer,
|
||||
model,
|
||||
quirk_class,
|
||||
match_name,
|
||||
):
|
||||
"""Test weightedd match."""
|
||||
|
||||
s = mock.sentinel
|
||||
@ -308,11 +356,17 @@ def test_weighted_match(
|
||||
class OnOffModelManufacturer:
|
||||
pass
|
||||
|
||||
@entity_registry.strict_match(
|
||||
s.component, channel_names="on_off", quirk_classes=QUIRK_CLASS
|
||||
)
|
||||
class OnOffQuirk:
|
||||
pass
|
||||
|
||||
ch_on_off = channel("on_off", 6)
|
||||
ch_level = channel("level", 8)
|
||||
|
||||
match, claimed = entity_registry.get_entity(
|
||||
s.component, manufacturer, model, [ch_on_off, ch_level]
|
||||
s.component, manufacturer, model, [ch_on_off, ch_level], quirk_class
|
||||
)
|
||||
|
||||
assert match.__name__ == match_name
|
||||
@ -335,7 +389,10 @@ def test_multi_sensor_match(channel, entity_registry: er.EntityRegistry) -> None
|
||||
ch_illuminati = channel("illuminance", 0x0401)
|
||||
|
||||
match, claimed = entity_registry.get_multi_entity(
|
||||
"manufacturer", "model", channels=[ch_se, ch_illuminati]
|
||||
"manufacturer",
|
||||
"model",
|
||||
channels=[ch_se, ch_illuminati],
|
||||
quirk_class="quirk_class",
|
||||
)
|
||||
|
||||
assert s.binary_sensor in match
|
||||
@ -360,7 +417,10 @@ def test_multi_sensor_match(channel, entity_registry: er.EntityRegistry) -> None
|
||||
pass
|
||||
|
||||
match, claimed = entity_registry.get_multi_entity(
|
||||
"manufacturer", "model", channels={ch_se, ch_illuminati}
|
||||
"manufacturer",
|
||||
"model",
|
||||
channels={ch_se, ch_illuminati},
|
||||
quirk_class="quirk_class",
|
||||
)
|
||||
|
||||
assert s.binary_sensor in match
|
||||
@ -373,3 +433,62 @@ def test_multi_sensor_match(channel, entity_registry: er.EntityRegistry) -> None
|
||||
assert {cls.entity_class.__name__ for cls in match[s.component]} == {
|
||||
SmartEnergySensor1.__name__
|
||||
}
|
||||
|
||||
|
||||
def test_quirk_classes():
|
||||
"""Make sure that quirk_classes in components matches are valid."""
|
||||
|
||||
def find_quirk_class(base_obj, quirk_mod, quirk_cls):
|
||||
"""Find a specific quirk class."""
|
||||
mods = dict(inspect.getmembers(base_obj, inspect.ismodule))
|
||||
|
||||
# Check if we have found the right module
|
||||
if quirk_mod in mods:
|
||||
# If so, look for the class
|
||||
clss = dict(inspect.getmembers(mods[quirk_mod], inspect.isclass))
|
||||
if quirk_cls in clss:
|
||||
# Quirk class found
|
||||
return True
|
||||
|
||||
else:
|
||||
# Recurse into other modules
|
||||
for mod in mods:
|
||||
if not mods[mod].__name__.startswith("zhaquirks."):
|
||||
continue
|
||||
if find_quirk_class(mods[mod], quirk_mod, quirk_cls):
|
||||
return True
|
||||
return False
|
||||
|
||||
def quirk_class_validator(value):
|
||||
"""Validate quirk classes during self test."""
|
||||
if callable(value):
|
||||
# Callables cannot be tested
|
||||
return
|
||||
|
||||
if isinstance(value, (frozenset, set, list)):
|
||||
for v in value:
|
||||
# Unpack the value if needed
|
||||
quirk_class_validator(v)
|
||||
return
|
||||
|
||||
quirk_tok = value.split(".")
|
||||
if len(quirk_tok) != 2:
|
||||
# quirk_class is always __module__.__class__
|
||||
raise ValueError(f"Invalid quirk class : '{value}'")
|
||||
|
||||
if not find_quirk_class(zhaquirks, quirk_tok[0], quirk_tok[1]):
|
||||
raise ValueError(f"Quirk class '{value}' does not exists.")
|
||||
|
||||
for component in registries.ZHA_ENTITIES._strict_registry.items():
|
||||
for rule in component[1].items():
|
||||
quirk_class_validator(rule[0].quirk_classes)
|
||||
|
||||
for component in registries.ZHA_ENTITIES._multi_entity_registry.items():
|
||||
for item in component[1].items():
|
||||
for rule in item[1].items():
|
||||
quirk_class_validator(rule[0].quirk_classes)
|
||||
|
||||
for component in registries.ZHA_ENTITIES._config_diagnostic_entity_registry.items():
|
||||
for item in component[1].items():
|
||||
for rule in item[1].items():
|
||||
quirk_class_validator(rule[0].quirk_classes)
|
||||
|
Loading…
x
Reference in New Issue
Block a user