Make the rest of ZHA platforms to use ZHA class registry (#30261)

* Refactor ZHA component tests fixtures.

* Add tests for ZHA device discovery.

* Refactor ZHA registry MatchRule.

Allow callables as a matching criteria.
Allow sets for model & manufacturer.

* Minor ZHA class registry refactoring.

Less cluttered strict_matching registrations.

* Add entities only if there are any.

* Migrate rest of ZHA platforms to ZHA registry.

* Pylint fixes.
This commit is contained in:
Alexei Chetroi 2019-12-31 11:09:58 -05:00 committed by David F. Mulcahey
parent 5ed44297e6
commit a3061bda60
13 changed files with 2152 additions and 80 deletions

View File

@ -28,7 +28,7 @@ from .core.const import (
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
ZHA_DISCOVERY_NEW, ZHA_DISCOVERY_NEW,
) )
from .core.registries import ZHA_ENTITIES, MatchRule from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -85,7 +85,8 @@ async def _async_setup_entities(
if entity: if entity:
entities.append(entity(**discovery_info)) entities.append(entity(**discovery_info))
async_add_entities(entities, update_before_add=True) if entities:
async_add_entities(entities, update_before_add=True)
class BinarySensor(ZhaEntity, BinarySensorDevice): class BinarySensor(ZhaEntity, BinarySensorDevice):
@ -141,28 +142,28 @@ class BinarySensor(ZhaEntity, BinarySensorDevice):
self._state = await self._channel.get_attribute_value(attribute) self._state = await self._channel.get_attribute_value(attribute)
@STRICT_MATCH(MatchRule(channel_names={CHANNEL_ACCELEROMETER})) @STRICT_MATCH(channel_names=CHANNEL_ACCELEROMETER)
class Accelerometer(BinarySensor): class Accelerometer(BinarySensor):
"""ZHA BinarySensor.""" """ZHA BinarySensor."""
DEVICE_CLASS = DEVICE_CLASS_MOVING DEVICE_CLASS = DEVICE_CLASS_MOVING
@STRICT_MATCH(MatchRule(channel_names={CHANNEL_OCCUPANCY})) @STRICT_MATCH(channel_names=CHANNEL_OCCUPANCY)
class Occupancy(BinarySensor): class Occupancy(BinarySensor):
"""ZHA BinarySensor.""" """ZHA BinarySensor."""
DEVICE_CLASS = DEVICE_CLASS_OCCUPANCY DEVICE_CLASS = DEVICE_CLASS_OCCUPANCY
@STRICT_MATCH(MatchRule(channel_names={CHANNEL_ON_OFF})) @STRICT_MATCH(channel_names=CHANNEL_ON_OFF)
class Opening(BinarySensor): class Opening(BinarySensor):
"""ZHA BinarySensor.""" """ZHA BinarySensor."""
DEVICE_CLASS = DEVICE_CLASS_OPENING DEVICE_CLASS = DEVICE_CLASS_OPENING
@STRICT_MATCH(MatchRule(channel_names={CHANNEL_ZONE})) @STRICT_MATCH(channel_names=CHANNEL_ZONE)
class IASZone(BinarySensor): class IASZone(BinarySensor):
"""ZHA IAS BinarySensor.""" """ZHA IAS BinarySensor."""

View File

@ -5,7 +5,7 @@ For more details about this component, please refer to the documentation at
https://home-assistant.io/integrations/zha/ https://home-assistant.io/integrations/zha/
""" """
import collections import collections
from typing import Callable, Set from typing import Callable, Set, Union
import attr import attr
import bellows.ezsp import bellows.ezsp
@ -171,14 +171,33 @@ def establish_device_mappings():
REMOTE_DEVICE_TYPES[zll.PROFILE_ID].append(zll.DeviceType.SCENE_CONTROLLER) REMOTE_DEVICE_TYPES[zll.PROFILE_ID].append(zll.DeviceType.SCENE_CONTROLLER)
def set_or_callable(value):
"""Convert single str or None to a set. Pass through callables and sets."""
if value is None:
return frozenset()
if callable(value):
return value
if isinstance(value, (frozenset, set, list)):
return frozenset(value)
return frozenset([str(value)])
@attr.s(frozen=True) @attr.s(frozen=True)
class MatchRule: class MatchRule:
"""Match a ZHA Entity to a channel name or generic id.""" """Match a ZHA Entity to a channel name or generic id."""
channel_names: Set[str] = attr.ib(factory=frozenset, converter=frozenset) channel_names: Union[Callable, Set[str], str] = attr.ib(
generic_ids: Set[str] = attr.ib(factory=frozenset, converter=frozenset) factory=frozenset, converter=set_or_callable
manufacturer: str = attr.ib(default=None) )
model: str = attr.ib(default=None) generic_ids: Union[Callable, Set[str], str] = attr.ib(
factory=frozenset, converter=set_or_callable
)
manufacturers: Union[Callable, Set[str], str] = attr.ib(
factory=frozenset, converter=set_or_callable
)
models: Union[Callable, Set[str], str] = attr.ib(
factory=frozenset, converter=set_or_callable
)
class ZHAEntityRegistry: class ZHAEntityRegistry:
@ -190,7 +209,7 @@ class ZHAEntityRegistry:
self._loose_registry = collections.defaultdict(dict) self._loose_registry = collections.defaultdict(dict)
def get_entity( def get_entity(
self, component: str, zha_device, chnls: list, default: CALLABLE_T = None self, component: str, zha_device, chnls: dict, default: CALLABLE_T = None
) -> CALLABLE_T: ) -> CALLABLE_T:
"""Match a ZHA Channels to a ZHA Entity class.""" """Match a ZHA Channels to a ZHA Entity class."""
for match in self._strict_registry[component]: for match in self._strict_registry[component]:
@ -200,10 +219,17 @@ class ZHAEntityRegistry:
return default return default
def strict_match( def strict_match(
self, component: str, rule: MatchRule self,
component: str,
channel_names: Union[Callable, Set[str], str] = None,
generic_ids: Union[Callable, Set[str], str] = None,
manufacturers: Union[Callable, Set[str], str] = None,
models: Union[Callable, Set[str], str] = None,
) -> Callable[[CALLABLE_T], CALLABLE_T]: ) -> Callable[[CALLABLE_T], CALLABLE_T]:
"""Decorate a strict match rule.""" """Decorate a strict match rule."""
rule = MatchRule(channel_names, generic_ids, manufacturers, models)
def decorator(zha_ent: CALLABLE_T) -> CALLABLE_T: def decorator(zha_ent: CALLABLE_T) -> CALLABLE_T:
"""Register a strict match rule. """Register a strict match rule.
@ -215,10 +241,17 @@ class ZHAEntityRegistry:
return decorator return decorator
def loose_match( def loose_match(
self, component: str, rule: MatchRule self,
component: str,
channel_names: Union[Callable, Set[str], str] = None,
generic_ids: Union[Callable, Set[str], str] = None,
manufacturers: Union[Callable, Set[str], str] = None,
models: Union[Callable, Set[str], str] = None,
) -> Callable[[CALLABLE_T], CALLABLE_T]: ) -> Callable[[CALLABLE_T], CALLABLE_T]:
"""Decorate a loose match rule.""" """Decorate a loose match rule."""
rule = MatchRule(channel_names, generic_ids, manufacturers, models)
def decorator(zha_entity: CALLABLE_T) -> CALLABLE_T: def decorator(zha_entity: CALLABLE_T) -> CALLABLE_T:
"""Register a loose match rule. """Register a loose match rule.
@ -238,7 +271,7 @@ class ZHAEntityRegistry:
return any(self._matched(zha_device, chnls, rule)) return any(self._matched(zha_device, chnls, rule))
@staticmethod @staticmethod
def _matched(zha_device, chnls: list, rule: MatchRule) -> bool: def _matched(zha_device, chnls: dict, rule: MatchRule) -> list:
"""Return a list of field matches.""" """Return a list of field matches."""
if not any(attr.asdict(rule).values()): if not any(attr.asdict(rule).values()):
return [False] return [False]
@ -252,11 +285,17 @@ class ZHAEntityRegistry:
all_generic_ids = {ch.generic_id for ch in chnls} all_generic_ids = {ch.generic_id for ch in chnls}
matches.append(rule.generic_ids.issubset(all_generic_ids)) matches.append(rule.generic_ids.issubset(all_generic_ids))
if rule.manufacturer: if rule.manufacturers:
matches.append(zha_device.manufacturer == rule.manufacturer) if callable(rule.manufacturers):
matches.append(rule.manufacturers(zha_device.manufacturer))
else:
matches.append(zha_device.manufacturer in rule.manufacturers)
if rule.model: if rule.models:
matches.append(zha_device.model == rule.model) if callable(rule.models):
matches.append(rule.models(zha_device.model))
else:
matches.append(zha_device.model in rule.models)
return matches return matches

View File

@ -1,4 +1,5 @@
"""Support for the ZHA platform.""" """Support for the ZHA platform."""
import functools
import logging import logging
import time import time
@ -14,9 +15,11 @@ from .core.const import (
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
ZHA_DISCOVERY_NEW, ZHA_DISCOVERY_NEW,
) )
from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
from .sensor import Battery from .sensor import Battery
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -47,11 +50,20 @@ async def _async_setup_entities(
"""Set up the ZHA device trackers.""" """Set up the ZHA device trackers."""
entities = [] entities = []
for discovery_info in discovery_infos: for discovery_info in discovery_infos:
entities.append(ZHADeviceScannerEntity(**discovery_info)) zha_dev = discovery_info["zha_device"]
channels = discovery_info["channels"]
async_add_entities(entities, update_before_add=True) entity = ZHA_ENTITIES.get_entity(
DOMAIN, zha_dev, channels, ZHADeviceScannerEntity
)
if entity:
entities.append(entity(**discovery_info))
if entities:
async_add_entities(entities, update_before_add=True)
@STRICT_MATCH(channel_names=CHANNEL_POWER_CONFIGURATION)
class ZHADeviceScannerEntity(ScannerEntity, ZhaEntity): class ZHADeviceScannerEntity(ScannerEntity, ZhaEntity):
"""Represent a tracked device.""" """Represent a tracked device."""

View File

@ -1,4 +1,5 @@
"""Fans on Zigbee Home Automation networks.""" """Fans on Zigbee Home Automation networks."""
import functools
import logging import logging
from homeassistant.components.fan import ( from homeassistant.components.fan import (
@ -20,6 +21,7 @@ from .core.const import (
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
ZHA_DISCOVERY_NEW, ZHA_DISCOVERY_NEW,
) )
from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -45,6 +47,7 @@ SPEED_LIST = [
VALUE_TO_SPEED = dict(enumerate(SPEED_LIST)) VALUE_TO_SPEED = dict(enumerate(SPEED_LIST))
SPEED_TO_VALUE = {speed: i for i, speed in enumerate(SPEED_LIST)} SPEED_TO_VALUE = {speed: i for i, speed in enumerate(SPEED_LIST)}
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN)
async def async_setup_platform(hass, config, async_add_entities, discovery_info=None): async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
@ -79,11 +82,18 @@ async def _async_setup_entities(
"""Set up the ZHA fans.""" """Set up the ZHA fans."""
entities = [] entities = []
for discovery_info in discovery_infos: for discovery_info in discovery_infos:
entities.append(ZhaFan(**discovery_info)) zha_dev = discovery_info["zha_device"]
channels = discovery_info["channels"]
async_add_entities(entities, update_before_add=True) entity = ZHA_ENTITIES.get_entity(DOMAIN, zha_dev, channels, ZhaFan)
if entity:
entities.append(entity(**discovery_info))
if entities:
async_add_entities(entities, update_before_add=True)
@STRICT_MATCH(channel_names=CHANNEL_FAN)
class ZhaFan(ZhaEntity, FanEntity): class ZhaFan(ZhaEntity, FanEntity):
"""Representation of a ZHA fan.""" """Representation of a ZHA fan."""

View File

@ -1,5 +1,6 @@
"""Lights on Zigbee Home Automation networks.""" """Lights on Zigbee Home Automation networks."""
from datetime import timedelta from datetime import timedelta
import functools
import logging import logging
from zigpy.zcl.foundation import Status from zigpy.zcl.foundation import Status
@ -21,6 +22,7 @@ from .core.const import (
SIGNAL_SET_LEVEL, SIGNAL_SET_LEVEL,
ZHA_DISCOVERY_NEW, ZHA_DISCOVERY_NEW,
) )
from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -36,6 +38,7 @@ UPDATE_COLORLOOP_HUE = 0x8
UNSUPPORTED_ATTRIBUTE = 0x86 UNSUPPORTED_ATTRIBUTE = 0x86
SCAN_INTERVAL = timedelta(minutes=60) SCAN_INTERVAL = timedelta(minutes=60)
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, light.DOMAIN)
PARALLEL_UPDATES = 5 PARALLEL_UPDATES = 5
@ -71,12 +74,18 @@ async def _async_setup_entities(
"""Set up the ZHA lights.""" """Set up the ZHA lights."""
entities = [] entities = []
for discovery_info in discovery_infos: for discovery_info in discovery_infos:
zha_light = Light(**discovery_info) zha_dev = discovery_info["zha_device"]
entities.append(zha_light) channels = discovery_info["channels"]
async_add_entities(entities, update_before_add=True) entity = ZHA_ENTITIES.get_entity(light.DOMAIN, zha_dev, channels, Light)
if entity:
entities.append(entity(**discovery_info))
if entities:
async_add_entities(entities, update_before_add=True)
@STRICT_MATCH(channel_names=CHANNEL_ON_OFF)
class Light(ZhaEntity, light.Light): class Light(ZhaEntity, light.Light):
"""Representation of a ZHA or ZLL light.""" """Representation of a ZHA or ZLL light."""

View File

@ -1,4 +1,5 @@
"""Locks on Zigbee Home Automation networks.""" """Locks on Zigbee Home Automation networks."""
import functools
import logging import logging
from zigpy.zcl.foundation import Status from zigpy.zcl.foundation import Status
@ -19,6 +20,7 @@ from .core.const import (
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
ZHA_DISCOVERY_NEW, ZHA_DISCOVERY_NEW,
) )
from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -26,6 +28,7 @@ _LOGGER = logging.getLogger(__name__)
""" The first state is Zigbee 'Not fully locked' """ """ The first state is Zigbee 'Not fully locked' """
STATE_LIST = [STATE_UNLOCKED, STATE_LOCKED, STATE_UNLOCKED] STATE_LIST = [STATE_UNLOCKED, STATE_LOCKED, STATE_UNLOCKED]
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN)
VALUE_TO_STATE = dict(enumerate(STATE_LIST)) VALUE_TO_STATE = dict(enumerate(STATE_LIST))
@ -62,11 +65,18 @@ async def _async_setup_entities(
"""Set up the ZHA locks.""" """Set up the ZHA locks."""
entities = [] entities = []
for discovery_info in discovery_infos: for discovery_info in discovery_infos:
entities.append(ZhaDoorLock(**discovery_info)) zha_dev = discovery_info["zha_device"]
channels = discovery_info["channels"]
async_add_entities(entities, update_before_add=True) entity = ZHA_ENTITIES.get_entity(DOMAIN, zha_dev, channels, ZhaDoorLock)
if entity:
entities.append(entity(**discovery_info))
if entities:
async_add_entities(entities, update_before_add=True)
@STRICT_MATCH(channel_names=CHANNEL_DOORLOCK)
class ZhaDoorLock(ZhaEntity, LockDevice): class ZhaDoorLock(ZhaEntity, LockDevice):
"""Representation of a ZHA lock.""" """Representation of a ZHA lock."""

View File

@ -30,7 +30,7 @@ from .core.const import (
SIGNAL_STATE_ATTR, SIGNAL_STATE_ATTR,
ZHA_DISCOVERY_NEW, ZHA_DISCOVERY_NEW,
) )
from .core.registries import SMARTTHINGS_HUMIDITY_CLUSTER, ZHA_ENTITIES, MatchRule from .core.registries import SMARTTHINGS_HUMIDITY_CLUSTER, ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
PARALLEL_UPDATES = 5 PARALLEL_UPDATES = 5
@ -90,7 +90,8 @@ async def _async_setup_entities(
for discovery_info in discovery_infos: for discovery_info in discovery_infos:
entities.append(await make_sensor(discovery_info)) entities.append(await make_sensor(discovery_info))
async_add_entities(entities, update_before_add=True) if entities:
async_add_entities(entities, update_before_add=True)
async def make_sensor(discovery_info): async def make_sensor(discovery_info):
@ -175,7 +176,7 @@ class Sensor(ZhaEntity):
return round(float(value * self._multiplier) / self._divisor) return round(float(value * self._multiplier) / self._divisor)
@STRICT_MATCH(MatchRule(channel_names={CHANNEL_POWER_CONFIGURATION})) @STRICT_MATCH(channel_names=CHANNEL_POWER_CONFIGURATION)
class Battery(Sensor): class Battery(Sensor):
"""Battery sensor of power configuration cluster.""" """Battery sensor of power configuration cluster."""
@ -203,7 +204,7 @@ class Battery(Sensor):
return state_attrs return state_attrs
@STRICT_MATCH(MatchRule(channel_names={CHANNEL_ELECTRICAL_MEASUREMENT})) @STRICT_MATCH(channel_names=CHANNEL_ELECTRICAL_MEASUREMENT)
class ElectricalMeasurement(Sensor): class ElectricalMeasurement(Sensor):
"""Active power measurement.""" """Active power measurement."""
@ -221,8 +222,8 @@ class ElectricalMeasurement(Sensor):
return round(value * self._channel.multiplier / self._channel.divisor) return round(value * self._channel.multiplier / self._channel.divisor)
@STRICT_MATCH(MatchRule(generic_ids={CHANNEL_ST_HUMIDITY_CLUSTER})) @STRICT_MATCH(generic_ids=CHANNEL_ST_HUMIDITY_CLUSTER)
@STRICT_MATCH(MatchRule(channel_names={CHANNEL_HUMIDITY})) @STRICT_MATCH(channel_names=CHANNEL_HUMIDITY)
class Humidity(Sensor): class Humidity(Sensor):
"""Humidity sensor.""" """Humidity sensor."""
@ -231,7 +232,7 @@ class Humidity(Sensor):
_unit = "%" _unit = "%"
@STRICT_MATCH(MatchRule(channel_names={CHANNEL_ILLUMINANCE})) @STRICT_MATCH(channel_names=CHANNEL_ILLUMINANCE)
class Illuminance(Sensor): class Illuminance(Sensor):
"""Illuminance Sensor.""" """Illuminance Sensor."""
@ -244,7 +245,7 @@ class Illuminance(Sensor):
return round(pow(10, ((value - 1) / 10000)), 1) return round(pow(10, ((value - 1) / 10000)), 1)
@STRICT_MATCH(MatchRule(channel_names={CHANNEL_SMARTENERGY_METERING})) @STRICT_MATCH(channel_names=CHANNEL_SMARTENERGY_METERING)
class SmartEnergyMetering(Sensor): class SmartEnergyMetering(Sensor):
"""Metering sensor.""" """Metering sensor."""
@ -260,7 +261,7 @@ class SmartEnergyMetering(Sensor):
return self._channel.unit_of_measurement return self._channel.unit_of_measurement
@STRICT_MATCH(MatchRule(channel_names={CHANNEL_PRESSURE})) @STRICT_MATCH(channel_names=CHANNEL_PRESSURE)
class Pressure(Sensor): class Pressure(Sensor):
"""Pressure sensor.""" """Pressure sensor."""
@ -269,7 +270,7 @@ class Pressure(Sensor):
_unit = "hPa" _unit = "hPa"
@STRICT_MATCH(MatchRule(channel_names={CHANNEL_TEMPERATURE})) @STRICT_MATCH(channel_names=CHANNEL_TEMPERATURE)
class Temperature(Sensor): class Temperature(Sensor):
"""Temperature Sensor.""" """Temperature Sensor."""

View File

@ -1,4 +1,5 @@
"""Switches on Zigbee Home Automation networks.""" """Switches on Zigbee Home Automation networks."""
import functools
import logging import logging
from zigpy.zcl.foundation import Status from zigpy.zcl.foundation import Status
@ -15,9 +16,11 @@ from .core.const import (
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
ZHA_DISCOVERY_NEW, ZHA_DISCOVERY_NEW,
) )
from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN)
async def async_setup_platform(hass, config, async_add_entities, discovery_info=None): async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
@ -52,11 +55,18 @@ async def _async_setup_entities(
"""Set up the ZHA switches.""" """Set up the ZHA switches."""
entities = [] entities = []
for discovery_info in discovery_infos: for discovery_info in discovery_infos:
entities.append(Switch(**discovery_info)) zha_dev = discovery_info["zha_device"]
channels = discovery_info["channels"]
async_add_entities(entities, update_before_add=True) entity = ZHA_ENTITIES.get_entity(DOMAIN, zha_dev, channels, Switch)
if entity:
entities.append(entity(**discovery_info))
if entities:
async_add_entities(entities, update_before_add=True)
@STRICT_MATCH(channel_names=CHANNEL_ON_OFF)
class Switch(ZhaEntity, SwitchDevice): class Switch(ZhaEntity, SwitchDevice):
"""ZHA switch.""" """ZHA switch."""

View File

@ -36,10 +36,10 @@ APPLICATION = FakeApplication()
class FakeEndpoint: class FakeEndpoint:
"""Fake endpoint for moking zigpy.""" """Fake endpoint for moking zigpy."""
def __init__(self, manufacturer, model): def __init__(self, manufacturer, model, epid=1):
"""Init fake endpoint.""" """Init fake endpoint."""
self.device = None self.device = None
self.endpoint_id = 1 self.endpoint_id = epid
self.in_clusters = {} self.in_clusters = {}
self.out_clusters = {} self.out_clusters = {}
self._cluster_attr = {} self._cluster_attr = {}
@ -97,21 +97,23 @@ class FakeDevice:
self.remove_from_group = CoroutineMock() self.remove_from_group = CoroutineMock()
def make_device( def make_device(endpoints, ieee, manufacturer, model):
in_cluster_ids, out_cluster_ids, device_type, ieee, manufacturer, model
):
"""Make a fake device using the specified cluster classes.""" """Make a fake device using the specified cluster classes."""
device = FakeDevice(ieee, manufacturer, model) device = FakeDevice(ieee, manufacturer, model)
endpoint = FakeEndpoint(manufacturer, model) for epid, ep in endpoints.items():
endpoint.device = device endpoint = FakeEndpoint(manufacturer, model, epid)
device.endpoints[endpoint.endpoint_id] = endpoint endpoint.device = device
endpoint.device_type = device_type device.endpoints[epid] = endpoint
endpoint.device_type = ep["device_type"]
profile_id = ep.get("profile_id")
if profile_id:
endpoint.profile_id = profile_id
for cluster_id in in_cluster_ids: for cluster_id in ep.get("in_clusters", []):
endpoint.add_input_cluster(cluster_id) endpoint.add_input_cluster(cluster_id)
for cluster_id in out_cluster_ids: for cluster_id in ep.get("out_clusters", []):
endpoint.add_output_cluster(cluster_id) endpoint.add_output_cluster(cluster_id)
return device return device
@ -136,7 +138,16 @@ async def async_init_zigpy_device(
happens when the device is paired to the network for the first time. happens when the device is paired to the network for the first time.
""" """
device = make_device( device = make_device(
in_cluster_ids, out_cluster_ids, device_type, ieee, manufacturer, model {
1: {
"in_clusters": in_cluster_ids,
"out_clusters": out_cluster_ids,
"device_type": device_type,
}
},
ieee,
manufacturer,
model,
) )
if is_new_join: if is_new_join:
await gateway.async_device_initialized(device) await gateway.async_device_initialized(device)

View File

@ -67,9 +67,7 @@ def nwk():
async def test_in_channel_config(cluster_id, bind_count, attrs, zha_gateway, hass): async def test_in_channel_config(cluster_id, bind_count, attrs, zha_gateway, hass):
"""Test ZHA core channel configuration for input clusters.""" """Test ZHA core channel configuration for input clusters."""
zigpy_dev = make_device( zigpy_dev = make_device(
[cluster_id], {1: {"in_clusters": [cluster_id], "out_clusters": [], "device_type": 0x1234}},
[],
0x1234,
"00:11:22:33:44:55:66:77", "00:11:22:33:44:55:66:77",
"test manufacturer", "test manufacturer",
"test model", "test model",
@ -125,9 +123,7 @@ async def test_in_channel_config(cluster_id, bind_count, attrs, zha_gateway, has
async def test_out_channel_config(cluster_id, bind_count, zha_gateway, hass): async def test_out_channel_config(cluster_id, bind_count, zha_gateway, hass):
"""Test ZHA core channel configuration for output clusters.""" """Test ZHA core channel configuration for output clusters."""
zigpy_dev = make_device( zigpy_dev = make_device(
[], {1: {"out_clusters": [cluster_id], "in_clusters": [], "device_type": 0x1234}},
[cluster_id],
0x1234,
"00:11:22:33:44:55:66:77", "00:11:22:33:44:55:66:77",
"test manufacturer", "test manufacturer",
"test model", "test model",

View File

@ -0,0 +1,55 @@
"""Test zha device discovery."""
import asyncio
from unittest import mock
import pytest
from homeassistant.components.zha.core.channels import EventRelayChannel
import homeassistant.components.zha.core.const as zha_const
import homeassistant.components.zha.core.discovery as disc
import homeassistant.components.zha.core.gateway as core_zha_gw
from .common import make_device
from .zha_devices_list import DEVICES
@pytest.mark.parametrize("device", DEVICES)
async def test_devices(device, zha_gateway: core_zha_gw.ZHAGateway, hass, config_entry):
"""Test device discovery."""
zigpy_device = make_device(
device["endpoints"],
"00:11:22:33:44:55:66:77",
device["manufacturer"],
device["model"],
)
with mock.patch(
"homeassistant.components.zha.core.discovery._async_create_cluster_channel",
wraps=disc._async_create_cluster_channel,
) as cr_ch:
await zha_gateway.async_device_restored(zigpy_device)
await hass.async_block_till_done()
tasks = [
hass.config_entries.async_forward_entry_setup(config_entry, component)
for component in zha_const.COMPONENTS
]
await asyncio.gather(*tasks)
await hass.async_block_till_done()
entity_ids = hass.states.async_entity_ids()
await hass.async_block_till_done()
zha_entities = {
ent for ent in entity_ids if ent.split(".")[0] in zha_const.COMPONENTS
}
event_channels = {
arg[0].cluster_id
for arg, kwarg in cr_ch.call_args_list
if kwarg.get("channel_class") == EventRelayChannel
}
assert zha_entities == set(device["entities"])
assert event_channels == set(device["event_channels"])

View File

@ -59,24 +59,68 @@ def channels():
True, True,
), ),
# manufacturer matching # manufacturer matching
(registries.MatchRule(manufacturer="no match"), False), (registries.MatchRule(manufacturers="no match"), False),
(registries.MatchRule(manufacturer=MANUFACTURER), True), (registries.MatchRule(manufacturers=MANUFACTURER), True),
(registries.MatchRule(model=MODEL), True), (registries.MatchRule(models=MODEL), True),
(registries.MatchRule(model="no match"), False), (registries.MatchRule(models="no match"), False),
# match everything # match everything
( (
registries.MatchRule( registries.MatchRule(
generic_ids={"channel_0x0006", "channel_0x0008"}, generic_ids={"channel_0x0006", "channel_0x0008"},
channel_names={"on_off", "level"}, channel_names={"on_off", "level"},
manufacturer=MANUFACTURER, manufacturers=MANUFACTURER,
model=MODEL, models=MODEL,
), ),
True, True,
), ),
(
registries.MatchRule(
channel_names="on_off", manufacturers={"random manuf", MANUFACTURER}
),
True,
),
(
registries.MatchRule(
channel_names="on_off", manufacturers={"random manuf", "Another manuf"}
),
False,
),
(
registries.MatchRule(
channel_names="on_off", manufacturers=lambda x: x == MANUFACTURER
),
True,
),
(
registries.MatchRule(
channel_names="on_off", manufacturers=lambda x: x != MANUFACTURER
),
False,
),
(
registries.MatchRule(
channel_names="on_off", models={"random model", MODEL}
),
True,
),
(
registries.MatchRule(
channel_names="on_off", models={"random model", "Another model"}
),
False,
),
(
registries.MatchRule(channel_names="on_off", models=lambda x: x == MODEL),
True,
),
(
registries.MatchRule(channel_names="on_off", models=lambda x: x != MODEL),
False,
),
], ],
) )
def test_registry_matching(rule, matched, zha_device, channels): def test_registry_matching(rule, matched, zha_device, channels):
"""Test empty rule matching.""" """Test strict rule matching."""
reg = registries.ZHAEntityRegistry() reg = registries.ZHAEntityRegistry()
assert reg._strict_matched(zha_device, channels, rule) is matched assert reg._strict_matched(zha_device, channels, rule) is matched
@ -92,22 +136,22 @@ def test_registry_matching(rule, matched, zha_device, channels):
(registries.MatchRule(channel_names={"on_off", "level"}), True), (registries.MatchRule(channel_names={"on_off", "level"}), True),
(registries.MatchRule(channel_names={"on_off", "level", "no match"}), False), (registries.MatchRule(channel_names={"on_off", "level", "no match"}), False),
( (
registries.MatchRule(channel_names={"on_off", "level"}, model="no match"), registries.MatchRule(channel_names={"on_off", "level"}, models="no match"),
True, True,
), ),
( (
registries.MatchRule( registries.MatchRule(
channel_names={"on_off", "level"}, channel_names={"on_off", "level"},
model="no match", models="no match",
manufacturer="no match", manufacturers="no match",
), ),
True, True,
), ),
( (
registries.MatchRule( registries.MatchRule(
channel_names={"on_off", "level"}, channel_names={"on_off", "level"},
model="no match", models="no match",
manufacturer=MANUFACTURER, manufacturers=MANUFACTURER,
), ),
True, True,
), ),
@ -124,14 +168,14 @@ def test_registry_matching(rule, matched, zha_device, channels):
( (
registries.MatchRule( registries.MatchRule(
generic_ids={"channel_0x0006", "channel_0x0008", "channel_0x0009"}, generic_ids={"channel_0x0006", "channel_0x0008", "channel_0x0009"},
model="mo match", models="mo match",
), ),
False, False,
), ),
( (
registries.MatchRule( registries.MatchRule(
generic_ids={"channel_0x0006", "channel_0x0008", "channel_0x0009"}, generic_ids={"channel_0x0006", "channel_0x0008", "channel_0x0009"},
model=MODEL, models=MODEL,
), ),
True, True,
), ),
@ -143,17 +187,17 @@ def test_registry_matching(rule, matched, zha_device, channels):
True, True,
), ),
# manufacturer matching # manufacturer matching
(registries.MatchRule(manufacturer="no match"), False), (registries.MatchRule(manufacturers="no match"), False),
(registries.MatchRule(manufacturer=MANUFACTURER), True), (registries.MatchRule(manufacturers=MANUFACTURER), True),
(registries.MatchRule(model=MODEL), True), (registries.MatchRule(models=MODEL), True),
(registries.MatchRule(model="no match"), False), (registries.MatchRule(models="no match"), False),
# match everything # match everything
( (
registries.MatchRule( registries.MatchRule(
generic_ids={"channel_0x0006", "channel_0x0008"}, generic_ids={"channel_0x0006", "channel_0x0008"},
channel_names={"on_off", "level"}, channel_names={"on_off", "level"},
manufacturer=MANUFACTURER, manufacturers=MANUFACTURER,
model=MODEL, models=MODEL,
), ),
True, True,
), ),

File diff suppressed because it is too large Load Diff