Allow customizing unit for temperature and pressure sensors (#64366)

* Allow customizing unit for temperature and pressure sensors

* pylint

* Adjust google_wifi tests

* Address review comments and add tests

* Improve rounding when scaling

* Tweak rounding

* Further tweak rounding

* Allow setting entity options with config/entity_registry/update

* Address review comments

* Tweak tests

* Load custom unit when sensor is added

* Override async_internal_added_to_hass

* Adjust tests after rebase

* Apply suggestions from code review

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>

* Address review comments

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Erik Montnemery 2022-03-30 15:43:04 +02:00 committed by GitHub
parent 5b1e319947
commit 91f6e58e9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 306 additions and 25 deletions

View File

@ -86,6 +86,8 @@ def websocket_get_entity(hass, connection, msg):
er.RegistryEntryHider.USER.value,
),
),
vol.Inclusive("options_domain", "entity_option"): str,
vol.Inclusive("options", "entity_option"): vol.Any(None, dict),
}
)
@callback
@ -96,7 +98,8 @@ def websocket_update_entity(hass, connection, msg):
"""
registry = er.async_get(hass)
if msg["entity_id"] not in registry.entities:
entity_id = msg["entity_id"]
if not (entity_entry := registry.async_get(entity_id)):
connection.send_message(
websocket_api.error_message(msg["id"], ERR_NOT_FOUND, "Entity not found")
)
@ -108,7 +111,7 @@ def websocket_update_entity(hass, connection, msg):
if key in msg:
changes[key] = msg[key]
if "new_entity_id" in msg and msg["new_entity_id"] != msg["entity_id"]:
if "new_entity_id" in msg and msg["new_entity_id"] != entity_id:
changes["new_entity_id"] = msg["new_entity_id"]
if hass.states.get(msg["new_entity_id"]) is not None:
connection.send_message(
@ -122,10 +125,9 @@ def websocket_update_entity(hass, connection, msg):
if "disabled_by" in msg and msg["disabled_by"] is None:
# Don't allow enabling an entity of a disabled device
entity = registry.entities[msg["entity_id"]]
if entity.device_id:
if entity_entry.device_id:
device_registry = dr.async_get(hass)
device = device_registry.async_get(entity.device_id)
device = device_registry.async_get(entity_entry.device_id)
if device.disabled:
connection.send_message(
websocket_api.error_message(
@ -136,16 +138,31 @@ def websocket_update_entity(hass, connection, msg):
try:
if changes:
entry = registry.async_update_entity(msg["entity_id"], **changes)
entity_entry = registry.async_update_entity(entity_id, **changes)
except ValueError as err:
connection.send_message(
websocket_api.error_message(msg["id"], "invalid_info", str(err))
)
return
result = {"entity_entry": _entry_ext_dict(entry)}
if "new_entity_id" in msg:
entity_id = msg["new_entity_id"]
try:
if "options_domain" in msg:
entity_entry = registry.async_update_entity_options(
entity_id, msg["options_domain"], msg["options"]
)
except ValueError as err:
connection.send_message(
websocket_api.error_message(msg["id"], "invalid_info", str(err))
)
return
result = {"entity_entry": _entry_ext_dict(entity_entry)}
if "disabled_by" in changes and changes["disabled_by"] is None:
# Enabling an entity requires a config entry reload, or HA restart
config_entry = hass.config_entries.async_get_entry(entry.config_entry_id)
config_entry = hass.config_entries.async_get_entry(entity_entry.config_entry_id)
if config_entry and not config_entry.supports_unload:
result["require_restart"] = True
else:
@ -201,6 +218,7 @@ def _entry_ext_dict(entry):
data = _entry_dict(entry)
data["capabilities"] = entry.capabilities
data["device_class"] = entry.device_class
data["options"] = entry.options
data["original_device_class"] = entry.original_device_class
data["original_icon"] = entry.original_icon
data["original_name"] = entry.original_name

View File

@ -1,12 +1,13 @@
"""Component to interface with various sensors that can be monitored."""
from __future__ import annotations
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from contextlib import suppress
from dataclasses import dataclass
from datetime import date, datetime, timedelta, timezone
import inspect
import logging
from math import floor, log10
from typing import Any, Final, cast, final
import voluptuous as vol
@ -14,6 +15,7 @@ import voluptuous as vol
from homeassistant.backports.enum import StrEnum
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( # noqa: F401
CONF_UNIT_OF_MEASUREMENT,
DEVICE_CLASS_AQI,
DEVICE_CLASS_BATTERY,
DEVICE_CLASS_CO,
@ -44,8 +46,9 @@ from homeassistant.const import ( # noqa: F401
DEVICE_CLASS_VOLTAGE,
TEMP_CELSIUS,
TEMP_FAHRENHEIT,
TEMP_KELVIN,
)
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.config_validation import ( # noqa: F401
PLATFORM_SCHEMA,
PLATFORM_SCHEMA_BASE,
@ -54,7 +57,11 @@ from homeassistant.helpers.entity import Entity, EntityDescription
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import ExtraStoredData, RestoreEntity
from homeassistant.helpers.typing import ConfigType, StateType
from homeassistant.util import dt as dt_util
from homeassistant.util import (
dt as dt_util,
pressure as pressure_util,
temperature as temperature_util,
)
from .const import CONF_STATE_CLASS # noqa: F401
@ -194,6 +201,25 @@ STATE_CLASS_TOTAL: Final = "total"
STATE_CLASS_TOTAL_INCREASING: Final = "total_increasing"
STATE_CLASSES: Final[list[str]] = [cls.value for cls in SensorStateClass]
UNIT_CONVERSIONS: dict[str, Callable[[float, str, str], float]] = {
SensorDeviceClass.PRESSURE: pressure_util.convert,
SensorDeviceClass.TEMPERATURE: temperature_util.convert,
}
UNIT_RATIOS: dict[str, dict[str, float]] = {
SensorDeviceClass.PRESSURE: pressure_util.UNIT_CONVERSION,
SensorDeviceClass.TEMPERATURE: {
TEMP_CELSIUS: 1.0,
TEMP_FAHRENHEIT: 1.8,
TEMP_KELVIN: 1.0,
},
}
VALID_UNITS: dict[str, tuple[str, ...]] = {
SensorDeviceClass.PRESSURE: pressure_util.VALID_UNITS,
SensorDeviceClass.TEMPERATURE: temperature_util.VALID_UNITS,
}
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Track states and offer events for sensors."""
@ -264,10 +290,18 @@ class SensorEntity(Entity):
)
_last_reset_reported = False
_temperature_conversion_reported = False
_sensor_option_unit_of_measurement: str | None = None
# Temporary private attribute to track if deprecation has been logged.
__datetime_as_string_deprecation_logged = False
async def async_internal_added_to_hass(self) -> None:
"""Call when the sensor entity is added to hass."""
await super().async_internal_added_to_hass()
if not self.registry_entry:
return
self.async_registry_entry_updated()
@property
def device_class(self) -> SensorDeviceClass | str | None:
"""Return the class of this entity."""
@ -350,6 +384,9 @@ class SensorEntity(Entity):
@property
def unit_of_measurement(self) -> str | None:
"""Return the unit of measurement of the entity, after unit conversion."""
if self._sensor_option_unit_of_measurement:
return self._sensor_option_unit_of_measurement
# Support for _attr_unit_of_measurement will be removed in Home Assistant 2021.11
if (
hasattr(self, "_attr_unit_of_measurement")
@ -368,7 +405,8 @@ class SensorEntity(Entity):
@property
def state(self) -> Any:
"""Return the state of the sensor and perform unit conversions, if needed."""
unit_of_measurement = self.native_unit_of_measurement
native_unit_of_measurement = self.native_unit_of_measurement
unit_of_measurement = self.unit_of_measurement
value = self.native_value
device_class = self.device_class
@ -407,16 +445,48 @@ class SensorEntity(Entity):
f"but does not provide a date state but {type(value)}"
) from err
units = self.hass.config.units
if (
value is not None
and unit_of_measurement in (TEMP_CELSIUS, TEMP_FAHRENHEIT)
and unit_of_measurement != units.temperature_unit
and native_unit_of_measurement != unit_of_measurement
and self.device_class in UNIT_CONVERSIONS
):
if (
self.device_class != DEVICE_CLASS_TEMPERATURE
and not self._temperature_conversion_reported
):
assert unit_of_measurement
assert native_unit_of_measurement
value_s = str(value)
prec = len(value_s) - value_s.index(".") - 1 if "." in value_s else 0
# Scale the precision when converting to a larger unit
# For example 1.1 kWh should be rendered as 0.0011 kWh, not 0.0 kWh
ratio_log = max(
0,
log10(
UNIT_RATIOS[self.device_class][native_unit_of_measurement]
/ UNIT_RATIOS[self.device_class][unit_of_measurement]
),
)
prec = prec + floor(ratio_log)
# Suppress ValueError (Could not convert sensor_value to float)
with suppress(ValueError):
value_f = float(value) # type: ignore[arg-type]
value_f_new = UNIT_CONVERSIONS[self.device_class](
value_f,
native_unit_of_measurement,
unit_of_measurement,
)
# Round to the wanted precision
value = round(value_f_new) if prec == 0 else round(value_f_new, prec)
elif (
value is not None
and self.device_class != DEVICE_CLASS_TEMPERATURE
and native_unit_of_measurement != self.hass.config.units.temperature_unit
and native_unit_of_measurement in (TEMP_CELSIUS, TEMP_FAHRENHEIT)
):
units = self.hass.config.units
if not self._temperature_conversion_reported:
self._temperature_conversion_reported = True
report_issue = self._suggest_report_issue()
_LOGGER.warning(
@ -429,7 +499,7 @@ class SensorEntity(Entity):
self.entity_id,
type(self),
self.device_class,
unit_of_measurement,
native_unit_of_measurement,
units.temperature_unit,
report_issue,
)
@ -437,7 +507,7 @@ class SensorEntity(Entity):
prec = len(value_s) - value_s.index(".") - 1 if "." in value_s else 0
# Suppress ValueError (Could not convert sensor_value to float)
with suppress(ValueError):
temp = units.temperature(float(value), unit_of_measurement) # type: ignore[arg-type]
temp = units.temperature(float(value), native_unit_of_measurement) # type: ignore[arg-type]
value = round(temp) if prec == 0 else round(temp, prec)
return value
@ -453,6 +523,22 @@ class SensorEntity(Entity):
return super().__repr__()
@callback
def async_registry_entry_updated(self) -> None:
"""Run when the entity registry entry has been updated."""
assert self.registry_entry
if (
(sensor_options := self.registry_entry.options.get(DOMAIN))
and (custom_unit := sensor_options.get(CONF_UNIT_OF_MEASUREMENT))
and (device_class := self.device_class) in UNIT_CONVERSIONS
and self.native_unit_of_measurement in VALID_UNITS[device_class]
and custom_unit in VALID_UNITS[device_class]
):
self._sensor_option_unit_of_measurement = custom_unit
return
self._sensor_option_unit_of_measurement = None
@dataclass
class SensorExtraStoredData(ExtraStoredData):

