From 91668f8599821261f25919ab21a63a3e43b1f81b Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Fri, 3 Feb 2023 10:49:41 +0100 Subject: [PATCH] Use SensorStateClass enum in sensor (#87066) * Use SensorStateClass enum in sensor * Apply suggestions from code review Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> * Update homeassistant/components/sensor/recorder.py Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> --------- Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> --- homeassistant/components/sensor/recorder.py | 30 ++++++++++----------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/homeassistant/components/sensor/recorder.py b/homeassistant/components/sensor/recorder.py index 656e9fb00f0..49a352bf17c 100644 --- a/homeassistant/components/sensor/recorder.py +++ b/homeassistant/components/sensor/recorder.py @@ -34,24 +34,22 @@ from homeassistant.core import HomeAssistant, State, callback, split_entity_id from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.entity import entity_sources from homeassistant.util import dt as dt_util +from homeassistant.util.enum import try_parse_enum -from . import ( +from .const import ( ATTR_LAST_RESET, ATTR_OPTIONS, ATTR_STATE_CLASS, DOMAIN, - STATE_CLASS_MEASUREMENT, - STATE_CLASS_TOTAL, - STATE_CLASS_TOTAL_INCREASING, - STATE_CLASSES, + SensorStateClass, ) _LOGGER = logging.getLogger(__name__) DEFAULT_STATISTICS = { - STATE_CLASS_MEASUREMENT: {"mean", "min", "max"}, - STATE_CLASS_TOTAL: {"sum"}, - STATE_CLASS_TOTAL_INCREASING: {"sum"}, + SensorStateClass.MEASUREMENT: {"mean", "min", "max"}, + SensorStateClass.TOTAL: {"sum"}, + SensorStateClass.TOTAL_INCREASING: {"sum"}, } EQUIVALENT_UNITS = { @@ -82,7 +80,7 @@ def _get_sensor_states(hass: HomeAssistant) -> list[State]: for state in all_sensors: if not is_entity_recorded(hass, state.entity_id): continue - if (state.attributes.get(ATTR_STATE_CLASS)) not in STATE_CLASSES: + if not try_parse_enum(SensorStateClass, state.attributes.get(ATTR_STATE_CLASS)): continue statistics_sensors.append(state) @@ -537,7 +535,7 @@ def _compile_statistics( # noqa: C901 for fstate, state in fstates: reset = False if ( - state_class != STATE_CLASS_TOTAL_INCREASING + state_class != SensorStateClass.TOTAL_INCREASING and ( last_reset := _last_reset_as_utc_isoformat( state.attributes.get("last_reset"), entity_id @@ -573,7 +571,7 @@ def _compile_statistics( # noqa: C901 entity_id, fstate, ) - elif state_class == STATE_CLASS_TOTAL_INCREASING: + elif state_class == SensorStateClass.TOTAL_INCREASING: try: if old_state is None or reset_detected( hass, entity_id, fstate, new_state, state @@ -648,7 +646,7 @@ def list_statistic_ids( if ( "sum" in provided_statistics and ATTR_LAST_RESET not in state.attributes - and state.attributes.get(ATTR_STATE_CLASS) == STATE_CLASS_MEASUREMENT + and state.attributes.get(ATTR_STATE_CLASS) == SensorStateClass.MEASUREMENT ): continue @@ -678,7 +676,9 @@ def validate_statistics( for state in sensor_states: entity_id = state.entity_id - state_class = state.attributes.get(ATTR_STATE_CLASS) + state_class = try_parse_enum( + SensorStateClass, state.attributes.get(ATTR_STATE_CLASS) + ) state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) if metadata := metadatas.get(entity_id): @@ -691,7 +691,7 @@ def validate_statistics( ) ) - if state_class not in STATE_CLASSES: + if state_class is None: # Sensor no longer has a valid state class validation_result[entity_id].append( statistics.ValidationIssue( @@ -731,7 +731,7 @@ def validate_statistics( }, ) ) - elif state_class in STATE_CLASSES: + elif state_class is not None: if not is_entity_recorded(hass, state.entity_id): # Sensor is not recorded validation_result[entity_id].append(