Make derivative sensor unavailable when source sensor is unavailable (#147468)

This commit is contained in:
karwosts 2025-07-04 13:48:48 -07:00 committed by GitHub
parent 57c04f3a56
commit 22e46d9977
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 220 additions and 26 deletions

View File

@ -198,6 +198,7 @@ class DerivativeSensor(RestoreSensor, SensorEntity):
self._attr_native_value = round(Decimal(0), round_digits)
# List of tuples with (timestamp_start, timestamp_end, derivative)
self._state_list: list[tuple[datetime, datetime, Decimal]] = []
self._last_valid_state_time: tuple[str, datetime] | None = None
self._attr_name = name if name is not None else f"{source_entity} derivative"
self._attr_extra_state_attributes = {ATTR_SOURCE_ID: source_entity}
@ -242,6 +243,25 @@ class DerivativeSensor(RestoreSensor, SensorEntity):
if (current_time - time_end).total_seconds() < self._time_window
]
def _handle_invalid_source_state(self, state: State | None) -> bool:
# Check the source state for unknown/unavailable condition. If unusable, write unknown/unavailable state and return false.
if not state or state.state == STATE_UNAVAILABLE:
self._attr_available = False
self.async_write_ha_state()
return False
if not _is_decimal_state(state.state):
self._attr_available = True
self._write_native_value(None)
return False
self._attr_available = True
return True
def _write_native_value(self, derivative: Decimal | None) -> None:
self._attr_native_value = (
None if derivative is None else round(derivative, self._round_digits)
)
self.async_write_ha_state()
async def async_added_to_hass(self) -> None:
"""Handle entity which will be added."""
await super().async_added_to_hass()
@ -255,8 +275,8 @@ class DerivativeSensor(RestoreSensor, SensorEntity):
Decimal(restored_data.native_value), # type: ignore[arg-type]
self._round_digits,
)
except SyntaxError as err:
_LOGGER.warning("Could not restore last state: %s", err)
except (InvalidOperation, TypeError):
self._attr_native_value = None
def schedule_max_sub_interval_exceeded(source_state: State | None) -> None:
"""Schedule calculation using the source state and max_sub_interval.
@ -280,9 +300,7 @@ class DerivativeSensor(RestoreSensor, SensorEntity):
self._prune_state_list(now)
derivative = self._calc_derivative_from_state_list(now)
self._attr_native_value = round(derivative, self._round_digits)
self.async_write_ha_state()
self._write_native_value(derivative)
# If derivative is now zero, don't schedule another timeout callback, as it will have no effect
if derivative != 0:
@ -299,36 +317,46 @@ class DerivativeSensor(RestoreSensor, SensorEntity):
"""Handle constant sensor state."""
self._cancel_max_sub_interval_exceeded_callback()
new_state = event.data["new_state"]
if not self._handle_invalid_source_state(new_state):
return
assert new_state
if self._attr_native_value == Decimal(0):
# If the derivative is zero, and the source sensor hasn't
# changed state, then we know it will still be zero.
return
schedule_max_sub_interval_exceeded(new_state)
new_state = event.data["new_state"]
if new_state is not None:
calc_derivative(
new_state, new_state.state, event.data["old_last_reported"]
)
calc_derivative(new_state, new_state.state, event.data["old_last_reported"])
@callback
def on_state_changed(event: Event[EventStateChangedData]) -> None:
"""Handle changed sensor state."""
self._cancel_max_sub_interval_exceeded_callback()
new_state = event.data["new_state"]
if not self._handle_invalid_source_state(new_state):
return
assert new_state
schedule_max_sub_interval_exceeded(new_state)
old_state = event.data["old_state"]
if new_state is not None and old_state is not None:
if old_state is not None:
calc_derivative(new_state, old_state.state, old_state.last_reported)
else:
# On first state change from none, update availability
self.async_write_ha_state()
def calc_derivative(
new_state: State, old_value: str, old_last_reported: datetime
) -> None:
"""Handle the sensor state changes."""
if old_value in (STATE_UNKNOWN, STATE_UNAVAILABLE) or new_state.state in (
STATE_UNKNOWN,
STATE_UNAVAILABLE,
):
return
if not _is_decimal_state(old_value):
if self._last_valid_state_time:
old_value = self._last_valid_state_time[0]
old_last_reported = self._last_valid_state_time[1]
else:
# Sensor becomes valid for the first time, just keep the restored value
self.async_write_ha_state()
return
if self.native_unit_of_measurement is None:
unit = new_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
@ -373,6 +401,10 @@ class DerivativeSensor(RestoreSensor, SensorEntity):
self._state_list.append(
(old_last_reported, new_state.last_reported, new_derivative)
)
self._last_valid_state_time = (
new_state.state,
new_state.last_reported,
)
# If outside of time window just report derivative (is the same as modeling it in the window),
# otherwise take the weighted average with the previous derivatives
@ -382,11 +414,16 @@ class DerivativeSensor(RestoreSensor, SensorEntity):
derivative = self._calc_derivative_from_state_list(
new_state.last_reported
)
self._attr_native_value = round(derivative, self._round_digits)
self.async_write_ha_state()
self._write_native_value(derivative)
source_state = self.hass.states.get(self._sensor_source_id)
if source_state is None or source_state.state in [
STATE_UNAVAILABLE,
STATE_UNKNOWN,
]:
self._attr_available = False
if self._max_sub_interval is not None:
source_state = self.hass.states.get(self._sensor_source_id)
schedule_max_sub_interval_exceeded(source_state)
@callback

View File

@ -99,6 +99,9 @@ async def test_setup_and_remove_config_entry(
input_sensor_entity_id = "sensor.input"
derivative_entity_id = "sensor.my_derivative"
hass.states.async_set(input_sensor_entity_id, "10.0", {})
await hass.async_block_till_done()
# Setup the config entry
config_entry = MockConfigEntry(
data={},

View File

@ -6,16 +6,26 @@ import random
from typing import Any
from freezegun import freeze_time
import pytest
from homeassistant.components.derivative.const import DOMAIN
from homeassistant.components.sensor import ATTR_STATE_CLASS, SensorStateClass
from homeassistant.const import STATE_UNAVAILABLE, UnitOfPower, UnitOfTime
from homeassistant.const import (
STATE_UNAVAILABLE,
STATE_UNKNOWN,
UnitOfPower,
UnitOfTime,
)
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
from tests.common import MockConfigEntry, async_fire_time_changed
from tests.common import (
MockConfigEntry,
async_fire_time_changed,
mock_restore_cache_with_extra_data,
)
async def test_state(hass: HomeAssistant) -> None:
@ -106,6 +116,7 @@ async def _setup_sensor(
config = {"sensor": dict(default_config, **config)}
assert await async_setup_component(hass, "sensor", config)
await hass.async_block_till_done()
entity_id = config["sensor"]["source"]
hass.states.async_set(entity_id, 0, {})
@ -440,16 +451,14 @@ async def test_sub_intervals_instantaneous(hass: HomeAssistant) -> None:
await hass.async_block_till_done()
state = hass.states.get("sensor.power")
derivative = round(float(state.state), config["sensor"]["round"])
assert derivative == -0.29
assert state.state == STATE_UNAVAILABLE
now += timedelta(seconds=60)
async_fire_time_changed(hass, now)
await hass.async_block_till_done()
state = hass.states.get("sensor.power")
derivative = round(float(state.state), config["sensor"]["round"])
assert derivative == -0.29
assert state.state == STATE_UNAVAILABLE
now += timedelta(seconds=10)
freezer.move_to(now)
@ -458,7 +467,7 @@ async def test_sub_intervals_instantaneous(hass: HomeAssistant) -> None:
state = hass.states.get("sensor.power")
derivative = round(float(state.state), config["sensor"]["round"])
assert derivative == -0.29
assert derivative == 0
now += timedelta(seconds=max_sub_interval + 1)
async_fire_time_changed(hass, now)
@ -693,3 +702,148 @@ async def test_device_id(
derivative_entity = entity_registry.async_get("sensor.derivative")
assert derivative_entity is not None
assert derivative_entity.device_id == source_entity.device_id
@pytest.mark.parametrize("bad_state", [STATE_UNAVAILABLE, STATE_UNKNOWN, "foo"])
async def test_unavailable(
bad_state: str,
hass: HomeAssistant,
) -> None:
"""Test derivative sensor state when unavailable."""
config, entity_id = await _setup_sensor(hass, {"unit_time": "s"})
times = [0, 1, 2, 3]
values = [0, 1, bad_state, 2]
expected_state = [
0,
1,
STATE_UNAVAILABLE if bad_state == STATE_UNAVAILABLE else STATE_UNKNOWN,
0.5,
]
# Testing a energy sensor with non-monotonic intervals and values
base = dt_util.utcnow()
with freeze_time(base) as freezer:
for time, value, expect in zip(times, values, expected_state, strict=False):
freezer.move_to(base + timedelta(seconds=time))
hass.states.async_set(entity_id, value, {})
await hass.async_block_till_done()
state = hass.states.get("sensor.power")
assert state is not None
rounded_state = (
state.state
if expect in [STATE_UNKNOWN, STATE_UNAVAILABLE]
else round(float(state.state), config["sensor"]["round"])
)
assert rounded_state == expect
@pytest.mark.parametrize("bad_state", [STATE_UNAVAILABLE, STATE_UNKNOWN, "foo"])
async def test_unavailable_2(
bad_state: str,
hass: HomeAssistant,
) -> None:
"""Test derivative sensor state when unavailable with a time window."""
config, entity_id = await _setup_sensor(
hass, {"unit_time": "s", "time_window": {"seconds": 10}}
)
# Monotonically increasing by 1, with some unavailable holes
times = list(range(21))
values = list(range(21))
values[3] = bad_state
values[6] = bad_state
values[7] = bad_state
values[8] = bad_state
base = dt_util.utcnow()
with freeze_time(base) as freezer:
for time, value in zip(times, values, strict=False):
freezer.move_to(base + timedelta(seconds=time))
hass.states.async_set(entity_id, value, {})
await hass.async_block_till_done()
state = hass.states.get("sensor.power")
assert state is not None
if value == bad_state:
assert (
state.state == STATE_UNAVAILABLE
if bad_state is STATE_UNAVAILABLE
else STATE_UNKNOWN
)
else:
expect = (time / 10) if time < 10 else 1
assert round(float(state.state), config["sensor"]["round"]) == round(
expect, config["sensor"]["round"]
)
@pytest.mark.parametrize("restore_state", ["3.00", STATE_UNKNOWN])
async def test_unavailable_boot(
restore_state,
hass: HomeAssistant,
) -> None:
"""Test that the booting sequence does not leave derivative in a bad state."""
mock_restore_cache_with_extra_data(
hass,
[
(
State(
"sensor.power",
restore_state,
{
"unit_of_measurement": "W",
},
),
{
"native_value": restore_state,
"native_unit_of_measurement": "W",
},
),
],
)
config = {
"platform": "derivative",
"name": "power",
"source": "sensor.energy",
"round": 2,
"unit_time": "s",
}
config = {"sensor": config}
entity_id = config["sensor"]["source"]
hass.states.async_set(entity_id, STATE_UNAVAILABLE, {})
await hass.async_block_till_done()
assert await async_setup_component(hass, "sensor", config)
await hass.async_block_till_done()
state = hass.states.get("sensor.power")
assert state is not None
# Sensor is unavailable as source is unavailable
assert state.state == STATE_UNAVAILABLE
base = dt_util.utcnow()
with freeze_time(base) as freezer:
freezer.move_to(base + timedelta(seconds=1))
hass.states.async_set(entity_id, 10, {})
await hass.async_block_till_done()
state = hass.states.get("sensor.power")
assert state is not None
# The source sensor has moved to a valid value, but we need 2 points to derive,
# so just hold until the next tick
assert state.state == restore_state
freezer.move_to(base + timedelta(seconds=2))
hass.states.async_set(entity_id, 15, {})
await hass.async_block_till_done()
state = hass.states.get("sensor.power")
assert state is not None
# Now that the source sensor has two valid datapoints, we can calculate derivative
assert state.state == "5.00"