Statistics component fix device_class for incremental source sensors (#88096)

* Return None device_class for incremental source sensors

* Ignore linting error

* Fix ignore linting error

* Fix ignore linting error

* Fix ignore linting error

* Catch potential parsing error with enum
This commit is contained in:
Thomas Dietrich 2023-02-15 10:22:09 +01:00 committed by Paulus Schoutsen
parent 55fed18e3e
commit 634aff0006
2 changed files with 153 additions and 28 deletions

View File

@ -13,7 +13,8 @@ import voluptuous as vol
from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN
from homeassistant.components.recorder import get_instance, history from homeassistant.components.recorder import get_instance, history
from homeassistant.components.sensor import ( from homeassistant.components.sensor import ( # pylint: disable=hass-deprecated-import
DEVICE_CLASS_STATE_CLASSES,
PLATFORM_SCHEMA, PLATFORM_SCHEMA,
SensorDeviceClass, SensorDeviceClass,
SensorEntity, SensorEntity,
@ -47,6 +48,7 @@ from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.start import async_at_start from homeassistant.helpers.start import async_at_start
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, StateType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, StateType
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.enum import try_parse_enum
from . import DOMAIN, PLATFORMS from . import DOMAIN, PLATFORMS
@ -144,7 +146,7 @@ STATS_DATETIME = {
} }
# Statistics which retain the unit of the source entity # Statistics which retain the unit of the source entity
STAT_NUMERIC_RETAIN_UNIT = { STATS_NUMERIC_RETAIN_UNIT = {
STAT_AVERAGE_LINEAR, STAT_AVERAGE_LINEAR,
STAT_AVERAGE_STEP, STAT_AVERAGE_STEP,
STAT_AVERAGE_TIMELESS, STAT_AVERAGE_TIMELESS,
@ -166,7 +168,7 @@ STAT_NUMERIC_RETAIN_UNIT = {
} }
# Statistics which produce percentage ratio from binary_sensor source entity # Statistics which produce percentage ratio from binary_sensor source entity
STAT_BINARY_PERCENTAGE = { STATS_BINARY_PERCENTAGE = {
STAT_AVERAGE_STEP, STAT_AVERAGE_STEP,
STAT_AVERAGE_TIMELESS, STAT_AVERAGE_TIMELESS,
STAT_MEAN, STAT_MEAN,
@ -296,15 +298,9 @@ class StatisticsSensor(SensorEntity):
self.ages: deque[datetime] = deque(maxlen=self._samples_max_buffer_size) self.ages: deque[datetime] = deque(maxlen=self._samples_max_buffer_size)
self.attributes: dict[str, StateType] = {} self.attributes: dict[str, StateType] = {}
self._state_characteristic_fn: Callable[[], StateType | datetime] self._state_characteristic_fn: Callable[
if self.is_binary: [], StateType | datetime
self._state_characteristic_fn = getattr( ] = self._callable_characteristic_fn(self._state_characteristic)
self, f"_stat_binary_{self._state_characteristic}"
)
else:
self._state_characteristic_fn = getattr(
self, f"_stat_{self._state_characteristic}"
)
self._update_listener: CALLBACK_TYPE | None = None self._update_listener: CALLBACK_TYPE | None = None
@ -368,11 +364,11 @@ class StatisticsSensor(SensorEntity):
def _derive_unit_of_measurement(self, new_state: State) -> str | None: def _derive_unit_of_measurement(self, new_state: State) -> str | None:
base_unit: str | None = new_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) base_unit: str | None = new_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
unit: str | None unit: str | None
if self.is_binary and self._state_characteristic in STAT_BINARY_PERCENTAGE: if self.is_binary and self._state_characteristic in STATS_BINARY_PERCENTAGE:
unit = PERCENTAGE unit = PERCENTAGE
elif not base_unit: elif not base_unit:
unit = None unit = None
elif self._state_characteristic in STAT_NUMERIC_RETAIN_UNIT: elif self._state_characteristic in STATS_NUMERIC_RETAIN_UNIT:
unit = base_unit unit = base_unit
elif self._state_characteristic in STATS_NOT_A_NUMBER: elif self._state_characteristic in STATS_NOT_A_NUMBER:
unit = None unit = None
@ -393,11 +389,24 @@ class StatisticsSensor(SensorEntity):
@property @property
def device_class(self) -> SensorDeviceClass | None: def device_class(self) -> SensorDeviceClass | None:
"""Return the class of this device.""" """Return the class of this device."""
if self._state_characteristic in STAT_NUMERIC_RETAIN_UNIT:
_state = self.hass.states.get(self._source_entity_id)
return None if _state is None else _state.attributes.get(ATTR_DEVICE_CLASS)
if self._state_characteristic in STATS_DATETIME: if self._state_characteristic in STATS_DATETIME:
return SensorDeviceClass.TIMESTAMP return SensorDeviceClass.TIMESTAMP
if self._state_characteristic in STATS_NUMERIC_RETAIN_UNIT:
source_state = self.hass.states.get(self._source_entity_id)
if source_state is None:
return None
source_device_class = source_state.attributes.get(ATTR_DEVICE_CLASS)
if source_device_class is None:
return None
sensor_device_class = try_parse_enum(SensorDeviceClass, source_device_class)
if sensor_device_class is None:
return None
sensor_state_classes = DEVICE_CLASS_STATE_CLASSES.get(
sensor_device_class, set()
)
if SensorStateClass.MEASUREMENT not in sensor_state_classes:
return None
return sensor_device_class
return None return None
@property @property
@ -472,8 +481,8 @@ class StatisticsSensor(SensorEntity):
if timestamp := self._next_to_purge_timestamp(): if timestamp := self._next_to_purge_timestamp():
_LOGGER.debug("%s: scheduling update at %s", self.entity_id, timestamp) _LOGGER.debug("%s: scheduling update at %s", self.entity_id, timestamp)
if self._update_listener: if self._update_listener:
self._update_listener() self._update_listener() # pragma: no cover
self._update_listener = None self._update_listener = None # pragma: no cover
@callback @callback
def _scheduled_update(now: datetime) -> None: def _scheduled_update(now: datetime) -> None:
@ -563,6 +572,18 @@ class StatisticsSensor(SensorEntity):
value = int(value) value = int(value)
self._value = value self._value = value
def _callable_characteristic_fn(
self, characteristic: str
) -> Callable[[], StateType | datetime]:
"""Return the function callable of one characteristic function."""
function: Callable[[], StateType | datetime] = getattr(
self,
f"_stat_binary_{characteristic}"
if self.is_binary
else f"_stat_{characteristic}",
)
return function
# Statistics for numeric sensor # Statistics for numeric sensor
def _stat_average_linear(self) -> StateType: def _stat_average_linear(self) -> StateType:

View File

@ -21,6 +21,7 @@ from homeassistant.const import (
SERVICE_RELOAD, SERVICE_RELOAD,
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
STATE_UNKNOWN, STATE_UNKNOWN,
UnitOfEnergy,
UnitOfTemperature, UnitOfTemperature,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -250,6 +251,63 @@ async def test_sensor_source_with_force_update(hass: HomeAssistant):
assert state_force.attributes.get("buffer_usage_ratio") == round(9 / 20, 2) assert state_force.attributes.get("buffer_usage_ratio") == round(9 / 20, 2)
async def test_sampling_boundaries_given(hass: HomeAssistant):
"""Test if either sampling_size or max_age are given."""
assert await async_setup_component(
hass,
"sensor",
{
"sensor": [
{
"platform": "statistics",
"name": "test_boundaries_none",
"entity_id": "sensor.test_monitored",
"state_characteristic": "mean",
},
{
"platform": "statistics",
"name": "test_boundaries_size",
"entity_id": "sensor.test_monitored",
"state_characteristic": "mean",
"sampling_size": 20,
},
{
"platform": "statistics",
"name": "test_boundaries_age",
"entity_id": "sensor.test_monitored",
"state_characteristic": "mean",
"max_age": {"minutes": 4},
},
{
"platform": "statistics",
"name": "test_boundaries_both",
"entity_id": "sensor.test_monitored",
"state_characteristic": "mean",
"sampling_size": 20,
"max_age": {"minutes": 4},
},
]
},
)
await hass.async_block_till_done()
hass.states.async_set(
"sensor.test_monitored",
str(VALUES_NUMERIC[0]),
{ATTR_UNIT_OF_MEASUREMENT: UnitOfTemperature.CELSIUS},
)
await hass.async_block_till_done()
state = hass.states.get("sensor.test_boundaries_none")
assert state is None
state = hass.states.get("sensor.test_boundaries_size")
assert state is not None
state = hass.states.get("sensor.test_boundaries_age")
assert state is not None
state = hass.states.get("sensor.test_boundaries_both")
assert state is not None
async def test_sampling_size_reduced(hass: HomeAssistant): async def test_sampling_size_reduced(hass: HomeAssistant):
"""Test limited buffer size.""" """Test limited buffer size."""
assert await async_setup_component( assert await async_setup_component(
@ -514,9 +572,9 @@ async def test_device_class(hass: HomeAssistant):
{ {
"sensor": [ "sensor": [
{ {
# Device class is carried over from source sensor for characteristics with same unit # Device class is carried over from source sensor for characteristics which retain unit
"platform": "statistics", "platform": "statistics",
"name": "test_source_class", "name": "test_retain_unit",
"entity_id": "sensor.test_monitored", "entity_id": "sensor.test_monitored",
"state_characteristic": "mean", "state_characteristic": "mean",
"sampling_size": 20, "sampling_size": 20,
@ -537,6 +595,14 @@ async def test_device_class(hass: HomeAssistant):
"state_characteristic": "datetime_oldest", "state_characteristic": "datetime_oldest",
"sampling_size": 20, "sampling_size": 20,
}, },
{
# Device class is set to None for any source sensor with TOTAL state class
"platform": "statistics",
"name": "test_source_class_total",
"entity_id": "sensor.test_monitored_total",
"state_characteristic": "mean",
"sampling_size": 20,
},
] ]
}, },
) )
@ -549,11 +615,21 @@ async def test_device_class(hass: HomeAssistant):
{ {
ATTR_UNIT_OF_MEASUREMENT: UnitOfTemperature.CELSIUS, ATTR_UNIT_OF_MEASUREMENT: UnitOfTemperature.CELSIUS,
ATTR_DEVICE_CLASS: SensorDeviceClass.TEMPERATURE, ATTR_DEVICE_CLASS: SensorDeviceClass.TEMPERATURE,
ATTR_STATE_CLASS: SensorStateClass.MEASUREMENT,
},
)
hass.states.async_set(
"sensor.test_monitored_total",
str(value),
{
ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.WATT_HOUR,
ATTR_DEVICE_CLASS: SensorDeviceClass.ENERGY,
ATTR_STATE_CLASS: SensorStateClass.TOTAL,
}, },
) )
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("sensor.test_source_class") state = hass.states.get("sensor.test_retain_unit")
assert state is not None assert state is not None
assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.TEMPERATURE assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.TEMPERATURE
state = hass.states.get("sensor.test_none") state = hass.states.get("sensor.test_none")
@ -562,6 +638,9 @@ async def test_device_class(hass: HomeAssistant):
state = hass.states.get("sensor.test_timestamp") state = hass.states.get("sensor.test_timestamp")
assert state is not None assert state is not None
assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.TIMESTAMP assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.TIMESTAMP
state = hass.states.get("sensor.test_source_class_total")
assert state is not None
assert state.attributes.get(ATTR_DEVICE_CLASS) is None
async def test_state_class(hass: HomeAssistant): async def test_state_class(hass: HomeAssistant):
@ -572,6 +651,15 @@ async def test_state_class(hass: HomeAssistant):
{ {
"sensor": [ "sensor": [
{ {
# State class is None for datetime characteristics
"platform": "statistics",
"name": "test_nan",
"entity_id": "sensor.test_monitored",
"state_characteristic": "datetime_oldest",
"sampling_size": 20,
},
{
# State class is MEASUREMENT for all other characteristics
"platform": "statistics", "platform": "statistics",
"name": "test_normal", "name": "test_normal",
"entity_id": "sensor.test_monitored", "entity_id": "sensor.test_monitored",
@ -579,10 +667,12 @@ async def test_state_class(hass: HomeAssistant):
"sampling_size": 20, "sampling_size": 20,
}, },
{ {
# State class is MEASUREMENT, even when the source sensor
# is of state class TOTAL
"platform": "statistics", "platform": "statistics",
"name": "test_nan", "name": "test_total",
"entity_id": "sensor.test_monitored", "entity_id": "sensor.test_monitored_total",
"state_characteristic": "datetime_oldest", "state_characteristic": "count",
"sampling_size": 20, "sampling_size": 20,
}, },
] ]
@ -596,14 +686,28 @@ async def test_state_class(hass: HomeAssistant):
str(value), str(value),
{ATTR_UNIT_OF_MEASUREMENT: UnitOfTemperature.CELSIUS}, {ATTR_UNIT_OF_MEASUREMENT: UnitOfTemperature.CELSIUS},
) )
hass.states.async_set(
"sensor.test_monitored_total",
str(value),
{
ATTR_UNIT_OF_MEASUREMENT: UnitOfTemperature.CELSIUS,
ATTR_STATE_CLASS: SensorStateClass.TOTAL,
},
)
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("sensor.test_normal")
assert state is not None
assert state.attributes.get(ATTR_STATE_CLASS) is SensorStateClass.MEASUREMENT
state = hass.states.get("sensor.test_nan") state = hass.states.get("sensor.test_nan")
assert state is not None assert state is not None
assert state.attributes.get(ATTR_STATE_CLASS) is None assert state.attributes.get(ATTR_STATE_CLASS) is None
state = hass.states.get("sensor.test_normal")
assert state is not None
assert state.attributes.get(ATTR_STATE_CLASS) is SensorStateClass.MEASUREMENT
state = hass.states.get("sensor.test_monitored_total")
assert state is not None
assert state.attributes.get(ATTR_STATE_CLASS) is SensorStateClass.TOTAL
state = hass.states.get("sensor.test_total")
assert state is not None
assert state.attributes.get(ATTR_STATE_CLASS) is SensorStateClass.MEASUREMENT
async def test_unitless_source_sensor(hass: HomeAssistant): async def test_unitless_source_sensor(hass: HomeAssistant):