From 6420837d58472e6d8ed3aeff8f26696bf643e955 Mon Sep 17 00:00:00 2001 From: Joakim Plate Date: Fri, 21 Jun 2024 12:47:57 +0200 Subject: [PATCH] Calculate device class as soon as it is known in integral (#119940) --- .../components/integration/sensor.py | 46 ++++++++++--- .../integration/snapshots/test_sensor.ambr | 69 +++++++++++++++++++ tests/components/integration/test_sensor.py | 65 ++++++++++++++++- 3 files changed, 169 insertions(+), 11 deletions(-) create mode 100644 tests/components/integration/snapshots/test_sensor.ambr diff --git a/homeassistant/components/integration/sensor.py b/homeassistant/components/integration/sensor.py index 02451773558..d201fab0c6f 100644 --- a/homeassistant/components/integration/sensor.py +++ b/homeassistant/components/integration/sensor.py @@ -13,6 +13,7 @@ from typing import Any, Final, Self import voluptuous as vol from homeassistant.components.sensor import ( + DEVICE_CLASS_UNITS, PLATFORM_SCHEMA, RestoreSensor, SensorDeviceClass, @@ -75,6 +76,10 @@ UNIT_TIME = { UnitOfTime.DAYS: 24 * 60 * 60, } +DEVICE_CLASS_MAP = { + SensorDeviceClass.POWER: SensorDeviceClass.ENERGY, +} + DEFAULT_ROUND = 3 PLATFORM_SCHEMA = vol.All( @@ -381,6 +386,22 @@ class IntegrationSensor(RestoreSensor): return f"{self._unit_prefix_string}{integral_unit}" + def _calculate_device_class( + self, + source_device_class: SensorDeviceClass | None, + unit_of_measurement: str | None, + ) -> SensorDeviceClass | None: + """Deduce device class if possible from source device class and target unit.""" + if source_device_class is None: + return None + + if (device_class := DEVICE_CLASS_MAP.get(source_device_class)) is None: + return None + + if unit_of_measurement not in DEVICE_CLASS_UNITS.get(device_class, set()): + return None + return device_class + def _derive_and_set_attributes_from_state(self, source_state: State) -> None: source_unit = source_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) if source_unit is not None: @@ -389,13 +410,13 @@ class IntegrationSensor(RestoreSensor): # If the source has no defined unit we cannot derive a unit for the integral self._unit_of_measurement = None - if ( - self.device_class is None - and source_state.attributes.get(ATTR_DEVICE_CLASS) - == SensorDeviceClass.POWER - ): - self._attr_device_class = SensorDeviceClass.ENERGY - self._attr_icon = None # Remove this sensors icon default and allow to fallback to the ENERGY default + self._attr_device_class = self._calculate_device_class( + source_state.attributes.get(ATTR_DEVICE_CLASS), self.unit_of_measurement + ) + if self._attr_device_class: + self._attr_icon = None # Remove this sensors icon default and allow to fallback to the device class default + else: + self._attr_icon = "mdi:chart-histogram" def _update_integral(self, area: Decimal) -> None: area_scaled = area / (self._unit_prefix * self._unit_time) @@ -436,6 +457,11 @@ class IntegrationSensor(RestoreSensor): else: handle_state_change = self._integrate_on_state_change_callback + if ( + state := self.hass.states.get(self._source_entity) + ) and state.state != STATE_UNAVAILABLE: + self._derive_and_set_attributes_from_state(state) + self.async_on_remove( async_track_state_change_event( self.hass, @@ -477,7 +503,7 @@ class IntegrationSensor(RestoreSensor): def _integrate_on_state_change( self, old_state: State | None, new_state: State | None ) -> None: - if old_state is None or new_state is None: + if new_state is None: return if new_state.state == STATE_UNAVAILABLE: @@ -488,6 +514,10 @@ class IntegrationSensor(RestoreSensor): self._attr_available = True self._derive_and_set_attributes_from_state(new_state) + if old_state is None: + self.async_write_ha_state() + return + if not (states := self._method.validate_states(old_state, new_state)): self.async_write_ha_state() return diff --git a/tests/components/integration/snapshots/test_sensor.ambr b/tests/components/integration/snapshots/test_sensor.ambr new file mode 100644 index 00000000000..5747e6489b9 --- /dev/null +++ b/tests/components/integration/snapshots/test_sensor.ambr @@ -0,0 +1,69 @@ +# serializer version: 1 +# name: test_initial_state[BTU/h-power-h] + StateSnapshot({ + 'attributes': ReadOnlyDict({ + 'friendly_name': 'integration', + 'icon': 'mdi:chart-histogram', + 'source': 'sensor.source', + 'state_class': , + 'unit_of_measurement': 'BTU', + }), + 'context': , + 'entity_id': 'sensor.integration', + 'last_changed': , + 'last_reported': , + 'last_updated': , + 'state': 'unknown', + }) +# --- +# name: test_initial_state[ft\xb3/min-volume_flow_rate-min] + StateSnapshot({ + 'attributes': ReadOnlyDict({ + 'friendly_name': 'integration', + 'icon': 'mdi:chart-histogram', + 'source': 'sensor.source', + 'state_class': , + 'unit_of_measurement': 'ft³', + }), + 'context': , + 'entity_id': 'sensor.integration', + 'last_changed': , + 'last_reported': , + 'last_updated': , + 'state': 'unknown', + }) +# --- +# name: test_initial_state[kW-None-h] + StateSnapshot({ + 'attributes': ReadOnlyDict({ + 'friendly_name': 'integration', + 'icon': 'mdi:chart-histogram', + 'source': 'sensor.source', + 'state_class': , + 'unit_of_measurement': 'kWh', + }), + 'context': , + 'entity_id': 'sensor.integration', + 'last_changed': , + 'last_reported': , + 'last_updated': , + 'state': 'unknown', + }) +# --- +# name: test_initial_state[kW-power-h] + StateSnapshot({ + 'attributes': ReadOnlyDict({ + 'device_class': 'energy', + 'friendly_name': 'integration', + 'source': 'sensor.source', + 'state_class': , + 'unit_of_measurement': 'kWh', + }), + 'context': , + 'entity_id': 'sensor.integration', + 'last_changed': , + 'last_reported': , + 'last_updated': , + 'state': 'unknown', + }) +# --- diff --git a/tests/components/integration/test_sensor.py b/tests/components/integration/test_sensor.py index 1a729f6254e..243504cb3e0 100644 --- a/tests/components/integration/test_sensor.py +++ b/tests/components/integration/test_sensor.py @@ -5,10 +5,12 @@ from typing import Any from freezegun import freeze_time import pytest +from syrupy.assertion import SnapshotAssertion from homeassistant.components.integration.const import DOMAIN from homeassistant.components.sensor import SensorDeviceClass, SensorStateClass from homeassistant.const import ( + ATTR_DEVICE_CLASS, ATTR_UNIT_OF_MEASUREMENT, STATE_UNAVAILABLE, STATE_UNKNOWN, @@ -17,6 +19,7 @@ from homeassistant.const import ( UnitOfInformation, UnitOfPower, UnitOfTime, + UnitOfVolumeFlowRate, ) from homeassistant.core import HomeAssistant, State from homeassistant.helpers import ( @@ -36,6 +39,52 @@ from tests.common import ( DEFAULT_MAX_SUB_INTERVAL = {"minutes": 1} +@pytest.mark.parametrize( + ("unit_of_measurement", "device_class", "unit_time"), + [ + (UnitOfPower.KILO_WATT, SensorDeviceClass.POWER, "h"), + (UnitOfPower.KILO_WATT, None, "h"), + (UnitOfPower.BTU_PER_HOUR, SensorDeviceClass.POWER, "h"), + ( + UnitOfVolumeFlowRate.CUBIC_FEET_PER_MINUTE, + SensorDeviceClass.VOLUME_FLOW_RATE, + "min", + ), + ], +) +async def test_initial_state( + hass: HomeAssistant, + unit_of_measurement: str, + device_class: SensorDeviceClass, + unit_time: str, + snapshot: SnapshotAssertion, +) -> None: + """Test integration sensor state.""" + config = { + "sensor": { + "platform": "integration", + "name": "integration", + "source": "sensor.source", + "round": 2, + "method": "left", + "unit_time": unit_time, + } + } + + assert await async_setup_component(hass, "sensor", config) + hass.states.async_set( + "sensor.source", + "1", + { + ATTR_DEVICE_CLASS: device_class, + ATTR_UNIT_OF_MEASUREMENT: unit_of_measurement, + }, + ) + await hass.async_block_till_done() + + assert hass.states.get("sensor.integration") == snapshot + + @pytest.mark.parametrize("method", ["trapezoidal", "left", "right"]) async def test_state(hass: HomeAssistant, method) -> None: """Test integration sensor state.""" @@ -49,13 +98,23 @@ async def test_state(hass: HomeAssistant, method) -> None: } } + assert await async_setup_component(hass, "sensor", config) + await hass.async_block_till_done() + + state = hass.states.get("sensor.integration") + assert state is not None + assert state.attributes.get("state_class") is SensorStateClass.TOTAL + assert "device_class" not in state.attributes + now = dt_util.utcnow() with freeze_time(now): - assert await async_setup_component(hass, "sensor", config) - entity_id = config["sensor"]["source"] hass.states.async_set( - entity_id, 1, {ATTR_UNIT_OF_MEASUREMENT: UnitOfPower.KILO_WATT} + entity_id, + 1, + { + ATTR_UNIT_OF_MEASUREMENT: UnitOfPower.KILO_WATT, + }, ) await hass.async_block_till_done()