Improve device class of utility meter (#114368)

This commit is contained in:
Erik Montnemery 2024-03-28 13:24:44 +01:00 committed by GitHub
parent 68d6f96a9d
commit 5b98a8458f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 146 additions and 29 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from decimal import Decimal, DecimalException, InvalidOperation from decimal import Decimal, DecimalException, InvalidOperation
@ -13,6 +14,7 @@ import voluptuous as vol
from homeassistant.components.sensor import ( from homeassistant.components.sensor import (
ATTR_LAST_RESET, ATTR_LAST_RESET,
DEVICE_CLASS_UNITS,
RestoreSensor, RestoreSensor,
SensorDeviceClass, SensorDeviceClass,
SensorExtraStoredData, SensorExtraStoredData,
@ -21,12 +23,12 @@ from homeassistant.components.sensor import (
from homeassistant.components.sensor.recorder import _suggest_report_issue from homeassistant.components.sensor.recorder import _suggest_report_issue
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
ATTR_DEVICE_CLASS,
ATTR_UNIT_OF_MEASUREMENT, ATTR_UNIT_OF_MEASUREMENT,
CONF_NAME, CONF_NAME,
CONF_UNIQUE_ID, CONF_UNIQUE_ID,
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
STATE_UNKNOWN, STATE_UNKNOWN,
UnitOfEnergy,
) )
from homeassistant.core import Event, HomeAssistant, State, callback from homeassistant.core import Event, HomeAssistant, State, callback
from homeassistant.helpers import ( from homeassistant.helpers import (
@ -47,6 +49,7 @@ from homeassistant.helpers.template import is_number
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.util import slugify from homeassistant.util import slugify
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.util.enum import try_parse_enum
from .const import ( from .const import (
ATTR_CRON_PATTERN, ATTR_CRON_PATTERN,
@ -97,12 +100,6 @@ ATTR_LAST_PERIOD = "last_period"
ATTR_LAST_VALID_STATE = "last_valid_state" ATTR_LAST_VALID_STATE = "last_valid_state"
ATTR_TARIFF = "tariff" ATTR_TARIFF = "tariff"
DEVICE_CLASS_MAP = {
UnitOfEnergy.WATT_HOUR: SensorDeviceClass.ENERGY,
UnitOfEnergy.KILO_WATT_HOUR: SensorDeviceClass.ENERGY,
}
PRECISION = 3 PRECISION = 3
PAUSED = "paused" PAUSED = "paused"
COLLECTING = "collecting" COLLECTING = "collecting"
@ -313,6 +310,7 @@ class UtilitySensorExtraStoredData(SensorExtraStoredData):
last_reset: datetime | None last_reset: datetime | None
last_valid_state: Decimal | None last_valid_state: Decimal | None
status: str status: str
input_device_class: SensorDeviceClass | None
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
"""Return a dict representation of the utility sensor data.""" """Return a dict representation of the utility sensor data."""
@ -324,6 +322,7 @@ class UtilitySensorExtraStoredData(SensorExtraStoredData):
str(self.last_valid_state) if self.last_valid_state else None str(self.last_valid_state) if self.last_valid_state else None
) )
data["status"] = self.status data["status"] = self.status
data["input_device_class"] = str(self.input_device_class)
return data return data
@ -343,6 +342,9 @@ class UtilitySensorExtraStoredData(SensorExtraStoredData):
else None else None
) )
status: str = restored["status"] status: str = restored["status"]
input_device_class = try_parse_enum(
SensorDeviceClass, restored.get("input_device_class")
)
except KeyError: except KeyError:
# restored is a dict, but does not have all values # restored is a dict, but does not have all values
return None return None
@ -357,6 +359,7 @@ class UtilitySensorExtraStoredData(SensorExtraStoredData):
last_reset, last_reset,
last_valid_state, last_valid_state,
status, status,
input_device_class,
) )
@ -397,6 +400,7 @@ class UtilityMeterSensor(RestoreSensor):
self._last_valid_state = None self._last_valid_state = None
self._collecting = None self._collecting = None
self._name = name self._name = name
self._input_device_class = None
self._unit_of_measurement = None self._unit_of_measurement = None
self._period = meter_type self._period = meter_type
if meter_type is not None: if meter_type is not None:
@ -416,9 +420,10 @@ class UtilityMeterSensor(RestoreSensor):
self._tariff = tariff self._tariff = tariff
self._tariff_entity = tariff_entity self._tariff_entity = tariff_entity
def start(self, unit): def start(self, attributes: Mapping[str, Any]) -> None:
"""Initialize unit and state upon source initial update.""" """Initialize unit and state upon source initial update."""
self._unit_of_measurement = unit self._input_device_class = attributes.get(ATTR_DEVICE_CLASS)
self._unit_of_measurement = attributes.get(ATTR_UNIT_OF_MEASUREMENT)
self._state = 0 self._state = 0
self.async_write_ha_state() self.async_write_ha_state()
@ -482,6 +487,7 @@ class UtilityMeterSensor(RestoreSensor):
new_state = event.data["new_state"] new_state = event.data["new_state"]
if new_state is None: if new_state is None:
return return
new_state_attributes: Mapping[str, Any] = new_state.attributes or {}
# First check if the new_state is valid (see discussion in PR #88446) # First check if the new_state is valid (see discussion in PR #88446)
if (new_state_val := self._validate_state(new_state)) is None: if (new_state_val := self._validate_state(new_state)) is None:
@ -498,7 +504,7 @@ class UtilityMeterSensor(RestoreSensor):
for sensor in self.hass.data[DATA_UTILITY][self._parent_meter][ for sensor in self.hass.data[DATA_UTILITY][self._parent_meter][
DATA_TARIFF_SENSORS DATA_TARIFF_SENSORS
]: ]:
sensor.start(new_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)) sensor.start(new_state_attributes)
if self._unit_of_measurement is None: if self._unit_of_measurement is None:
_LOGGER.warning( _LOGGER.warning(
"Source sensor %s has no unit of measurement. Please %s", "Source sensor %s has no unit of measurement. Please %s",
@ -512,7 +518,8 @@ class UtilityMeterSensor(RestoreSensor):
# If net_consumption is off, the adjustment must be non-negative # If net_consumption is off, the adjustment must be non-negative
self._state += adjustment # type: ignore[operator] # self._state will be set to by the start function if it is None, therefore it always has a valid Decimal value at this line self._state += adjustment # type: ignore[operator] # self._state will be set to by the start function if it is None, therefore it always has a valid Decimal value at this line
self._unit_of_measurement = new_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) self._input_device_class = new_state_attributes.get(ATTR_DEVICE_CLASS)
self._unit_of_measurement = new_state_attributes.get(ATTR_UNIT_OF_MEASUREMENT)
self._last_valid_state = new_state_val self._last_valid_state = new_state_val
self.async_write_ha_state() self.async_write_ha_state()
@ -600,6 +607,7 @@ class UtilityMeterSensor(RestoreSensor):
if (last_sensor_data := await self.async_get_last_sensor_data()) is not None: if (last_sensor_data := await self.async_get_last_sensor_data()) is not None:
# new introduced in 2022.04 # new introduced in 2022.04
self._state = last_sensor_data.native_value self._state = last_sensor_data.native_value
self._input_device_class = last_sensor_data.input_device_class
self._unit_of_measurement = last_sensor_data.native_unit_of_measurement self._unit_of_measurement = last_sensor_data.native_unit_of_measurement
self._last_period = last_sensor_data.last_period self._last_period = last_sensor_data.last_period
self._last_reset = last_sensor_data.last_reset self._last_reset = last_sensor_data.last_reset
@ -693,7 +701,11 @@ class UtilityMeterSensor(RestoreSensor):
@property @property
def device_class(self): def device_class(self):
"""Return the device class of the sensor.""" """Return the device class of the sensor."""
return DEVICE_CLASS_MAP.get(self._unit_of_measurement) if self._input_device_class is not None:
return self._input_device_class
if self._unit_of_measurement in DEVICE_CLASS_UNITS[SensorDeviceClass.ENERGY]:
return SensorDeviceClass.ENERGY
return None
@property @property
def state_class(self): def state_class(self):
@ -744,6 +756,7 @@ class UtilityMeterSensor(RestoreSensor):
self._last_reset, self._last_reset,
self._last_valid_state, self._last_valid_state,
PAUSED if self._collecting is None else COLLECTING, PAUSED if self._collecting is None else COLLECTING,
self._input_device_class,
) )
async def async_get_last_sensor_data(self) -> UtilitySensorExtraStoredData | None: async def async_get_last_sensor_data(self) -> UtilitySensorExtraStoredData | None:

