Compare commits

...

2 Commits

Author SHA1 Message Date
abmantis
69a252092c Restore state in Energy cost sensors 2025-02-20 23:38:15 +00:00
abmantis
ec257a54f3 Simplify Energy cost sensor update method 2025-02-20 22:44:26 +00:00

View File

@@ -12,8 +12,8 @@ from typing import Any, Final, Literal, cast
from homeassistant.components.sensor import (
ATTR_LAST_RESET,
ATTR_STATE_CLASS,
RestoreSensor,
SensorDeviceClass,
SensorEntity,
SensorStateClass,
)
from homeassistant.components.sensor.recorder import reset_detected
@@ -222,7 +222,7 @@ def _set_result_unless_done(future: asyncio.Future[None]) -> None:
future.set_result(None)
class EnergyCostSensor(SensorEntity):
class EnergyCostSensor(RestoreSensor):
"""Calculate costs incurred by consuming energy.
This is intended as a fallback for when no specific cost sensor is available for the
@@ -312,40 +312,22 @@ class EnergyCostSensor(SensorEntity):
return
# Determine energy price
if self._config["entity_energy_price"] is not None:
energy_price_state = self.hass.states.get(
self._config["entity_energy_price"]
)
if energy_price_state is None:
return
try:
energy_price = float(energy_price_state.state)
except ValueError:
if self._last_energy_sensor_state is None:
# Initialize as it's the first time all required entities except
# price are in place. This means that the cost will update the first
# time the energy is updated after the price entity is in place.
self._reset(energy_state)
return
energy_price_unit: str | None = energy_price_state.attributes.get(
ATTR_UNIT_OF_MEASUREMENT, ""
).partition("/")[2]
# For backwards compatibility we don't validate the unit of the price
# If it is not valid, we assume it's our default price unit.
if energy_price_unit not in valid_units:
energy_price_unit = default_price_unit
else:
energy_price = cast(float, self._config["number_energy_price"])
energy_price_unit = default_price_unit
energy_price_tuple = self._get_energy_price(valid_units, default_price_unit)
if energy_price_tuple is None:
return
if self._last_energy_sensor_state is None:
# Initialize as it's the first time all required entities are in place.
self._reset(energy_state)
# Initialize as it's the first time all required entities are in place or
# only the price is missing. In the later case, cost will update the first
# time the energy is updated after the price entity is in place.
if self._attr_native_value is None:
self._reset(energy_state)
else:
self._last_energy_sensor_state = energy_state
return
energy_price, energy_price_unit = energy_price_tuple
if energy_price is None:
return
energy_unit: str | None = energy_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
@@ -383,20 +365,9 @@ class EnergyCostSensor(SensorEntity):
old_energy_value = float(self._last_energy_sensor_state.state)
cur_value = cast(float, self._attr_native_value)
if energy_price_unit is None:
converted_energy_price = energy_price
else:
converter: Callable[[float, str, str], float]
if energy_unit in VALID_ENERGY_UNITS:
converter = unit_conversion.EnergyConverter.convert
else:
converter = unit_conversion.VolumeConverter.convert
converted_energy_price = converter(
energy_price,
energy_unit,
energy_price_unit,
)
converted_energy_price = self._convert_energy_price(
energy_price, energy_price_unit, energy_unit
)
self._attr_native_value = (
cur_value + (energy - old_energy_value) * converted_energy_price
@@ -404,8 +375,53 @@ class EnergyCostSensor(SensorEntity):
self._last_energy_sensor_state = energy_state
def _get_energy_price(
self,
valid_units: set[str],
default_unit: str | None,
) -> tuple[float | None, str | None] | None:
if self._config["entity_energy_price"] is None:
return cast(float, self._config["number_energy_price"]), default_unit
energy_price_state = self.hass.states.get(self._config["entity_energy_price"])
if energy_price_state is None:
return None
try:
energy_price = float(energy_price_state.state)
except ValueError:
return (None, None)
energy_price_unit: str | None = energy_price_state.attributes.get(
ATTR_UNIT_OF_MEASUREMENT, ""
).partition("/")[2]
# For backwards compatibility we don't validate the unit of the price
# If it is not valid, we assume it's our default price unit.
if energy_price_unit not in valid_units:
energy_price_unit = default_unit
return energy_price, energy_price_unit
def _convert_energy_price(
self, energy_price: float, energy_price_unit: str | None, energy_unit: str
) -> float:
if energy_price_unit is None:
return energy_price
converter: Callable[[float, str, str], float]
if energy_unit in VALID_ENERGY_UNITS:
converter = unit_conversion.EnergyConverter.convert
else:
converter = unit_conversion.VolumeConverter.convert
return converter(energy_price, energy_unit, energy_price_unit)
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
if (sensor_data := await self.async_get_last_sensor_data()) is not None:
self._attr_native_value = sensor_data.native_value
energy_state = self.hass.states.get(self._config[self._adapter.stat_energy_key])
if energy_state:
name = energy_state.name