diff --git a/homeassistant/components/sensor/mqtt.py b/homeassistant/components/sensor/mqtt.py index d7d66a3a145..997fd312a6a 100644 --- a/homeassistant/components/sensor/mqtt.py +++ b/homeassistant/components/sensor/mqtt.py @@ -15,9 +15,10 @@ from homeassistant.core import callback from homeassistant.components.mqtt import ( CONF_AVAILABILITY_TOPIC, CONF_STATE_TOPIC, CONF_PAYLOAD_AVAILABLE, CONF_PAYLOAD_NOT_AVAILABLE, CONF_QOS, MqttAvailability) +from homeassistant.components.sensor import DEVICE_CLASSES_SCHEMA from homeassistant.const import ( CONF_FORCE_UPDATE, CONF_NAME, CONF_VALUE_TEMPLATE, STATE_UNKNOWN, - CONF_UNIT_OF_MEASUREMENT, CONF_ICON) + CONF_UNIT_OF_MEASUREMENT, CONF_ICON, CONF_DEVICE_CLASS) from homeassistant.helpers.entity import Entity import homeassistant.components.mqtt as mqtt import homeassistant.helpers.config_validation as cv @@ -39,6 +40,7 @@ PLATFORM_SCHEMA = mqtt.MQTT_RO_PLATFORM_SCHEMA.extend({ vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, vol.Optional(CONF_ICON): cv.icon, + vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA, vol.Optional(CONF_JSON_ATTRS, default=[]): cv.ensure_list_csv, vol.Optional(CONF_EXPIRE_AFTER): cv.positive_int, vol.Optional(CONF_FORCE_UPDATE, default=DEFAULT_FORCE_UPDATE): cv.boolean, @@ -66,6 +68,7 @@ async def async_setup_platform(hass: HomeAssistantType, config: ConfigType, config.get(CONF_FORCE_UPDATE), config.get(CONF_EXPIRE_AFTER), config.get(CONF_ICON), + config.get(CONF_DEVICE_CLASS), value_template, config.get(CONF_JSON_ATTRS), config.get(CONF_UNIQUE_ID), @@ -79,8 +82,8 @@ class MqttSensor(MqttAvailability, Entity): """Representation of a sensor that can be updated using MQTT.""" def __init__(self, name, state_topic, qos, unit_of_measurement, - force_update, expire_after, icon, value_template, - json_attributes, unique_id: Optional[str], + force_update, expire_after, icon, device_class: Optional[str], + value_template, json_attributes, unique_id: Optional[str], availability_topic, payload_available, payload_not_available): """Initialize the sensor.""" @@ -95,6 +98,7 @@ class MqttSensor(MqttAvailability, Entity): self._template = value_template self._expire_after = expire_after self._icon = icon + self._device_class = device_class self._expiration_trigger = None self._json_attributes = set(json_attributes) self._unique_id = unique_id @@ -190,3 +194,8 @@ class MqttSensor(MqttAvailability, Entity): def icon(self): """Return the icon.""" return self._icon + + @property + def device_class(self) -> Optional[str]: + """Return the device class of the sensor.""" + return self._device_class diff --git a/tests/components/sensor/test_mqtt.py b/tests/components/sensor/test_mqtt.py index 88e74e11008..2583f52b3d2 100644 --- a/tests/components/sensor/test_mqtt.py +++ b/tests/components/sensor/test_mqtt.py @@ -10,7 +10,8 @@ import homeassistant.components.sensor as sensor from homeassistant.const import EVENT_STATE_CHANGED, STATE_UNAVAILABLE import homeassistant.util.dt as dt_util -from tests.common import mock_mqtt_component, fire_mqtt_message +from tests.common import mock_mqtt_component, fire_mqtt_message, \ + assert_setup_component from tests.common import get_test_home_assistant, mock_component @@ -350,3 +351,36 @@ class TestSensorMQTT(unittest.TestCase): self.hass.block_till_done() assert len(self.hass.states.all()) == 1 + + def test_invalid_device_class(self): + """Test device_class option with invalid value.""" + with assert_setup_component(0): + assert setup_component(self.hass, 'sensor', { + 'sensor': { + 'platform': 'mqtt', + 'name': 'Test 1', + 'state_topic': 'test-topic', + 'device_class': 'foobarnotreal' + } + }) + + def test_valid_device_class(self): + """Test device_class option with valid values.""" + assert setup_component(self.hass, 'sensor', { + 'sensor': [{ + 'platform': 'mqtt', + 'name': 'Test 1', + 'state_topic': 'test-topic', + 'device_class': 'temperature' + }, { + 'platform': 'mqtt', + 'name': 'Test 2', + 'state_topic': 'test-topic', + }] + }) + self.hass.block_till_done() + + state = self.hass.states.get('sensor.test_1') + assert state.attributes['device_class'] == 'temperature' + state = self.hass.states.get('sensor.test_2') + assert 'device_class' not in state.attributes