Refactor ZHA sensor initialization (#43339)

* Refactor ZHA sensors to use cached values after restart

* Get attr from cluster, not channel

* Run cached state through formatter method

* Use cached values for div/multiplier for SmartEnergy channel

* Restore batter voltage from cache

* Refactor sensor to use cached values only

* Update tests

* Add battery sensor test
This commit is contained in:
Alexei Chetroi 2020-11-18 21:34:12 -05:00 committed by GitHub
parent 70a3489845
commit 54c4e9335f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 128 additions and 138 deletions

View File

@ -208,7 +208,7 @@ class ZigbeeChannel(LogMixin):
attributes = [] attributes = []
for report_config in self._report_config: for report_config in self._report_config:
attributes.append(report_config["attr"]) attributes.append(report_config["attr"])
if len(attributes) > 0: if attributes:
await self.get_attributes(attributes, from_cache=from_cache) await self.get_attributes(attributes, from_cache=from_cache)
self._status = ChannelStatus.INITIALIZED self._status = ChannelStatus.INITIALIZED

View File

@ -17,10 +17,9 @@ from ..const import (
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
SIGNAL_MOVE_LEVEL, SIGNAL_MOVE_LEVEL,
SIGNAL_SET_LEVEL, SIGNAL_SET_LEVEL,
SIGNAL_STATE_ATTR,
SIGNAL_UPDATE_DEVICE, SIGNAL_UPDATE_DEVICE,
) )
from .base import ClientChannel, ZigbeeChannel, parse_and_log_command from .base import ChannelStatus, ClientChannel, ZigbeeChannel, parse_and_log_command
@registries.ZIGBEE_CHANNEL_REGISTRY.register(general.Alarms.cluster_id) @registries.ZIGBEE_CHANNEL_REGISTRY.register(general.Alarms.cluster_id)
@ -72,13 +71,6 @@ class BasicChannel(ZigbeeChannel):
6: "Emergency mains and transfer switch", 6: "Emergency mains and transfer switch",
} }
def __init__(
self, cluster: zha_typing.ZigpyClusterType, ch_pool: zha_typing.ChannelPoolType
) -> None:
"""Initialize BasicChannel."""
super().__init__(cluster, ch_pool)
self._power_source = None
async def async_configure(self): async def async_configure(self):
"""Configure this channel.""" """Configure this channel."""
await super().async_configure() await super().async_configure()
@ -87,16 +79,12 @@ class BasicChannel(ZigbeeChannel):
async def async_initialize(self, from_cache): async def async_initialize(self, from_cache):
"""Initialize channel.""" """Initialize channel."""
if not self._ch_pool.skip_configuration or from_cache: if not self._ch_pool.skip_configuration or from_cache:
power_source = await self.get_attribute_value( await self.get_attribute_value("power_source", from_cache=from_cache)
"power_source", from_cache=from_cache
)
if power_source is not None:
self._power_source = power_source
await super().async_initialize(from_cache) await super().async_initialize(from_cache)
def get_power_source(self): def get_power_source(self):
"""Get the power source.""" """Get the power source."""
return self._power_source return self.cluster.get("power_source")
@registries.ZIGBEE_CHANNEL_REGISTRY.register(general.BinaryInput.cluster_id) @registries.ZIGBEE_CHANNEL_REGISTRY.register(general.BinaryInput.cluster_id)
@ -392,38 +380,8 @@ class PowerConfigurationChannel(ZigbeeChannel):
{"attr": "battery_percentage_remaining", "config": REPORT_CONFIG_BATTERY_SAVE}, {"attr": "battery_percentage_remaining", "config": REPORT_CONFIG_BATTERY_SAVE},
) )
@callback
def attribute_updated(self, attrid, value):
"""Handle attribute updates on this cluster."""
attr = self._report_config[1].get("attr")
if isinstance(attr, str):
attr_id = self.cluster.attridx.get(attr)
else:
attr_id = attr
if attrid == attr_id:
self.async_send_signal(
f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}",
attrid,
self.cluster.attributes.get(attrid, [attrid])[0],
value,
)
return
attr_name = self.cluster.attributes.get(attrid, [attrid])[0]
self.async_send_signal(
f"{self.unique_id}_{SIGNAL_STATE_ATTR}", attr_name, value
)
async def async_initialize(self, from_cache): async def async_initialize(self, from_cache):
"""Initialize channel.""" """Initialize channel."""
await self.async_read_state(from_cache)
await super().async_initialize(from_cache)
async def async_update(self):
"""Retrieve latest state."""
await self.async_read_state(True)
async def async_read_state(self, from_cache):
"""Read data from the cluster."""
attributes = [ attributes = [
"battery_size", "battery_size",
"battery_percentage_remaining", "battery_percentage_remaining",
@ -431,6 +389,7 @@ class PowerConfigurationChannel(ZigbeeChannel):
"battery_quantity", "battery_quantity",
] ]
await self.get_attributes(attributes, from_cache=from_cache) await self.get_attributes(attributes, from_cache=from_cache)
self._status = ChannelStatus.INITIALIZED
@registries.ZIGBEE_CHANNEL_REGISTRY.register(general.PowerProfile.cluster_id) @registries.ZIGBEE_CHANNEL_REGISTRY.register(general.PowerProfile.cluster_id)

