diff --git a/homeassistant/components/template/sensor.py b/homeassistant/components/template/sensor.py index ee24407699d..06d8824642b 100644 --- a/homeassistant/components/template/sensor.py +++ b/homeassistant/components/template/sensor.py @@ -288,7 +288,11 @@ class SensorTemplate(TemplateEntity, SensorEntity): def _async_setup_templates(self) -> None: """Set up templates.""" self.add_template_attribute( - "_attr_native_value", self._template, None, self._update_state + "_attr_native_value", + self._template, + None, + self._update_state, + use_reported=True, ) if self._attr_last_reset_template is not None: self.add_template_attribute( diff --git a/homeassistant/components/template/template_entity.py b/homeassistant/components/template/template_entity.py index f5b84b1ad7a..d71b90b8a91 100644 --- a/homeassistant/components/template/template_entity.py +++ b/homeassistant/components/template/template_entity.py @@ -164,6 +164,7 @@ class _TemplateAttribute: validator: Callable[[Any], Any] | None = None, on_update: Callable[[Any], None] | None = None, none_on_template_error: bool | None = False, + use_reported: bool = False, ) -> None: """Template attribute.""" self._entity = entity @@ -173,6 +174,7 @@ class _TemplateAttribute: self.on_update = on_update self.async_update = None self.none_on_template_error = none_on_template_error + self.use_reported = use_reported @callback def async_setup(self) -> None: @@ -394,6 +396,7 @@ class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module validator: Callable[[Any], Any] | None = None, on_update: Callable[[Any], None] | None = None, none_on_template_error: bool = False, + use_reported: bool = False, ) -> None: """Call in the constructor to add a template linked to a attribute. @@ -412,6 +415,8 @@ class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module if the template or validator resulted in an error. none_on_template_error If True, the attribute will be set to None if the template errors. + use_reported + If True, also update the attribute on reported values (not only changed). """ if self.hass is None: @@ -419,7 +424,13 @@ class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module if template.hass is None: raise ValueError("template.hass cannot be None") template_attribute = _TemplateAttribute( - self, attribute, template, validator, on_update, none_on_template_error + self, + attribute, + template, + validator, + on_update, + none_on_template_error, + use_reported, ) self._template_attrs.setdefault(template, []) self._template_attrs[template].append(template_attribute) @@ -492,7 +503,13 @@ class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module } for template, attributes in self._template_attrs.items(): - template_var_tup = TrackTemplate(template, variables) + use_reported = False + for attribute in attributes: + if attribute.use_reported: + use_reported = True + template_var_tup = TrackTemplate( + template, variables, use_reported=use_reported + ) is_availability_template = False for attribute in attributes: if attribute._attribute == "_attr_available": # noqa: SLF001 diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 578132f358f..2fdae476e15 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -151,6 +151,7 @@ class TrackTemplate: template: Template variables: TemplateVarsType rate_limit: float | None = None + use_reported: bool = False @dataclass(slots=True) @@ -643,7 +644,7 @@ def _async_domain_added_filter( def async_track_state_added_domain( hass: HomeAssistant, domains: str | Iterable[str], - action: Callable[[Event[EventStateChangedData]], Any], + action: Callable[[Event[EventStateChangedData | EventStateReportedData]], Any], job_type: HassJobType | None = None, ) -> CALLBACK_TYPE: """Track state change events when an entity is added to domains.""" @@ -881,6 +882,169 @@ def async_track_state_change_filtered( return tracker +class _TrackStateReportedFiltered: + """Handle removal / refresh of tracker.""" + + def __init__( + self, + hass: HomeAssistant, + track_states: TrackStates, + action: Callable[[Event[EventStateReportedData]], Any], + ) -> None: + """Handle removal / refresh of tracker init.""" + self.hass = hass + self._action = action + self._action_as_hassjob = HassJob( + action, f"track state report filtered {track_states}" + ) + self._listeners: dict[str, Callable[[], None]] = {} + self._last_track_states: TrackStates = track_states + + @callback + def async_setup(self) -> None: + """Create listeners to track states.""" + track_states = self._last_track_states + + if ( + not track_states.all_states + and not track_states.domains + and not track_states.entities + ): + return + + if track_states.all_states: + self._setup_all_listener() + return + + self._setup_domains_listener(track_states.domains) + self._setup_entities_listener(track_states.domains, track_states.entities) + + @property + def listeners(self) -> dict[str, bool | set[str]]: + """State changes that will cause a re-render.""" + track_states = self._last_track_states + return { + _ALL_LISTENER: track_states.all_states, + _ENTITIES_LISTENER: track_states.entities, + _DOMAINS_LISTENER: track_states.domains, + } + + @callback + def async_update_listeners(self, new_track_states: TrackStates) -> None: + """Update the listeners based on the new TrackStates.""" + last_track_states = self._last_track_states + self._last_track_states = new_track_states + + had_all_listener = last_track_states.all_states + + if new_track_states.all_states: + if had_all_listener: + return + self._cancel_listener(_DOMAINS_LISTENER) + self._cancel_listener(_ENTITIES_LISTENER) + self._setup_all_listener() + return + + if had_all_listener: + self._cancel_listener(_ALL_LISTENER) + + domains_changed = new_track_states.domains != last_track_states.domains + + if had_all_listener or domains_changed: + domains_changed = True + self._cancel_listener(_DOMAINS_LISTENER) + self._setup_domains_listener(new_track_states.domains) + + if ( + had_all_listener + or domains_changed + or new_track_states.entities != last_track_states.entities + ): + self._cancel_listener(_ENTITIES_LISTENER) + self._setup_entities_listener( + new_track_states.domains, new_track_states.entities + ) + + @callback + def async_remove(self) -> None: + """Cancel the listeners.""" + for key in list(self._listeners): + self._listeners.pop(key)() + + @callback + def _cancel_listener(self, listener_name: str) -> None: + if listener_name not in self._listeners: + return + + self._listeners.pop(listener_name)() + + @callback + def _setup_entities_listener(self, domains: set[str], entities: set[str]) -> None: + if domains: + entities = entities.copy() + entities.update(self.hass.states.async_entity_ids(domains)) + + # Entities has changed to none + if not entities: + return + + self._listeners[_ENTITIES_LISTENER] = async_track_state_report_event( + self.hass, entities, self._action, self._action_as_hassjob.job_type + ) + + @callback + def _state_reported(self, event: Event[EventStateReportedData]) -> None: + self._cancel_listener(_ENTITIES_LISTENER) + self._setup_entities_listener( + self._last_track_states.domains, self._last_track_states.entities + ) + self.hass.async_run_hass_job(self._action_as_hassjob, event) + + @callback + def _setup_domains_listener(self, domains: set[str]) -> None: + if not domains: + return + + self._listeners[_DOMAINS_LISTENER] = _async_track_state_added_domain( + self.hass, domains, self._state_reported, HassJobType.Callback + ) + + @callback + def _setup_all_listener(self) -> None: + self._listeners[_ALL_LISTENER] = self.hass.bus.async_listen( + EVENT_STATE_REPORTED, self._action + ) + + +@callback +@bind_hass +def async_track_state_report_filtered( + hass: HomeAssistant, + track_states: TrackStates, + action: Callable[[Event[EventStateReportedData]], Any], +) -> _TrackStateReportedFiltered: + """Track state changes with a TrackStates filter that can be updated. + + Parameters + ---------- + hass + Home assistant object. + track_states + A TrackStates data class. + action + Callable to call with results. + + Returns + ------- + Object used to update the listeners (async_update_listeners) with a new + TrackStates or cancel the tracking (async_remove). + + """ + tracker = _TrackStateReportedFiltered(hass, track_states, action) + tracker.async_setup() + return tracker + + @callback @bind_hass def async_track_template( @@ -992,7 +1156,10 @@ class TrackTemplateResultInfo: self._last_result: dict[Template, bool | str | TemplateError] = {} + self._uses_reported = False for track_template_ in track_templates: + if track_template_.use_reported: + self._uses_reported = True if track_template_.template.hass: continue @@ -1005,7 +1172,9 @@ class TrackTemplateResultInfo: self._rate_limit = KeyedRateLimit(hass) self._info: dict[Template, RenderInfo] = {} + self._info_reports: dict[Template, RenderInfo] = {} self._track_state_changes: _TrackStateChangeFiltered | None = None + self._track_state_reports: _TrackStateReportedFiltered | None = None self._time_listeners: dict[Template, Callable[[], None]] = {} def __repr__(self) -> str: @@ -1049,6 +1218,10 @@ class TrackTemplateResultInfo: self._info[template] = info = template.async_render_to_info( variables, strict=strict, log_fn=log_fn ) + if track_template_.use_reported: + self._info_reports[template] = template.async_render_to_info( + variables, strict=strict, log_fn=log_fn + ) if info.exception: if not log_fn: @@ -1063,6 +1236,11 @@ class TrackTemplateResultInfo: self._track_state_changes = async_track_state_change_filtered( self.hass, _render_infos_to_track_states(self._info.values()), self._refresh ) + self._track_state_reports = async_track_state_report_filtered( + self.hass, + _render_infos_to_track_states(self._info_reports.values()), + self._refresh, + ) self._update_time_listeners() _LOGGER.debug( ( @@ -1078,6 +1256,13 @@ class TrackTemplateResultInfo: def listeners(self) -> dict[str, bool | set[str]]: """State changes that will cause a re-render.""" assert self._track_state_changes + if self._uses_reported: + assert self._track_state_reports + return { + **self._track_state_changes.listeners, + **self._track_state_reports.listeners, + "time": bool(self._time_listeners), + } return { **self._track_state_changes.listeners, "time": bool(self._time_listeners), @@ -1131,7 +1316,7 @@ class TrackTemplateResultInfo: self, track_template_: TrackTemplate, now: float, - event: Event[EventStateChangedData] | None, + event: Event[EventStateChangedData | EventStateReportedData] | None, ) -> bool | TrackTemplateResult: """Re-render the template if conditions match. @@ -1183,7 +1368,11 @@ class TrackTemplateResultInfo: last_result = self._last_result.get(template) # Check to see if the result has changed or is new - if result == last_result and template in self._last_result: + if ( + not track_template_.use_reported + and result == last_result + and template in self._last_result + ): return True if isinstance(result, TemplateError) and isinstance(last_result, TemplateError): @@ -1220,7 +1409,7 @@ class TrackTemplateResultInfo: @callback def _refresh( self, - event: Event[EventStateChangedData] | None, + event: Event[EventStateChangedData | EventStateReportedData] | None, track_templates: Iterable[TrackTemplate] | None = None, replayed: bool | None = False, ) -> None: @@ -1284,6 +1473,18 @@ class TrackTemplateResultInfo: ) if info_changed: + if event and event.event_type is EventStateReportedData: + assert self._track_state_reports + self._track_state_reports.async_update_listeners( + _render_infos_to_track_states( + [ + _suppress_domain_all_in_render_info(info) + if self._rate_limit.async_has_timer(template) + else info + for template, info in self._info.items() + ] + ) + ) assert self._track_state_changes self._track_state_changes.async_update_listeners( _render_infos_to_track_states( @@ -1316,7 +1517,7 @@ class TrackTemplateResultInfo: type TrackTemplateResultListener = Callable[ [ - Event[EventStateChangedData] | None, + Event[EventStateChangedData] | Event[EventStateReportedData] | None, list[TrackTemplateResult], ], Coroutine[Any, Any, None] | None, @@ -1982,7 +2183,7 @@ def _event_triggers_rerender( @callback def _rate_limit_for_event( - event: Event[EventStateChangedData], + event: Event[EventStateChangedData | EventStateReportedData], info: RenderInfo, track_template_: TrackTemplate, ) -> float | None: diff --git a/tests/components/template/test_sensor.py b/tests/components/template/test_sensor.py index 929a890ab38..f3753736dbb 100644 --- a/tests/components/template/test_sensor.py +++ b/tests/components/template/test_sensor.py @@ -115,6 +115,14 @@ async def test_template_legacy(hass: HomeAssistant) -> None: hass.states.async_set("sensor.test_state", "Works") await hass.async_block_till_done() assert hass.states.get(TEST_NAME).state == "It Works." + entity_reported = hass.states.get(TEST_NAME).last_reported_timestamp + entity_changed = hass.states.get(TEST_NAME).last_changed_timestamp + + hass.states.async_set("sensor.test_state", "Works") + await hass.async_block_till_done() + assert hass.states.get(TEST_NAME).state == "It Works." + assert hass.states.get(TEST_NAME).last_reported_timestamp > entity_reported + assert hass.states.get(TEST_NAME).last_changed_timestamp == entity_changed @pytest.mark.parametrize(("count", "domain"), [(1, sensor.DOMAIN)])