diff --git a/homeassistant/components/utility_meter/sensor.py b/homeassistant/components/utility_meter/sensor.py index 84533efdcf5..e8d938928dd 100644 --- a/homeassistant/components/utility_meter/sensor.py +++ b/homeassistant/components/utility_meter/sensor.py @@ -5,7 +5,11 @@ import logging import voluptuous as vol -from homeassistant.components.sensor import STATE_CLASS_MEASUREMENT, SensorEntity +from homeassistant.components.sensor import ( + STATE_CLASS_MEASUREMENT, + STATE_CLASS_TOTAL_INCREASING, + SensorEntity, +) from homeassistant.const import ( ATTR_UNIT_OF_MEASUREMENT, CONF_NAME, @@ -330,7 +334,11 @@ class UtilityMeterSensor(RestoreEntity, SensorEntity): @property def state_class(self): """Return the device class of the sensor.""" - return STATE_CLASS_MEASUREMENT + return ( + STATE_CLASS_MEASUREMENT + if self._sensor_net_consumption + else STATE_CLASS_TOTAL_INCREASING + ) @property def native_unit_of_measurement(self): diff --git a/tests/components/utility_meter/test_sensor.py b/tests/components/utility_meter/test_sensor.py index c5075aa322b..a2d15c595b0 100644 --- a/tests/components/utility_meter/test_sensor.py +++ b/tests/components/utility_meter/test_sensor.py @@ -3,7 +3,11 @@ from contextlib import contextmanager from datetime import timedelta from unittest.mock import patch -from homeassistant.components.sensor import ATTR_STATE_CLASS, STATE_CLASS_MEASUREMENT +from homeassistant.components.sensor import ( + ATTR_STATE_CLASS, + STATE_CLASS_MEASUREMENT, + STATE_CLASS_TOTAL_INCREASING, +) from homeassistant.components.utility_meter.const import ( ATTR_TARIFF, ATTR_VALUE, @@ -165,6 +169,7 @@ async def test_device_class(hass): "utility_meter": { "energy_meter": { "source": "sensor.energy", + "net_consumption": True, }, "gas_meter": { "source": "sensor.gas", @@ -197,7 +202,7 @@ async def test_device_class(hass): assert state is not None assert state.state == "0" assert state.attributes.get(ATTR_DEVICE_CLASS) is None - assert state.attributes.get(ATTR_STATE_CLASS) == STATE_CLASS_MEASUREMENT + assert state.attributes.get(ATTR_STATE_CLASS) == STATE_CLASS_TOTAL_INCREASING assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) is None hass.states.async_set( @@ -219,7 +224,7 @@ async def test_device_class(hass): assert state is not None assert state.state == "1" assert state.attributes.get(ATTR_DEVICE_CLASS) is None - assert state.attributes.get(ATTR_STATE_CLASS) == STATE_CLASS_MEASUREMENT + assert state.attributes.get(ATTR_STATE_CLASS) == STATE_CLASS_TOTAL_INCREASING assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "some_archaic_unit"