View File

@ -1,4 +1,6 @@
"""Smart energy channels module for Zigbee Home Automation.""" """Smart energy channels module for Zigbee Home Automation."""
from typing import Union
import zigpy.zcl.clusters.smartenergy as smartenergy import zigpy.zcl.clusters.smartenergy as smartenergy
from homeassistant.const import ( from homeassistant.const import (
@ -82,44 +84,48 @@ class Metering(ZigbeeChannel):
) -> None: ) -> None:
"""Initialize Metering.""" """Initialize Metering."""
super().__init__(cluster, ch_pool) super().__init__(cluster, ch_pool)
self._divisor = 1
self._multiplier = 1
self._unit_enum = None
self._format_spec = None self._format_spec = None
async def async_configure(self): @property
def divisor(self) -> int:
"""Return divisor for the value."""
return self.cluster.get("divisor")
@property
def multiplier(self) -> int:
"""Return multiplier for the value."""
return self.cluster.get("multiplier")
async def async_configure(self) -> None:
"""Configure channel.""" """Configure channel."""
await self.fetch_config(False) await self.fetch_config(False)
await super().async_configure() await super().async_configure()
async def async_initialize(self, from_cache): async def async_initialize(self, from_cache: bool) -> None:
"""Initialize channel.""" """Initialize channel."""
await self.fetch_config(True) await self.fetch_config(True)
await super().async_initialize(from_cache) await super().async_initialize(from_cache)
@callback @callback
def attribute_updated(self, attrid, value): def attribute_updated(self, attrid: int, value: int) -> None:
"""Handle attribute update from Metering cluster.""" """Handle attribute update from Metering cluster."""
if None in (self._multiplier, self._divisor, self._format_spec): if None in (self.multiplier, self.divisor, self._format_spec):
return return
super().attribute_updated(attrid, value * self._multiplier / self._divisor) super().attribute_updated(attrid, value)
@property @property
def unit_of_measurement(self): def unit_of_measurement(self) -> str:
"""Return unit of measurement.""" """Return unit of measurement."""
return self.unit_of_measure_map.get(self._unit_enum & 0x7F, "unknown") uom = self.cluster.get("unit_of_measure", 0x7F)
return self.unit_of_measure_map.get(uom & 0x7F, "unknown")
async def fetch_config(self, from_cache): async def fetch_config(self, from_cache: bool) -> None:
"""Fetch config from device and updates format specifier.""" """Fetch config from device and updates format specifier."""
results = await self.get_attributes( results = await self.get_attributes(
["divisor", "multiplier", "unit_of_measure", "demand_formatting"], ["divisor", "multiplier", "unit_of_measure", "demand_formatting"],
from_cache=from_cache, from_cache=from_cache,
) )
self._divisor = results.get("divisor", self._divisor)
self._multiplier = results.get("multiplier", self._multiplier)
self._unit_enum = results.get("unit_of_measure", 0x7F) # default to unknown
fmting = results.get( fmting = results.get(
"demand_formatting", 0xF9 "demand_formatting", 0xF9
) # 1 digit to the right, 15 digits to the left ) # 1 digit to the right, 15 digits to the left
@ -135,8 +141,9 @@ class Metering(ZigbeeChannel):
else: else:
self._format_spec = "{:0" + str(width) + "." + str(r_digits) + "f}" self._format_spec = "{:0" + str(width) + "." + str(r_digits) + "f}"
def formatter_function(self, value): def formatter_function(self, value: int) -> Union[int, float]:
"""Return formatted value for display.""" """Return formatted value for display."""
value = value * self.multiplier / self.divisor
if self.unit_of_measurement == POWER_WATT: if self.unit_of_measurement == POWER_WATT:
# Zigbee spec power unit is kW, but we show the value in W # Zigbee spec power unit is kW, but we show the value in W
value_watt = value * 1000 value_watt = value * 1000

View File

@ -39,12 +39,13 @@ async def async_add_entities(
Tuple[str, zha_typing.ZhaDeviceType, List[zha_typing.ChannelType]], Tuple[str, zha_typing.ZhaDeviceType, List[zha_typing.ChannelType]],
] ]
], ],
update_before_add: bool = True,
) -> None: ) -> None:
"""Add entities helper.""" """Add entities helper."""
if not entities: if not entities:
return return
to_add = [ent_cls(*args) for ent_cls, args in entities] to_add = [ent_cls(*args) for ent_cls, args in entities]
_async_add_entities(to_add, update_before_add=True) _async_add_entities(to_add, update_before_add=update_before_add)
entities.clear() entities.clear()

View File

@ -1,6 +1,7 @@
"""Sensors on Zigbee Home Automation networks.""" """Sensors on Zigbee Home Automation networks."""
import functools import functools
import numbers import numbers
from typing import Any, Callable, Dict, List, Optional, Union
from homeassistant.components.sensor import ( from homeassistant.components.sensor import (
DEVICE_CLASS_BATTERY, DEVICE_CLASS_BATTERY,
@ -11,18 +12,17 @@ from homeassistant.components.sensor import (
DEVICE_CLASS_TEMPERATURE, DEVICE_CLASS_TEMPERATURE,
DOMAIN, DOMAIN,
) )
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
ATTR_UNIT_OF_MEASUREMENT,
LIGHT_LUX, LIGHT_LUX,
PERCENTAGE, PERCENTAGE,
POWER_WATT, POWER_WATT,
PRESSURE_HPA, PRESSURE_HPA,
STATE_UNKNOWN,
TEMP_CELSIUS, TEMP_CELSIUS,
) )
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.util.temperature import fahrenheit_to_celsius from homeassistant.helpers.typing import HomeAssistantType, StateType
from .core import discovery from .core import discovery
from .core.const import ( from .core.const import (
@ -38,9 +38,9 @@ from .core.const import (
DATA_ZHA_DISPATCHERS, DATA_ZHA_DISPATCHERS,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
SIGNAL_STATE_ATTR,
) )
from .core.registries import SMARTTHINGS_HUMIDITY_CLUSTER, ZHA_ENTITIES from .core.registries import SMARTTHINGS_HUMIDITY_CLUSTER, ZHA_ENTITIES
from .core.typing import ChannelType, ZhaDeviceType
from .entity import ZhaEntity from .entity import ZhaEntity
PARALLEL_UPDATES = 5 PARALLEL_UPDATES = 5
@ -65,7 +65,9 @@ CHANNEL_ST_HUMIDITY_CLUSTER = f"channel_0x{SMARTTHINGS_HUMIDITY_CLUSTER:04x}"
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN) STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN)
async def async_setup_entry(hass, config_entry, async_add_entities): async def async_setup_entry(
hass: HomeAssistantType, config_entry: ConfigEntry, async_add_entities: Callable
) -> None:
"""Set up the Zigbee Home Automation sensor from config entry.""" """Set up the Zigbee Home Automation sensor from config entry."""
entities_to_create = hass.data[DATA_ZHA][DOMAIN] entities_to_create = hass.data[DATA_ZHA][DOMAIN]
@ -73,7 +75,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
hass, hass,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
functools.partial( functools.partial(
discovery.async_add_entities, async_add_entities, entities_to_create discovery.async_add_entities,
async_add_entities,
entities_to_create,
update_before_add=False,
), ),
) )
hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub)
@ -82,29 +87,30 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
class Sensor(ZhaEntity): class Sensor(ZhaEntity):
"""Base ZHA sensor.""" """Base ZHA sensor."""
SENSOR_ATTR = None SENSOR_ATTR: Optional[Union[int, str]] = None
_decimals = 1 _decimals: int = 1
_device_class = None _device_class: Optional[str] = None
_divisor = 1 _divisor: int = 1
_multiplier = 1 _multiplier: int = 1
_unit = None _unit: Optional[str] = None
def __init__(self, unique_id, zha_device, channels, **kwargs): def __init__(
self,
unique_id: str,
zha_device: ZhaDeviceType,
channels: List[ChannelType],
**kwargs,
):
"""Init this sensor.""" """Init this sensor."""
super().__init__(unique_id, zha_device, channels, **kwargs) super().__init__(unique_id, zha_device, channels, **kwargs)
self._channel = channels[0] self._channel: ChannelType = channels[0]
async def async_added_to_hass(self): async def async_added_to_hass(self) -> None:
"""Run when about to be added to hass.""" """Run when about to be added to hass."""
await super().async_added_to_hass() await super().async_added_to_hass()
self._device_state_attributes.update(await self.async_state_attr_provider())
self.async_accept_signal( self.async_accept_signal(
self._channel, SIGNAL_ATTR_UPDATED, self.async_set_state self._channel, SIGNAL_ATTR_UPDATED, self.async_set_state
) )
self.async_accept_signal(
self._channel, SIGNAL_STATE_ATTR, self.async_update_state_attribute
)
@property @property
def device_class(self) -> str: def device_class(self) -> str:
@ -112,37 +118,25 @@ class Sensor(ZhaEntity):
return self._device_class return self._device_class
@property @property
def unit_of_measurement(self): def unit_of_measurement(self) -> Optional[str]:
"""Return the unit of measurement of this entity.""" """Return the unit of measurement of this entity."""
return self._unit return self._unit
@property @property
def state(self) -> str: def state(self) -> StateType:
"""Return the state of the entity.""" """Return the state of the entity."""
if self._state is None: assert self.SENSOR_ATTR is not None
raw_state = self._channel.cluster.get(self.SENSOR_ATTR)
if raw_state is None:
return None return None
return self._state return self.formatter(raw_state)
@callback @callback
def async_set_state(self, attr_id, attr_name, value): def async_set_state(self, attr_id: int, attr_name: str, value: Any) -> None:
"""Handle state update from channel.""" """Handle state update from channel."""
if self.SENSOR_ATTR is None or self.SENSOR_ATTR != attr_name:
return
if value is not None:
value = self.formatter(value)
self._state = value
self.async_write_ha_state() self.async_write_ha_state()
@callback def formatter(self, value: int) -> Union[int, float]:
def async_restore_last_state(self, last_state):
"""Restore previous state."""
self._state = last_state.state
async def async_state_attr_provider(self):
"""Initialize device state attributes."""
return {}
def formatter(self, value):
"""Numeric pass-through formatter.""" """Numeric pass-through formatter."""
if self._decimals > 0: if self._decimals > 0:
return round( return round(
@ -167,7 +161,7 @@ class Battery(Sensor):
_unit = PERCENTAGE _unit = PERCENTAGE
@staticmethod @staticmethod
def formatter(value): def formatter(value: int) -> int:
"""Return the state of the entity.""" """Return the state of the entity."""
# per zcl specs battery percent is reported at 200% ¯\_(ツ)_/¯ # per zcl specs battery percent is reported at 200% ¯\_(ツ)_/¯
if not isinstance(value, numbers.Number) or value == -1: if not isinstance(value, numbers.Number) or value == -1:
@ -175,26 +169,21 @@ class Battery(Sensor):
value = round(value / 2) value = round(value / 2)
return value return value
async def async_state_attr_provider(self): @property
def device_state_attributes(self) -> Dict[str, Any]:
"""Return device state attrs for battery sensors.""" """Return device state attrs for battery sensors."""
state_attrs = {} state_attrs = {}
attributes = ["battery_size", "battery_quantity"] battery_size = self._channel.cluster.get("battery_size")
results = await self._channel.get_attributes(attributes)
battery_size = results.get("battery_size")
if battery_size is not None: if battery_size is not None:
state_attrs["battery_size"] = BATTERY_SIZES.get(battery_size, "Unknown") state_attrs["battery_size"] = BATTERY_SIZES.get(battery_size, "Unknown")
battery_quantity = results.get("battery_quantity") battery_quantity = self._channel.cluster.get("battery_quantity")
if battery_quantity is not None: if battery_quantity is not None:
state_attrs["battery_quantity"] = battery_quantity state_attrs["battery_quantity"] = battery_quantity
battery_voltage = self._channel.cluster.get("battery_voltage")
if battery_voltage is not None:
state_attrs["battery_voltage"] = round(battery_voltage / 10, 1)
return state_attrs return state_attrs
@callback
def async_update_state_attribute(self, key, value):
"""Update a single device state attribute."""
if key == "battery_voltage":
self._device_state_attributes[key] = round(value / 10, 1)
self.async_write_ha_state()
@STRICT_MATCH(channel_names=CHANNEL_ELECTRICAL_MEASUREMENT) @STRICT_MATCH(channel_names=CHANNEL_ELECTRICAL_MEASUREMENT)
class ElectricalMeasurement(Sensor): class ElectricalMeasurement(Sensor):
@ -202,7 +191,6 @@ class ElectricalMeasurement(Sensor):
SENSOR_ATTR = "active_power" SENSOR_ATTR = "active_power"
_device_class = DEVICE_CLASS_POWER _device_class = DEVICE_CLASS_POWER
_divisor = 10
_unit = POWER_WATT _unit = POWER_WATT
@property @property
@ -210,7 +198,7 @@ class ElectricalMeasurement(Sensor):
"""Return True if HA needs to poll for state changes.""" """Return True if HA needs to poll for state changes."""
return True return True
def formatter(self, value) -> int: def formatter(self, value: int) -> Union[int, float]:
"""Return 'normalized' value.""" """Return 'normalized' value."""
value = value * self._channel.multiplier / self._channel.divisor value = value * self._channel.multiplier / self._channel.divisor
if value < 100 and self._channel.divisor > 1: if value < 100 and self._channel.divisor > 1:
@ -244,7 +232,7 @@ class Illuminance(Sensor):
_unit = LIGHT_LUX _unit = LIGHT_LUX
@staticmethod @staticmethod
def formatter(value): def formatter(value: int) -> float:
"""Convert illumination data.""" """Convert illumination data."""
return round(pow(10, ((value - 1) / 10000)), 1) return round(pow(10, ((value - 1) / 10000)), 1)
@ -256,12 +244,12 @@ class SmartEnergyMetering(Sensor):
SENSOR_ATTR = "instantaneous_demand" SENSOR_ATTR = "instantaneous_demand"
_device_class = DEVICE_CLASS_POWER _device_class = DEVICE_CLASS_POWER
def formatter(self, value): def formatter(self, value: int) -> Union[int, float]:
"""Pass through channel formatter.""" """Pass through channel formatter."""
return self._channel.formatter_function(value) return self._channel.formatter_function(value)
@property @property
def unit_of_measurement(self): def unit_of_measurement(self) -> str:
"""Return Unit of measurement.""" """Return Unit of measurement."""
return self._channel.unit_of_measurement return self._channel.unit_of_measurement
@ -284,14 +272,3 @@ class Temperature(Sensor):
_device_class = DEVICE_CLASS_TEMPERATURE _device_class = DEVICE_CLASS_TEMPERATURE
_divisor = 100 _divisor = 100
_unit = TEMP_CELSIUS _unit = TEMP_CELSIUS
@callback
def async_restore_last_state(self, last_state):
"""Restore previous state."""
if last_state.state == STATE_UNKNOWN:
return
if last_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) != TEMP_CELSIUS:
ftemp = float(last_state.state)
self._state = round(fahrenheit_to_celsius(ftemp), 1)
return
self._state = last_state.state

View File

@ -93,18 +93,59 @@ async def async_test_electrical_measurement(hass, cluster, entity_id):
assert_state(hass, entity_id, "9.9", POWER_WATT) assert_state(hass, entity_id, "9.9", POWER_WATT)
async def async_test_powerconfiguration(hass, cluster, entity_id):
"""Test powerconfiguration/battery sensor."""
await send_attributes_report(hass, cluster, {33: 98})
assert_state(hass, entity_id, "49", "%")
assert hass.states.get(entity_id).attributes["battery_voltage"] == 2.9
assert hass.states.get(entity_id).attributes["battery_quantity"] == 3
assert hass.states.get(entity_id).attributes["battery_size"] == "AAA"
await send_attributes_report(hass, cluster, {32: 20})
assert hass.states.get(entity_id).attributes["battery_voltage"] == 2.0
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cluster_id, test_func, report_count", "cluster_id, test_func, report_count, read_plug",
( (
(measurement.RelativeHumidity.cluster_id, async_test_humidity, 1), (measurement.RelativeHumidity.cluster_id, async_test_humidity, 1, None),
(measurement.TemperatureMeasurement.cluster_id, async_test_temperature, 1), (
(measurement.PressureMeasurement.cluster_id, async_test_pressure, 1), measurement.TemperatureMeasurement.cluster_id,
(measurement.IlluminanceMeasurement.cluster_id, async_test_illuminance, 1), async_test_temperature,
(smartenergy.Metering.cluster_id, async_test_metering, 1), 1,
None,
),
(measurement.PressureMeasurement.cluster_id, async_test_pressure, 1, None),
(
measurement.IlluminanceMeasurement.cluster_id,
async_test_illuminance,
1,
None,
),
(
smartenergy.Metering.cluster_id,
async_test_metering,
1,
{
"demand_formatting": 0xF9,
"divisor": 1,
"multiplier": 1,
},
),
( (
homeautomation.ElectricalMeasurement.cluster_id, homeautomation.ElectricalMeasurement.cluster_id,
async_test_electrical_measurement, async_test_electrical_measurement,
1, 1,
None,
),
(
general.PowerConfiguration.cluster_id,
async_test_powerconfiguration,
2,
{
"battery_size": 4, # AAA
"battery_voltage": 29,
"battery_quantity": 3,
},
), ),
), ),
) )
@ -115,6 +156,7 @@ async def test_sensor(
cluster_id, cluster_id,
test_func, test_func,
report_count, report_count,
read_plug,
): ):
"""Test zha sensor platform.""" """Test zha sensor platform."""
@ -128,6 +170,10 @@ async def test_sensor(
} }
) )
cluster = zigpy_device.endpoints[1].in_clusters[cluster_id] cluster = zigpy_device.endpoints[1].in_clusters[cluster_id]
if cluster_id == smartenergy.Metering.cluster_id:
# this one is mains powered
zigpy_device.node_desc.mac_capability_flags |= 0b_0000_0100
cluster.PLUGGED_ATTR_READS = read_plug
zha_device = await zha_device_joined_restored(zigpy_device) zha_device = await zha_device_joined_restored(zigpy_device)
entity_id = await find_entity_id(DOMAIN, zha_device, hass) entity_id = await find_entity_id(DOMAIN, zha_device, hass)