View File

@ -827,6 +827,13 @@ class Entity(ABC):
To be extended by integrations.
"""
@callback
def async_registry_entry_updated(self) -> None:
"""Run when the entity registry entry has been updated.
To be extended by integrations.
"""
async def async_internal_added_to_hass(self) -> None:
"""Run when entity about to be added to hass.
@ -888,6 +895,7 @@ class Entity(ABC):
assert old is not None
if self.registry_entry.entity_id == old.entity_id:
self.async_registry_entry_updated()
self.async_write_ha_state()
return

View File

@ -618,11 +618,11 @@ class EntityRegistry:
@callback
def async_update_entity_options(
self, entity_id: str, domain: str, options: dict[str, Any]
) -> None:
) -> RegistryEntry:
"""Update entity options."""
old = self.entities[entity_id]
new_options: Mapping[str, Mapping[str, Any]] = {**old.options, domain: options}
self.entities[entity_id] = attr.evolve(old, options=new_options)
new = self.entities[entity_id] = attr.evolve(old, options=new_options)
self.async_schedule_save()
@ -634,6 +634,8 @@ class EntityRegistry:
self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, data)
return new
async def async_load(self) -> None:
"""Load the entity registry."""
async_setup_entity_restore(self.hass, self)

View File

@ -7,6 +7,12 @@ from homeassistant.const import (
UNIT_NOT_RECOGNIZED_TEMPLATE,
)
VALID_UNITS: tuple[str, ...] = (
TEMP_CELSIUS,
TEMP_FAHRENHEIT,
TEMP_KELVIN,
)
def fahrenheit_to_celsius(fahrenheit: float, interval: bool = False) -> float:
"""Convert a temperature in Fahrenheit to Celsius."""

