From 0ab3b10aed46a5d80f515bdb705343f26ad322f2 Mon Sep 17 00:00:00 2001 From: Thomas Dietrich Date: Wed, 17 Nov 2021 12:31:32 +0100 Subject: [PATCH] Allow selection of statistics state characteristic (#49960) * Make statistics state characteristic selectable * Move computation in helper function * Add relevant config elements for clarity * Rename variables for better readability * Avoid reserved prefix ATTR_ for stats * Fix NoneType base_unit error * Add testcases for statistics characteristic * Add testcases for state_class, unitless, and characteristics * Add testcase coverage for no unit with binary * Replace error catching by an exception * Attend to review comments --- homeassistant/components/statistics/sensor.py | 368 +++++++++++------- tests/components/statistics/test_sensor.py | 191 ++++++++- 2 files changed, 403 insertions(+), 156 deletions(-) diff --git a/homeassistant/components/statistics/sensor.py b/homeassistant/components/statistics/sensor.py index 5b0ad3765b1..92fc682196b 100644 --- a/homeassistant/components/statistics/sensor.py +++ b/homeassistant/components/statistics/sensor.py @@ -8,12 +8,15 @@ import voluptuous as vol from homeassistant.components.recorder.models import States from homeassistant.components.recorder.util import execute, session_scope -from homeassistant.components.sensor import PLATFORM_SCHEMA, SensorEntity +from homeassistant.components.sensor import ( + PLATFORM_SCHEMA, + STATE_CLASS_MEASUREMENT, + SensorEntity, +) from homeassistant.const import ( ATTR_UNIT_OF_MEASUREMENT, CONF_ENTITY_ID, CONF_NAME, - EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE, STATE_UNKNOWN, ) @@ -24,36 +27,37 @@ from homeassistant.helpers.event import ( async_track_state_change_event, ) from homeassistant.helpers.reload import async_setup_reload_service +from homeassistant.helpers.start import async_at_start from homeassistant.util import dt as dt_util from . import DOMAIN, PLATFORMS _LOGGER = logging.getLogger(__name__) -ATTR_AVERAGE_CHANGE = "average_change" -ATTR_CHANGE = "change" -ATTR_CHANGE_RATE = "change_rate" -ATTR_COUNT = "count" -ATTR_MAX_AGE = "max_age" -ATTR_MAX_VALUE = "max_value" -ATTR_MEAN = "mean" -ATTR_MEDIAN = "median" -ATTR_MIN_AGE = "min_age" -ATTR_MIN_VALUE = "min_value" -ATTR_QUANTILES = "quantiles" -ATTR_SAMPLING_SIZE = "sampling_size" -ATTR_STANDARD_DEVIATION = "standard_deviation" -ATTR_TOTAL = "total" -ATTR_VARIANCE = "variance" +STAT_AVERAGE_CHANGE = "average_change" +STAT_CHANGE = "change" +STAT_CHANGE_RATE = "change_rate" +STAT_COUNT = "count" +STAT_MAX_AGE = "max_age" +STAT_MAX_VALUE = "max_value" +STAT_MEAN = "mean" +STAT_MEDIAN = "median" +STAT_MIN_AGE = "min_age" +STAT_MIN_VALUE = "min_value" +STAT_QUANTILES = "quantiles" +STAT_STANDARD_DEVIATION = "standard_deviation" +STAT_TOTAL = "total" +STAT_VARIANCE = "variance" -CONF_SAMPLING_SIZE = "sampling_size" +CONF_STATE_CHARACTERISTIC = "state_characteristic" +CONF_SAMPLES_MAX_BUFFER_SIZE = "sampling_size" CONF_MAX_AGE = "max_age" CONF_PRECISION = "precision" CONF_QUANTILE_INTERVALS = "quantile_intervals" CONF_QUANTILE_METHOD = "quantile_method" DEFAULT_NAME = "Stats" -DEFAULT_SIZE = 20 +DEFAULT_BUFFER_SIZE = 20 DEFAULT_PRECISION = 2 DEFAULT_QUANTILE_INTERVALS = 4 DEFAULT_QUANTILE_METHOD = "exclusive" @@ -63,9 +67,27 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( { vol.Required(CONF_ENTITY_ID): cv.entity_id, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, - vol.Optional(CONF_SAMPLING_SIZE, default=DEFAULT_SIZE): vol.All( - vol.Coerce(int), vol.Range(min=1) + vol.Optional(CONF_STATE_CHARACTERISTIC, default=STAT_MEAN): vol.In( + [ + STAT_AVERAGE_CHANGE, + STAT_CHANGE, + STAT_CHANGE_RATE, + STAT_COUNT, + STAT_MAX_AGE, + STAT_MAX_VALUE, + STAT_MEAN, + STAT_MEDIAN, + STAT_MIN_AGE, + STAT_MIN_VALUE, + STAT_QUANTILES, + STAT_STANDARD_DEVIATION, + STAT_TOTAL, + STAT_VARIANCE, + ] ), + vol.Optional( + CONF_SAMPLES_MAX_BUFFER_SIZE, default=DEFAULT_BUFFER_SIZE + ): vol.All(vol.Coerce(int), vol.Range(min=1)), vol.Optional(CONF_MAX_AGE): cv.time_period, vol.Optional(CONF_PRECISION, default=DEFAULT_PRECISION): vol.Coerce(int), vol.Optional( @@ -83,29 +105,21 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= await async_setup_reload_service(hass, DOMAIN, PLATFORMS) - entity_id = config.get(CONF_ENTITY_ID) - name = config.get(CONF_NAME) - sampling_size = config.get(CONF_SAMPLING_SIZE) - max_age = config.get(CONF_MAX_AGE) - precision = config.get(CONF_PRECISION) - quantile_intervals = config.get(CONF_QUANTILE_INTERVALS) - quantile_method = config.get(CONF_QUANTILE_METHOD) - async_add_entities( [ StatisticsSensor( - entity_id, - name, - sampling_size, - max_age, - precision, - quantile_intervals, - quantile_method, + source_entity_id=config.get(CONF_ENTITY_ID), + name=config.get(CONF_NAME), + state_characteristic=config.get(CONF_STATE_CHARACTERISTIC), + samples_max_buffer_size=config.get(CONF_SAMPLES_MAX_BUFFER_SIZE), + samples_max_age=config.get(CONF_MAX_AGE), + precision=config.get(CONF_PRECISION), + quantile_intervals=config.get(CONF_QUANTILE_INTERVALS), + quantile_method=config.get(CONF_QUANTILE_METHOD), ) ], True, ) - return True @@ -114,33 +128,45 @@ class StatisticsSensor(SensorEntity): def __init__( self, - entity_id, + source_entity_id, name, - sampling_size, - max_age, + state_characteristic, + samples_max_buffer_size, + samples_max_age, precision, quantile_intervals, quantile_method, ): """Initialize the Statistics sensor.""" - self._entity_id = entity_id - self.is_binary = self._entity_id.split(".")[0] == "binary_sensor" + self._source_entity_id = source_entity_id + self.is_binary = self._source_entity_id.split(".")[0] == "binary_sensor" self._name = name self._available = False - self._sampling_size = sampling_size - self._max_age = max_age + self._state_characteristic = state_characteristic + self._samples_max_buffer_size = samples_max_buffer_size + self._samples_max_age = samples_max_age self._precision = precision self._quantile_intervals = quantile_intervals self._quantile_method = quantile_method self._unit_of_measurement = None - self.states = deque(maxlen=self._sampling_size) - self.ages = deque(maxlen=self._sampling_size) - - self.count = 0 - self.mean = self.median = self.quantiles = self.stdev = self.variance = None - self.total = self.min = self.max = None - self.min_age = self.max_age = None - self.change = self.average_change = self.change_rate = None + self.states = deque(maxlen=self._samples_max_buffer_size) + self.ages = deque(maxlen=self._samples_max_buffer_size) + self.attr = { + STAT_COUNT: 0, + STAT_TOTAL: None, + STAT_MEAN: None, + STAT_MEDIAN: None, + STAT_STANDARD_DEVIATION: None, + STAT_VARIANCE: None, + STAT_MIN_VALUE: None, + STAT_MAX_VALUE: None, + STAT_MIN_AGE: None, + STAT_MAX_AGE: None, + STAT_CHANGE: None, + STAT_AVERAGE_CHANGE: None, + STAT_CHANGE_RATE: None, + STAT_QUANTILES: None, + } self._update_listener = None async def async_added_to_hass(self): @@ -151,9 +177,7 @@ class StatisticsSensor(SensorEntity): """Handle the sensor state changes.""" if (new_state := event.data.get("new_state")) is None: return - self._add_state_to_queue(new_state) - self.async_schedule_update_ha_state(True) @callback @@ -163,17 +187,16 @@ class StatisticsSensor(SensorEntity): self.async_on_remove( async_track_state_change_event( - self.hass, [self._entity_id], async_stats_sensor_state_listener + self.hass, + [self._source_entity_id], + async_stats_sensor_state_listener, ) ) if "recorder" in self.hass.config.components: - # Only use the database if it's configured self.hass.async_create_task(self._initialize_from_database()) - self.hass.bus.async_listen_once( - EVENT_HOMEASSISTANT_START, async_stats_sensor_startup - ) + async_at_start(self.hass, async_stats_sensor_startup) def _add_state_to_queue(self, new_state): """Add the state to the queue.""" @@ -195,27 +218,75 @@ class StatisticsSensor(SensorEntity): ) return - self._unit_of_measurement = new_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) + self._unit_of_measurement = self._derive_unit_of_measurement(new_state) + + def _derive_unit_of_measurement(self, new_state): + base_unit = new_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) + if not base_unit: + unit = None + elif self.is_binary: + unit = None + elif self._state_characteristic in ( + STAT_COUNT, + STAT_MIN_AGE, + STAT_MAX_AGE, + STAT_QUANTILES, + ): + unit = None + elif self._state_characteristic in ( + STAT_TOTAL, + STAT_MEAN, + STAT_MEDIAN, + STAT_STANDARD_DEVIATION, + STAT_MIN_VALUE, + STAT_MAX_VALUE, + STAT_CHANGE, + ): + unit = base_unit + elif self._state_characteristic == STAT_VARIANCE: + unit = base_unit + "²" + elif self._state_characteristic == STAT_AVERAGE_CHANGE: + unit = base_unit + "/sample" + elif self._state_characteristic == STAT_CHANGE_RATE: + unit = base_unit + "/s" + return unit @property def name(self): """Return the name of the sensor.""" return self._name + @property + def state_class(self): + """Return the state class of this entity.""" + if self._state_characteristic in ( + STAT_MIN_AGE, + STAT_MAX_AGE, + STAT_QUANTILES, + ): + return None + return STATE_CLASS_MEASUREMENT + @property def native_value(self): """Return the state of the sensor.""" if self.is_binary: - return self.count + return self.attr[STAT_COUNT] + if self._state_characteristic in ( + STAT_MIN_AGE, + STAT_MAX_AGE, + STAT_QUANTILES, + ): + return self.attr[self._state_characteristic] if self._precision == 0: with contextlib.suppress(TypeError, ValueError): - return int(self.mean) - return self.mean + return int(self.attr[self._state_characteristic]) + return self.attr[self._state_characteristic] @property def native_unit_of_measurement(self): """Return the unit the value is expressed in.""" - return self._unit_of_measurement if not self.is_binary else None + return self._unit_of_measurement @property def available(self): @@ -230,24 +301,9 @@ class StatisticsSensor(SensorEntity): @property def extra_state_attributes(self): """Return the state attributes of the sensor.""" - if not self.is_binary: - return { - ATTR_SAMPLING_SIZE: self._sampling_size, - ATTR_COUNT: self.count, - ATTR_MEAN: self.mean, - ATTR_MEDIAN: self.median, - ATTR_QUANTILES: self.quantiles, - ATTR_STANDARD_DEVIATION: self.stdev, - ATTR_VARIANCE: self.variance, - ATTR_TOTAL: self.total, - ATTR_MIN_VALUE: self.min, - ATTR_MAX_VALUE: self.max, - ATTR_MIN_AGE: self.min_age, - ATTR_MAX_AGE: self.max_age, - ATTR_CHANGE: self.change, - ATTR_AVERAGE_CHANGE: self.average_change, - ATTR_CHANGE_RATE: self.change_rate, - } + if self.is_binary: + return None + return self.attr @property def icon(self): @@ -255,17 +311,17 @@ class StatisticsSensor(SensorEntity): return ICON def _purge_old(self): - """Remove states which are older than self._max_age.""" + """Remove states which are older than self._samples_max_age.""" now = dt_util.utcnow() _LOGGER.debug( "%s: purging records older then %s(%s)", self.entity_id, - dt_util.as_local(now - self._max_age), - self._max_age, + dt_util.as_local(now - self._samples_max_age), + self._samples_max_age, ) - while self.ages and (now - self.ages[0]) > self._max_age: + while self.ages and (now - self.ages[0]) > self._samples_max_age: _LOGGER.debug( "%s: purging record with datetime %s(%s)", self.entity_id, @@ -277,73 +333,91 @@ class StatisticsSensor(SensorEntity): def _next_to_purge_timestamp(self): """Find the timestamp when the next purge would occur.""" - if self.ages and self._max_age: + if self.ages and self._samples_max_age: # Take the oldest entry from the ages list and add the configured max_age. # If executed after purging old states, the result is the next timestamp # in the future when the oldest state will expire. - return self.ages[0] + self._max_age + return self.ages[0] + self._samples_max_age return None + def _update_characteristics(self): + """Calculate and update the various statistical characteristics.""" + states_count = len(self.states) + self.attr[STAT_COUNT] = states_count + + if self.is_binary: + return + + if states_count >= 2: + self.attr[STAT_STANDARD_DEVIATION] = round( + statistics.stdev(self.states), self._precision + ) + self.attr[STAT_VARIANCE] = round( + statistics.variance(self.states), self._precision + ) + else: + self.attr[STAT_STANDARD_DEVIATION] = STATE_UNKNOWN + self.attr[STAT_VARIANCE] = STATE_UNKNOWN + + if states_count > self._quantile_intervals: + self.attr[STAT_QUANTILES] = [ + round(quantile, self._precision) + for quantile in statistics.quantiles( + self.states, + n=self._quantile_intervals, + method=self._quantile_method, + ) + ] + else: + self.attr[STAT_QUANTILES] = STATE_UNKNOWN + + if states_count == 0: + self.attr[STAT_MEAN] = STATE_UNKNOWN + self.attr[STAT_MEDIAN] = STATE_UNKNOWN + self.attr[STAT_TOTAL] = STATE_UNKNOWN + self.attr[STAT_MIN_VALUE] = self.attr[STAT_MAX_VALUE] = STATE_UNKNOWN + self.attr[STAT_MIN_AGE] = self.attr[STAT_MAX_AGE] = STATE_UNKNOWN + self.attr[STAT_CHANGE] = self.attr[STAT_AVERAGE_CHANGE] = STATE_UNKNOWN + self.attr[STAT_CHANGE_RATE] = STATE_UNKNOWN + return + + self.attr[STAT_MEAN] = round(statistics.mean(self.states), self._precision) + self.attr[STAT_MEDIAN] = round(statistics.median(self.states), self._precision) + + self.attr[STAT_TOTAL] = round(sum(self.states), self._precision) + self.attr[STAT_MIN_VALUE] = round(min(self.states), self._precision) + self.attr[STAT_MAX_VALUE] = round(max(self.states), self._precision) + + self.attr[STAT_MIN_AGE] = self.ages[0] + self.attr[STAT_MAX_AGE] = self.ages[-1] + + self.attr[STAT_CHANGE] = self.states[-1] - self.states[0] + + self.attr[STAT_AVERAGE_CHANGE] = self.attr[STAT_CHANGE] + self.attr[STAT_CHANGE_RATE] = 0 + if states_count > 1: + self.attr[STAT_AVERAGE_CHANGE] /= len(self.states) - 1 + + time_diff = ( + self.attr[STAT_MAX_AGE] - self.attr[STAT_MIN_AGE] + ).total_seconds() + if time_diff > 0: + self.attr[STAT_CHANGE_RATE] = self.attr[STAT_CHANGE] / time_diff + self.attr[STAT_CHANGE] = round(self.attr[STAT_CHANGE], self._precision) + self.attr[STAT_AVERAGE_CHANGE] = round( + self.attr[STAT_AVERAGE_CHANGE], self._precision + ) + self.attr[STAT_CHANGE_RATE] = round( + self.attr[STAT_CHANGE_RATE], self._precision + ) + async def async_update(self): """Get the latest data and updates the states.""" _LOGGER.debug("%s: updating statistics", self.entity_id) - if self._max_age is not None: + if self._samples_max_age is not None: self._purge_old() - self.count = len(self.states) - - if not self.is_binary: - try: # require only one data point - self.mean = round(statistics.mean(self.states), self._precision) - self.median = round(statistics.median(self.states), self._precision) - except statistics.StatisticsError as err: - _LOGGER.debug("%s: %s", self.entity_id, err) - self.mean = self.median = STATE_UNKNOWN - - try: # require at least two data points - self.stdev = round(statistics.stdev(self.states), self._precision) - self.variance = round(statistics.variance(self.states), self._precision) - if self._quantile_intervals < self.count: - self.quantiles = [ - round(quantile, self._precision) - for quantile in statistics.quantiles( - self.states, - n=self._quantile_intervals, - method=self._quantile_method, - ) - ] - except statistics.StatisticsError as err: - _LOGGER.debug("%s: %s", self.entity_id, err) - self.stdev = self.variance = self.quantiles = STATE_UNKNOWN - - if self.states: - self.total = round(sum(self.states), self._precision) - self.min = round(min(self.states), self._precision) - self.max = round(max(self.states), self._precision) - - self.min_age = self.ages[0] - self.max_age = self.ages[-1] - - self.change = self.states[-1] - self.states[0] - self.average_change = self.change - self.change_rate = 0 - - if len(self.states) > 1: - self.average_change /= len(self.states) - 1 - - time_diff = (self.max_age - self.min_age).total_seconds() - if time_diff > 0: - self.change_rate = self.change / time_diff - - self.change = round(self.change, self._precision) - self.average_change = round(self.average_change, self._precision) - self.change_rate = round(self.change_rate, self._precision) - - else: - self.total = self.min = self.max = STATE_UNKNOWN - self.min_age = self.max_age = dt_util.utcnow() - self.change = self.average_change = STATE_UNKNOWN - self.change_rate = STATE_UNKNOWN + self._update_characteristics() # If max_age is set, ensure to update again after the defined interval. next_to_purge_timestamp = self._next_to_purge_timestamp() @@ -381,11 +455,11 @@ class StatisticsSensor(SensorEntity): with session_scope(hass=self.hass) as session: query = session.query(States).filter( - States.entity_id == self._entity_id.lower() + States.entity_id == self._source_entity_id.lower() ) - if self._max_age is not None: - records_older_then = dt_util.utcnow() - self._max_age + if self._samples_max_age is not None: + records_older_then = dt_util.utcnow() - self._samples_max_age _LOGGER.debug( "%s: retrieve records not older then %s", self.entity_id, @@ -396,7 +470,7 @@ class StatisticsSensor(SensorEntity): _LOGGER.debug("%s: retrieving all records", self.entity_id) query = query.order_by(States.last_updated.desc()).limit( - self._sampling_size + self._samples_max_buffer_size ) states = execute(query, to_native=True, validate_entity_ids=False) diff --git a/tests/components/statistics/test_sensor.py b/tests/components/statistics/test_sensor.py index 0aa1c116991..4412dad843a 100644 --- a/tests/components/statistics/test_sensor.py +++ b/tests/components/statistics/test_sensor.py @@ -8,6 +8,7 @@ import pytest from homeassistant import config as hass_config from homeassistant.components import recorder +from homeassistant.components.sensor import ATTR_STATE_CLASS, STATE_CLASS_MEASUREMENT from homeassistant.components.statistics.sensor import DOMAIN, StatisticsSensor from homeassistant.const import ( ATTR_UNIT_OF_MEASUREMENT, @@ -64,11 +65,18 @@ class TestStatisticsSensor(unittest.TestCase): self.hass, "sensor", { - "sensor": { - "platform": "statistics", - "name": "test", - "entity_id": "binary_sensor.test_monitored", - } + "sensor": [ + { + "platform": "statistics", + "name": "test", + "entity_id": "binary_sensor.test_monitored", + }, + { + "platform": "statistics", + "name": "test_unitless", + "entity_id": "binary_sensor.test_monitored_unitless", + }, + ] }, ) @@ -77,12 +85,21 @@ class TestStatisticsSensor(unittest.TestCase): self.hass.block_till_done() for value in values: - self.hass.states.set("binary_sensor.test_monitored", value) + self.hass.states.set( + "binary_sensor.test_monitored", + value, + {ATTR_UNIT_OF_MEASUREMENT: TEMP_CELSIUS}, + ) + self.hass.states.set("binary_sensor.test_monitored_unitless", value) self.hass.block_till_done() state = self.hass.states.get("sensor.test") + assert state.state == str(len(values)) + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) is None + assert state.attributes.get(ATTR_STATE_CLASS) == STATE_CLASS_MEASUREMENT - assert str(len(values)) == state.state + state = self.hass.states.get("sensor.test_unitless") + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) is None def test_sensor_source(self): """Test if source is a sensor.""" @@ -121,17 +138,18 @@ class TestStatisticsSensor(unittest.TestCase): assert self.mean == state.attributes.get("mean") assert self.count == state.attributes.get("count") assert self.total == state.attributes.get("total") - assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == TEMP_CELSIUS assert self.change == state.attributes.get("change") assert self.average_change == state.attributes.get("average_change") + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == TEMP_CELSIUS + assert state.attributes.get(ATTR_STATE_CLASS) == STATE_CLASS_MEASUREMENT + # Source sensor turns unavailable, then available with valid value, # statistics sensor should follow state = self.hass.states.get("sensor.test") self.hass.states.set( "sensor.test_monitored", STATE_UNAVAILABLE, - {ATTR_UNIT_OF_MEASUREMENT: TEMP_CELSIUS}, ) self.hass.block_till_done() new_state = self.hass.states.get("sensor.test") @@ -445,6 +463,161 @@ class TestStatisticsSensor(unittest.TestCase): state = self.hass.states.get("sensor.test") assert state.state == str(round(sum(self.values) / len(self.values), 1)) + def test_state_characteristic_unit(self): + """Test statistics characteristic selection (via config).""" + assert setup_component( + self.hass, + "sensor", + { + "sensor": [ + { + "platform": "statistics", + "name": "test_min_age", + "entity_id": "sensor.test_monitored", + "state_characteristic": "min_age", + }, + { + "platform": "statistics", + "name": "test_variance", + "entity_id": "sensor.test_monitored", + "state_characteristic": "variance", + }, + { + "platform": "statistics", + "name": "test_average_change", + "entity_id": "sensor.test_monitored", + "state_characteristic": "average_change", + }, + { + "platform": "statistics", + "name": "test_change_rate", + "entity_id": "sensor.test_monitored", + "state_characteristic": "change_rate", + }, + ] + }, + ) + + self.hass.block_till_done() + self.hass.start() + self.hass.block_till_done() + + for value in self.values: + self.hass.states.set( + "sensor.test_monitored", + value, + {ATTR_UNIT_OF_MEASUREMENT: TEMP_CELSIUS}, + ) + self.hass.states.set( + "sensor.test_monitored_unitless", + value, + ) + self.hass.block_till_done() + + state = self.hass.states.get("sensor.test_min_age") + assert state.state == str(state.attributes.get("min_age")) + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) is None + state = self.hass.states.get("sensor.test_variance") + assert state.state == str(state.attributes.get("variance")) + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == TEMP_CELSIUS + "²" + state = self.hass.states.get("sensor.test_average_change") + assert state.state == str(state.attributes.get("average_change")) + assert ( + state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == TEMP_CELSIUS + "/sample" + ) + state = self.hass.states.get("sensor.test_change_rate") + assert state.state == str(state.attributes.get("change_rate")) + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == TEMP_CELSIUS + "/s" + + def test_state_class(self): + """Test state class, which depends on the characteristic configured.""" + assert setup_component( + self.hass, + "sensor", + { + "sensor": [ + { + "platform": "statistics", + "name": "test_normal", + "entity_id": "sensor.test_monitored", + "state_characteristic": "count", + }, + { + "platform": "statistics", + "name": "test_nan", + "entity_id": "sensor.test_monitored", + "state_characteristic": "min_age", + }, + ] + }, + ) + + self.hass.block_till_done() + self.hass.start() + self.hass.block_till_done() + + for value in self.values: + self.hass.states.set( + "sensor.test_monitored", + value, + {ATTR_UNIT_OF_MEASUREMENT: TEMP_CELSIUS}, + ) + self.hass.block_till_done() + + state = self.hass.states.get("sensor.test_normal") + assert state.attributes.get(ATTR_STATE_CLASS) == STATE_CLASS_MEASUREMENT + state = self.hass.states.get("sensor.test_nan") + assert state.attributes.get(ATTR_STATE_CLASS) is None + + def test_unitless_source_sensor(self): + """Statistics for a unitless source sensor should never have a unit.""" + assert setup_component( + self.hass, + "sensor", + { + "sensor": [ + { + "platform": "statistics", + "name": "test_unitless_1", + "entity_id": "sensor.test_monitored_unitless", + "state_characteristic": "count", + }, + { + "platform": "statistics", + "name": "test_unitless_2", + "entity_id": "sensor.test_monitored_unitless", + "state_characteristic": "mean", + }, + { + "platform": "statistics", + "name": "test_unitless_3", + "entity_id": "sensor.test_monitored_unitless", + "state_characteristic": "change_rate", + }, + ] + }, + ) + + self.hass.block_till_done() + self.hass.start() + self.hass.block_till_done() + + for value in self.values: + self.hass.states.set( + "sensor.test_monitored_unitless", + value, + ) + self.hass.block_till_done() + + state = self.hass.states.get("sensor.test_unitless_1") + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) is None + state = self.hass.states.get("sensor.test_unitless_2") + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) is None + state = self.hass.states.get("sensor.test_unitless_3") + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) is None + + assert state.attributes.get(ATTR_STATE_CLASS) == STATE_CLASS_MEASUREMENT + def test_initialize_from_database(self): """Test initializing the statistics from the database.""" # enable the recorder