Implement support for start_up_on_off in ZHA (#70110)

* Implement support for start_up_on_off

fix discovery issues

remove cover change

* add tests
This commit is contained in:
David F. Mulcahey 2022-04-24 12:50:06 -04:00 committed by GitHub
parent 8a73381b56
commit 9b8d217b0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 273 additions and 19 deletions

View File

@ -291,6 +291,9 @@ class OnOffChannel(ZigbeeChannel):
ON_OFF = 0 ON_OFF = 0
REPORT_CONFIG = ({"attr": "on_off", "config": REPORT_CONFIG_IMMEDIATE},) REPORT_CONFIG = ({"attr": "on_off", "config": REPORT_CONFIG_IMMEDIATE},)
ZCL_INIT_ATTRS = {
"start_up_on_off": True,
}
def __init__( def __init__(
self, cluster: zha_typing.ZigpyClusterType, ch_pool: zha_typing.ChannelPoolType self, cluster: zha_typing.ZigpyClusterType, ch_pool: zha_typing.ChannelPoolType

View File

@ -78,6 +78,7 @@ class ProbeEndpoint:
self.discover_by_device_type(channel_pool) self.discover_by_device_type(channel_pool)
self.discover_multi_entities(channel_pool) self.discover_multi_entities(channel_pool)
self.discover_by_cluster_id(channel_pool) self.discover_by_cluster_id(channel_pool)
self.discover_multi_entities(channel_pool, config_diagnostic_entities=True)
zha_regs.ZHA_ENTITIES.clean_up() zha_regs.ZHA_ENTITIES.clean_up()
@callback @callback
@ -177,16 +178,27 @@ class ProbeEndpoint:
@staticmethod @staticmethod
@callback @callback
def discover_multi_entities(channel_pool: ChannelPool) -> None: def discover_multi_entities(
channel_pool: ChannelPool,
config_diagnostic_entities: bool = False,
) -> None:
"""Process an endpoint on and discover multiple entities.""" """Process an endpoint on and discover multiple entities."""
ep_profile_id = channel_pool.endpoint.profile_id ep_profile_id = channel_pool.endpoint.profile_id
ep_device_type = channel_pool.endpoint.device_type ep_device_type = channel_pool.endpoint.device_type
cmpt_by_dev_type = zha_regs.DEVICE_CLASS[ep_profile_id].get(ep_device_type) cmpt_by_dev_type = zha_regs.DEVICE_CLASS[ep_profile_id].get(ep_device_type)
remaining_channels = channel_pool.unclaimed_channels()
if config_diagnostic_entities:
matches, claimed = zha_regs.ZHA_ENTITIES.get_config_diagnostic_entity(
channel_pool.manufacturer,
channel_pool.model,
list(channel_pool.all_channels.values()),
)
else:
matches, claimed = zha_regs.ZHA_ENTITIES.get_multi_entity( matches, claimed = zha_regs.ZHA_ENTITIES.get_multi_entity(
channel_pool.manufacturer, channel_pool.model, remaining_channels channel_pool.manufacturer,
channel_pool.model,
channel_pool.unclaimed_channels(),
) )
channel_pool.claim_channels(claimed) channel_pool.claim_channels(claimed)

View File

@ -232,6 +232,11 @@ class ZHAEntityRegistry:
] = collections.defaultdict( ] = collections.defaultdict(
lambda: collections.defaultdict(lambda: collections.defaultdict(list)) lambda: collections.defaultdict(lambda: collections.defaultdict(list))
) )
self._config_diagnostic_entity_registry: dict[
str, dict[int | str | None, dict[MatchRule, list[CALLABLE_T]]]
] = collections.defaultdict(
lambda: collections.defaultdict(lambda: collections.defaultdict(list))
)
self._group_registry: dict[str, CALLABLE_T] = {} self._group_registry: dict[str, CALLABLE_T] = {}
self.single_device_matches: dict[ self.single_device_matches: dict[
Platform, dict[EUI64, list[str]] Platform, dict[EUI64, list[str]]
@ -278,6 +283,33 @@ class ZHAEntityRegistry:
return result, list(all_claimed) return result, list(all_claimed)
def get_config_diagnostic_entity(
self,
manufacturer: str,
model: str,
channels: list[ChannelType],
) -> tuple[dict[str, list[EntityClassAndChannels]], list[ChannelType]]:
"""Match ZHA Channels to potentially multiple ZHA Entity classes."""
result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list)
all_claimed: set[ChannelType] = set()
for (
component,
stop_match_groups,
) in self._config_diagnostic_entity_registry.items():
for stop_match_grp, matches in stop_match_groups.items():
sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True)
for match in sorted_matches:
if match.strict_matched(manufacturer, model, channels):
claimed = match.claim_channels(channels)
for ent_class in stop_match_groups[stop_match_grp][match]:
ent_n_channels = EntityClassAndChannels(ent_class, claimed)
result[component].append(ent_n_channels)
all_claimed |= set(claimed)
if stop_match_grp:
break
return result, list(all_claimed)
def get_group_entity(self, component: str) -> CALLABLE_T: def get_group_entity(self, component: str) -> CALLABLE_T:
"""Match a ZHA group to a ZHA Entity class.""" """Match a ZHA group to a ZHA Entity class."""
return self._group_registry.get(component) return self._group_registry.get(component)
@ -340,6 +372,39 @@ class ZHAEntityRegistry:
return decorator return decorator
def config_diagnostic_match(
self,
component: str,
channel_names: set[str] | str = None,
generic_ids: set[str] | str = None,
manufacturers: Callable | set[str] | str = None,
models: Callable | set[str] | str = None,
aux_channels: Callable | set[str] | str = None,
stop_on_match_group: int | str | None = None,
) -> Callable[[CALLABLE_T], CALLABLE_T]:
"""Decorate a loose match rule."""
rule = MatchRule(
channel_names,
generic_ids,
manufacturers,
models,
aux_channels,
)
def decorator(zha_entity: CALLABLE_T) -> CALLABLE_T:
"""Register a loose match rule.
All non empty fields of a match rule must match.
"""
# group the rules by channels
self._config_diagnostic_entity_registry[component][stop_on_match_group][
rule
].append(zha_entity)
return zha_entity
return decorator
def group_match(self, component: str) -> Callable[[CALLABLE_T], CALLABLE_T]: def group_match(self, component: str) -> Callable[[CALLABLE_T], CALLABLE_T]:
"""Decorate a group match rule.""" """Decorate a group match rule."""

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from enum import Enum from enum import Enum
import functools import functools
from zigpy.zcl.clusters.general import OnOff
from zigpy.zcl.clusters.security import IasWd from zigpy.zcl.clusters.security import IasWd
from homeassistant.components.select import SelectEntity from homeassistant.components.select import SelectEntity
@ -15,12 +16,20 @@ from homeassistant.helpers.entity import EntityCategory
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .core import discovery from .core import discovery
from .core.const import CHANNEL_IAS_WD, DATA_ZHA, SIGNAL_ADD_ENTITIES, Strobe from .core.const import (
CHANNEL_IAS_WD,
CHANNEL_ON_OFF,
DATA_ZHA,
SIGNAL_ADD_ENTITIES,
Strobe,
)
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .core.typing import ChannelType, ZhaDeviceType from .core.typing import ChannelType, ZhaDeviceType
from .entity import ZhaEntity from .entity import ZhaEntity
MULTI_MATCH = functools.partial(ZHA_ENTITIES.multipass_match, Platform.SELECT) CONFIG_DIAGNOSTIC_MATCH = functools.partial(
ZHA_ENTITIES.config_diagnostic_match, Platform.SELECT
)
async def async_setup_entry( async def async_setup_entry(
@ -100,7 +109,7 @@ class ZHANonZCLSelectEntity(ZHAEnumSelectEntity):
return True return True
@MULTI_MATCH(channel_names=CHANNEL_IAS_WD) @CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_IAS_WD)
class ZHADefaultToneSelectEntity( class ZHADefaultToneSelectEntity(
ZHANonZCLSelectEntity, id_suffix=IasWd.Warning.WarningMode.__name__ ZHANonZCLSelectEntity, id_suffix=IasWd.Warning.WarningMode.__name__
): ):
@ -109,7 +118,7 @@ class ZHADefaultToneSelectEntity(
_enum: Enum = IasWd.Warning.WarningMode _enum: Enum = IasWd.Warning.WarningMode
@MULTI_MATCH(channel_names=CHANNEL_IAS_WD) @CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_IAS_WD)
class ZHADefaultSirenLevelSelectEntity( class ZHADefaultSirenLevelSelectEntity(
ZHANonZCLSelectEntity, id_suffix=IasWd.Warning.SirenLevel.__name__ ZHANonZCLSelectEntity, id_suffix=IasWd.Warning.SirenLevel.__name__
): ):
@ -118,7 +127,7 @@ class ZHADefaultSirenLevelSelectEntity(
_enum: Enum = IasWd.Warning.SirenLevel _enum: Enum = IasWd.Warning.SirenLevel
@MULTI_MATCH(channel_names=CHANNEL_IAS_WD) @CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_IAS_WD)
class ZHADefaultStrobeLevelSelectEntity( class ZHADefaultStrobeLevelSelectEntity(
ZHANonZCLSelectEntity, id_suffix=IasWd.StrobeLevel.__name__ ZHANonZCLSelectEntity, id_suffix=IasWd.StrobeLevel.__name__
): ):
@ -127,8 +136,72 @@ class ZHADefaultStrobeLevelSelectEntity(
_enum: Enum = IasWd.StrobeLevel _enum: Enum = IasWd.StrobeLevel
@MULTI_MATCH(channel_names=CHANNEL_IAS_WD) @CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_IAS_WD)
class ZHADefaultStrobeSelectEntity(ZHANonZCLSelectEntity, id_suffix=Strobe.__name__): class ZHADefaultStrobeSelectEntity(ZHANonZCLSelectEntity, id_suffix=Strobe.__name__):
"""Representation of a ZHA default siren strobe select entity.""" """Representation of a ZHA default siren strobe select entity."""
_enum: Enum = Strobe _enum: Enum = Strobe
class ZCLEnumSelectEntity(ZhaEntity, SelectEntity):
"""Representation of a ZHA ZCL enum select entity."""
_select_attr: str
_attr_entity_category = EntityCategory.CONFIG
_enum: Enum
@classmethod
def create_entity(
cls,
unique_id: str,
zha_device: ZhaDeviceType,
channels: list[ChannelType],
**kwargs,
) -> ZhaEntity | None:
"""Entity Factory.
Return entity if it is a supported configuration, otherwise return None
"""
channel = channels[0]
if cls._select_attr in channel.cluster.unsupported_attributes:
return None
return cls(unique_id, zha_device, channels, **kwargs)
def __init__(
self,
unique_id: str,
zha_device: ZhaDeviceType,
channels: list[ChannelType],
**kwargs,
) -> None:
"""Init this select entity."""
self._attr_options = [entry.name.replace("_", " ") for entry in self._enum]
self._channel: ChannelType = channels[0]
super().__init__(unique_id, zha_device, channels, **kwargs)
@property
def current_option(self) -> str | None:
"""Return the selected entity option to represent the entity state."""
option = self._channel.cluster.get(self._select_attr)
if option is None:
return None
option = self._enum(option)
return option.name.replace("_", " ")
async def async_select_option(self, option: str | int) -> None:
"""Change the selected option."""
await self._channel.cluster.write_attributes(
{self._select_attr: self._enum[option.replace(" ", "_")]}
)
self.async_write_ha_state()
@CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_ON_OFF)
class ZHAStartupOnOffSelectEntity(
ZCLEnumSelectEntity, id_suffix=OnOff.StartUpOnOff.__name__
):
"""Representation of a ZHA startup onoff select entity."""
_select_attr = "start_up_on_off"
_enum: Enum = OnOff.StartUpOnOff

View File

@ -44,6 +44,16 @@ from .zha_devices_list import (
NO_TAIL_ID = re.compile("_\\d$") NO_TAIL_ID = re.compile("_\\d$")
UNIQUE_ID_HD = re.compile(r"^(([\da-fA-F]{2}:){7}[\da-fA-F]{2}-\d{1,3})", re.X) UNIQUE_ID_HD = re.compile(r"^(([\da-fA-F]{2}:){7}[\da-fA-F]{2}-\d{1,3})", re.X)
IGNORE_SUFFIXES = [zigpy.zcl.clusters.general.OnOff.StartUpOnOff.__name__]
def contains_ignored_suffix(unique_id: str) -> bool:
"""Return true if the unique_id ends with an ignored suffix."""
for suffix in IGNORE_SUFFIXES:
if suffix.lower() in unique_id.lower():
return True
return False
@pytest.fixture @pytest.fixture
def channels_mock(zha_device_mock): def channels_mock(zha_device_mock):
@ -142,7 +152,7 @@ async def test_devices(
_, component, entity_cls, unique_id, channels = call[0] _, component, entity_cls, unique_id, channels = call[0]
# the factory can return None. We filter these out to get an accurate created entity count # the factory can return None. We filter these out to get an accurate created entity count
response = entity_cls.create_entity(unique_id, zha_dev, channels) response = entity_cls.create_entity(unique_id, zha_dev, channels)
if response: if response and not contains_ignored_suffix(response.name):
created_entity_count += 1 created_entity_count += 1
unique_id_head = UNIQUE_ID_HD.match(unique_id).group( unique_id_head = UNIQUE_ID_HD.match(unique_id).group(
0 0
@ -178,7 +188,9 @@ async def test_devices(
await hass_disable_services.async_block_till_done() await hass_disable_services.async_block_till_done()
zha_entity_ids = { zha_entity_ids = {
ent for ent in entity_ids if ent.split(".")[0] in zha_const.PLATFORMS ent
for ent in entity_ids
if not contains_ignored_suffix(ent) and ent.split(".")[0] in zha_const.PLATFORMS
} }
assert zha_entity_ids == { assert zha_entity_ids == {
e[DEV_SIG_ENT_MAP_ID] for e in device[DEV_SIG_ENT_MAP].values() e[DEV_SIG_ENT_MAP_ID] for e in device[DEV_SIG_ENT_MAP].values()
@ -319,7 +331,10 @@ async def test_discover_endpoint(device_info, channels_mock, hass):
ha_ent_info = {} ha_ent_info = {}
for call in new_ent.call_args_list: for call in new_ent.call_args_list:
component, entity_cls, unique_id, channels = call[0] component, entity_cls, unique_id, channels = call[0]
unique_id_head = UNIQUE_ID_HD.match(unique_id).group(0) # ieee + endpoint_id if not contains_ignored_suffix(unique_id):
unique_id_head = UNIQUE_ID_HD.match(unique_id).group(
0
) # ieee + endpoint_id
ha_ent_info[(unique_id_head, entity_cls.__name__)] = ( ha_ent_info[(unique_id_head, entity_cls.__name__)] = (
component, component,
unique_id, unique_id,

View File

@ -33,6 +33,28 @@ async def siren(hass, zigpy_device_mock, zha_device_joined_restored):
return zha_device, zigpy_device.endpoints[1].ias_wd return zha_device, zigpy_device.endpoints[1].ias_wd
@pytest.fixture
async def light(hass, zigpy_device_mock):
"""Siren fixture."""
zigpy_device = zigpy_device_mock(
{
1: {
SIG_EP_PROFILE: zha.PROFILE_ID,
SIG_EP_TYPE: zha.DeviceType.ON_OFF_LIGHT,
SIG_EP_INPUT: [
general.Basic.cluster_id,
general.Identify.cluster_id,
general.OnOff.cluster_id,
],
SIG_EP_OUTPUT: [general.Ota.cluster_id],
}
},
)
return zigpy_device
@pytest.fixture @pytest.fixture
def core_rs(hass_storage): def core_rs(hass_storage):
"""Core.restore_state fixture.""" """Core.restore_state fixture."""
@ -149,3 +171,67 @@ async def test_select_restore_state(
state = hass.states.get(entity_id) state = hass.states.get(entity_id)
assert state assert state
assert state.state == security.IasWd.Warning.WarningMode.Burglar.name assert state.state == security.IasWd.Warning.WarningMode.Burglar.name
async def test_on_off_select(hass, light, zha_device_joined_restored):
"""Test zha on off select."""
entity_registry = er.async_get(hass)
on_off_cluster = light.endpoints[1].on_off
on_off_cluster.PLUGGED_ATTR_READS = {
"start_up_on_off": general.OnOff.StartUpOnOff.On
}
zha_device = await zha_device_joined_restored(light)
select_name = general.OnOff.StartUpOnOff.__name__
entity_id = await find_entity_id(
Platform.SELECT,
zha_device,
hass,
qualifier=select_name.lower(),
)
assert entity_id is not None
state = hass.states.get(entity_id)
assert state
assert state.state == STATE_UNKNOWN
assert state.attributes["options"] == ["Off", "On", "Toggle", "PreviousValue"]
entity_entry = entity_registry.async_get(entity_id)
assert entity_entry
assert entity_entry.entity_category == ENTITY_CATEGORY_CONFIG
# Test select option with string value
await hass.services.async_call(
"select",
"select_option",
{
"entity_id": entity_id,
"option": general.OnOff.StartUpOnOff.Off.name,
},
blocking=True,
)
assert on_off_cluster.write_attributes.call_count == 1
assert on_off_cluster.write_attributes.call_args[0][0] == {
"start_up_on_off": general.OnOff.StartUpOnOff.Off
}
state = hass.states.get(entity_id)
assert state
assert state.state == general.OnOff.StartUpOnOff.Off.name
async def test_on_off_select_unsupported(hass, light, zha_device_joined_restored):
"""Test zha on off select unsupported."""
on_off_cluster = light.endpoints[1].on_off
on_off_cluster.add_unsupported_attribute("start_up_on_off")
zha_device = await zha_device_joined_restored(light)
select_name = general.OnOff.StartUpOnOff.__name__
entity_id = await find_entity_id(
Platform.SELECT,
zha_device,
hass,
qualifier=select_name.lower(),
)
assert entity_id is None