Improve performance of accessing template state (#40323)

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
J. Nick Koston 2020-09-28 10:35:12 -05:00 committed by GitHub
parent 3596eb39f2
commit e564af0b5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 221 additions and 47 deletions

View File

@ -759,6 +759,7 @@ class State:
last_updated: last time this object was updated.
context: Context in which it was created
domain: Domain of this state.
object_id: Object id of this state.
"""
__slots__ = [
@ -769,6 +770,7 @@ class State:
"last_updated",
"context",
"domain",
"object_id",
]
def __init__(
@ -802,12 +804,7 @@ class State:
self.last_updated = last_updated or dt_util.utcnow()
self.last_changed = last_changed or self.last_updated
self.context = context or Context()
self.domain = split_entity_id(self.entity_id)[0]
@property
def object_id(self) -> str:
"""Object id of this state."""
return split_entity_id(self.entity_id)[1]
self.domain, self.object_id = split_entity_id(self.entity_id)
@property
def name(self) -> str:

View File

@ -61,6 +61,17 @@ _RESERVED_NAMES = {"contextfunction", "evalcontextfunction", "environmentfunctio
_GROUP_DOMAIN_PREFIX = "group."
_COLLECTABLE_STATE_ATTRIBUTES = {
"state",
"attributes",
"last_changed",
"last_updated",
"context",
"domain",
"object_id",
"name",
}
@bind_hass
def attach(hass: HomeAssistantType, obj: Any) -> None:
@ -477,9 +488,7 @@ class AllStates:
def __getattr__(self, name):
"""Return the domain state."""
if "." in name:
if not valid_entity_id(name):
raise TemplateError(f"Invalid entity ID '{name}'")
return _get_state(self._hass, name)
return _get_state_if_valid(self._hass, name)
if name in _RESERVED_NAMES:
return None
@ -489,6 +498,10 @@ class AllStates:
return DomainStates(self._hass, name)
# Jinja will try __getitem__ first and it avoids the need
# to call is_safe_attribute
__getitem__ = __getattr__
def _collect_all(self) -> None:
render_info = self._hass.data.get(_RENDER_INFO)
if render_info is not None:
@ -529,10 +542,11 @@ class DomainStates:
def __getattr__(self, name):
"""Return the states."""
entity_id = f"{self._domain}.{name}"
if not valid_entity_id(entity_id):
raise TemplateError(f"Invalid entity ID '{entity_id}'")
return _get_state(self._hass, entity_id)
return _get_state_if_valid(self._hass, f"{self._domain}.{name}")
# Jinja will try __getitem__ first and it avoids the need
# to call is_safe_attribute
__getitem__ = __getattr__
def _collect_domain(self) -> None:
entity_collect = self._hass.data.get(_RENDER_INFO)
@ -571,46 +585,96 @@ class TemplateState(State):
self._hass = hass
self._state = state
def _access_state(self):
state = object.__getattribute__(self, "_state")
hass = object.__getattribute__(self, "_hass")
_collect_state(hass, state.entity_id)
return state
def _collect_state(self):
if _RENDER_INFO in self._hass.data:
self._hass.data[_RENDER_INFO].entities.add(self._state.entity_id)
# Jinja will try __getitem__ first and it avoids the need
# to call is_safe_attribute
def __getitem__(self, item):
"""Return a property as an attribute for jinja."""
if item in _COLLECTABLE_STATE_ATTRIBUTES:
# _collect_state inlined here for performance
if _RENDER_INFO in self._hass.data:
self._hass.data[_RENDER_INFO].entities.add(self._state.entity_id)
return getattr(self._state, item)
if item == "entity_id":
return self._state.entity_id
if item == "state_with_unit":
return self.state_with_unit
raise KeyError
@property
def entity_id(self):
"""Wrap State.entity_id.
Intentionally does not collect state
"""
return self._state.entity_id
@property
def state(self):
"""Wrap State.state."""
self._collect_state()
return self._state.state
@property
def attributes(self):
"""Wrap State.attributes."""
self._collect_state()
return self._state.attributes
@property
def last_changed(self):
"""Wrap State.last_changed."""
self._collect_state()
return self._state.last_changed
@property
def last_updated(self):
"""Wrap State.last_updated."""
self._collect_state()
return self._state.last_updated
@property
def context(self):
"""Wrap State.context."""
self._collect_state()
return self._state.context
@property
def domain(self):
"""Wrap State.domain."""
self._collect_state()
return self._state.domain
@property
def object_id(self):
"""Wrap State.object_id."""
self._collect_state()
return self._state.object_id
@property
def name(self):
"""Wrap State.name."""
self._collect_state()
return self._state.name
@property
def state_with_unit(self) -> str:
"""Return the state concatenated with the unit if available."""
state = object.__getattribute__(self, "_access_state")()
unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
if unit is None:
return state.state
return f"{state.state} {unit}"
self._collect_state()
unit = self._state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
return f"{self._state.state} {unit}" if unit else self._state.state
def __eq__(self, other: Any) -> bool:
"""Ensure we collect on equality check."""
state = object.__getattribute__(self, "_state")
hass = object.__getattribute__(self, "_hass")
_collect_state(hass, state.entity_id)
return super().__eq__(other)
def __getattribute__(self, name):
"""Return an attribute of the state."""
# This one doesn't count as an access of the state
# since we either found it by looking direct for the ID
# or got it off an iterator.
if name == "entity_id" or name in object.__dict__:
state = object.__getattribute__(self, "_state")
return getattr(state, name)
if name in TemplateState.__dict__:
return object.__getattribute__(self, name)
state = object.__getattribute__(self, "_access_state")()
return getattr(state, name)
self._collect_state()
return self._state.__eq__(other)
def __repr__(self) -> str:
"""Representation of Template State."""
state = object.__getattribute__(self, "_access_state")()
rep = state.__repr__()
return f"<template {rep[1:]}"
return f"<template TemplateState({self._state.__repr__()})>"
def _collect_state(hass: HomeAssistantType, entity_id: str) -> None:
@ -625,8 +689,22 @@ def _state_generator(hass: HomeAssistantType, domain: Optional[str]) -> Generato
yield TemplateState(hass, state)
def _get_state(hass: HomeAssistantType, entity_id: str) -> Optional[TemplateState]:
def _get_state_if_valid(
hass: HomeAssistantType, entity_id: str
) -> Optional[TemplateState]:
state = hass.states.get(entity_id)
if state is None and not valid_entity_id(entity_id):
raise TemplateError(f"Invalid entity ID '{entity_id}'") # type: ignore
return _get_template_state_from_state(hass, entity_id, state)
def _get_state(hass: HomeAssistantType, entity_id: str) -> Optional[TemplateState]:
return _get_template_state_from_state(hass, entity_id, hass.states.get(entity_id))
def _get_template_state_from_state(
hass: HomeAssistantType, entity_id: str, state: Optional[State]
) -> Optional[TemplateState]:
if state is None:
# Only need to collect if none, if not none collect first actual
# access to the state properties in the state wrapper.
@ -1208,12 +1286,12 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
def is_safe_attribute(self, obj, attr, value):
"""Test if attribute is safe."""
if isinstance(obj, (AllStates, DomainStates, TemplateState)):
return not attr[0] == "_"
if isinstance(obj, Namespace):
return True
if isinstance(obj, (AllStates, DomainStates, TemplateState)):
return not attr.startswith("_")
return super().is_safe_attribute(obj, attr, value)
def compile(self, source, name=None, filename=None, raw=False, defer_init=False):

View File

@ -2483,3 +2483,102 @@ async def test_template_timeout(hass):
"""
tmp5 = template.Template(slow_template_str, hass)
assert await tmp5.async_render_will_timeout(0.000001) is True
async def test_lights(hass):
"""Test we can sort lights."""
tmpl = """
{% set lights_on = states.light|selectattr('state','eq','on')|map(attribute='name')|list %}
{% if lights_on|length == 0 %}
No lights on. Sleep well..
{% elif lights_on|length == 1 %}
The {{lights_on[0]}} light is on.
{% elif lights_on|length == 2 %}
The {{lights_on[0]}} and {{lights_on[1]}} lights are on.
{% else %}
The {{lights_on[:-1]|join(', ')}}, and {{lights_on[-1]}} lights are on.
{% endif %}
"""
states = []
for i in range(10):
states.append(f"light.sensor{i}")
hass.states.async_set(f"light.sensor{i}", "on")
tmp = template.Template(tmpl, hass)
info = tmp.async_render_to_info()
assert info.entities == set(states)
assert "lights are on" in info.result()
for i in range(10):
assert f"sensor{i}" in info.result()
async def test_state_attributes(hass):
"""Test state attributes."""
hass.states.async_set("sensor.test", "23")
tpl = template.Template(
"{{ states.sensor.test.last_changed }}",
hass,
)
assert tpl.async_render() == str(hass.states.get("sensor.test").last_changed)
tpl = template.Template(
"{{ states.sensor.test.object_id }}",
hass,
)
assert tpl.async_render() == hass.states.get("sensor.test").object_id
tpl = template.Template(
"{{ states.sensor.test.domain }}",
hass,
)
assert tpl.async_render() == hass.states.get("sensor.test").domain
tpl = template.Template(
"{{ states.sensor.test.context.id }}",
hass,
)
assert tpl.async_render() == hass.states.get("sensor.test").context.id
tpl = template.Template(
"{{ states.sensor.test.state_with_unit }}",
hass,
)
assert tpl.async_render() == "23"
tpl = template.Template(
"{{ states.sensor.test.invalid_prop }}",
hass,
)
assert tpl.async_render() == ""
tpl = template.Template(
"{{ states.sensor.test.invalid_prop.xx }}",
hass,
)
with pytest.raises(TemplateError):
tpl.async_render()
async def test_unavailable_states(hass):
"""Test watching unavailable states."""
for i in range(10):
hass.states.async_set(f"light.sensor{i}", "on")
hass.states.async_set("light.unavailable", "unavailable")
hass.states.async_set("light.unknown", "unknown")
hass.states.async_set("light.none", "none")
tpl = template.Template(
"{{ states | selectattr('state', 'in', ['unavailable','unknown','none']) | map(attribute='entity_id') | list | join(', ') }}",
hass,
)
assert tpl.async_render() == "light.none, light.unavailable, light.unknown"
tpl = template.Template(
"{{ states.light | selectattr('state', 'in', ['unavailable','unknown','none']) | map(attribute='entity_id') | list | join(', ') }}",
hass,
)
assert tpl.async_render() == "light.none, light.unavailable, light.unknown"