From dea80414614c2f201686e840b4d2c1a9411279d8 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Thu, 16 Jun 2022 13:34:54 +0200 Subject: [PATCH] Add device_class to MQTT number and migrate to native_value (#73534) --- homeassistant/components/mqtt/number.py | 33 +++-- homeassistant/components/number/__init__.py | 5 +- tests/components/mqtt/test_number.py | 141 ++++++++++++++------ 3 files changed, 122 insertions(+), 57 deletions(-) diff --git a/homeassistant/components/mqtt/number.py b/homeassistant/components/mqtt/number.py index 1404dc86a3c..bbc78ae07db 100644 --- a/homeassistant/components/mqtt/number.py +++ b/homeassistant/components/mqtt/number.py @@ -11,10 +11,12 @@ from homeassistant.components.number import ( DEFAULT_MAX_VALUE, DEFAULT_MIN_VALUE, DEFAULT_STEP, - NumberEntity, + DEVICE_CLASSES_SCHEMA, + RestoreNumber, ) from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( + CONF_DEVICE_CLASS, CONF_NAME, CONF_OPTIMISTIC, CONF_UNIT_OF_MEASUREMENT, @@ -23,7 +25,6 @@ from homeassistant.const import ( from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import subscription @@ -78,6 +79,7 @@ def validate_config(config): _PLATFORM_SCHEMA_BASE = MQTT_RW_SCHEMA.extend( { vol.Optional(CONF_COMMAND_TEMPLATE): cv.template, + vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA, vol.Optional(CONF_MAX, default=DEFAULT_MAX_VALUE): vol.Coerce(float), vol.Optional(CONF_MIN, default=DEFAULT_MIN_VALUE): vol.Coerce(float), vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, @@ -152,7 +154,7 @@ async def _async_setup_entity( async_add_entities([MqttNumber(hass, config, config_entry, discovery_data)]) -class MqttNumber(MqttEntity, NumberEntity, RestoreEntity): +class MqttNumber(MqttEntity, RestoreNumber): """representation of an MQTT number.""" _entity_id_format = number.ENTITY_ID_FORMAT @@ -166,7 +168,7 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity): self._current_number = None - NumberEntity.__init__(self) + RestoreNumber.__init__(self) MqttEntity.__init__(self, hass, config, config_entry, discovery_data) @staticmethod @@ -243,35 +245,37 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity): """(Re)Subscribe to topics.""" await subscription.async_subscribe_topics(self.hass, self._sub_state) - if self._optimistic and (last_state := await self.async_get_last_state()): - self._current_number = last_state.state + if self._optimistic and ( + last_number_data := await self.async_get_last_number_data() + ): + self._current_number = last_number_data.native_value @property - def min_value(self) -> float: + def native_min_value(self) -> float: """Return the minimum value.""" return self._config[CONF_MIN] @property - def max_value(self) -> float: + def native_max_value(self) -> float: """Return the maximum value.""" return self._config[CONF_MAX] @property - def step(self) -> float: + def native_step(self) -> float: """Return the increment/decrement step.""" return self._config[CONF_STEP] @property - def unit_of_measurement(self) -> str | None: + def native_unit_of_measurement(self) -> str | None: """Return the unit of measurement.""" return self._config.get(CONF_UNIT_OF_MEASUREMENT) @property - def value(self): + def native_value(self): """Return the current value.""" return self._current_number - async def async_set_value(self, value: float) -> None: + async def async_set_native_value(self, value: float) -> None: """Update the current value.""" current_number = value @@ -295,3 +299,8 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity): def assumed_state(self): """Return true if we do optimistic updates.""" return self._optimistic + + @property + def device_class(self) -> str | None: + """Return the device class of the sensor.""" + return self._config.get(CONF_DEVICE_CLASS) diff --git a/homeassistant/components/number/__init__.py b/homeassistant/components/number/__init__.py index f0dc77b7dfb..f0095e2aecb 100644 --- a/homeassistant/components/number/__init__.py +++ b/homeassistant/components/number/__init__.py @@ -8,7 +8,7 @@ from datetime import timedelta import inspect import logging from math import ceil, floor -from typing import Any, final +from typing import Any, Final, final import voluptuous as vol @@ -54,6 +54,9 @@ class NumberDeviceClass(StrEnum): TEMPERATURE = "temperature" +DEVICE_CLASSES_SCHEMA: Final = vol.All(vol.Lower, vol.Coerce(NumberDeviceClass)) + + class NumberMode(StrEnum): """Modes for number entities.""" diff --git a/tests/components/mqtt/test_number.py b/tests/components/mqtt/test_number.py index ea79c5cd7aa..1db7c5e3463 100644 --- a/tests/components/mqtt/test_number.py +++ b/tests/components/mqtt/test_number.py @@ -18,11 +18,14 @@ from homeassistant.components.number import ( ATTR_VALUE, DOMAIN as NUMBER_DOMAIN, SERVICE_SET_VALUE, + NumberDeviceClass, ) from homeassistant.const import ( ATTR_ASSUMED_STATE, + ATTR_DEVICE_CLASS, ATTR_ENTITY_ID, ATTR_UNIT_OF_MEASUREMENT, + TEMP_FAHRENHEIT, Platform, ) import homeassistant.core as ha @@ -58,7 +61,7 @@ from .test_common import ( help_test_update_with_json_attrs_not_dict, ) -from tests.common import async_fire_mqtt_message +from tests.common import async_fire_mqtt_message, mock_restore_cache_with_extra_data DEFAULT_CONFIG = { number.DOMAIN: {"platform": "mqtt", "name": "test", "command_topic": "test-topic"} @@ -84,7 +87,8 @@ async def test_run_number_setup(hass, mqtt_mock_entry_with_yaml_config): "state_topic": topic, "command_topic": topic, "name": "Test Number", - "unit_of_measurement": "my unit", + "device_class": "temperature", + "unit_of_measurement": TEMP_FAHRENHEIT, "payload_reset": "reset!", } }, @@ -97,16 +101,18 @@ async def test_run_number_setup(hass, mqtt_mock_entry_with_yaml_config): await hass.async_block_till_done() state = hass.states.get("number.test_number") - assert state.state == "10" - assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "my unit" + assert state.state == "-12.0" # 10 °F -> -12 °C + assert state.attributes.get(ATTR_DEVICE_CLASS) == NumberDeviceClass.TEMPERATURE + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "°C" async_fire_mqtt_message(hass, topic, "20.5") await hass.async_block_till_done() state = hass.states.get("number.test_number") - assert state.state == "20.5" - assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "my unit" + assert state.state == "-6.4" # 20.5 °F -> -6.4 °C + assert state.attributes.get(ATTR_DEVICE_CLASS) == NumberDeviceClass.TEMPERATURE + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "°C" async_fire_mqtt_message(hass, topic, "reset!") @@ -114,7 +120,8 @@ async def test_run_number_setup(hass, mqtt_mock_entry_with_yaml_config): state = hass.states.get("number.test_number") assert state.state == "unknown" - assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "my unit" + assert state.attributes.get(ATTR_DEVICE_CLASS) == NumberDeviceClass.TEMPERATURE + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "°C" async def test_value_template(hass, mqtt_mock_entry_with_yaml_config): @@ -158,29 +165,70 @@ async def test_value_template(hass, mqtt_mock_entry_with_yaml_config): assert state.state == "unknown" +async def test_restore_native_value(hass, mqtt_mock_entry_with_yaml_config): + """Test that the stored native_value is restored.""" + topic = "test/number" + + RESTORE_DATA = { + "native_max_value": None, # Ignored by MQTT number + "native_min_value": None, # Ignored by MQTT number + "native_step": None, # Ignored by MQTT number + "native_unit_of_measurement": None, # Ignored by MQTT number + "native_value": 100.0, + } + + mock_restore_cache_with_extra_data( + hass, ((ha.State("number.test_number", "abc"), RESTORE_DATA),) + ) + assert await async_setup_component( + hass, + number.DOMAIN, + { + "number": { + "platform": "mqtt", + "command_topic": topic, + "device_class": "temperature", + "unit_of_measurement": TEMP_FAHRENHEIT, + "name": "Test Number", + } + }, + ) + await hass.async_block_till_done() + await mqtt_mock_entry_with_yaml_config() + + state = hass.states.get("number.test_number") + assert state.state == "37.8" + assert state.attributes.get(ATTR_ASSUMED_STATE) + + async def test_run_number_service_optimistic(hass, mqtt_mock_entry_with_yaml_config): """Test that set_value service works in optimistic mode.""" topic = "test/number" - fake_state = ha.State("switch.test", "3") + RESTORE_DATA = { + "native_max_value": None, # Ignored by MQTT number + "native_min_value": None, # Ignored by MQTT number + "native_step": None, # Ignored by MQTT number + "native_unit_of_measurement": None, # Ignored by MQTT number + "native_value": 3, + } - with patch( - "homeassistant.helpers.restore_state.RestoreEntity.async_get_last_state", - return_value=fake_state, - ): - assert await async_setup_component( - hass, - number.DOMAIN, - { - "number": { - "platform": "mqtt", - "command_topic": topic, - "name": "Test Number", - } - }, - ) - await hass.async_block_till_done() - mqtt_mock = await mqtt_mock_entry_with_yaml_config() + mock_restore_cache_with_extra_data( + hass, ((ha.State("number.test_number", "abc"), RESTORE_DATA),) + ) + assert await async_setup_component( + hass, + number.DOMAIN, + { + "number": { + "platform": "mqtt", + "command_topic": topic, + "name": "Test Number", + } + }, + ) + await hass.async_block_till_done() + mqtt_mock = await mqtt_mock_entry_with_yaml_config() state = hass.states.get("number.test_number") assert state.state == "3" @@ -232,26 +280,31 @@ async def test_run_number_service_optimistic_with_command_template( """Test that set_value service works in optimistic mode and with a command_template.""" topic = "test/number" - fake_state = ha.State("switch.test", "3") + RESTORE_DATA = { + "native_max_value": None, # Ignored by MQTT number + "native_min_value": None, # Ignored by MQTT number + "native_step": None, # Ignored by MQTT number + "native_unit_of_measurement": None, # Ignored by MQTT number + "native_value": 3, + } - with patch( - "homeassistant.helpers.restore_state.RestoreEntity.async_get_last_state", - return_value=fake_state, - ): - assert await async_setup_component( - hass, - number.DOMAIN, - { - "number": { - "platform": "mqtt", - "command_topic": topic, - "name": "Test Number", - "command_template": '{"number": {{ value }} }', - } - }, - ) - await hass.async_block_till_done() - mqtt_mock = await mqtt_mock_entry_with_yaml_config() + mock_restore_cache_with_extra_data( + hass, ((ha.State("number.test_number", "abc"), RESTORE_DATA),) + ) + assert await async_setup_component( + hass, + number.DOMAIN, + { + "number": { + "platform": "mqtt", + "command_topic": topic, + "name": "Test Number", + "command_template": '{"number": {{ value }} }', + } + }, + ) + await hass.async_block_till_done() + mqtt_mock = await mqtt_mock_entry_with_yaml_config() state = hass.states.get("number.test_number") assert state.state == "3"