mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 21:57:51 +00:00
Add matching on quirk_classes to zha (#87653)
* Add matching on quirk_classes. * Add and fix tests for matching on quirk_classes. * Black fix. * Add a unit test to validate quirk classes.
This commit is contained in:
parent
8968ed1c47
commit
7365522d1f
@ -239,6 +239,11 @@ class ChannelPool:
|
|||||||
"""Return device model."""
|
"""Return device model."""
|
||||||
return self._channels.zha_device.model
|
return self._channels.zha_device.model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def quirk_class(self) -> str:
|
||||||
|
"""Return device quirk class."""
|
||||||
|
return self._channels.zha_device.quirk_class
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def skip_configuration(self) -> bool:
|
def skip_configuration(self) -> bool:
|
||||||
"""Return True if device does not require channel configuration."""
|
"""Return True if device does not require channel configuration."""
|
||||||
|
@ -95,7 +95,11 @@ class ProbeEndpoint:
|
|||||||
if component and component in zha_const.PLATFORMS:
|
if component and component in zha_const.PLATFORMS:
|
||||||
channels = channel_pool.unclaimed_channels()
|
channels = channel_pool.unclaimed_channels()
|
||||||
entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity(
|
entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity(
|
||||||
component, channel_pool.manufacturer, channel_pool.model, channels
|
component,
|
||||||
|
channel_pool.manufacturer,
|
||||||
|
channel_pool.model,
|
||||||
|
channels,
|
||||||
|
channel_pool.quirk_class,
|
||||||
)
|
)
|
||||||
if entity_class is None:
|
if entity_class is None:
|
||||||
return
|
return
|
||||||
@ -145,7 +149,11 @@ class ProbeEndpoint:
|
|||||||
unique_id = f"{ep_channels.unique_id}-{channel.cluster.cluster_id}"
|
unique_id = f"{ep_channels.unique_id}-{channel.cluster.cluster_id}"
|
||||||
|
|
||||||
entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity(
|
entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity(
|
||||||
component, ep_channels.manufacturer, ep_channels.model, channel_list
|
component,
|
||||||
|
ep_channels.manufacturer,
|
||||||
|
ep_channels.model,
|
||||||
|
channel_list,
|
||||||
|
ep_channels.quirk_class,
|
||||||
)
|
)
|
||||||
if entity_class is None:
|
if entity_class is None:
|
||||||
return
|
return
|
||||||
@ -190,12 +198,14 @@ class ProbeEndpoint:
|
|||||||
channel_pool.manufacturer,
|
channel_pool.manufacturer,
|
||||||
channel_pool.model,
|
channel_pool.model,
|
||||||
list(channel_pool.all_channels.values()),
|
list(channel_pool.all_channels.values()),
|
||||||
|
channel_pool.quirk_class,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
matches, claimed = zha_regs.ZHA_ENTITIES.get_multi_entity(
|
matches, claimed = zha_regs.ZHA_ENTITIES.get_multi_entity(
|
||||||
channel_pool.manufacturer,
|
channel_pool.manufacturer,
|
||||||
channel_pool.model,
|
channel_pool.model,
|
||||||
channel_pool.unclaimed_channels(),
|
channel_pool.unclaimed_channels(),
|
||||||
|
channel_pool.quirk_class,
|
||||||
)
|
)
|
||||||
|
|
||||||
channel_pool.claim_channels(claimed)
|
channel_pool.claim_channels(claimed)
|
||||||
@ -210,8 +220,7 @@ class ProbeEndpoint:
|
|||||||
for component, ent_n_chan_list in matches.items():
|
for component, ent_n_chan_list in matches.items():
|
||||||
for entity_and_channel in ent_n_chan_list:
|
for entity_and_channel in ent_n_chan_list:
|
||||||
if component == cmpt_by_dev_type:
|
if component == cmpt_by_dev_type:
|
||||||
# for well known device types, like thermostats
|
# for well known device types, like thermostats we'll take only 1st class
|
||||||
# we'll take only 1st class
|
|
||||||
channel_pool.async_new_entity(
|
channel_pool.async_new_entity(
|
||||||
component,
|
component,
|
||||||
entity_and_channel.entity_class,
|
entity_and_channel.entity_class,
|
||||||
|
@ -93,9 +93,7 @@ DEVICE_CLASS = {
|
|||||||
zigpy.profiles.zha.DeviceType.ON_OFF_PLUG_IN_UNIT: Platform.SWITCH,
|
zigpy.profiles.zha.DeviceType.ON_OFF_PLUG_IN_UNIT: Platform.SWITCH,
|
||||||
zigpy.profiles.zha.DeviceType.SHADE: Platform.COVER,
|
zigpy.profiles.zha.DeviceType.SHADE: Platform.COVER,
|
||||||
zigpy.profiles.zha.DeviceType.SMART_PLUG: Platform.SWITCH,
|
zigpy.profiles.zha.DeviceType.SMART_PLUG: Platform.SWITCH,
|
||||||
zigpy.profiles.zha.DeviceType.IAS_ANCILLARY_CONTROL: (
|
zigpy.profiles.zha.DeviceType.IAS_ANCILLARY_CONTROL: Platform.ALARM_CONTROL_PANEL,
|
||||||
Platform.ALARM_CONTROL_PANEL
|
|
||||||
),
|
|
||||||
zigpy.profiles.zha.DeviceType.IAS_WARNING_DEVICE: Platform.SIREN,
|
zigpy.profiles.zha.DeviceType.IAS_WARNING_DEVICE: Platform.SIREN,
|
||||||
},
|
},
|
||||||
zigpy.profiles.zll.PROFILE_ID: {
|
zigpy.profiles.zll.PROFILE_ID: {
|
||||||
@ -146,13 +144,17 @@ class MatchRule:
|
|||||||
aux_channels: frozenset[str] | Callable = attr.ib(
|
aux_channels: frozenset[str] | Callable = attr.ib(
|
||||||
factory=_get_empty_frozenset, converter=set_or_callable
|
factory=_get_empty_frozenset, converter=set_or_callable
|
||||||
)
|
)
|
||||||
|
quirk_classes: frozenset[str] | Callable = attr.ib(
|
||||||
|
factory=_get_empty_frozenset, converter=set_or_callable
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def weight(self) -> int:
|
def weight(self) -> int:
|
||||||
"""Return the weight of the matching rule.
|
"""Return the weight of the matching rule.
|
||||||
|
|
||||||
More specific matches should be preferred over less specific. Model matching
|
More specific matches should be preferred over less specific. Quirk class
|
||||||
rules have a priority over manufacturer matching rules and rules matching a
|
matching rules have priority over model matching rules
|
||||||
|
and have a priority over manufacturer matching rules and rules matching a
|
||||||
single model/manufacturer get a better priority over rules matching multiple
|
single model/manufacturer get a better priority over rules matching multiple
|
||||||
models/manufacturers. And any model or manufacturers matching rules get better
|
models/manufacturers. And any model or manufacturers matching rules get better
|
||||||
priority over rules matching only channels.
|
priority over rules matching only channels.
|
||||||
@ -160,6 +162,11 @@ class MatchRule:
|
|||||||
multiple channels a better priority over rules matching a single channel.
|
multiple channels a better priority over rules matching a single channel.
|
||||||
"""
|
"""
|
||||||
weight = 0
|
weight = 0
|
||||||
|
if self.quirk_classes:
|
||||||
|
weight += 501 - (
|
||||||
|
1 if callable(self.quirk_classes) else len(self.quirk_classes)
|
||||||
|
)
|
||||||
|
|
||||||
if self.models:
|
if self.models:
|
||||||
weight += 401 - (1 if callable(self.models) else len(self.models))
|
weight += 401 - (1 if callable(self.models) else len(self.models))
|
||||||
|
|
||||||
@ -187,15 +194,21 @@ class MatchRule:
|
|||||||
claimed.extend([ch for ch in channel_pool if ch.name in self.aux_channels])
|
claimed.extend([ch for ch in channel_pool if ch.name in self.aux_channels])
|
||||||
return claimed
|
return claimed
|
||||||
|
|
||||||
def strict_matched(self, manufacturer: str, model: str, channels: list) -> bool:
|
def strict_matched(
|
||||||
|
self, manufacturer: str, model: str, channels: list, quirk_class: str
|
||||||
|
) -> bool:
|
||||||
"""Return True if this device matches the criteria."""
|
"""Return True if this device matches the criteria."""
|
||||||
return all(self._matched(manufacturer, model, channels))
|
return all(self._matched(manufacturer, model, channels, quirk_class))
|
||||||
|
|
||||||
def loose_matched(self, manufacturer: str, model: str, channels: list) -> bool:
|
def loose_matched(
|
||||||
|
self, manufacturer: str, model: str, channels: list, quirk_class: str
|
||||||
|
) -> bool:
|
||||||
"""Return True if this device matches the criteria."""
|
"""Return True if this device matches the criteria."""
|
||||||
return any(self._matched(manufacturer, model, channels))
|
return any(self._matched(manufacturer, model, channels, quirk_class))
|
||||||
|
|
||||||
def _matched(self, manufacturer: str, model: str, channels: list) -> list:
|
def _matched(
|
||||||
|
self, manufacturer: str, model: str, channels: list, quirk_class: str
|
||||||
|
) -> list:
|
||||||
"""Return a list of field matches."""
|
"""Return a list of field matches."""
|
||||||
if not any(attr.asdict(self).values()):
|
if not any(attr.asdict(self).values()):
|
||||||
return [False]
|
return [False]
|
||||||
@ -221,6 +234,12 @@ class MatchRule:
|
|||||||
else:
|
else:
|
||||||
matches.append(model in self.models)
|
matches.append(model in self.models)
|
||||||
|
|
||||||
|
if self.quirk_classes:
|
||||||
|
if callable(self.quirk_classes):
|
||||||
|
matches.append(self.quirk_classes(quirk_class))
|
||||||
|
else:
|
||||||
|
matches.append(quirk_class in self.quirk_classes)
|
||||||
|
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
|
||||||
@ -261,12 +280,13 @@ class ZHAEntityRegistry:
|
|||||||
manufacturer: str,
|
manufacturer: str,
|
||||||
model: str,
|
model: str,
|
||||||
channels: list[ZigbeeChannel],
|
channels: list[ZigbeeChannel],
|
||||||
|
quirk_class: str,
|
||||||
default: type[ZhaEntity] | None = None,
|
default: type[ZhaEntity] | None = None,
|
||||||
) -> tuple[type[ZhaEntity] | None, list[ZigbeeChannel]]:
|
) -> tuple[type[ZhaEntity] | None, list[ZigbeeChannel]]:
|
||||||
"""Match a ZHA Channels to a ZHA Entity class."""
|
"""Match a ZHA Channels to a ZHA Entity class."""
|
||||||
matches = self._strict_registry[component]
|
matches = self._strict_registry[component]
|
||||||
for match in sorted(matches, key=lambda x: x.weight, reverse=True):
|
for match in sorted(matches, key=lambda x: x.weight, reverse=True):
|
||||||
if match.strict_matched(manufacturer, model, channels):
|
if match.strict_matched(manufacturer, model, channels, quirk_class):
|
||||||
claimed = match.claim_channels(channels)
|
claimed = match.claim_channels(channels)
|
||||||
return self._strict_registry[component][match], claimed
|
return self._strict_registry[component][match], claimed
|
||||||
|
|
||||||
@ -277,6 +297,7 @@ class ZHAEntityRegistry:
|
|||||||
manufacturer: str,
|
manufacturer: str,
|
||||||
model: str,
|
model: str,
|
||||||
channels: list[ZigbeeChannel],
|
channels: list[ZigbeeChannel],
|
||||||
|
quirk_class: str,
|
||||||
) -> tuple[dict[str, list[EntityClassAndChannels]], list[ZigbeeChannel]]:
|
) -> tuple[dict[str, list[EntityClassAndChannels]], list[ZigbeeChannel]]:
|
||||||
"""Match ZHA Channels to potentially multiple ZHA Entity classes."""
|
"""Match ZHA Channels to potentially multiple ZHA Entity classes."""
|
||||||
result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list)
|
result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list)
|
||||||
@ -285,7 +306,7 @@ class ZHAEntityRegistry:
|
|||||||
for stop_match_grp, matches in stop_match_groups.items():
|
for stop_match_grp, matches in stop_match_groups.items():
|
||||||
sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True)
|
sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True)
|
||||||
for match in sorted_matches:
|
for match in sorted_matches:
|
||||||
if match.strict_matched(manufacturer, model, channels):
|
if match.strict_matched(manufacturer, model, channels, quirk_class):
|
||||||
claimed = match.claim_channels(channels)
|
claimed = match.claim_channels(channels)
|
||||||
for ent_class in stop_match_groups[stop_match_grp][match]:
|
for ent_class in stop_match_groups[stop_match_grp][match]:
|
||||||
ent_n_channels = EntityClassAndChannels(ent_class, claimed)
|
ent_n_channels = EntityClassAndChannels(ent_class, claimed)
|
||||||
@ -301,6 +322,7 @@ class ZHAEntityRegistry:
|
|||||||
manufacturer: str,
|
manufacturer: str,
|
||||||
model: str,
|
model: str,
|
||||||
channels: list[ZigbeeChannel],
|
channels: list[ZigbeeChannel],
|
||||||
|
quirk_class: str,
|
||||||
) -> tuple[dict[str, list[EntityClassAndChannels]], list[ZigbeeChannel]]:
|
) -> tuple[dict[str, list[EntityClassAndChannels]], list[ZigbeeChannel]]:
|
||||||
"""Match ZHA Channels to potentially multiple ZHA Entity classes."""
|
"""Match ZHA Channels to potentially multiple ZHA Entity classes."""
|
||||||
result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list)
|
result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list)
|
||||||
@ -312,7 +334,7 @@ class ZHAEntityRegistry:
|
|||||||
for stop_match_grp, matches in stop_match_groups.items():
|
for stop_match_grp, matches in stop_match_groups.items():
|
||||||
sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True)
|
sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True)
|
||||||
for match in sorted_matches:
|
for match in sorted_matches:
|
||||||
if match.strict_matched(manufacturer, model, channels):
|
if match.strict_matched(manufacturer, model, channels, quirk_class):
|
||||||
claimed = match.claim_channels(channels)
|
claimed = match.claim_channels(channels)
|
||||||
for ent_class in stop_match_groups[stop_match_grp][match]:
|
for ent_class in stop_match_groups[stop_match_grp][match]:
|
||||||
ent_n_channels = EntityClassAndChannels(ent_class, claimed)
|
ent_n_channels = EntityClassAndChannels(ent_class, claimed)
|
||||||
@ -335,11 +357,17 @@ class ZHAEntityRegistry:
|
|||||||
manufacturers: Callable | set[str] | str | None = None,
|
manufacturers: Callable | set[str] | str | None = None,
|
||||||
models: Callable | set[str] | str | None = None,
|
models: Callable | set[str] | str | None = None,
|
||||||
aux_channels: Callable | set[str] | str | None = None,
|
aux_channels: Callable | set[str] | str | None = None,
|
||||||
|
quirk_classes: set[str] | str | None = None,
|
||||||
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
|
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
|
||||||
"""Decorate a strict match rule."""
|
"""Decorate a strict match rule."""
|
||||||
|
|
||||||
rule = MatchRule(
|
rule = MatchRule(
|
||||||
channel_names, generic_ids, manufacturers, models, aux_channels
|
channel_names,
|
||||||
|
generic_ids,
|
||||||
|
manufacturers,
|
||||||
|
models,
|
||||||
|
aux_channels,
|
||||||
|
quirk_classes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def decorator(zha_ent: _ZhaEntityT) -> _ZhaEntityT:
|
def decorator(zha_ent: _ZhaEntityT) -> _ZhaEntityT:
|
||||||
@ -361,6 +389,7 @@ class ZHAEntityRegistry:
|
|||||||
models: Callable | set[str] | str | None = None,
|
models: Callable | set[str] | str | None = None,
|
||||||
aux_channels: Callable | set[str] | str | None = None,
|
aux_channels: Callable | set[str] | str | None = None,
|
||||||
stop_on_match_group: int | str | None = None,
|
stop_on_match_group: int | str | None = None,
|
||||||
|
quirk_classes: set[str] | str | None = None,
|
||||||
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
|
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
|
||||||
"""Decorate a loose match rule."""
|
"""Decorate a loose match rule."""
|
||||||
|
|
||||||
@ -370,6 +399,7 @@ class ZHAEntityRegistry:
|
|||||||
manufacturers,
|
manufacturers,
|
||||||
models,
|
models,
|
||||||
aux_channels,
|
aux_channels,
|
||||||
|
quirk_classes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT:
|
def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT:
|
||||||
@ -394,6 +424,7 @@ class ZHAEntityRegistry:
|
|||||||
models: Callable | set[str] | str | None = None,
|
models: Callable | set[str] | str | None = None,
|
||||||
aux_channels: Callable | set[str] | str | None = None,
|
aux_channels: Callable | set[str] | str | None = None,
|
||||||
stop_on_match_group: int | str | None = None,
|
stop_on_match_group: int | str | None = None,
|
||||||
|
quirk_classes: set[str] | str | None = None,
|
||||||
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
|
) -> Callable[[_ZhaEntityT], _ZhaEntityT]:
|
||||||
"""Decorate a loose match rule."""
|
"""Decorate a loose match rule."""
|
||||||
|
|
||||||
@ -403,6 +434,7 @@ class ZHAEntityRegistry:
|
|||||||
manufacturers,
|
manufacturers,
|
||||||
models,
|
models,
|
||||||
aux_channels,
|
aux_channels,
|
||||||
|
quirk_classes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT:
|
def decorator(zha_entity: _ZhaEntityT) -> _ZhaEntityT:
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
"""Test ZHA registries."""
|
"""Test ZHA registries."""
|
||||||
|
import inspect
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import zhaquirks
|
||||||
|
|
||||||
import homeassistant.components.zha.core.registries as registries
|
import homeassistant.components.zha.core.registries as registries
|
||||||
from homeassistant.helpers import entity_registry as er
|
from homeassistant.helpers import entity_registry as er
|
||||||
|
|
||||||
MANUFACTURER = "mock manufacturer"
|
MANUFACTURER = "mock manufacturer"
|
||||||
MODEL = "mock model"
|
MODEL = "mock model"
|
||||||
|
QUIRK_CLASS = "mock.class"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -16,6 +19,7 @@ def zha_device():
|
|||||||
dev = mock.MagicMock()
|
dev = mock.MagicMock()
|
||||||
dev.manufacturer = MANUFACTURER
|
dev.manufacturer = MANUFACTURER
|
||||||
dev.model = MODEL
|
dev.model = MODEL
|
||||||
|
dev.quirk_class = QUIRK_CLASS
|
||||||
return dev
|
return dev
|
||||||
|
|
||||||
|
|
||||||
@ -70,6 +74,16 @@ def channels(channel):
|
|||||||
(registries.MatchRule(models="no match"), False),
|
(registries.MatchRule(models="no match"), False),
|
||||||
(registries.MatchRule(models=MODEL, aux_channels="aux_channel"), True),
|
(registries.MatchRule(models=MODEL, aux_channels="aux_channel"), True),
|
||||||
(registries.MatchRule(models="no match", aux_channels="aux_channel"), False),
|
(registries.MatchRule(models="no match", aux_channels="aux_channel"), False),
|
||||||
|
(registries.MatchRule(quirk_classes=QUIRK_CLASS), True),
|
||||||
|
(registries.MatchRule(quirk_classes="no match"), False),
|
||||||
|
(
|
||||||
|
registries.MatchRule(quirk_classes=QUIRK_CLASS, aux_channels="aux_channel"),
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
registries.MatchRule(quirk_classes="no match", aux_channels="aux_channel"),
|
||||||
|
False,
|
||||||
|
),
|
||||||
# match everything
|
# match everything
|
||||||
(
|
(
|
||||||
registries.MatchRule(
|
registries.MatchRule(
|
||||||
@ -77,6 +91,7 @@ def channels(channel):
|
|||||||
channel_names={"on_off", "level"},
|
channel_names={"on_off", "level"},
|
||||||
manufacturers=MANUFACTURER,
|
manufacturers=MANUFACTURER,
|
||||||
models=MODEL,
|
models=MODEL,
|
||||||
|
quirk_classes=QUIRK_CLASS,
|
||||||
),
|
),
|
||||||
True,
|
True,
|
||||||
),
|
),
|
||||||
@ -124,11 +139,35 @@ def channels(channel):
|
|||||||
registries.MatchRule(channel_names="on_off", models=lambda x: x != MODEL),
|
registries.MatchRule(channel_names="on_off", models=lambda x: x != MODEL),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
registries.MatchRule(
|
||||||
|
channel_names="on_off", quirk_classes={"random quirk", QUIRK_CLASS}
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
registries.MatchRule(
|
||||||
|
channel_names="on_off", quirk_classes={"random quirk", "another quirk"}
|
||||||
|
),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
registries.MatchRule(
|
||||||
|
channel_names="on_off", quirk_classes=lambda x: x == QUIRK_CLASS
|
||||||
|
),
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
registries.MatchRule(
|
||||||
|
channel_names="on_off", quirk_classes=lambda x: x != QUIRK_CLASS
|
||||||
|
),
|
||||||
|
False,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_registry_matching(rule, matched, channels) -> None:
|
def test_registry_matching(rule, matched, channels) -> None:
|
||||||
"""Test strict rule matching."""
|
"""Test strict rule matching."""
|
||||||
assert rule.strict_matched(MANUFACTURER, MODEL, channels) is matched
|
assert rule.strict_matched(MANUFACTURER, MODEL, channels, QUIRK_CLASS) is matched
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -197,6 +236,8 @@ def test_registry_matching(rule, matched, channels) -> None:
|
|||||||
(registries.MatchRule(manufacturers=MANUFACTURER), True),
|
(registries.MatchRule(manufacturers=MANUFACTURER), True),
|
||||||
(registries.MatchRule(models=MODEL), True),
|
(registries.MatchRule(models=MODEL), True),
|
||||||
(registries.MatchRule(models="no match"), False),
|
(registries.MatchRule(models="no match"), False),
|
||||||
|
(registries.MatchRule(quirk_classes=QUIRK_CLASS), True),
|
||||||
|
(registries.MatchRule(quirk_classes="no match"), False),
|
||||||
# match everything
|
# match everything
|
||||||
(
|
(
|
||||||
registries.MatchRule(
|
registries.MatchRule(
|
||||||
@ -204,6 +245,7 @@ def test_registry_matching(rule, matched, channels) -> None:
|
|||||||
channel_names={"on_off", "level"},
|
channel_names={"on_off", "level"},
|
||||||
manufacturers=MANUFACTURER,
|
manufacturers=MANUFACTURER,
|
||||||
models=MODEL,
|
models=MODEL,
|
||||||
|
quirk_classes=QUIRK_CLASS,
|
||||||
),
|
),
|
||||||
True,
|
True,
|
||||||
),
|
),
|
||||||
@ -211,7 +253,7 @@ def test_registry_matching(rule, matched, channels) -> None:
|
|||||||
)
|
)
|
||||||
def test_registry_loose_matching(rule, matched, channels) -> None:
|
def test_registry_loose_matching(rule, matched, channels) -> None:
|
||||||
"""Test loose rule matching."""
|
"""Test loose rule matching."""
|
||||||
assert rule.loose_matched(MANUFACTURER, MODEL, channels) is matched
|
assert rule.loose_matched(MANUFACTURER, MODEL, channels, QUIRK_CLASS) is matched
|
||||||
|
|
||||||
|
|
||||||
def test_match_rule_claim_channels_color(channel) -> None:
|
def test_match_rule_claim_channels_color(channel) -> None:
|
||||||
@ -264,18 +306,24 @@ def entity_registry():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("manufacturer", "model", "match_name"),
|
("manufacturer", "model", "quirk_class", "match_name"),
|
||||||
(
|
(
|
||||||
("random manufacturer", "random model", "OnOff"),
|
("random manufacturer", "random model", "random.class", "OnOff"),
|
||||||
("random manufacturer", MODEL, "OnOffModel"),
|
("random manufacturer", MODEL, "random.class", "OnOffModel"),
|
||||||
(MANUFACTURER, "random model", "OnOffManufacturer"),
|
(MANUFACTURER, "random model", "random.class", "OnOffManufacturer"),
|
||||||
(MANUFACTURER, MODEL, "OnOffModelManufacturer"),
|
("random manufacturer", "random model", QUIRK_CLASS, "OnOffQuirk"),
|
||||||
(MANUFACTURER, "some model", "OnOffMultimodel"),
|
(MANUFACTURER, MODEL, "random.class", "OnOffModelManufacturer"),
|
||||||
|
(MANUFACTURER, "some model", "random.class", "OnOffMultimodel"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def test_weighted_match(
|
def test_weighted_match(
|
||||||
channel, entity_registry: er.EntityRegistry, manufacturer, model, match_name
|
channel,
|
||||||
) -> None:
|
entity_registry: er.EntityRegistry,
|
||||||
|
manufacturer,
|
||||||
|
model,
|
||||||
|
quirk_class,
|
||||||
|
match_name,
|
||||||
|
):
|
||||||
"""Test weightedd match."""
|
"""Test weightedd match."""
|
||||||
|
|
||||||
s = mock.sentinel
|
s = mock.sentinel
|
||||||
@ -308,11 +356,17 @@ def test_weighted_match(
|
|||||||
class OnOffModelManufacturer:
|
class OnOffModelManufacturer:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@entity_registry.strict_match(
|
||||||
|
s.component, channel_names="on_off", quirk_classes=QUIRK_CLASS
|
||||||
|
)
|
||||||
|
class OnOffQuirk:
|
||||||
|
pass
|
||||||
|
|
||||||
ch_on_off = channel("on_off", 6)
|
ch_on_off = channel("on_off", 6)
|
||||||
ch_level = channel("level", 8)
|
ch_level = channel("level", 8)
|
||||||
|
|
||||||
match, claimed = entity_registry.get_entity(
|
match, claimed = entity_registry.get_entity(
|
||||||
s.component, manufacturer, model, [ch_on_off, ch_level]
|
s.component, manufacturer, model, [ch_on_off, ch_level], quirk_class
|
||||||
)
|
)
|
||||||
|
|
||||||
assert match.__name__ == match_name
|
assert match.__name__ == match_name
|
||||||
@ -335,7 +389,10 @@ def test_multi_sensor_match(channel, entity_registry: er.EntityRegistry) -> None
|
|||||||
ch_illuminati = channel("illuminance", 0x0401)
|
ch_illuminati = channel("illuminance", 0x0401)
|
||||||
|
|
||||||
match, claimed = entity_registry.get_multi_entity(
|
match, claimed = entity_registry.get_multi_entity(
|
||||||
"manufacturer", "model", channels=[ch_se, ch_illuminati]
|
"manufacturer",
|
||||||
|
"model",
|
||||||
|
channels=[ch_se, ch_illuminati],
|
||||||
|
quirk_class="quirk_class",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert s.binary_sensor in match
|
assert s.binary_sensor in match
|
||||||
@ -360,7 +417,10 @@ def test_multi_sensor_match(channel, entity_registry: er.EntityRegistry) -> None
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
match, claimed = entity_registry.get_multi_entity(
|
match, claimed = entity_registry.get_multi_entity(
|
||||||
"manufacturer", "model", channels={ch_se, ch_illuminati}
|
"manufacturer",
|
||||||
|
"model",
|
||||||
|
channels={ch_se, ch_illuminati},
|
||||||
|
quirk_class="quirk_class",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert s.binary_sensor in match
|
assert s.binary_sensor in match
|
||||||
@ -373,3 +433,62 @@ def test_multi_sensor_match(channel, entity_registry: er.EntityRegistry) -> None
|
|||||||
assert {cls.entity_class.__name__ for cls in match[s.component]} == {
|
assert {cls.entity_class.__name__ for cls in match[s.component]} == {
|
||||||
SmartEnergySensor1.__name__
|
SmartEnergySensor1.__name__
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_quirk_classes():
|
||||||
|
"""Make sure that quirk_classes in components matches are valid."""
|
||||||
|
|
||||||
|
def find_quirk_class(base_obj, quirk_mod, quirk_cls):
|
||||||
|
"""Find a specific quirk class."""
|
||||||
|
mods = dict(inspect.getmembers(base_obj, inspect.ismodule))
|
||||||
|
|
||||||
|
# Check if we have found the right module
|
||||||
|
if quirk_mod in mods:
|
||||||
|
# If so, look for the class
|
||||||
|
clss = dict(inspect.getmembers(mods[quirk_mod], inspect.isclass))
|
||||||
|
if quirk_cls in clss:
|
||||||
|
# Quirk class found
|
||||||
|
return True
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Recurse into other modules
|
||||||
|
for mod in mods:
|
||||||
|
if not mods[mod].__name__.startswith("zhaquirks."):
|
||||||
|
continue
|
||||||
|
if find_quirk_class(mods[mod], quirk_mod, quirk_cls):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def quirk_class_validator(value):
|
||||||
|
"""Validate quirk classes during self test."""
|
||||||
|
if callable(value):
|
||||||
|
# Callables cannot be tested
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(value, (frozenset, set, list)):
|
||||||
|
for v in value:
|
||||||
|
# Unpack the value if needed
|
||||||
|
quirk_class_validator(v)
|
||||||
|
return
|
||||||
|
|
||||||
|
quirk_tok = value.split(".")
|
||||||
|
if len(quirk_tok) != 2:
|
||||||
|
# quirk_class is always __module__.__class__
|
||||||
|
raise ValueError(f"Invalid quirk class : '{value}'")
|
||||||
|
|
||||||
|
if not find_quirk_class(zhaquirks, quirk_tok[0], quirk_tok[1]):
|
||||||
|
raise ValueError(f"Quirk class '{value}' does not exists.")
|
||||||
|
|
||||||
|
for component in registries.ZHA_ENTITIES._strict_registry.items():
|
||||||
|
for rule in component[1].items():
|
||||||
|
quirk_class_validator(rule[0].quirk_classes)
|
||||||
|
|
||||||
|
for component in registries.ZHA_ENTITIES._multi_entity_registry.items():
|
||||||
|
for item in component[1].items():
|
||||||
|
for rule in item[1].items():
|
||||||
|
quirk_class_validator(rule[0].quirk_classes)
|
||||||
|
|
||||||
|
for component in registries.ZHA_ENTITIES._config_diagnostic_entity_registry.items():
|
||||||
|
for item in component[1].items():
|
||||||
|
for rule in item[1].items():
|
||||||
|
quirk_class_validator(rule[0].quirk_classes)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user