Reduce the number of template re-renders when we are only counting states (#40272)

This commit is contained in:
J. Nick Koston 2020-09-26 16:29:49 -05:00 committed by GitHub
parent b8f837365c
commit 3261a904da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 21 deletions

View File

@ -912,7 +912,7 @@ class StateMachine:
This method must be run in the event loop.
"""
if domain_filter is None:
return list(self._states.keys())
return list(self._states)
if isinstance(domain_filter, str):
domain_filter = (domain_filter.lower(),)
@ -932,7 +932,7 @@ class StateMachine:
This method must be run in the event loop.
"""
if domain_filter is None:
return len(self._states.keys())
return len(self._states)
if isinstance(domain_filter, str):
domain_filter = (domain_filter.lower(),)

View File

@ -602,7 +602,10 @@ class _TrackTemplateResultInfo:
template = track_template_.template
# Tracking all states
if self._info[template].all_states:
if (
self._info[template].all_states
or self._info[template].all_states_lifecycle
):
return True
# Previous call had an exception
@ -719,6 +722,9 @@ class _TrackTemplateResultInfo:
@callback
def _refresh(self, event: Optional[Event]) -> None:
entity_id = event and event.data.get(ATTR_ENTITY_ID)
lifecycle_event = event and (
event.data.get("new_state") is None or event.data.get("old_state") is None
)
updates = []
info_changed = False
@ -726,13 +732,18 @@ class _TrackTemplateResultInfo:
template = track_template_.template
if (
entity_id
and len(self._last_info) > 1
and not self._last_info[template].filter_lifecycle(entity_id)
and not self._last_info[template].filter(entity_id)
and (
not lifecycle_event
or not self._last_info[template].filter_lifecycle(entity_id)
)
):
continue
_LOGGER.debug(
"Template update %s triggered by event: %s", template.template, event
"Template update %s triggered by event: %s",
template.template,
event,
)
self._info[template] = template.async_render_to_info(
@ -1229,4 +1240,6 @@ def _entities_domains_from_info(render_infos: Iterable[RenderInfo]) -> Tuple[Set
entities.update(render_info.entities)
if render_info.domains:
domains.update(render_info.domains)
if render_info.domains_lifecycle:
domains.update(render_info.domains_lifecycle)
return entities, domains

View File

@ -164,6 +164,10 @@ def _true(arg: Any) -> bool:
return True
def _false(arg: Any) -> bool:
return False
class RenderInfo:
"""Holds information about a template render."""
@ -172,23 +176,30 @@ class RenderInfo:
self.template = template
# Will be set sensibly once frozen.
self.filter_lifecycle = _true
self.filter = _true
self._result = None
self.is_static = False
self.exception = None
self.all_states = False
self.all_states_lifecycle = False
self.domains = set()
self.domains_lifecycle = set()
self.entities = set()
def filter(self, entity_id: str) -> bool:
"""Template should re-render if the state changes."""
return entity_id in self.entities
def __repr__(self) -> str:
"""Representation of RenderInfo."""
return f"<RenderInfo {self.template} all_states={self.all_states} all_states_lifecycle={self.all_states_lifecycle} domains={self.domains} domains_lifecycle={self.domains_lifecycle} entities={self.entities}>"
def _filter_lifecycle(self, entity_id: str) -> bool:
"""Template should re-render if the state changes."""
def _filter_domains_and_entities(self, entity_id: str) -> bool:
"""Template should re-render if the entity state changes when we match specific domains or entities."""
return (
split_entity_id(entity_id)[0] in self.domains or entity_id in self.entities
)
def _filter_lifecycle_domains(self, entity_id: str) -> bool:
"""Template should re-render if the entity is added or removed with domains watched."""
return split_entity_id(entity_id)[0] in self.domains_lifecycle
def result(self) -> str:
"""Results of the template computation."""
if self.exception is not None:
@ -199,19 +210,30 @@ class RenderInfo:
self.is_static = True
self.entities = frozenset(self.entities)
self.domains = frozenset(self.domains)
self.domains_lifecycle = frozenset(self.domains_lifecycle)
self.all_states = False
def _freeze(self) -> None:
self.entities = frozenset(self.entities)
self.domains = frozenset(self.domains)
self.domains_lifecycle = frozenset(self.domains_lifecycle)
if self.all_states or self.exception:
if self.exception:
return
if not self.domains:
self.filter_lifecycle = self.filter
if not self.all_states_lifecycle:
if self.domains_lifecycle:
self.filter_lifecycle = self._filter_lifecycle_domains
else:
self.filter_lifecycle = self._filter_lifecycle
self.filter_lifecycle = _false
if self.all_states:
return
if self.entities or self.domains:
self.filter = self._filter_domains_and_entities
else:
self.filter = _false
class Template:
@ -422,6 +444,11 @@ class AllStates:
if render_info is not None:
render_info.all_states = True
def _collect_all_lifecycle(self) -> None:
render_info = self._hass.data.get(_RENDER_INFO)
if render_info is not None:
render_info.all_states_lifecycle = True
def __iter__(self):
"""Return all states."""
self._collect_all()
@ -429,7 +456,7 @@ class AllStates:
def __len__(self) -> int:
"""Return number of states."""
self._collect_all()
self._collect_all_lifecycle()
return self._hass.states.async_entity_ids_count()
def __call__(self, entity_id):
@ -462,6 +489,11 @@ class DomainStates:
if entity_collect is not None:
entity_collect.domains.add(self._domain)
def _collect_domain_lifecycle(self) -> None:
entity_collect = self._hass.data.get(_RENDER_INFO)
if entity_collect is not None:
entity_collect.domains_lifecycle.add(self._domain)
def __iter__(self):
"""Return the iteration over all the states."""
self._collect_domain()
@ -469,7 +501,7 @@ class DomainStates:
def __len__(self) -> int:
"""Return number of states."""
self._collect_domain()
self._collect_domain_lifecycle()
return self._hass.states.async_entity_ids_count(self._domain)
def __repr__(self) -> str:

View File

@ -54,16 +54,17 @@ def assert_result_info(info, result, entities=None, domains=None, all_states=Fal
"""Check result info."""
assert info.result() == result
assert info.all_states == all_states
assert info.filter_lifecycle("invalid_entity_name.somewhere") == all_states
assert info.filter("invalid_entity_name.somewhere") == all_states
if entities is not None:
assert info.entities == frozenset(entities)
assert all([info.filter(entity) for entity in entities])
if not all_states:
assert not info.filter("invalid_entity_name.somewhere")
else:
assert not info.entities
if domains is not None:
assert info.domains == frozenset(domains)
assert all([info.filter_lifecycle(domain + ".entity") for domain in domains])
assert all([info.filter(domain + ".entity") for domain in domains])
else:
assert not hasattr(info, "_domains")
@ -1958,7 +1959,8 @@ def test_generate_select(hass):
tmp = template.Template(template_str, hass)
info = tmp.async_render_to_info()
assert_result_info(info, "", [], ["sensor"])
assert_result_info(info, "", [], [])
assert info.domains_lifecycle == {"sensor"}
hass.states.async_set("sensor.test_sensor", "off", {"attr": "value"})
hass.states.async_set("sensor.test_sensor_on", "on")
@ -1970,6 +1972,7 @@ def test_generate_select(hass):
["sensor.test_sensor", "sensor.test_sensor_on"],
["sensor"],
)
assert info.domains_lifecycle == {"sensor"}
async def test_async_render_to_info_in_conditional(hass):
@ -2431,3 +2434,24 @@ async def test_slice_states(hass):
hass,
)
assert tpl.async_render() == "sensor.test"
async def test_lifecycle(hass):
"""Test that we limit template render info for lifecycle events."""
hass.states.async_set("sun.sun", "above", {"elevation": 50, "next_rising": "later"})
for i in range(2):
hass.states.async_set(f"sensor.sensor{i}", "on")
tmp = template.Template("{{ states | count }}", hass)
info = tmp.async_render_to_info()
assert info.all_states is False
assert info.all_states_lifecycle is True
assert info.entities == set()
assert info.domains == set()
assert info.domains_lifecycle == set()
assert info.filter("sun.sun") is False
assert info.filter("sensor.sensor1") is False
assert info.filter_lifecycle("sensor.new") is True
assert info.filter_lifecycle("sensor.removed") is True