Use SensorStateClass enum in sensor (#87066)

* Use SensorStateClass enum in sensor

* Apply suggestions from code review

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Update homeassistant/components/sensor/recorder.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

---------

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
Erik Montnemery 2023-02-03 10:49:41 +01:00 committed by GitHub
parent 2349aa73b2
commit 91668f8599
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -34,24 +34,22 @@ from homeassistant.core import HomeAssistant, State, callback, split_entity_id
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity import entity_sources from homeassistant.helpers.entity import entity_sources
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.enum import try_parse_enum
from . import ( from .const import (
ATTR_LAST_RESET, ATTR_LAST_RESET,
ATTR_OPTIONS, ATTR_OPTIONS,
ATTR_STATE_CLASS, ATTR_STATE_CLASS,
DOMAIN, DOMAIN,
STATE_CLASS_MEASUREMENT, SensorStateClass,
STATE_CLASS_TOTAL,
STATE_CLASS_TOTAL_INCREASING,
STATE_CLASSES,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DEFAULT_STATISTICS = { DEFAULT_STATISTICS = {
STATE_CLASS_MEASUREMENT: {"mean", "min", "max"}, SensorStateClass.MEASUREMENT: {"mean", "min", "max"},
STATE_CLASS_TOTAL: {"sum"}, SensorStateClass.TOTAL: {"sum"},
STATE_CLASS_TOTAL_INCREASING: {"sum"}, SensorStateClass.TOTAL_INCREASING: {"sum"},
} }
EQUIVALENT_UNITS = { EQUIVALENT_UNITS = {
@ -82,7 +80,7 @@ def _get_sensor_states(hass: HomeAssistant) -> list[State]:
for state in all_sensors: for state in all_sensors:
if not is_entity_recorded(hass, state.entity_id): if not is_entity_recorded(hass, state.entity_id):
continue continue
if (state.attributes.get(ATTR_STATE_CLASS)) not in STATE_CLASSES: if not try_parse_enum(SensorStateClass, state.attributes.get(ATTR_STATE_CLASS)):
continue continue
statistics_sensors.append(state) statistics_sensors.append(state)
@ -537,7 +535,7 @@ def _compile_statistics( # noqa: C901
for fstate, state in fstates: for fstate, state in fstates:
reset = False reset = False
if ( if (
state_class != STATE_CLASS_TOTAL_INCREASING state_class != SensorStateClass.TOTAL_INCREASING
and ( and (
last_reset := _last_reset_as_utc_isoformat( last_reset := _last_reset_as_utc_isoformat(
state.attributes.get("last_reset"), entity_id state.attributes.get("last_reset"), entity_id
@ -573,7 +571,7 @@ def _compile_statistics( # noqa: C901
entity_id, entity_id,
fstate, fstate,
) )
elif state_class == STATE_CLASS_TOTAL_INCREASING: elif state_class == SensorStateClass.TOTAL_INCREASING:
try: try:
if old_state is None or reset_detected( if old_state is None or reset_detected(
hass, entity_id, fstate, new_state, state hass, entity_id, fstate, new_state, state
@ -648,7 +646,7 @@ def list_statistic_ids(
if ( if (
"sum" in provided_statistics "sum" in provided_statistics
and ATTR_LAST_RESET not in state.attributes and ATTR_LAST_RESET not in state.attributes
and state.attributes.get(ATTR_STATE_CLASS) == STATE_CLASS_MEASUREMENT and state.attributes.get(ATTR_STATE_CLASS) == SensorStateClass.MEASUREMENT
): ):
continue continue
@ -678,7 +676,9 @@ def validate_statistics(
for state in sensor_states: for state in sensor_states:
entity_id = state.entity_id entity_id = state.entity_id
state_class = state.attributes.get(ATTR_STATE_CLASS) state_class = try_parse_enum(
SensorStateClass, state.attributes.get(ATTR_STATE_CLASS)
)
state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
if metadata := metadatas.get(entity_id): if metadata := metadatas.get(entity_id):
@ -691,7 +691,7 @@ def validate_statistics(
) )
) )
if state_class not in STATE_CLASSES: if state_class is None:
# Sensor no longer has a valid state class # Sensor no longer has a valid state class
validation_result[entity_id].append( validation_result[entity_id].append(
statistics.ValidationIssue( statistics.ValidationIssue(
@ -731,7 +731,7 @@ def validate_statistics(
}, },
) )
) )
elif state_class in STATE_CLASSES: elif state_class is not None:
if not is_entity_recorded(hass, state.entity_id): if not is_entity_recorded(hass, state.entity_id):
# Sensor is not recorded # Sensor is not recorded
validation_result[entity_id].append( validation_result[entity_id].append(