diff --git a/homeassistant/components/number/__init__.py b/homeassistant/components/number/__init__.py index fe438ea6aea..1820e28bc4c 100644 --- a/homeassistant/components/number/__init__.py +++ b/homeassistant/components/number/__init__.py @@ -14,8 +14,13 @@ import voluptuous as vol from homeassistant.backports.enum import StrEnum from homeassistant.config_entries import ConfigEntry -from homeassistant.const import ATTR_MODE, TEMP_CELSIUS, TEMP_FAHRENHEIT -from homeassistant.core import HomeAssistant, ServiceCall +from homeassistant.const import ( + ATTR_MODE, + CONF_UNIT_OF_MEASUREMENT, + TEMP_CELSIUS, + TEMP_FAHRENHEIT, +) +from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.helpers.config_validation import ( # noqa: F401 PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, @@ -69,6 +74,10 @@ UNIT_CONVERSIONS: dict[str, Callable[[float, str, str], float]] = { NumberDeviceClass.TEMPERATURE: temperature_util.convert, } +VALID_UNITS: dict[str, tuple[str, ...]] = { + NumberDeviceClass.TEMPERATURE: temperature_util.VALID_UNITS, +} + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up Number entities.""" @@ -193,6 +202,7 @@ class NumberEntity(Entity): _attr_native_value: float _attr_native_unit_of_measurement: str | None _deprecated_number_entity_reported = False + _number_option_unit_of_measurement: str | None = None def __init_subclass__(cls, **kwargs: Any) -> None: """Post initialisation processing.""" @@ -226,6 +236,13 @@ class NumberEntity(Entity): report_issue, ) + async def async_internal_added_to_hass(self) -> None: + """Call when the number entity is added to hass.""" + await super().async_internal_added_to_hass() + if not self.registry_entry: + return + self.async_registry_entry_updated() + @property def capability_attributes(self) -> dict[str, Any]: """Return capability attributes.""" @@ -348,6 +365,9 @@ class NumberEntity(Entity): @final def unit_of_measurement(self) -> str | None: """Return the unit of measurement of the entity, after unit conversion.""" + if self._number_option_unit_of_measurement: + return self._number_option_unit_of_measurement + if hasattr(self, "_attr_unit_of_measurement"): return self._attr_unit_of_measurement if ( @@ -467,6 +487,22 @@ class NumberEntity(Entity): report_issue, ) + @callback + def async_registry_entry_updated(self) -> None: + """Run when the entity registry entry has been updated.""" + assert self.registry_entry + if ( + (number_options := self.registry_entry.options.get(DOMAIN)) + and (custom_unit := number_options.get(CONF_UNIT_OF_MEASUREMENT)) + and (device_class := self.device_class) in UNIT_CONVERSIONS + and self.native_unit_of_measurement in VALID_UNITS[device_class] + and custom_unit in VALID_UNITS[device_class] + ): + self._number_option_unit_of_measurement = custom_unit + return + + self._number_option_unit_of_measurement = None + @dataclasses.dataclass class NumberExtraStoredData(ExtraStoredData): diff --git a/tests/components/number/test_init.py b/tests/components/number/test_init.py index 9921d2a639e..8d7f8a91ae8 100644 --- a/tests/components/number/test_init.py +++ b/tests/components/number/test_init.py @@ -22,6 +22,7 @@ from homeassistant.const import ( TEMP_FAHRENHEIT, ) from homeassistant.core import HomeAssistant, State +from homeassistant.helpers import entity_registry as er from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY from homeassistant.setup import async_setup_component from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM @@ -689,3 +690,161 @@ async def test_restore_number_restore_state( assert entity0.native_value == native_value assert type(entity0.native_value) == native_value_type assert entity0.native_unit_of_measurement == uom + + +@pytest.mark.parametrize( + "device_class,native_unit,custom_unit,state_unit,native_value,custom_value", + [ + # Not a supported temperature unit + ( + NumberDeviceClass.TEMPERATURE, + TEMP_CELSIUS, + "my_temperature_unit", + TEMP_CELSIUS, + 1000, + 1000, + ), + ( + NumberDeviceClass.TEMPERATURE, + TEMP_CELSIUS, + TEMP_FAHRENHEIT, + TEMP_FAHRENHEIT, + 37.5, + 99.5, + ), + ( + NumberDeviceClass.TEMPERATURE, + TEMP_FAHRENHEIT, + TEMP_CELSIUS, + TEMP_CELSIUS, + 100, + 38.0, + ), + ], +) +async def test_custom_unit( + hass, + enable_custom_integrations, + device_class, + native_unit, + custom_unit, + state_unit, + native_value, + custom_value, +): + """Test custom unit.""" + entity_registry = er.async_get(hass) + + entry = entity_registry.async_get_or_create("number", "test", "very_unique") + entity_registry.async_update_entity_options( + entry.entity_id, "number", {"unit_of_measurement": custom_unit} + ) + await hass.async_block_till_done() + + platform = getattr(hass.components, "test.number") + platform.init(empty=True) + platform.ENTITIES.append( + platform.MockNumberEntity( + name="Test", + native_value=native_value, + native_unit_of_measurement=native_unit, + device_class=device_class, + unique_id="very_unique", + ) + ) + + entity0 = platform.ENTITIES[0] + assert await async_setup_component(hass, "number", {"number": {"platform": "test"}}) + await hass.async_block_till_done() + + state = hass.states.get(entity0.entity_id) + assert float(state.state) == pytest.approx(float(custom_value)) + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == state_unit + + +@pytest.mark.parametrize( + "native_unit, custom_unit, used_custom_unit, default_unit, native_value, custom_value, default_value", + [ + ( + TEMP_CELSIUS, + TEMP_FAHRENHEIT, + TEMP_FAHRENHEIT, + TEMP_CELSIUS, + 37.5, + 99.5, + 37.5, + ), + ( + TEMP_FAHRENHEIT, + TEMP_FAHRENHEIT, + TEMP_FAHRENHEIT, + TEMP_CELSIUS, + 100, + 100, + 38.0, + ), + # Not a supported temperature unit + (TEMP_CELSIUS, "no_unit", TEMP_CELSIUS, TEMP_CELSIUS, 1000, 1000, 1000), + ], +) +async def test_custom_unit_change( + hass, + enable_custom_integrations, + native_unit, + custom_unit, + used_custom_unit, + default_unit, + native_value, + custom_value, + default_value, +): + """Test custom unit changes are picked up.""" + entity_registry = er.async_get(hass) + platform = getattr(hass.components, "test.number") + platform.init(empty=True) + platform.ENTITIES.append( + platform.MockNumberEntity( + name="Test", + native_value=native_value, + native_unit_of_measurement=native_unit, + device_class=NumberDeviceClass.TEMPERATURE, + unique_id="very_unique", + ) + ) + + entity0 = platform.ENTITIES[0] + assert await async_setup_component(hass, "number", {"number": {"platform": "test"}}) + await hass.async_block_till_done() + + # Default unit conversion according to unit system + state = hass.states.get(entity0.entity_id) + assert float(state.state) == pytest.approx(float(default_value)) + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == default_unit + + entity_registry.async_update_entity_options( + "number.test", "number", {"unit_of_measurement": custom_unit} + ) + await hass.async_block_till_done() + + # Unit conversion to the custom unit + state = hass.states.get(entity0.entity_id) + assert float(state.state) == pytest.approx(float(custom_value)) + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == used_custom_unit + + entity_registry.async_update_entity_options( + "number.test", "number", {"unit_of_measurement": native_unit} + ) + await hass.async_block_till_done() + + # Unit conversion to another custom unit + state = hass.states.get(entity0.entity_id) + assert float(state.state) == pytest.approx(float(native_value)) + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == native_unit + + entity_registry.async_update_entity_options("number.test", "number", None) + await hass.async_block_till_done() + + # Default unit conversion according to unit system + state = hass.states.get(entity0.entity_id) + assert float(state.state) == pytest.approx(float(default_value)) + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == default_unit