From 91f6e58e9aa15f5fc28efcab4e723aa324280b1a Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 30 Mar 2022 15:43:04 +0200 Subject: [PATCH] Allow customizing unit for temperature and pressure sensors (#64366) * Allow customizing unit for temperature and pressure sensors * pylint * Adjust google_wifi tests * Address review comments and add tests * Improve rounding when scaling * Tweak rounding * Further tweak rounding * Allow setting entity options with config/entity_registry/update * Address review comments * Tweak tests * Load custom unit when sensor is added * Override async_internal_added_to_hass * Adjust tests after rebase * Apply suggestions from code review Co-authored-by: Paulus Schoutsen * Address review comments Co-authored-by: Paulus Schoutsen --- .../components/config/entity_registry.py | 34 +++-- homeassistant/components/sensor/__init__.py | 112 +++++++++++++++-- homeassistant/helpers/entity.py | 8 ++ homeassistant/helpers/entity_registry.py | 6 +- homeassistant/util/temperature.py | 6 + .../components/config/test_entity_registry.py | 42 +++++++ tests/components/google_wifi/test_sensor.py | 4 +- tests/components/sensor/test_init.py | 119 ++++++++++++++++++ 8 files changed, 306 insertions(+), 25 deletions(-) diff --git a/homeassistant/components/config/entity_registry.py b/homeassistant/components/config/entity_registry.py index 719e028cab1..2bb585e12c6 100644 --- a/homeassistant/components/config/entity_registry.py +++ b/homeassistant/components/config/entity_registry.py @@ -86,6 +86,8 @@ def websocket_get_entity(hass, connection, msg): er.RegistryEntryHider.USER.value, ), ), + vol.Inclusive("options_domain", "entity_option"): str, + vol.Inclusive("options", "entity_option"): vol.Any(None, dict), } ) @callback @@ -96,7 +98,8 @@ def websocket_update_entity(hass, connection, msg): """ registry = er.async_get(hass) - if msg["entity_id"] not in registry.entities: + entity_id = msg["entity_id"] + if not (entity_entry := registry.async_get(entity_id)): connection.send_message( websocket_api.error_message(msg["id"], ERR_NOT_FOUND, "Entity not found") ) @@ -108,7 +111,7 @@ def websocket_update_entity(hass, connection, msg): if key in msg: changes[key] = msg[key] - if "new_entity_id" in msg and msg["new_entity_id"] != msg["entity_id"]: + if "new_entity_id" in msg and msg["new_entity_id"] != entity_id: changes["new_entity_id"] = msg["new_entity_id"] if hass.states.get(msg["new_entity_id"]) is not None: connection.send_message( @@ -122,10 +125,9 @@ def websocket_update_entity(hass, connection, msg): if "disabled_by" in msg and msg["disabled_by"] is None: # Don't allow enabling an entity of a disabled device - entity = registry.entities[msg["entity_id"]] - if entity.device_id: + if entity_entry.device_id: device_registry = dr.async_get(hass) - device = device_registry.async_get(entity.device_id) + device = device_registry.async_get(entity_entry.device_id) if device.disabled: connection.send_message( websocket_api.error_message( @@ -136,16 +138,31 @@ def websocket_update_entity(hass, connection, msg): try: if changes: - entry = registry.async_update_entity(msg["entity_id"], **changes) + entity_entry = registry.async_update_entity(entity_id, **changes) except ValueError as err: connection.send_message( websocket_api.error_message(msg["id"], "invalid_info", str(err)) ) return - result = {"entity_entry": _entry_ext_dict(entry)} + + if "new_entity_id" in msg: + entity_id = msg["new_entity_id"] + + try: + if "options_domain" in msg: + entity_entry = registry.async_update_entity_options( + entity_id, msg["options_domain"], msg["options"] + ) + except ValueError as err: + connection.send_message( + websocket_api.error_message(msg["id"], "invalid_info", str(err)) + ) + return + + result = {"entity_entry": _entry_ext_dict(entity_entry)} if "disabled_by" in changes and changes["disabled_by"] is None: # Enabling an entity requires a config entry reload, or HA restart - config_entry = hass.config_entries.async_get_entry(entry.config_entry_id) + config_entry = hass.config_entries.async_get_entry(entity_entry.config_entry_id) if config_entry and not config_entry.supports_unload: result["require_restart"] = True else: @@ -201,6 +218,7 @@ def _entry_ext_dict(entry): data = _entry_dict(entry) data["capabilities"] = entry.capabilities data["device_class"] = entry.device_class + data["options"] = entry.options data["original_device_class"] = entry.original_device_class data["original_icon"] = entry.original_icon data["original_name"] = entry.original_name diff --git a/homeassistant/components/sensor/__init__.py b/homeassistant/components/sensor/__init__.py index 3414f13268f..b9cb3d94796 100644 --- a/homeassistant/components/sensor/__init__.py +++ b/homeassistant/components/sensor/__init__.py @@ -1,12 +1,13 @@ """Component to interface with various sensors that can be monitored.""" from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Callable, Mapping from contextlib import suppress from dataclasses import dataclass from datetime import date, datetime, timedelta, timezone import inspect import logging +from math import floor, log10 from typing import Any, Final, cast, final import voluptuous as vol @@ -14,6 +15,7 @@ import voluptuous as vol from homeassistant.backports.enum import StrEnum from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( # noqa: F401 + CONF_UNIT_OF_MEASUREMENT, DEVICE_CLASS_AQI, DEVICE_CLASS_BATTERY, DEVICE_CLASS_CO, @@ -44,8 +46,9 @@ from homeassistant.const import ( # noqa: F401 DEVICE_CLASS_VOLTAGE, TEMP_CELSIUS, TEMP_FAHRENHEIT, + TEMP_KELVIN, ) -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.config_validation import ( # noqa: F401 PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, @@ -54,7 +57,11 @@ from homeassistant.helpers.entity import Entity, EntityDescription from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.restore_state import ExtraStoredData, RestoreEntity from homeassistant.helpers.typing import ConfigType, StateType -from homeassistant.util import dt as dt_util +from homeassistant.util import ( + dt as dt_util, + pressure as pressure_util, + temperature as temperature_util, +) from .const import CONF_STATE_CLASS # noqa: F401 @@ -194,6 +201,25 @@ STATE_CLASS_TOTAL: Final = "total" STATE_CLASS_TOTAL_INCREASING: Final = "total_increasing" STATE_CLASSES: Final[list[str]] = [cls.value for cls in SensorStateClass] +UNIT_CONVERSIONS: dict[str, Callable[[float, str, str], float]] = { + SensorDeviceClass.PRESSURE: pressure_util.convert, + SensorDeviceClass.TEMPERATURE: temperature_util.convert, +} + +UNIT_RATIOS: dict[str, dict[str, float]] = { + SensorDeviceClass.PRESSURE: pressure_util.UNIT_CONVERSION, + SensorDeviceClass.TEMPERATURE: { + TEMP_CELSIUS: 1.0, + TEMP_FAHRENHEIT: 1.8, + TEMP_KELVIN: 1.0, + }, +} + +VALID_UNITS: dict[str, tuple[str, ...]] = { + SensorDeviceClass.PRESSURE: pressure_util.VALID_UNITS, + SensorDeviceClass.TEMPERATURE: temperature_util.VALID_UNITS, +} + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Track states and offer events for sensors.""" @@ -264,10 +290,18 @@ class SensorEntity(Entity): ) _last_reset_reported = False _temperature_conversion_reported = False + _sensor_option_unit_of_measurement: str | None = None # Temporary private attribute to track if deprecation has been logged. __datetime_as_string_deprecation_logged = False + async def async_internal_added_to_hass(self) -> None: + """Call when the sensor entity is added to hass.""" + await super().async_internal_added_to_hass() + if not self.registry_entry: + return + self.async_registry_entry_updated() + @property def device_class(self) -> SensorDeviceClass | str | None: """Return the class of this entity.""" @@ -350,6 +384,9 @@ class SensorEntity(Entity): @property def unit_of_measurement(self) -> str | None: """Return the unit of measurement of the entity, after unit conversion.""" + if self._sensor_option_unit_of_measurement: + return self._sensor_option_unit_of_measurement + # Support for _attr_unit_of_measurement will be removed in Home Assistant 2021.11 if ( hasattr(self, "_attr_unit_of_measurement") @@ -368,7 +405,8 @@ class SensorEntity(Entity): @property def state(self) -> Any: """Return the state of the sensor and perform unit conversions, if needed.""" - unit_of_measurement = self.native_unit_of_measurement + native_unit_of_measurement = self.native_unit_of_measurement + unit_of_measurement = self.unit_of_measurement value = self.native_value device_class = self.device_class @@ -407,16 +445,48 @@ class SensorEntity(Entity): f"but does not provide a date state but {type(value)}" ) from err - units = self.hass.config.units if ( value is not None - and unit_of_measurement in (TEMP_CELSIUS, TEMP_FAHRENHEIT) - and unit_of_measurement != units.temperature_unit + and native_unit_of_measurement != unit_of_measurement + and self.device_class in UNIT_CONVERSIONS ): - if ( - self.device_class != DEVICE_CLASS_TEMPERATURE - and not self._temperature_conversion_reported - ): + assert unit_of_measurement + assert native_unit_of_measurement + + value_s = str(value) + prec = len(value_s) - value_s.index(".") - 1 if "." in value_s else 0 + + # Scale the precision when converting to a larger unit + # For example 1.1 kWh should be rendered as 0.0011 kWh, not 0.0 kWh + ratio_log = max( + 0, + log10( + UNIT_RATIOS[self.device_class][native_unit_of_measurement] + / UNIT_RATIOS[self.device_class][unit_of_measurement] + ), + ) + prec = prec + floor(ratio_log) + + # Suppress ValueError (Could not convert sensor_value to float) + with suppress(ValueError): + value_f = float(value) # type: ignore[arg-type] + value_f_new = UNIT_CONVERSIONS[self.device_class]( + value_f, + native_unit_of_measurement, + unit_of_measurement, + ) + + # Round to the wanted precision + value = round(value_f_new) if prec == 0 else round(value_f_new, prec) + + elif ( + value is not None + and self.device_class != DEVICE_CLASS_TEMPERATURE + and native_unit_of_measurement != self.hass.config.units.temperature_unit + and native_unit_of_measurement in (TEMP_CELSIUS, TEMP_FAHRENHEIT) + ): + units = self.hass.config.units + if not self._temperature_conversion_reported: self._temperature_conversion_reported = True report_issue = self._suggest_report_issue() _LOGGER.warning( @@ -429,7 +499,7 @@ class SensorEntity(Entity): self.entity_id, type(self), self.device_class, - unit_of_measurement, + native_unit_of_measurement, units.temperature_unit, report_issue, ) @@ -437,7 +507,7 @@ class SensorEntity(Entity): prec = len(value_s) - value_s.index(".") - 1 if "." in value_s else 0 # Suppress ValueError (Could not convert sensor_value to float) with suppress(ValueError): - temp = units.temperature(float(value), unit_of_measurement) # type: ignore[arg-type] + temp = units.temperature(float(value), native_unit_of_measurement) # type: ignore[arg-type] value = round(temp) if prec == 0 else round(temp, prec) return value @@ -453,6 +523,22 @@ class SensorEntity(Entity): return super().__repr__() + @callback + def async_registry_entry_updated(self) -> None: + """Run when the entity registry entry has been updated.""" + assert self.registry_entry + if ( + (sensor_options := self.registry_entry.options.get(DOMAIN)) + and (custom_unit := sensor_options.get(CONF_UNIT_OF_MEASUREMENT)) + and (device_class := self.device_class) in UNIT_CONVERSIONS + and self.native_unit_of_measurement in VALID_UNITS[device_class] + and custom_unit in VALID_UNITS[device_class] + ): + self._sensor_option_unit_of_measurement = custom_unit + return + + self._sensor_option_unit_of_measurement = None + @dataclass class SensorExtraStoredData(ExtraStoredData): diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 30df64d3e88..24f9c5a64d8 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -827,6 +827,13 @@ class Entity(ABC): To be extended by integrations. """ + @callback + def async_registry_entry_updated(self) -> None: + """Run when the entity registry entry has been updated. + + To be extended by integrations. + """ + async def async_internal_added_to_hass(self) -> None: """Run when entity about to be added to hass. @@ -888,6 +895,7 @@ class Entity(ABC): assert old is not None if self.registry_entry.entity_id == old.entity_id: + self.async_registry_entry_updated() self.async_write_ha_state() return diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 90ca1e55cf2..c02ffde4672 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -618,11 +618,11 @@ class EntityRegistry: @callback def async_update_entity_options( self, entity_id: str, domain: str, options: dict[str, Any] - ) -> None: + ) -> RegistryEntry: """Update entity options.""" old = self.entities[entity_id] new_options: Mapping[str, Mapping[str, Any]] = {**old.options, domain: options} - self.entities[entity_id] = attr.evolve(old, options=new_options) + new = self.entities[entity_id] = attr.evolve(old, options=new_options) self.async_schedule_save() @@ -634,6 +634,8 @@ class EntityRegistry: self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, data) + return new + async def async_load(self) -> None: """Load the entity registry.""" async_setup_entity_restore(self.hass, self) diff --git a/homeassistant/util/temperature.py b/homeassistant/util/temperature.py index bc3cb4c1017..d7b7597d6d0 100644 --- a/homeassistant/util/temperature.py +++ b/homeassistant/util/temperature.py @@ -7,6 +7,12 @@ from homeassistant.const import ( UNIT_NOT_RECOGNIZED_TEMPLATE, ) +VALID_UNITS: tuple[str, ...] = ( + TEMP_CELSIUS, + TEMP_FAHRENHEIT, + TEMP_KELVIN, +) + def fahrenheit_to_celsius(fahrenheit: float, interval: bool = False) -> float: """Convert a temperature in Fahrenheit to Celsius.""" diff --git a/tests/components/config/test_entity_registry.py b/tests/components/config/test_entity_registry.py index 19a07b3f8f7..e74e43de701 100644 --- a/tests/components/config/test_entity_registry.py +++ b/tests/components/config/test_entity_registry.py @@ -118,6 +118,7 @@ async def test_get_entity(hass, client): "hidden_by": None, "icon": None, "name": "Hello World", + "options": {}, "original_device_class": None, "original_icon": None, "original_name": None, @@ -146,6 +147,7 @@ async def test_get_entity(hass, client): "hidden_by": None, "icon": None, "name": None, + "options": {}, "original_device_class": None, "original_icon": None, "original_name": None, @@ -207,6 +209,7 @@ async def test_update_entity(hass, client): "hidden_by": "user", # We exchange strings over the WS API, not enums "icon": "icon:after update", "name": "after update", + "options": {}, "original_device_class": None, "original_icon": None, "original_name": None, @@ -277,6 +280,7 @@ async def test_update_entity(hass, client): "hidden_by": "user", # We exchange strings over the WS API, not enums "icon": "icon:after update", "name": "after update", + "options": {}, "original_device_class": None, "original_icon": None, "original_name": None, @@ -286,6 +290,41 @@ async def test_update_entity(hass, client): "reload_delay": 30, } + # UPDATE ENTITY OPTION + await client.send_json( + { + "id": 10, + "type": "config/entity_registry/update", + "entity_id": "test_domain.world", + "options_domain": "sensor", + "options": {"unit_of_measurement": "beard_second"}, + } + ) + + msg = await client.receive_json() + + assert msg["result"] == { + "entity_entry": { + "area_id": "mock-area-id", + "capabilities": None, + "config_entry_id": None, + "device_class": "custom_device_class", + "device_id": None, + "disabled_by": None, + "entity_category": None, + "entity_id": "test_domain.world", + "hidden_by": "user", # We exchange strings over the WS API, not enums + "icon": "icon:after update", + "name": "after update", + "options": {"sensor": {"unit_of_measurement": "beard_second"}}, + "original_device_class": None, + "original_icon": None, + "original_name": None, + "platform": "test_platform", + "unique_id": "1234", + }, + } + async def test_update_entity_require_restart(hass, client): """Test updating entity.""" @@ -335,6 +374,7 @@ async def test_update_entity_require_restart(hass, client): "icon": None, "hidden_by": None, "name": None, + "options": {}, "original_device_class": None, "original_icon": None, "original_name": None, @@ -440,6 +480,7 @@ async def test_update_entity_no_changes(hass, client): "hidden_by": None, "icon": None, "name": "name of entity", + "options": {}, "original_device_class": None, "original_icon": None, "original_name": None, @@ -524,6 +565,7 @@ async def test_update_entity_id(hass, client): "hidden_by": None, "icon": None, "name": None, + "options": {}, "original_device_class": None, "original_icon": None, "original_name": None, diff --git a/tests/components/google_wifi/test_sensor.py b/tests/components/google_wifi/test_sensor.py index add8ec04cbe..ae0715c640b 100644 --- a/tests/components/google_wifi/test_sensor.py +++ b/tests/components/google_wifi/test_sensor.py @@ -108,9 +108,9 @@ def test_name(requests_mock): assert test_name == sensor.name -def test_unit_of_measurement(requests_mock): +def test_unit_of_measurement(hass, requests_mock): """Test the unit of measurement.""" - api, sensor_dict = setup_api(None, MOCK_DATA, requests_mock) + api, sensor_dict = setup_api(hass, MOCK_DATA, requests_mock) for name in sensor_dict: sensor = sensor_dict[name]["sensor"] assert sensor_dict[name]["units"] == sensor.unit_of_measurement diff --git a/tests/components/sensor/test_init.py b/tests/components/sensor/test_init.py index df33cb1a081..1d9df34f657 100644 --- a/tests/components/sensor/test_init.py +++ b/tests/components/sensor/test_init.py @@ -7,11 +7,16 @@ from pytest import approx from homeassistant.components.sensor import SensorDeviceClass, SensorEntityDescription from homeassistant.const import ( ATTR_UNIT_OF_MEASUREMENT, + PRESSURE_HPA, + PRESSURE_INHG, + PRESSURE_KPA, + PRESSURE_MMHG, STATE_UNKNOWN, TEMP_CELSIUS, TEMP_FAHRENHEIT, ) from homeassistant.core import State +from homeassistant.helpers import entity_registry as er from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY from homeassistant.setup import async_setup_component from homeassistant.util import dt as dt_util @@ -342,3 +347,117 @@ async def test_restore_sensor_restore_state( assert entity0.native_value == native_value assert type(entity0.native_value) == native_value_type assert entity0.native_unit_of_measurement == uom + + +@pytest.mark.parametrize( + "native_unit,custom_unit,state_unit,native_value,custom_value", + [ + # Smaller to larger unit, InHg is ~33x larger than hPa -> 1 more decimal + (PRESSURE_HPA, PRESSURE_INHG, PRESSURE_INHG, 1000.0, 29.53), + (PRESSURE_KPA, PRESSURE_HPA, PRESSURE_HPA, 1.234, 12.34), + (PRESSURE_HPA, PRESSURE_MMHG, PRESSURE_MMHG, 1000, 750), + # Not a supported pressure unit + (PRESSURE_HPA, "peer_pressure", PRESSURE_HPA, 1000, 1000), + ], +) +async def test_custom_unit( + hass, + enable_custom_integrations, + native_unit, + custom_unit, + state_unit, + native_value, + custom_value, +): + """Test custom unit.""" + entity_registry = er.async_get(hass) + + entry = entity_registry.async_get_or_create("sensor", "test", "very_unique") + entity_registry.async_update_entity_options( + entry.entity_id, "sensor", {"unit_of_measurement": custom_unit} + ) + await hass.async_block_till_done() + + platform = getattr(hass.components, "test.sensor") + platform.init(empty=True) + platform.ENTITIES["0"] = platform.MockSensor( + name="Test", + native_value=str(native_value), + native_unit_of_measurement=native_unit, + device_class=SensorDeviceClass.PRESSURE, + unique_id="very_unique", + ) + + entity0 = platform.ENTITIES["0"] + assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}}) + await hass.async_block_till_done() + + state = hass.states.get(entity0.entity_id) + assert float(state.state) == approx(float(custom_value)) + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == state_unit + + +@pytest.mark.parametrize( + "native_unit,custom_unit,state_unit,native_value,custom_value", + [ + # Smaller to larger unit, InHg is ~33x larger than hPa -> 1 more decimal + (PRESSURE_HPA, PRESSURE_INHG, PRESSURE_INHG, 1000.0, 29.53), + (PRESSURE_KPA, PRESSURE_HPA, PRESSURE_HPA, 1.234, 12.34), + (PRESSURE_HPA, PRESSURE_MMHG, PRESSURE_MMHG, 1000, 750), + # Not a supported pressure unit + (PRESSURE_HPA, "peer_pressure", PRESSURE_HPA, 1000, 1000), + ], +) +async def test_custom_unit_change( + hass, + enable_custom_integrations, + native_unit, + custom_unit, + state_unit, + native_value, + custom_value, +): + """Test custom unit changes are picked up.""" + entity_registry = er.async_get(hass) + platform = getattr(hass.components, "test.sensor") + platform.init(empty=True) + platform.ENTITIES["0"] = platform.MockSensor( + name="Test", + native_value=str(native_value), + native_unit_of_measurement=native_unit, + device_class=SensorDeviceClass.PRESSURE, + unique_id="very_unique", + ) + + entity0 = platform.ENTITIES["0"] + assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}}) + await hass.async_block_till_done() + + state = hass.states.get(entity0.entity_id) + assert float(state.state) == approx(float(native_value)) + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == native_unit + + entity_registry.async_update_entity_options( + "sensor.test", "sensor", {"unit_of_measurement": custom_unit} + ) + await hass.async_block_till_done() + + state = hass.states.get(entity0.entity_id) + assert float(state.state) == approx(float(custom_value)) + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == state_unit + + entity_registry.async_update_entity_options( + "sensor.test", "sensor", {"unit_of_measurement": native_unit} + ) + await hass.async_block_till_done() + + state = hass.states.get(entity0.entity_id) + assert float(state.state) == approx(float(native_value)) + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == native_unit + + entity_registry.async_update_entity_options("sensor.test", "sensor", None) + await hass.async_block_till_done() + + state = hass.states.get(entity0.entity_id) + assert float(state.state) == approx(float(native_value)) + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == native_unit