From 54c4e9335fea750bd8c954cf568ef3665bfcd426 Mon Sep 17 00:00:00 2001 From: Alexei Chetroi Date: Wed, 18 Nov 2020 21:34:12 -0500 Subject: [PATCH] Refactor ZHA sensor initialization (#43339) * Refactor ZHA sensors to use cached values after restart * Get attr from cluster, not channel * Run cached state through formatter method * Use cached values for div/multiplier for SmartEnergy channel * Restore batter voltage from cache * Refactor sensor to use cached values only * Update tests * Add battery sensor test --- .../components/zha/core/channels/base.py | 2 +- .../components/zha/core/channels/general.py | 49 +------- .../zha/core/channels/smartenergy.py | 39 +++--- .../components/zha/core/discovery.py | 3 +- homeassistant/components/zha/sensor.py | 115 +++++++----------- tests/components/zha/test_sensor.py | 58 ++++++++- 6 files changed, 128 insertions(+), 138 deletions(-) diff --git a/homeassistant/components/zha/core/channels/base.py b/homeassistant/components/zha/core/channels/base.py index 9f2fe4f21bd..c6019c10843 100644 --- a/homeassistant/components/zha/core/channels/base.py +++ b/homeassistant/components/zha/core/channels/base.py @@ -208,7 +208,7 @@ class ZigbeeChannel(LogMixin): attributes = [] for report_config in self._report_config: attributes.append(report_config["attr"]) - if len(attributes) > 0: + if attributes: await self.get_attributes(attributes, from_cache=from_cache) self._status = ChannelStatus.INITIALIZED diff --git a/homeassistant/components/zha/core/channels/general.py b/homeassistant/components/zha/core/channels/general.py index f443151de02..8747355a21a 100644 --- a/homeassistant/components/zha/core/channels/general.py +++ b/homeassistant/components/zha/core/channels/general.py @@ -17,10 +17,9 @@ from ..const import ( SIGNAL_ATTR_UPDATED, SIGNAL_MOVE_LEVEL, SIGNAL_SET_LEVEL, - SIGNAL_STATE_ATTR, SIGNAL_UPDATE_DEVICE, ) -from .base import ClientChannel, ZigbeeChannel, parse_and_log_command +from .base import ChannelStatus, ClientChannel, ZigbeeChannel, parse_and_log_command @registries.ZIGBEE_CHANNEL_REGISTRY.register(general.Alarms.cluster_id) @@ -72,13 +71,6 @@ class BasicChannel(ZigbeeChannel): 6: "Emergency mains and transfer switch", } - def __init__( - self, cluster: zha_typing.ZigpyClusterType, ch_pool: zha_typing.ChannelPoolType - ) -> None: - """Initialize BasicChannel.""" - super().__init__(cluster, ch_pool) - self._power_source = None - async def async_configure(self): """Configure this channel.""" await super().async_configure() @@ -87,16 +79,12 @@ class BasicChannel(ZigbeeChannel): async def async_initialize(self, from_cache): """Initialize channel.""" if not self._ch_pool.skip_configuration or from_cache: - power_source = await self.get_attribute_value( - "power_source", from_cache=from_cache - ) - if power_source is not None: - self._power_source = power_source + await self.get_attribute_value("power_source", from_cache=from_cache) await super().async_initialize(from_cache) def get_power_source(self): """Get the power source.""" - return self._power_source + return self.cluster.get("power_source") @registries.ZIGBEE_CHANNEL_REGISTRY.register(general.BinaryInput.cluster_id) @@ -392,38 +380,8 @@ class PowerConfigurationChannel(ZigbeeChannel): {"attr": "battery_percentage_remaining", "config": REPORT_CONFIG_BATTERY_SAVE}, ) - @callback - def attribute_updated(self, attrid, value): - """Handle attribute updates on this cluster.""" - attr = self._report_config[1].get("attr") - if isinstance(attr, str): - attr_id = self.cluster.attridx.get(attr) - else: - attr_id = attr - if attrid == attr_id: - self.async_send_signal( - f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}", - attrid, - self.cluster.attributes.get(attrid, [attrid])[0], - value, - ) - return - attr_name = self.cluster.attributes.get(attrid, [attrid])[0] - self.async_send_signal( - f"{self.unique_id}_{SIGNAL_STATE_ATTR}", attr_name, value - ) - async def async_initialize(self, from_cache): """Initialize channel.""" - await self.async_read_state(from_cache) - await super().async_initialize(from_cache) - - async def async_update(self): - """Retrieve latest state.""" - await self.async_read_state(True) - - async def async_read_state(self, from_cache): - """Read data from the cluster.""" attributes = [ "battery_size", "battery_percentage_remaining", @@ -431,6 +389,7 @@ class PowerConfigurationChannel(ZigbeeChannel): "battery_quantity", ] await self.get_attributes(attributes, from_cache=from_cache) + self._status = ChannelStatus.INITIALIZED @registries.ZIGBEE_CHANNEL_REGISTRY.register(general.PowerProfile.cluster_id) diff --git a/homeassistant/components/zha/core/channels/smartenergy.py b/homeassistant/components/zha/core/channels/smartenergy.py index 120d0afdfb6..792b9413294 100644 --- a/homeassistant/components/zha/core/channels/smartenergy.py +++ b/homeassistant/components/zha/core/channels/smartenergy.py @@ -1,4 +1,6 @@ """Smart energy channels module for Zigbee Home Automation.""" +from typing import Union + import zigpy.zcl.clusters.smartenergy as smartenergy from homeassistant.const import ( @@ -82,44 +84,48 @@ class Metering(ZigbeeChannel): ) -> None: """Initialize Metering.""" super().__init__(cluster, ch_pool) - self._divisor = 1 - self._multiplier = 1 - self._unit_enum = None self._format_spec = None - async def async_configure(self): + @property + def divisor(self) -> int: + """Return divisor for the value.""" + return self.cluster.get("divisor") + + @property + def multiplier(self) -> int: + """Return multiplier for the value.""" + return self.cluster.get("multiplier") + + async def async_configure(self) -> None: """Configure channel.""" await self.fetch_config(False) await super().async_configure() - async def async_initialize(self, from_cache): + async def async_initialize(self, from_cache: bool) -> None: """Initialize channel.""" await self.fetch_config(True) await super().async_initialize(from_cache) @callback - def attribute_updated(self, attrid, value): + def attribute_updated(self, attrid: int, value: int) -> None: """Handle attribute update from Metering cluster.""" - if None in (self._multiplier, self._divisor, self._format_spec): + if None in (self.multiplier, self.divisor, self._format_spec): return - super().attribute_updated(attrid, value * self._multiplier / self._divisor) + super().attribute_updated(attrid, value) @property - def unit_of_measurement(self): + def unit_of_measurement(self) -> str: """Return unit of measurement.""" - return self.unit_of_measure_map.get(self._unit_enum & 0x7F, "unknown") + uom = self.cluster.get("unit_of_measure", 0x7F) + return self.unit_of_measure_map.get(uom & 0x7F, "unknown") - async def fetch_config(self, from_cache): + async def fetch_config(self, from_cache: bool) -> None: """Fetch config from device and updates format specifier.""" results = await self.get_attributes( ["divisor", "multiplier", "unit_of_measure", "demand_formatting"], from_cache=from_cache, ) - self._divisor = results.get("divisor", self._divisor) - self._multiplier = results.get("multiplier", self._multiplier) - self._unit_enum = results.get("unit_of_measure", 0x7F) # default to unknown - fmting = results.get( "demand_formatting", 0xF9 ) # 1 digit to the right, 15 digits to the left @@ -135,8 +141,9 @@ class Metering(ZigbeeChannel): else: self._format_spec = "{:0" + str(width) + "." + str(r_digits) + "f}" - def formatter_function(self, value): + def formatter_function(self, value: int) -> Union[int, float]: """Return formatted value for display.""" + value = value * self.multiplier / self.divisor if self.unit_of_measurement == POWER_WATT: # Zigbee spec power unit is kW, but we show the value in W value_watt = value * 1000 diff --git a/homeassistant/components/zha/core/discovery.py b/homeassistant/components/zha/core/discovery.py index 25f320b0bf1..4dff2c6b16b 100644 --- a/homeassistant/components/zha/core/discovery.py +++ b/homeassistant/components/zha/core/discovery.py @@ -39,12 +39,13 @@ async def async_add_entities( Tuple[str, zha_typing.ZhaDeviceType, List[zha_typing.ChannelType]], ] ], + update_before_add: bool = True, ) -> None: """Add entities helper.""" if not entities: return to_add = [ent_cls(*args) for ent_cls, args in entities] - _async_add_entities(to_add, update_before_add=True) + _async_add_entities(to_add, update_before_add=update_before_add) entities.clear() diff --git a/homeassistant/components/zha/sensor.py b/homeassistant/components/zha/sensor.py index 302637cc068..eff3892630b 100644 --- a/homeassistant/components/zha/sensor.py +++ b/homeassistant/components/zha/sensor.py @@ -1,6 +1,7 @@ """Sensors on Zigbee Home Automation networks.""" import functools import numbers +from typing import Any, Callable, Dict, List, Optional, Union from homeassistant.components.sensor import ( DEVICE_CLASS_BATTERY, @@ -11,18 +12,17 @@ from homeassistant.components.sensor import ( DEVICE_CLASS_TEMPERATURE, DOMAIN, ) +from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( - ATTR_UNIT_OF_MEASUREMENT, LIGHT_LUX, PERCENTAGE, POWER_WATT, PRESSURE_HPA, - STATE_UNKNOWN, TEMP_CELSIUS, ) from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.util.temperature import fahrenheit_to_celsius +from homeassistant.helpers.typing import HomeAssistantType, StateType from .core import discovery from .core.const import ( @@ -38,9 +38,9 @@ from .core.const import ( DATA_ZHA_DISPATCHERS, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, - SIGNAL_STATE_ATTR, ) from .core.registries import SMARTTHINGS_HUMIDITY_CLUSTER, ZHA_ENTITIES +from .core.typing import ChannelType, ZhaDeviceType from .entity import ZhaEntity PARALLEL_UPDATES = 5 @@ -65,7 +65,9 @@ CHANNEL_ST_HUMIDITY_CLUSTER = f"channel_0x{SMARTTHINGS_HUMIDITY_CLUSTER:04x}" STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN) -async def async_setup_entry(hass, config_entry, async_add_entities): +async def async_setup_entry( + hass: HomeAssistantType, config_entry: ConfigEntry, async_add_entities: Callable +) -> None: """Set up the Zigbee Home Automation sensor from config entry.""" entities_to_create = hass.data[DATA_ZHA][DOMAIN] @@ -73,7 +75,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities): hass, SIGNAL_ADD_ENTITIES, functools.partial( - discovery.async_add_entities, async_add_entities, entities_to_create + discovery.async_add_entities, + async_add_entities, + entities_to_create, + update_before_add=False, ), ) hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) @@ -82,29 +87,30 @@ async def async_setup_entry(hass, config_entry, async_add_entities): class Sensor(ZhaEntity): """Base ZHA sensor.""" - SENSOR_ATTR = None - _decimals = 1 - _device_class = None - _divisor = 1 - _multiplier = 1 - _unit = None + SENSOR_ATTR: Optional[Union[int, str]] = None + _decimals: int = 1 + _device_class: Optional[str] = None + _divisor: int = 1 + _multiplier: int = 1 + _unit: Optional[str] = None - def __init__(self, unique_id, zha_device, channels, **kwargs): + def __init__( + self, + unique_id: str, + zha_device: ZhaDeviceType, + channels: List[ChannelType], + **kwargs, + ): """Init this sensor.""" super().__init__(unique_id, zha_device, channels, **kwargs) - self._channel = channels[0] + self._channel: ChannelType = channels[0] - async def async_added_to_hass(self): + async def async_added_to_hass(self) -> None: """Run when about to be added to hass.""" await super().async_added_to_hass() - self._device_state_attributes.update(await self.async_state_attr_provider()) - self.async_accept_signal( self._channel, SIGNAL_ATTR_UPDATED, self.async_set_state ) - self.async_accept_signal( - self._channel, SIGNAL_STATE_ATTR, self.async_update_state_attribute - ) @property def device_class(self) -> str: @@ -112,37 +118,25 @@ class Sensor(ZhaEntity): return self._device_class @property - def unit_of_measurement(self): + def unit_of_measurement(self) -> Optional[str]: """Return the unit of measurement of this entity.""" return self._unit @property - def state(self) -> str: + def state(self) -> StateType: """Return the state of the entity.""" - if self._state is None: + assert self.SENSOR_ATTR is not None + raw_state = self._channel.cluster.get(self.SENSOR_ATTR) + if raw_state is None: return None - return self._state + return self.formatter(raw_state) @callback - def async_set_state(self, attr_id, attr_name, value): + def async_set_state(self, attr_id: int, attr_name: str, value: Any) -> None: """Handle state update from channel.""" - if self.SENSOR_ATTR is None or self.SENSOR_ATTR != attr_name: - return - if value is not None: - value = self.formatter(value) - self._state = value self.async_write_ha_state() - @callback - def async_restore_last_state(self, last_state): - """Restore previous state.""" - self._state = last_state.state - - async def async_state_attr_provider(self): - """Initialize device state attributes.""" - return {} - - def formatter(self, value): + def formatter(self, value: int) -> Union[int, float]: """Numeric pass-through formatter.""" if self._decimals > 0: return round( @@ -167,7 +161,7 @@ class Battery(Sensor): _unit = PERCENTAGE @staticmethod - def formatter(value): + def formatter(value: int) -> int: """Return the state of the entity.""" # per zcl specs battery percent is reported at 200% ¯\_(ツ)_/¯ if not isinstance(value, numbers.Number) or value == -1: @@ -175,26 +169,21 @@ class Battery(Sensor): value = round(value / 2) return value - async def async_state_attr_provider(self): + @property + def device_state_attributes(self) -> Dict[str, Any]: """Return device state attrs for battery sensors.""" state_attrs = {} - attributes = ["battery_size", "battery_quantity"] - results = await self._channel.get_attributes(attributes) - battery_size = results.get("battery_size") + battery_size = self._channel.cluster.get("battery_size") if battery_size is not None: state_attrs["battery_size"] = BATTERY_SIZES.get(battery_size, "Unknown") - battery_quantity = results.get("battery_quantity") + battery_quantity = self._channel.cluster.get("battery_quantity") if battery_quantity is not None: state_attrs["battery_quantity"] = battery_quantity + battery_voltage = self._channel.cluster.get("battery_voltage") + if battery_voltage is not None: + state_attrs["battery_voltage"] = round(battery_voltage / 10, 1) return state_attrs - @callback - def async_update_state_attribute(self, key, value): - """Update a single device state attribute.""" - if key == "battery_voltage": - self._device_state_attributes[key] = round(value / 10, 1) - self.async_write_ha_state() - @STRICT_MATCH(channel_names=CHANNEL_ELECTRICAL_MEASUREMENT) class ElectricalMeasurement(Sensor): @@ -202,7 +191,6 @@ class ElectricalMeasurement(Sensor): SENSOR_ATTR = "active_power" _device_class = DEVICE_CLASS_POWER - _divisor = 10 _unit = POWER_WATT @property @@ -210,7 +198,7 @@ class ElectricalMeasurement(Sensor): """Return True if HA needs to poll for state changes.""" return True - def formatter(self, value) -> int: + def formatter(self, value: int) -> Union[int, float]: """Return 'normalized' value.""" value = value * self._channel.multiplier / self._channel.divisor if value < 100 and self._channel.divisor > 1: @@ -244,7 +232,7 @@ class Illuminance(Sensor): _unit = LIGHT_LUX @staticmethod - def formatter(value): + def formatter(value: int) -> float: """Convert illumination data.""" return round(pow(10, ((value - 1) / 10000)), 1) @@ -256,12 +244,12 @@ class SmartEnergyMetering(Sensor): SENSOR_ATTR = "instantaneous_demand" _device_class = DEVICE_CLASS_POWER - def formatter(self, value): + def formatter(self, value: int) -> Union[int, float]: """Pass through channel formatter.""" return self._channel.formatter_function(value) @property - def unit_of_measurement(self): + def unit_of_measurement(self) -> str: """Return Unit of measurement.""" return self._channel.unit_of_measurement @@ -284,14 +272,3 @@ class Temperature(Sensor): _device_class = DEVICE_CLASS_TEMPERATURE _divisor = 100 _unit = TEMP_CELSIUS - - @callback - def async_restore_last_state(self, last_state): - """Restore previous state.""" - if last_state.state == STATE_UNKNOWN: - return - if last_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) != TEMP_CELSIUS: - ftemp = float(last_state.state) - self._state = round(fahrenheit_to_celsius(ftemp), 1) - return - self._state = last_state.state diff --git a/tests/components/zha/test_sensor.py b/tests/components/zha/test_sensor.py index 8d854894ba0..b6b4b343e3b 100644 --- a/tests/components/zha/test_sensor.py +++ b/tests/components/zha/test_sensor.py @@ -93,18 +93,59 @@ async def async_test_electrical_measurement(hass, cluster, entity_id): assert_state(hass, entity_id, "9.9", POWER_WATT) +async def async_test_powerconfiguration(hass, cluster, entity_id): + """Test powerconfiguration/battery sensor.""" + await send_attributes_report(hass, cluster, {33: 98}) + assert_state(hass, entity_id, "49", "%") + assert hass.states.get(entity_id).attributes["battery_voltage"] == 2.9 + assert hass.states.get(entity_id).attributes["battery_quantity"] == 3 + assert hass.states.get(entity_id).attributes["battery_size"] == "AAA" + await send_attributes_report(hass, cluster, {32: 20}) + assert hass.states.get(entity_id).attributes["battery_voltage"] == 2.0 + + @pytest.mark.parametrize( - "cluster_id, test_func, report_count", + "cluster_id, test_func, report_count, read_plug", ( - (measurement.RelativeHumidity.cluster_id, async_test_humidity, 1), - (measurement.TemperatureMeasurement.cluster_id, async_test_temperature, 1), - (measurement.PressureMeasurement.cluster_id, async_test_pressure, 1), - (measurement.IlluminanceMeasurement.cluster_id, async_test_illuminance, 1), - (smartenergy.Metering.cluster_id, async_test_metering, 1), + (measurement.RelativeHumidity.cluster_id, async_test_humidity, 1, None), + ( + measurement.TemperatureMeasurement.cluster_id, + async_test_temperature, + 1, + None, + ), + (measurement.PressureMeasurement.cluster_id, async_test_pressure, 1, None), + ( + measurement.IlluminanceMeasurement.cluster_id, + async_test_illuminance, + 1, + None, + ), + ( + smartenergy.Metering.cluster_id, + async_test_metering, + 1, + { + "demand_formatting": 0xF9, + "divisor": 1, + "multiplier": 1, + }, + ), ( homeautomation.ElectricalMeasurement.cluster_id, async_test_electrical_measurement, 1, + None, + ), + ( + general.PowerConfiguration.cluster_id, + async_test_powerconfiguration, + 2, + { + "battery_size": 4, # AAA + "battery_voltage": 29, + "battery_quantity": 3, + }, ), ), ) @@ -115,6 +156,7 @@ async def test_sensor( cluster_id, test_func, report_count, + read_plug, ): """Test zha sensor platform.""" @@ -128,6 +170,10 @@ async def test_sensor( } ) cluster = zigpy_device.endpoints[1].in_clusters[cluster_id] + if cluster_id == smartenergy.Metering.cluster_id: + # this one is mains powered + zigpy_device.node_desc.mac_capability_flags |= 0b_0000_0100 + cluster.PLUGGED_ATTR_READS = read_plug zha_device = await zha_device_joined_restored(zigpy_device) entity_id = await find_entity_id(DOMAIN, zha_device, hass)