From 6ee6a8a74fa1542f3cae532b4c893a60bc62df04 Mon Sep 17 00:00:00 2001 From: G Johansson Date: Wed, 23 Oct 2024 20:51:18 +0200 Subject: [PATCH] Fix calculation of attributes in group sensor (#128601) * Fix calculation of attributes in group sensor * Fixes * Fixes * Make module level function --- homeassistant/components/group/sensor.py | 161 +++++++++++------- tests/components/group/test_sensor.py | 203 ++++++++++++++++++++++- 2 files changed, 296 insertions(+), 68 deletions(-) diff --git a/homeassistant/components/group/sensor.py b/homeassistant/components/group/sensor.py index 32744bebc33..4a3e191e511 100644 --- a/homeassistant/components/group/sensor.py +++ b/homeassistant/components/group/sensor.py @@ -36,14 +36,7 @@ from homeassistant.const import ( STATE_UNAVAILABLE, STATE_UNKNOWN, ) -from homeassistant.core import ( - CALLBACK_TYPE, - Event, - EventStateChangedData, - HomeAssistant, - State, - callback, -) +from homeassistant.core import HomeAssistant, State, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers.entity import ( @@ -52,7 +45,6 @@ from homeassistant.helpers.entity import ( get_unit_of_measurement, ) from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.issue_registry import ( IssueSeverity, async_create_issue, @@ -180,6 +172,17 @@ def async_create_preview_sensor( ) +def _has_numeric_state(hass: HomeAssistant, entity_id: str) -> bool: + """Test if state is numeric.""" + if not (state := hass.states.get(entity_id)): + return False + try: + float(state.state) + except ValueError: + return False + return True + + def calc_min( sensor_values: list[tuple[str, float, State]], ) -> tuple[dict[str, str | None], float | None]: @@ -332,12 +335,11 @@ class SensorGroup(GroupEntity, SensorEntity): self.hass = hass self._entity_ids = entity_ids self._sensor_type = sensor_type - self._state_class = state_class - self._device_class = device_class - self._native_unit_of_measurement = unit_of_measurement + self._configured_state_class = state_class + self._configured_device_class = device_class + self._configured_unit_of_measurement = unit_of_measurement self._valid_units: set[str | None] = set() self._can_convert: bool = False - self.calculate_attributes_later: CALLBACK_TYPE | None = None self._attr_name = name if name == DEFAULT_NAME: self._attr_name = f"{DEFAULT_NAME} {sensor_type}".capitalize() @@ -352,39 +354,25 @@ class SensorGroup(GroupEntity, SensorEntity): self._state_incorrect: set[str] = set() self._extra_state_attribute: dict[str, Any] = {} - async def async_added_to_hass(self) -> None: - """When added to hass.""" - for entity_id in self._entity_ids: - if self.hass.states.get(entity_id) is None: - self.calculate_attributes_later = async_track_state_change_event( - self.hass, self._entity_ids, self.calculate_state_attributes - ) - break - if not self.calculate_attributes_later: - await self.calculate_state_attributes() - await super().async_added_to_hass() - - async def calculate_state_attributes( - self, event: Event[EventStateChangedData] | None = None - ) -> None: + def calculate_state_attributes(self, valid_state_entities: list[str]) -> None: """Calculate state attributes.""" - for entity_id in self._entity_ids: - if self.hass.states.get(entity_id) is None: - return - if self.calculate_attributes_later: - self.calculate_attributes_later() - self.calculate_attributes_later = None - self._attr_state_class = self._calculate_state_class(self._state_class) - self._attr_device_class = self._calculate_device_class(self._device_class) + self._attr_state_class = self._calculate_state_class( + self._configured_state_class, valid_state_entities + ) + self._attr_device_class = self._calculate_device_class( + self._configured_device_class, valid_state_entities + ) self._attr_native_unit_of_measurement = self._calculate_unit_of_measurement( - self._native_unit_of_measurement + self._configured_unit_of_measurement, valid_state_entities ) self._valid_units = self._get_valid_units() @callback def async_update_group_state(self) -> None: """Query all members and determine the sensor group state.""" + self.calculate_state_attributes(self._get_valid_entities()) states: list[StateType] = [] + valid_units = self._valid_units valid_states: list[bool] = [] sensor_values: list[tuple[str, float, State]] = [] for entity_id in self._entity_ids: @@ -392,20 +380,18 @@ class SensorGroup(GroupEntity, SensorEntity): states.append(state.state) try: numeric_state = float(state.state) - if ( - self._valid_units - and (uom := state.attributes["unit_of_measurement"]) - in self._valid_units - and self._can_convert is True - ): + uom = state.attributes.get("unit_of_measurement") + + # Convert the state to the native unit of measurement when we have valid units + # and a correct device class + if valid_units and uom in valid_units and self._can_convert is True: numeric_state = UNIT_CONVERTERS[self.device_class].convert( numeric_state, uom, self.native_unit_of_measurement ) - if ( - self._valid_units - and (uom := state.attributes["unit_of_measurement"]) - not in self._valid_units - ): + + # If we have valid units and the entity's unit does not match + # we raise which skips the state and log a warning once + if valid_units and uom not in valid_units: raise HomeAssistantError("Not a valid unit") # noqa: TRY301 sensor_values.append((entity_id, numeric_state, state)) @@ -480,7 +466,9 @@ class SensorGroup(GroupEntity, SensorEntity): return None def _calculate_state_class( - self, state_class: SensorStateClass | None + self, + state_class: SensorStateClass | None, + valid_state_entities: list[str], ) -> SensorStateClass | None: """Calculate state class. @@ -491,8 +479,18 @@ class SensorGroup(GroupEntity, SensorEntity): """ if state_class: return state_class + + if not valid_state_entities: + return None + + if not self._ignore_non_numeric and len(valid_state_entities) < len( + self._entity_ids + ): + # Only return state class if all states are valid when not ignoring non numeric + return None + state_classes: list[SensorStateClass] = [] - for entity_id in self._entity_ids: + for entity_id in valid_state_entities: try: _state_class = get_capability(self.hass, entity_id, "state_class") except HomeAssistantError: @@ -523,7 +521,9 @@ class SensorGroup(GroupEntity, SensorEntity): return None def _calculate_device_class( - self, device_class: SensorDeviceClass | None + self, + device_class: SensorDeviceClass | None, + valid_state_entities: list[str], ) -> SensorDeviceClass | None: """Calculate device class. @@ -534,8 +534,18 @@ class SensorGroup(GroupEntity, SensorEntity): """ if device_class: return device_class + + if not valid_state_entities: + return None + + if not self._ignore_non_numeric and len(valid_state_entities) < len( + self._entity_ids + ): + # Only return device class if all states are valid when not ignoring non numeric + return None + device_classes: list[SensorDeviceClass] = [] - for entity_id in self._entity_ids: + for entity_id in valid_state_entities: try: _device_class = get_device_class(self.hass, entity_id) except HomeAssistantError: @@ -568,7 +578,9 @@ class SensorGroup(GroupEntity, SensorEntity): return None def _calculate_unit_of_measurement( - self, unit_of_measurement: str | None + self, + unit_of_measurement: str | None, + valid_state_entities: list[str], ) -> str | None: """Calculate the unit of measurement. @@ -579,8 +591,17 @@ class SensorGroup(GroupEntity, SensorEntity): if unit_of_measurement: return unit_of_measurement + if not valid_state_entities: + return None + + if not self._ignore_non_numeric and len(valid_state_entities) < len( + self._entity_ids + ): + # Only return device class if all states are valid when not ignoring non numeric + return None + unit_of_measurements: list[str] = [] - for entity_id in self._entity_ids: + for entity_id in valid_state_entities: try: _unit_of_measurement = get_unit_of_measurement(self.hass, entity_id) except HomeAssistantError: @@ -665,19 +686,31 @@ class SensorGroup(GroupEntity, SensorEntity): If device class is set and compatible unit of measurements. If device class is not set, use one unit of measurement. + Only calculate valid units if there are no valid units set. """ - if ( - device_class := self.device_class - ) in UNIT_CONVERTERS and self.native_unit_of_measurement: + if (valid_units := self._valid_units) and not self._ignore_non_numeric: + # If we have valid units already and not using ignore_non_numeric + # we should not recalculate. + return valid_units + + native_uom = self.native_unit_of_measurement + if (device_class := self.device_class) in UNIT_CONVERTERS and native_uom: self._can_convert = True return UNIT_CONVERTERS[device_class].VALID_UNITS - if ( - device_class - and (device_class) in DEVICE_CLASS_UNITS - and self.native_unit_of_measurement - ): + if device_class and (device_class) in DEVICE_CLASS_UNITS and native_uom: valid_uoms: set = DEVICE_CLASS_UNITS[device_class] return valid_uoms - if device_class is None and self.native_unit_of_measurement: - return {self.native_unit_of_measurement} + if device_class is None and native_uom: + return {native_uom} return set() + + def _get_valid_entities( + self, + ) -> list[str]: + """Return list of valid entities.""" + + return [ + entity_id + for entity_id in self._entity_ids + if _has_numeric_state(self.hass, entity_id) + ] diff --git a/tests/components/group/test_sensor.py b/tests/components/group/test_sensor.py index db642506361..de406cb251c 100644 --- a/tests/components/group/test_sensor.py +++ b/tests/components/group/test_sensor.py @@ -32,6 +32,7 @@ from homeassistant.const import ( SERVICE_RELOAD, STATE_UNAVAILABLE, STATE_UNKNOWN, + UnitOfTemperature, ) from homeassistant.core import HomeAssistant from homeassistant.helpers import issue_registry as ir @@ -496,7 +497,7 @@ async def test_sensor_with_uoms_but_no_device_class( state = hass.states.get("sensor.test_sum") assert state.attributes.get("device_class") is None assert state.attributes.get("state_class") is None - assert state.attributes.get("unit_of_measurement") == "W" + assert state.attributes.get("unit_of_measurement") is None assert state.state == STATE_UNKNOWN assert ( @@ -650,10 +651,10 @@ async def test_sensor_calculated_result_fails_on_uom(hass: HomeAssistant) -> Non await hass.async_block_till_done() state = hass.states.get("sensor.test_sum") - assert state.state == STATE_UNKNOWN + assert state.state == STATE_UNAVAILABLE assert state.attributes.get("device_class") == "energy" assert state.attributes.get("state_class") == "total" - assert state.attributes.get("unit_of_measurement") == "kWh" + assert state.attributes.get("unit_of_measurement") is None async def test_sensor_calculated_properties_not_convertible_device_class( @@ -730,7 +731,7 @@ async def test_sensor_calculated_properties_not_convertible_device_class( assert state.state == STATE_UNKNOWN assert state.attributes.get("device_class") == "humidity" assert state.attributes.get("state_class") == "measurement" - assert state.attributes.get("unit_of_measurement") == "%" + assert state.attributes.get("unit_of_measurement") is None assert ( "Unable to use state. Only entities with correct unit of measurement is" @@ -812,3 +813,197 @@ async def test_sensors_attributes_added_when_entity_info_available( assert state.attributes.get(ATTR_ICON) is None assert state.attributes.get(ATTR_STATE_CLASS) == SensorStateClass.TOTAL assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "L" + + +async def test_sensor_state_class_no_uom_not_available( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test when input sensors drops unit of measurement.""" + + # If we have a valid unit of measurement from all input sensors + # the group sensor will go unknown in the case any input sensor + # drops the unit of measurement and log a warning. + + config = { + SENSOR_DOMAIN: { + "platform": GROUP_DOMAIN, + "name": "test_sum", + "type": "sum", + "entities": ["sensor.test_1", "sensor.test_2", "sensor.test_3"], + "unique_id": "very_unique_id_sum_sensor", + } + } + + entity_ids = config["sensor"]["entities"] + + input_attributes = { + "state_class": SensorStateClass.MEASUREMENT, + "unit_of_measurement": PERCENTAGE, + } + + hass.states.async_set(entity_ids[0], VALUES[0], input_attributes) + hass.states.async_set(entity_ids[1], VALUES[1], input_attributes) + hass.states.async_set(entity_ids[2], VALUES[2], input_attributes) + await hass.async_block_till_done() + + assert await async_setup_component(hass, "sensor", config) + await hass.async_block_till_done() + + state = hass.states.get("sensor.test_sum") + assert state.state == str(sum(VALUES)) + assert state.attributes.get("state_class") == "measurement" + assert state.attributes.get("unit_of_measurement") == "%" + + assert ( + "Unable to use state. Only entities with correct unit of measurement is" + " supported" + ) not in caplog.text + + # sensor.test_3 drops the unit of measurement + hass.states.async_set( + entity_ids[2], + VALUES[2], + { + "state_class": SensorStateClass.MEASUREMENT, + }, + ) + await hass.async_block_till_done() + + state = hass.states.get("sensor.test_sum") + assert state.state == STATE_UNKNOWN + assert state.attributes.get("state_class") == "measurement" + assert state.attributes.get("unit_of_measurement") is None + + assert ( + "Unable to use state. Only entities with correct unit of measurement is" + " supported, entity sensor.test_3, value 15.3 with" + " device class None and unit of measurement None excluded from calculation" + " in sensor.test_sum" + ) in caplog.text + + +async def test_sensor_different_attributes_ignore_non_numeric( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test the sensor handles calculating attributes when using ignore_non_numeric.""" + config = { + SENSOR_DOMAIN: { + "platform": GROUP_DOMAIN, + "name": "test_sum", + "type": "sum", + "ignore_non_numeric": True, + "entities": ["sensor.test_1", "sensor.test_2", "sensor.test_3"], + "unique_id": "very_unique_id_sum_sensor", + } + } + + entity_ids = config["sensor"]["entities"] + + assert await async_setup_component(hass, "sensor", config) + await hass.async_block_till_done() + + state = hass.states.get("sensor.test_sum") + assert state.state == STATE_UNAVAILABLE + assert state.attributes.get("state_class") is None + assert state.attributes.get("device_class") is None + assert state.attributes.get("unit_of_measurement") is None + + test_cases = [ + { + "entity": entity_ids[0], + "value": VALUES[0], + "attributes": { + "state_class": SensorStateClass.MEASUREMENT, + "unit_of_measurement": PERCENTAGE, + }, + "expected_state": str(float(VALUES[0])), + "expected_state_class": SensorStateClass.MEASUREMENT, + "expected_device_class": None, + "expected_unit_of_measurement": PERCENTAGE, + }, + { + "entity": entity_ids[1], + "value": VALUES[1], + "attributes": { + "state_class": SensorStateClass.MEASUREMENT, + "device_class": SensorDeviceClass.HUMIDITY, + "unit_of_measurement": PERCENTAGE, + }, + "expected_state": str(float(sum([VALUES[0], VALUES[1]]))), + "expected_state_class": SensorStateClass.MEASUREMENT, + "expected_device_class": None, + "expected_unit_of_measurement": PERCENTAGE, + }, + { + "entity": entity_ids[2], + "value": VALUES[2], + "attributes": { + "state_class": SensorStateClass.MEASUREMENT, + "device_class": SensorDeviceClass.TEMPERATURE, + "unit_of_measurement": UnitOfTemperature.CELSIUS, + }, + "expected_state": str(float(sum(VALUES))), + "expected_state_class": SensorStateClass.MEASUREMENT, + "expected_device_class": None, + "expected_unit_of_measurement": None, + }, + { + "entity": entity_ids[2], + "value": VALUES[2], + "attributes": { + "state_class": SensorStateClass.MEASUREMENT, + "device_class": SensorDeviceClass.HUMIDITY, + "unit_of_measurement": PERCENTAGE, + }, + "expected_state": str(float(sum(VALUES))), + "expected_state_class": SensorStateClass.MEASUREMENT, + # One sensor does not have a device class + "expected_device_class": None, + "expected_unit_of_measurement": PERCENTAGE, + }, + { + "entity": entity_ids[0], + "value": VALUES[0], + "attributes": { + "state_class": SensorStateClass.MEASUREMENT, + "device_class": SensorDeviceClass.HUMIDITY, + "unit_of_measurement": PERCENTAGE, + }, + "expected_state": str(float(sum(VALUES))), + "expected_state_class": SensorStateClass.MEASUREMENT, + # First sensor now has a device class + "expected_device_class": SensorDeviceClass.HUMIDITY, + "expected_unit_of_measurement": PERCENTAGE, + }, + { + "entity": entity_ids[0], + "value": VALUES[0], + "attributes": { + "state_class": SensorStateClass.MEASUREMENT, + }, + "expected_state": str(float(sum(VALUES))), + "expected_state_class": SensorStateClass.MEASUREMENT, + "expected_device_class": None, + "expected_unit_of_measurement": None, + }, + ] + + for test_case in test_cases: + hass.states.async_set( + test_case["entity"], + test_case["value"], + test_case["attributes"], + ) + await hass.async_block_till_done() + state = hass.states.get("sensor.test_sum") + assert state.state == test_case["expected_state"] + assert state.attributes.get("state_class") == test_case["expected_state_class"] + assert ( + state.attributes.get("device_class") == test_case["expected_device_class"] + ) + assert ( + state.attributes.get("unit_of_measurement") + == test_case["expected_unit_of_measurement"] + )