Fix calculation of attributes in group sensor (#128601)

* Fix calculation of attributes in group sensor

* Fixes

* Fixes

* Make module level function
This commit is contained in:
G Johansson 2024-10-23 20:51:18 +02:00 committed by GitHub
parent 80984c94a1
commit 6ee6a8a74f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 296 additions and 68 deletions

View File

@ -36,14 +36,7 @@ from homeassistant.const import (
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
STATE_UNKNOWN, STATE_UNKNOWN,
) )
from homeassistant.core import ( from homeassistant.core import HomeAssistant, State, callback
CALLBACK_TYPE,
Event,
EventStateChangedData,
HomeAssistant,
State,
callback,
)
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity import ( from homeassistant.helpers.entity import (
@ -52,7 +45,6 @@ from homeassistant.helpers.entity import (
get_unit_of_measurement, get_unit_of_measurement,
) )
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_state_change_event
from homeassistant.helpers.issue_registry import ( from homeassistant.helpers.issue_registry import (
IssueSeverity, IssueSeverity,
async_create_issue, async_create_issue,
@ -180,6 +172,17 @@ def async_create_preview_sensor(
) )
def _has_numeric_state(hass: HomeAssistant, entity_id: str) -> bool:
"""Test if state is numeric."""
if not (state := hass.states.get(entity_id)):
return False
try:
float(state.state)
except ValueError:
return False
return True
def calc_min( def calc_min(
sensor_values: list[tuple[str, float, State]], sensor_values: list[tuple[str, float, State]],
) -> tuple[dict[str, str | None], float | None]: ) -> tuple[dict[str, str | None], float | None]:
@ -332,12 +335,11 @@ class SensorGroup(GroupEntity, SensorEntity):
self.hass = hass self.hass = hass
self._entity_ids = entity_ids self._entity_ids = entity_ids
self._sensor_type = sensor_type self._sensor_type = sensor_type
self._state_class = state_class self._configured_state_class = state_class
self._device_class = device_class self._configured_device_class = device_class
self._native_unit_of_measurement = unit_of_measurement self._configured_unit_of_measurement = unit_of_measurement
self._valid_units: set[str | None] = set() self._valid_units: set[str | None] = set()
self._can_convert: bool = False self._can_convert: bool = False
self.calculate_attributes_later: CALLBACK_TYPE | None = None
self._attr_name = name self._attr_name = name
if name == DEFAULT_NAME: if name == DEFAULT_NAME:
self._attr_name = f"{DEFAULT_NAME} {sensor_type}".capitalize() self._attr_name = f"{DEFAULT_NAME} {sensor_type}".capitalize()
@ -352,39 +354,25 @@ class SensorGroup(GroupEntity, SensorEntity):
self._state_incorrect: set[str] = set() self._state_incorrect: set[str] = set()
self._extra_state_attribute: dict[str, Any] = {} self._extra_state_attribute: dict[str, Any] = {}
async def async_added_to_hass(self) -> None: def calculate_state_attributes(self, valid_state_entities: list[str]) -> None:
"""When added to hass."""
for entity_id in self._entity_ids:
if self.hass.states.get(entity_id) is None:
self.calculate_attributes_later = async_track_state_change_event(
self.hass, self._entity_ids, self.calculate_state_attributes
)
break
if not self.calculate_attributes_later:
await self.calculate_state_attributes()
await super().async_added_to_hass()
async def calculate_state_attributes(
self, event: Event[EventStateChangedData] | None = None
) -> None:
"""Calculate state attributes.""" """Calculate state attributes."""
for entity_id in self._entity_ids: self._attr_state_class = self._calculate_state_class(
if self.hass.states.get(entity_id) is None: self._configured_state_class, valid_state_entities
return )
if self.calculate_attributes_later: self._attr_device_class = self._calculate_device_class(
self.calculate_attributes_later() self._configured_device_class, valid_state_entities
self.calculate_attributes_later = None )
self._attr_state_class = self._calculate_state_class(self._state_class)
self._attr_device_class = self._calculate_device_class(self._device_class)
self._attr_native_unit_of_measurement = self._calculate_unit_of_measurement( self._attr_native_unit_of_measurement = self._calculate_unit_of_measurement(
self._native_unit_of_measurement self._configured_unit_of_measurement, valid_state_entities
) )
self._valid_units = self._get_valid_units() self._valid_units = self._get_valid_units()
@callback @callback
def async_update_group_state(self) -> None: def async_update_group_state(self) -> None:
"""Query all members and determine the sensor group state.""" """Query all members and determine the sensor group state."""
self.calculate_state_attributes(self._get_valid_entities())
states: list[StateType] = [] states: list[StateType] = []
valid_units = self._valid_units
valid_states: list[bool] = [] valid_states: list[bool] = []
sensor_values: list[tuple[str, float, State]] = [] sensor_values: list[tuple[str, float, State]] = []
for entity_id in self._entity_ids: for entity_id in self._entity_ids:
@ -392,20 +380,18 @@ class SensorGroup(GroupEntity, SensorEntity):
states.append(state.state) states.append(state.state)
try: try:
numeric_state = float(state.state) numeric_state = float(state.state)
if ( uom = state.attributes.get("unit_of_measurement")
self._valid_units
and (uom := state.attributes["unit_of_measurement"]) # Convert the state to the native unit of measurement when we have valid units
in self._valid_units # and a correct device class
and self._can_convert is True if valid_units and uom in valid_units and self._can_convert is True:
):
numeric_state = UNIT_CONVERTERS[self.device_class].convert( numeric_state = UNIT_CONVERTERS[self.device_class].convert(
numeric_state, uom, self.native_unit_of_measurement numeric_state, uom, self.native_unit_of_measurement
) )
if (
self._valid_units # If we have valid units and the entity's unit does not match
and (uom := state.attributes["unit_of_measurement"]) # we raise which skips the state and log a warning once
not in self._valid_units if valid_units and uom not in valid_units:
):
raise HomeAssistantError("Not a valid unit") # noqa: TRY301 raise HomeAssistantError("Not a valid unit") # noqa: TRY301
sensor_values.append((entity_id, numeric_state, state)) sensor_values.append((entity_id, numeric_state, state))
@ -480,7 +466,9 @@ class SensorGroup(GroupEntity, SensorEntity):
return None return None
def _calculate_state_class( def _calculate_state_class(
self, state_class: SensorStateClass | None self,
state_class: SensorStateClass | None,
valid_state_entities: list[str],
) -> SensorStateClass | None: ) -> SensorStateClass | None:
"""Calculate state class. """Calculate state class.
@ -491,8 +479,18 @@ class SensorGroup(GroupEntity, SensorEntity):
""" """
if state_class: if state_class:
return state_class return state_class
if not valid_state_entities:
return None
if not self._ignore_non_numeric and len(valid_state_entities) < len(
self._entity_ids
):
# Only return state class if all states are valid when not ignoring non numeric
return None
state_classes: list[SensorStateClass] = [] state_classes: list[SensorStateClass] = []
for entity_id in self._entity_ids: for entity_id in valid_state_entities:
try: try:
_state_class = get_capability(self.hass, entity_id, "state_class") _state_class = get_capability(self.hass, entity_id, "state_class")
except HomeAssistantError: except HomeAssistantError:
@ -523,7 +521,9 @@ class SensorGroup(GroupEntity, SensorEntity):
return None return None
def _calculate_device_class( def _calculate_device_class(
self, device_class: SensorDeviceClass | None self,
device_class: SensorDeviceClass | None,
valid_state_entities: list[str],
) -> SensorDeviceClass | None: ) -> SensorDeviceClass | None:
"""Calculate device class. """Calculate device class.
@ -534,8 +534,18 @@ class SensorGroup(GroupEntity, SensorEntity):
""" """
if device_class: if device_class:
return device_class return device_class
if not valid_state_entities:
return None
if not self._ignore_non_numeric and len(valid_state_entities) < len(
self._entity_ids
):
# Only return device class if all states are valid when not ignoring non numeric
return None
device_classes: list[SensorDeviceClass] = [] device_classes: list[SensorDeviceClass] = []
for entity_id in self._entity_ids: for entity_id in valid_state_entities:
try: try:
_device_class = get_device_class(self.hass, entity_id) _device_class = get_device_class(self.hass, entity_id)
except HomeAssistantError: except HomeAssistantError:
@ -568,7 +578,9 @@ class SensorGroup(GroupEntity, SensorEntity):
return None return None
def _calculate_unit_of_measurement( def _calculate_unit_of_measurement(
self, unit_of_measurement: str | None self,
unit_of_measurement: str | None,
valid_state_entities: list[str],
) -> str | None: ) -> str | None:
"""Calculate the unit of measurement. """Calculate the unit of measurement.
@ -579,8 +591,17 @@ class SensorGroup(GroupEntity, SensorEntity):
if unit_of_measurement: if unit_of_measurement:
return unit_of_measurement return unit_of_measurement
if not valid_state_entities:
return None
if not self._ignore_non_numeric and len(valid_state_entities) < len(
self._entity_ids
):
# Only return device class if all states are valid when not ignoring non numeric
return None
unit_of_measurements: list[str] = [] unit_of_measurements: list[str] = []
for entity_id in self._entity_ids: for entity_id in valid_state_entities:
try: try:
_unit_of_measurement = get_unit_of_measurement(self.hass, entity_id) _unit_of_measurement = get_unit_of_measurement(self.hass, entity_id)
except HomeAssistantError: except HomeAssistantError:
@ -665,19 +686,31 @@ class SensorGroup(GroupEntity, SensorEntity):
If device class is set and compatible unit of measurements. If device class is set and compatible unit of measurements.
If device class is not set, use one unit of measurement. If device class is not set, use one unit of measurement.
Only calculate valid units if there are no valid units set.
""" """
if ( if (valid_units := self._valid_units) and not self._ignore_non_numeric:
device_class := self.device_class # If we have valid units already and not using ignore_non_numeric
) in UNIT_CONVERTERS and self.native_unit_of_measurement: # we should not recalculate.
return valid_units
native_uom = self.native_unit_of_measurement
if (device_class := self.device_class) in UNIT_CONVERTERS and native_uom:
self._can_convert = True self._can_convert = True
return UNIT_CONVERTERS[device_class].VALID_UNITS return UNIT_CONVERTERS[device_class].VALID_UNITS
if ( if device_class and (device_class) in DEVICE_CLASS_UNITS and native_uom:
device_class
and (device_class) in DEVICE_CLASS_UNITS
and self.native_unit_of_measurement
):
valid_uoms: set = DEVICE_CLASS_UNITS[device_class] valid_uoms: set = DEVICE_CLASS_UNITS[device_class]
return valid_uoms return valid_uoms
if device_class is None and self.native_unit_of_measurement: if device_class is None and native_uom:
return {self.native_unit_of_measurement} return {native_uom}
return set() return set()
def _get_valid_entities(
self,
) -> list[str]:
"""Return list of valid entities."""
return [
entity_id
for entity_id in self._entity_ids
if _has_numeric_state(self.hass, entity_id)
]

View File

@ -32,6 +32,7 @@ from homeassistant.const import (
SERVICE_RELOAD, SERVICE_RELOAD,
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
STATE_UNKNOWN, STATE_UNKNOWN,
UnitOfTemperature,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import issue_registry as ir from homeassistant.helpers import issue_registry as ir
@ -496,7 +497,7 @@ async def test_sensor_with_uoms_but_no_device_class(
state = hass.states.get("sensor.test_sum") state = hass.states.get("sensor.test_sum")
assert state.attributes.get("device_class") is None assert state.attributes.get("device_class") is None
assert state.attributes.get("state_class") is None assert state.attributes.get("state_class") is None
assert state.attributes.get("unit_of_measurement") == "W" assert state.attributes.get("unit_of_measurement") is None
assert state.state == STATE_UNKNOWN assert state.state == STATE_UNKNOWN
assert ( assert (
@ -650,10 +651,10 @@ async def test_sensor_calculated_result_fails_on_uom(hass: HomeAssistant) -> Non
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("sensor.test_sum") state = hass.states.get("sensor.test_sum")
assert state.state == STATE_UNKNOWN assert state.state == STATE_UNAVAILABLE
assert state.attributes.get("device_class") == "energy" assert state.attributes.get("device_class") == "energy"
assert state.attributes.get("state_class") == "total" assert state.attributes.get("state_class") == "total"
assert state.attributes.get("unit_of_measurement") == "kWh" assert state.attributes.get("unit_of_measurement") is None
async def test_sensor_calculated_properties_not_convertible_device_class( async def test_sensor_calculated_properties_not_convertible_device_class(
@ -730,7 +731,7 @@ async def test_sensor_calculated_properties_not_convertible_device_class(
assert state.state == STATE_UNKNOWN assert state.state == STATE_UNKNOWN
assert state.attributes.get("device_class") == "humidity" assert state.attributes.get("device_class") == "humidity"
assert state.attributes.get("state_class") == "measurement" assert state.attributes.get("state_class") == "measurement"
assert state.attributes.get("unit_of_measurement") == "%" assert state.attributes.get("unit_of_measurement") is None
assert ( assert (
"Unable to use state. Only entities with correct unit of measurement is" "Unable to use state. Only entities with correct unit of measurement is"
@ -812,3 +813,197 @@ async def test_sensors_attributes_added_when_entity_info_available(
assert state.attributes.get(ATTR_ICON) is None assert state.attributes.get(ATTR_ICON) is None
assert state.attributes.get(ATTR_STATE_CLASS) == SensorStateClass.TOTAL assert state.attributes.get(ATTR_STATE_CLASS) == SensorStateClass.TOTAL
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "L" assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "L"
async def test_sensor_state_class_no_uom_not_available(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test when input sensors drops unit of measurement."""
# If we have a valid unit of measurement from all input sensors
# the group sensor will go unknown in the case any input sensor
# drops the unit of measurement and log a warning.
config = {
SENSOR_DOMAIN: {
"platform": GROUP_DOMAIN,
"name": "test_sum",
"type": "sum",
"entities": ["sensor.test_1", "sensor.test_2", "sensor.test_3"],
"unique_id": "very_unique_id_sum_sensor",
}
}
entity_ids = config["sensor"]["entities"]
input_attributes = {
"state_class": SensorStateClass.MEASUREMENT,
"unit_of_measurement": PERCENTAGE,
}
hass.states.async_set(entity_ids[0], VALUES[0], input_attributes)
hass.states.async_set(entity_ids[1], VALUES[1], input_attributes)
hass.states.async_set(entity_ids[2], VALUES[2], input_attributes)
await hass.async_block_till_done()
assert await async_setup_component(hass, "sensor", config)
await hass.async_block_till_done()
state = hass.states.get("sensor.test_sum")
assert state.state == str(sum(VALUES))
assert state.attributes.get("state_class") == "measurement"
assert state.attributes.get("unit_of_measurement") == "%"
assert (
"Unable to use state. Only entities with correct unit of measurement is"
" supported"
) not in caplog.text
# sensor.test_3 drops the unit of measurement
hass.states.async_set(
entity_ids[2],
VALUES[2],
{
"state_class": SensorStateClass.MEASUREMENT,
},
)
await hass.async_block_till_done()
state = hass.states.get("sensor.test_sum")
assert state.state == STATE_UNKNOWN
assert state.attributes.get("state_class") == "measurement"
assert state.attributes.get("unit_of_measurement") is None
assert (
"Unable to use state. Only entities with correct unit of measurement is"
" supported, entity sensor.test_3, value 15.3 with"
" device class None and unit of measurement None excluded from calculation"
" in sensor.test_sum"
) in caplog.text
async def test_sensor_different_attributes_ignore_non_numeric(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test the sensor handles calculating attributes when using ignore_non_numeric."""
config = {
SENSOR_DOMAIN: {
"platform": GROUP_DOMAIN,
"name": "test_sum",
"type": "sum",
"ignore_non_numeric": True,
"entities": ["sensor.test_1", "sensor.test_2", "sensor.test_3"],
"unique_id": "very_unique_id_sum_sensor",
}
}
entity_ids = config["sensor"]["entities"]
assert await async_setup_component(hass, "sensor", config)
await hass.async_block_till_done()
state = hass.states.get("sensor.test_sum")
assert state.state == STATE_UNAVAILABLE
assert state.attributes.get("state_class") is None
assert state.attributes.get("device_class") is None
assert state.attributes.get("unit_of_measurement") is None
test_cases = [
{
"entity": entity_ids[0],
"value": VALUES[0],
"attributes": {
"state_class": SensorStateClass.MEASUREMENT,
"unit_of_measurement": PERCENTAGE,
},
"expected_state": str(float(VALUES[0])),
"expected_state_class": SensorStateClass.MEASUREMENT,
"expected_device_class": None,
"expected_unit_of_measurement": PERCENTAGE,
},
{
"entity": entity_ids[1],
"value": VALUES[1],
"attributes": {
"state_class": SensorStateClass.MEASUREMENT,
"device_class": SensorDeviceClass.HUMIDITY,
"unit_of_measurement": PERCENTAGE,
},
"expected_state": str(float(sum([VALUES[0], VALUES[1]]))),
"expected_state_class": SensorStateClass.MEASUREMENT,
"expected_device_class": None,
"expected_unit_of_measurement": PERCENTAGE,
},
{
"entity": entity_ids[2],
"value": VALUES[2],
"attributes": {
"state_class": SensorStateClass.MEASUREMENT,
"device_class": SensorDeviceClass.TEMPERATURE,
"unit_of_measurement": UnitOfTemperature.CELSIUS,
},
"expected_state": str(float(sum(VALUES))),
"expected_state_class": SensorStateClass.MEASUREMENT,
"expected_device_class": None,
"expected_unit_of_measurement": None,
},
{
"entity": entity_ids[2],
"value": VALUES[2],
"attributes": {
"state_class": SensorStateClass.MEASUREMENT,
"device_class": SensorDeviceClass.HUMIDITY,
"unit_of_measurement": PERCENTAGE,
},
"expected_state": str(float(sum(VALUES))),
"expected_state_class": SensorStateClass.MEASUREMENT,
# One sensor does not have a device class
"expected_device_class": None,
"expected_unit_of_measurement": PERCENTAGE,
},
{
"entity": entity_ids[0],
"value": VALUES[0],
"attributes": {
"state_class": SensorStateClass.MEASUREMENT,
"device_class": SensorDeviceClass.HUMIDITY,
"unit_of_measurement": PERCENTAGE,
},
"expected_state": str(float(sum(VALUES))),
"expected_state_class": SensorStateClass.MEASUREMENT,
# First sensor now has a device class
"expected_device_class": SensorDeviceClass.HUMIDITY,
"expected_unit_of_measurement": PERCENTAGE,
},
{
"entity": entity_ids[0],
"value": VALUES[0],
"attributes": {
"state_class": SensorStateClass.MEASUREMENT,
},
"expected_state": str(float(sum(VALUES))),
"expected_state_class": SensorStateClass.MEASUREMENT,
"expected_device_class": None,
"expected_unit_of_measurement": None,
},
]
for test_case in test_cases:
hass.states.async_set(
test_case["entity"],
test_case["value"],
test_case["attributes"],
)
await hass.async_block_till_done()
state = hass.states.get("sensor.test_sum")
assert state.state == test_case["expected_state"]
assert state.attributes.get("state_class") == test_case["expected_state_class"]
assert (
state.attributes.get("device_class") == test_case["expected_device_class"]
)
assert (
state.attributes.get("unit_of_measurement")
== test_case["expected_unit_of_measurement"]
)