diff --git a/homeassistant/components/sensor/__init__.py b/homeassistant/components/sensor/__init__.py index 35ffc1c3d2a..69b4d1cecd8 100644 --- a/homeassistant/components/sensor/__init__.py +++ b/homeassistant/components/sensor/__init__.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio from collections.abc import Mapping -from contextlib import suppress from dataclasses import dataclass from datetime import date, datetime, timedelta, timezone from decimal import Decimal, InvalidOperation as DecimalInvalidOperation @@ -59,6 +58,7 @@ from homeassistant.helpers.entity_platform import EntityPlatform from homeassistant.helpers.restore_state import ExtraStoredData, RestoreEntity from homeassistant.helpers.typing import UNDEFINED, ConfigType, StateType, UndefinedType from homeassistant.util import dt as dt_util +from homeassistant.util.enum import try_parse_enum from .const import ( # noqa: F401 ATTR_LAST_RESET, @@ -464,11 +464,9 @@ class SensorEntity(Entity): native_unit_of_measurement = self.native_unit_of_measurement unit_of_measurement = self.unit_of_measurement value = self.native_value - device_class: SensorDeviceClass | None = None - with suppress(ValueError): - # For the sake of validation, we can ignore custom device classes - # (customization and legacy style translations) - device_class = SensorDeviceClass(str(self.device_class)) + # For the sake of validation, we can ignore custom device classes + # (customization and legacy style translations) + device_class = try_parse_enum(SensorDeviceClass, self.device_class) state_class = self.state_class # Sensors with device classes indicating a non-numeric value diff --git a/homeassistant/util/enum.py b/homeassistant/util/enum.py new file mode 100644 index 00000000000..7d1e3970586 --- /dev/null +++ b/homeassistant/util/enum.py @@ -0,0 +1,16 @@ +"""Helpers for working with enums.""" +import contextlib +from enum import Enum +from typing import Any, TypeVar + +_EnumT = TypeVar("_EnumT", bound=Enum) + + +def try_parse_enum(cls: type[_EnumT], value: Any) -> _EnumT | None: + """Try to parse the value into an Enum. + + Return None if parsing fails. + """ + with contextlib.suppress(ValueError): + return cls(value) + return None diff --git a/tests/util/test_enum.py b/tests/util/test_enum.py new file mode 100644 index 00000000000..bf30baefa1e --- /dev/null +++ b/tests/util/test_enum.py @@ -0,0 +1,51 @@ +"""Test enum helpers.""" +from enum import Enum, IntEnum, IntFlag +from typing import Any + +import pytest + +from homeassistant.backports.enum import StrEnum +from homeassistant.util.enum import try_parse_enum + + +class _AStrEnum(StrEnum): + VALUE = "value" + + +class _AnIntEnum(IntEnum): + VALUE = 1 + + +class _AnIntFlag(IntFlag): + VALUE = 1 + + +@pytest.mark.parametrize( + "enum_type,value,expected", + [ + # StrEnum valid checks + (_AStrEnum, _AStrEnum.VALUE, _AStrEnum.VALUE), + (_AStrEnum, "value", _AStrEnum.VALUE), + # StrEnum invalid checks + (_AStrEnum, "invalid", None), + (_AStrEnum, 1, None), + (_AStrEnum, None, None), + # IntEnum valid checks + (_AnIntEnum, _AnIntEnum.VALUE, _AnIntEnum.VALUE), + (_AnIntEnum, 1, _AnIntEnum.VALUE), + # IntEnum invalid checks + (_AnIntEnum, "value", None), + (_AnIntEnum, 2, None), + (_AnIntEnum, None, None), + # IntFlag valid checks + (_AnIntFlag, _AnIntFlag.VALUE, _AnIntFlag.VALUE), + (_AnIntFlag, 1, _AnIntFlag.VALUE), + (_AnIntFlag, 2, _AnIntFlag(2)), + # IntFlag invalid checks + (_AnIntFlag, "value", None), + (_AnIntFlag, None, None), + ], +) +def test_try_parse(enum_type: type[Enum], value: Any, expected: Enum | None) -> None: + """Test parsing of values into an Enum.""" + assert try_parse_enum(enum_type, value) is expected