diff --git a/homeassistant/components/zha/core/channels/general.py b/homeassistant/components/zha/core/channels/general.py index 09a1fd80f17..f528057c313 100644 --- a/homeassistant/components/zha/core/channels/general.py +++ b/homeassistant/components/zha/core/channels/general.py @@ -291,6 +291,9 @@ class OnOffChannel(ZigbeeChannel): ON_OFF = 0 REPORT_CONFIG = ({"attr": "on_off", "config": REPORT_CONFIG_IMMEDIATE},) + ZCL_INIT_ATTRS = { + "start_up_on_off": True, + } def __init__( self, cluster: zha_typing.ZigpyClusterType, ch_pool: zha_typing.ChannelPoolType diff --git a/homeassistant/components/zha/core/discovery.py b/homeassistant/components/zha/core/discovery.py index 9f7523d41f0..8d7d53468e2 100644 --- a/homeassistant/components/zha/core/discovery.py +++ b/homeassistant/components/zha/core/discovery.py @@ -78,6 +78,7 @@ class ProbeEndpoint: self.discover_by_device_type(channel_pool) self.discover_multi_entities(channel_pool) self.discover_by_cluster_id(channel_pool) + self.discover_multi_entities(channel_pool, config_diagnostic_entities=True) zha_regs.ZHA_ENTITIES.clean_up() @callback @@ -177,17 +178,28 @@ class ProbeEndpoint: @staticmethod @callback - def discover_multi_entities(channel_pool: ChannelPool) -> None: + def discover_multi_entities( + channel_pool: ChannelPool, + config_diagnostic_entities: bool = False, + ) -> None: """Process an endpoint on and discover multiple entities.""" ep_profile_id = channel_pool.endpoint.profile_id ep_device_type = channel_pool.endpoint.device_type cmpt_by_dev_type = zha_regs.DEVICE_CLASS[ep_profile_id].get(ep_device_type) - remaining_channels = channel_pool.unclaimed_channels() - matches, claimed = zha_regs.ZHA_ENTITIES.get_multi_entity( - channel_pool.manufacturer, channel_pool.model, remaining_channels - ) + if config_diagnostic_entities: + matches, claimed = zha_regs.ZHA_ENTITIES.get_config_diagnostic_entity( + channel_pool.manufacturer, + channel_pool.model, + list(channel_pool.all_channels.values()), + ) + else: + matches, claimed = zha_regs.ZHA_ENTITIES.get_multi_entity( + channel_pool.manufacturer, + channel_pool.model, + channel_pool.unclaimed_channels(), + ) channel_pool.claim_channels(claimed) for component, ent_n_chan_list in matches.items(): diff --git a/homeassistant/components/zha/core/registries.py b/homeassistant/components/zha/core/registries.py index 1d3482cd8f4..fb00e23ac6f 100644 --- a/homeassistant/components/zha/core/registries.py +++ b/homeassistant/components/zha/core/registries.py @@ -232,6 +232,11 @@ class ZHAEntityRegistry: ] = collections.defaultdict( lambda: collections.defaultdict(lambda: collections.defaultdict(list)) ) + self._config_diagnostic_entity_registry: dict[ + str, dict[int | str | None, dict[MatchRule, list[CALLABLE_T]]] + ] = collections.defaultdict( + lambda: collections.defaultdict(lambda: collections.defaultdict(list)) + ) self._group_registry: dict[str, CALLABLE_T] = {} self.single_device_matches: dict[ Platform, dict[EUI64, list[str]] @@ -278,6 +283,33 @@ class ZHAEntityRegistry: return result, list(all_claimed) + def get_config_diagnostic_entity( + self, + manufacturer: str, + model: str, + channels: list[ChannelType], + ) -> tuple[dict[str, list[EntityClassAndChannels]], list[ChannelType]]: + """Match ZHA Channels to potentially multiple ZHA Entity classes.""" + result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list) + all_claimed: set[ChannelType] = set() + for ( + component, + stop_match_groups, + ) in self._config_diagnostic_entity_registry.items(): + for stop_match_grp, matches in stop_match_groups.items(): + sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True) + for match in sorted_matches: + if match.strict_matched(manufacturer, model, channels): + claimed = match.claim_channels(channels) + for ent_class in stop_match_groups[stop_match_grp][match]: + ent_n_channels = EntityClassAndChannels(ent_class, claimed) + result[component].append(ent_n_channels) + all_claimed |= set(claimed) + if stop_match_grp: + break + + return result, list(all_claimed) + def get_group_entity(self, component: str) -> CALLABLE_T: """Match a ZHA group to a ZHA Entity class.""" return self._group_registry.get(component) @@ -340,6 +372,39 @@ class ZHAEntityRegistry: return decorator + def config_diagnostic_match( + self, + component: str, + channel_names: set[str] | str = None, + generic_ids: set[str] | str = None, + manufacturers: Callable | set[str] | str = None, + models: Callable | set[str] | str = None, + aux_channels: Callable | set[str] | str = None, + stop_on_match_group: int | str | None = None, + ) -> Callable[[CALLABLE_T], CALLABLE_T]: + """Decorate a loose match rule.""" + + rule = MatchRule( + channel_names, + generic_ids, + manufacturers, + models, + aux_channels, + ) + + def decorator(zha_entity: CALLABLE_T) -> CALLABLE_T: + """Register a loose match rule. + + All non empty fields of a match rule must match. + """ + # group the rules by channels + self._config_diagnostic_entity_registry[component][stop_on_match_group][ + rule + ].append(zha_entity) + return zha_entity + + return decorator + def group_match(self, component: str) -> Callable[[CALLABLE_T], CALLABLE_T]: """Decorate a group match rule.""" diff --git a/homeassistant/components/zha/select.py b/homeassistant/components/zha/select.py index 7cb214566d1..0a67f1eac5f 100644 --- a/homeassistant/components/zha/select.py +++ b/homeassistant/components/zha/select.py @@ -4,6 +4,7 @@ from __future__ import annotations from enum import Enum import functools +from zigpy.zcl.clusters.general import OnOff from zigpy.zcl.clusters.security import IasWd from homeassistant.components.select import SelectEntity @@ -15,12 +16,20 @@ from homeassistant.helpers.entity import EntityCategory from homeassistant.helpers.entity_platform import AddEntitiesCallback from .core import discovery -from .core.const import CHANNEL_IAS_WD, DATA_ZHA, SIGNAL_ADD_ENTITIES, Strobe +from .core.const import ( + CHANNEL_IAS_WD, + CHANNEL_ON_OFF, + DATA_ZHA, + SIGNAL_ADD_ENTITIES, + Strobe, +) from .core.registries import ZHA_ENTITIES from .core.typing import ChannelType, ZhaDeviceType from .entity import ZhaEntity -MULTI_MATCH = functools.partial(ZHA_ENTITIES.multipass_match, Platform.SELECT) +CONFIG_DIAGNOSTIC_MATCH = functools.partial( + ZHA_ENTITIES.config_diagnostic_match, Platform.SELECT +) async def async_setup_entry( @@ -100,7 +109,7 @@ class ZHANonZCLSelectEntity(ZHAEnumSelectEntity): return True -@MULTI_MATCH(channel_names=CHANNEL_IAS_WD) +@CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_IAS_WD) class ZHADefaultToneSelectEntity( ZHANonZCLSelectEntity, id_suffix=IasWd.Warning.WarningMode.__name__ ): @@ -109,7 +118,7 @@ class ZHADefaultToneSelectEntity( _enum: Enum = IasWd.Warning.WarningMode -@MULTI_MATCH(channel_names=CHANNEL_IAS_WD) +@CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_IAS_WD) class ZHADefaultSirenLevelSelectEntity( ZHANonZCLSelectEntity, id_suffix=IasWd.Warning.SirenLevel.__name__ ): @@ -118,7 +127,7 @@ class ZHADefaultSirenLevelSelectEntity( _enum: Enum = IasWd.Warning.SirenLevel -@MULTI_MATCH(channel_names=CHANNEL_IAS_WD) +@CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_IAS_WD) class ZHADefaultStrobeLevelSelectEntity( ZHANonZCLSelectEntity, id_suffix=IasWd.StrobeLevel.__name__ ): @@ -127,8 +136,72 @@ class ZHADefaultStrobeLevelSelectEntity( _enum: Enum = IasWd.StrobeLevel -@MULTI_MATCH(channel_names=CHANNEL_IAS_WD) +@CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_IAS_WD) class ZHADefaultStrobeSelectEntity(ZHANonZCLSelectEntity, id_suffix=Strobe.__name__): """Representation of a ZHA default siren strobe select entity.""" _enum: Enum = Strobe + + +class ZCLEnumSelectEntity(ZhaEntity, SelectEntity): + """Representation of a ZHA ZCL enum select entity.""" + + _select_attr: str + _attr_entity_category = EntityCategory.CONFIG + _enum: Enum + + @classmethod + def create_entity( + cls, + unique_id: str, + zha_device: ZhaDeviceType, + channels: list[ChannelType], + **kwargs, + ) -> ZhaEntity | None: + """Entity Factory. + + Return entity if it is a supported configuration, otherwise return None + """ + channel = channels[0] + if cls._select_attr in channel.cluster.unsupported_attributes: + return None + + return cls(unique_id, zha_device, channels, **kwargs) + + def __init__( + self, + unique_id: str, + zha_device: ZhaDeviceType, + channels: list[ChannelType], + **kwargs, + ) -> None: + """Init this select entity.""" + self._attr_options = [entry.name.replace("_", " ") for entry in self._enum] + self._channel: ChannelType = channels[0] + super().__init__(unique_id, zha_device, channels, **kwargs) + + @property + def current_option(self) -> str | None: + """Return the selected entity option to represent the entity state.""" + option = self._channel.cluster.get(self._select_attr) + if option is None: + return None + option = self._enum(option) + return option.name.replace("_", " ") + + async def async_select_option(self, option: str | int) -> None: + """Change the selected option.""" + await self._channel.cluster.write_attributes( + {self._select_attr: self._enum[option.replace(" ", "_")]} + ) + self.async_write_ha_state() + + +@CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_ON_OFF) +class ZHAStartupOnOffSelectEntity( + ZCLEnumSelectEntity, id_suffix=OnOff.StartUpOnOff.__name__ +): + """Representation of a ZHA startup onoff select entity.""" + + _select_attr = "start_up_on_off" + _enum: Enum = OnOff.StartUpOnOff diff --git a/tests/components/zha/test_discover.py b/tests/components/zha/test_discover.py index 93a50c77c90..149c77314a1 100644 --- a/tests/components/zha/test_discover.py +++ b/tests/components/zha/test_discover.py @@ -44,6 +44,16 @@ from .zha_devices_list import ( NO_TAIL_ID = re.compile("_\\d$") UNIQUE_ID_HD = re.compile(r"^(([\da-fA-F]{2}:){7}[\da-fA-F]{2}-\d{1,3})", re.X) +IGNORE_SUFFIXES = [zigpy.zcl.clusters.general.OnOff.StartUpOnOff.__name__] + + +def contains_ignored_suffix(unique_id: str) -> bool: + """Return true if the unique_id ends with an ignored suffix.""" + for suffix in IGNORE_SUFFIXES: + if suffix.lower() in unique_id.lower(): + return True + return False + @pytest.fixture def channels_mock(zha_device_mock): @@ -142,7 +152,7 @@ async def test_devices( _, component, entity_cls, unique_id, channels = call[0] # the factory can return None. We filter these out to get an accurate created entity count response = entity_cls.create_entity(unique_id, zha_dev, channels) - if response: + if response and not contains_ignored_suffix(response.name): created_entity_count += 1 unique_id_head = UNIQUE_ID_HD.match(unique_id).group( 0 @@ -178,7 +188,9 @@ async def test_devices( await hass_disable_services.async_block_till_done() zha_entity_ids = { - ent for ent in entity_ids if ent.split(".")[0] in zha_const.PLATFORMS + ent + for ent in entity_ids + if not contains_ignored_suffix(ent) and ent.split(".")[0] in zha_const.PLATFORMS } assert zha_entity_ids == { e[DEV_SIG_ENT_MAP_ID] for e in device[DEV_SIG_ENT_MAP].values() @@ -319,12 +331,15 @@ async def test_discover_endpoint(device_info, channels_mock, hass): ha_ent_info = {} for call in new_ent.call_args_list: component, entity_cls, unique_id, channels = call[0] - unique_id_head = UNIQUE_ID_HD.match(unique_id).group(0) # ieee + endpoint_id - ha_ent_info[(unique_id_head, entity_cls.__name__)] = ( - component, - unique_id, - channels, - ) + if not contains_ignored_suffix(unique_id): + unique_id_head = UNIQUE_ID_HD.match(unique_id).group( + 0 + ) # ieee + endpoint_id + ha_ent_info[(unique_id_head, entity_cls.__name__)] = ( + component, + unique_id, + channels, + ) for comp_id, ent_info in device_info[DEV_SIG_ENT_MAP].items(): component, unique_id = comp_id diff --git a/tests/components/zha/test_select.py b/tests/components/zha/test_select.py index fb21c900838..a761b8ea36b 100644 --- a/tests/components/zha/test_select.py +++ b/tests/components/zha/test_select.py @@ -33,6 +33,28 @@ async def siren(hass, zigpy_device_mock, zha_device_joined_restored): return zha_device, zigpy_device.endpoints[1].ias_wd +@pytest.fixture +async def light(hass, zigpy_device_mock): + """Siren fixture.""" + + zigpy_device = zigpy_device_mock( + { + 1: { + SIG_EP_PROFILE: zha.PROFILE_ID, + SIG_EP_TYPE: zha.DeviceType.ON_OFF_LIGHT, + SIG_EP_INPUT: [ + general.Basic.cluster_id, + general.Identify.cluster_id, + general.OnOff.cluster_id, + ], + SIG_EP_OUTPUT: [general.Ota.cluster_id], + } + }, + ) + + return zigpy_device + + @pytest.fixture def core_rs(hass_storage): """Core.restore_state fixture.""" @@ -149,3 +171,67 @@ async def test_select_restore_state( state = hass.states.get(entity_id) assert state assert state.state == security.IasWd.Warning.WarningMode.Burglar.name + + +async def test_on_off_select(hass, light, zha_device_joined_restored): + """Test zha on off select.""" + + entity_registry = er.async_get(hass) + on_off_cluster = light.endpoints[1].on_off + on_off_cluster.PLUGGED_ATTR_READS = { + "start_up_on_off": general.OnOff.StartUpOnOff.On + } + zha_device = await zha_device_joined_restored(light) + select_name = general.OnOff.StartUpOnOff.__name__ + entity_id = await find_entity_id( + Platform.SELECT, + zha_device, + hass, + qualifier=select_name.lower(), + ) + assert entity_id is not None + + state = hass.states.get(entity_id) + assert state + assert state.state == STATE_UNKNOWN + assert state.attributes["options"] == ["Off", "On", "Toggle", "PreviousValue"] + + entity_entry = entity_registry.async_get(entity_id) + assert entity_entry + assert entity_entry.entity_category == ENTITY_CATEGORY_CONFIG + + # Test select option with string value + await hass.services.async_call( + "select", + "select_option", + { + "entity_id": entity_id, + "option": general.OnOff.StartUpOnOff.Off.name, + }, + blocking=True, + ) + + assert on_off_cluster.write_attributes.call_count == 1 + assert on_off_cluster.write_attributes.call_args[0][0] == { + "start_up_on_off": general.OnOff.StartUpOnOff.Off + } + + state = hass.states.get(entity_id) + assert state + assert state.state == general.OnOff.StartUpOnOff.Off.name + + +async def test_on_off_select_unsupported(hass, light, zha_device_joined_restored): + """Test zha on off select unsupported.""" + + on_off_cluster = light.endpoints[1].on_off + on_off_cluster.add_unsupported_attribute("start_up_on_off") + zha_device = await zha_device_joined_restored(light) + select_name = general.OnOff.StartUpOnOff.__name__ + entity_id = await find_entity_id( + Platform.SELECT, + zha_device, + hass, + qualifier=select_name.lower(), + ) + assert entity_id is None