diff --git a/homeassistant/components/integration/sensor.py b/homeassistant/components/integration/sensor.py index 106eb9cc79c..3baa1e3c1d5 100644 --- a/homeassistant/components/integration/sensor.py +++ b/homeassistant/components/integration/sensor.py @@ -8,7 +8,7 @@ from datetime import UTC, datetime, timedelta from decimal import Decimal, InvalidOperation from enum import Enum import logging -from typing import Any, Final, Self +from typing import TYPE_CHECKING, Any, Final, Self import voluptuous as vol @@ -27,13 +27,14 @@ from homeassistant.const import ( CONF_METHOD, CONF_NAME, CONF_UNIQUE_ID, + EVENT_STATE_REPORTED, STATE_UNAVAILABLE, UnitOfTime, ) from homeassistant.core import ( CALLBACK_TYPE, Event, - EventStateChangedData, + EventStateReportedData, HomeAssistant, State, callback, @@ -42,7 +43,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers.device import async_device_info_to_link_from_entity from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.helpers.event import async_call_later, async_track_state_change_event +from homeassistant.helpers.event import async_call_later from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from .const import ( @@ -107,9 +108,7 @@ class _IntegrationMethod(ABC): return _NAME_TO_INTEGRATION_METHOD[method_name]() @abstractmethod - def validate_states( - self, left: State, right: State - ) -> tuple[Decimal, Decimal] | None: + def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None: """Check state requirements for integration.""" @abstractmethod @@ -130,11 +129,9 @@ class _Trapezoidal(_IntegrationMethod): ) -> Decimal: return elapsed_time * (left + right) / 2 - def validate_states( - self, left: State, right: State - ) -> tuple[Decimal, Decimal] | None: - if (left_dec := _decimal_state(left.state)) is None or ( - right_dec := _decimal_state(right.state) + def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None: + if (left_dec := _decimal_state(left)) is None or ( + right_dec := _decimal_state(right) ) is None: return None return (left_dec, right_dec) @@ -146,10 +143,8 @@ class _Left(_IntegrationMethod): ) -> Decimal: return self.calculate_area_with_one_state(elapsed_time, left) - def validate_states( - self, left: State, right: State - ) -> tuple[Decimal, Decimal] | None: - if (left_dec := _decimal_state(left.state)) is None: + def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None: + if (left_dec := _decimal_state(left)) is None: return None return (left_dec, left_dec) @@ -160,10 +155,8 @@ class _Right(_IntegrationMethod): ) -> Decimal: return self.calculate_area_with_one_state(elapsed_time, right) - def validate_states( - self, left: State, right: State - ) -> tuple[Decimal, Decimal] | None: - if (right_dec := _decimal_state(right.state)) is None: + def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None: + if (right_dec := _decimal_state(right)) is None: return None return (right_dec, right_dec) @@ -183,7 +176,7 @@ _NAME_TO_INTEGRATION_METHOD: dict[str, type[_IntegrationMethod]] = { class _IntegrationTrigger(Enum): - StateChange = "state_change" + StateReport = "state_report" TimeElapsed = "time_elapsed" @@ -343,7 +336,7 @@ class IntegrationSensor(RestoreSensor): ) self._max_sub_interval_exceeded_callback: CALLBACK_TYPE = lambda *args: None self._last_integration_time: datetime = datetime.now(tz=UTC) - self._last_integration_trigger = _IntegrationTrigger.StateChange + self._last_integration_trigger = _IntegrationTrigger.StateReport self._attr_suggested_display_precision = round_digits or 2 def _calculate_unit(self, source_unit: str) -> str: @@ -443,16 +436,19 @@ class IntegrationSensor(RestoreSensor): self._derive_and_set_attributes_from_state(state) self.async_on_remove( - async_track_state_change_event( - self.hass, - [self._sensor_source_id], + self.hass.bus.async_listen( + EVENT_STATE_REPORTED, handle_state_change, + event_filter=callback( + lambda event_data: event_data["entity_id"] == self._sensor_source_id + ), + run_immediately=True, ) ) @callback def _integrate_on_state_change_and_max_sub_interval( - self, event: Event[EventStateChangedData] + self, event: Event[EventStateReportedData] ) -> None: """Integrate based on state change and time. @@ -460,11 +456,12 @@ class IntegrationSensor(RestoreSensor): reschedules time based integration. """ self._cancel_max_sub_interval_exceeded_callback() - old_state = event.data["old_state"] + old_last_reported = event.data.get("old_last_reported") + old_state = event.data.get("old_state") new_state = event.data["new_state"] try: - self._integrate_on_state_change(old_state, new_state) - self._last_integration_trigger = _IntegrationTrigger.StateChange + self._integrate_on_state_change(old_last_reported, old_state, new_state) + self._last_integration_trigger = _IntegrationTrigger.StateReport self._last_integration_time = datetime.now(tz=UTC) finally: # When max_sub_interval exceeds without state change the source is assumed @@ -473,15 +470,19 @@ class IntegrationSensor(RestoreSensor): @callback def _integrate_on_state_change_callback( - self, event: Event[EventStateChangedData] + self, event: Event[EventStateReportedData] ) -> None: """Handle the sensor state changes.""" - old_state = event.data["old_state"] + old_last_reported = event.data.get("old_last_reported") + old_state = event.data.get("old_state") new_state = event.data["new_state"] - return self._integrate_on_state_change(old_state, new_state) + return self._integrate_on_state_change(old_last_reported, old_state, new_state) def _integrate_on_state_change( - self, old_state: State | None, new_state: State | None + self, + old_last_reported: datetime | None, + old_state: State | None, + new_state: State | None, ) -> None: if new_state is None: return @@ -491,21 +492,33 @@ class IntegrationSensor(RestoreSensor): self.async_write_ha_state() return + if old_state: + # state has changed, we recover old_state from the event + old_state_state = old_state.state + old_last_reported = old_state.last_reported + else: + # event state reported without any state change + old_state_state = new_state.state + self._attr_available = True self._derive_and_set_attributes_from_state(new_state) - if old_state is None: + if old_last_reported is None and old_state is None: self.async_write_ha_state() return - if not (states := self._method.validate_states(old_state, new_state)): + if not ( + states := self._method.validate_states(old_state_state, new_state.state) + ): self.async_write_ha_state() return + if TYPE_CHECKING: + assert old_last_reported is not None elapsed_seconds = Decimal( - (new_state.last_updated - old_state.last_updated).total_seconds() - if self._last_integration_trigger == _IntegrationTrigger.StateChange - else (new_state.last_updated - self._last_integration_time).total_seconds() + (new_state.last_reported - old_last_reported).total_seconds() + if self._last_integration_trigger == _IntegrationTrigger.StateReport + else (new_state.last_reported - self._last_integration_time).total_seconds() ) area = self._method.calculate_area_with_two_states(elapsed_seconds, *states)