View File

@ -118,6 +118,7 @@ async def test_get_entity(hass, client):
"hidden_by": None,
"icon": None,
"name": "Hello World",
"options": {},
"original_device_class": None,
"original_icon": None,
"original_name": None,
@ -146,6 +147,7 @@ async def test_get_entity(hass, client):
"hidden_by": None,
"icon": None,
"name": None,
"options": {},
"original_device_class": None,
"original_icon": None,
"original_name": None,
@ -207,6 +209,7 @@ async def test_update_entity(hass, client):
"hidden_by": "user", # We exchange strings over the WS API, not enums
"icon": "icon:after update",
"name": "after update",
"options": {},
"original_device_class": None,
"original_icon": None,
"original_name": None,
@ -277,6 +280,7 @@ async def test_update_entity(hass, client):
"hidden_by": "user", # We exchange strings over the WS API, not enums
"icon": "icon:after update",
"name": "after update",
"options": {},
"original_device_class": None,
"original_icon": None,
"original_name": None,
@ -286,6 +290,41 @@ async def test_update_entity(hass, client):
"reload_delay": 30,
}
# UPDATE ENTITY OPTION
await client.send_json(
{
"id": 10,
"type": "config/entity_registry/update",
"entity_id": "test_domain.world",
"options_domain": "sensor",
"options": {"unit_of_measurement": "beard_second"},
}
)
msg = await client.receive_json()
assert msg["result"] == {
"entity_entry": {
"area_id": "mock-area-id",
"capabilities": None,
"config_entry_id": None,
"device_class": "custom_device_class",
"device_id": None,
"disabled_by": None,
"entity_category": None,
"entity_id": "test_domain.world",
"hidden_by": "user", # We exchange strings over the WS API, not enums
"icon": "icon:after update",
"name": "after update",
"options": {"sensor": {"unit_of_measurement": "beard_second"}},
"original_device_class": None,
"original_icon": None,
"original_name": None,
"platform": "test_platform",
"unique_id": "1234",
},
}
async def test_update_entity_require_restart(hass, client):
"""Test updating entity."""
@ -335,6 +374,7 @@ async def test_update_entity_require_restart(hass, client):
"icon": None,
"hidden_by": None,
"name": None,
"options": {},
"original_device_class": None,
"original_icon": None,
"original_name": None,
@ -440,6 +480,7 @@ async def test_update_entity_no_changes(hass, client):
"hidden_by": None,
"icon": None,
"name": "name of entity",
"options": {},
"original_device_class": None,
"original_icon": None,
"original_name": None,
@ -524,6 +565,7 @@ async def test_update_entity_id(hass, client):
"hidden_by": None,
"icon": None,
"name": None,
"options": {},
"original_device_class": None,
"original_icon": None,
"original_name": None,

