diff --git a/homeassistant/core.py b/homeassistant/core.py index 14699aba33e..7b812096fcb 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -905,7 +905,9 @@ class StateMachine: return future.result() @callback - def async_entity_ids(self, domain_filter: Optional[str] = None) -> List[str]: + def async_entity_ids( + self, domain_filter: Optional[Union[str, Iterable]] = None + ) -> List[str]: """List of entity ids that are being tracked. This method must be run in the event loop. @@ -913,12 +915,13 @@ class StateMachine: if domain_filter is None: return list(self._states.keys()) - domain_filter = domain_filter.lower() + if isinstance(domain_filter, str): + domain_filter = (domain_filter.lower(),) return [ state.entity_id for state in self._states.values() - if state.domain == domain_filter + if state.domain in domain_filter ] def all(self) -> List[State]: diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index f6c423a35af..6c7d6771a13 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta import functools as ft import logging import time -from typing import Any, Awaitable, Callable, Dict, Iterable, Optional, Union +from typing import Any, Awaitable, Callable, Iterable, Optional, Union import attr @@ -25,9 +25,11 @@ from homeassistant.core import ( callback, split_entity_id, ) +from homeassistant.exceptions import TemplateError from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED from homeassistant.helpers.sun import get_astral_event_next -from homeassistant.helpers.template import Template +from homeassistant.helpers.template import Template, result_as_boolean +from homeassistant.helpers.typing import TemplateVarsType from homeassistant.loader import bind_hass from homeassistant.util import dt as dt_util from homeassistant.util.async_ import run_callback_threadsafe @@ -354,36 +356,315 @@ def async_track_state_added_domain( def async_track_template( hass: HomeAssistant, template: Template, - action: Callable[[str, State, State], None], - variables: Optional[Dict[str, Any]] = None, -) -> CALLBACK_TYPE: - """Add a listener that track state changes with template condition.""" - from . import condition # pylint: disable=import-outside-toplevel + action: Callable[[str, Optional[State], Optional[State]], None], + variables: Optional[TemplateVarsType] = None, +) -> Callable[[], None]: + """Add a listener that fires when a a template evaluates to 'true'. - # Local variable to keep track of if the action has already been triggered - already_triggered = False + Listen for the result of the template becoming true, or a true-like + string result, such as 'On', 'Open', or 'Yes'. If the template results + in an error state when the value changes, this will be logged and not + passed through. + + If the initial check of the template is invalid and results in an + exception, the listener will still be registered but will only + fire if the template result becomes true without an exception. + + Action arguments + ---------------- + entity_id + ID of the entity that triggered the state change. + old_state + The old state of the entity that changed. + new_state + New state of the entity that changed. + + Parameters + ---------- + hass + Home assistant object. + template + The template to calculate. + action + Callable to call with results. See above for arguments. + variables + Variables to pass to the template. + + Returns + ------- + Callable to unregister the listener. + + """ @callback - def template_condition_listener(entity_id: str, from_s: State, to_s: State) -> None: + def state_changed_listener( + event: Event, + template: Template, + last_result: Optional[str], + result: Union[str, TemplateError], + ) -> None: """Check if condition is correct and run action.""" - nonlocal already_triggered - template_result = condition.async_template(hass, template, variables) + if isinstance(result, TemplateError): + _LOGGER.exception(result) + return - # Check to see if template returns true - if template_result and not already_triggered: - already_triggered = True - hass.async_run_job(action, entity_id, from_s, to_s) - elif not template_result: - already_triggered = False + if result_as_boolean(last_result) or not result_as_boolean(result): + return - return async_track_state_change( - hass, template.extract_entities(variables), template_condition_listener + hass.async_run_job( + action, + event.data.get("entity_id"), + event.data.get("old_state"), + event.data.get("new_state"), + ) + + info = async_track_template_result( + hass, template, state_changed_listener, variables ) + return info.async_remove + track_template = threaded_listener_factory(async_track_template) +_UNCHANGED = object() + + +class TrackTemplateResultInfo: + """Handle removal / refresh of tracker.""" + + def __init__( + self, + hass: HomeAssistant, + template: Template, + action: Callable, + variables: Optional[TemplateVarsType], + ): + """Handle removal / refresh of tracker init.""" + self.hass = hass + self._template = template + self._action = action + self._variables = variables + self._last_result: Optional[str] = None + self._last_exception = False + self._all_listener: Optional[Callable] = None + self._domains_listener: Optional[Callable] = None + self._entities_listener: Optional[Callable] = None + self._info = template.async_render_to_info(variables) + if self._info.exception: + self._last_exception = True + _LOGGER.exception(self._info.exception) + self._create_listeners() + self._last_info = self._info + + @property + def _needs_all_listener(self) -> bool: + # Tracking all states + if self._info.all_states: + return True + + # Previous call had an exception + # so we do not know which states + # to track + if self._info.exception: + return True + + # There are no entities in the template + # to track so this template will + # re-render on EVERY state change + if not self._info.domains and not self._info.entities: + return True + + return False + + @callback + def _create_listeners(self) -> None: + if self._info.is_static: + return + + if self._needs_all_listener: + self._setup_all_listener() + return + + if self._info.domains: + self._setup_domains_listener() + + if self._info.entities or self._info.domains: + self._setup_entities_listener() + + @callback + def _cancel_domains_listener(self) -> None: + if self._domains_listener is None: + return + self._domains_listener() + self._domains_listener = None + + @callback + def _cancel_entities_listener(self) -> None: + if self._entities_listener is None: + return + self._entities_listener() + self._entities_listener = None + + @callback + def _cancel_all_listener(self) -> None: + if self._all_listener is None: + return + self._all_listener() + self._all_listener = None + + @callback + def _update_listeners(self) -> None: + if self._needs_all_listener: + if self._all_listener: + return + self._cancel_domains_listener() + self._cancel_entities_listener() + self._setup_all_listener() + return + + had_all_listener = self._all_listener is not None + if had_all_listener: + self._cancel_all_listener() + + domains_changed = self._info.domains != self._last_info.domains + if had_all_listener or domains_changed: + domains_changed = True + self._cancel_domains_listener() + self._setup_domains_listener() + + if ( + had_all_listener + or domains_changed + or self._info.entities != self._last_info.entities + ): + self._cancel_entities_listener() + self._setup_entities_listener() + + @callback + def _setup_entities_listener(self) -> None: + entities = set(self._info.entities) + for entity_id in self.hass.states.async_entity_ids(self._info.domains): + entities.add(entity_id) + self._entities_listener = async_track_state_change_event( + self.hass, entities, self._refresh + ) + + @callback + def _setup_domains_listener(self) -> None: + self._domains_listener = async_track_state_added_domain( + self.hass, self._info.domains, self._refresh + ) + + @callback + def _setup_all_listener(self) -> None: + self._all_listener = self.hass.bus.async_listen( + EVENT_STATE_CHANGED, self._refresh + ) + + @callback + def async_remove(self) -> None: + """Cancel the listener.""" + self._cancel_all_listener() + self._cancel_domains_listener() + self._cancel_entities_listener() + + @callback + def async_refresh(self, variables: Any = _UNCHANGED) -> None: + """Force recalculate the template.""" + if variables is not _UNCHANGED: + self._variables = variables + self._refresh(None) + + def _refresh(self, event: Optional[Event]) -> None: + self._info = self._template.async_render_to_info(self._variables) + self._update_listeners() + self._last_info = self._info + + try: + result = self._info.result + except TemplateError as ex: + if not self._last_exception: + self.hass.async_run_job( + self._action, event, self._template, self._last_result, ex + ) + self._last_exception = True + return + self._last_exception = False + + # Check to see if the result has changed + if result == self._last_result: + return + + self.hass.async_run_job( + self._action, event, self._template, self._last_result, result + ) + self._last_result = result + + +TrackTemplateResultListener = Callable[ + [Event, Template, Optional[str], Union[str, TemplateError]], None +] +"""Type for the listener for template results. + + Action arguments + ---------------- + event + Event that caused the template to change output. None if not + triggered by an event. + template + The template that has changed. + last_result + The output from the template on the last successful run, or None + if no previous successful run. + result + Result from the template run. This will be a string or an + TemplateError if the template resulted in an error. +""" + + +@callback +@bind_hass +def async_track_template_result( + hass: HomeAssistant, + template: Template, + action: TrackTemplateResultListener, + variables: Optional[TemplateVarsType] = None, +) -> TrackTemplateResultInfo: + """Add a listener that fires when a the result of a template changes. + + The action will fire with the initial result from the template, and + then whenever the output from the template changes. The template will + be reevaluated if any states referenced in the last run of the + template change, or if manually triggered. If the result of the + evaluation is different from the previous run, the listener is passed + the result. + + If the template results in an TemplateError, this will be returned to + the listener the first time this happens but not for subsequent errors. + Once the template returns to a non-error condition the result is sent + to the action as usual. + + Parameters + ---------- + hass + Home assistant object. + template + The template to calculate. + action + Callable to call with results. + variables + Variables to pass to the template. + + Returns + ------- + Info object used to unregister the listener, and refresh the template. + + """ + return TrackTemplateResultInfo(hass, template, action, variables) + + @callback @bind_hass def async_track_same_state( diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index 140c233f41e..00f2c4c7296 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -8,7 +8,7 @@ import logging import math import random import re -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Iterable, List, Optional, Union from urllib.parse import urlencode as urllib_urlencode import weakref @@ -16,6 +16,7 @@ import jinja2 from jinja2 import contextfilter, contextfunction from jinja2.sandbox import ImmutableSandboxedEnvironment from jinja2.utils import Namespace # type: ignore +import voluptuous as vol from homeassistant.const import ( ATTR_ENTITY_ID, @@ -28,7 +29,7 @@ from homeassistant.const import ( ) from homeassistant.core import State, callback, split_entity_id, valid_entity_id from homeassistant.exceptions import TemplateError -from homeassistant.helpers import location as loc_helper +from homeassistant.helpers import config_validation as cv, location as loc_helper from homeassistant.helpers.typing import HomeAssistantType, TemplateVarsType from homeassistant.loader import bind_hass from homeassistant.util import convert, dt as dt_util, location as loc_util @@ -49,8 +50,13 @@ _RE_GET_ENTITIES = re.compile( r"(?:(?:(?:states\.|(?Pis_state|is_state_attr|state_attr|states|expand)\((?:[\ \'\"]?))(?P[\w]+\.[\w]+)|states\.(?P[a-z]+)|states\[(?:[\'\"]?)(?P[\w]+))|(?P[\w]+))", re.I | re.M, ) + _RE_JINJA_DELIMITERS = re.compile(r"\{%|\{\{") +_RESERVED_NAMES = {"contextfunction", "evalcontextfunction", "environmentfunction"} + +_GROUP_DOMAIN_PREFIX = "group." + @bind_hass def attach(hass: HomeAssistantType, obj: Any) -> None: @@ -79,7 +85,7 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any: def extract_entities( hass: HomeAssistantType, template: Optional[str], - variables: Optional[Dict[str, Any]] = None, + variables: TemplateVarsType = None, ) -> Union[str, List[str]]: """Extract all entities for state_changed listener from template string.""" if template is None or _RE_JINJA_DELIMITERS.search(template) is None: @@ -137,39 +143,45 @@ class RenderInfo: # Will be set sensibly once frozen. self.filter_lifecycle = _true self._result = None - self._exception = None - self._all_states = False - self._domains = [] - self._entities = [] + self.is_static = False + self.exception = None + self.all_states = False + self.domains = set() + self.entities = set() def filter(self, entity_id: str) -> bool: """Template should re-render if the state changes.""" - return entity_id in self._entities + return entity_id in self.entities def _filter_lifecycle(self, entity_id: str) -> bool: """Template should re-render if the state changes.""" return ( - split_entity_id(entity_id)[0] in self._domains - or entity_id in self._entities + split_entity_id(entity_id)[0] in self.domains or entity_id in self.entities ) @property def result(self) -> str: """Results of the template computation.""" - if self._exception is not None: - raise self._exception + if self.exception is not None: + raise self.exception return self._result + def _freeze_static(self) -> None: + self.is_static = True + self.entities = frozenset(self.entities) + self.domains = frozenset(self.domains) + self.all_states = False + def _freeze(self) -> None: - self._entities = frozenset(self._entities) - if self._all_states: - # Leave lifecycle_filter as True - del self._domains - elif not self._domains: - del self._domains + self.entities = frozenset(self.entities) + self.domains = frozenset(self.domains) + + if self.all_states: + return + + if not self.domains: self.filter_lifecycle = self.filter else: - self._domains = frozenset(self._domains) self.filter_lifecycle = self._filter_lifecycle @@ -206,7 +218,7 @@ class Template: raise TemplateError(err) def extract_entities( - self, variables: Optional[Dict[str, Any]] = None + self, variables: TemplateVarsType = None ) -> Union[str, List[str]]: """Extract all entities for state_changed listener.""" return extract_entities(self.hass, self.template, variables) @@ -247,10 +259,13 @@ class Template: try: render_info._result = self.async_render(variables, **kwargs) except TemplateError as ex: - render_info._exception = ex + render_info.exception = ex finally: del self.hass.data[_RENDER_INFO] - render_info._freeze() + if _RE_JINJA_DELIMITERS.search(self.template) is None: + render_info._freeze_static() + else: + render_info._freeze() return render_info def render_with_possible_json_value(self, value, error_value=_SENTINEL): @@ -342,15 +357,19 @@ class AllStates: if not valid_entity_id(name): raise TemplateError(f"Invalid entity ID '{name}'") return _get_state(self._hass, name) + + if name in _RESERVED_NAMES: + return None + if not valid_entity_id(f"{name}.entity"): raise TemplateError(f"Invalid domain name '{name}'") + return DomainStates(self._hass, name) def _collect_all(self) -> None: render_info = self._hass.data.get(_RENDER_INFO) if render_info is not None: - # pylint: disable=protected-access - render_info._all_states = True + render_info.all_states = True def __iter__(self): """Return all states.""" @@ -395,8 +414,7 @@ class DomainStates: def _collect_domain(self) -> None: entity_collect = self._hass.data.get(_RENDER_INFO) if entity_collect is not None: - # pylint: disable=protected-access - entity_collect._domains.append(self._domain) + entity_collect.domains.add(self._domain) def __iter__(self): """Return the iteration over all the states.""" @@ -435,7 +453,6 @@ class TemplateState(State): def _access_state(self): state = object.__getattribute__(self, "_state") hass = object.__getattribute__(self, "_hass") - _collect_state(hass, state.entity_id) return state @@ -448,6 +465,13 @@ class TemplateState(State): return state.state return f"{state.state} {unit}" + def __eq__(self, other: Any) -> bool: + """Ensure we collect on equality check.""" + state = object.__getattribute__(self, "_state") + hass = object.__getattribute__(self, "_hass") + _collect_state(hass, state.entity_id) + return super().__eq__(other) + def __getattribute__(self, name): """Return an attribute of the state.""" # This one doesn't count as an access of the state @@ -471,8 +495,7 @@ class TemplateState(State): def _collect_state(hass: HomeAssistantType, entity_id: str) -> None: entity_collect = hass.data.get(_RENDER_INFO) if entity_collect is not None: - # pylint: disable=protected-access - entity_collect._entities.append(entity_id) + entity_collect.entities.add(entity_id) def _wrap_state( @@ -503,6 +526,19 @@ def _resolve_state( return None +def result_as_boolean(template_result: Optional[str]) -> bool: + """Convert the template result to a boolean. + + True/not 0/'1'/'true'/'yes'/'on'/'enable' are considered truthy + False/0/None/'0'/'false'/'no'/'off'/'disable' are considered falsy + + """ + try: + return cv.boolean(template_result) + except vol.Invalid: + return False + + def expand(hass: HomeAssistantType, *args: Any) -> Iterable[State]: """Expand out any groups into entity states.""" search = list(args) @@ -523,16 +559,15 @@ def expand(hass: HomeAssistantType, *args: Any) -> Iterable[State]: # ignore other types continue - # pylint: disable=import-outside-toplevel - from homeassistant.components import group - - if split_entity_id(entity_id)[0] == group.DOMAIN: + if entity_id.startswith(_GROUP_DOMAIN_PREFIX): # Collect state will be called in here since it's wrapped group_entities = entity.attributes.get(ATTR_ENTITY_ID) if group_entities: search += group_entities else: + _collect_state(hass, entity_id) found[entity_id] = entity + return sorted(found.values(), key=lambda a: a.entity_id) @@ -618,7 +653,10 @@ def distance(hass, *args): while to_process: value = to_process.pop(0) - point_state = _resolve_state(hass, value) + if isinstance(value, str) and not valid_entity_id(value): + point_state = None + else: + point_state = _resolve_state(hass, value) if point_state is None: # We expect this and next value to be lat&lng diff --git a/tests/helpers/test_condition.py b/tests/helpers/test_condition.py index b2cb1ff100c..1a575d1eff7 100644 --- a/tests/helpers/test_condition.py +++ b/tests/helpers/test_condition.py @@ -1,4 +1,6 @@ """Test the condition helper.""" +from logging import ERROR + import pytest from homeassistant.exceptions import HomeAssistantError @@ -576,3 +578,18 @@ async def test_extract_devices(): ], } ) == {"abcd", "qwer", "abcd_not", "qwer_not", "abcd_or", "qwer_or"} + + +async def test_condition_template_error(hass, caplog): + """Test invalid template.""" + caplog.set_level(ERROR) + + test = await condition.async_from_config( + hass, {"condition": "template", "value_template": "{{ undefined.state }}"} + ) + + assert not test(hass) + assert len(caplog.records) == 1 + assert caplog.records[0].message.startswith( + "Error during template condition: UndefinedError:" + ) diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index e30f85c9c38..b61d8a9365c 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -10,6 +10,7 @@ from homeassistant.components import sun from homeassistant.const import MATCH_ALL import homeassistant.core as ha from homeassistant.core import callback +from homeassistant.exceptions import TemplateError from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED from homeassistant.helpers.event import ( async_call_later, @@ -22,6 +23,7 @@ from homeassistant.helpers.event import ( async_track_sunrise, async_track_sunset, async_track_template, + async_track_template_result, async_track_time_change, async_track_time_interval, async_track_utc_time_change, @@ -490,61 +492,481 @@ async def test_track_template(hass): assert len(wildcard_runs) == 2 assert len(wildercard_runs) == 2 - -async def test_track_same_state_simple_trigger(hass): - """Test track_same_change with trigger simple.""" - thread_runs = [] - callback_runs = [] - coroutine_runs = [] - period = timedelta(minutes=1) - - def thread_run_callback(): - thread_runs.append(1) - - async_track_same_state( - hass, - period, - thread_run_callback, - lambda _, _2, to_s: to_s.state == "on", - entity_ids="light.Bowl", - ) + template_iterate = Template("{{ (states.switch | length) > 0 }}", hass) + iterate_calls = [] @ha.callback - def callback_run_callback(): - callback_runs.append(1) + def iterate_callback(entity_id, old_state, new_state): + iterate_calls.append((entity_id, old_state, new_state)) - async_track_same_state( - hass, - period, - callback_run_callback, - callback(lambda _, _2, to_s: to_s.state == "on"), - entity_ids="light.Bowl", + async_track_template(hass, template_iterate, iterate_callback) + await hass.async_block_till_done() + + hass.states.async_set("switch.new", "on") + await hass.async_block_till_done() + + assert len(iterate_calls) == 1 + assert iterate_calls[0][0] == "switch.new" + assert iterate_calls[0][1] is None + assert iterate_calls[0][2].state == "on" + + +async def test_track_template_error(hass, caplog): + """Test tracking template with error.""" + template_error = Template("{{ (states.switch | lunch) > 0 }}", hass) + error_calls = [] + + @ha.callback + def error_callback(entity_id, old_state, new_state): + error_calls.append((entity_id, old_state, new_state)) + + async_track_template(hass, template_error, error_callback) + await hass.async_block_till_done() + + hass.states.async_set("switch.new", "on") + await hass.async_block_till_done() + + assert not error_calls + assert "lunch" in caplog.text + assert "TemplateAssertionError" in caplog.text + + caplog.clear() + + with patch.object(Template, "async_render") as render: + render.return_value = "ok" + + hass.states.async_set("switch.not_exist", "off") + await hass.async_block_till_done() + + assert "lunch" not in caplog.text + assert "TemplateAssertionError" not in caplog.text + + hass.states.async_set("switch.not_exist", "on") + await hass.async_block_till_done() + + assert "lunch" in caplog.text + assert "TemplateAssertionError" in caplog.text + + +async def test_track_template_result(hass): + """Test tracking template.""" + specific_runs = [] + wildcard_runs = [] + wildercard_runs = [] + + template_condition = Template("{{states.sensor.test.state}}", hass) + template_condition_var = Template( + "{{(states.sensor.test.state|int) + test }}", hass ) - async def coroutine_run_callback(): - coroutine_runs.append(1) + def specific_run_callback(event, template, old_result, new_result): + specific_runs.append(int(new_result)) - async_track_same_state( - hass, - period, - coroutine_run_callback, - callback(lambda _, _2, to_s: to_s.state == "on"), + async_track_template_result(hass, template_condition, specific_run_callback) + + @ha.callback + def wildcard_run_callback(event, template, old_result, new_result): + wildcard_runs.append((int(old_result or 0), int(new_result))) + + async_track_template_result(hass, template_condition, wildcard_run_callback) + + async def wildercard_run_callback(event, template, old_result, new_result): + wildercard_runs.append((int(old_result or 0), int(new_result))) + + async_track_template_result( + hass, template_condition_var, wildercard_run_callback, {"test": 5} ) - - # Adding state to state machine - hass.states.async_set("light.Bowl", "on") await hass.async_block_till_done() - assert len(thread_runs) == 0 - assert len(callback_runs) == 0 - assert len(coroutine_runs) == 0 - # change time to track and see if they trigger - future = dt_util.utcnow() + period - async_fire_time_changed(hass, future) + hass.states.async_set("sensor.test", 5) await hass.async_block_till_done() - assert len(thread_runs) == 1 - assert len(callback_runs) == 1 - assert len(coroutine_runs) == 1 + + assert specific_runs == [5] + assert wildcard_runs == [(0, 5)] + assert wildercard_runs == [(0, 10)] + + hass.states.async_set("sensor.test", 30) + await hass.async_block_till_done() + + assert specific_runs == [5, 30] + assert wildcard_runs == [(0, 5), (5, 30)] + assert wildercard_runs == [(0, 10), (10, 35)] + + hass.states.async_set("sensor.test", 30) + await hass.async_block_till_done() + + assert len(specific_runs) == 2 + assert len(wildcard_runs) == 2 + assert len(wildercard_runs) == 2 + + hass.states.async_set("sensor.test", 5) + await hass.async_block_till_done() + + assert len(specific_runs) == 3 + assert len(wildcard_runs) == 3 + assert len(wildercard_runs) == 3 + + hass.states.async_set("sensor.test", 5) + await hass.async_block_till_done() + + assert len(specific_runs) == 3 + assert len(wildcard_runs) == 3 + assert len(wildercard_runs) == 3 + + hass.states.async_set("sensor.test", 20) + await hass.async_block_till_done() + + assert len(specific_runs) == 4 + assert len(wildcard_runs) == 4 + assert len(wildercard_runs) == 4 + + +async def test_track_template_result_complex(hass): + """Test tracking template.""" + specific_runs = [] + template_complex_str = """ + +{% if states("sensor.domain") == "light" %} + {{ states.light | map(attribute='entity_id') | list }} +{% elif states("sensor.domain") == "lock" %} + {{ states.lock | map(attribute='entity_id') | list }} +{% elif states("sensor.domain") == "single_binary_sensor" %} + {{ states("binary_sensor.single") }} +{% else %} + {{ states | map(attribute='entity_id') | list }} +{% endif %} + +""" + template_complex = Template(template_complex_str, hass) + + def specific_run_callback(event, template, old_result, new_result): + specific_runs.append(new_result) + + hass.states.async_set("light.one", "on") + hass.states.async_set("lock.one", "locked") + + async_track_template_result(hass, template_complex, specific_run_callback) + await hass.async_block_till_done() + + hass.states.async_set("sensor.domain", "light") + await hass.async_block_till_done() + assert len(specific_runs) == 1 + assert specific_runs[0].strip() == "['light.one']" + + hass.states.async_set("sensor.domain", "lock") + await hass.async_block_till_done() + assert len(specific_runs) == 2 + assert specific_runs[1].strip() == "['lock.one']" + + hass.states.async_set("sensor.domain", "all") + await hass.async_block_till_done() + assert len(specific_runs) == 3 + assert "light.one" in specific_runs[2] + assert "lock.one" in specific_runs[2] + assert "sensor.domain" in specific_runs[2] + + hass.states.async_set("sensor.domain", "light") + await hass.async_block_till_done() + assert len(specific_runs) == 4 + assert specific_runs[3].strip() == "['light.one']" + + hass.states.async_set("light.two", "on") + await hass.async_block_till_done() + assert len(specific_runs) == 5 + assert "light.one" in specific_runs[4] + assert "light.two" in specific_runs[4] + assert "sensor.domain" not in specific_runs[4] + + hass.states.async_set("light.three", "on") + await hass.async_block_till_done() + assert len(specific_runs) == 6 + assert "light.one" in specific_runs[5] + assert "light.two" in specific_runs[5] + assert "light.three" in specific_runs[5] + assert "sensor.domain" not in specific_runs[5] + + hass.states.async_set("sensor.domain", "lock") + await hass.async_block_till_done() + assert len(specific_runs) == 7 + assert specific_runs[6].strip() == "['lock.one']" + + hass.states.async_set("sensor.domain", "single_binary_sensor") + await hass.async_block_till_done() + assert len(specific_runs) == 8 + assert specific_runs[7].strip() == "unknown" + + hass.states.async_set("binary_sensor.single", "binary_sensor_on") + await hass.async_block_till_done() + assert len(specific_runs) == 9 + assert specific_runs[8].strip() == "binary_sensor_on" + + hass.states.async_set("sensor.domain", "lock") + await hass.async_block_till_done() + assert len(specific_runs) == 10 + assert specific_runs[9].strip() == "['lock.one']" + + +async def test_track_template_result_with_wildcard(hass): + """Test tracking template with a wildcard.""" + specific_runs = [] + template_complex_str = r""" + +{% for state in states %} + {% if state.entity_id | regex_match('.*\.office_') %} + {{ state.entity_id }}={{ state.state }} + {% endif %} +{% endfor %} + +""" + template_complex = Template(template_complex_str, hass) + + def specific_run_callback(event, template, old_result, new_result): + specific_runs.append(new_result) + + hass.states.async_set("cover.office_drapes", "closed") + hass.states.async_set("cover.office_window", "closed") + hass.states.async_set("cover.office_skylight", "open") + + async_track_template_result(hass, template_complex, specific_run_callback) + await hass.async_block_till_done() + + hass.states.async_set("cover.office_window", "open") + await hass.async_block_till_done() + assert len(specific_runs) == 1 + + assert "cover.office_drapes=closed" in specific_runs[0] + assert "cover.office_window=open" in specific_runs[0] + assert "cover.office_skylight=open" in specific_runs[0] + + +async def test_track_template_result_with_group(hass): + """Test tracking template with a group.""" + hass.states.async_set("sensor.power_1", 0) + hass.states.async_set("sensor.power_2", 200.2) + hass.states.async_set("sensor.power_3", 400.4) + hass.states.async_set("sensor.power_4", 800.8) + + assert await async_setup_component( + hass, + "group", + {"group": {"power_sensors": "sensor.power_1,sensor.power_2,sensor.power_3"}}, + ) + await hass.async_block_till_done() + + assert hass.states.get("group.power_sensors") + assert hass.states.get("group.power_sensors").state + + specific_runs = [] + template_complex_str = r""" + +{{ states.group.power_sensors.attributes.entity_id | expand | map(attribute='state')|map('float')|sum }} + +""" + template_complex = Template(template_complex_str, hass) + + def specific_run_callback(event, template, old_result, new_result): + specific_runs.append(new_result) + + async_track_template_result(hass, template_complex, specific_run_callback) + await hass.async_block_till_done() + + hass.states.async_set("sensor.power_1", 100.1) + await hass.async_block_till_done() + assert len(specific_runs) == 1 + + assert specific_runs[0] == str(100.1 + 200.2 + 400.4) + + hass.states.async_set("sensor.power_3", 0) + await hass.async_block_till_done() + assert len(specific_runs) == 2 + + assert specific_runs[1] == str(100.1 + 200.2 + 0) + + with patch( + "homeassistant.config.load_yaml_config_file", + return_value={ + "group": { + "power_sensors": "sensor.power_1,sensor.power_2,sensor.power_3,sensor.power_4", + } + }, + ): + await hass.services.async_call("group", "reload") + await hass.async_block_till_done() + + assert specific_runs[-1] == str(100.1 + 200.2 + 0 + 800.8) + + +async def test_track_template_result_iterator(hass): + """Test tracking template.""" + iterator_runs = [] + + @ha.callback + def iterator_callback(event, template, old_result, new_result): + iterator_runs.append(new_result) + + async_track_template_result( + hass, + Template( + """ + {% for state in states.sensor %} + {% if state.state == 'on' %} + {{ state.entity_id }}, + {% endif %} + {% endfor %} + """, + hass, + ), + iterator_callback, + ) + await hass.async_block_till_done() + + hass.states.async_set("sensor.test", 5) + await hass.async_block_till_done() + + assert iterator_runs == [""] + + filter_runs = [] + + @ha.callback + def filter_callback(event, template, old_result, new_result): + filter_runs.append(new_result) + + async_track_template_result( + hass, + Template( + """{{ states.sensor|selectattr("state","equalto","on") + |join(",", attribute="entity_id") }}""", + hass, + ), + filter_callback, + ) + await hass.async_block_till_done() + + hass.states.async_set("sensor.test", 6) + await hass.async_block_till_done() + + assert filter_runs == [""] + assert iterator_runs == [""] + + hass.states.async_set("sensor.new", "on") + await hass.async_block_till_done() + assert iterator_runs == ["", "sensor.new,"] + assert filter_runs == ["", "sensor.new"] + + +async def test_track_template_result_errors(hass, caplog): + """Test tracking template with errors in the template.""" + template_syntax_error = Template("{{states.switch", hass) + + template_not_exist = Template("{{states.switch.not_exist.state }}", hass) + + syntax_error_runs = [] + not_exist_runs = [] + + def syntax_error_listener(event, template, last_result, result): + syntax_error_runs.append((event, template, last_result, result)) + + async_track_template_result(hass, template_syntax_error, syntax_error_listener) + await hass.async_block_till_done() + + assert len(syntax_error_runs) == 0 + assert "TemplateSyntaxError" in caplog.text + + async_track_template_result( + hass, + template_not_exist, + lambda event, template, last_result, result: ( + not_exist_runs.append((event, template, last_result, result)) + ), + ) + await hass.async_block_till_done() + + assert len(syntax_error_runs) == 0 + assert len(not_exist_runs) == 0 + + hass.states.async_set("switch.not_exist", "off") + await hass.async_block_till_done() + + assert len(not_exist_runs) == 1 + assert not_exist_runs[0][0].data.get("entity_id") == "switch.not_exist" + assert not_exist_runs[0][1] == template_not_exist + assert not_exist_runs[0][2] is None + assert not_exist_runs[0][3] == "off" + + hass.states.async_set("switch.not_exist", "on") + await hass.async_block_till_done() + + assert len(syntax_error_runs) == 0 + assert len(not_exist_runs) == 2 + assert not_exist_runs[1][0].data.get("entity_id") == "switch.not_exist" + assert not_exist_runs[1][1] == template_not_exist + assert not_exist_runs[1][2] == "off" + assert not_exist_runs[1][3] == "on" + + with patch.object(Template, "async_render") as render: + render.side_effect = TemplateError("Test") + + hass.states.async_set("switch.not_exist", "off") + await hass.async_block_till_done() + + assert len(not_exist_runs) == 3 + assert not_exist_runs[2][0].data.get("entity_id") == "switch.not_exist" + assert not_exist_runs[2][1] == template_not_exist + assert not_exist_runs[2][2] == "on" + assert isinstance(not_exist_runs[2][3], TemplateError) + + +async def test_track_template_result_refresh_cancel(hass): + """Test cancelling and refreshing result.""" + template_refresh = Template("{{states.switch.test.state == 'on' and now() }}", hass) + + refresh_runs = [] + + def refresh_listener(event, template, last_result, result): + refresh_runs.append(result) + + info = async_track_template_result(hass, template_refresh, refresh_listener) + await hass.async_block_till_done() + + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() + + assert refresh_runs == ["False"] + + assert len(refresh_runs) == 1 + + info.async_refresh() + hass.states.async_set("switch.test", "on") + await hass.async_block_till_done() + + assert len(refresh_runs) == 2 + assert refresh_runs[0] != refresh_runs[1] + + info.async_remove() + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() + + assert len(refresh_runs) == 2 + + template_refresh = Template("{{ value }}", hass) + refresh_runs = [] + + info = async_track_template_result( + hass, template_refresh, refresh_listener, {"value": "duck"} + ) + await hass.async_block_till_done() + info.async_refresh() + await hass.async_block_till_done() + + assert refresh_runs == ["duck"] + + info.async_refresh() + await hass.async_block_till_done() + assert refresh_runs == ["duck"] + + info.async_refresh({"value": "dog"}) + await hass.async_block_till_done() + assert refresh_runs == ["duck", "dog"] async def test_track_same_state_simple_no_trigger(hass): diff --git a/tests/helpers/test_template.py b/tests/helpers/test_template.py index fa650b280c6..74cd19a165d 100644 --- a/tests/helpers/test_template.py +++ b/tests/helpers/test_template.py @@ -39,25 +39,22 @@ def render_to_info(hass, template_str, variables=None): def extract_entities(hass, template_str, variables=None): """Extract entities from a template.""" info = render_to_info(hass, template_str, variables) - # pylint: disable=protected-access - assert not hasattr(info, "_domains") - return info._entities + return info.entities def assert_result_info(info, result, entities=None, domains=None, all_states=False): """Check result info.""" assert info.result == result - # pylint: disable=protected-access - assert info._all_states == all_states + assert info.all_states == all_states assert info.filter_lifecycle("invalid_entity_name.somewhere") == all_states if entities is not None: - assert info._entities == frozenset(entities) + assert info.entities == frozenset(entities) assert all([info.filter(entity) for entity in entities]) assert not info.filter("invalid_entity_name.somewhere") else: - assert not info._entities + assert not info.entities if domains is not None: - assert info._domains == frozenset(domains) + assert info.domains == frozenset(domains) assert all([info.filter_lifecycle(domain + ".entity") for domain in domains]) else: assert not hasattr(info, "_domains") @@ -1256,7 +1253,7 @@ async def test_closest_function_home_vs_group_entity_id(hass): info = render_to_info(hass, '{{ closest("group.location_group").entity_id }}') assert_result_info( - info, "test_domain.object", ["test_domain.object", "group.location_group"] + info, "test_domain.object", {"group.location_group", "test_domain.object"} ) @@ -1281,12 +1278,12 @@ async def test_closest_function_home_vs_group_state(hass): info = render_to_info(hass, '{{ closest("group.location_group").entity_id }}') assert_result_info( - info, "test_domain.object", ["test_domain.object", "group.location_group"] + info, "test_domain.object", {"group.location_group", "test_domain.object"} ) info = render_to_info(hass, "{{ closest(states.group.location_group).entity_id }}") assert_result_info( - info, "test_domain.object", ["test_domain.object", "group.location_group"] + info, "test_domain.object", {"test_domain.object", "group.location_group"} ) @@ -1303,7 +1300,7 @@ async def test_expand(hass): info = render_to_info( hass, "{{ expand('test.object') | map(attribute='entity_id') | join(', ') }}" ) - assert_result_info(info, "test.object", []) + assert_result_info(info, "test.object", ["test.object"]) info = render_to_info( hass, @@ -1322,26 +1319,45 @@ async def test_expand(hass): hass, "{{ expand('group.new_group') | map(attribute='entity_id') | join(', ') }}", ) - assert_result_info(info, "test.object", ["group.new_group"]) + assert_result_info(info, "test.object", {"group.new_group", "test.object"}) info = render_to_info( hass, "{{ expand(states.group) | map(attribute='entity_id') | join(', ') }}" ) - assert_result_info(info, "test.object", ["group.new_group"], ["group"]) + assert_result_info( + info, "test.object", {"test.object", "group.new_group"}, ["group"] + ) info = render_to_info( hass, "{{ expand('group.new_group', 'test.object')" " | map(attribute='entity_id') | join(', ') }}", ) - assert_result_info(info, "test.object", ["group.new_group"]) + assert_result_info(info, "test.object", {"test.object", "group.new_group"}) info = render_to_info( hass, "{{ ['group.new_group', 'test.object'] | expand" " | map(attribute='entity_id') | join(', ') }}", ) - assert_result_info(info, "test.object", ["group.new_group"]) + assert_result_info(info, "test.object", {"test.object", "group.new_group"}) + + hass.states.async_set("sensor.power_1", 0) + hass.states.async_set("sensor.power_2", 200.2) + hass.states.async_set("sensor.power_3", 400.4) + await group.Group.async_create_group( + hass, "power sensors", ["sensor.power_1", "sensor.power_2", "sensor.power_3"] + ) + + info = render_to_info( + hass, + "{{ states.group.power_sensors.attributes.entity_id | expand | map(attribute='state')|map('float')|sum }}", + ) + assert_result_info( + info, + str(200.2 + 400.4), + {"group.power_sensors", "sensor.power_1", "sensor.power_2", "sensor.power_3"}, + ) def test_closest_function_to_coord(hass): @@ -1390,6 +1406,198 @@ def test_closest_function_to_coord(hass): assert tpl.async_render() == "test_domain.closest_zone" +def test_async_render_to_info_with_branching(hass): + """Test async_render_to_info function by domain.""" + hass.states.async_set("light.a", "off") + hass.states.async_set("light.b", "on") + hass.states.async_set("light.c", "off") + + info = render_to_info( + hass, + """ +{% if states.light.a == "on" %} + {{ states.light.b.state }} +{% else %} + {{ states.light.c.state }} +{% endif %} +""", + ) + assert_result_info(info, "off", {"light.a", "light.c"}) + + info = render_to_info( + hass, + """ + {% if states.light.a.state == "off" %} + {% set domain = "light" %} + {{ states[domain].b.state }} + {% endif %} +""", + ) + assert_result_info(info, "on", {"light.a", "light.b"}) + + +def test_async_render_to_info_with_complex_branching(hass): + """Test async_render_to_info function by domain.""" + hass.states.async_set("light.a", "off") + hass.states.async_set("light.b", "on") + hass.states.async_set("light.c", "off") + hass.states.async_set("vacuum.a", "off") + hass.states.async_set("device_tracker.a", "off") + hass.states.async_set("device_tracker.b", "off") + hass.states.async_set("lock.a", "off") + hass.states.async_set("sensor.a", "off") + hass.states.async_set("binary_sensor.a", "off") + + info = render_to_info( + hass, + """ +{% set domain = "vacuum" %} +{% if states.light.a == "on" %} + {{ states.light.b.state }} +{% elif states.light.a == "on" %} + {{ states.device_tracker }} +{% elif states.light.a == "on" %} + {{ states[domain] | list }} +{% elif states('light.b') == "on" %} + {{ states[otherdomain] | map(attribute='entity_id') | list }} +{% elif states.light.a == "on" %} + {{ states["nonexist"] | list }} +{% else %} + else +{% endif %} +""", + {"otherdomain": "sensor"}, + ) + + assert_result_info(info, "['sensor.a']", {"light.a", "light.b"}, {"sensor"}) + + +async def test_async_render_to_info_with_wildcard_matching_entity_id(hass): + """Test tracking template with a wildcard.""" + template_complex_str = r""" + +{% for state in states %} + {% if state.entity_id | regex_match('.*\.office_') %} + {{ state.entity_id }}={{ state.state }} + {% endif %} +{% endfor %} + +""" + hass.states.async_set("cover.office_drapes", "closed") + hass.states.async_set("cover.office_window", "closed") + hass.states.async_set("cover.office_skylight", "open") + info = render_to_info(hass, template_complex_str) + + assert not info.domains + assert info.entities == { + "cover.office_drapes", + "cover.office_window", + "cover.office_skylight", + } + assert info.all_states is True + + +async def test_async_render_to_info_with_wildcard_matching_state(hass): + """Test tracking template with a wildcard.""" + template_complex_str = """ + +{% for state in states %} + {% if state.state | regex_match('ope.*') %} + {{ state.entity_id }}={{ state.state }} + {% endif %} +{% endfor %} + +""" + hass.states.async_set("cover.office_drapes", "closed") + hass.states.async_set("cover.office_window", "closed") + hass.states.async_set("cover.office_skylight", "open") + hass.states.async_set("cover.x_skylight", "open") + hass.states.async_set("binary_sensor.door", "open") + + info = render_to_info(hass, template_complex_str) + + assert not info.domains + assert info.entities == { + "cover.x_skylight", + "binary_sensor.door", + "cover.office_drapes", + "cover.office_window", + "cover.office_skylight", + } + assert info.all_states is True + + hass.states.async_set("binary_sensor.door", "closed") + info = render_to_info(hass, template_complex_str) + + assert not info.domains + assert info.entities == { + "cover.x_skylight", + "binary_sensor.door", + "cover.office_drapes", + "cover.office_window", + "cover.office_skylight", + } + assert info.all_states is True + + template_cover_str = """ + +{% for state in states.cover %} + {% if state.state | regex_match('ope.*') %} + {{ state.entity_id }}={{ state.state }} + {% endif %} +{% endfor %} + +""" + hass.states.async_set("cover.x_skylight", "closed") + info = render_to_info(hass, template_cover_str) + + assert info.domains == {"cover"} + assert info.entities == { + "cover.x_skylight", + "cover.office_drapes", + "cover.office_window", + "cover.office_skylight", + } + assert info.all_states is False + + +def test_nested_async_render_to_info_case(hass): + """Test a deeply nested state with async_render_to_info.""" + + hass.states.async_set("input_select.picker", "vacuum.a") + hass.states.async_set("vacuum.a", "off") + + info = render_to_info( + hass, "{{ states[states['input_select.picker'].state].state }}", {} + ) + assert_result_info(info, "off", {"input_select.picker", "vacuum.a"}) + + +def test_result_as_boolean(hass): + """Test converting a template result to a boolean.""" + + template.result_as_boolean(True) is True + template.result_as_boolean(" 1 ") is True + template.result_as_boolean(" true ") is True + template.result_as_boolean(" TrUE ") is True + template.result_as_boolean(" YeS ") is True + template.result_as_boolean(" On ") is True + template.result_as_boolean(" Enable ") is True + template.result_as_boolean(1) is True + template.result_as_boolean(-1) is True + template.result_as_boolean(500) is True + + template.result_as_boolean(False) is False + template.result_as_boolean(" 0 ") is False + template.result_as_boolean(" false ") is False + template.result_as_boolean(" FaLsE ") is False + template.result_as_boolean(" no ") is False + template.result_as_boolean(" off ") is False + template.result_as_boolean(" disable ") is False + template.result_as_boolean(0) is False + template.result_as_boolean(None) is False + + def test_closest_function_to_entity_id(hass): """Test closest function to entity id.""" hass.states.async_set( @@ -1558,13 +1766,16 @@ def test_extract_entities_none_exclude_stuff(hass): assert ( template.extract_entities( - hass, "{{ closest(states.zone.far_away, states.test_domain).entity_id }}" + hass, + "{{ closest(states.zone.far_away, states.test_domain.xxx).entity_id }}", ) == MATCH_ALL ) assert ( - template.extract_entities(hass, '{{ distance("123", states.test_object_2) }}') + template.extract_entities( + hass, '{{ distance("123", states.test_object_2.user) }}' + ) == MATCH_ALL ) @@ -1673,6 +1884,42 @@ def test_generate_select(hass): ) +async def test_async_render_to_info_in_conditional(hass): + """Test extract entities function with none entities stuff.""" + template_str = """ +{{ states("sensor.xyz") == "dog" }} + """ + + tmp = template.Template(template_str, hass) + info = tmp.async_render_to_info() + assert_result_info(info, "False", ["sensor.xyz"], []) + + hass.states.async_set("sensor.xyz", "dog") + hass.states.async_set("sensor.cow", "True") + await hass.async_block_till_done() + + template_str = """ +{% if states("sensor.xyz") == "dog" %} + {{ states("sensor.cow") }} +{% else %} + {{ states("sensor.pig") }} +{% endif %} + """ + + tmp = template.Template(template_str, hass) + info = tmp.async_render_to_info() + assert_result_info(info, "True", ["sensor.xyz", "sensor.cow"], []) + + hass.states.async_set("sensor.xyz", "sheep") + hass.states.async_set("sensor.pig", "oink") + + await hass.async_block_till_done() + + tmp = template.Template(template_str, hass) + info = tmp.async_render_to_info() + assert_result_info(info, "oink", ["sensor.xyz", "sensor.pig"], []) + + async def test_extract_entities_match_entities(hass): """Test extract entities function with entities stuff.""" assert ( @@ -1739,8 +1986,8 @@ Hercules you power goes done!. hass, """ {{ -states.sensor.pick_temperature.state ~ „°C (“ ~ -states.sensor.pick_humidity.state ~ „ %“ +states.sensor.pick_temperature.state ~ "°C (" ~ +states.sensor.pick_humidity.state ~ " %" }} """, ) @@ -1771,11 +2018,15 @@ states.sensor.pick_humidity.state ~ „ %“ hass, "{{ expand('group.expand_group') | list | length }}" ) ) - assert ["test_domain.entity"] == template.Template( '{{ is_state("test_domain.entity", "on") }}', hass ).extract_entities() + # No expand, extract finds the group + assert template.extract_entities(hass, "{{ states('group.empty_group') }}") == [ + "group.empty_group" + ] + def test_extract_entities_with_variables(hass): """Test extract entities function with variables and entities stuff."""