Add device_class to MQTT number and migrate to native_value (#73534)

This commit is contained in:
Erik Montnemery 2022-06-16 13:34:54 +02:00 committed by GitHub
parent 67b0354632
commit dea8041461
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 122 additions and 57 deletions

View File

@ -11,10 +11,12 @@ from homeassistant.components.number import (
DEFAULT_MAX_VALUE, DEFAULT_MAX_VALUE,
DEFAULT_MIN_VALUE, DEFAULT_MIN_VALUE,
DEFAULT_STEP, DEFAULT_STEP,
NumberEntity, DEVICE_CLASSES_SCHEMA,
RestoreNumber,
) )
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
CONF_DEVICE_CLASS,
CONF_NAME, CONF_NAME,
CONF_OPTIMISTIC, CONF_OPTIMISTIC,
CONF_UNIT_OF_MEASUREMENT, CONF_UNIT_OF_MEASUREMENT,
@ -23,7 +25,6 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import subscription from . import subscription
@ -78,6 +79,7 @@ def validate_config(config):
_PLATFORM_SCHEMA_BASE = MQTT_RW_SCHEMA.extend( _PLATFORM_SCHEMA_BASE = MQTT_RW_SCHEMA.extend(
{ {
vol.Optional(CONF_COMMAND_TEMPLATE): cv.template, vol.Optional(CONF_COMMAND_TEMPLATE): cv.template,
vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA,
vol.Optional(CONF_MAX, default=DEFAULT_MAX_VALUE): vol.Coerce(float), vol.Optional(CONF_MAX, default=DEFAULT_MAX_VALUE): vol.Coerce(float),
vol.Optional(CONF_MIN, default=DEFAULT_MIN_VALUE): vol.Coerce(float), vol.Optional(CONF_MIN, default=DEFAULT_MIN_VALUE): vol.Coerce(float),
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
@ -152,7 +154,7 @@ async def _async_setup_entity(
async_add_entities([MqttNumber(hass, config, config_entry, discovery_data)]) async_add_entities([MqttNumber(hass, config, config_entry, discovery_data)])
class MqttNumber(MqttEntity, NumberEntity, RestoreEntity): class MqttNumber(MqttEntity, RestoreNumber):
"""representation of an MQTT number.""" """representation of an MQTT number."""
_entity_id_format = number.ENTITY_ID_FORMAT _entity_id_format = number.ENTITY_ID_FORMAT
@ -166,7 +168,7 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity):
self._current_number = None self._current_number = None
NumberEntity.__init__(self) RestoreNumber.__init__(self)
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod @staticmethod
@ -243,35 +245,37 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) await subscription.async_subscribe_topics(self.hass, self._sub_state)
if self._optimistic and (last_state := await self.async_get_last_state()): if self._optimistic and (
self._current_number = last_state.state last_number_data := await self.async_get_last_number_data()
):
self._current_number = last_number_data.native_value
@property @property
def min_value(self) -> float: def native_min_value(self) -> float:
"""Return the minimum value.""" """Return the minimum value."""
return self._config[CONF_MIN] return self._config[CONF_MIN]
@property @property
def max_value(self) -> float: def native_max_value(self) -> float:
"""Return the maximum value.""" """Return the maximum value."""
return self._config[CONF_MAX] return self._config[CONF_MAX]
@property @property
def step(self) -> float: def native_step(self) -> float:
"""Return the increment/decrement step.""" """Return the increment/decrement step."""
return self._config[CONF_STEP] return self._config[CONF_STEP]
@property @property
def unit_of_measurement(self) -> str | None: def native_unit_of_measurement(self) -> str | None:
"""Return the unit of measurement.""" """Return the unit of measurement."""
return self._config.get(CONF_UNIT_OF_MEASUREMENT) return self._config.get(CONF_UNIT_OF_MEASUREMENT)
@property @property
def value(self): def native_value(self):
"""Return the current value.""" """Return the current value."""
return self._current_number return self._current_number
async def async_set_value(self, value: float) -> None: async def async_set_native_value(self, value: float) -> None:
"""Update the current value.""" """Update the current value."""
current_number = value current_number = value
@ -295,3 +299,8 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity):
def assumed_state(self): def assumed_state(self):
"""Return true if we do optimistic updates.""" """Return true if we do optimistic updates."""
return self._optimistic return self._optimistic
@property
def device_class(self) -> str | None:
"""Return the device class of the sensor."""
return self._config.get(CONF_DEVICE_CLASS)

View File

@ -8,7 +8,7 @@ from datetime import timedelta
import inspect import inspect
import logging import logging
from math import ceil, floor from math import ceil, floor
from typing import Any, final from typing import Any, Final, final
import voluptuous as vol import voluptuous as vol
@ -54,6 +54,9 @@ class NumberDeviceClass(StrEnum):
TEMPERATURE = "temperature" TEMPERATURE = "temperature"
DEVICE_CLASSES_SCHEMA: Final = vol.All(vol.Lower, vol.Coerce(NumberDeviceClass))
class NumberMode(StrEnum): class NumberMode(StrEnum):
"""Modes for number entities.""" """Modes for number entities."""

View File

@ -18,11 +18,14 @@ from homeassistant.components.number import (
ATTR_VALUE, ATTR_VALUE,
DOMAIN as NUMBER_DOMAIN, DOMAIN as NUMBER_DOMAIN,
SERVICE_SET_VALUE, SERVICE_SET_VALUE,
NumberDeviceClass,
) )
from homeassistant.const import ( from homeassistant.const import (
ATTR_ASSUMED_STATE, ATTR_ASSUMED_STATE,
ATTR_DEVICE_CLASS,
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
ATTR_UNIT_OF_MEASUREMENT, ATTR_UNIT_OF_MEASUREMENT,
TEMP_FAHRENHEIT,
Platform, Platform,
) )
import homeassistant.core as ha import homeassistant.core as ha
@ -58,7 +61,7 @@ from .test_common import (
help_test_update_with_json_attrs_not_dict, help_test_update_with_json_attrs_not_dict,
) )
from tests.common import async_fire_mqtt_message from tests.common import async_fire_mqtt_message, mock_restore_cache_with_extra_data
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
number.DOMAIN: {"platform": "mqtt", "name": "test", "command_topic": "test-topic"} number.DOMAIN: {"platform": "mqtt", "name": "test", "command_topic": "test-topic"}
@ -84,7 +87,8 @@ async def test_run_number_setup(hass, mqtt_mock_entry_with_yaml_config):
"state_topic": topic, "state_topic": topic,
"command_topic": topic, "command_topic": topic,
"name": "Test Number", "name": "Test Number",
"unit_of_measurement": "my unit", "device_class": "temperature",
"unit_of_measurement": TEMP_FAHRENHEIT,
"payload_reset": "reset!", "payload_reset": "reset!",
} }
}, },
@ -97,16 +101,18 @@ async def test_run_number_setup(hass, mqtt_mock_entry_with_yaml_config):
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("number.test_number") state = hass.states.get("number.test_number")
assert state.state == "10" assert state.state == "-12.0" # 10 °F -> -12 °C
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "my unit" assert state.attributes.get(ATTR_DEVICE_CLASS) == NumberDeviceClass.TEMPERATURE
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "°C"
async_fire_mqtt_message(hass, topic, "20.5") async_fire_mqtt_message(hass, topic, "20.5")
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("number.test_number") state = hass.states.get("number.test_number")
assert state.state == "20.5" assert state.state == "-6.4" # 20.5 °F -> -6.4 °C
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "my unit" assert state.attributes.get(ATTR_DEVICE_CLASS) == NumberDeviceClass.TEMPERATURE
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "°C"
async_fire_mqtt_message(hass, topic, "reset!") async_fire_mqtt_message(hass, topic, "reset!")
@ -114,7 +120,8 @@ async def test_run_number_setup(hass, mqtt_mock_entry_with_yaml_config):
state = hass.states.get("number.test_number") state = hass.states.get("number.test_number")
assert state.state == "unknown" assert state.state == "unknown"
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "my unit" assert state.attributes.get(ATTR_DEVICE_CLASS) == NumberDeviceClass.TEMPERATURE
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "°C"
async def test_value_template(hass, mqtt_mock_entry_with_yaml_config): async def test_value_template(hass, mqtt_mock_entry_with_yaml_config):
@ -158,29 +165,70 @@ async def test_value_template(hass, mqtt_mock_entry_with_yaml_config):
assert state.state == "unknown" assert state.state == "unknown"
async def test_restore_native_value(hass, mqtt_mock_entry_with_yaml_config):
"""Test that the stored native_value is restored."""
topic = "test/number"
RESTORE_DATA = {
"native_max_value": None, # Ignored by MQTT number
"native_min_value": None, # Ignored by MQTT number
"native_step": None, # Ignored by MQTT number
"native_unit_of_measurement": None, # Ignored by MQTT number
"native_value": 100.0,
}
mock_restore_cache_with_extra_data(
hass, ((ha.State("number.test_number", "abc"), RESTORE_DATA),)
)
assert await async_setup_component(
hass,
number.DOMAIN,
{
"number": {
"platform": "mqtt",
"command_topic": topic,
"device_class": "temperature",
"unit_of_measurement": TEMP_FAHRENHEIT,
"name": "Test Number",
}
},
)
await hass.async_block_till_done()
await mqtt_mock_entry_with_yaml_config()
state = hass.states.get("number.test_number")
assert state.state == "37.8"
assert state.attributes.get(ATTR_ASSUMED_STATE)
async def test_run_number_service_optimistic(hass, mqtt_mock_entry_with_yaml_config): async def test_run_number_service_optimistic(hass, mqtt_mock_entry_with_yaml_config):
"""Test that set_value service works in optimistic mode.""" """Test that set_value service works in optimistic mode."""
topic = "test/number" topic = "test/number"
fake_state = ha.State("switch.test", "3") RESTORE_DATA = {
"native_max_value": None, # Ignored by MQTT number
"native_min_value": None, # Ignored by MQTT number
"native_step": None, # Ignored by MQTT number
"native_unit_of_measurement": None, # Ignored by MQTT number
"native_value": 3,
}
with patch( mock_restore_cache_with_extra_data(
"homeassistant.helpers.restore_state.RestoreEntity.async_get_last_state", hass, ((ha.State("number.test_number", "abc"), RESTORE_DATA),)
return_value=fake_state, )
): assert await async_setup_component(
assert await async_setup_component( hass,
hass, number.DOMAIN,
number.DOMAIN, {
{ "number": {
"number": { "platform": "mqtt",
"platform": "mqtt", "command_topic": topic,
"command_topic": topic, "name": "Test Number",
"name": "Test Number", }
} },
}, )
) await hass.async_block_till_done()
await hass.async_block_till_done() mqtt_mock = await mqtt_mock_entry_with_yaml_config()
mqtt_mock = await mqtt_mock_entry_with_yaml_config()
state = hass.states.get("number.test_number") state = hass.states.get("number.test_number")
assert state.state == "3" assert state.state == "3"
@ -232,26 +280,31 @@ async def test_run_number_service_optimistic_with_command_template(
"""Test that set_value service works in optimistic mode and with a command_template.""" """Test that set_value service works in optimistic mode and with a command_template."""
topic = "test/number" topic = "test/number"
fake_state = ha.State("switch.test", "3") RESTORE_DATA = {
"native_max_value": None, # Ignored by MQTT number
"native_min_value": None, # Ignored by MQTT number
"native_step": None, # Ignored by MQTT number
"native_unit_of_measurement": None, # Ignored by MQTT number
"native_value": 3,
}
with patch( mock_restore_cache_with_extra_data(
"homeassistant.helpers.restore_state.RestoreEntity.async_get_last_state", hass, ((ha.State("number.test_number", "abc"), RESTORE_DATA),)
return_value=fake_state, )
): assert await async_setup_component(
assert await async_setup_component( hass,
hass, number.DOMAIN,
number.DOMAIN, {
{ "number": {
"number": { "platform": "mqtt",
"platform": "mqtt", "command_topic": topic,
"command_topic": topic, "name": "Test Number",
"name": "Test Number", "command_template": '{"number": {{ value }} }',
"command_template": '{"number": {{ value }} }', }
} },
}, )
) await hass.async_block_till_done()
await hass.async_block_till_done() mqtt_mock = await mqtt_mock_entry_with_yaml_config()
mqtt_mock = await mqtt_mock_entry_with_yaml_config()
state = hass.states.get("number.test_number") state = hass.states.get("number.test_number")
assert state.state == "3" assert state.state == "3"