diff --git a/homeassistant/components/zha/sensor.py b/homeassistant/components/zha/sensor.py index b2d246c3095..fb4b577ff61 100644 --- a/homeassistant/components/zha/sensor.py +++ b/homeassistant/components/zha/sensor.py @@ -2,7 +2,10 @@ import logging from homeassistant.core import callback -from homeassistant.components.sensor import DOMAIN +from homeassistant.components.sensor import ( + DOMAIN, DEVICE_CLASS_HUMIDITY, DEVICE_CLASS_ILLUMINANCE, + DEVICE_CLASS_TEMPERATURE, DEVICE_CLASS_PRESSURE, DEVICE_CLASS_POWER +) from homeassistant.const import ( TEMP_CELSIUS, POWER_WATT, ATTR_UNIT_OF_MEASUREMENT ) @@ -11,13 +14,12 @@ from .core.const import ( DATA_ZHA, DATA_ZHA_DISPATCHERS, ZHA_DISCOVERY_NEW, HUMIDITY, TEMPERATURE, ILLUMINANCE, PRESSURE, METERING, ELECTRICAL_MEASUREMENT, GENERIC, SENSOR_TYPE, ATTRIBUTE_CHANNEL, ELECTRICAL_MEASUREMENT_CHANNEL, - SIGNAL_ATTR_UPDATED, SIGNAL_STATE_ATTR) + SIGNAL_ATTR_UPDATED, SIGNAL_STATE_ATTR, UNKNOWN) from .entity import ZhaEntity PARALLEL_UPDATES = 5 _LOGGER = logging.getLogger(__name__) - # Formatter functions def pass_through_formatter(value): """No op update function.""" @@ -91,6 +93,16 @@ FORCE_UPDATE_REGISTRY = { ELECTRICAL_MEASUREMENT: False } +DEVICE_CLASS_REGISTRY = { + UNKNOWN: None, + HUMIDITY: DEVICE_CLASS_HUMIDITY, + TEMPERATURE: DEVICE_CLASS_TEMPERATURE, + PRESSURE: DEVICE_CLASS_PRESSURE, + ILLUMINANCE: DEVICE_CLASS_ILLUMINANCE, + METERING: DEVICE_CLASS_POWER, + ELECTRICAL_MEASUREMENT: DEVICE_CLASS_POWER +} + async def async_setup_platform(hass, config, async_add_entities, discovery_info=None): @@ -155,6 +167,10 @@ class Sensor(ZhaEntity): self._channel = self.cluster_channels.get( CHANNEL_REGISTRY.get(self._sensor_type, ATTRIBUTE_CHANNEL) ) + self._device_class = DEVICE_CLASS_REGISTRY.get( + self._sensor_type, + None + ) async def async_added_to_hass(self): """Run when about to be added to hass.""" @@ -165,6 +181,11 @@ class Sensor(ZhaEntity): self._channel, SIGNAL_STATE_ATTR, self.async_update_state_attribute) + @property + def device_class(self) -> str: + """Return device class from component DEVICE_CLASSES.""" + return self._device_class + @property def unit_of_measurement(self): """Return the unit of measurement of this entity."""