diff --git a/homeassistant/const.py b/homeassistant/const.py index 7a58ce111f8..6f0ba1d6ec6 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -140,6 +140,7 @@ CONF_DOMAIN: Final = "domain" CONF_DOMAINS: Final = "domains" CONF_EFFECT: Final = "effect" CONF_ELEVATION: Final = "elevation" +CONF_ELSE: Final = "else" CONF_EMAIL: Final = "email" CONF_ENTITIES: Final = "entities" CONF_ENTITY_CATEGORY: Final = "entity_category" @@ -165,6 +166,7 @@ CONF_HS: Final = "hs" CONF_ICON: Final = "icon" CONF_ICON_TEMPLATE: Final = "icon_template" CONF_ID: Final = "id" +CONF_IF: Final = "if" CONF_INCLUDE: Final = "include" CONF_INTERNAL_URL: Final = "internal_url" CONF_IP_ADDRESS: Final = "ip_address" @@ -232,6 +234,7 @@ CONF_STRUCTURE: Final = "structure" CONF_SWITCHES: Final = "switches" CONF_TARGET: Final = "target" CONF_TEMPERATURE_UNIT: Final = "temperature_unit" +CONF_THEN: Final = "then" CONF_TIMEOUT: Final = "timeout" CONF_TIME_ZONE: Final = "time_zone" CONF_TOKEN: Final = "token" diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 899d038cbe3..e00fa0d5c8a 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -42,6 +42,7 @@ from homeassistant.const import ( CONF_DELAY, CONF_DEVICE_ID, CONF_DOMAIN, + CONF_ELSE, CONF_ENTITY_ID, CONF_ENTITY_NAMESPACE, CONF_ERROR, @@ -50,6 +51,7 @@ from homeassistant.const import ( CONF_EVENT_DATA_TEMPLATE, CONF_FOR, CONF_ID, + CONF_IF, CONF_MATCH, CONF_PLATFORM, CONF_REPEAT, @@ -61,6 +63,7 @@ from homeassistant.const import ( CONF_STATE, CONF_STOP, CONF_TARGET, + CONF_THEN, CONF_TIMEOUT, CONF_UNIT_SYSTEM_IMPERIAL, CONF_UNIT_SYSTEM_METRIC, @@ -1420,6 +1423,15 @@ _SCRIPT_WAIT_FOR_TRIGGER_SCHEMA = vol.Schema( } ) +_SCRIPT_IF_SCHEMA = vol.Schema( + { + **SCRIPT_ACTION_BASE_SCHEMA, + vol.Required(CONF_IF): vol.All(ensure_list, [CONDITION_SCHEMA]), + vol.Required(CONF_THEN): SCRIPT_SCHEMA, + vol.Optional(CONF_ELSE): SCRIPT_SCHEMA, + } +) + _SCRIPT_SET_SCHEMA = vol.Schema( { **SCRIPT_ACTION_BASE_SCHEMA, @@ -1454,6 +1466,7 @@ SCRIPT_ACTION_WAIT_FOR_TRIGGER = "wait_for_trigger" SCRIPT_ACTION_VARIABLES = "variables" SCRIPT_ACTION_STOP = "stop" SCRIPT_ACTION_ERROR = "error" +SCRIPT_ACTION_IF = "if" def determine_script_action(action: dict[str, Any]) -> str: @@ -1488,6 +1501,9 @@ def determine_script_action(action: dict[str, Any]) -> str: if CONF_VARIABLES in action: return SCRIPT_ACTION_VARIABLES + if CONF_IF in action: + return SCRIPT_ACTION_IF + if CONF_SERVICE in action or CONF_SERVICE_TEMPLATE in action: return SCRIPT_ACTION_CALL_SERVICE @@ -1514,6 +1530,7 @@ ACTION_TYPE_SCHEMAS: dict[str, Callable[[Any], dict]] = { SCRIPT_ACTION_VARIABLES: _SCRIPT_SET_SCHEMA, SCRIPT_ACTION_STOP: _SCRIPT_STOP_SCHEMA, SCRIPT_ACTION_ERROR: _SCRIPT_ERROR_SCHEMA, + SCRIPT_ACTION_IF: _SCRIPT_IF_SCHEMA, } diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index f6109e9d6f8..caff3d19f4b 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -33,10 +33,12 @@ from homeassistant.const import ( CONF_DELAY, CONF_DEVICE_ID, CONF_DOMAIN, + CONF_ELSE, CONF_ERROR, CONF_EVENT, CONF_EVENT_DATA, CONF_EVENT_DATA_TEMPLATE, + CONF_IF, CONF_MODE, CONF_REPEAT, CONF_SCENE, @@ -44,6 +46,7 @@ from homeassistant.const import ( CONF_SERVICE, CONF_STOP, CONF_TARGET, + CONF_THEN, CONF_TIMEOUT, CONF_UNTIL, CONF_VARIABLES, @@ -295,6 +298,15 @@ async def async_validate_action_config( hass, choose_conf[CONF_SEQUENCE] ) + elif action_type == cv.SCRIPT_ACTION_IF: + config[CONF_IF] = await condition.async_validate_conditions_config( + hass, config[CONF_IF] + ) + config[CONF_THEN] = await async_validate_actions_config(hass, config[CONF_THEN]) + if CONF_ELSE in config: + config[CONF_ELSE] = await async_validate_actions_config( + hass, config[CONF_ELSE] + ) else: raise ValueError(f"No validation for {action_type}") @@ -780,6 +792,31 @@ class _ScriptRun: with trace_path(["default"]): await self._async_run_script(choose_data["default"]) + async def _async_if_step(self) -> None: + """If sequence.""" + # pylint: disable=protected-access + if_data = await self._script._async_get_if_data(self._step) + + test_conditions = False + try: + with trace_path("if"): + test_conditions = self._test_conditions( + if_data["if_conditions"], "if", "condition" + ) + except exceptions.ConditionError as ex: + _LOGGER.warning("Error in 'if' evaluation:\n%s", ex) + + if test_conditions: + trace_set_result(choice="then") + with trace_path("then"): + await self._async_run_script(if_data["if_then"]) + return + + if if_data["if_else"] is not None: + trace_set_result(choice="else") + with trace_path("else"): + await self._async_run_script(if_data["if_else"]) + async def _async_wait_for_trigger_step(self): """Wait for a trigger event.""" if CONF_TIMEOUT in self._action: @@ -970,6 +1007,12 @@ class _ChooseData(TypedDict): default: Script | None +class _IfData(TypedDict): + if_conditions: list[ConditionCheckerType] + if_then: Script + if_else: Script | None + + class Script: """Representation of a script.""" @@ -1031,6 +1074,7 @@ class Script: self._config_cache: dict[set[tuple], Callable[..., bool]] = {} self._repeat_script: dict[int, Script] = {} self._choose_data: dict[int, _ChooseData] = {} + self._if_data: dict[int, _IfData] = {} self._referenced_entities: set[str] | None = None self._referenced_devices: set[str] | None = None self._referenced_areas: set[str] | None = None @@ -1070,6 +1114,10 @@ class Script: script.update_logger(self._logger) if choose_data["default"] is not None: choose_data["default"].update_logger(self._logger) + for if_data in self._if_data.values(): + if_data["if_then"].update_logger(self._logger) + if if_data["if_else"] is not None: + if_data["if_else"].update_logger(self._logger) def _changed(self) -> None: if self._change_listener_job: @@ -1125,6 +1173,11 @@ class Script: if CONF_DEFAULT in step: Script._find_referenced_areas(referenced, step[CONF_DEFAULT]) + elif action == cv.SCRIPT_ACTION_IF: + Script._find_referenced_areas(referenced, step[CONF_THEN]) + if CONF_ELSE in step: + Script._find_referenced_areas(referenced, step[CONF_ELSE]) + @property def referenced_devices(self): """Return a set of referenced devices.""" @@ -1162,6 +1215,13 @@ class Script: if CONF_DEFAULT in step: Script._find_referenced_devices(referenced, step[CONF_DEFAULT]) + elif action == cv.SCRIPT_ACTION_IF: + for cond in step[CONF_IF]: + referenced |= condition.async_extract_devices(cond) + Script._find_referenced_devices(referenced, step[CONF_THEN]) + if CONF_ELSE in step: + Script._find_referenced_devices(referenced, step[CONF_ELSE]) + @property def referenced_entities(self): """Return a set of referenced entities.""" @@ -1200,6 +1260,13 @@ class Script: if CONF_DEFAULT in step: Script._find_referenced_entities(referenced, step[CONF_DEFAULT]) + elif action == cv.SCRIPT_ACTION_IF: + for cond in step[CONF_IF]: + referenced |= condition.async_extract_entities(cond) + Script._find_referenced_entities(referenced, step[CONF_THEN]) + if CONF_ELSE in step: + Script._find_referenced_entities(referenced, step[CONF_ELSE]) + def run( self, variables: _VarsType | None = None, context: Context | None = None ) -> None: @@ -1411,6 +1478,58 @@ class Script: self._choose_data[step] = choose_data return choose_data + async def _async_prep_if_data(self, step: int) -> _IfData: + """Prepare data for an if statement.""" + action = self.sequence[step] + step_name = action.get(CONF_ALIAS, f"If at step {step+1}") + + conditions = [ + await self._async_get_condition(config) for config in action[CONF_IF] + ] + + then_script = Script( + self._hass, + action[CONF_THEN], + f"{self.name}: {step_name}", + self.domain, + running_description=self.running_description, + script_mode=SCRIPT_MODE_PARALLEL, + max_runs=self.max_runs, + logger=self._logger, + top_level=False, + ) + then_script.change_listener = partial(self._chain_change_listener, then_script) + + if CONF_ELSE in action: + else_script = Script( + self._hass, + action[CONF_ELSE], + f"{self.name}: {step_name}", + self.domain, + running_description=self.running_description, + script_mode=SCRIPT_MODE_PARALLEL, + max_runs=self.max_runs, + logger=self._logger, + top_level=False, + ) + else_script.change_listener = partial( + self._chain_change_listener, else_script + ) + else: + else_script = None + + return _IfData( + if_conditions=conditions, + if_then=then_script, + if_else=else_script, + ) + + async def _async_get_if_data(self, step: int) -> _IfData: + if not (if_data := self._if_data.get(step)): + if_data = await self._async_prep_if_data(step) + self._if_data[step] = if_data + return if_data + def _log( self, msg: str, *args: Any, level: int = logging.INFO, **kwargs: Any ) -> None: diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index d5b7f8048f4..5a6849df3ef 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -23,7 +23,13 @@ from homeassistant.const import ( CONF_DOMAIN, SERVICE_TURN_ON, ) -from homeassistant.core import SERVICE_CALL_LIMIT, Context, CoreState, callback +from homeassistant.core import ( + SERVICE_CALL_LIMIT, + Context, + CoreState, + HomeAssistant, + callback, +) from homeassistant.exceptions import ConditionError, ServiceNotFound from homeassistant.helpers import ( config_validation as cv, @@ -2510,6 +2516,172 @@ async def test_multiple_runs_repeat_choose(hass, caplog, action): assert len(events) == max_runs +async def test_if_warning( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test warning on if.""" + event = "test_event" + events = async_capture_events(hass, event) + + sequence = cv.SCRIPT_SCHEMA( + { + "if": { + "condition": "numeric_state", + "entity_id": "test.entity", + "value_template": "{{ undefined_a + undefined_b }}", + "above": 1, + }, + "then": {"event": event, "event_data": {"if": "then"}}, + "else": {"event": event, "event_data": {"if": "else"}}, + } + ) + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + hass.states.async_set("test.entity", "9") + await hass.async_block_till_done() + + caplog.clear() + caplog.set_level(logging.WARNING) + + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + assert len(caplog.record_tuples) == 1 + assert caplog.record_tuples[0][1] == logging.WARNING + + assert len(events) == 1 + assert events[0].data["if"] == "else" + + +@pytest.mark.parametrize( + "var,if_result,choice", [(1, True, "then"), (2, False, "else")] +) +async def test_if( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, + var: int, + if_result: bool, + choice: str, +) -> None: + """Test if action.""" + events = async_capture_events(hass, "test_event") + sequence = cv.SCRIPT_SCHEMA( + { + "if": { + "alias": "if condition", + "condition": "template", + "value_template": "{{ var == 1 }}", + }, + "then": { + "alias": "if then", + "event": "test_event", + "event_data": {"if": "then"}, + }, + "else": { + "alias": "if else", + "event": "test_event", + "event_data": {"if": "else"}, + }, + } + ) + + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + await script_obj.async_run(MappingProxyType({"var": var}), Context()) + await hass.async_block_till_done() + + assert len(events) == 1 + assert events[0].data["if"] == choice + assert f"Test Name: If at step 1: Executing step if {choice}" in caplog.text + + expected_trace = { + "0": [{"result": {"choice": choice}}], + "0/if": [{"result": {"result": if_result}}], + "0/if/condition/0": [{"result": {"result": var == 1, "entities": []}}], + f"0/{choice}/0": [ + {"result": {"event": "test_event", "event_data": {"if": choice}}} + ], + } + assert_action_trace(expected_trace) + + +async def test_if_condition_validation( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test if we can use conditions in if actions which validate late.""" + registry = er.async_get(hass) + entry = registry.async_get_or_create( + "test", "hue", "1234", suggested_object_id="entity" + ) + assert entry.entity_id == "test.entity" + events = async_capture_events(hass, "test_event") + sequence = cv.SCRIPT_SCHEMA( + [ + {"event": "test_event"}, + { + "if": { + "condition": "state", + "entity_id": entry.id, + "state": "hello", + }, + "then": { + "event": "test_event", + "event_data": {"if": "then"}, + }, + }, + ] + ) + sequence = await script.async_validate_actions_config(hass, sequence) + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + + hass.states.async_set("test.entity", "hello") + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + caplog.clear() + assert len(events) == 2 + + assert_action_trace( + { + "0": [{"result": {"event": "test_event", "event_data": {}}}], + "1": [{"result": {"choice": "then"}}], + "1/if": [{"result": {"result": True}}], + "1/if/condition/0": [{"result": {"result": True}}], + "1/if/condition/0/entity_id/0": [ + {"result": {"result": True, "state": "hello", "wanted_state": "hello"}} + ], + "1/then/0": [ + {"result": {"event": "test_event", "event_data": {"if": "then"}}} + ], + } + ) + + hass.states.async_set("test.entity", "goodbye") + + await script_obj.async_run(context=Context()) + await hass.async_block_till_done() + + assert len(events) == 3 + + assert_action_trace( + { + "0": [{"result": {"event": "test_event", "event_data": {}}}], + "1": [{}], + "1/if": [{"result": {"result": False}}], + "1/if/condition/0": [{"result": {"result": False}}], + "1/if/condition/0/entity_id/0": [ + { + "result": { + "result": False, + "state": "goodbye", + "wanted_state": "hello", + } + } + ], + }, + ) + + async def test_last_triggered(hass): """Test the last_triggered.""" event = "test_event" @@ -2694,6 +2866,21 @@ async def test_referenced_areas(hass): }, {"event": "test_event"}, {"delay": "{{ delay_period }}"}, + { + "if": [], + "then": [ + { + "service": "test.script", + "data": {"area_id": "area_if_then"}, + } + ], + "else": [ + { + "service": "test.script", + "data": {"area_id": "area_if_else"}, + } + ], + }, ] ), "Test Name", @@ -2707,6 +2894,8 @@ async def test_referenced_areas(hass): "area_in_target", "area_service_list", "area_service_not_list", + "area_if_then", + "area_if_else", # 'area_service_template', # no area extraction from template } # Test we cache results. @@ -2784,6 +2973,21 @@ async def test_referenced_entities(hass): }, {"event": "test_event"}, {"delay": "{{ delay_period }}"}, + { + "if": [], + "then": [ + { + "service": "test.script", + "data": {"entity_id": "light.if_then"}, + } + ], + "else": [ + { + "service": "test.script", + "data": {"entity_id": "light.if_else"}, + } + ], + }, ] ), "Test Name", @@ -2800,6 +3004,8 @@ async def test_referenced_entities(hass): "light.entity_in_target", "light.service_list", "light.service_not_list", + "light.if_then", + "light.if_else", # "light.service_template", # no entity extraction from template "scene.hello", "sensor.condition", @@ -2872,6 +3078,21 @@ async def test_referenced_devices(hass): } ], }, + { + "if": [], + "then": [ + { + "service": "test.script", + "data": {"device_id": "if-then"}, + } + ], + "else": [ + { + "service": "test.script", + "data": {"device_id": "if-else"}, + } + ], + }, ] ), "Test Name", @@ -2890,6 +3111,8 @@ async def test_referenced_devices(hass): "target-list-id-1", "target-list-id-2", "target-string-id", + "if-then", + "if-else", } # Test we cache results. assert script_obj.referenced_devices is script_obj.referenced_devices @@ -3510,6 +3733,17 @@ async def test_validate_action_config(hass): cv.SCRIPT_ACTION_VARIABLES: {"variables": {"hello": "world"}}, cv.SCRIPT_ACTION_STOP: {"stop": "Stop it right there buddy..."}, cv.SCRIPT_ACTION_ERROR: {"error": "Stand up, and try again!"}, + cv.SCRIPT_ACTION_IF: { + "if": [ + { + "condition": "state", + "entity_id": "light.kitchen", + "state": "on", + }, + ], + "then": [templated_device_action("if_then_event")], + "else": [templated_device_action("if_else_event")], + }, } expected_templates = { cv.SCRIPT_ACTION_CHECK_CONDITION: None, @@ -3517,6 +3751,7 @@ async def test_validate_action_config(hass): cv.SCRIPT_ACTION_REPEAT: [["repeat", "sequence", 0]], cv.SCRIPT_ACTION_CHOOSE: [["choose", 0, "sequence", 0], ["default", 0]], cv.SCRIPT_ACTION_WAIT_FOR_TRIGGER: None, + cv.SCRIPT_ACTION_IF: None, } for key in cv.ACTION_TYPE_SCHEMAS: