Use reported state in Template

This commit is contained in:
G Johansson 2024-11-24 18:38:17 +00:00
parent 5b27f07f81
commit 6b7b149de6
4 changed files with 239 additions and 9 deletions

View File

@ -288,7 +288,11 @@ class SensorTemplate(TemplateEntity, SensorEntity):
def _async_setup_templates(self) -> None:
"""Set up templates."""
self.add_template_attribute(
"_attr_native_value", self._template, None, self._update_state
"_attr_native_value",
self._template,
None,
self._update_state,
use_reported=True,
)
if self._attr_last_reset_template is not None:
self.add_template_attribute(

View File

@ -164,6 +164,7 @@ class _TemplateAttribute:
validator: Callable[[Any], Any] | None = None,
on_update: Callable[[Any], None] | None = None,
none_on_template_error: bool | None = False,
use_reported: bool = False,
) -> None:
"""Template attribute."""
self._entity = entity
@ -173,6 +174,7 @@ class _TemplateAttribute:
self.on_update = on_update
self.async_update = None
self.none_on_template_error = none_on_template_error
self.use_reported = use_reported
@callback
def async_setup(self) -> None:
@ -394,6 +396,7 @@ class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module
validator: Callable[[Any], Any] | None = None,
on_update: Callable[[Any], None] | None = None,
none_on_template_error: bool = False,
use_reported: bool = False,
) -> None:
"""Call in the constructor to add a template linked to a attribute.
@ -412,6 +415,8 @@ class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module
if the template or validator resulted in an error.
none_on_template_error
If True, the attribute will be set to None if the template errors.
use_reported
If True, also update the attribute on reported values (not only changed).
"""
if self.hass is None:
@ -419,7 +424,13 @@ class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module
if template.hass is None:
raise ValueError("template.hass cannot be None")
template_attribute = _TemplateAttribute(
self, attribute, template, validator, on_update, none_on_template_error
self,
attribute,
template,
validator,
on_update,
none_on_template_error,
use_reported,
)
self._template_attrs.setdefault(template, [])
self._template_attrs[template].append(template_attribute)
@ -492,7 +503,13 @@ class TemplateEntity(Entity): # pylint: disable=hass-enforce-class-module
}
for template, attributes in self._template_attrs.items():
template_var_tup = TrackTemplate(template, variables)
use_reported = False
for attribute in attributes:
if attribute.use_reported:
use_reported = True
template_var_tup = TrackTemplate(
template, variables, use_reported=use_reported
)
is_availability_template = False
for attribute in attributes:
if attribute._attribute == "_attr_available": # noqa: SLF001

View File

@ -151,6 +151,7 @@ class TrackTemplate:
template: Template
variables: TemplateVarsType
rate_limit: float | None = None
use_reported: bool = False
@dataclass(slots=True)
@ -643,7 +644,7 @@ def _async_domain_added_filter(
def async_track_state_added_domain(
hass: HomeAssistant,
domains: str | Iterable[str],
action: Callable[[Event[EventStateChangedData]], Any],
action: Callable[[Event[EventStateChangedData | EventStateReportedData]], Any],
job_type: HassJobType | None = None,
) -> CALLBACK_TYPE:
"""Track state change events when an entity is added to domains."""
@ -881,6 +882,169 @@ def async_track_state_change_filtered(
return tracker
class _TrackStateReportedFiltered:
"""Handle removal / refresh of tracker."""
def __init__(
self,
hass: HomeAssistant,
track_states: TrackStates,
action: Callable[[Event[EventStateReportedData]], Any],
) -> None:
"""Handle removal / refresh of tracker init."""
self.hass = hass
self._action = action
self._action_as_hassjob = HassJob(
action, f"track state report filtered {track_states}"
)
self._listeners: dict[str, Callable[[], None]] = {}
self._last_track_states: TrackStates = track_states
@callback
def async_setup(self) -> None:
"""Create listeners to track states."""
track_states = self._last_track_states
if (
not track_states.all_states
and not track_states.domains
and not track_states.entities
):
return
if track_states.all_states:
self._setup_all_listener()
return
self._setup_domains_listener(track_states.domains)
self._setup_entities_listener(track_states.domains, track_states.entities)
@property
def listeners(self) -> dict[str, bool | set[str]]:
"""State changes that will cause a re-render."""
track_states = self._last_track_states
return {
_ALL_LISTENER: track_states.all_states,
_ENTITIES_LISTENER: track_states.entities,
_DOMAINS_LISTENER: track_states.domains,
}
@callback
def async_update_listeners(self, new_track_states: TrackStates) -> None:
"""Update the listeners based on the new TrackStates."""
last_track_states = self._last_track_states
self._last_track_states = new_track_states
had_all_listener = last_track_states.all_states
if new_track_states.all_states:
if had_all_listener:
return
self._cancel_listener(_DOMAINS_LISTENER)
self._cancel_listener(_ENTITIES_LISTENER)
self._setup_all_listener()
return
if had_all_listener:
self._cancel_listener(_ALL_LISTENER)
domains_changed = new_track_states.domains != last_track_states.domains
if had_all_listener or domains_changed:
domains_changed = True
self._cancel_listener(_DOMAINS_LISTENER)
self._setup_domains_listener(new_track_states.domains)
if (
had_all_listener
or domains_changed
or new_track_states.entities != last_track_states.entities
):
self._cancel_listener(_ENTITIES_LISTENER)
self._setup_entities_listener(
new_track_states.domains, new_track_states.entities
)
@callback
def async_remove(self) -> None:
"""Cancel the listeners."""
for key in list(self._listeners):
self._listeners.pop(key)()
@callback
def _cancel_listener(self, listener_name: str) -> None:
if listener_name not in self._listeners:
return
self._listeners.pop(listener_name)()
@callback
def _setup_entities_listener(self, domains: set[str], entities: set[str]) -> None:
if domains:
entities = entities.copy()
entities.update(self.hass.states.async_entity_ids(domains))
# Entities has changed to none
if not entities:
return
self._listeners[_ENTITIES_LISTENER] = async_track_state_report_event(
self.hass, entities, self._action, self._action_as_hassjob.job_type
)
@callback
def _state_reported(self, event: Event[EventStateReportedData]) -> None:
self._cancel_listener(_ENTITIES_LISTENER)
self._setup_entities_listener(
self._last_track_states.domains, self._last_track_states.entities
)
self.hass.async_run_hass_job(self._action_as_hassjob, event)
@callback
def _setup_domains_listener(self, domains: set[str]) -> None:
if not domains:
return
self._listeners[_DOMAINS_LISTENER] = _async_track_state_added_domain(
self.hass, domains, self._state_reported, HassJobType.Callback
)
@callback
def _setup_all_listener(self) -> None:
self._listeners[_ALL_LISTENER] = self.hass.bus.async_listen(
EVENT_STATE_REPORTED, self._action
)
@callback
@bind_hass
def async_track_state_report_filtered(
hass: HomeAssistant,
track_states: TrackStates,
action: Callable[[Event[EventStateReportedData]], Any],
) -> _TrackStateReportedFiltered:
"""Track state changes with a TrackStates filter that can be updated.
Parameters
----------
hass
Home assistant object.
track_states
A TrackStates data class.
action
Callable to call with results.
Returns
-------
Object used to update the listeners (async_update_listeners) with a new
TrackStates or cancel the tracking (async_remove).
"""
tracker = _TrackStateReportedFiltered(hass, track_states, action)
tracker.async_setup()
return tracker
@callback
@bind_hass
def async_track_template(
@ -992,7 +1156,10 @@ class TrackTemplateResultInfo:
self._last_result: dict[Template, bool | str | TemplateError] = {}
self._uses_reported = False
for track_template_ in track_templates:
if track_template_.use_reported:
self._uses_reported = True
if track_template_.template.hass:
continue
@ -1005,7 +1172,9 @@ class TrackTemplateResultInfo:
self._rate_limit = KeyedRateLimit(hass)
self._info: dict[Template, RenderInfo] = {}
self._info_reports: dict[Template, RenderInfo] = {}
self._track_state_changes: _TrackStateChangeFiltered | None = None
self._track_state_reports: _TrackStateReportedFiltered | None = None
self._time_listeners: dict[Template, Callable[[], None]] = {}
def __repr__(self) -> str:
@ -1049,6 +1218,10 @@ class TrackTemplateResultInfo:
self._info[template] = info = template.async_render_to_info(
variables, strict=strict, log_fn=log_fn
)
if track_template_.use_reported:
self._info_reports[template] = template.async_render_to_info(
variables, strict=strict, log_fn=log_fn
)
if info.exception:
if not log_fn:
@ -1063,6 +1236,11 @@ class TrackTemplateResultInfo:
self._track_state_changes = async_track_state_change_filtered(
self.hass, _render_infos_to_track_states(self._info.values()), self._refresh
)
self._track_state_reports = async_track_state_report_filtered(
self.hass,
_render_infos_to_track_states(self._info_reports.values()),
self._refresh,
)
self._update_time_listeners()
_LOGGER.debug(
(
@ -1078,6 +1256,13 @@ class TrackTemplateResultInfo:
def listeners(self) -> dict[str, bool | set[str]]:
"""State changes that will cause a re-render."""
assert self._track_state_changes
if self._uses_reported:
assert self._track_state_reports
return {
**self._track_state_changes.listeners,
**self._track_state_reports.listeners,
"time": bool(self._time_listeners),
}
return {
**self._track_state_changes.listeners,
"time": bool(self._time_listeners),
@ -1131,7 +1316,7 @@ class TrackTemplateResultInfo:
self,
track_template_: TrackTemplate,
now: float,
event: Event[EventStateChangedData] | None,
event: Event[EventStateChangedData | EventStateReportedData] | None,
) -> bool | TrackTemplateResult:
"""Re-render the template if conditions match.
@ -1183,7 +1368,11 @@ class TrackTemplateResultInfo:
last_result = self._last_result.get(template)
# Check to see if the result has changed or is new
if result == last_result and template in self._last_result:
if (
not track_template_.use_reported
and result == last_result
and template in self._last_result
):
return True
if isinstance(result, TemplateError) and isinstance(last_result, TemplateError):
@ -1220,7 +1409,7 @@ class TrackTemplateResultInfo:
@callback
def _refresh(
self,
event: Event[EventStateChangedData] | None,
event: Event[EventStateChangedData | EventStateReportedData] | None,
track_templates: Iterable[TrackTemplate] | None = None,
replayed: bool | None = False,
) -> None:
@ -1284,6 +1473,18 @@ class TrackTemplateResultInfo:
)
if info_changed:
if event and event.event_type is EventStateReportedData:
assert self._track_state_reports
self._track_state_reports.async_update_listeners(
_render_infos_to_track_states(
[
_suppress_domain_all_in_render_info(info)
if self._rate_limit.async_has_timer(template)
else info
for template, info in self._info.items()
]
)
)
assert self._track_state_changes
self._track_state_changes.async_update_listeners(
_render_infos_to_track_states(
@ -1316,7 +1517,7 @@ class TrackTemplateResultInfo:
type TrackTemplateResultListener = Callable[
[
Event[EventStateChangedData] | None,
Event[EventStateChangedData] | Event[EventStateReportedData] | None,
list[TrackTemplateResult],
],
Coroutine[Any, Any, None] | None,
@ -1982,7 +2183,7 @@ def _event_triggers_rerender(
@callback
def _rate_limit_for_event(
event: Event[EventStateChangedData],
event: Event[EventStateChangedData | EventStateReportedData],
info: RenderInfo,
track_template_: TrackTemplate,
) -> float | None:

View File

@ -115,6 +115,14 @@ async def test_template_legacy(hass: HomeAssistant) -> None:
hass.states.async_set("sensor.test_state", "Works")
await hass.async_block_till_done()
assert hass.states.get(TEST_NAME).state == "It Works."
entity_reported = hass.states.get(TEST_NAME).last_reported_timestamp
entity_changed = hass.states.get(TEST_NAME).last_changed_timestamp
hass.states.async_set("sensor.test_state", "Works")
await hass.async_block_till_done()
assert hass.states.get(TEST_NAME).state == "It Works."
assert hass.states.get(TEST_NAME).last_reported_timestamp > entity_reported
assert hass.states.get(TEST_NAME).last_changed_timestamp == entity_changed
@pytest.mark.parametrize(("count", "domain"), [(1, sensor.DOMAIN)])