diff --git a/homeassistant/components/button/__init__.py b/homeassistant/components/button/__init__.py index 621effd5d16..583310f0be9 100644 --- a/homeassistant/components/button/__init__.py +++ b/homeassistant/components/button/__init__.py @@ -19,6 +19,7 @@ from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType from homeassistant.util import dt as dt_util +from homeassistant.util.enum import StrEnum from .const import DOMAIN, SERVICE_PRESS @@ -30,12 +31,15 @@ MIN_TIME_BETWEEN_SCANS = timedelta(seconds=10) _LOGGER = logging.getLogger(__name__) -DEVICE_CLASS_RESTART = "restart" -DEVICE_CLASS_UPDATE = "update" -DEVICE_CLASSES = [DEVICE_CLASS_RESTART, DEVICE_CLASS_UPDATE] +class ButtonDeviceClass(StrEnum): + """Device class for buttons.""" -DEVICE_CLASSES_SCHEMA = vol.All(vol.Lower, vol.In(DEVICE_CLASSES)) + RESTART = "restart" + UPDATE = "update" + + +DEVICE_CLASSES_SCHEMA = vol.All(vol.Lower, vol.Coerce(ButtonDeviceClass)) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: @@ -70,16 +74,27 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: class ButtonEntityDescription(EntityDescription): """A class that describes button entities.""" + device_class: ButtonDeviceClass | None = None + class ButtonEntity(RestoreEntity): """Representation of a Button entity.""" entity_description: ButtonEntityDescription _attr_should_poll = False - _attr_device_class: None = None + _attr_device_class: ButtonDeviceClass | None = None _attr_state: None = None __last_pressed: datetime | None = None + @property + def device_class(self) -> ButtonDeviceClass | None: + """Return the class of this entity.""" + if hasattr(self, "_attr_device_class"): + return self._attr_device_class + if hasattr(self, "entity_description"): + return self.entity_description.device_class + return None + @property @final def state(self) -> str | None: diff --git a/homeassistant/components/esphome/button.py b/homeassistant/components/esphome/button.py index 914ddb38da1..5b6f2c153c8 100644 --- a/homeassistant/components/esphome/button.py +++ b/homeassistant/components/esphome/button.py @@ -1,11 +1,12 @@ """Support for ESPHome buttons.""" from __future__ import annotations +from contextlib import suppress from typing import Any from aioesphomeapi import ButtonInfo, EntityState -from homeassistant.components.button import DEVICE_CLASSES, ButtonEntity +from homeassistant.components.button import ButtonDeviceClass, ButtonEntity from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddEntitiesCallback @@ -32,11 +33,11 @@ class EsphomeButton(EsphomeEntity[ButtonInfo, EntityState], ButtonEntity): """A button implementation for ESPHome.""" @property - def device_class(self) -> str | None: - """Return the class of this device, from component DEVICE_CLASSES.""" - if self._static_info.device_class not in DEVICE_CLASSES: - return None - return self._static_info.device_class + def device_class(self) -> ButtonDeviceClass | None: + """Return the class of this entity.""" + with suppress(ValueError): + return ButtonDeviceClass(self._static_info.device_class) + return None @callback def _on_device_update(self) -> None: