diff --git a/homeassistant/components/mqtt/button.py b/homeassistant/components/mqtt/button.py index 4006b8bfab9..4b8931375b9 100644 --- a/homeassistant/components/mqtt/button.py +++ b/homeassistant/components/mqtt/button.py @@ -6,8 +6,8 @@ import functools import voluptuous as vol from homeassistant.components import button -from homeassistant.components.button import ButtonEntity -from homeassistant.const import CONF_NAME +from homeassistant.components.button import ButtonDeviceClass, ButtonEntity +from homeassistant.const import CONF_DEVICE_CLASS, CONF_NAME from homeassistant.core import HomeAssistant import homeassistant.helpers.config_validation as cv from homeassistant.helpers.reload import async_setup_reload_service @@ -25,6 +25,7 @@ DEFAULT_PAYLOAD_PRESS = "PRESS" PLATFORM_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend( { vol.Required(CONF_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_DEVICE_CLASS): button.DEVICE_CLASSES_SCHEMA, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_PAYLOAD_PRESS, default=DEFAULT_PAYLOAD_PRESS): cv.string, vol.Optional(CONF_RETAIN, default=mqtt.DEFAULT_RETAIN): cv.boolean, @@ -74,6 +75,11 @@ class MqttButton(MqttEntity, ButtonEntity): async def _subscribe_topics(self): """(Re)Subscribe to topics.""" + @property + def device_class(self) -> ButtonDeviceClass | None: + """Return the device class of the sensor.""" + return self._config.get(CONF_DEVICE_CLASS) + async def async_press(self, **kwargs): """Turn the device on. diff --git a/tests/components/mqtt/test_button.py b/tests/components/mqtt/test_button.py index 0a1d8af2a0c..5eb92db7767 100644 --- a/tests/components/mqtt/test_button.py +++ b/tests/components/mqtt/test_button.py @@ -264,3 +264,60 @@ async def test_entity_id_update_discovery_update(hass, mqtt_mock): await help_test_entity_id_update_discovery_update( hass, mqtt_mock, button.DOMAIN, DEFAULT_CONFIG ) + + +async def test_invalid_device_class(hass, mqtt_mock): + """Test device_class option with invalid value.""" + assert await async_setup_component( + hass, + button.DOMAIN, + { + button.DOMAIN: { + "platform": "mqtt", + "name": "test", + "state_topic": "test-topic", + "device_class": "foobarnotreal", + } + }, + ) + await hass.async_block_till_done() + + state = hass.states.get("button.test") + assert state is None + + +async def test_valid_device_class(hass, mqtt_mock): + """Test device_class option with valid values.""" + assert await async_setup_component( + hass, + button.DOMAIN, + { + button.DOMAIN: [ + { + "platform": "mqtt", + "name": "Test 1", + "command_topic": "test-topic", + "device_class": "update", + }, + { + "platform": "mqtt", + "name": "Test 2", + "command_topic": "test-topic", + "device_class": "restart", + }, + { + "platform": "mqtt", + "name": "Test 3", + "command_topic": "test-topic", + }, + ] + }, + ) + await hass.async_block_till_done() + + state = hass.states.get("button.test_1") + assert state.attributes["device_class"] == button.ButtonDeviceClass.UPDATE + state = hass.states.get("button.test_2") + assert state.attributes["device_class"] == button.ButtonDeviceClass.RESTART + state = hass.states.get("button.test_3") + assert "device_class" not in state.attributes