diff --git a/homeassistant/components/statistics/sensor.py b/homeassistant/components/statistics/sensor.py index bb4fd2821bc..7edffc54fcd 100644 --- a/homeassistant/components/statistics/sensor.py +++ b/homeassistant/components/statistics/sensor.py @@ -17,6 +17,7 @@ from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAI from homeassistant.components.recorder import get_instance, history from homeassistant.components.sensor import ( DEVICE_CLASS_STATE_CLASSES, + DEVICE_CLASS_UNITS, PLATFORM_SCHEMA as SENSOR_PLATFORM_SCHEMA, SensorDeviceClass, SensorEntity, @@ -359,15 +360,14 @@ class StatisticsSensor(SensorEntity): self.samples_keep_last: bool = samples_keep_last self._precision: int = precision self._percentile: int = percentile - self._value: StateType | datetime = None - self._unit_of_measurement: str | None = None + self._value: float | int | datetime | None = None self._available: bool = False self.states: deque[float | bool] = deque(maxlen=self._samples_max_buffer_size) self.ages: deque[datetime] = deque(maxlen=self._samples_max_buffer_size) self.attributes: dict[str, StateType] = {} - self._state_characteristic_fn: Callable[[], StateType | datetime] = ( + self._state_characteristic_fn: Callable[[], float | int | datetime | None] = ( self._callable_characteristic_fn(self._state_characteristic) ) @@ -486,11 +486,28 @@ class StatisticsSensor(SensorEntity): ) return - self._unit_of_measurement = self._derive_unit_of_measurement(new_state) + self._calculate_state_attributes(new_state) + + def _calculate_state_attributes(self, new_state: State) -> None: + """Set the entity state attributes.""" + + self._attr_native_unit_of_measurement = self._calculate_unit_of_measurement( + new_state + ) + self._attr_device_class = self._calculate_device_class( + new_state, self._attr_native_unit_of_measurement + ) + self._attr_state_class = self._calculate_state_class(new_state) + + def _calculate_unit_of_measurement(self, new_state: State) -> str | None: + """Return the calculated unit of measurement. + + The unit of measurement is that of the source sensor, adjusted based on the + state characteristics. + """ - def _derive_unit_of_measurement(self, new_state: State) -> str | None: base_unit: str | None = new_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) - unit: str | None + unit: str | None = None if self.is_binary and self._state_characteristic in STATS_BINARY_PERCENTAGE: unit = PERCENTAGE elif not base_unit: @@ -513,48 +530,66 @@ class StatisticsSensor(SensorEntity): unit = base_unit + "/sample" elif self._state_characteristic == STAT_CHANGE_SECOND: unit = base_unit + "/s" + return unit - @property - def device_class(self) -> SensorDeviceClass | None: - """Return the class of this device.""" + def _calculate_device_class( + self, new_state: State, unit: str | None + ) -> SensorDeviceClass | None: + """Return the calculated device class. + + The device class is calculated based on the state characteristics, + the source device class and the unit of measurement is + in the device class units list. + """ + + device_class: SensorDeviceClass | None = None if self._state_characteristic in STATS_DATETIME: return SensorDeviceClass.TIMESTAMP if self._state_characteristic in STATS_NUMERIC_RETAIN_UNIT: - source_state = self.hass.states.get(self._source_entity_id) - if source_state is None: + device_class = new_state.attributes.get(ATTR_DEVICE_CLASS) + if device_class is None: return None - source_device_class = source_state.attributes.get(ATTR_DEVICE_CLASS) - if source_device_class is None: + if ( + sensor_device_class := try_parse_enum(SensorDeviceClass, device_class) + ) is None: return None - sensor_device_class = try_parse_enum(SensorDeviceClass, source_device_class) - if sensor_device_class is None: + if ( + sensor_device_class + and ( + sensor_state_classes := DEVICE_CLASS_STATE_CLASSES.get( + sensor_device_class + ) + ) + and sensor_state_classes + and SensorStateClass.MEASUREMENT not in sensor_state_classes + ): return None - sensor_state_classes = DEVICE_CLASS_STATE_CLASSES.get( - sensor_device_class, set() - ) - if SensorStateClass.MEASUREMENT not in sensor_state_classes: + if device_class not in DEVICE_CLASS_UNITS: + return None + if ( + device_class in DEVICE_CLASS_UNITS + and unit not in DEVICE_CLASS_UNITS[device_class] + ): return None - return sensor_device_class - return None - @property - def state_class(self) -> SensorStateClass | None: - """Return the state class of this entity.""" + return device_class + + def _calculate_state_class(self, new_state: State) -> SensorStateClass | None: + """Return the calculated state class. + + Will be None if the characteristics is not numerical, otherwise + SensorStateClass.MEASUREMENT. + """ if self._state_characteristic in STATS_NOT_A_NUMBER: return None return SensorStateClass.MEASUREMENT @property - def native_value(self) -> StateType | datetime: + def native_value(self) -> float | int | datetime | None: """Return the state of the sensor.""" return self._value - @property - def native_unit_of_measurement(self) -> str | None: - """Return the unit the value is expressed in.""" - return self._unit_of_measurement - @property def available(self) -> bool: """Return the availability of the sensor linked to the source sensor.""" @@ -703,7 +738,7 @@ class StatisticsSensor(SensorEntity): ): for state in reversed(states): self._add_state_to_queue(state) - + self._calculate_state_attributes(state) self._async_purge_update_and_schedule() # only write state to the state machine if we are not in preview mode @@ -750,9 +785,9 @@ class StatisticsSensor(SensorEntity): def _callable_characteristic_fn( self, characteristic: str - ) -> Callable[[], StateType | datetime]: + ) -> Callable[[], float | int | datetime | None]: """Return the function callable of one characteristic function.""" - function: Callable[[], StateType | datetime] = getattr( + function: Callable[[], float | int | datetime | None] = getattr( self, f"_stat_binary_{characteristic}" if self.is_binary diff --git a/tests/components/statistics/snapshots/test_config_flow.ambr b/tests/components/statistics/snapshots/test_config_flow.ambr index 8d274cd86c6..5f79c56dec7 100644 --- a/tests/components/statistics/snapshots/test_config_flow.ambr +++ b/tests/components/statistics/snapshots/test_config_flow.ambr @@ -4,7 +4,6 @@ 'attributes': dict({ 'friendly_name': 'Statistical characteristic', 'icon': 'mdi:calculator', - 'state_class': 'measurement', }), 'state': 'unavailable', }) diff --git a/tests/components/statistics/test_sensor.py b/tests/components/statistics/test_sensor.py index fa9e627fe6b..7e2bc1cb16b 100644 --- a/tests/components/statistics/test_sensor.py +++ b/tests/components/statistics/test_sensor.py @@ -1832,3 +1832,203 @@ async def test_average_linear_unevenly_timed(hass: HomeAssistant) -> None: "value mismatch for characteristic 'sensor/average_linear' - " f"assert {state.state} == 8.33" ) + + +async def test_sensor_unit_gets_removed(hass: HomeAssistant) -> None: + """Test when input lose its unit of measurement.""" + assert await async_setup_component( + hass, + "sensor", + { + "sensor": [ + { + "platform": "statistics", + "name": "test", + "entity_id": "sensor.test_monitored", + "state_characteristic": "mean", + "sampling_size": 10, + }, + ] + }, + ) + await hass.async_block_till_done() + + input_attributes = { + ATTR_STATE_CLASS: SensorStateClass.MEASUREMENT, + ATTR_DEVICE_CLASS: SensorDeviceClass.TEMPERATURE, + ATTR_UNIT_OF_MEASUREMENT: UnitOfTemperature.CELSIUS, + } + + for value in VALUES_NUMERIC: + hass.states.async_set( + "sensor.test_monitored", + str(value), + input_attributes, + ) + await hass.async_block_till_done() + + state = hass.states.get("sensor.test") + assert state is not None + assert state.state == str(round(sum(VALUES_NUMERIC) / len(VALUES_NUMERIC), 2)) + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfTemperature.CELSIUS + assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.TEMPERATURE + assert state.attributes.get(ATTR_STATE_CLASS) == SensorStateClass.MEASUREMENT + + hass.states.async_set( + "sensor.test_monitored", + str(VALUES_NUMERIC[0]), + { + ATTR_STATE_CLASS: SensorStateClass.MEASUREMENT, + ATTR_DEVICE_CLASS: SensorDeviceClass.TEMPERATURE, + }, + ) + await hass.async_block_till_done() + + state = hass.states.get("sensor.test") + assert state is not None + assert state.state == "11.39" + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) is None + # Temperature device class is not valid with no unit of measurement + assert state.attributes.get(ATTR_DEVICE_CLASS) is None + assert state.attributes.get(ATTR_STATE_CLASS) == SensorStateClass.MEASUREMENT + + for value in VALUES_NUMERIC: + hass.states.async_set( + "sensor.test_monitored", + str(value), + input_attributes, + ) + await hass.async_block_till_done() + + state = hass.states.get("sensor.test") + assert state is not None + assert state.state == "11.39" + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfTemperature.CELSIUS + assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.TEMPERATURE + assert state.attributes.get(ATTR_STATE_CLASS) == SensorStateClass.MEASUREMENT + + +async def test_sensor_device_class_gets_removed(hass: HomeAssistant) -> None: + """Test when device class gets removed.""" + assert await async_setup_component( + hass, + "sensor", + { + "sensor": [ + { + "platform": "statistics", + "name": "test", + "entity_id": "sensor.test_monitored", + "state_characteristic": "mean", + "sampling_size": 10, + }, + ] + }, + ) + await hass.async_block_till_done() + + input_attributes = { + ATTR_STATE_CLASS: SensorStateClass.MEASUREMENT, + ATTR_DEVICE_CLASS: SensorDeviceClass.TEMPERATURE, + ATTR_UNIT_OF_MEASUREMENT: UnitOfTemperature.CELSIUS, + } + + for value in VALUES_NUMERIC: + hass.states.async_set( + "sensor.test_monitored", + str(value), + input_attributes, + ) + await hass.async_block_till_done() + + state = hass.states.get("sensor.test") + assert state is not None + assert state.state == str(round(sum(VALUES_NUMERIC) / len(VALUES_NUMERIC), 2)) + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfTemperature.CELSIUS + assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.TEMPERATURE + assert state.attributes.get(ATTR_STATE_CLASS) == SensorStateClass.MEASUREMENT + + hass.states.async_set( + "sensor.test_monitored", + str(VALUES_NUMERIC[0]), + { + ATTR_STATE_CLASS: SensorStateClass.MEASUREMENT, + ATTR_UNIT_OF_MEASUREMENT: UnitOfTemperature.CELSIUS, + }, + ) + await hass.async_block_till_done() + + state = hass.states.get("sensor.test") + assert state is not None + assert state.state == "11.39" + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfTemperature.CELSIUS + assert state.attributes.get(ATTR_DEVICE_CLASS) is None + assert state.attributes.get(ATTR_STATE_CLASS) == SensorStateClass.MEASUREMENT + + for value in VALUES_NUMERIC: + hass.states.async_set( + "sensor.test_monitored", + str(value), + input_attributes, + ) + await hass.async_block_till_done() + + state = hass.states.get("sensor.test") + assert state is not None + assert state.state == "11.39" + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfTemperature.CELSIUS + assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.TEMPERATURE + assert state.attributes.get(ATTR_STATE_CLASS) == SensorStateClass.MEASUREMENT + + +async def test_not_valid_device_class(hass: HomeAssistant) -> None: + """Test when not valid device class.""" + assert await async_setup_component( + hass, + "sensor", + { + "sensor": [ + { + "platform": "statistics", + "name": "test", + "entity_id": "sensor.test_monitored", + "state_characteristic": "mean", + "sampling_size": 10, + }, + ] + }, + ) + await hass.async_block_till_done() + + for value in VALUES_NUMERIC: + hass.states.async_set( + "sensor.test_monitored", + str(value), + { + ATTR_DEVICE_CLASS: SensorDeviceClass.DATE, + }, + ) + await hass.async_block_till_done() + + state = hass.states.get("sensor.test") + assert state is not None + assert state.state == str(round(sum(VALUES_NUMERIC) / len(VALUES_NUMERIC), 2)) + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) is None + assert state.attributes.get(ATTR_DEVICE_CLASS) is None + assert state.attributes.get(ATTR_STATE_CLASS) == SensorStateClass.MEASUREMENT + + hass.states.async_set( + "sensor.test_monitored", + str(10), + { + ATTR_DEVICE_CLASS: "not_exist", + }, + ) + await hass.async_block_till_done() + + state = hass.states.get("sensor.test") + assert state is not None + assert state.state == "10.69" + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) is None + assert state.attributes.get(ATTR_DEVICE_CLASS) is None + assert state.attributes.get(ATTR_STATE_CLASS) == SensorStateClass.MEASUREMENT