diff --git a/homeassistant/components/integration/sensor.py b/homeassistant/components/integration/sensor.py index 106eb9cc79c..60cbee5549f 100644 --- a/homeassistant/components/integration/sensor.py +++ b/homeassistant/components/integration/sensor.py @@ -8,7 +8,7 @@ from datetime import UTC, datetime, timedelta from decimal import Decimal, InvalidOperation from enum import Enum import logging -from typing import Any, Final, Self +from typing import TYPE_CHECKING, Any, Final, Self import voluptuous as vol @@ -27,6 +27,8 @@ from homeassistant.const import ( CONF_METHOD, CONF_NAME, CONF_UNIQUE_ID, + EVENT_STATE_CHANGED, + EVENT_STATE_REPORTED, STATE_UNAVAILABLE, UnitOfTime, ) @@ -34,6 +36,7 @@ from homeassistant.core import ( CALLBACK_TYPE, Event, EventStateChangedData, + EventStateReportedData, HomeAssistant, State, callback, @@ -42,7 +45,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers.device import async_device_info_to_link_from_entity from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.helpers.event import async_call_later, async_track_state_change_event +from homeassistant.helpers.event import async_call_later from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from .const import ( @@ -107,9 +110,7 @@ class _IntegrationMethod(ABC): return _NAME_TO_INTEGRATION_METHOD[method_name]() @abstractmethod - def validate_states( - self, left: State, right: State - ) -> tuple[Decimal, Decimal] | None: + def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None: """Check state requirements for integration.""" @abstractmethod @@ -130,11 +131,9 @@ class _Trapezoidal(_IntegrationMethod): ) -> Decimal: return elapsed_time * (left + right) / 2 - def validate_states( - self, left: State, right: State - ) -> tuple[Decimal, Decimal] | None: - if (left_dec := _decimal_state(left.state)) is None or ( - right_dec := _decimal_state(right.state) + def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None: + if (left_dec := _decimal_state(left)) is None or ( + right_dec := _decimal_state(right) ) is None: return None return (left_dec, right_dec) @@ -146,10 +145,8 @@ class _Left(_IntegrationMethod): ) -> Decimal: return self.calculate_area_with_one_state(elapsed_time, left) - def validate_states( - self, left: State, right: State - ) -> tuple[Decimal, Decimal] | None: - if (left_dec := _decimal_state(left.state)) is None: + def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None: + if (left_dec := _decimal_state(left)) is None: return None return (left_dec, left_dec) @@ -160,10 +157,8 @@ class _Right(_IntegrationMethod): ) -> Decimal: return self.calculate_area_with_one_state(elapsed_time, right) - def validate_states( - self, left: State, right: State - ) -> tuple[Decimal, Decimal] | None: - if (right_dec := _decimal_state(right.state)) is None: + def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None: + if (right_dec := _decimal_state(right)) is None: return None return (right_dec, right_dec) @@ -183,7 +178,7 @@ _NAME_TO_INTEGRATION_METHOD: dict[str, type[_IntegrationMethod]] = { class _IntegrationTrigger(Enum): - StateChange = "state_change" + StateEvent = "state_event" TimeElapsed = "time_elapsed" @@ -343,7 +338,7 @@ class IntegrationSensor(RestoreSensor): ) self._max_sub_interval_exceeded_callback: CALLBACK_TYPE = lambda *args: None self._last_integration_time: datetime = datetime.now(tz=UTC) - self._last_integration_trigger = _IntegrationTrigger.StateChange + self._last_integration_trigger = _IntegrationTrigger.StateEvent self._attr_suggested_display_precision = round_digits or 2 def _calculate_unit(self, source_unit: str) -> str: @@ -433,9 +428,11 @@ class IntegrationSensor(RestoreSensor): source_state = self.hass.states.get(self._sensor_source_id) self._schedule_max_sub_interval_exceeded_if_state_is_numeric(source_state) self.async_on_remove(self._cancel_max_sub_interval_exceeded_callback) - handle_state_change = self._integrate_on_state_change_and_max_sub_interval + handle_state_change = self._integrate_on_state_change_with_max_sub_interval + handle_state_report = self._integrate_on_state_report_with_max_sub_interval else: handle_state_change = self._integrate_on_state_change_callback + handle_state_report = self._integrate_on_state_report_callback if ( state := self.hass.states.get(self._source_entity) @@ -443,16 +440,50 @@ class IntegrationSensor(RestoreSensor): self._derive_and_set_attributes_from_state(state) self.async_on_remove( - async_track_state_change_event( - self.hass, - [self._sensor_source_id], + self.hass.bus.async_listen( + EVENT_STATE_CHANGED, handle_state_change, + event_filter=callback( + lambda event_data: event_data["entity_id"] == self._sensor_source_id + ), + run_immediately=True, + ) + ) + self.async_on_remove( + self.hass.bus.async_listen( + EVENT_STATE_REPORTED, + handle_state_report, + event_filter=callback( + lambda event_data: event_data["entity_id"] == self._sensor_source_id + ), + run_immediately=True, ) ) @callback - def _integrate_on_state_change_and_max_sub_interval( + def _integrate_on_state_change_with_max_sub_interval( self, event: Event[EventStateChangedData] + ) -> None: + """Handle sensor state update when sub interval is configured.""" + self._integrate_on_state_update_with_max_sub_interval( + None, event.data["old_state"], event.data["new_state"] + ) + + @callback + def _integrate_on_state_report_with_max_sub_interval( + self, event: Event[EventStateReportedData] + ) -> None: + """Handle sensor state report when sub interval is configured.""" + self._integrate_on_state_update_with_max_sub_interval( + event.data["old_last_reported"], None, event.data["new_state"] + ) + + @callback + def _integrate_on_state_update_with_max_sub_interval( + self, + old_last_reported: datetime | None, + old_state: State | None, + new_state: State | None, ) -> None: """Integrate based on state change and time. @@ -460,11 +491,9 @@ class IntegrationSensor(RestoreSensor): reschedules time based integration. """ self._cancel_max_sub_interval_exceeded_callback() - old_state = event.data["old_state"] - new_state = event.data["new_state"] try: - self._integrate_on_state_change(old_state, new_state) - self._last_integration_trigger = _IntegrationTrigger.StateChange + self._integrate_on_state_change(old_last_reported, old_state, new_state) + self._last_integration_trigger = _IntegrationTrigger.StateEvent self._last_integration_time = datetime.now(tz=UTC) finally: # When max_sub_interval exceeds without state change the source is assumed @@ -475,13 +504,25 @@ class IntegrationSensor(RestoreSensor): def _integrate_on_state_change_callback( self, event: Event[EventStateChangedData] ) -> None: - """Handle the sensor state changes.""" - old_state = event.data["old_state"] - new_state = event.data["new_state"] - return self._integrate_on_state_change(old_state, new_state) + """Handle sensor state change.""" + return self._integrate_on_state_change( + None, event.data["old_state"], event.data["new_state"] + ) + + @callback + def _integrate_on_state_report_callback( + self, event: Event[EventStateReportedData] + ) -> None: + """Handle sensor state report.""" + return self._integrate_on_state_change( + event.data["old_last_reported"], None, event.data["new_state"] + ) def _integrate_on_state_change( - self, old_state: State | None, new_state: State | None + self, + old_last_reported: datetime | None, + old_state: State | None, + new_state: State | None, ) -> None: if new_state is None: return @@ -491,21 +532,33 @@ class IntegrationSensor(RestoreSensor): self.async_write_ha_state() return + if old_state: + # state has changed, we recover old_state from the event + old_state_state = old_state.state + old_last_reported = old_state.last_reported + else: + # event state reported without any state change + old_state_state = new_state.state + self._attr_available = True self._derive_and_set_attributes_from_state(new_state) - if old_state is None: + if old_last_reported is None and old_state is None: self.async_write_ha_state() return - if not (states := self._method.validate_states(old_state, new_state)): + if not ( + states := self._method.validate_states(old_state_state, new_state.state) + ): self.async_write_ha_state() return + if TYPE_CHECKING: + assert old_last_reported is not None elapsed_seconds = Decimal( - (new_state.last_updated - old_state.last_updated).total_seconds() - if self._last_integration_trigger == _IntegrationTrigger.StateChange - else (new_state.last_updated - self._last_integration_time).total_seconds() + (new_state.last_reported - old_last_reported).total_seconds() + if self._last_integration_trigger == _IntegrationTrigger.StateEvent + else (new_state.last_reported - self._last_integration_time).total_seconds() ) area = self._method.calculate_area_with_two_states(elapsed_seconds, *states) diff --git a/tests/components/integration/test_sensor.py b/tests/components/integration/test_sensor.py index 10f921ce603..974c8bb8691 100644 --- a/tests/components/integration/test_sensor.py +++ b/tests/components/integration/test_sensor.py @@ -294,28 +294,16 @@ async def test_restore_state_failed(hass: HomeAssistant, extra_attributes) -> No assert state.state == STATE_UNKNOWN +@pytest.mark.parametrize("force_update", [False, True]) @pytest.mark.parametrize( - ("force_update", "sequence"), + "sequence", [ ( - False, - ( - (20, 10, 1.67), - (30, 30, 5.0), - (40, 5, 7.92), - (50, 5, 7.92), - (60, 0, 8.75), - ), - ), - ( - True, - ( - (20, 10, 1.67), - (30, 30, 5.0), - (40, 5, 7.92), - (50, 5, 8.75), - (60, 0, 9.17), - ), + (20, 10, 1.67), + (30, 30, 5.0), + (40, 5, 7.92), + (50, 5, 8.75), + (60, 0, 9.17), ), ], ) @@ -358,28 +346,16 @@ async def test_trapezoidal( assert state.attributes.get("unit_of_measurement") == UnitOfEnergy.KILO_WATT_HOUR +@pytest.mark.parametrize("force_update", [False, True]) @pytest.mark.parametrize( - ("force_update", "sequence"), + "sequence", [ ( - False, - ( - (20, 10, 0.0), - (30, 30, 1.67), - (40, 5, 6.67), - (50, 5, 6.67), - (60, 0, 8.33), - ), - ), - ( - True, - ( - (20, 10, 0.0), - (30, 30, 1.67), - (40, 5, 6.67), - (50, 5, 7.5), - (60, 0, 8.33), - ), + (20, 10, 0.0), + (30, 30, 1.67), + (40, 5, 6.67), + (50, 5, 7.5), + (60, 0, 8.33), ), ], ) @@ -425,28 +401,16 @@ async def test_left( assert state.attributes.get("unit_of_measurement") == UnitOfEnergy.KILO_WATT_HOUR +@pytest.mark.parametrize("force_update", [False, True]) @pytest.mark.parametrize( - ("force_update", "sequence"), + "sequence", [ ( - False, - ( - (20, 10, 3.33), - (30, 30, 8.33), - (40, 5, 9.17), - (50, 5, 9.17), - (60, 0, 9.17), - ), - ), - ( - True, - ( - (20, 10, 3.33), - (30, 30, 8.33), - (40, 5, 9.17), - (50, 5, 10.0), - (60, 0, 10.0), - ), + (20, 10, 3.33), + (30, 30, 8.33), + (40, 5, 9.17), + (50, 5, 10.0), + (60, 0, 10.0), ), ], )