Add condition to trigger template entities (#119689)

* Add conditions to trigger template entities

* Add tests

* Fix ruff error

* Ruff

* Apply suggestions from code review

* Deduplicate

* Tweak name used in debug message

* Add and improve type annotations of modified code

* Adjust typing

* Adjust typing

* Add typing and remove unused parameter

* Adjust typing

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Adjust return type

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

---------

Co-authored-by: Erik Montnemery <erik@montnemery.com>
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
chammp 2024-09-11 09:36:49 +02:00 committed by GitHub
parent 74834b2d88
commit b3377fe5fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 265 additions and 49 deletions

View File

@ -47,14 +47,7 @@ from homeassistant.core import (
split_entity_id,
valid_entity_id,
)
from homeassistant.exceptions import (
ConditionError,
ConditionErrorContainer,
ConditionErrorIndex,
HomeAssistantError,
ServiceNotFound,
TemplateError,
)
from homeassistant.exceptions import HomeAssistantError, ServiceNotFound, TemplateError
from homeassistant.helpers import condition
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.deprecation import (
@ -1146,39 +1139,14 @@ async def _async_process_if(
"""Process if checks."""
if_configs = config[CONF_CONDITION]
checks: list[condition.ConditionCheckerType] = []
for if_config in if_configs:
try:
checks.append(await condition.async_from_config(hass, if_config))
if_action = await condition.async_conditions_from_config(
hass, if_configs, LOGGER, name
)
except HomeAssistantError as ex:
LOGGER.warning("Invalid condition: %s", ex)
return None
def if_action(variables: Mapping[str, Any] | None = None) -> bool:
"""AND all conditions."""
errors: list[ConditionErrorIndex] = []
for index, check in enumerate(checks):
try:
with trace_path(["condition", str(index)]):
if check(hass, variables) is False:
return False
except ConditionError as ex:
errors.append(
ConditionErrorIndex(
"condition", index=index, total=len(checks), error=ex
)
)
if errors:
LOGGER.warning(
"Error evaluating condition in '%s':\n%s",
name,
ConditionErrorContainer("condition", errors=errors),
)
return False
return True
result: IfAction = if_action # type: ignore[assignment]
result.config = if_configs

View File

@ -15,6 +15,7 @@ from homeassistant.config import async_log_schema_error, config_without_domain
from homeassistant.const import CONF_BINARY_SENSORS, CONF_SENSORS, CONF_UNIQUE_ID
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.condition import async_validate_conditions_config
from homeassistant.helpers.trigger import async_validate_trigger_config
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_notify_setup_error
@ -28,7 +29,7 @@ from . import (
sensor as sensor_platform,
weather as weather_platform,
)
from .const import CONF_ACTION, CONF_TRIGGER, DOMAIN
from .const import CONF_ACTION, CONF_CONDITION, CONF_TRIGGER, DOMAIN
PACKAGE_MERGE_HINT = "list"
@ -36,6 +37,7 @@ CONFIG_SECTION_SCHEMA = vol.Schema(
{
vol.Optional(CONF_UNIQUE_ID): cv.string,
vol.Optional(CONF_TRIGGER): cv.TRIGGER_SCHEMA,
vol.Optional(CONF_CONDITION): cv.CONDITIONS_SCHEMA,
vol.Optional(CONF_ACTION): cv.SCRIPT_SCHEMA,
vol.Optional(NUMBER_DOMAIN): vol.All(
cv.ensure_list, [number_platform.NUMBER_SCHEMA]
@ -83,6 +85,11 @@ async def async_validate_config(hass: HomeAssistant, config: ConfigType) -> Conf
cfg[CONF_TRIGGER] = await async_validate_trigger_config(
hass, cfg[CONF_TRIGGER]
)
if CONF_CONDITION in cfg:
cfg[CONF_CONDITION] = await async_validate_conditions_config(
hass, cfg[CONF_CONDITION]
)
except vol.Invalid as err:
async_log_schema_error(err, DOMAIN, cfg, hass)
async_notify_setup_error(hass, DOMAIN)

View File

@ -7,6 +7,7 @@ CONF_ATTRIBUTE_TEMPLATES = "attribute_templates"
CONF_ATTRIBUTES = "attributes"
CONF_AVAILABILITY = "availability"
CONF_AVAILABILITY_TEMPLATE = "availability_template"
CONF_CONDITION = "condition"
CONF_MAX = "max"
CONF_MIN = "min"
CONF_OBJECT_ID = "object_id"

View File

@ -1,16 +1,18 @@
"""Data update coordinator for trigger based template entities."""
from collections.abc import Callable
from collections.abc import Callable, Mapping
import logging
from typing import TYPE_CHECKING, Any
from homeassistant.const import EVENT_HOMEASSISTANT_START
from homeassistant.core import Context, CoreState, callback
from homeassistant.helpers import discovery, trigger as trigger_helper
from homeassistant.helpers import condition, discovery, trigger as trigger_helper
from homeassistant.helpers.script import Script
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.trace import trace_get
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
from .const import CONF_ACTION, CONF_TRIGGER, DOMAIN, PLATFORMS
from .const import CONF_ACTION, CONF_CONDITION, CONF_TRIGGER, DOMAIN, PLATFORMS
_LOGGER = logging.getLogger(__name__)
@ -24,6 +26,7 @@ class TriggerUpdateCoordinator(DataUpdateCoordinator):
"""Instantiate trigger data."""
super().__init__(hass, _LOGGER, name="Trigger Update Coordinator")
self.config = config
self._cond_func: Callable[[Mapping[str, Any] | None], bool] | None = None
self._unsub_start: Callable[[], None] | None = None
self._unsub_trigger: Callable[[], None] | None = None
self._script: Script | None = None
@ -73,6 +76,11 @@ class TriggerUpdateCoordinator(DataUpdateCoordinator):
DOMAIN,
)
if CONF_CONDITION in self.config:
self._cond_func = await condition.async_conditions_from_config(
self.hass, self.config[CONF_CONDITION], _LOGGER, "template entity"
)
if start_event is not None:
self._unsub_start = None
@ -91,16 +99,43 @@ class TriggerUpdateCoordinator(DataUpdateCoordinator):
start_event is not None,
)
async def _handle_triggered_with_script(self, run_variables, context=None):
async def _handle_triggered_with_script(
self, run_variables: TemplateVarsType, context: Context | None = None
) -> None:
if not self._check_condition(run_variables):
return
# Create a context referring to the trigger context.
trigger_context_id = None if context is None else context.id
script_context = Context(parent_id=trigger_context_id)
if TYPE_CHECKING:
# This method is only called if there's a script
assert self._script is not None
if script_result := await self._script.async_run(run_variables, script_context):
run_variables = script_result.variables
self._handle_triggered(run_variables, context)
self._execute_update(run_variables, context)
async def _handle_triggered(
self, run_variables: TemplateVarsType, context: Context | None = None
) -> None:
if not self._check_condition(run_variables):
return
self._execute_update(run_variables, context)
def _check_condition(self, run_variables: TemplateVarsType) -> bool:
if not self._cond_func:
return True
condition_result = self._cond_func(run_variables)
if condition_result is False:
_LOGGER.debug(
"Conditions not met, aborting template trigger update. Condition summary: %s",
trace_get(clear=False),
)
return condition_result
@callback
def _handle_triggered(self, run_variables, context=None):
def _execute_update(
self, run_variables: TemplateVarsType, context: Context | None = None
) -> None:
self.async_set_updated_data(
{"run_variables": run_variables, "context": context}
)

View File

@ -8,6 +8,7 @@ from collections.abc import Callable, Container, Generator
from contextlib import contextmanager
from datetime import datetime, time as dt_time, timedelta
import functools as ft
import logging
import re
import sys
from typing import Any, Protocol, cast
@ -1064,6 +1065,46 @@ async def async_validate_conditions_config(
return [await async_validate_condition_config(hass, cond) for cond in conditions]
async def async_conditions_from_config(
hass: HomeAssistant,
condition_configs: list[ConfigType],
logger: logging.Logger,
name: str,
) -> Callable[[TemplateVarsType], bool]:
"""AND all conditions."""
checks: list[ConditionCheckerType] = [
await async_from_config(hass, condition_config)
for condition_config in condition_configs
]
def check_conditions(variables: TemplateVarsType = None) -> bool:
"""AND all conditions."""
errors: list[ConditionErrorIndex] = []
for index, check in enumerate(checks):
try:
with trace_path(["condition", str(index)]):
if check(hass, variables) is False:
return False
except ConditionError as ex:
errors.append(
ConditionErrorIndex(
"condition", index=index, total=len(checks), error=ex
)
)
if errors:
logger.warning(
"Error evaluating condition in '%s':\n%s",
name,
ConditionErrorContainer("condition", errors=errors),
)
return False
return True
return check_conditions
@callback
def async_extract_entities(config: ConfigType | Template) -> set[str]:
"""Extract entities from a condition."""

View File

@ -1349,7 +1349,7 @@ async def _async_stop_scripts_at_shutdown(hass: HomeAssistant, event: Event) ->
)
type _VarsType = dict[str, Any] | MappingProxyType[str, Any]
type _VarsType = dict[str, Any] | Mapping[str, Any] | MappingProxyType[str, Any]
def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:

View File

@ -1207,6 +1207,124 @@ async def test_trigger_entity(
assert state.context is context
@pytest.mark.parametrize(("count", "domain"), [(1, template.DOMAIN)])
@pytest.mark.parametrize(
"config",
[
{
"template": [
{
"unique_id": "listening-test-event",
"trigger": {"platform": "event", "event_type": "test_event"},
"condition": [
{
"condition": "template",
"value_template": "{{ trigger.event.data.beer >= 42 }}",
}
],
"sensor": [
{
"name": "Enough Name",
"unique_id": "enough-id",
"state": "You had enough Beer.",
}
],
},
],
},
],
)
async def test_trigger_conditional_entity(hass: HomeAssistant, start_ha) -> None:
"""Test conditional trigger entity works."""
state = hass.states.get("sensor.enough_name")
assert state is not None
assert state.state == STATE_UNKNOWN
hass.bus.async_fire("test_event", {"beer": 2})
await hass.async_block_till_done()
state = hass.states.get("sensor.enough_name")
assert state.state == STATE_UNKNOWN
hass.bus.async_fire("test_event", {"beer": 42})
await hass.async_block_till_done()
state = hass.states.get("sensor.enough_name")
assert state.state == "You had enough Beer."
@pytest.mark.parametrize(("count", "domain"), [(1, template.DOMAIN)])
@pytest.mark.parametrize(
"config",
[
{
"template": [
{
"unique_id": "listening-test-event",
"trigger": {"platform": "event", "event_type": "test_event"},
"condition": [
{
"condition": "template",
"value_template": "{{ trigger.event.data.beer / 0 == 'narf' }}",
}
],
"sensor": [
{
"name": "Enough Name",
"unique_id": "enough-id",
"state": "You had enough Beer.",
}
],
},
],
},
],
)
async def test_trigger_conditional_entity_evaluation_error(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture, start_ha
) -> None:
"""Test trigger entity is not updated when condition evaluation fails."""
hass.bus.async_fire("test_event", {"beer": 1})
await hass.async_block_till_done()
state = hass.states.get("sensor.enough_name")
assert state is not None
assert state.state == STATE_UNKNOWN
assert "Error evaluating condition in 'template entity'" in caplog.text
@pytest.mark.parametrize(("count", "domain"), [(0, template.DOMAIN)])
@pytest.mark.parametrize(
"config",
[
{
"template": [
{
"unique_id": "listening-test-event",
"trigger": {"platform": "event", "event_type": "test_event"},
"condition": [
{"condition": "template", "value_template": "{{ invalid"}
],
"sensor": [
{
"name": "Will Not Exist Name",
"state": "Unimportant",
}
],
},
],
},
],
)
async def test_trigger_conditional_entity_invalid_condition(
hass: HomeAssistant, start_ha
) -> None:
"""Test trigger entity is not created when condition is invalid."""
state = hass.states.get("sensor.will_not_exist_name")
assert state is None
@pytest.mark.parametrize(("count", "domain"), [(1, "template")])
@pytest.mark.parametrize(
"config",
@ -1903,6 +2021,52 @@ async def test_trigger_action(
assert events[0].context.parent_id == context.id
@pytest.mark.parametrize(("count", "domain"), [(1, template.DOMAIN)])
@pytest.mark.parametrize(
"config",
[
{
"template": [
{
"unique_id": "listening-test-event",
"trigger": {"platform": "event", "event_type": "test_event"},
"condition": [
{
"condition": "template",
"value_template": "{{ trigger.event.data.beer >= 42 }}",
}
],
"action": [
{"event": "test_event_by_action"},
],
"sensor": [
{
"name": "Not That Important",
"state": "Really not.",
}
],
},
],
},
],
)
async def test_trigger_conditional_action(hass: HomeAssistant, start_ha) -> None:
"""Test conditional trigger entity with an action works."""
event = "test_event_by_action"
events = async_capture_events(hass, event)
hass.bus.async_fire("test_event", {"beer": 1})
await hass.async_block_till_done()
assert len(events) == 0
hass.bus.async_fire("test_event", {"beer": 42})
await hass.async_block_till_done()
assert len(events) == 1
async def test_device_id(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,