From 3385893b7724af50f22ac5d49ce9bacbe283b5f6 Mon Sep 17 00:00:00 2001 From: Alexei Chetroi Date: Fri, 21 Feb 2020 18:06:57 -0500 Subject: [PATCH] ZHA device channel refactoring (#31971) * Add ZHA core typing helper. * Add aux_channels to ZHA rule matching. * Add match rule claim_channels() method. * Expose underlying zigpy device. * Not sure we need this one. * Move "base" channels. * Framework for channel discovery. * Make DEVICE_CLASS and REMOTE_DEVICE_TYPE default dicts. * Remove attribute reporting configuration registry. * Refactor channels. - Refactor zha events - Use compound IDs and unique_ids - Refactor signal dispatching on attribute updates * Use unique id compatible with entities unique ids. * Refactor ZHA Entity registry. Let match rule to check for the match. * Refactor discovery to use new channels. * Cleanup ZDO channel. Remove unused zha store call. * Handle channel configuration and initialization. * Refactor ZHA Device to use new channels. * Refactor ZHA Gateway to use new discovery framework. Use hass.data for entity info intermediate store. * Don't keep entities in hass.data. * ZHA gateway new discovery framework. * Refactor ZHA platform loading. * Don't update ZHA entities, when restoring from zigpy. * ZHA entity discover tests. * Add AnalogInput sensor. * Remove 0xFC02 based entity from Keen smart vents. * Clean up IAS channels. * Refactor entity restoration. * Fix lumi.router entities name. * Rename EndpointsChannel to ChannelPool. * Make Channels.pools a list. * Fix cover test. * Fix FakeDevice class. * Fix device actions. * Fix channels typing. * Revert update_before_add=False * Refactor channel class matching. * Use a helper function for adding entities. * Make Pylint happy. * Rebase cleanup. * Update coverage for ZHA device type overrides. * Use cluster_id for single output cluster registry. * Remove ZHA typing from coverage. * Fix tests. * Address comments. * Address comments. --- .coveragerc | 1 + homeassistant/components/zha/__init__.py | 26 +- homeassistant/components/zha/binary_sensor.py | 39 +- .../components/zha/core/channels/__init__.py | 683 ++++++++---------- .../components/zha/core/channels/base.py | 383 ++++++++++ .../components/zha/core/channels/closures.py | 19 +- .../components/zha/core/channels/general.py | 43 +- .../zha/core/channels/homeautomation.py | 16 +- .../components/zha/core/channels/hvac.py | 11 +- .../components/zha/core/channels/lighting.py | 10 +- .../components/zha/core/channels/lightlink.py | 2 +- .../zha/core/channels/manufacturerspecific.py | 33 +- .../zha/core/channels/measurement.py | 2 +- .../components/zha/core/channels/protocol.py | 2 +- .../components/zha/core/channels/security.py | 43 +- .../zha/core/channels/smartenergy.py | 10 +- homeassistant/components/zha/core/const.py | 7 + homeassistant/components/zha/core/device.py | 152 ++-- .../components/zha/core/discovery.py | 358 +++------ homeassistant/components/zha/core/gateway.py | 113 ++- .../components/zha/core/registries.py | 162 +++-- homeassistant/components/zha/core/typing.py | 41 ++ homeassistant/components/zha/cover.py | 39 +- homeassistant/components/zha/device_action.py | 7 +- .../components/zha/device_tracker.py | 45 +- homeassistant/components/zha/entity.py | 1 - homeassistant/components/zha/fan.py | 39 +- homeassistant/components/zha/light.py | 41 +- homeassistant/components/zha/lock.py | 39 +- homeassistant/components/zha/sensor.py | 52 +- homeassistant/components/zha/switch.py | 43 +- tests/components/zha/common.py | 1 + tests/components/zha/conftest.py | 44 +- tests/components/zha/test_channels.py | 233 +++++- tests/components/zha/test_cover.py | 2 +- tests/components/zha/test_device_action.py | 5 +- tests/components/zha/test_device_trigger.py | 5 +- tests/components/zha/test_discover.py | 357 ++++++++- tests/components/zha/test_registries.py | 65 +- tests/components/zha/zha_devices_list.py | 69 +- 40 files changed, 1918 insertions(+), 1325 deletions(-) create mode 100644 homeassistant/components/zha/core/channels/base.py create mode 100644 homeassistant/components/zha/core/typing.py diff --git a/.coveragerc b/.coveragerc index bf980a40c92..e51f4de886d 100644 --- a/.coveragerc +++ b/.coveragerc @@ -841,6 +841,7 @@ omit = homeassistant/components/zha/core/helpers.py homeassistant/components/zha/core/patches.py homeassistant/components/zha/core/registries.py + homeassistant/components/zha/core/typing.py homeassistant/components/zha/entity.py homeassistant/components/zha/light.py homeassistant/components/zha/sensor.py diff --git a/homeassistant/components/zha/__init__.py b/homeassistant/components/zha/__init__.py index 377c77bf601..0d4ceed829b 100644 --- a/homeassistant/components/zha/__init__.py +++ b/homeassistant/components/zha/__init__.py @@ -1,5 +1,6 @@ """Support for Zigbee Home Automation devices.""" +import asyncio import logging import voluptuous as vol @@ -22,6 +23,7 @@ from .core.const import ( DATA_ZHA_CONFIG, DATA_ZHA_DISPATCHERS, DATA_ZHA_GATEWAY, + DATA_ZHA_PLATFORM_LOADED, DEFAULT_BAUDRATE, DEFAULT_RADIO_TYPE, DOMAIN, @@ -87,11 +89,23 @@ async def async_setup_entry(hass, config_entry): Will automatically load components to support devices found on the network. """ - for component in COMPONENTS: - hass.data[DATA_ZHA][component] = hass.data[DATA_ZHA].get(component, {}) - hass.data[DATA_ZHA] = hass.data.get(DATA_ZHA, {}) hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS] = [] + hass.data[DATA_ZHA][DATA_ZHA_PLATFORM_LOADED] = asyncio.Event() + platforms = [] + for component in COMPONENTS: + platforms.append( + hass.async_create_task( + hass.config_entries.async_forward_entry_setup(config_entry, component) + ) + ) + + async def _platforms_loaded(): + await asyncio.gather(*platforms) + hass.data[DATA_ZHA][DATA_ZHA_PLATFORM_LOADED].set() + + hass.async_create_task(_platforms_loaded()) + config = hass.data[DATA_ZHA].get(DATA_ZHA_CONFIG, {}) if config.get(CONF_ENABLE_QUIRKS, True): @@ -112,11 +126,6 @@ async def async_setup_entry(hass, config_entry): model=zha_gateway.radio_description, ) - for component in COMPONENTS: - hass.async_create_task( - hass.config_entries.async_forward_entry_setup(config_entry, component) - ) - api.async_load_api(hass) async def async_zha_shutdown(event): @@ -125,6 +134,7 @@ async def async_setup_entry(hass, config_entry): await hass.data[DATA_ZHA][DATA_ZHA_GATEWAY].async_update_device_storage() hass.bus.async_listen_once(ha_const.EVENT_HOMEASSISTANT_STOP, async_zha_shutdown) + hass.async_create_task(zha_gateway.async_load_devices()) return True diff --git a/homeassistant/components/zha/binary_sensor.py b/homeassistant/components/zha/binary_sensor.py index 58b671a340f..93baf8e111b 100644 --- a/homeassistant/components/zha/binary_sensor.py +++ b/homeassistant/components/zha/binary_sensor.py @@ -18,6 +18,7 @@ from homeassistant.const import STATE_ON from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_connect +from .core import discovery from .core.const import ( CHANNEL_ACCELEROMETER, CHANNEL_OCCUPANCY, @@ -25,8 +26,8 @@ from .core.const import ( CHANNEL_ZONE, DATA_ZHA, DATA_ZHA_DISPATCHERS, + SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, - ZHA_DISCOVERY_NEW, ) from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -48,41 +49,17 @@ STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN) async def async_setup_entry(hass, config_entry, async_add_entities): """Set up the Zigbee Home Automation binary sensor from config entry.""" - - async def async_discover(discovery_info): - await _async_setup_entities( - hass, config_entry, async_add_entities, [discovery_info] - ) + entities_to_create = hass.data[DATA_ZHA][DOMAIN] = [] unsub = async_dispatcher_connect( - hass, ZHA_DISCOVERY_NEW.format(DOMAIN), async_discover + hass, + SIGNAL_ADD_ENTITIES, + functools.partial( + discovery.async_add_entities, async_add_entities, entities_to_create + ), ) hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) - binary_sensors = hass.data.get(DATA_ZHA, {}).get(DOMAIN) - if binary_sensors is not None: - await _async_setup_entities( - hass, config_entry, async_add_entities, binary_sensors.values() - ) - del hass.data[DATA_ZHA][DOMAIN] - - -async def _async_setup_entities( - hass, config_entry, async_add_entities, discovery_infos -): - """Set up the ZHA binary sensors.""" - entities = [] - for discovery_info in discovery_infos: - zha_dev = discovery_info["zha_device"] - channels = discovery_info["channels"] - - entity = ZHA_ENTITIES.get_entity(DOMAIN, zha_dev, channels, BinarySensor) - if entity: - entities.append(entity(**discovery_info)) - - if entities: - async_add_entities(entities, update_before_add=True) - class BinarySensor(ZhaEntity, BinarySensorDevice): """ZHA BinarySensor.""" diff --git a/homeassistant/components/zha/core/channels/__init__.py b/homeassistant/components/zha/core/channels/__init__.py index 1210ac9d32c..ea838a05665 100644 --- a/homeassistant/components/zha/core/channels/__init__.py +++ b/homeassistant/components/zha/core/channels/__init__.py @@ -1,394 +1,13 @@ """Channels module for Zigbee Home Automation.""" import asyncio -from concurrent.futures import TimeoutError as Timeout -from enum import Enum -from functools import wraps import logging -from random import uniform - -import zigpy.exceptions +from typing import Any, Dict, List, Optional, Union from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_send -from ..const import ( - CHANNEL_EVENT_RELAY, - CHANNEL_ZDO, - REPORT_CONFIG_DEFAULT, - REPORT_CONFIG_MAX_INT, - REPORT_CONFIG_MIN_INT, - REPORT_CONFIG_RPT_CHANGE, - SIGNAL_ATTR_UPDATED, -) -from ..helpers import LogMixin, get_attr_id_by_name, safe_read -from ..registries import CLUSTER_REPORT_CONFIGS - -_LOGGER = logging.getLogger(__name__) - - -def parse_and_log_command(channel, tsn, command_id, args): - """Parse and log a zigbee cluster command.""" - cmd = channel.cluster.server_commands.get(command_id, [command_id])[0] - channel.debug( - "received '%s' command with %s args on cluster_id '%s' tsn '%s'", - cmd, - args, - channel.cluster.cluster_id, - tsn, - ) - return cmd - - -def decorate_command(channel, command): - """Wrap a cluster command to make it safe.""" - - @wraps(command) - async def wrapper(*args, **kwds): - try: - result = await command(*args, **kwds) - channel.debug( - "executed command: %s %s %s %s", - command.__name__, - "{}: {}".format("with args", args), - "{}: {}".format("with kwargs", kwds), - "{}: {}".format("and result", result), - ) - return result - - except (zigpy.exceptions.DeliveryError, Timeout) as ex: - channel.debug("command failed: %s exception: %s", command.__name__, str(ex)) - return ex - - return wrapper - - -class ChannelStatus(Enum): - """Status of a channel.""" - - CREATED = 1 - CONFIGURED = 2 - INITIALIZED = 3 - - -class ZigbeeChannel(LogMixin): - """Base channel for a Zigbee cluster.""" - - CHANNEL_NAME = None - REPORT_CONFIG = () - - def __init__(self, cluster, device): - """Initialize ZigbeeChannel.""" - self._channel_name = cluster.ep_attribute - if self.CHANNEL_NAME: - self._channel_name = self.CHANNEL_NAME - self._generic_id = f"channel_0x{cluster.cluster_id:04x}" - self._cluster = cluster - self._zha_device = device - self._id = f"{cluster.endpoint.endpoint_id}:0x{cluster.cluster_id:04x}" - self._unique_id = f"{str(device.ieee)}:{self._id}" - self._report_config = CLUSTER_REPORT_CONFIGS.get( - self._cluster.cluster_id, self.REPORT_CONFIG - ) - self._status = ChannelStatus.CREATED - self._cluster.add_listener(self) - - @property - def id(self) -> str: - """Return channel id unique for this device only.""" - return self._id - - @property - def generic_id(self): - """Return the generic id for this channel.""" - return self._generic_id - - @property - def unique_id(self): - """Return the unique id for this channel.""" - return self._unique_id - - @property - def cluster(self): - """Return the zigpy cluster for this channel.""" - return self._cluster - - @property - def device(self): - """Return the device this channel is linked to.""" - return self._zha_device - - @property - def name(self) -> str: - """Return friendly name.""" - return self._channel_name - - @property - def status(self): - """Return the status of the channel.""" - return self._status - - def set_report_config(self, report_config): - """Set the reporting configuration.""" - self._report_config = report_config - - async def bind(self): - """Bind a zigbee cluster. - - This also swallows DeliveryError exceptions that are thrown when - devices are unreachable. - """ - try: - res = await self.cluster.bind() - self.debug("bound '%s' cluster: %s", self.cluster.ep_attribute, res[0]) - except (zigpy.exceptions.DeliveryError, Timeout) as ex: - self.debug( - "Failed to bind '%s' cluster: %s", self.cluster.ep_attribute, str(ex) - ) - - async def configure_reporting( - self, - attr, - report_config=( - REPORT_CONFIG_MIN_INT, - REPORT_CONFIG_MAX_INT, - REPORT_CONFIG_RPT_CHANGE, - ), - ): - """Configure attribute reporting for a cluster. - - This also swallows DeliveryError exceptions that are thrown when - devices are unreachable. - """ - attr_name = self.cluster.attributes.get(attr, [attr])[0] - - kwargs = {} - if self.cluster.cluster_id >= 0xFC00 and self.device.manufacturer_code: - kwargs["manufacturer"] = self.device.manufacturer_code - - min_report_int, max_report_int, reportable_change = report_config - try: - res = await self.cluster.configure_reporting( - attr, min_report_int, max_report_int, reportable_change, **kwargs - ) - self.debug( - "reporting '%s' attr on '%s' cluster: %d/%d/%d: Result: '%s'", - attr_name, - self.cluster.ep_attribute, - min_report_int, - max_report_int, - reportable_change, - res, - ) - except (zigpy.exceptions.DeliveryError, Timeout) as ex: - self.debug( - "failed to set reporting for '%s' attr on '%s' cluster: %s", - attr_name, - self.cluster.ep_attribute, - str(ex), - ) - - async def async_configure(self): - """Set cluster binding and attribute reporting.""" - if not self._zha_device.skip_configuration: - await self.bind() - if self.cluster.is_server: - for report_config in self._report_config: - await self.configure_reporting( - report_config["attr"], report_config["config"] - ) - await asyncio.sleep(uniform(0.1, 0.5)) - self.debug("finished channel configuration") - else: - self.debug("skipping channel configuration") - self._status = ChannelStatus.CONFIGURED - - async def async_initialize(self, from_cache): - """Initialize channel.""" - self.debug("initializing channel: from_cache: %s", from_cache) - self._status = ChannelStatus.INITIALIZED - - @callback - def cluster_command(self, tsn, command_id, args): - """Handle commands received to this cluster.""" - pass - - @callback - def attribute_updated(self, attrid, value): - """Handle attribute updates on this cluster.""" - pass - - @callback - def zdo_command(self, *args, **kwargs): - """Handle ZDO commands on this cluster.""" - pass - - @callback - def zha_send_event(self, cluster, command, args): - """Relay events to hass.""" - self._zha_device.hass.bus.async_fire( - "zha_event", - { - "unique_id": self._unique_id, - "device_ieee": str(self._zha_device.ieee), - "endpoint_id": cluster.endpoint.endpoint_id, - "cluster_id": cluster.cluster_id, - "command": command, - "args": args, - }, - ) - - async def async_update(self): - """Retrieve latest state from cluster.""" - pass - - async def get_attribute_value(self, attribute, from_cache=True): - """Get the value for an attribute.""" - manufacturer = None - manufacturer_code = self._zha_device.manufacturer_code - if self.cluster.cluster_id >= 0xFC00 and manufacturer_code: - manufacturer = manufacturer_code - result = await safe_read( - self._cluster, - [attribute], - allow_cache=from_cache, - only_cache=from_cache, - manufacturer=manufacturer, - ) - return result.get(attribute) - - def log(self, level, msg, *args): - """Log a message.""" - msg = f"[%s:%s]: {msg}" - args = (self.device.nwk, self._id) + args - _LOGGER.log(level, msg, *args) - - def __getattr__(self, name): - """Get attribute or a decorated cluster command.""" - if hasattr(self._cluster, name) and callable(getattr(self._cluster, name)): - command = getattr(self._cluster, name) - command.__name__ = name - return decorate_command(self, command) - return self.__getattribute__(name) - - -class AttributeListeningChannel(ZigbeeChannel): - """Channel for attribute reports from the cluster.""" - - REPORT_CONFIG = [{"attr": 0, "config": REPORT_CONFIG_DEFAULT}] - - def __init__(self, cluster, device): - """Initialize AttributeListeningChannel.""" - super().__init__(cluster, device) - attr = self._report_config[0].get("attr") - if isinstance(attr, str): - self.value_attribute = get_attr_id_by_name(self.cluster, attr) - else: - self.value_attribute = attr - - @callback - def attribute_updated(self, attrid, value): - """Handle attribute updates on this cluster.""" - if attrid == self.value_attribute: - async_dispatcher_send( - self._zha_device.hass, f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value - ) - - async def async_initialize(self, from_cache): - """Initialize listener.""" - await self.get_attribute_value( - self._report_config[0].get("attr"), from_cache=from_cache - ) - await super().async_initialize(from_cache) - - -class ZDOChannel(LogMixin): - """Channel for ZDO events.""" - - def __init__(self, cluster, device): - """Initialize ZDOChannel.""" - self.name = CHANNEL_ZDO - self._cluster = cluster - self._zha_device = device - self._status = ChannelStatus.CREATED - self._unique_id = "{}:{}_ZDO".format(str(device.ieee), device.name) - self._cluster.add_listener(self) - - @property - def unique_id(self): - """Return the unique id for this channel.""" - return self._unique_id - - @property - def cluster(self): - """Return the aigpy cluster for this channel.""" - return self._cluster - - @property - def status(self): - """Return the status of the channel.""" - return self._status - - @callback - def device_announce(self, zigpy_device): - """Device announce handler.""" - pass - - @callback - def permit_duration(self, duration): - """Permit handler.""" - pass - - async def async_initialize(self, from_cache): - """Initialize channel.""" - entry = self._zha_device.gateway.zha_storage.async_get_or_create( - self._zha_device - ) - self.debug("entry loaded from storage: %s", entry) - self._status = ChannelStatus.INITIALIZED - - async def async_configure(self): - """Configure channel.""" - self._status = ChannelStatus.CONFIGURED - - def log(self, level, msg, *args): - """Log a message.""" - msg = f"[%s:ZDO](%s): {msg}" - args = (self._zha_device.nwk, self._zha_device.model) + args - _LOGGER.log(level, msg, *args) - - -class EventRelayChannel(ZigbeeChannel): - """Event relay that can be attached to zigbee clusters.""" - - CHANNEL_NAME = CHANNEL_EVENT_RELAY - - @callback - def attribute_updated(self, attrid, value): - """Handle an attribute updated on this cluster.""" - self.zha_send_event( - self._cluster, - SIGNAL_ATTR_UPDATED, - { - "attribute_id": attrid, - "attribute_name": self._cluster.attributes.get(attrid, ["Unknown"])[0], - "value": value, - }, - ) - - @callback - def cluster_command(self, tsn, command_id, args): - """Handle a cluster command received on this cluster.""" - if ( - self._cluster.server_commands is not None - and self._cluster.server_commands.get(command_id) is not None - ): - self.zha_send_event( - self._cluster, self._cluster.server_commands.get(command_id)[0], args - ) - - -# pylint: disable=wrong-import-position, import-outside-toplevel -from . import ( # noqa: F401 isort:skip +from . import ( # noqa: F401 # pylint: disable=unused-import + base, closures, general, homeautomation, @@ -401,3 +20,299 @@ from . import ( # noqa: F401 isort:skip security, smartenergy, ) +from .. import ( + const, + device as zha_core_device, + discovery as zha_disc, + registries as zha_regs, + typing as zha_typing, +) + +_LOGGER = logging.getLogger(__name__) +ChannelsDict = Dict[str, zha_typing.ChannelType] + + +class Channels: + """All discovered channels of a device.""" + + def __init__(self, zha_device: zha_typing.ZhaDeviceType) -> None: + """Initialize instance.""" + self._pools: List[zha_typing.ChannelPoolType] = [] + self._power_config = None + self._semaphore = asyncio.Semaphore(3) + self._unique_id = str(zha_device.ieee) + self._zdo_channel = base.ZDOChannel(zha_device.device.endpoints[0], zha_device) + self._zha_device = zha_device + + @property + def pools(self) -> List["ChannelPool"]: + """Return channel pools list.""" + return self._pools + + @property + def power_configuration_ch(self) -> zha_typing.ChannelType: + """Return power configuration channel.""" + return self._power_config + + @power_configuration_ch.setter + def power_configuration_ch(self, channel: zha_typing.ChannelType) -> None: + """Power configuration channel setter.""" + if self._power_config is None: + self._power_config = channel + + @property + def semaphore(self) -> asyncio.Semaphore: + """Return semaphore for concurrent tasks.""" + return self._semaphore + + @property + def zdo_channel(self) -> zha_typing.ZDOChannelType: + """Return ZDO channel.""" + return self._zdo_channel + + @property + def zha_device(self) -> zha_typing.ZhaDeviceType: + """Return parent zha device.""" + return self._zha_device + + @property + def unique_id(self): + """Return the unique id for this channel.""" + return self._unique_id + + @classmethod + def new(cls, zha_device: zha_typing.ZhaDeviceType) -> "Channels": + """Create new instance.""" + channels = cls(zha_device) + for ep_id in sorted(zha_device.device.endpoints): + channels.add_pool(ep_id) + return channels + + def add_pool(self, ep_id: int) -> None: + """Add channels for a specific endpoint.""" + if ep_id == 0: + return + self._pools.append(ChannelPool.new(self, ep_id)) + + async def async_initialize(self, from_cache: bool = False) -> None: + """Initialize claimed channels.""" + await self.zdo_channel.async_initialize(from_cache) + self.zdo_channel.debug("'async_initialize' stage succeeded") + await asyncio.gather( + *(pool.async_initialize(from_cache) for pool in self.pools) + ) + + async def async_configure(self) -> None: + """Configure claimed channels.""" + await self.zdo_channel.async_configure() + self.zdo_channel.debug("'async_configure' stage succeeded") + await asyncio.gather(*(pool.async_configure() for pool in self.pools)) + + @callback + def async_new_entity( + self, + component: str, + entity_class: zha_typing.CALLABLE_T, + unique_id: str, + channels: List[zha_typing.ChannelType], + ): + """Signal new entity addition.""" + if self.zha_device.status == zha_core_device.DeviceStatus.INITIALIZED: + return + + self.zha_device.hass.data[const.DATA_ZHA][component].append( + (entity_class, (unique_id, self.zha_device, channels)) + ) + + @callback + def async_send_signal(self, signal: str, *args: Any) -> None: + """Send a signal through hass dispatcher.""" + async_dispatcher_send(self.zha_device.hass, signal, *args) + + @callback + def zha_send_event(self, event_data: Dict[str, Union[str, int]]) -> None: + """Relay events to hass.""" + self.zha_device.hass.bus.async_fire( + "zha_event", + { + const.ATTR_DEVICE_IEEE: str(self.zha_device.ieee), + const.ATTR_UNIQUE_ID: self.unique_id, + **event_data, + }, + ) + + +class ChannelPool: + """All channels of an endpoint.""" + + def __init__(self, channels: Channels, ep_id: int): + """Initialize instance.""" + self._all_channels: ChannelsDict = {} + self._channels: Channels = channels + self._claimed_channels: ChannelsDict = {} + self._id: int = ep_id + self._relay_channels: Dict[str, zha_typing.EventRelayChannelType] = {} + self._unique_id: str = f"{channels.unique_id}-{ep_id}" + + @property + def all_channels(self) -> ChannelsDict: + """All channels of an endpoint.""" + return self._all_channels + + @property + def claimed_channels(self) -> ChannelsDict: + """Channels in use.""" + return self._claimed_channels + + @property + def endpoint(self) -> zha_typing.ZigpyEndpointType: + """Return endpoint of zigpy device.""" + return self._channels.zha_device.device.endpoints[self.id] + + @property + def id(self) -> int: + """Return endpoint id.""" + return self._id + + @property + def nwk(self) -> int: + """Device NWK for logging.""" + return self._channels.zha_device.nwk + + @property + def manufacturer(self) -> Optional[str]: + """Return device manufacturer.""" + return self._channels.zha_device.manufacturer + + @property + def manufacturer_code(self) -> Optional[int]: + """Return device manufacturer.""" + return self._channels.zha_device.manufacturer_code + + @property + def model(self) -> Optional[str]: + """Return device model.""" + return self._channels.zha_device.model + + @property + def relay_channels(self) -> Dict[str, zha_typing.EventRelayChannelType]: + """Return a dict of event relay channels.""" + return self._relay_channels + + @property + def skip_configuration(self) -> bool: + """Return True if device does not require channel configuration.""" + return self._channels.zha_device.skip_configuration + + @property + def unique_id(self): + """Return the unique id for this channel.""" + return self._unique_id + + @classmethod + def new(cls, channels: Channels, ep_id: int) -> "ChannelPool": + """Create new channels for an endpoint.""" + pool = cls(channels, ep_id) + pool.add_all_channels() + pool.add_relay_channels() + zha_disc.PROBE.discover_entities(pool) + return pool + + @callback + def add_all_channels(self) -> None: + """Create and add channels for all input clusters.""" + for cluster_id, cluster in self.endpoint.in_clusters.items(): + channel_class = zha_regs.ZIGBEE_CHANNEL_REGISTRY.get( + cluster_id, base.AttributeListeningChannel + ) + # really ugly hack to deal with xiaomi using the door lock cluster + # incorrectly. + if ( + hasattr(cluster, "ep_attribute") + and cluster.ep_attribute == "multistate_input" + ): + channel_class = base.AttributeListeningChannel + # end of ugly hack + channel = channel_class(cluster, self) + if channel.name == const.CHANNEL_POWER_CONFIGURATION: + if ( + self._channels.power_configuration_ch + or self._channels.zha_device.is_mains_powered + ): + # on power configuration channel per device + continue + self._channels.power_configuration_ch = channel + + self.all_channels[channel.id] = channel + + @callback + def add_relay_channels(self) -> None: + """Create relay channels for all output clusters if in the registry.""" + for cluster_id in zha_regs.EVENT_RELAY_CLUSTERS: + cluster = self.endpoint.out_clusters.get(cluster_id) + if cluster is not None: + channel = base.EventRelayChannel(cluster, self) + self.relay_channels[channel.id] = channel + + async def async_initialize(self, from_cache: bool = False) -> None: + """Initialize claimed channels.""" + await self._execute_channel_tasks("async_initialize", from_cache) + + async def async_configure(self) -> None: + """Configure claimed channels.""" + await self._execute_channel_tasks("async_configure") + + async def _execute_channel_tasks(self, func_name: str, *args: Any) -> None: + """Add a throttled channel task and swallow exceptions.""" + + async def _throttle(coro): + async with self._channels.semaphore: + return await coro + + channels = [*self.claimed_channels.values(), *self.relay_channels.values()] + tasks = [_throttle(getattr(ch, func_name)(*args)) for ch in channels] + results = await asyncio.gather(*tasks, return_exceptions=True) + for channel, outcome in zip(channels, results): + if isinstance(outcome, Exception): + channel.warning("'%s' stage failed: %s", func_name, str(outcome)) + continue + channel.debug("'%s' stage succeeded", func_name) + + @callback + def async_new_entity( + self, + component: str, + entity_class: zha_typing.CALLABLE_T, + unique_id: str, + channels: List[zha_typing.ChannelType], + ): + """Signal new entity addition.""" + self._channels.async_new_entity(component, entity_class, unique_id, channels) + + @callback + def async_send_signal(self, signal: str, *args: Any) -> None: + """Send a signal through hass dispatcher.""" + self._channels.async_send_signal(signal, *args) + + @callback + def claim_channels(self, channels: List[zha_typing.ChannelType]) -> None: + """Claim a channel.""" + self.claimed_channels.update({ch.id: ch for ch in channels}) + + @callback + def unclaimed_channels(self) -> List[zha_typing.ChannelType]: + """Return a list of available (unclaimed) channels.""" + claimed = set(self.claimed_channels) + available = set(self.all_channels) + return [self.all_channels[chan_id] for chan_id in (available - claimed)] + + @callback + def zha_send_event(self, event_data: Dict[str, Union[str, int]]) -> None: + """Relay events to hass.""" + self._channels.zha_send_event( + { + const.ATTR_UNIQUE_ID: self.unique_id, + const.ATTR_ENDPOINT_ID: self.id, + **event_data, + } + ) diff --git a/homeassistant/components/zha/core/channels/base.py b/homeassistant/components/zha/core/channels/base.py new file mode 100644 index 00000000000..7bb2ad7b57e --- /dev/null +++ b/homeassistant/components/zha/core/channels/base.py @@ -0,0 +1,383 @@ +"""Base classes for channels.""" + +import asyncio +from enum import Enum +from functools import wraps +import logging +from random import uniform +from typing import Any, Union + +import zigpy.exceptions + +from homeassistant.core import callback + +from .. import typing as zha_typing +from ..const import ( + ATTR_ARGS, + ATTR_ATTRIBUTE_ID, + ATTR_ATTRIBUTE_NAME, + ATTR_CLUSTER_ID, + ATTR_COMMAND, + ATTR_UNIQUE_ID, + ATTR_VALUE, + CHANNEL_EVENT_RELAY, + CHANNEL_ZDO, + REPORT_CONFIG_DEFAULT, + REPORT_CONFIG_MAX_INT, + REPORT_CONFIG_MIN_INT, + REPORT_CONFIG_RPT_CHANGE, + SIGNAL_ATTR_UPDATED, +) +from ..helpers import LogMixin, get_attr_id_by_name, safe_read + +_LOGGER = logging.getLogger(__name__) + + +def parse_and_log_command(channel, tsn, command_id, args): + """Parse and log a zigbee cluster command.""" + cmd = channel.cluster.server_commands.get(command_id, [command_id])[0] + channel.debug( + "received '%s' command with %s args on cluster_id '%s' tsn '%s'", + cmd, + args, + channel.cluster.cluster_id, + tsn, + ) + return cmd + + +def decorate_command(channel, command): + """Wrap a cluster command to make it safe.""" + + @wraps(command) + async def wrapper(*args, **kwds): + try: + result = await command(*args, **kwds) + channel.debug( + "executed '%s' command with args: '%s' kwargs: '%s' result: %s", + command.__name__, + args, + kwds, + result, + ) + return result + + except (zigpy.exceptions.DeliveryError, asyncio.TimeoutError) as ex: + channel.debug("command failed: %s exception: %s", command.__name__, str(ex)) + return ex + + return wrapper + + +class ChannelStatus(Enum): + """Status of a channel.""" + + CREATED = 1 + CONFIGURED = 2 + INITIALIZED = 3 + + +class ZigbeeChannel(LogMixin): + """Base channel for a Zigbee cluster.""" + + CHANNEL_NAME = None + REPORT_CONFIG = () + + def __init__( + self, cluster: zha_typing.ZigpyClusterType, ch_pool: zha_typing.ChannelPoolType, + ) -> None: + """Initialize ZigbeeChannel.""" + self._channel_name = cluster.ep_attribute + if self.CHANNEL_NAME: + self._channel_name = self.CHANNEL_NAME + self._ch_pool = ch_pool + self._generic_id = f"channel_0x{cluster.cluster_id:04x}" + self._cluster = cluster + self._id = f"{ch_pool.id}:0x{cluster.cluster_id:04x}" + unique_id = ch_pool.unique_id.replace("-", ":") + self._unique_id = f"{unique_id}:0x{cluster.cluster_id:04x}" + self._report_config = self.REPORT_CONFIG + self._status = ChannelStatus.CREATED + self._cluster.add_listener(self) + + @property + def id(self) -> str: + """Return channel id unique for this device only.""" + return self._id + + @property + def generic_id(self): + """Return the generic id for this channel.""" + return self._generic_id + + @property + def unique_id(self): + """Return the unique id for this channel.""" + return self._unique_id + + @property + def cluster(self): + """Return the zigpy cluster for this channel.""" + return self._cluster + + @property + def name(self) -> str: + """Return friendly name.""" + return self._channel_name + + @property + def status(self): + """Return the status of the channel.""" + return self._status + + @callback + def async_send_signal(self, signal: str, *args: Any) -> None: + """Send a signal through hass dispatcher.""" + self._ch_pool.async_send_signal(signal, *args) + + async def bind(self): + """Bind a zigbee cluster. + + This also swallows DeliveryError exceptions that are thrown when + devices are unreachable. + """ + try: + res = await self.cluster.bind() + self.debug("bound '%s' cluster: %s", self.cluster.ep_attribute, res[0]) + except (zigpy.exceptions.DeliveryError, asyncio.TimeoutError) as ex: + self.debug( + "Failed to bind '%s' cluster: %s", self.cluster.ep_attribute, str(ex) + ) + + async def configure_reporting( + self, + attr, + report_config=( + REPORT_CONFIG_MIN_INT, + REPORT_CONFIG_MAX_INT, + REPORT_CONFIG_RPT_CHANGE, + ), + ): + """Configure attribute reporting for a cluster. + + This also swallows DeliveryError exceptions that are thrown when + devices are unreachable. + """ + attr_name = self.cluster.attributes.get(attr, [attr])[0] + + kwargs = {} + if self.cluster.cluster_id >= 0xFC00 and self._ch_pool.manufacturer_code: + kwargs["manufacturer"] = self._ch_pool.manufacturer_code + + min_report_int, max_report_int, reportable_change = report_config + try: + res = await self.cluster.configure_reporting( + attr, min_report_int, max_report_int, reportable_change, **kwargs + ) + self.debug( + "reporting '%s' attr on '%s' cluster: %d/%d/%d: Result: '%s'", + attr_name, + self.cluster.ep_attribute, + min_report_int, + max_report_int, + reportable_change, + res, + ) + except (zigpy.exceptions.DeliveryError, asyncio.TimeoutError) as ex: + self.debug( + "failed to set reporting for '%s' attr on '%s' cluster: %s", + attr_name, + self.cluster.ep_attribute, + str(ex), + ) + + async def async_configure(self): + """Set cluster binding and attribute reporting.""" + if not self._ch_pool.skip_configuration: + await self.bind() + if self.cluster.is_server: + for report_config in self._report_config: + await self.configure_reporting( + report_config["attr"], report_config["config"] + ) + await asyncio.sleep(uniform(0.1, 0.5)) + self.debug("finished channel configuration") + else: + self.debug("skipping channel configuration") + self._status = ChannelStatus.CONFIGURED + + async def async_initialize(self, from_cache): + """Initialize channel.""" + self.debug("initializing channel: from_cache: %s", from_cache) + self._status = ChannelStatus.INITIALIZED + + @callback + def cluster_command(self, tsn, command_id, args): + """Handle commands received to this cluster.""" + pass + + @callback + def attribute_updated(self, attrid, value): + """Handle attribute updates on this cluster.""" + pass + + @callback + def zdo_command(self, *args, **kwargs): + """Handle ZDO commands on this cluster.""" + pass + + @callback + def zha_send_event(self, command: str, args: Union[int, dict]) -> None: + """Relay events to hass.""" + self._ch_pool.zha_send_event( + { + ATTR_UNIQUE_ID: self.unique_id, + ATTR_CLUSTER_ID: self.cluster.cluster_id, + ATTR_COMMAND: command, + ATTR_ARGS: args, + } + ) + + async def async_update(self): + """Retrieve latest state from cluster.""" + pass + + async def get_attribute_value(self, attribute, from_cache=True): + """Get the value for an attribute.""" + manufacturer = None + manufacturer_code = self._ch_pool.manufacturer_code + if self.cluster.cluster_id >= 0xFC00 and manufacturer_code: + manufacturer = manufacturer_code + result = await safe_read( + self._cluster, + [attribute], + allow_cache=from_cache, + only_cache=from_cache, + manufacturer=manufacturer, + ) + return result.get(attribute) + + def log(self, level, msg, *args): + """Log a message.""" + msg = f"[%s:%s]: {msg}" + args = (self._ch_pool.nwk, self._id) + args + _LOGGER.log(level, msg, *args) + + def __getattr__(self, name): + """Get attribute or a decorated cluster command.""" + if hasattr(self._cluster, name) and callable(getattr(self._cluster, name)): + command = getattr(self._cluster, name) + command.__name__ = name + return decorate_command(self, command) + return self.__getattribute__(name) + + +class AttributeListeningChannel(ZigbeeChannel): + """Channel for attribute reports from the cluster.""" + + REPORT_CONFIG = [{"attr": 0, "config": REPORT_CONFIG_DEFAULT}] + + def __init__( + self, cluster: zha_typing.ZigpyClusterType, ch_pool: zha_typing.ChannelPoolType, + ) -> None: + """Initialize AttributeListeningChannel.""" + super().__init__(cluster, ch_pool) + attr = self._report_config[0].get("attr") + if isinstance(attr, str): + self.value_attribute = get_attr_id_by_name(self.cluster, attr) + else: + self.value_attribute = attr + + @callback + def attribute_updated(self, attrid, value): + """Handle attribute updates on this cluster.""" + if attrid == self.value_attribute: + self.async_send_signal(f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value) + + async def async_initialize(self, from_cache): + """Initialize listener.""" + await self.get_attribute_value( + self._report_config[0].get("attr"), from_cache=from_cache + ) + await super().async_initialize(from_cache) + + +class ZDOChannel(LogMixin): + """Channel for ZDO events.""" + + def __init__(self, cluster, device): + """Initialize ZDOChannel.""" + self.name = CHANNEL_ZDO + self._cluster = cluster + self._zha_device = device + self._status = ChannelStatus.CREATED + self._unique_id = "{}:{}_ZDO".format(str(device.ieee), device.name) + self._cluster.add_listener(self) + + @property + def unique_id(self): + """Return the unique id for this channel.""" + return self._unique_id + + @property + def cluster(self): + """Return the aigpy cluster for this channel.""" + return self._cluster + + @property + def status(self): + """Return the status of the channel.""" + return self._status + + @callback + def device_announce(self, zigpy_device): + """Device announce handler.""" + pass + + @callback + def permit_duration(self, duration): + """Permit handler.""" + pass + + async def async_initialize(self, from_cache): + """Initialize channel.""" + self._status = ChannelStatus.INITIALIZED + + async def async_configure(self): + """Configure channel.""" + self._status = ChannelStatus.CONFIGURED + + def log(self, level, msg, *args): + """Log a message.""" + msg = f"[%s:ZDO](%s): {msg}" + args = (self._zha_device.nwk, self._zha_device.model) + args + _LOGGER.log(level, msg, *args) + + +class EventRelayChannel(ZigbeeChannel): + """Event relay that can be attached to zigbee clusters.""" + + CHANNEL_NAME = CHANNEL_EVENT_RELAY + + @callback + def attribute_updated(self, attrid, value): + """Handle an attribute updated on this cluster.""" + self.zha_send_event( + SIGNAL_ATTR_UPDATED, + { + ATTR_ATTRIBUTE_ID: attrid, + ATTR_ATTRIBUTE_NAME: self._cluster.attributes.get(attrid, ["Unknown"])[ + 0 + ], + ATTR_VALUE: value, + }, + ) + + @callback + def cluster_command(self, tsn, command_id, args): + """Handle a cluster command received on this cluster.""" + if ( + self._cluster.server_commands is not None + and self._cluster.server_commands.get(command_id) is not None + ): + self.zha_send_event(self._cluster.server_commands.get(command_id)[0], args) diff --git a/homeassistant/components/zha/core/channels/closures.py b/homeassistant/components/zha/core/channels/closures.py index 0cf6f840070..e25c2253bb3 100644 --- a/homeassistant/components/zha/core/channels/closures.py +++ b/homeassistant/components/zha/core/channels/closures.py @@ -4,11 +4,10 @@ import logging import zigpy.zcl.clusters.closures as closures from homeassistant.core import callback -from homeassistant.helpers.dispatcher import async_dispatcher_send -from . import ZigbeeChannel from .. import registries from ..const import REPORT_CONFIG_IMMEDIATE, SIGNAL_ATTR_UPDATED +from .base import ZigbeeChannel _LOGGER = logging.getLogger(__name__) @@ -24,9 +23,7 @@ class DoorLockChannel(ZigbeeChannel): """Retrieve latest state.""" result = await self.get_attribute_value("lock_state", from_cache=True) - async_dispatcher_send( - self._zha_device.hass, f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", result - ) + self.async_send_signal(f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", result) @callback def attribute_updated(self, attrid, value): @@ -36,9 +33,7 @@ class DoorLockChannel(ZigbeeChannel): "Attribute report '%s'[%s] = %s", self.cluster.name, attr_name, value ) if attrid == self._value_attribute: - async_dispatcher_send( - self._zha_device.hass, f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value - ) + self.async_send_signal(f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value) async def async_initialize(self, from_cache): """Initialize channel.""" @@ -69,9 +64,7 @@ class WindowCovering(ZigbeeChannel): ) self.debug("read current position: %s", result) - async_dispatcher_send( - self._zha_device.hass, f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", result - ) + self.async_send_signal(f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", result) @callback def attribute_updated(self, attrid, value): @@ -81,9 +74,7 @@ class WindowCovering(ZigbeeChannel): "Attribute report '%s'[%s] = %s", self.cluster.name, attr_name, value ) if attrid == self._value_attribute: - async_dispatcher_send( - self._zha_device.hass, f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value - ) + self.async_send_signal(f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value) async def async_initialize(self, from_cache): """Initialize channel.""" diff --git a/homeassistant/components/zha/core/channels/general.py b/homeassistant/components/zha/core/channels/general.py index 111b35e7e58..3e41e961f0a 100644 --- a/homeassistant/components/zha/core/channels/general.py +++ b/homeassistant/components/zha/core/channels/general.py @@ -4,11 +4,9 @@ import logging import zigpy.zcl.clusters.general as general from homeassistant.core import callback -from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.event import async_call_later -from . import AttributeListeningChannel, ZigbeeChannel, parse_and_log_command -from .. import registries +from .. import registries, typing as zha_typing from ..const import ( REPORT_CONFIG_ASAP, REPORT_CONFIG_BATTERY_SAVE, @@ -20,6 +18,7 @@ from ..const import ( SIGNAL_STATE_ATTR, ) from ..helpers import get_attr_id_by_name +from .base import AttributeListeningChannel, ZigbeeChannel, parse_and_log_command _LOGGER = logging.getLogger(__name__) @@ -77,9 +76,11 @@ class BasicChannel(ZigbeeChannel): 6: "Emergency mains and transfer switch", } - def __init__(self, cluster, device): + def __init__( + self, cluster: zha_typing.ZigpyClusterType, ch_pool: zha_typing.ChannelPoolType, + ) -> None: """Initialize BasicChannel.""" - super().__init__(cluster, device) + super().__init__(cluster, ch_pool) self._power_source = None async def async_configure(self): @@ -193,9 +194,7 @@ class LevelControlChannel(ZigbeeChannel): def dispatch_level_change(self, command, level): """Dispatch level change.""" - async_dispatcher_send( - self._zha_device.hass, f"{self.unique_id}_{command}", level - ) + self.async_send_signal(f"{self.unique_id}_{command}", level) async def async_initialize(self, from_cache): """Initialize channel.""" @@ -236,9 +235,11 @@ class OnOffChannel(ZigbeeChannel): ON_OFF = 0 REPORT_CONFIG = ({"attr": "on_off", "config": REPORT_CONFIG_IMMEDIATE},) - def __init__(self, cluster, device): + def __init__( + self, cluster: zha_typing.ZigpyClusterType, ch_pool: zha_typing.ChannelPoolType, + ) -> None: """Initialize OnOffChannel.""" - super().__init__(cluster, device) + super().__init__(cluster, ch_pool) self._state = None self._off_listener = None @@ -279,9 +280,7 @@ class OnOffChannel(ZigbeeChannel): def attribute_updated(self, attrid, value): """Handle attribute updates on this cluster.""" if attrid == self.ON_OFF: - async_dispatcher_send( - self._zha_device.hass, f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value - ) + self.async_send_signal(f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value) self._state = bool(value) async def async_initialize(self, from_cache): @@ -293,10 +292,11 @@ class OnOffChannel(ZigbeeChannel): async def async_update(self): """Initialize channel.""" - from_cache = not self.device.is_mains_powered - self.debug("attempting to update onoff state - from cache: %s", from_cache) + if self.cluster.is_client: + return + self.debug("attempting to update onoff state - from cache: False") self._state = bool( - await self.get_attribute_value(self.ON_OFF, from_cache=from_cache) + await self.get_attribute_value(self.ON_OFF, from_cache=False) ) await super().async_update() @@ -348,16 +348,11 @@ class PowerConfigurationChannel(ZigbeeChannel): else: attr_id = attr if attrid == attr_id: - async_dispatcher_send( - self._zha_device.hass, f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value - ) + self.async_send_signal(f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value) return attr_name = self.cluster.attributes.get(attrid, [attrid])[0] - async_dispatcher_send( - self._zha_device.hass, - f"{self.unique_id}_{SIGNAL_STATE_ATTR}", - attr_name, - value, + self.async_send_signal( + f"{self.unique_id}_{SIGNAL_STATE_ATTR}", attr_name, value ) async def async_initialize(self, from_cache): diff --git a/homeassistant/components/zha/core/channels/homeautomation.py b/homeassistant/components/zha/core/channels/homeautomation.py index 8c2c2e57972..e47aca5eafd 100644 --- a/homeassistant/components/zha/core/channels/homeautomation.py +++ b/homeassistant/components/zha/core/channels/homeautomation.py @@ -4,15 +4,13 @@ from typing import Optional import zigpy.zcl.clusters.homeautomation as homeautomation -from homeassistant.helpers.dispatcher import async_dispatcher_send - -from . import AttributeListeningChannel, ZigbeeChannel -from .. import registries +from .. import registries, typing as zha_typing from ..const import ( CHANNEL_ELECTRICAL_MEASUREMENT, REPORT_CONFIG_DEFAULT, SIGNAL_ATTR_UPDATED, ) +from .base import AttributeListeningChannel, ZigbeeChannel _LOGGER = logging.getLogger(__name__) @@ -61,9 +59,11 @@ class ElectricalMeasurementChannel(AttributeListeningChannel): REPORT_CONFIG = ({"attr": "active_power", "config": REPORT_CONFIG_DEFAULT},) - def __init__(self, cluster, device): + def __init__( + self, cluster: zha_typing.ZigpyClusterType, ch_pool: zha_typing.ChannelPoolType, + ) -> None: """Initialize Metering.""" - super().__init__(cluster, device) + super().__init__(cluster, ch_pool) self._divisor = None self._multiplier = None @@ -73,9 +73,7 @@ class ElectricalMeasurementChannel(AttributeListeningChannel): # This is a polling channel. Don't allow cache. result = await self.get_attribute_value("active_power", from_cache=False) - async_dispatcher_send( - self._zha_device.hass, f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", result - ) + self.async_send_signal(f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", result) async def async_initialize(self, from_cache): """Initialize channel.""" diff --git a/homeassistant/components/zha/core/channels/hvac.py b/homeassistant/components/zha/core/channels/hvac.py index b638259b4a1..e4519d5cb2c 100644 --- a/homeassistant/components/zha/core/channels/hvac.py +++ b/homeassistant/components/zha/core/channels/hvac.py @@ -5,11 +5,10 @@ from zigpy.exceptions import DeliveryError import zigpy.zcl.clusters.hvac as hvac from homeassistant.core import callback -from homeassistant.helpers.dispatcher import async_dispatcher_send -from . import ZigbeeChannel from .. import registries from ..const import REPORT_CONFIG_OP, SIGNAL_ATTR_UPDATED +from .base import ZigbeeChannel _LOGGER = logging.getLogger(__name__) @@ -42,9 +41,7 @@ class FanChannel(ZigbeeChannel): """Retrieve latest state.""" result = await self.get_attribute_value("fan_mode", from_cache=True) - async_dispatcher_send( - self._zha_device.hass, f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", result - ) + self.async_send_signal(f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", result) @callback def attribute_updated(self, attrid, value): @@ -54,9 +51,7 @@ class FanChannel(ZigbeeChannel): "Attribute report '%s'[%s] = %s", self.cluster.name, attr_name, value ) if attrid == self._value_attribute: - async_dispatcher_send( - self._zha_device.hass, f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value - ) + self.async_send_signal(f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value) async def async_initialize(self, from_cache): """Initialize channel.""" diff --git a/homeassistant/components/zha/core/channels/lighting.py b/homeassistant/components/zha/core/channels/lighting.py index 0a1e2048132..c87235d9ec0 100644 --- a/homeassistant/components/zha/core/channels/lighting.py +++ b/homeassistant/components/zha/core/channels/lighting.py @@ -3,9 +3,9 @@ import logging import zigpy.zcl.clusters.lighting as lighting -from . import ZigbeeChannel -from .. import registries +from .. import registries, typing as zha_typing from ..const import REPORT_CONFIG_DEFAULT +from .base import ZigbeeChannel _LOGGER = logging.getLogger(__name__) @@ -33,9 +33,11 @@ class ColorChannel(ZigbeeChannel): {"attr": "color_temperature", "config": REPORT_CONFIG_DEFAULT}, ) - def __init__(self, cluster, device): + def __init__( + self, cluster: zha_typing.ZigpyClusterType, ch_pool: zha_typing.ChannelPoolType, + ) -> None: """Initialize ColorChannel.""" - super().__init__(cluster, device) + super().__init__(cluster, ch_pool) self._color_capabilities = None def get_color_capabilities(self): diff --git a/homeassistant/components/zha/core/channels/lightlink.py b/homeassistant/components/zha/core/channels/lightlink.py index 5d0ac199185..af0248c9713 100644 --- a/homeassistant/components/zha/core/channels/lightlink.py +++ b/homeassistant/components/zha/core/channels/lightlink.py @@ -3,8 +3,8 @@ import logging import zigpy.zcl.clusters.lightlink as lightlink -from . import ZigbeeChannel from .. import registries +from .base import ZigbeeChannel _LOGGER = logging.getLogger(__name__) diff --git a/homeassistant/components/zha/core/channels/manufacturerspecific.py b/homeassistant/components/zha/core/channels/manufacturerspecific.py index e3d1e67439f..90f81513ec4 100644 --- a/homeassistant/components/zha/core/channels/manufacturerspecific.py +++ b/homeassistant/components/zha/core/channels/manufacturerspecific.py @@ -2,16 +2,19 @@ import logging from homeassistant.core import callback -from homeassistant.helpers.dispatcher import async_dispatcher_send -from . import AttributeListeningChannel, ZigbeeChannel from .. import registries from ..const import ( + ATTR_ATTRIBUTE_ID, + ATTR_ATTRIBUTE_NAME, + ATTR_VALUE, REPORT_CONFIG_ASAP, REPORT_CONFIG_MAX_INT, REPORT_CONFIG_MIN_INT, SIGNAL_ATTR_UPDATED, + UNKNOWN, ) +from .base import AttributeListeningChannel, ZigbeeChannel _LOGGER = logging.getLogger(__name__) @@ -53,18 +56,14 @@ class SmartThingsAcceleration(AttributeListeningChannel): def attribute_updated(self, attrid, value): """Handle attribute updates on this cluster.""" if attrid == self.value_attribute: - async_dispatcher_send( - self._zha_device.hass, f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value - ) - else: - self.zha_send_event( - self._cluster, - SIGNAL_ATTR_UPDATED, - { - "attribute_id": attrid, - "attribute_name": self._cluster.attributes.get(attrid, ["Unknown"])[ - 0 - ], - "value": value, - }, - ) + self.async_send_signal(f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value) + return + + self.zha_send_event( + SIGNAL_ATTR_UPDATED, + { + ATTR_ATTRIBUTE_ID: attrid, + ATTR_ATTRIBUTE_NAME: self._cluster.attributes.get(attrid, [UNKNOWN])[0], + ATTR_VALUE: value, + }, + ) diff --git a/homeassistant/components/zha/core/channels/measurement.py b/homeassistant/components/zha/core/channels/measurement.py index dfb83224505..68952c64e8d 100644 --- a/homeassistant/components/zha/core/channels/measurement.py +++ b/homeassistant/components/zha/core/channels/measurement.py @@ -3,7 +3,6 @@ import logging import zigpy.zcl.clusters.measurement as measurement -from . import AttributeListeningChannel from .. import registries from ..const import ( REPORT_CONFIG_DEFAULT, @@ -11,6 +10,7 @@ from ..const import ( REPORT_CONFIG_MAX_INT, REPORT_CONFIG_MIN_INT, ) +from .base import AttributeListeningChannel _LOGGER = logging.getLogger(__name__) diff --git a/homeassistant/components/zha/core/channels/protocol.py b/homeassistant/components/zha/core/channels/protocol.py index 20867553121..db7488e9a7f 100644 --- a/homeassistant/components/zha/core/channels/protocol.py +++ b/homeassistant/components/zha/core/channels/protocol.py @@ -4,7 +4,7 @@ import logging import zigpy.zcl.clusters.protocol as protocol from .. import registries -from ..channels import ZigbeeChannel +from .base import ZigbeeChannel _LOGGER = logging.getLogger(__name__) diff --git a/homeassistant/components/zha/core/channels/security.py b/homeassistant/components/zha/core/channels/security.py index a529ff69d32..20390c018d8 100644 --- a/homeassistant/components/zha/core/channels/security.py +++ b/homeassistant/components/zha/core/channels/security.py @@ -1,16 +1,19 @@ -"""Security channels module for Zigbee Home Automation.""" +""" +Security channels module for Zigbee Home Automation. + +For more details about this component, please refer to the documentation at +https://home-assistant.io/integrations/zha/ +""" +import asyncio import logging from zigpy.exceptions import DeliveryError import zigpy.zcl.clusters.security as security from homeassistant.core import callback -from homeassistant.helpers.dispatcher import async_dispatcher_send -from . import ZigbeeChannel from .. import registries from ..const import ( - CLUSTER_COMMAND_SERVER, SIGNAL_ATTR_UPDATED, WARNING_DEVICE_MODE_EMERGENCY, WARNING_DEVICE_SOUND_HIGH, @@ -18,6 +21,7 @@ from ..const import ( WARNING_DEVICE_STROBE_HIGH, WARNING_DEVICE_STROBE_YES, ) +from .base import ZigbeeChannel _LOGGER = logging.getLogger(__name__) @@ -70,13 +74,7 @@ class IasWd(ZigbeeChannel): value = IasWd.set_bit(value, 6, mode, 2) value = IasWd.set_bit(value, 7, mode, 3) - await self.device.issue_cluster_command( - self.cluster.endpoint.endpoint_id, - self.cluster.cluster_id, - 0x0001, - CLUSTER_COMMAND_SERVER, - [value], - ) + await self.squawk(value) async def start_warning( self, @@ -111,12 +109,8 @@ class IasWd(ZigbeeChannel): value = IasWd.set_bit(value, 6, mode, 2) value = IasWd.set_bit(value, 7, mode, 3) - await self.device.issue_cluster_command( - self.cluster.endpoint.endpoint_id, - self.cluster.cluster_id, - 0x0000, - CLUSTER_COMMAND_SERVER, - [value, warning_duration, strobe_duty_cycle, strobe_intensity], + await self.start_warning( + value, warning_duration, strobe_duty_cycle, strobe_intensity ) @@ -130,18 +124,17 @@ class IASZoneChannel(ZigbeeChannel): """Handle commands received to this cluster.""" if command_id == 0: state = args[0] & 3 - async_dispatcher_send( - self._zha_device.hass, f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", state - ) + self.async_send_signal(f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", state) self.debug("Updated alarm state: %s", state) elif command_id == 1: self.debug("Enroll requested") res = self._cluster.enroll_response(0, 0) - self._zha_device.hass.async_create_task(res) + asyncio.create_task(res) async def async_configure(self): """Configure IAS device.""" - if self._zha_device.skip_configuration: + await self.get_attribute_value("zone_type", from_cache=False) + if self._ch_pool.skip_configuration: self.debug("skipping IASZoneChannel configuration") return @@ -167,16 +160,12 @@ class IASZoneChannel(ZigbeeChannel): ) self.debug("finished IASZoneChannel configuration") - await self.get_attribute_value("zone_type", from_cache=False) - @callback def attribute_updated(self, attrid, value): """Handle attribute updates on this cluster.""" if attrid == 2: value = value & 3 - async_dispatcher_send( - self._zha_device.hass, f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value - ) + self.async_send_signal(f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", value) async def async_initialize(self, from_cache): """Initialize channel.""" diff --git a/homeassistant/components/zha/core/channels/smartenergy.py b/homeassistant/components/zha/core/channels/smartenergy.py index 08feb328603..b738b665e80 100644 --- a/homeassistant/components/zha/core/channels/smartenergy.py +++ b/homeassistant/components/zha/core/channels/smartenergy.py @@ -5,9 +5,9 @@ import zigpy.zcl.clusters.smartenergy as smartenergy from homeassistant.core import callback -from .. import registries -from ..channels import AttributeListeningChannel, ZigbeeChannel +from .. import registries, typing as zha_typing from ..const import REPORT_CONFIG_DEFAULT +from .base import AttributeListeningChannel, ZigbeeChannel _LOGGER = logging.getLogger(__name__) @@ -90,9 +90,11 @@ class Metering(AttributeListeningChannel): 0x0C: "MJ/s", } - def __init__(self, cluster, device): + def __init__( + self, cluster: zha_typing.ZigpyClusterType, ch_pool: zha_typing.ChannelPoolType, + ) -> None: """Initialize Metering.""" - super().__init__(cluster, device) + super().__init__(cluster, ch_pool) self._divisor = None self._multiplier = None self._unit_enum = None diff --git a/homeassistant/components/zha/core/const.py b/homeassistant/components/zha/core/const.py index f4cccfa4e52..cb0ac2182ec 100644 --- a/homeassistant/components/zha/core/const.py +++ b/homeassistant/components/zha/core/const.py @@ -13,11 +13,14 @@ from homeassistant.components.switch import DOMAIN as SWITCH ATTR_ARGS = "args" ATTR_ATTRIBUTE = "attribute" +ATTR_ATTRIBUTE_ID = "attribute_id" +ATTR_ATTRIBUTE_NAME = "attribute_name" ATTR_AVAILABLE = "available" ATTR_CLUSTER_ID = "cluster_id" ATTR_CLUSTER_TYPE = "cluster_type" ATTR_COMMAND = "command" ATTR_COMMAND_TYPE = "command_type" +ATTR_DEVICE_IEEE = "device_ieee" ATTR_DEVICE_TYPE = "device_type" ATTR_ENDPOINT_ID = "endpoint_id" ATTR_IEEE = "ieee" @@ -36,6 +39,7 @@ ATTR_QUIRK_CLASS = "quirk_class" ATTR_RSSI = "rssi" ATTR_SIGNATURE = "signature" ATTR_TYPE = "type" +ATTR_UNIQUE_ID = "unique_id" ATTR_VALUE = "value" ATTR_WARNING_DEVICE_DURATION = "duration" ATTR_WARNING_DEVICE_MODE = "mode" @@ -47,6 +51,7 @@ BAUD_RATES = [2400, 4800, 9600, 14400, 19200, 38400, 57600, 115200, 128000, 2560 BINDINGS = "bindings" CHANNEL_ACCELEROMETER = "accelerometer" +CHANNEL_ANALOG_INPUT = "analog_input" CHANNEL_ATTRIBUTE = "attribute" CHANNEL_BASIC = "basic" CHANNEL_COLOR = "light_color" @@ -92,6 +97,7 @@ DATA_ZHA_BRIDGE_ID = "zha_bridge_id" DATA_ZHA_CORE_EVENTS = "zha_core_events" DATA_ZHA_DISPATCHERS = "zha_dispatchers" DATA_ZHA_GATEWAY = "zha_gateway" +DATA_ZHA_PLATFORM_LOADED = "platform_loaded" DEBUG_COMP_BELLOWS = "bellows" DEBUG_COMP_ZHA = "homeassistant.components.zha" @@ -192,6 +198,7 @@ SENSOR_PRESSURE = CHANNEL_PRESSURE SENSOR_TEMPERATURE = CHANNEL_TEMPERATURE SENSOR_TYPE = "sensor_type" +SIGNAL_ADD_ENTITIES = "zha_add_new_entities" SIGNAL_ATTR_UPDATED = "attribute_updated" SIGNAL_AVAILABLE = "available" SIGNAL_MOVE_LEVEL = "move_level" diff --git a/homeassistant/components/zha/core/device.py b/homeassistant/components/zha/core/device.py index ffa264dde63..54c1bbe49a8 100644 --- a/homeassistant/components/zha/core/device.py +++ b/homeassistant/components/zha/core/device.py @@ -18,8 +18,9 @@ from homeassistant.helpers.dispatcher import ( async_dispatcher_send, ) from homeassistant.helpers.event import async_track_time_interval +from homeassistant.helpers.typing import HomeAssistantType -from .channels import EventRelayChannel +from . import channels, typing as zha_typing from .const import ( ATTR_ARGS, ATTR_ATTRIBUTE, @@ -42,9 +43,6 @@ from .const import ( ATTR_QUIRK_CLASS, ATTR_RSSI, ATTR_VALUE, - CHANNEL_BASIC, - CHANNEL_POWER_CONFIGURATION, - CHANNEL_ZDO, CLUSTER_COMMAND_SERVER, CLUSTER_COMMANDS_CLIENT, CLUSTER_COMMANDS_SERVER, @@ -75,14 +73,16 @@ class DeviceStatus(Enum): class ZHADevice(LogMixin): """ZHA Zigbee device object.""" - def __init__(self, hass, zigpy_device, zha_gateway): + def __init__( + self, + hass: HomeAssistantType, + zigpy_device: zha_typing.ZigpyDeviceType, + zha_gateway: zha_typing.ZhaGatewayType, + ): """Initialize the gateway.""" self.hass = hass self._zigpy_device = zigpy_device self._zha_gateway = zha_gateway - self.cluster_channels = {} - self._relay_channels = {} - self._all_channels = [] self._available = False self._available_signal = "{}_{}_{}".format( self.name, self.ieee, SIGNAL_AVAILABLE @@ -101,6 +101,7 @@ class ZHADevice(LogMixin): ) self._ha_device_id = None self.status = DeviceStatus.CREATED + self._channels = channels.Channels(self) @property def device_id(self): @@ -111,6 +112,22 @@ class ZHADevice(LogMixin): """Set the HA device registry device id.""" self._ha_device_id = device_id + @property + def device(self) -> zha_typing.ZigpyDeviceType: + """Return underlying Zigpy device.""" + return self._zigpy_device + + @property + def channels(self) -> zha_typing.ChannelsType: + """Return ZHA channels.""" + return self._channels + + @channels.setter + def channels(self, value: zha_typing.ChannelsType) -> None: + """Channels setter.""" + assert isinstance(value, channels.Channels) + self._channels = value + @property def name(self): """Return device name.""" @@ -218,11 +235,6 @@ class ZHADevice(LogMixin): """Return the gateway for this device.""" return self._zha_gateway - @property - def all_channels(self): - """Return cluster channels and relay channels for device.""" - return self._all_channels - @property def device_automation_triggers(self): """Return the device automation triggers for this device.""" @@ -244,6 +256,19 @@ class ZHADevice(LogMixin): """Set availability from restore and prevent signals.""" self._available = available + @classmethod + def new( + cls, + hass: HomeAssistantType, + zigpy_dev: zha_typing.ZigpyDeviceType, + gateway: zha_typing.ZhaGatewayType, + restored: bool = False, + ): + """Create new device.""" + zha_dev = cls(hass, zigpy_dev, gateway) + zha_dev.channels = channels.Channels.new(zha_dev) + return zha_dev + def _check_available(self, *_): if self.last_seen is None: self.update_available(False) @@ -252,16 +277,17 @@ class ZHADevice(LogMixin): if difference > _KEEP_ALIVE_INTERVAL: if self._checkins_missed_count < _CHECKIN_GRACE_PERIODS: self._checkins_missed_count += 1 - if ( - CHANNEL_BASIC in self.cluster_channels - and self.manufacturer != "LUMI" - ): + if self.manufacturer != "LUMI": self.debug( "Attempting to checkin with device - missed checkins: %s", self._checkins_missed_count, ) + if not self._channels.pools: + return + pool = self._channels.pools[0] + basic_ch = pool.all_channels[f"{pool.id}:0"] self.hass.async_create_task( - self.cluster_channels[CHANNEL_BASIC].get_attribute_value( + basic_ch.get_attribute_value( ATTR_MANUFACTURER, from_cache=False ) ) @@ -304,66 +330,10 @@ class ZHADevice(LogMixin): ATTR_DEVICE_TYPE: self.device_type, } - def add_cluster_channel(self, cluster_channel): - """Add cluster channel to device.""" - # only keep 1 power configuration channel - if ( - cluster_channel.name is CHANNEL_POWER_CONFIGURATION - and CHANNEL_POWER_CONFIGURATION in self.cluster_channels - ): - return - - if isinstance(cluster_channel, EventRelayChannel): - self._relay_channels[cluster_channel.unique_id] = cluster_channel - self._all_channels.append(cluster_channel) - else: - self.cluster_channels[cluster_channel.name] = cluster_channel - self._all_channels.append(cluster_channel) - - def get_channels_to_configure(self): - """Get a deduped list of channels for configuration. - - This goes through all channels and gets a unique list of channels to - configure. It first assembles a unique list of channels that are part - of entities while stashing relay channels off to the side. It then - takse the stashed relay channels and adds them to the list of channels - that will be returned if there isn't a channel in the list for that - cluster already. This is done to ensure each cluster is only configured - once. - """ - channel_keys = [] - channels = [] - relay_channels = self._relay_channels.values() - - def get_key(channel): - channel_key = "ZDO" - if hasattr(channel.cluster, "cluster_id"): - channel_key = "{}_{}".format( - channel.cluster.endpoint.endpoint_id, channel.cluster.cluster_id - ) - return channel_key - - # first we get all unique non event channels - for channel in self.all_channels: - c_key = get_key(channel) - if c_key not in channel_keys and channel not in relay_channels: - channel_keys.append(c_key) - channels.append(channel) - - # now we get event channels that still need their cluster configured - for channel in relay_channels: - channel_key = get_key(channel) - if channel_key not in channel_keys: - channel_keys.append(channel_key) - channels.append(channel) - return channels - async def async_configure(self): """Configure the device.""" self.debug("started configuration") - await self._execute_channel_tasks( - self.get_channels_to_configure(), "async_configure" - ) + await self._channels.async_configure() self.debug("completed configuration") entry = self.gateway.zha_storage.async_create_or_update(self) self.debug("stored in registry: %s", entry) @@ -371,41 +341,11 @@ class ZHADevice(LogMixin): async def async_initialize(self, from_cache=False): """Initialize channels.""" self.debug("started initialization") - await self._execute_channel_tasks( - self.all_channels, "async_initialize", from_cache - ) + await self._channels.async_initialize(from_cache) self.debug("power source: %s", self.power_source) self.status = DeviceStatus.INITIALIZED self.debug("completed initialization") - async def _execute_channel_tasks(self, channels, task_name, *args): - """Gather and execute a set of CHANNEL tasks.""" - channel_tasks = [] - semaphore = asyncio.Semaphore(3) - zdo_task = None - for channel in channels: - if channel.name == CHANNEL_ZDO: - if zdo_task is None: # We only want to do this once - zdo_task = self._async_create_task( - semaphore, channel, task_name, *args - ) - else: - channel_tasks.append( - self._async_create_task(semaphore, channel, task_name, *args) - ) - if zdo_task is not None: - await zdo_task - await asyncio.gather(*channel_tasks) - - async def _async_create_task(self, semaphore, channel, func_name, *args): - """Configure a single channel on this device.""" - try: - async with semaphore: - await getattr(channel, func_name)(*args) - channel.debug("channel: '%s' stage succeeded", func_name) - except Exception as ex: # pylint: disable=broad-except - channel.warning("channel: '%s' stage failed ex: %s", func_name, ex) - @callback def async_unsub_dispatcher(self): """Unsubscribe the dispatcher.""" diff --git a/homeassistant/components/zha/core/discovery.py b/homeassistant/components/zha/core/discovery.py index c8514e2937d..e6b844b9c43 100644 --- a/homeassistant/components/zha/core/discovery.py +++ b/homeassistant/components/zha/core/discovery.py @@ -1,268 +1,150 @@ """Device discovery functions for Zigbee Home Automation.""" import logging - -import zigpy.profiles -from zigpy.zcl.clusters.general import OnOff, PowerConfiguration +from typing import Callable, List, Tuple from homeassistant import const as ha_const from homeassistant.core import callback -from homeassistant.helpers.dispatcher import async_dispatcher_send +from homeassistant.helpers.typing import HomeAssistantType -from .channels import AttributeListeningChannel, EventRelayChannel, ZDOChannel -from .const import COMPONENTS, CONF_DEVICE_CONFIG, DATA_ZHA, ZHA_DISCOVERY_NEW -from .registries import ( - CHANNEL_ONLY_CLUSTERS, - COMPONENT_CLUSTERS, - DEVICE_CLASS, - EVENT_RELAY_CLUSTERS, - OUTPUT_CHANNEL_ONLY_CLUSTERS, - REMOTE_DEVICE_TYPES, - SINGLE_INPUT_CLUSTER_DEVICE_CLASS, - SINGLE_OUTPUT_CLUSTER_DEVICE_CLASS, - ZIGBEE_CHANNEL_REGISTRY, -) +from . import const as zha_const, registries as zha_regs, typing as zha_typing +from .channels import base _LOGGER = logging.getLogger(__name__) @callback -def async_process_endpoint( - hass, - config, - endpoint_id, - endpoint, - discovery_infos, - device, - zha_device, - is_new_join, -): - """Process an endpoint on a zigpy device.""" - if endpoint_id == 0: # ZDO - _async_create_cluster_channel( - endpoint, zha_device, is_new_join, channel_class=ZDOChannel - ) +async def async_add_entities( + _async_add_entities: Callable, + entities: List[ + Tuple[ + zha_typing.ZhaEntityType, + Tuple[str, zha_typing.ZhaDeviceType, List[zha_typing.ChannelType]], + ] + ], +) -> None: + """Add entities helper.""" + if not entities: return + to_add = [ent_cls(*args) for ent_cls, args in entities] + _async_add_entities(to_add, update_before_add=True) + entities.clear() - component = None - profile_clusters = [] - device_key = f"{device.ieee}-{endpoint_id}" - node_config = {} - if CONF_DEVICE_CONFIG in config: - node_config = config[CONF_DEVICE_CONFIG].get(device_key, {}) - if endpoint.profile_id in zigpy.profiles.PROFILES: - if DEVICE_CLASS.get(endpoint.profile_id, {}).get(endpoint.device_type, None): - profile_info = DEVICE_CLASS[endpoint.profile_id] - component = profile_info[endpoint.device_type] +class ProbeEndpoint: + """All discovered channels and entities of an endpoint.""" - if ha_const.CONF_TYPE in node_config: - component = node_config[ha_const.CONF_TYPE] + def __init__(self): + """Initialize instance.""" + self._device_configs = {} - if component and component in COMPONENTS and component in COMPONENT_CLUSTERS: - profile_clusters = COMPONENT_CLUSTERS[component] - if profile_clusters: - profile_match = _async_handle_profile_match( - hass, - endpoint, - profile_clusters, - zha_device, - component, - device_key, - is_new_join, + @callback + def discover_entities(self, channel_pool: zha_typing.ChannelPoolType) -> None: + """Process an endpoint on a zigpy device.""" + self.discover_by_device_type(channel_pool) + self.discover_by_cluster_id(channel_pool) + + @callback + def discover_by_device_type(self, channel_pool: zha_typing.ChannelPoolType) -> None: + """Process an endpoint on a zigpy device.""" + + unique_id = channel_pool.unique_id + + component = self._device_configs.get(unique_id, {}).get(ha_const.CONF_TYPE) + if component is None: + ep_profile_id = channel_pool.endpoint.profile_id + ep_device_type = channel_pool.endpoint.device_type + component = zha_regs.DEVICE_CLASS[ep_profile_id].get(ep_device_type) + + if component and component in zha_const.COMPONENTS: + channels = channel_pool.unclaimed_channels() + entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity( + component, channel_pool.manufacturer, channel_pool.model, channels ) - discovery_infos.append(profile_match) + if entity_class is None: + return + channel_pool.claim_channels(claimed) + channel_pool.async_new_entity(component, entity_class, unique_id, claimed) - discovery_infos.extend( - _async_handle_single_cluster_matches( - hass, endpoint, zha_device, profile_clusters, device_key, is_new_join - ) - ) + @callback + def discover_by_cluster_id(self, channel_pool: zha_typing.ChannelPoolType) -> None: + """Process an endpoint on a zigpy device.""" - -@callback -def _async_create_cluster_channel( - cluster, zha_device, is_new_join, channels=None, channel_class=None -): - """Create a cluster channel and attach it to a device.""" - # really ugly hack to deal with xiaomi using the door lock cluster - # incorrectly. - if hasattr(cluster, "ep_attribute") and cluster.ep_attribute == "multistate_input": - channel_class = AttributeListeningChannel - # end of ugly hack - if channel_class is None: - channel_class = ZIGBEE_CHANNEL_REGISTRY.get( - cluster.cluster_id, AttributeListeningChannel - ) - channel = channel_class(cluster, zha_device) - zha_device.add_cluster_channel(channel) - if channels is not None: - channels.append(channel) - - -@callback -def async_dispatch_discovery_info(hass, is_new_join, discovery_info): - """Dispatch or store discovery information.""" - if not discovery_info["channels"]: - _LOGGER.warning( - "there are no channels in the discovery info: %s", discovery_info - ) - return - component = discovery_info["component"] - if is_new_join: - async_dispatcher_send(hass, ZHA_DISCOVERY_NEW.format(component), discovery_info) - else: - hass.data[DATA_ZHA][component][discovery_info["unique_id"]] = discovery_info - - -@callback -def _async_handle_profile_match( - hass, endpoint, profile_clusters, zha_device, component, device_key, is_new_join -): - """Dispatch a profile match to the appropriate HA component.""" - in_clusters = [ - endpoint.in_clusters[c] for c in profile_clusters if c in endpoint.in_clusters - ] - out_clusters = [ - endpoint.out_clusters[c] for c in profile_clusters if c in endpoint.out_clusters - ] - - channels = [] - - for cluster in in_clusters: - _async_create_cluster_channel( - cluster, zha_device, is_new_join, channels=channels - ) - - for cluster in out_clusters: - _async_create_cluster_channel( - cluster, zha_device, is_new_join, channels=channels - ) - - discovery_info = { - "unique_id": device_key, - "zha_device": zha_device, - "channels": channels, - "component": component, - } - - return discovery_info - - -@callback -def _async_handle_single_cluster_matches( - hass, endpoint, zha_device, profile_clusters, device_key, is_new_join -): - """Dispatch single cluster matches to HA components.""" - cluster_matches = [] - cluster_match_results = [] - matched_power_configuration = False - for cluster in endpoint.in_clusters.values(): - if cluster.cluster_id in CHANNEL_ONLY_CLUSTERS: - cluster_match_results.append( - _async_handle_channel_only_cluster_match( - zha_device, cluster, is_new_join - ) - ) - continue - - if cluster.cluster_id not in profile_clusters: - # Only create one battery sensor per device - if cluster.cluster_id == PowerConfiguration.cluster_id and ( - zha_device.is_mains_powered or matched_power_configuration - ): + items = zha_regs.SINGLE_INPUT_CLUSTER_DEVICE_CLASS.items() + single_input_clusters = { + cluster_class: match + for cluster_class, match in items + if not isinstance(cluster_class, int) + } + remaining_channels = channel_pool.unclaimed_channels() + for channel in remaining_channels: + if channel.cluster.cluster_id in zha_regs.CHANNEL_ONLY_CLUSTERS: + channel_pool.claim_channels([channel]) continue - if ( - cluster.cluster_id == PowerConfiguration.cluster_id - and not zha_device.is_mains_powered - ): - matched_power_configuration = True - - cluster_match_results.append( - _async_handle_single_cluster_match( - hass, - zha_device, - cluster, - device_key, - SINGLE_INPUT_CLUSTER_DEVICE_CLASS, - is_new_join, - ) + component = zha_regs.SINGLE_INPUT_CLUSTER_DEVICE_CLASS.get( + channel.cluster.cluster_id ) + if component is None: + for cluster_class, match in single_input_clusters.items(): + if isinstance(channel.cluster, cluster_class): + component = match + break - for cluster in endpoint.out_clusters.values(): - if cluster.cluster_id in OUTPUT_CHANNEL_ONLY_CLUSTERS: - cluster_match_results.append( - _async_handle_channel_only_cluster_match( - zha_device, cluster, is_new_join - ) + self.probe_single_cluster(component, channel, channel_pool) + + # until we can get rid off registries + self.handle_on_off_output_cluster_exception(channel_pool) + + @staticmethod + def probe_single_cluster( + component: str, + channel: zha_typing.ChannelType, + ep_channels: zha_typing.ChannelPoolType, + ) -> None: + """Probe specified cluster for specific component.""" + if component is None or component not in zha_const.COMPONENTS: + return + channel_list = [channel] + unique_id = f"{ep_channels.unique_id}-{channel.cluster.cluster_id}" + + entity_class, claimed = zha_regs.ZHA_ENTITIES.get_entity( + component, ep_channels.manufacturer, ep_channels.model, channel_list + ) + if entity_class is None: + return + ep_channels.claim_channels(claimed) + ep_channels.async_new_entity(component, entity_class, unique_id, claimed) + + def handle_on_off_output_cluster_exception( + self, ep_channels: zha_typing.ChannelPoolType + ) -> None: + """Process output clusters of the endpoint.""" + + profile_id = ep_channels.endpoint.profile_id + device_type = ep_channels.endpoint.device_type + if device_type in zha_regs.REMOTE_DEVICE_TYPES.get(profile_id, []): + return + + for cluster_id, cluster in ep_channels.endpoint.out_clusters.items(): + component = zha_regs.SINGLE_OUTPUT_CLUSTER_DEVICE_CLASS.get( + cluster.cluster_id ) - continue + if component is None: + continue - device_type = cluster.endpoint.device_type - profile_id = cluster.endpoint.profile_id - - if cluster.cluster_id not in profile_clusters: - # prevent remotes and controllers from getting entities - if not ( - cluster.cluster_id == OnOff.cluster_id - and profile_id in REMOTE_DEVICE_TYPES - and device_type in REMOTE_DEVICE_TYPES[profile_id] - ): - cluster_match_results.append( - _async_handle_single_cluster_match( - hass, - zha_device, - cluster, - device_key, - SINGLE_OUTPUT_CLUSTER_DEVICE_CLASS, - is_new_join, - ) - ) - - if cluster.cluster_id in EVENT_RELAY_CLUSTERS: - _async_create_cluster_channel( - cluster, zha_device, is_new_join, channel_class=EventRelayChannel + channel_class = zha_regs.ZIGBEE_CHANNEL_REGISTRY.get( + cluster_id, base.AttributeListeningChannel ) + channel = channel_class(cluster, ep_channels) + self.probe_single_cluster(component, channel, ep_channels) - for cluster_match in cluster_match_results: - if cluster_match is not None: - cluster_matches.append(cluster_match) - return cluster_matches + def initialize(self, hass: HomeAssistantType) -> None: + """Update device overrides config.""" + zha_config = hass.data[zha_const.DATA_ZHA].get(zha_const.DATA_ZHA_CONFIG, {}) + overrides = zha_config.get(zha_const.CONF_DEVICE_CONFIG) + if overrides: + self._device_configs.update(overrides) -@callback -def _async_handle_channel_only_cluster_match(zha_device, cluster, is_new_join): - """Handle a channel only cluster match.""" - _async_create_cluster_channel(cluster, zha_device, is_new_join) - - -@callback -def _async_handle_single_cluster_match( - hass, zha_device, cluster, device_key, device_classes, is_new_join -): - """Dispatch a single cluster match to a HA component.""" - component = None # sub_component = None - for cluster_type, candidate_component in device_classes.items(): - if isinstance(cluster_type, int): - if cluster.cluster_id == cluster_type: - component = candidate_component - elif isinstance(cluster, cluster_type): - component = candidate_component - break - - if component is None or component not in COMPONENTS: - return - channels = [] - _async_create_cluster_channel(cluster, zha_device, is_new_join, channels=channels) - - cluster_key = f"{device_key}-{cluster.cluster_id}" - discovery_info = { - "unique_id": cluster_key, - "zha_device": zha_device, - "channels": channels, - "entity_suffix": f"_{cluster.cluster_id}", - "component": component, - } - - return discovery_info +PROBE = ProbeEndpoint() diff --git a/homeassistant/components/zha/core/gateway.py b/homeassistant/components/zha/core/gateway.py index 8a8f57764a6..90d8165c640 100644 --- a/homeassistant/components/zha/core/gateway.py +++ b/homeassistant/components/zha/core/gateway.py @@ -18,6 +18,7 @@ from homeassistant.helpers.device_registry import ( from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.entity_registry import async_get_registry as get_ent_reg +from . import discovery, typing as zha_typing from .const import ( ATTR_IEEE, ATTR_MANUFACTURER, @@ -33,6 +34,7 @@ from .const import ( DATA_ZHA, DATA_ZHA_BRIDGE_ID, DATA_ZHA_GATEWAY, + DATA_ZHA_PLATFORM_LOADED, DEBUG_COMP_BELLOWS, DEBUG_COMP_ZHA, DEBUG_COMP_ZIGPY, @@ -47,6 +49,7 @@ from .const import ( DEFAULT_BAUDRATE, DEFAULT_DATABASE_NAME, DOMAIN, + SIGNAL_ADD_ENTITIES, SIGNAL_REMOVE, UNKNOWN_MANUFACTURER, UNKNOWN_MODEL, @@ -67,7 +70,6 @@ from .const import ( ZHA_GW_RADIO_DESCRIPTION, ) from .device import DeviceStatus, ZHADevice -from .discovery import async_dispatch_discovery_info, async_process_endpoint from .group import ZHAGroup from .patches import apply_application_controller_patch from .registries import RADIO_TYPES @@ -107,6 +109,8 @@ class ZHAGateway: async def async_initialize(self): """Initialize controller and connect radio.""" + discovery.PROBE.initialize(self._hass) + self.zha_storage = await async_get_registry(self._hass) self.ha_device_registry = await get_dev_reg(self._hass) self.ha_entity_registry = await get_ent_reg(self._hass) @@ -133,22 +137,34 @@ class ZHAGateway: self._hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID] = str( self.application_controller.ieee ) + self._initialize_groups() + + async def async_load_devices(self) -> None: + """Restore ZHA devices from zigpy application state.""" + await self._hass.data[DATA_ZHA][DATA_ZHA_PLATFORM_LOADED].wait() - init_tasks = [] semaphore = asyncio.Semaphore(2) - async def init_with_semaphore(coro, semaphore): - """Don't flood the zigbee network during initialization.""" + async def _throttle(device: zha_typing.ZigpyDeviceType): async with semaphore: - await coro + await self.async_device_restored(device) - for device in self.application_controller.devices.values(): - init_tasks.append( - init_with_semaphore(self.async_device_restored(device), semaphore) - ) - await asyncio.gather(*init_tasks) + zigpy_devices = self.application_controller.devices.values() + _LOGGER.debug("Loading battery powered devices") + await asyncio.gather( + *[ + _throttle(dev) + for dev in zigpy_devices + if not dev.node_desc.is_mains_powered + ] + ) + async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES) - self._initialize_groups() + _LOGGER.debug("Loading mains powered devices") + await asyncio.gather( + *[_throttle(dev) for dev in zigpy_devices if dev.node_desc.is_mains_powered] + ) + async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES) def device_joined(self, device): """Handle device joined. @@ -356,11 +372,13 @@ class ZHAGateway: self._async_get_or_create_group(group) @callback - def _async_get_or_create_device(self, zigpy_device): + def _async_get_or_create_device( + self, zigpy_device: zha_typing.ZigpyDeviceType, restored: bool = False + ): """Get or create a ZHA device.""" zha_device = self._devices.get(zigpy_device.ieee) if zha_device is None: - zha_device = ZHADevice(self._hass, zigpy_device, self) + zha_device = ZHADevice.new(self._hass, zigpy_device, self, restored) self._devices[zigpy_device.ieee] = zha_device device_registry_device = self.ha_device_registry.async_get_or_create( config_entry_id=self._config_entry.entry_id, @@ -406,13 +424,14 @@ class ZHAGateway: self.zha_storage.async_update(device) await self.zha_storage.async_save() - async def async_device_initialized(self, device): + async def async_device_initialized(self, device: zha_typing.ZigpyDeviceType): """Handle device joined and basic information discovered (async).""" zha_device = self._async_get_or_create_device(device) _LOGGER.debug( - "device - %s entering async_device_initialized - is_new_join: %s", - f"0x{device.nwk:04x}:{device.ieee}", + "device - %s:%s entering async_device_initialized - is_new_join: %s", + device.nwk, + device.ieee, zha_device.status is not DeviceStatus.INITIALIZED, ) @@ -420,16 +439,18 @@ class ZHAGateway: # ZHA already has an initialized device so either the device was assigned a # new nwk or device was physically reset and added again without being removed _LOGGER.debug( - "device - %s has been reset and re-added or its nwk address changed", - f"0x{device.nwk:04x}:{device.ieee}", + "device - %s:%s has been reset and re-added or its nwk address changed", + device.nwk, + device.ieee, ) await self._async_device_rejoined(zha_device) else: _LOGGER.debug( - "device - %s has joined the ZHA zigbee network", - f"0x{device.nwk:04x}:{device.ieee}", + "device - %s:%s has joined the ZHA zigbee network", + device.nwk, + device.ieee, ) - await self._async_device_joined(device, zha_device) + await self._async_device_joined(zha_device) device_info = zha_device.async_get_info() @@ -442,64 +463,36 @@ class ZHAGateway: }, ) - async def _async_device_joined(self, device, zha_device): - discovery_infos = [] - for endpoint_id, endpoint in device.endpoints.items(): - async_process_endpoint( - self._hass, - self._config, - endpoint_id, - endpoint, - discovery_infos, - device, - zha_device, - True, - ) - + async def _async_device_joined(self, zha_device: zha_typing.ZhaDeviceType) -> None: await zha_device.async_configure() # will cause async_init to fire so don't explicitly call it zha_device.update_available(True) - - for discovery_info in discovery_infos: - async_dispatch_discovery_info(self._hass, True, discovery_info) + async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES) # only public for testing - async def async_device_restored(self, device): + async def async_device_restored(self, device: zha_typing.ZigpyDeviceType): """Add an existing device to the ZHA zigbee network when ZHA first starts.""" - zha_device = self._async_get_or_create_device(device) - discovery_infos = [] - for endpoint_id, endpoint in device.endpoints.items(): - async_process_endpoint( - self._hass, - self._config, - endpoint_id, - endpoint, - discovery_infos, - device, - zha_device, - False, - ) + zha_device = self._async_get_or_create_device(device, restored=True) if zha_device.is_mains_powered: # the device isn't a battery powered device so we should be able # to update it now _LOGGER.debug( - "attempting to request fresh state for device - %s %s %s", - f"0x{zha_device.nwk:04x}:{zha_device.ieee}", + "attempting to request fresh state for device - %s:%s %s with power source %s", + zha_device.nwk, + zha_device.ieee, zha_device.name, - f"with power source: {zha_device.power_source}", + zha_device.power_source, ) await zha_device.async_initialize(from_cache=False) else: await zha_device.async_initialize(from_cache=True) - for discovery_info in discovery_infos: - async_dispatch_discovery_info(self._hass, False, discovery_info) - async def _async_device_rejoined(self, zha_device): _LOGGER.debug( - "skipping discovery for previously discovered device - %s", - f"0x{zha_device.nwk:04x}:{zha_device.ieee}", + "skipping discovery for previously discovered device - %s:%s", + zha_device.nwk, + zha_device.ieee, ) # we don't have to do this on a nwk swap but we don't have a way to tell currently await zha_device.async_configure() diff --git a/homeassistant/components/zha/core/registries.py b/homeassistant/components/zha/core/registries.py index bc788b39ee7..3b08d1acd37 100644 --- a/homeassistant/components/zha/core/registries.py +++ b/homeassistant/components/zha/core/registries.py @@ -1,6 +1,6 @@ """Mapping registries for Zigbee Home Automation.""" import collections -from typing import Callable, Set, Union +from typing import Callable, Dict, List, Set, Tuple, Union import attr import bellows.ezsp @@ -27,9 +27,10 @@ from homeassistant.components.sensor import DOMAIN as SENSOR from homeassistant.components.switch import DOMAIN as SWITCH # importing channels updates registries -from . import channels # noqa: F401 pylint: disable=unused-import +from . import channels as zha_channels # noqa: F401 pylint: disable=unused-import from .const import CONTROLLER, ZHA_GW_RADIO, ZHA_GW_RADIO_DESCRIPTION, RadioType from .decorators import CALLABLE_T, DictRegistry, SetRegistry +from .typing import ChannelType SMARTTHINGS_ACCELERATION_CLUSTER = 0xFC02 SMARTTHINGS_ARRIVAL_SENSOR_DEVICE_TYPE = 0x8000 @@ -57,30 +58,33 @@ REMOTE_DEVICE_TYPES = { zigpy.profiles.zll.DeviceType.SCENE_CONTROLLER, ], } +REMOTE_DEVICE_TYPES = collections.defaultdict(list, REMOTE_DEVICE_TYPES) SINGLE_INPUT_CLUSTER_DEVICE_CLASS = { # this works for now but if we hit conflicts we can break it out to # a different dict that is keyed by manufacturer SMARTTHINGS_ACCELERATION_CLUSTER: BINARY_SENSOR, SMARTTHINGS_HUMIDITY_CLUSTER: SENSOR, - zcl.clusters.closures.DoorLock: LOCK, - zcl.clusters.closures.WindowCovering: COVER, + zcl.clusters.closures.DoorLock.cluster_id: LOCK, + zcl.clusters.closures.WindowCovering.cluster_id: COVER, zcl.clusters.general.AnalogInput.cluster_id: SENSOR, zcl.clusters.general.MultistateInput.cluster_id: SENSOR, - zcl.clusters.general.OnOff: SWITCH, - zcl.clusters.general.PowerConfiguration: SENSOR, - zcl.clusters.homeautomation.ElectricalMeasurement: SENSOR, - zcl.clusters.hvac.Fan: FAN, - zcl.clusters.measurement.IlluminanceMeasurement: SENSOR, - zcl.clusters.measurement.OccupancySensing: BINARY_SENSOR, - zcl.clusters.measurement.PressureMeasurement: SENSOR, - zcl.clusters.measurement.RelativeHumidity: SENSOR, - zcl.clusters.measurement.TemperatureMeasurement: SENSOR, - zcl.clusters.security.IasZone: BINARY_SENSOR, - zcl.clusters.smartenergy.Metering: SENSOR, + zcl.clusters.general.OnOff.cluster_id: SWITCH, + zcl.clusters.general.PowerConfiguration.cluster_id: SENSOR, + zcl.clusters.homeautomation.ElectricalMeasurement.cluster_id: SENSOR, + zcl.clusters.hvac.Fan.cluster_id: FAN, + zcl.clusters.measurement.IlluminanceMeasurement.cluster_id: SENSOR, + zcl.clusters.measurement.OccupancySensing.cluster_id: BINARY_SENSOR, + zcl.clusters.measurement.PressureMeasurement.cluster_id: SENSOR, + zcl.clusters.measurement.RelativeHumidity.cluster_id: SENSOR, + zcl.clusters.measurement.TemperatureMeasurement.cluster_id: SENSOR, + zcl.clusters.security.IasZone.cluster_id: BINARY_SENSOR, + zcl.clusters.smartenergy.Metering.cluster_id: SENSOR, } -SINGLE_OUTPUT_CLUSTER_DEVICE_CLASS = {zcl.clusters.general.OnOff: BINARY_SENSOR} +SINGLE_OUTPUT_CLUSTER_DEVICE_CLASS = { + zcl.clusters.general.OnOff.cluster_id: BINARY_SENSOR +} SWITCH_CLUSTERS = SetRegistry() @@ -89,7 +93,6 @@ BINARY_SENSOR_CLUSTERS.add(SMARTTHINGS_ACCELERATION_CLUSTER) BINDABLE_CLUSTERS = SetRegistry() CHANNEL_ONLY_CLUSTERS = SetRegistry() -CLUSTER_REPORT_CONFIGS = {} CUSTOM_CLUSTER_MAPPINGS = {} DEVICE_CLASS = { @@ -117,6 +120,7 @@ DEVICE_CLASS = { zigpy.profiles.zll.DeviceType.ON_OFF_PLUGIN_UNIT: SWITCH, }, } +DEVICE_CLASS = collections.defaultdict(dict, DEVICE_CLASS) DEVICE_TRACKER_CLUSTERS = SetRegistry() EVENT_RELAY_CLUSTERS = SetRegistry() @@ -188,6 +192,63 @@ class MatchRule: models: Union[Callable, Set[str], str] = attr.ib( factory=frozenset, converter=set_or_callable ) + aux_channels: Union[Callable, Set[str], str] = attr.ib( + factory=frozenset, converter=set_or_callable + ) + + def claim_channels(self, channel_pool: List[ChannelType]) -> List[ChannelType]: + """Return a list of channels this rule matches + aux channels.""" + claimed = [] + if isinstance(self.channel_names, frozenset): + claimed.extend([ch for ch in channel_pool if ch.name in self.channel_names]) + if isinstance(self.generic_ids, frozenset): + claimed.extend( + [ch for ch in channel_pool if ch.generic_id in self.generic_ids] + ) + if isinstance(self.aux_channels, frozenset): + claimed.extend([ch for ch in channel_pool if ch.name in self.aux_channels]) + return claimed + + def strict_matched(self, manufacturer: str, model: str, channels: List) -> bool: + """Return True if this device matches the criteria.""" + return all(self._matched(manufacturer, model, channels)) + + def loose_matched(self, manufacturer: str, model: str, channels: List) -> bool: + """Return True if this device matches the criteria.""" + return any(self._matched(manufacturer, model, channels)) + + def _matched(self, manufacturer: str, model: str, channels: List) -> list: + """Return a list of field matches.""" + if not any(attr.asdict(self).values()): + return [False] + + matches = [] + if self.channel_names: + channel_names = {ch.name for ch in channels} + matches.append(self.channel_names.issubset(channel_names)) + + if self.generic_ids: + all_generic_ids = {ch.generic_id for ch in channels} + matches.append(self.generic_ids.issubset(all_generic_ids)) + + if self.manufacturers: + if callable(self.manufacturers): + matches.append(self.manufacturers(manufacturer)) + else: + matches.append(manufacturer in self.manufacturers) + + if self.models: + if callable(self.models): + matches.append(self.models(model)) + else: + matches.append(model in self.models) + + return matches + + +RegistryDictType = Dict[ + str, Dict[MatchRule, CALLABLE_T] +] # pylint: disable=invalid-name class ZHAEntityRegistry: @@ -195,18 +256,24 @@ class ZHAEntityRegistry: def __init__(self): """Initialize Registry instance.""" - self._strict_registry = collections.defaultdict(dict) - self._loose_registry = collections.defaultdict(dict) + self._strict_registry: RegistryDictType = collections.defaultdict(dict) + self._loose_registry: RegistryDictType = collections.defaultdict(dict) def get_entity( - self, component: str, zha_device, chnls: dict, default: CALLABLE_T = None - ) -> CALLABLE_T: + self, + component: str, + manufacturer: str, + model: str, + channels: List[ChannelType], + default: CALLABLE_T = None, + ) -> Tuple[CALLABLE_T, List[ChannelType]]: """Match a ZHA Channels to a ZHA Entity class.""" for match in self._strict_registry[component]: - if self._strict_matched(zha_device, chnls, match): - return self._strict_registry[component][match] + if match.strict_matched(manufacturer, model, channels): + claimed = match.claim_channels(channels) + return self._strict_registry[component][match], claimed - return default + return default, [] def strict_match( self, @@ -215,10 +282,13 @@ class ZHAEntityRegistry: generic_ids: Union[Callable, Set[str], str] = None, manufacturers: Union[Callable, Set[str], str] = None, models: Union[Callable, Set[str], str] = None, + aux_channels: Union[Callable, Set[str], str] = None, ) -> Callable[[CALLABLE_T], CALLABLE_T]: """Decorate a strict match rule.""" - rule = MatchRule(channel_names, generic_ids, manufacturers, models) + rule = MatchRule( + channel_names, generic_ids, manufacturers, models, aux_channels + ) def decorator(zha_ent: CALLABLE_T) -> CALLABLE_T: """Register a strict match rule. @@ -237,10 +307,13 @@ class ZHAEntityRegistry: generic_ids: Union[Callable, Set[str], str] = None, manufacturers: Union[Callable, Set[str], str] = None, models: Union[Callable, Set[str], str] = None, + aux_channels: Union[Callable, Set[str], str] = None, ) -> Callable[[CALLABLE_T], CALLABLE_T]: """Decorate a loose match rule.""" - rule = MatchRule(channel_names, generic_ids, manufacturers, models) + rule = MatchRule( + channel_names, generic_ids, manufacturers, models, aux_channels + ) def decorator(zha_entity: CALLABLE_T) -> CALLABLE_T: """Register a loose match rule. @@ -252,42 +325,5 @@ class ZHAEntityRegistry: return decorator - def _strict_matched(self, zha_device, chnls: dict, rule: MatchRule) -> bool: - """Return True if this device matches the criteria.""" - return all(self._matched(zha_device, chnls, rule)) - - def _loose_matched(self, zha_device, chnls: dict, rule: MatchRule) -> bool: - """Return True if this device matches the criteria.""" - return any(self._matched(zha_device, chnls, rule)) - - @staticmethod - def _matched(zha_device, chnls: dict, rule: MatchRule) -> list: - """Return a list of field matches.""" - if not any(attr.asdict(rule).values()): - return [False] - - matches = [] - if rule.channel_names: - channel_names = {ch.name for ch in chnls} - matches.append(rule.channel_names.issubset(channel_names)) - - if rule.generic_ids: - all_generic_ids = {ch.generic_id for ch in chnls} - matches.append(rule.generic_ids.issubset(all_generic_ids)) - - if rule.manufacturers: - if callable(rule.manufacturers): - matches.append(rule.manufacturers(zha_device.manufacturer)) - else: - matches.append(zha_device.manufacturer in rule.manufacturers) - - if rule.models: - if callable(rule.models): - matches.append(rule.models(zha_device.model)) - else: - matches.append(zha_device.model in rule.models) - - return matches - ZHA_ENTITIES = ZHAEntityRegistry() diff --git a/homeassistant/components/zha/core/typing.py b/homeassistant/components/zha/core/typing.py new file mode 100644 index 00000000000..3d10912d165 --- /dev/null +++ b/homeassistant/components/zha/core/typing.py @@ -0,0 +1,41 @@ +"""Typing helpers for ZHA component.""" + +from typing import TYPE_CHECKING, Callable, TypeVar + +import zigpy.device +import zigpy.endpoint +import zigpy.zcl +import zigpy.zdo + +# pylint: disable=invalid-name +CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) +ChannelType = "ZigbeeChannel" +ChannelsType = "Channels" +ChannelPoolType = "ChannelPool" +EventRelayChannelType = "EventRelayChannel" +ZDOChannelType = "ZDOChannel" +ZhaDeviceType = "ZHADevice" +ZhaEntityType = "ZHAEntity" +ZhaGatewayType = "ZHAGateway" +ZigpyClusterType = zigpy.zcl.Cluster +ZigpyDeviceType = zigpy.device.Device +ZigpyEndpointType = zigpy.endpoint.Endpoint +ZigpyZdoType = zigpy.zdo.ZDO + +if TYPE_CHECKING: + import homeassistant.components.zha.core.channels as channels + import homeassistant.components.zha.core.channels.base as base_channels + import homeassistant.components.zha.core.device + import homeassistant.components.zha.core.gateway + import homeassistant.components.zha.entity + import homeassistant.components.zha.core.channels + + # pylint: disable=invalid-name + ChannelType = base_channels.ZigbeeChannel + ChannelsType = channels.Channels + ChannelPoolType = channels.ChannelPool + EventRelayChannelType = base_channels.EventRelayChannel + ZDOChannelType = base_channels.ZDOChannel + ZhaDeviceType = homeassistant.components.zha.core.device.ZHADevice + ZhaEntityType = homeassistant.components.zha.entity.ZhaEntity + ZhaGatewayType = homeassistant.components.zha.core.gateway.ZHAGateway diff --git a/homeassistant/components/zha/cover.py b/homeassistant/components/zha/cover.py index 3eeb73a23fd..13de445cf37 100644 --- a/homeassistant/components/zha/cover.py +++ b/homeassistant/components/zha/cover.py @@ -10,12 +10,13 @@ from homeassistant.const import STATE_CLOSED, STATE_CLOSING, STATE_OPEN, STATE_O from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_connect +from .core import discovery from .core.const import ( CHANNEL_COVER, DATA_ZHA, DATA_ZHA_DISPATCHERS, + SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, - ZHA_DISCOVERY_NEW, ) from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -28,41 +29,17 @@ STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN) async def async_setup_entry(hass, config_entry, async_add_entities): """Set up the Zigbee Home Automation cover from config entry.""" - - async def async_discover(discovery_info): - await _async_setup_entities( - hass, config_entry, async_add_entities, [discovery_info] - ) + entities_to_create = hass.data[DATA_ZHA][DOMAIN] = [] unsub = async_dispatcher_connect( - hass, ZHA_DISCOVERY_NEW.format(DOMAIN), async_discover + hass, + SIGNAL_ADD_ENTITIES, + functools.partial( + discovery.async_add_entities, async_add_entities, entities_to_create + ), ) hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) - covers = hass.data.get(DATA_ZHA, {}).get(DOMAIN) - if covers is not None: - await _async_setup_entities( - hass, config_entry, async_add_entities, covers.values() - ) - del hass.data[DATA_ZHA][DOMAIN] - - -async def _async_setup_entities( - hass, config_entry, async_add_entities, discovery_infos -): - """Set up the ZHA covers.""" - entities = [] - for discovery_info in discovery_infos: - zha_dev = discovery_info["zha_device"] - channels = discovery_info["channels"] - - entity = ZHA_ENTITIES.get_entity(DOMAIN, zha_dev, channels, ZhaCover) - if entity: - entities.append(entity(**discovery_info)) - - if entities: - async_add_entities(entities, update_before_add=True) - @STRICT_MATCH(channel_names=CHANNEL_COVER) class ZhaCover(ZhaEntity, CoverDevice): diff --git a/homeassistant/components/zha/device_action.py b/homeassistant/components/zha/device_action.py index 60cfa0eec00..5a2e0c40881 100644 --- a/homeassistant/components/zha/device_action.py +++ b/homeassistant/components/zha/device_action.py @@ -57,11 +57,16 @@ async def async_call_action_from_config( async def async_get_actions(hass: HomeAssistant, device_id: str) -> List[dict]: """List device actions.""" zha_device = await async_get_zha_device(hass, device_id) + cluster_channels = [ + ch.name + for pool in zha_device.channels.pools + for ch in pool.claimed_channels.values() + ] actions = [ action for channel in DEVICE_ACTIONS for action in DEVICE_ACTIONS[channel] - if channel in zha_device.cluster_channels + if channel in cluster_channels ] for action in actions: action[CONF_DEVICE_ID] = device_id diff --git a/homeassistant/components/zha/device_tracker.py b/homeassistant/components/zha/device_tracker.py index 76548935814..5481ec70f52 100644 --- a/homeassistant/components/zha/device_tracker.py +++ b/homeassistant/components/zha/device_tracker.py @@ -8,12 +8,13 @@ from homeassistant.components.device_tracker.config_entry import ScannerEntity from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_connect +from .core import discovery from .core.const import ( CHANNEL_POWER_CONFIGURATION, DATA_ZHA, DATA_ZHA_DISPATCHERS, + SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, - ZHA_DISCOVERY_NEW, ) from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -25,51 +26,25 @@ _LOGGER = logging.getLogger(__name__) async def async_setup_entry(hass, config_entry, async_add_entities): """Set up the Zigbee Home Automation device tracker from config entry.""" - - async def async_discover(discovery_info): - await _async_setup_entities( - hass, config_entry, async_add_entities, [discovery_info] - ) + entities_to_create = hass.data[DATA_ZHA][DOMAIN] = [] unsub = async_dispatcher_connect( - hass, ZHA_DISCOVERY_NEW.format(DOMAIN), async_discover + hass, + SIGNAL_ADD_ENTITIES, + functools.partial( + discovery.async_add_entities, async_add_entities, entities_to_create + ), ) hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) - device_trackers = hass.data.get(DATA_ZHA, {}).get(DOMAIN) - if device_trackers is not None: - await _async_setup_entities( - hass, config_entry, async_add_entities, device_trackers.values() - ) - del hass.data[DATA_ZHA][DOMAIN] - - -async def _async_setup_entities( - hass, config_entry, async_add_entities, discovery_infos -): - """Set up the ZHA device trackers.""" - entities = [] - for discovery_info in discovery_infos: - zha_dev = discovery_info["zha_device"] - channels = discovery_info["channels"] - - entity = ZHA_ENTITIES.get_entity( - DOMAIN, zha_dev, channels, ZHADeviceScannerEntity - ) - if entity: - entities.append(entity(**discovery_info)) - - if entities: - async_add_entities(entities, update_before_add=True) - @STRICT_MATCH(channel_names=CHANNEL_POWER_CONFIGURATION) class ZHADeviceScannerEntity(ScannerEntity, ZhaEntity): """Represent a tracked device.""" - def __init__(self, **kwargs): + def __init__(self, unique_id, zha_device, channels, **kwargs): """Initialize the ZHA device tracker.""" - super().__init__(**kwargs) + super().__init__(unique_id, zha_device, channels, **kwargs) self._battery_channel = self.cluster_channels.get(CHANNEL_POWER_CONFIGURATION) self._connected = False self._keepalive_interval = 60 diff --git a/homeassistant/components/zha/entity.py b/homeassistant/components/zha/entity.py index 6a9dfc63432..76d0908000b 100644 --- a/homeassistant/components/zha/entity.py +++ b/homeassistant/components/zha/entity.py @@ -44,7 +44,6 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity): self._zha_device = zha_device self.cluster_channels = {} self._available = False - self._component = kwargs["component"] self._unsubs = [] self.remove_future = None for channel in channels: diff --git a/homeassistant/components/zha/fan.py b/homeassistant/components/zha/fan.py index 6ad13d1c802..59a6bfb9c47 100644 --- a/homeassistant/components/zha/fan.py +++ b/homeassistant/components/zha/fan.py @@ -14,12 +14,13 @@ from homeassistant.components.fan import ( from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_connect +from .core import discovery from .core.const import ( CHANNEL_FAN, DATA_ZHA, DATA_ZHA_DISPATCHERS, + SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, - ZHA_DISCOVERY_NEW, ) from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -52,41 +53,17 @@ STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN) async def async_setup_entry(hass, config_entry, async_add_entities): """Set up the Zigbee Home Automation fan from config entry.""" - - async def async_discover(discovery_info): - await _async_setup_entities( - hass, config_entry, async_add_entities, [discovery_info] - ) + entities_to_create = hass.data[DATA_ZHA][DOMAIN] = [] unsub = async_dispatcher_connect( - hass, ZHA_DISCOVERY_NEW.format(DOMAIN), async_discover + hass, + SIGNAL_ADD_ENTITIES, + functools.partial( + discovery.async_add_entities, async_add_entities, entities_to_create + ), ) hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) - fans = hass.data.get(DATA_ZHA, {}).get(DOMAIN) - if fans is not None: - await _async_setup_entities( - hass, config_entry, async_add_entities, fans.values() - ) - del hass.data[DATA_ZHA][DOMAIN] - - -async def _async_setup_entities( - hass, config_entry, async_add_entities, discovery_infos -): - """Set up the ZHA fans.""" - entities = [] - for discovery_info in discovery_infos: - zha_dev = discovery_info["zha_device"] - channels = discovery_info["channels"] - - entity = ZHA_ENTITIES.get_entity(DOMAIN, zha_dev, channels, ZhaFan) - if entity: - entities.append(entity(**discovery_info)) - - if entities: - async_add_entities(entities, update_before_add=True) - @STRICT_MATCH(channel_names=CHANNEL_FAN) class ZhaFan(ZhaEntity, FanEntity): diff --git a/homeassistant/components/zha/light.py b/homeassistant/components/zha/light.py index 409cd339122..dc2e156dbf5 100644 --- a/homeassistant/components/zha/light.py +++ b/homeassistant/components/zha/light.py @@ -12,15 +12,16 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.event import async_track_time_interval import homeassistant.util.color as color_util +from .core import discovery from .core.const import ( CHANNEL_COLOR, CHANNEL_LEVEL, CHANNEL_ON_OFF, DATA_ZHA, DATA_ZHA_DISPATCHERS, + SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, SIGNAL_SET_LEVEL, - ZHA_DISCOVERY_NEW, ) from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -44,43 +45,19 @@ PARALLEL_UPDATES = 5 async def async_setup_entry(hass, config_entry, async_add_entities): """Set up the Zigbee Home Automation light from config entry.""" - - async def async_discover(discovery_info): - await _async_setup_entities( - hass, config_entry, async_add_entities, [discovery_info] - ) + entities_to_create = hass.data[DATA_ZHA][light.DOMAIN] = [] unsub = async_dispatcher_connect( - hass, ZHA_DISCOVERY_NEW.format(light.DOMAIN), async_discover + hass, + SIGNAL_ADD_ENTITIES, + functools.partial( + discovery.async_add_entities, async_add_entities, entities_to_create + ), ) hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) - lights = hass.data.get(DATA_ZHA, {}).get(light.DOMAIN) - if lights is not None: - await _async_setup_entities( - hass, config_entry, async_add_entities, lights.values() - ) - del hass.data[DATA_ZHA][light.DOMAIN] - -async def _async_setup_entities( - hass, config_entry, async_add_entities, discovery_infos -): - """Set up the ZHA lights.""" - entities = [] - for discovery_info in discovery_infos: - zha_dev = discovery_info["zha_device"] - channels = discovery_info["channels"] - - entity = ZHA_ENTITIES.get_entity(light.DOMAIN, zha_dev, channels, Light) - if entity: - entities.append(entity(**discovery_info)) - - if entities: - async_add_entities(entities, update_before_add=True) - - -@STRICT_MATCH(channel_names=CHANNEL_ON_OFF) +@STRICT_MATCH(channel_names=CHANNEL_ON_OFF, aux_channels={CHANNEL_COLOR, CHANNEL_LEVEL}) class Light(ZhaEntity, light.Light): """Representation of a ZHA or ZLL light.""" diff --git a/homeassistant/components/zha/lock.py b/homeassistant/components/zha/lock.py index b173c166a77..7ba31158fc3 100644 --- a/homeassistant/components/zha/lock.py +++ b/homeassistant/components/zha/lock.py @@ -13,12 +13,13 @@ from homeassistant.components.lock import ( from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_connect +from .core import discovery from .core.const import ( CHANNEL_DOORLOCK, DATA_ZHA, DATA_ZHA_DISPATCHERS, + SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, - ZHA_DISCOVERY_NEW, ) from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -35,41 +36,17 @@ VALUE_TO_STATE = dict(enumerate(STATE_LIST)) async def async_setup_entry(hass, config_entry, async_add_entities): """Set up the Zigbee Home Automation Door Lock from config entry.""" - - async def async_discover(discovery_info): - await _async_setup_entities( - hass, config_entry, async_add_entities, [discovery_info] - ) + entities_to_create = hass.data[DATA_ZHA][DOMAIN] = [] unsub = async_dispatcher_connect( - hass, ZHA_DISCOVERY_NEW.format(DOMAIN), async_discover + hass, + SIGNAL_ADD_ENTITIES, + functools.partial( + discovery.async_add_entities, async_add_entities, entities_to_create + ), ) hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) - locks = hass.data.get(DATA_ZHA, {}).get(DOMAIN) - if locks is not None: - await _async_setup_entities( - hass, config_entry, async_add_entities, locks.values() - ) - del hass.data[DATA_ZHA][DOMAIN] - - -async def _async_setup_entities( - hass, config_entry, async_add_entities, discovery_infos -): - """Set up the ZHA locks.""" - entities = [] - for discovery_info in discovery_infos: - zha_dev = discovery_info["zha_device"] - channels = discovery_info["channels"] - - entity = ZHA_ENTITIES.get_entity(DOMAIN, zha_dev, channels, ZhaDoorLock) - if entity: - entities.append(entity(**discovery_info)) - - if entities: - async_add_entities(entities, update_before_add=True) - @STRICT_MATCH(channel_names=CHANNEL_DOORLOCK) class ZhaDoorLock(ZhaEntity, LockDevice): diff --git a/homeassistant/components/zha/sensor.py b/homeassistant/components/zha/sensor.py index 8b7dd894973..b98c50d1fa4 100644 --- a/homeassistant/components/zha/sensor.py +++ b/homeassistant/components/zha/sensor.py @@ -22,7 +22,9 @@ from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.util.temperature import fahrenheit_to_celsius +from .core import discovery from .core.const import ( + CHANNEL_ANALOG_INPUT, CHANNEL_ELECTRICAL_MEASUREMENT, CHANNEL_HUMIDITY, CHANNEL_ILLUMINANCE, @@ -33,9 +35,9 @@ from .core.const import ( CHANNEL_TEMPERATURE, DATA_ZHA, DATA_ZHA_DISPATCHERS, + SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, SIGNAL_STATE_ATTR, - ZHA_DISCOVERY_NEW, ) from .core.registries import SMARTTHINGS_HUMIDITY_CLUSTER, ZHA_ENTITIES from .entity import ZhaEntity @@ -65,46 +67,17 @@ STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN) async def async_setup_entry(hass, config_entry, async_add_entities): """Set up the Zigbee Home Automation sensor from config entry.""" - - async def async_discover(discovery_info): - await _async_setup_entities( - hass, config_entry, async_add_entities, [discovery_info] - ) + entities_to_create = hass.data[DATA_ZHA][DOMAIN] = [] unsub = async_dispatcher_connect( - hass, ZHA_DISCOVERY_NEW.format(DOMAIN), async_discover + hass, + SIGNAL_ADD_ENTITIES, + functools.partial( + discovery.async_add_entities, async_add_entities, entities_to_create + ), ) hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) - sensors = hass.data.get(DATA_ZHA, {}).get(DOMAIN) - if sensors is not None: - await _async_setup_entities( - hass, config_entry, async_add_entities, sensors.values() - ) - del hass.data[DATA_ZHA][DOMAIN] - - -async def _async_setup_entities( - hass, config_entry, async_add_entities, discovery_infos -): - """Set up the ZHA sensors.""" - entities = [] - for discovery_info in discovery_infos: - entities.append(await make_sensor(discovery_info)) - - if entities: - async_add_entities(entities, update_before_add=True) - - -async def make_sensor(discovery_info): - """Create ZHA sensors factory.""" - - zha_dev = discovery_info["zha_device"] - channels = discovery_info["channels"] - - entity = ZHA_ENTITIES.get_entity(DOMAIN, zha_dev, channels, Sensor) - return entity(**discovery_info) - class Sensor(ZhaEntity): """Base ZHA sensor.""" @@ -176,6 +149,13 @@ class Sensor(ZhaEntity): return round(float(value * self._multiplier) / self._divisor) +@STRICT_MATCH(channel_names=CHANNEL_ANALOG_INPUT) +class AnalogInput(Sensor): + """Sensor that displays analog input values.""" + + pass + + @STRICT_MATCH(channel_names=CHANNEL_POWER_CONFIGURATION) class Battery(Sensor): """Battery sensor of power configuration cluster.""" diff --git a/homeassistant/components/zha/switch.py b/homeassistant/components/zha/switch.py index 1280ace34dc..e6a82fe0270 100644 --- a/homeassistant/components/zha/switch.py +++ b/homeassistant/components/zha/switch.py @@ -9,12 +9,13 @@ from homeassistant.const import STATE_ON from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_connect +from .core import discovery from .core.const import ( CHANNEL_ON_OFF, DATA_ZHA, DATA_ZHA_DISPATCHERS, + SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, - ZHA_DISCOVERY_NEW, ) from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity @@ -25,49 +26,25 @@ STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN) async def async_setup_entry(hass, config_entry, async_add_entities): """Set up the Zigbee Home Automation switch from config entry.""" - - async def async_discover(discovery_info): - await _async_setup_entities( - hass, config_entry, async_add_entities, [discovery_info] - ) + entities_to_create = hass.data[DATA_ZHA][DOMAIN] = [] unsub = async_dispatcher_connect( - hass, ZHA_DISCOVERY_NEW.format(DOMAIN), async_discover + hass, + SIGNAL_ADD_ENTITIES, + functools.partial( + discovery.async_add_entities, async_add_entities, entities_to_create + ), ) hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) - switches = hass.data.get(DATA_ZHA, {}).get(DOMAIN) - if switches is not None: - await _async_setup_entities( - hass, config_entry, async_add_entities, switches.values() - ) - del hass.data[DATA_ZHA][DOMAIN] - - -async def _async_setup_entities( - hass, config_entry, async_add_entities, discovery_infos -): - """Set up the ZHA switches.""" - entities = [] - for discovery_info in discovery_infos: - zha_dev = discovery_info["zha_device"] - channels = discovery_info["channels"] - - entity = ZHA_ENTITIES.get_entity(DOMAIN, zha_dev, channels, Switch) - if entity: - entities.append(entity(**discovery_info)) - - if entities: - async_add_entities(entities, update_before_add=True) - @STRICT_MATCH(channel_names=CHANNEL_ON_OFF) class Switch(ZhaEntity, SwitchDevice): """ZHA switch.""" - def __init__(self, **kwargs): + def __init__(self, unique_id, zha_device, channels, **kwargs): """Initialize the ZHA switch.""" - super().__init__(**kwargs) + super().__init__(unique_id, zha_device, channels, **kwargs) self._on_off_channel = self.cluster_channels.get(CHANNEL_ON_OFF) @property diff --git a/tests/components/zha/common.py b/tests/components/zha/common.py index 03b6ed21148..dfa0c455649 100644 --- a/tests/components/zha/common.py +++ b/tests/components/zha/common.py @@ -63,6 +63,7 @@ class FakeDevice: def __init__(self, app, ieee, manufacturer, model, node_desc=None): """Init fake device.""" self._application = app + self.application = app self.ieee = zigpy.types.EUI64.convert(ieee) self.nwk = 0xB79C self.zdo = Mock() diff --git a/tests/components/zha/conftest.py b/tests/components/zha/conftest.py index 26dd2b5da5c..e3a8f6bf4dc 100644 --- a/tests/components/zha/conftest.py +++ b/tests/components/zha/conftest.py @@ -9,6 +9,7 @@ import zigpy.group import zigpy.types import homeassistant.components.zha.core.const as zha_const +import homeassistant.components.zha.core.device as zha_core_device import homeassistant.components.zha.core.registries as zha_regs from homeassistant.setup import async_setup_component @@ -63,7 +64,7 @@ async def config_entry_fixture(hass): @pytest.fixture def setup_zha(hass, config_entry, zigpy_app_controller, zigpy_radio): """Set up ZHA component.""" - zha_config = {zha_const.DOMAIN: {zha_const.CONF_ENABLE_QUIRKS: False}} + zha_config = {zha_const.CONF_ENABLE_QUIRKS: False} radio_details = { zha_const.ZHA_GW_RADIO: mock.MagicMock(return_value=zigpy_radio), @@ -71,9 +72,12 @@ def setup_zha(hass, config_entry, zigpy_app_controller, zigpy_radio): zha_const.ZHA_GW_RADIO_DESCRIPTION: "mock radio", } - async def _setup(): + async def _setup(config=None): + config = config or {} with mock.patch.dict(zha_regs.RADIO_TYPES, {"MockRadio": radio_details}): - status = await async_setup_component(hass, zha_const.DOMAIN, zha_config) + status = await async_setup_component( + hass, zha_const.DOMAIN, {zha_const.DOMAIN: {**zha_config, **config}} + ) assert status is True await hass.async_block_till_done() @@ -153,6 +157,7 @@ def zha_device_restored(hass, zigpy_app_controller, setup_zha): zigpy_app_controller.devices[zigpy_dev.ieee] = zigpy_dev await setup_zha() zha_gateway = hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] + await zha_gateway.async_load_devices() return zha_gateway.get_device(zigpy_dev.ieee) return _zha_device @@ -162,3 +167,36 @@ def zha_device_restored(hass, zigpy_app_controller, setup_zha): def zha_device_joined_restored(request): """Join or restore ZHA device.""" return request.getfixturevalue(request.param) + + +@pytest.fixture +def zha_device_mock(hass, zigpy_device_mock): + """Return a zha Device factory.""" + + def _zha_device( + endpoints=None, + ieee="00:11:22:33:44:55:66:77", + manufacturer="mock manufacturer", + model="mock model", + node_desc=b"\x02@\x807\x10\x7fd\x00\x00*d\x00\x00", + ): + if endpoints is None: + endpoints = { + 1: { + "in_clusters": [0, 1, 8, 768], + "out_clusters": [0x19], + "device_type": 0x0105, + }, + 2: { + "in_clusters": [0], + "out_clusters": [6, 8, 0x19, 768], + "device_type": 0x0810, + }, + } + zigpy_device = zigpy_device_mock( + endpoints, ieee, manufacturer, model, node_desc + ) + zha_device = zha_core_device.ZHADevice(hass, zigpy_device, mock.MagicMock()) + return zha_device + + return _zha_device diff --git a/tests/components/zha/test_channels.py b/tests/components/zha/test_channels.py index ee493ca01a7..3f38108cf89 100644 --- a/tests/components/zha/test_channels.py +++ b/tests/components/zha/test_channels.py @@ -1,9 +1,14 @@ """Test ZHA Core channels.""" +import asyncio +from unittest import mock + +import asynctest import pytest import zigpy.types as t -import homeassistant.components.zha.core.channels as channels -import homeassistant.components.zha.core.device as zha_device +import homeassistant.components.zha.core.channels as zha_channels +import homeassistant.components.zha.core.channels.base as base_channels +import homeassistant.components.zha.core.const as zha_const import homeassistant.components.zha.core.registries as registries from .common import get_zha_gateway @@ -28,6 +33,15 @@ async def zha_gateway(hass, setup_zha): return get_zha_gateway(hass) +@pytest.fixture +def channel_pool(): + """Endpoint Channels fixture.""" + ch_pool_mock = mock.MagicMock(spec_set=zha_channels.ChannelPool) + type(ch_pool_mock).skip_configuration = mock.PropertyMock(return_value=False) + ch_pool_mock.id = 1 + return ch_pool_mock + + @pytest.mark.parametrize( "cluster_id, bind_count, attrs", [ @@ -72,7 +86,7 @@ async def zha_gateway(hass, setup_zha): ], ) async def test_in_channel_config( - cluster_id, bind_count, attrs, hass, zigpy_device_mock, zha_gateway + cluster_id, bind_count, attrs, channel_pool, zigpy_device_mock, zha_gateway ): """Test ZHA core channel configuration for input clusters.""" zigpy_dev = zigpy_device_mock( @@ -81,13 +95,12 @@ async def test_in_channel_config( "test manufacturer", "test model", ) - zha_dev = zha_device.ZHADevice(hass, zigpy_dev, zha_gateway) cluster = zigpy_dev.endpoints[1].in_clusters[cluster_id] channel_class = registries.ZIGBEE_CHANNEL_REGISTRY.get( - cluster_id, channels.AttributeListeningChannel + cluster_id, base_channels.AttributeListeningChannel ) - channel = channel_class(cluster, zha_dev) + channel = channel_class(cluster, channel_pool) await channel.async_configure() @@ -130,7 +143,7 @@ async def test_in_channel_config( ], ) async def test_out_channel_config( - cluster_id, bind_count, zha_gateway, hass, zigpy_device_mock + cluster_id, bind_count, channel_pool, zigpy_device_mock, zha_gateway ): """Test ZHA core channel configuration for output clusters.""" zigpy_dev = zigpy_device_mock( @@ -139,14 +152,13 @@ async def test_out_channel_config( "test manufacturer", "test model", ) - zha_dev = zha_device.ZHADevice(hass, zigpy_dev, zha_gateway) cluster = zigpy_dev.endpoints[1].out_clusters[cluster_id] cluster.bind_only = True channel_class = registries.ZIGBEE_CHANNEL_REGISTRY.get( - cluster_id, channels.AttributeListeningChannel + cluster_id, base_channels.AttributeListeningChannel ) - channel = channel_class(cluster, zha_dev) + channel = channel_class(cluster, channel_pool) await channel.async_configure() @@ -159,4 +171,203 @@ def test_channel_registry(): for (cluster_id, channel) in registries.ZIGBEE_CHANNEL_REGISTRY.items(): assert isinstance(cluster_id, int) assert 0 <= cluster_id <= 0xFFFF - assert issubclass(channel, channels.ZigbeeChannel) + assert issubclass(channel, base_channels.ZigbeeChannel) + + +def test_epch_unclaimed_channels(channel): + """Test unclaimed channels.""" + + ch_1 = channel(zha_const.CHANNEL_ON_OFF, 6) + ch_2 = channel(zha_const.CHANNEL_LEVEL, 8) + ch_3 = channel(zha_const.CHANNEL_COLOR, 768) + + ep_channels = zha_channels.ChannelPool( + mock.MagicMock(spec_set=zha_channels.Channels), mock.sentinel.ep + ) + all_channels = {ch_1.id: ch_1, ch_2.id: ch_2, ch_3.id: ch_3} + with mock.patch.dict(ep_channels.all_channels, all_channels, clear=True): + available = ep_channels.unclaimed_channels() + assert ch_1 in available + assert ch_2 in available + assert ch_3 in available + + ep_channels.claimed_channels[ch_2.id] = ch_2 + available = ep_channels.unclaimed_channels() + assert ch_1 in available + assert ch_2 not in available + assert ch_3 in available + + ep_channels.claimed_channels[ch_1.id] = ch_1 + available = ep_channels.unclaimed_channels() + assert ch_1 not in available + assert ch_2 not in available + assert ch_3 in available + + ep_channels.claimed_channels[ch_3.id] = ch_3 + available = ep_channels.unclaimed_channels() + assert ch_1 not in available + assert ch_2 not in available + assert ch_3 not in available + + +def test_epch_claim_channels(channel): + """Test channel claiming.""" + + ch_1 = channel(zha_const.CHANNEL_ON_OFF, 6) + ch_2 = channel(zha_const.CHANNEL_LEVEL, 8) + ch_3 = channel(zha_const.CHANNEL_COLOR, 768) + + ep_channels = zha_channels.ChannelPool( + mock.MagicMock(spec_set=zha_channels.Channels), mock.sentinel.ep + ) + all_channels = {ch_1.id: ch_1, ch_2.id: ch_2, ch_3.id: ch_3} + with mock.patch.dict(ep_channels.all_channels, all_channels, clear=True): + assert ch_1.id not in ep_channels.claimed_channels + assert ch_2.id not in ep_channels.claimed_channels + assert ch_3.id not in ep_channels.claimed_channels + + ep_channels.claim_channels([ch_2]) + assert ch_1.id not in ep_channels.claimed_channels + assert ch_2.id in ep_channels.claimed_channels + assert ep_channels.claimed_channels[ch_2.id] is ch_2 + assert ch_3.id not in ep_channels.claimed_channels + + ep_channels.claim_channels([ch_3, ch_1]) + assert ch_1.id in ep_channels.claimed_channels + assert ep_channels.claimed_channels[ch_1.id] is ch_1 + assert ch_2.id in ep_channels.claimed_channels + assert ep_channels.claimed_channels[ch_2.id] is ch_2 + assert ch_3.id in ep_channels.claimed_channels + assert ep_channels.claimed_channels[ch_3.id] is ch_3 + assert "1:0x0300" in ep_channels.claimed_channels + + +@mock.patch("homeassistant.components.zha.core.channels.ChannelPool.add_relay_channels") +@mock.patch( + "homeassistant.components.zha.core.discovery.PROBE.discover_entities", + mock.MagicMock(), +) +def test_ep_channels_all_channels(m1, zha_device_mock): + """Test EndpointChannels adding all channels.""" + zha_device = zha_device_mock( + { + 1: {"in_clusters": [0, 1, 6, 8], "out_clusters": [], "device_type": 0x0000}, + 2: { + "in_clusters": [0, 1, 6, 8, 768], + "out_clusters": [], + "device_type": 0x0000, + }, + } + ) + channels = zha_channels.Channels(zha_device) + + ep_channels = zha_channels.ChannelPool.new(channels, 1) + assert "1:0x0000" in ep_channels.all_channels + assert "1:0x0001" in ep_channels.all_channels + assert "1:0x0006" in ep_channels.all_channels + assert "1:0x0008" in ep_channels.all_channels + assert "1:0x0300" not in ep_channels.all_channels + assert "2:0x0000" not in ep_channels.all_channels + assert "2:0x0001" not in ep_channels.all_channels + assert "2:0x0006" not in ep_channels.all_channels + assert "2:0x0008" not in ep_channels.all_channels + assert "2:0x0300" not in ep_channels.all_channels + + channels = zha_channels.Channels(zha_device) + ep_channels = zha_channels.ChannelPool.new(channels, 2) + assert "1:0x0000" not in ep_channels.all_channels + assert "1:0x0001" not in ep_channels.all_channels + assert "1:0x0006" not in ep_channels.all_channels + assert "1:0x0008" not in ep_channels.all_channels + assert "1:0x0300" not in ep_channels.all_channels + assert "2:0x0000" in ep_channels.all_channels + assert "2:0x0001" in ep_channels.all_channels + assert "2:0x0006" in ep_channels.all_channels + assert "2:0x0008" in ep_channels.all_channels + assert "2:0x0300" in ep_channels.all_channels + + +@mock.patch("homeassistant.components.zha.core.channels.ChannelPool.add_relay_channels") +@mock.patch( + "homeassistant.components.zha.core.discovery.PROBE.discover_entities", + mock.MagicMock(), +) +def test_channel_power_config(m1, zha_device_mock): + """Test that channels only get a single power channel.""" + in_clusters = [0, 1, 6, 8] + zha_device = zha_device_mock( + { + 1: {"in_clusters": in_clusters, "out_clusters": [], "device_type": 0x0000}, + 2: { + "in_clusters": [*in_clusters, 768], + "out_clusters": [], + "device_type": 0x0000, + }, + } + ) + channels = zha_channels.Channels.new(zha_device) + pools = {pool.id: pool for pool in channels.pools} + assert "1:0x0000" in pools[1].all_channels + assert "1:0x0001" in pools[1].all_channels + assert "1:0x0006" in pools[1].all_channels + assert "1:0x0008" in pools[1].all_channels + assert "1:0x0300" not in pools[1].all_channels + assert "2:0x0000" in pools[2].all_channels + assert "2:0x0001" not in pools[2].all_channels + assert "2:0x0006" in pools[2].all_channels + assert "2:0x0008" in pools[2].all_channels + assert "2:0x0300" in pools[2].all_channels + + zha_device = zha_device_mock( + { + 1: {"in_clusters": [], "out_clusters": [], "device_type": 0x0000}, + 2: {"in_clusters": in_clusters, "out_clusters": [], "device_type": 0x0000}, + } + ) + channels = zha_channels.Channels.new(zha_device) + pools = {pool.id: pool for pool in channels.pools} + assert "1:0x0001" not in pools[1].all_channels + assert "2:0x0001" in pools[2].all_channels + + zha_device = zha_device_mock( + {2: {"in_clusters": in_clusters, "out_clusters": [], "device_type": 0x0000}} + ) + channels = zha_channels.Channels.new(zha_device) + pools = {pool.id: pool for pool in channels.pools} + assert "2:0x0001" in pools[2].all_channels + + +async def test_ep_channels_configure(channel): + """Test unclaimed channels.""" + + ch_1 = channel(zha_const.CHANNEL_ON_OFF, 6) + ch_2 = channel(zha_const.CHANNEL_LEVEL, 8) + ch_3 = channel(zha_const.CHANNEL_COLOR, 768) + ch_3.async_configure = asynctest.CoroutineMock(side_effect=asyncio.TimeoutError) + ch_3.async_initialize = asynctest.CoroutineMock(side_effect=asyncio.TimeoutError) + ch_4 = channel(zha_const.CHANNEL_ON_OFF, 6) + ch_5 = channel(zha_const.CHANNEL_LEVEL, 8) + ch_5.async_configure = asynctest.CoroutineMock(side_effect=asyncio.TimeoutError) + ch_5.async_initialize = asynctest.CoroutineMock(side_effect=asyncio.TimeoutError) + + channels = mock.MagicMock(spec_set=zha_channels.Channels) + type(channels).semaphore = mock.PropertyMock(return_value=asyncio.Semaphore(3)) + ep_channels = zha_channels.ChannelPool(channels, mock.sentinel.ep) + + claimed = {ch_1.id: ch_1, ch_2.id: ch_2, ch_3.id: ch_3} + relay = {ch_4.id: ch_4, ch_5.id: ch_5} + + with mock.patch.dict(ep_channels.claimed_channels, claimed, clear=True): + with mock.patch.dict(ep_channels.relay_channels, relay, clear=True): + await ep_channels.async_configure() + await ep_channels.async_initialize(mock.sentinel.from_cache) + + for ch in [*claimed.values(), *relay.values()]: + assert ch.async_initialize.call_count == 1 + assert ch.async_initialize.await_count == 1 + assert ch.async_initialize.call_args[0][0] is mock.sentinel.from_cache + assert ch.async_configure.call_count == 1 + assert ch.async_configure.await_count == 1 + + assert ch_3.warning.call_count == 2 + assert ch_5.warning.call_count == 2 diff --git a/tests/components/zha/test_cover.py b/tests/components/zha/test_cover.py index e5883605e34..4fbabf4485a 100644 --- a/tests/components/zha/test_cover.py +++ b/tests/components/zha/test_cover.py @@ -45,7 +45,7 @@ async def test_cover(m1, hass, zha_device_joined_restored, zigpy_cover_device): return 100 with patch( - "homeassistant.components.zha.core.channels.ZigbeeChannel.get_attribute_value", + "homeassistant.components.zha.core.channels.base.ZigbeeChannel.get_attribute_value", new=MagicMock(side_effect=get_chan_attr), ) as get_attr_mock: # load up cover domain diff --git a/tests/components/zha/test_device_action.py b/tests/components/zha/test_device_action.py index 8866e6cff55..c779dda6cf8 100644 --- a/tests/components/zha/test_device_action.py +++ b/tests/components/zha/test_device_action.py @@ -11,7 +11,6 @@ from homeassistant.components.device_automation import ( _async_get_device_automations as async_get_device_automations, ) from homeassistant.components.zha import DOMAIN -from homeassistant.components.zha.core.const import CHANNEL_EVENT_RELAY from homeassistant.helpers.device_registry import async_get_registry from homeassistant.setup import async_setup_component @@ -104,8 +103,8 @@ async def test_action(hass, device_ias): await hass.async_block_till_done() calls = async_mock_service(hass, DOMAIN, "warning_device_warn") - channel = {ch.name: ch for ch in zha_device.all_channels}[CHANNEL_EVENT_RELAY] - channel.zha_send_event(channel.cluster, COMMAND_SINGLE, []) + channel = zha_device.channels.pools[0].relay_channels["1:0x0006"] + channel.zha_send_event(COMMAND_SINGLE, []) await hass.async_block_till_done() assert len(calls) == 1 diff --git a/tests/components/zha/test_device_trigger.py b/tests/components/zha/test_device_trigger.py index 4bb7567d1e6..9b69ba06e4f 100644 --- a/tests/components/zha/test_device_trigger.py +++ b/tests/components/zha/test_device_trigger.py @@ -3,7 +3,6 @@ import pytest import zigpy.zcl.clusters.general as general import homeassistant.components.automation as automation -from homeassistant.components.zha.core.const import CHANNEL_EVENT_RELAY from homeassistant.helpers.device_registry import async_get_registry from homeassistant.setup import async_setup_component @@ -173,8 +172,8 @@ async def test_if_fires_on_event(hass, mock_devices, calls): await hass.async_block_till_done() - channel = {ch.name: ch for ch in zha_device.all_channels}[CHANNEL_EVENT_RELAY] - channel.zha_send_event(channel.cluster, COMMAND_SINGLE, []) + channel = zha_device.channels.pools[0].relay_channels["1:0x0006"] + channel.zha_send_event(COMMAND_SINGLE, []) await hass.async_block_till_done() assert len(calls) == 1 diff --git a/tests/components/zha/test_discover.py b/tests/components/zha/test_discover.py index a194453bd65..c8f2eb0dd7c 100644 --- a/tests/components/zha/test_discover.py +++ b/tests/components/zha/test_discover.py @@ -4,10 +4,24 @@ import re from unittest import mock import pytest +import zigpy.quirks +import zigpy.zcl.clusters.closures +import zigpy.zcl.clusters.general +import zigpy.zcl.clusters.security +import homeassistant.components.zha.binary_sensor +import homeassistant.components.zha.core.channels as zha_channels +import homeassistant.components.zha.core.channels.base as base_channels import homeassistant.components.zha.core.const as zha_const import homeassistant.components.zha.core.discovery as disc -import homeassistant.components.zha.core.gateway as core_zha_gw +import homeassistant.components.zha.core.registries as zha_regs +import homeassistant.components.zha.cover +import homeassistant.components.zha.device_tracker +import homeassistant.components.zha.fan +import homeassistant.components.zha.light +import homeassistant.components.zha.lock +import homeassistant.components.zha.sensor +import homeassistant.components.zha.switch import homeassistant.helpers.entity_registry from .common import get_zha_gateway @@ -16,12 +30,34 @@ from .zha_devices_list import DEVICES NO_TAIL_ID = re.compile("_\\d$") +@pytest.fixture +def channels_mock(zha_device_mock): + """Channels mock factory.""" + + def _mock( + endpoints, + ieee="00:11:22:33:44:55:66:77", + manufacturer="mock manufacturer", + model="mock model", + node_desc=b"\x02@\x807\x10\x7fd\x00\x00*d\x00\x00", + ): + zha_dev = zha_device_mock(endpoints, ieee, manufacturer, model, node_desc) + channels = zha_channels.Channels.new(zha_dev) + return channels + + return _mock + + @pytest.mark.parametrize("device", DEVICES) async def test_devices( device, hass, zigpy_device_mock, monkeypatch, zha_device_joined_restored ): """Test device discovery.""" + entity_registry = await homeassistant.helpers.entity_registry.async_get_registry( + hass + ) + zigpy_device = zigpy_device_mock( device["endpoints"], "00:11:22:33:44:55:66:77", @@ -30,45 +66,298 @@ async def test_devices( node_descriptor=device["node_descriptor"], ) - _dispatch = mock.MagicMock(wraps=disc.async_dispatch_discovery_info) - monkeypatch.setattr(core_zha_gw, "async_dispatch_discovery_info", _dispatch) - entity_registry = await homeassistant.helpers.entity_registry.async_get_registry( - hass + orig_new_entity = zha_channels.ChannelPool.async_new_entity + _dispatch = mock.MagicMock(wraps=orig_new_entity) + try: + zha_channels.ChannelPool.async_new_entity = lambda *a, **kw: _dispatch(*a, **kw) + zha_dev = await zha_device_joined_restored(zigpy_device) + await hass.async_block_till_done() + finally: + zha_channels.ChannelPool.async_new_entity = orig_new_entity + + entity_ids = hass.states.async_entity_ids() + await hass.async_block_till_done() + zha_entity_ids = { + ent for ent in entity_ids if ent.split(".")[0] in zha_const.COMPONENTS + } + + event_channels = { + ch.id for pool in zha_dev.channels.pools for ch in pool.relay_channels.values() + } + + entity_map = device["entity_map"] + assert zha_entity_ids == set( + [ + e["entity_id"] + for e in entity_map.values() + if not e.get("default_match", False) + ] ) + assert event_channels == set(device["event_channels"]) + + for call in _dispatch.call_args_list: + _, component, entity_cls, unique_id, channels = call[0] + key = (component, unique_id) + entity_id = entity_registry.async_get_entity_id(component, "zha", unique_id) + + assert key in entity_map + assert entity_id is not None + no_tail_id = NO_TAIL_ID.sub("", entity_map[key]["entity_id"]) + assert entity_id.startswith(no_tail_id) + assert set([ch.name for ch in channels]) == set(entity_map[key]["channels"]) + assert entity_cls.__name__ == entity_map[key]["entity_class"] + + +@mock.patch( + "homeassistant.components.zha.core.discovery.ProbeEndpoint.discover_by_device_type" +) +@mock.patch( + "homeassistant.components.zha.core.discovery.ProbeEndpoint.discover_by_cluster_id" +) +def test_discover_entities(m1, m2): + """Test discover endpoint class method.""" + ep_channels = mock.MagicMock() + disc.PROBE.discover_entities(ep_channels) + assert m1.call_count == 1 + assert m1.call_args[0][0] is ep_channels + assert m2.call_count == 1 + assert m2.call_args[0][0] is ep_channels + + +@pytest.mark.parametrize( + "device_type, component, hit", + [ + (0x0100, zha_const.LIGHT, True), + (0x0108, zha_const.SWITCH, True), + (0x0051, zha_const.SWITCH, True), + (0xFFFF, None, False), + ], +) +def test_discover_by_device_type(device_type, component, hit): + """Test entity discovery by device type.""" + + ep_channels = mock.MagicMock(spec_set=zha_channels.ChannelPool) + ep_mock = mock.PropertyMock() + ep_mock.return_value.profile_id = 0x0104 + ep_mock.return_value.device_type = device_type + type(ep_channels).endpoint = ep_mock + + get_entity_mock = mock.MagicMock( + return_value=(mock.sentinel.entity_cls, mock.sentinel.claimed) + ) + with mock.patch( + "homeassistant.components.zha.core.registries.ZHA_ENTITIES.get_entity", + get_entity_mock, + ): + disc.PROBE.discover_by_device_type(ep_channels) + if hit: + assert get_entity_mock.call_count == 1 + assert ep_channels.claim_channels.call_count == 1 + assert ep_channels.claim_channels.call_args[0][0] is mock.sentinel.claimed + assert ep_channels.async_new_entity.call_count == 1 + assert ep_channels.async_new_entity.call_args[0][0] == component + assert ep_channels.async_new_entity.call_args[0][1] == mock.sentinel.entity_cls + + +def test_discover_by_device_type_override(): + """Test entity discovery by device type overriding.""" + + ep_channels = mock.MagicMock(spec_set=zha_channels.ChannelPool) + ep_mock = mock.PropertyMock() + ep_mock.return_value.profile_id = 0x0104 + ep_mock.return_value.device_type = 0x0100 + type(ep_channels).endpoint = ep_mock + + overrides = {ep_channels.unique_id: {"type": zha_const.SWITCH}} + get_entity_mock = mock.MagicMock( + return_value=(mock.sentinel.entity_cls, mock.sentinel.claimed) + ) + with mock.patch( + "homeassistant.components.zha.core.registries.ZHA_ENTITIES.get_entity", + get_entity_mock, + ): + with mock.patch.dict(disc.PROBE._device_configs, overrides, clear=True): + disc.PROBE.discover_by_device_type(ep_channels) + assert get_entity_mock.call_count == 1 + assert ep_channels.claim_channels.call_count == 1 + assert ep_channels.claim_channels.call_args[0][0] is mock.sentinel.claimed + assert ep_channels.async_new_entity.call_count == 1 + assert ep_channels.async_new_entity.call_args[0][0] == zha_const.SWITCH + assert ( + ep_channels.async_new_entity.call_args[0][1] == mock.sentinel.entity_cls + ) + + +def test_discover_probe_single_cluster(): + """Test entity discovery by single cluster.""" + + ep_channels = mock.MagicMock(spec_set=zha_channels.ChannelPool) + ep_mock = mock.PropertyMock() + ep_mock.return_value.profile_id = 0x0104 + ep_mock.return_value.device_type = 0x0100 + type(ep_channels).endpoint = ep_mock + + get_entity_mock = mock.MagicMock( + return_value=(mock.sentinel.entity_cls, mock.sentinel.claimed) + ) + channel_mock = mock.MagicMock(spec_set=base_channels.ZigbeeChannel) + with mock.patch( + "homeassistant.components.zha.core.registries.ZHA_ENTITIES.get_entity", + get_entity_mock, + ): + disc.PROBE.probe_single_cluster(zha_const.SWITCH, channel_mock, ep_channels) + + assert get_entity_mock.call_count == 1 + assert ep_channels.claim_channels.call_count == 1 + assert ep_channels.claim_channels.call_args[0][0] is mock.sentinel.claimed + assert ep_channels.async_new_entity.call_count == 1 + assert ep_channels.async_new_entity.call_args[0][0] == zha_const.SWITCH + assert ep_channels.async_new_entity.call_args[0][1] == mock.sentinel.entity_cls + assert ep_channels.async_new_entity.call_args[0][3] == mock.sentinel.claimed + + +@pytest.mark.parametrize("device_info", DEVICES) +async def test_discover_endpoint(device_info, channels_mock, hass): + """Test device discovery.""" with mock.patch( - "homeassistant.components.zha.core.discovery._async_create_cluster_channel", - wraps=disc._async_create_cluster_channel, + "homeassistant.components.zha.core.channels.Channels.async_new_entity" + ) as new_ent: + channels = channels_mock( + device_info["endpoints"], + manufacturer=device_info["manufacturer"], + model=device_info["model"], + node_desc=device_info["node_descriptor"], + ) + + assert device_info["event_channels"] == sorted( + [ch.id for pool in channels.pools for ch in pool.relay_channels.values()] + ) + assert new_ent.call_count == len( + [ + device_info + for device_info in device_info["entity_map"].values() + if not device_info.get("default_match", False) + ] + ) + + for call_args in new_ent.call_args_list: + comp, ent_cls, unique_id, channels = call_args[0] + map_id = (comp, unique_id) + assert map_id in device_info["entity_map"] + entity_info = device_info["entity_map"][map_id] + assert set([ch.name for ch in channels]) == set(entity_info["channels"]) + assert ent_cls.__name__ == entity_info["entity_class"] + + +def _ch_mock(cluster): + """Return mock of a channel with a cluster.""" + channel = mock.MagicMock() + type(channel).cluster = mock.PropertyMock(return_value=cluster(mock.MagicMock())) + return channel + + +@mock.patch( + "homeassistant.components.zha.core.discovery.ProbeEndpoint" + ".handle_on_off_output_cluster_exception", + new=mock.MagicMock(), +) +@mock.patch( + "homeassistant.components.zha.core.discovery.ProbeEndpoint.probe_single_cluster" +) +def _test_single_input_cluster_device_class(probe_mock): + """Test SINGLE_INPUT_CLUSTER_DEVICE_CLASS matching by cluster id or class.""" + + door_ch = _ch_mock(zigpy.zcl.clusters.closures.DoorLock) + cover_ch = _ch_mock(zigpy.zcl.clusters.closures.WindowCovering) + multistate_ch = _ch_mock(zigpy.zcl.clusters.general.MultistateInput) + + class QuirkedIAS(zigpy.quirks.CustomCluster, zigpy.zcl.clusters.security.IasZone): + pass + + ias_ch = _ch_mock(QuirkedIAS) + + class _Analog(zigpy.quirks.CustomCluster, zigpy.zcl.clusters.general.AnalogInput): + pass + + analog_ch = _ch_mock(_Analog) + + ch_pool = mock.MagicMock(spec_set=zha_channels.ChannelPool) + ch_pool.unclaimed_channels.return_value = [ + door_ch, + cover_ch, + multistate_ch, + ias_ch, + analog_ch, + ] + + disc.ProbeEndpoint().discover_by_cluster_id(ch_pool) + assert probe_mock.call_count == len(ch_pool.unclaimed_channels()) + probes = ( + (zha_const.LOCK, door_ch), + (zha_const.COVER, cover_ch), + (zha_const.SENSOR, multistate_ch), + (zha_const.BINARY_SENSOR, ias_ch), + (zha_const.SENSOR, analog_ch), + ) + for call, details in zip(probe_mock.call_args_list, probes): + component, ch = details + assert call[0][0] == component + assert call[0][1] == ch + + +def test_single_input_cluster_device_class(): + """Test SINGLE_INPUT_CLUSTER_DEVICE_CLASS matching by cluster id or class.""" + _test_single_input_cluster_device_class() + + +def test_single_input_cluster_device_class_by_cluster_class(): + """Test SINGLE_INPUT_CLUSTER_DEVICE_CLASS matching by cluster id or class.""" + mock_reg = { + zigpy.zcl.clusters.closures.DoorLock.cluster_id: zha_const.LOCK, + zigpy.zcl.clusters.closures.WindowCovering.cluster_id: zha_const.COVER, + zigpy.zcl.clusters.general.AnalogInput: zha_const.SENSOR, + zigpy.zcl.clusters.general.MultistateInput: zha_const.SENSOR, + zigpy.zcl.clusters.security.IasZone: zha_const.BINARY_SENSOR, + } + + with mock.patch.dict( + zha_regs.SINGLE_INPUT_CLUSTER_DEVICE_CLASS, mock_reg, clear=True ): - await zha_device_joined_restored(zigpy_device) - await hass.async_block_till_done() + _test_single_input_cluster_device_class() - entity_ids = hass.states.async_entity_ids() - await hass.async_block_till_done() - zha_entities = { - ent for ent in entity_ids if ent.split(".")[0] in zha_const.COMPONENTS - } - zha_gateway = get_zha_gateway(hass) - zha_dev = zha_gateway.get_device(zigpy_device.ieee) - event_channels = { # pylint: disable=protected-access - ch.id for ch in zha_dev._relay_channels.values() - } +@pytest.mark.parametrize( + "override, entity_id", + [ + (None, "light.manufacturer_model_77665544_level_light_color_on_off"), + ("switch", "switch.manufacturer_model_77665544_on_off"), + ], +) +async def test_device_override(hass, zigpy_device_mock, setup_zha, override, entity_id): + """Test device discovery override.""" - assert zha_entities == set(device["entities"]) - assert event_channels == set(device["event_channels"]) + zigpy_device = zigpy_device_mock( + { + 1: { + "device_type": 258, + "endpoint_id": 1, + "in_clusters": [0, 3, 4, 5, 6, 8, 768, 2821, 64513], + "out_clusters": [25], + "profile_id": 260, + } + }, + "00:11:22:33:44:55:66:77", + "manufacturer", + "model", + ) - entity_map = device["entity_map"] - for calls in _dispatch.call_args_list: - discovery_info = calls[0][2] - unique_id = discovery_info["unique_id"] - channels = discovery_info["channels"] - component = discovery_info["component"] - key = (component, unique_id) - entity_id = entity_registry.async_get_entity_id(component, "zha", unique_id) + if override is not None: + override = {"device_config": {"00:11:22:33:44:55:66:77-1": {"type": override}}} - assert key in entity_map - assert entity_id is not None - no_tail_id = NO_TAIL_ID.sub("", entity_map[key]["entity_id"]) - assert entity_id.startswith(no_tail_id) - assert set([ch.name for ch in channels]) == set(entity_map[key]["channels"]) + await setup_zha(override) + assert hass.states.get(entity_id) is None + zha_gateway = get_zha_gateway(hass) + await zha_gateway.async_device_initialized(zigpy_device) + await hass.async_block_till_done() + assert hass.states.get(entity_id) is not None diff --git a/tests/components/zha/test_registries.py b/tests/components/zha/test_registries.py index 383b61e6c66..fc41a409518 100644 --- a/tests/components/zha/test_registries.py +++ b/tests/components/zha/test_registries.py @@ -55,8 +55,20 @@ def channels(channel): # manufacturer matching (registries.MatchRule(manufacturers="no match"), False), (registries.MatchRule(manufacturers=MANUFACTURER), True), + ( + registries.MatchRule(manufacturers="no match", aux_channels="aux_channel"), + False, + ), + ( + registries.MatchRule( + manufacturers=MANUFACTURER, aux_channels="aux_channel" + ), + True, + ), (registries.MatchRule(models=MODEL), True), (registries.MatchRule(models="no match"), False), + (registries.MatchRule(models=MODEL, aux_channels="aux_channel"), True), + (registries.MatchRule(models="no match", aux_channels="aux_channel"), False), # match everything ( registries.MatchRule( @@ -113,10 +125,9 @@ def channels(channel): ), ], ) -def test_registry_matching(rule, matched, zha_device, channels): +def test_registry_matching(rule, matched, channels): """Test strict rule matching.""" - reg = registries.ZHAEntityRegistry() - assert reg._strict_matched(zha_device, channels, rule) is matched + assert rule.strict_matched(MANUFACTURER, MODEL, channels) is matched @pytest.mark.parametrize( @@ -197,7 +208,49 @@ def test_registry_matching(rule, matched, zha_device, channels): ), ], ) -def test_registry_loose_matching(rule, matched, zha_device, channels): +def test_registry_loose_matching(rule, matched, channels): """Test loose rule matching.""" - reg = registries.ZHAEntityRegistry() - assert reg._loose_matched(zha_device, channels, rule) is matched + assert rule.loose_matched(MANUFACTURER, MODEL, channels) is matched + + +def test_match_rule_claim_channels_color(channel): + """Test channel claiming.""" + ch_color = channel("color", 0x300) + ch_level = channel("level", 8) + ch_onoff = channel("on_off", 6) + + rule = registries.MatchRule(channel_names="on_off", aux_channels={"color", "level"}) + claimed = rule.claim_channels([ch_color, ch_level, ch_onoff]) + assert {"color", "level", "on_off"} == set([ch.name for ch in claimed]) + + +@pytest.mark.parametrize( + "rule, match", + [ + (registries.MatchRule(channel_names={"level"}), {"level"}), + (registries.MatchRule(channel_names={"level", "no match"}), {"level"}), + (registries.MatchRule(channel_names={"on_off"}), {"on_off"}), + (registries.MatchRule(generic_ids="channel_0x0000"), {"basic"}), + ( + registries.MatchRule(channel_names="level", generic_ids="channel_0x0000"), + {"basic", "level"}, + ), + (registries.MatchRule(channel_names={"level", "power"}), {"level", "power"}), + ( + registries.MatchRule( + channel_names={"level", "on_off"}, aux_channels={"basic", "power"} + ), + {"basic", "level", "on_off", "power"}, + ), + (registries.MatchRule(channel_names={"color"}), set()), + ], +) +def test_match_rule_claim_channels(rule, match, channel, channels): + """Test channel claiming.""" + ch_basic = channel("basic", 0) + channels.append(ch_basic) + ch_power = channel("power", 1) + channels.append(ch_power) + + claimed = rule.claim_channels(channels) + assert match == set([ch.name for ch in claimed]) diff --git a/tests/components/zha/zha_devices_list.py b/tests/components/zha/zha_devices_list.py index a8c83406435..a3dc4f1d780 100644 --- a/tests/components/zha/zha_devices_list.py +++ b/tests/components/zha/zha_devices_list.py @@ -523,7 +523,7 @@ DEVICES = [ "channels": ["ias_zone"], "entity_class": "IASZone", "entity_id": "binary_sensor.heiman_co_v16_77665544_ias_zone", - }, + } }, "event_channels": [], "manufacturer": "Heiman", @@ -547,7 +547,7 @@ DEVICES = [ "channels": ["ias_zone"], "entity_class": "IASZone", "entity_id": "binary_sensor.heiman_warningdevice_77665544_ias_zone", - }, + } }, "event_channels": [], "manufacturer": "Heiman", @@ -1036,7 +1036,6 @@ DEVICES = [ } }, "entities": [ - "binary_sensor.keen_home_inc_sv02_610_mp_1_3_77665544_manufacturer_specific", "light.keen_home_inc_sv02_610_mp_1_3_77665544_level_on_off", "sensor.keen_home_inc_sv02_610_mp_1_3_77665544_power", "sensor.keen_home_inc_sv02_610_mp_1_3_77665544_pressure", @@ -1063,12 +1062,6 @@ DEVICES = [ "entity_class": "Pressure", "entity_id": "sensor.keen_home_inc_sv02_610_mp_1_3_77665544_pressure", }, - ("binary_sensor", "00:11:22:33:44:55:66:77-1-64514"): { - "channels": ["manufacturer_specific"], - "entity_class": "BinarySensor", - "entity_id": "binary_sensor.keen_home_inc_sv02_610_mp_1_3_77665544_manufacturer_specific", - "default_match": True, - }, }, "event_channels": [], "manufacturer": "Keen Home Inc", @@ -1101,7 +1094,6 @@ DEVICES = [ } }, "entities": [ - "binary_sensor.keen_home_inc_sv02_612_mp_1_2_77665544_manufacturer_specific", "light.keen_home_inc_sv02_612_mp_1_2_77665544_level_on_off", "sensor.keen_home_inc_sv02_612_mp_1_2_77665544_power", "sensor.keen_home_inc_sv02_612_mp_1_2_77665544_pressure", @@ -1128,12 +1120,6 @@ DEVICES = [ "entity_class": "Pressure", "entity_id": "sensor.keen_home_inc_sv02_612_mp_1_2_77665544_pressure", }, - ("binary_sensor", "00:11:22:33:44:55:66:77-1-64514"): { - "channels": ["manufacturer_specific"], - "entity_class": "BinarySensor", - "entity_id": "binary_sensor.keen_home_inc_sv02_612_mp_1_2_77665544_manufacturer_specific", - "default_match": True, - }, }, "event_channels": [], "manufacturer": "Keen Home Inc", @@ -1166,7 +1152,6 @@ DEVICES = [ } }, "entities": [ - "binary_sensor.keen_home_inc_sv02_612_mp_1_3_77665544_manufacturer_specific", "light.keen_home_inc_sv02_612_mp_1_3_77665544_level_on_off", "sensor.keen_home_inc_sv02_612_mp_1_3_77665544_power", "sensor.keen_home_inc_sv02_612_mp_1_3_77665544_pressure", @@ -1193,12 +1178,6 @@ DEVICES = [ "entity_class": "Pressure", "entity_id": "sensor.keen_home_inc_sv02_612_mp_1_3_77665544_pressure", }, - ("binary_sensor", "00:11:22:33:44:55:66:77-1-64514"): { - "channels": ["manufacturer_specific"], - "entity_class": "BinarySensor", - "entity_id": "binary_sensor.keen_home_inc_sv02_612_mp_1_3_77665544_manufacturer_specific", - "default_match": True, - }, }, "event_channels": [], "manufacturer": "Keen Home Inc", @@ -1784,13 +1763,21 @@ DEVICES = [ "profile_id": 260, } }, - "entities": ["light.lumi_lumi_router_77665544_on_off_on_off"], + "entities": [ + "binary_sensor.lumi_lumi_router_77665544_on_off", + "light.lumi_lumi_router_77665544_on_off", + ], "entity_map": { + ("binary_sensor", "00:11:22:33:44:55:66:77-8-6"): { + "channels": ["on_off", "on_off"], + "entity_class": "Opening", + "entity_id": "binary_sensor.lumi_lumi_router_77665544_on_off", + }, ("light", "00:11:22:33:44:55:66:77-8"): { "channels": ["on_off", "on_off"], "entity_class": "Light", - "entity_id": "light.lumi_lumi_router_77665544_on_off_on_off", - } + "entity_id": "light.lumi_lumi_router_77665544_on_off", + }, }, "event_channels": ["8:0x0006"], "manufacturer": "LUMI", @@ -1808,13 +1795,21 @@ DEVICES = [ "profile_id": 260, } }, - "entities": ["light.lumi_lumi_router_77665544_on_off_on_off"], + "entities": [ + "binary_sensor.lumi_lumi_router_77665544_on_off", + "light.lumi_lumi_router_77665544_on_off", + ], "entity_map": { + ("binary_sensor", "00:11:22:33:44:55:66:77-8-6"): { + "channels": ["on_off", "on_off"], + "entity_class": "Opening", + "entity_id": "binary_sensor.lumi_lumi_router_77665544_on_off", + }, ("light", "00:11:22:33:44:55:66:77-8"): { "channels": ["on_off", "on_off"], "entity_class": "Light", - "entity_id": "light.lumi_lumi_router_77665544_on_off_on_off", - } + "entity_id": "light.lumi_lumi_router_77665544_on_off", + }, }, "event_channels": ["8:0x0006"], "manufacturer": "LUMI", @@ -1832,13 +1827,21 @@ DEVICES = [ "profile_id": 260, } }, - "entities": ["light.lumi_lumi_router_77665544_on_off_on_off"], + "entities": [ + "binary_sensor.lumi_lumi_router_77665544_on_off", + "light.lumi_lumi_router_77665544_on_off", + ], "entity_map": { + ("binary_sensor", "00:11:22:33:44:55:66:77-8-6"): { + "channels": ["on_off", "on_off"], + "entity_class": "Opening", + "entity_id": "binary_sensor.lumi_lumi_router_77665544_on_off", + }, ("light", "00:11:22:33:44:55:66:77-8"): { "channels": ["on_off", "on_off"], "entity_class": "Light", - "entity_id": "light.lumi_lumi_router_77665544_on_off_on_off", - } + "entity_id": "light.lumi_lumi_router_77665544_on_off", + }, }, "event_channels": ["8:0x0006"], "manufacturer": "LUMI", @@ -1862,7 +1865,7 @@ DEVICES = [ "channels": ["illuminance"], "entity_class": "Illuminance", "entity_id": "sensor.lumi_lumi_sen_ill_mgl01_77665544_illuminance", - }, + } }, "event_channels": [], "manufacturer": "LUMI",