View File

@ -40,6 +40,7 @@ from homeassistant.const import (
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
STATE_UNKNOWN, STATE_UNKNOWN,
UnitOfEnergy, UnitOfEnergy,
UnitOfVolume,
) )
from homeassistant.core import CoreState, HomeAssistant, State from homeassistant.core import CoreState, HomeAssistant, State
from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers import device_registry as dr, entity_registry as er
@ -553,8 +554,66 @@ async def test_entity_name(hass: HomeAssistant, yaml_config, entity_id, name) ->
), ),
], ],
) )
@pytest.mark.parametrize(
(
"energy_sensor_attributes",
"gas_sensor_attributes",
"energy_meter_attributes",
"gas_meter_attributes",
),
[
(
{ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR},
{ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit"},
{
ATTR_DEVICE_CLASS: SensorDeviceClass.ENERGY,
ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR,
},
{
ATTR_DEVICE_CLASS: None,
ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit",
},
),
(
{},
{},
{
ATTR_DEVICE_CLASS: None,
ATTR_UNIT_OF_MEASUREMENT: None,
},
{
ATTR_DEVICE_CLASS: None,
ATTR_UNIT_OF_MEASUREMENT: None,
},
),
(
{
ATTR_DEVICE_CLASS: SensorDeviceClass.GAS,
ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR,
},
{
ATTR_DEVICE_CLASS: SensorDeviceClass.WATER,
ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit",
},
{
ATTR_DEVICE_CLASS: SensorDeviceClass.GAS,
ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR,
},
{
ATTR_DEVICE_CLASS: SensorDeviceClass.WATER,
ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit",
},
),
],
)
async def test_device_class( async def test_device_class(
hass: HomeAssistant, yaml_config, config_entry_configs hass: HomeAssistant,
yaml_config,
config_entry_configs,
energy_sensor_attributes,
gas_sensor_attributes,
energy_meter_attributes,
gas_meter_attributes,
) -> None: ) -> None:
"""Test utility device_class.""" """Test utility device_class."""
if yaml_config: if yaml_config:
@ -579,27 +638,23 @@ async def test_device_class(
await hass.async_block_till_done() await hass.async_block_till_done()
hass.states.async_set( hass.states.async_set(entity_id_energy, 2, energy_sensor_attributes)
entity_id_energy, 2, {ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR} hass.states.async_set(entity_id_gas, 2, gas_sensor_attributes)
)
hass.states.async_set(
entity_id_gas, 2, {ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit"}
)
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("sensor.energy_meter") state = hass.states.get("sensor.energy_meter")
assert state is not None assert state is not None
assert state.state == "0" assert state.state == "0"
assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.ENERGY
assert state.attributes.get(ATTR_STATE_CLASS) is SensorStateClass.TOTAL assert state.attributes.get(ATTR_STATE_CLASS) is SensorStateClass.TOTAL
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.KILO_WATT_HOUR for attr, value in energy_meter_attributes.items():
assert state.attributes.get(attr) == value
state = hass.states.get("sensor.gas_meter") state = hass.states.get("sensor.gas_meter")
assert state is not None assert state is not None
assert state.state == "0" assert state.state == "0"
assert state.attributes.get(ATTR_DEVICE_CLASS) is None
assert state.attributes.get(ATTR_STATE_CLASS) is SensorStateClass.TOTAL_INCREASING assert state.attributes.get(ATTR_STATE_CLASS) is SensorStateClass.TOTAL_INCREASING
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "some_archaic_unit" for attr, value in gas_meter_attributes.items():
assert state.attributes.get(attr) == value
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -610,7 +665,13 @@ async def test_device_class(
"utility_meter": { "utility_meter": {
"energy_bill": { "energy_bill": {
"source": "sensor.energy", "source": "sensor.energy",
"tariffs": ["tariff1", "tariff2", "tariff3", "tariff4"], "tariffs": [
"tariff0",
"tariff1",
"tariff2",
"tariff3",
"tariff4",
],
} }
} }
}, },
@ -626,7 +687,13 @@ async def test_device_class(
"offset": 0, "offset": 0,
"periodically_resetting": True, "periodically_resetting": True,
"source": "sensor.energy", "source": "sensor.energy",
"tariffs": ["tariff1", "tariff2", "tariff3", "tariff4"], "tariffs": [
"tariff0",
"tariff1",
"tariff2",
"tariff3",
"tariff4",
],
}, },
), ),
], ],
@ -644,7 +711,33 @@ async def test_restore_state(
mock_restore_cache_with_extra_data( mock_restore_cache_with_extra_data(
hass, hass,
[ [
# sensor.energy_bill_tariff1 is restored as expected # sensor.energy_bill_tariff0 is restored as expected, including device
# class
(
State(
"sensor.energy_bill_tariff0",
"0.1",
attributes={
ATTR_STATUS: PAUSED,
ATTR_LAST_RESET: last_reset_1,
ATTR_UNIT_OF_MEASUREMENT: UnitOfVolume.CUBIC_METERS,
},
),
{
"native_value": {
"__type": "<class 'decimal.Decimal'>",
"decimal_str": "0.2",
},
"native_unit_of_measurement": "gal",
"last_reset": last_reset_2,
"last_period": "1.3",
"last_valid_state": None,
"status": "collecting",
"input_device_class": "water",
},
),
# sensor.energy_bill_tariff1 is restored as expected, except device
# class
( (
State( State(
"sensor.energy_bill_tariff1", "sensor.energy_bill_tariff1",
@ -743,12 +836,21 @@ async def test_restore_state(
await hass.async_block_till_done() await hass.async_block_till_done()
# restore from cache # restore from cache
state = hass.states.get("sensor.energy_bill_tariff0")
assert state.state == "0.2"
assert state.attributes.get("status") == COLLECTING
assert state.attributes.get("last_reset") == last_reset_2
assert state.attributes.get("last_valid_state") == "None"
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfVolume.GALLONS
assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.WATER
state = hass.states.get("sensor.energy_bill_tariff1") state = hass.states.get("sensor.energy_bill_tariff1")
assert state.state == "1.2" assert state.state == "1.2"
assert state.attributes.get("status") == PAUSED assert state.attributes.get("status") == PAUSED
assert state.attributes.get("last_reset") == last_reset_2 assert state.attributes.get("last_reset") == last_reset_2
assert state.attributes.get("last_valid_state") == "None" assert state.attributes.get("last_valid_state") == "None"
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.KILO_WATT_HOUR assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.KILO_WATT_HOUR
assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.ENERGY
state = hass.states.get("sensor.energy_bill_tariff2") state = hass.states.get("sensor.energy_bill_tariff2")
assert state.state == "2.1" assert state.state == "2.1"
@ -756,6 +858,7 @@ async def test_restore_state(
assert state.attributes.get("last_reset") == last_reset_1 assert state.attributes.get("last_reset") == last_reset_1
assert state.attributes.get("last_valid_state") == "None" assert state.attributes.get("last_valid_state") == "None"
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.MEGA_WATT_HOUR assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.MEGA_WATT_HOUR
assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.ENERGY
state = hass.states.get("sensor.energy_bill_tariff3") state = hass.states.get("sensor.energy_bill_tariff3")
assert state.state == "3.1" assert state.state == "3.1"
@ -763,6 +866,7 @@ async def test_restore_state(
assert state.attributes.get("last_reset") == last_reset_1 assert state.attributes.get("last_reset") == last_reset_1
assert state.attributes.get("last_valid_state") == "None" assert state.attributes.get("last_valid_state") == "None"
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.MEGA_WATT_HOUR assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.MEGA_WATT_HOUR
assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.ENERGY
state = hass.states.get("sensor.energy_bill_tariff4") state = hass.states.get("sensor.energy_bill_tariff4")
assert state.state == STATE_UNKNOWN assert state.state == STATE_UNKNOWN
@ -770,16 +874,16 @@ async def test_restore_state(
# utility_meter is loaded, now set sensors according to utility_meter: # utility_meter is loaded, now set sensors according to utility_meter:
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("select.energy_bill") state = hass.states.get("select.energy_bill")
assert state.state == "tariff1" assert state.state == "tariff0"
state = hass.states.get("sensor.energy_bill_tariff1") state = hass.states.get("sensor.energy_bill_tariff0")
assert state.attributes.get("status") == COLLECTING assert state.attributes.get("status") == COLLECTING
for entity_id in ( for entity_id in (
"sensor.energy_bill_tariff1",
"sensor.energy_bill_tariff2", "sensor.energy_bill_tariff2",
"sensor.energy_bill_tariff3", "sensor.energy_bill_tariff3",
"sensor.energy_bill_tariff4", "sensor.energy_bill_tariff4",