View File

@ -108,9 +108,9 @@ def test_name(requests_mock):
assert test_name == sensor.name
def test_unit_of_measurement(requests_mock):
def test_unit_of_measurement(hass, requests_mock):
"""Test the unit of measurement."""
api, sensor_dict = setup_api(None, MOCK_DATA, requests_mock)
api, sensor_dict = setup_api(hass, MOCK_DATA, requests_mock)
for name in sensor_dict:
sensor = sensor_dict[name]["sensor"]
assert sensor_dict[name]["units"] == sensor.unit_of_measurement

View File

@ -7,11 +7,16 @@ from pytest import approx
from homeassistant.components.sensor import SensorDeviceClass, SensorEntityDescription
from homeassistant.const import (
ATTR_UNIT_OF_MEASUREMENT,
PRESSURE_HPA,
PRESSURE_INHG,
PRESSURE_KPA,
PRESSURE_MMHG,
STATE_UNKNOWN,
TEMP_CELSIUS,
TEMP_FAHRENHEIT,
)
from homeassistant.core import State
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
@ -342,3 +347,117 @@ async def test_restore_sensor_restore_state(
assert entity0.native_value == native_value
assert type(entity0.native_value) == native_value_type
assert entity0.native_unit_of_measurement == uom
@pytest.mark.parametrize(
"native_unit,custom_unit,state_unit,native_value,custom_value",
[
# Smaller to larger unit, InHg is ~33x larger than hPa -> 1 more decimal
(PRESSURE_HPA, PRESSURE_INHG, PRESSURE_INHG, 1000.0, 29.53),
(PRESSURE_KPA, PRESSURE_HPA, PRESSURE_HPA, 1.234, 12.34),
(PRESSURE_HPA, PRESSURE_MMHG, PRESSURE_MMHG, 1000, 750),
# Not a supported pressure unit
(PRESSURE_HPA, "peer_pressure", PRESSURE_HPA, 1000, 1000),
],
)
async def test_custom_unit(
hass,
enable_custom_integrations,
native_unit,
custom_unit,
state_unit,
native_value,
custom_value,
):
"""Test custom unit."""
entity_registry = er.async_get(hass)
entry = entity_registry.async_get_or_create("sensor", "test", "very_unique")
entity_registry.async_update_entity_options(
entry.entity_id, "sensor", {"unit_of_measurement": custom_unit}
)
await hass.async_block_till_done()
platform = getattr(hass.components, "test.sensor")
platform.init(empty=True)
platform.ENTITIES["0"] = platform.MockSensor(
name="Test",
native_value=str(native_value),
native_unit_of_measurement=native_unit,
device_class=SensorDeviceClass.PRESSURE,
unique_id="very_unique",
)
entity0 = platform.ENTITIES["0"]
assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}})
await hass.async_block_till_done()
state = hass.states.get(entity0.entity_id)
assert float(state.state) == approx(float(custom_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == state_unit
@pytest.mark.parametrize(
"native_unit,custom_unit,state_unit,native_value,custom_value",
[
# Smaller to larger unit, InHg is ~33x larger than hPa -> 1 more decimal
(PRESSURE_HPA, PRESSURE_INHG, PRESSURE_INHG, 1000.0, 29.53),
(PRESSURE_KPA, PRESSURE_HPA, PRESSURE_HPA, 1.234, 12.34),
(PRESSURE_HPA, PRESSURE_MMHG, PRESSURE_MMHG, 1000, 750),
# Not a supported pressure unit
(PRESSURE_HPA, "peer_pressure", PRESSURE_HPA, 1000, 1000),
],
)
async def test_custom_unit_change(
hass,
enable_custom_integrations,
native_unit,
custom_unit,
state_unit,
native_value,
custom_value,
):
"""Test custom unit changes are picked up."""
entity_registry = er.async_get(hass)
platform = getattr(hass.components, "test.sensor")
platform.init(empty=True)
platform.ENTITIES["0"] = platform.MockSensor(
name="Test",
native_value=str(native_value),
native_unit_of_measurement=native_unit,
device_class=SensorDeviceClass.PRESSURE,
unique_id="very_unique",
)
entity0 = platform.ENTITIES["0"]
assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}})
await hass.async_block_till_done()
state = hass.states.get(entity0.entity_id)
assert float(state.state) == approx(float(native_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == native_unit
entity_registry.async_update_entity_options(
"sensor.test", "sensor", {"unit_of_measurement": custom_unit}
)
await hass.async_block_till_done()
state = hass.states.get(entity0.entity_id)
assert float(state.state) == approx(float(custom_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == state_unit
entity_registry.async_update_entity_options(
"sensor.test", "sensor", {"unit_of_measurement": native_unit}
)
await hass.async_block_till_done()
state = hass.states.get(entity0.entity_id)
assert float(state.state) == approx(float(native_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == native_unit
entity_registry.async_update_entity_options("sensor.test", "sensor", None)
await hass.async_block_till_done()
state = hass.states.get(entity0.entity_id)
assert float(state.state) == approx(float(native_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == native_unit