diff --git a/homeassistant/const.py b/homeassistant/const.py index f3ed345682e..cbb236a426c 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -47,6 +47,7 @@ CONF_BINARY_SENSORS = "binary_sensors" CONF_BRIGHTNESS = "brightness" CONF_BROADCAST_ADDRESS = "broadcast_address" CONF_BROADCAST_PORT = "broadcast_port" +CONF_CHOOSE = "choose" CONF_CLIENT_ID = "client_id" CONF_CLIENT_SECRET = "client_secret" CONF_CODE = "code" @@ -59,6 +60,7 @@ CONF_COMMAND_OPEN = "command_open" CONF_COMMAND_STATE = "command_state" CONF_COMMAND_STOP = "command_stop" CONF_CONDITION = "condition" +CONF_CONDITIONS = "conditions" CONF_CONTINUE_ON_TIMEOUT = "continue_on_timeout" CONF_COUNT = "count" CONF_COVERS = "covers" @@ -66,6 +68,7 @@ CONF_CURRENCY = "currency" CONF_CUSTOMIZE = "customize" CONF_CUSTOMIZE_DOMAIN = "customize_domain" CONF_CUSTOMIZE_GLOB = "customize_glob" +CONF_DEFAULT = "default" CONF_DELAY = "delay" CONF_DELAY_TIME = "delay_time" CONF_DEVICE = "device" diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 0e20dea718b..9be584403bd 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -38,9 +38,12 @@ from homeassistant.const import ( CONF_ABOVE, CONF_ALIAS, CONF_BELOW, + CONF_CHOOSE, CONF_CONDITION, + CONF_CONDITIONS, CONF_CONTINUE_ON_TIMEOUT, CONF_COUNT, + CONF_DEFAULT, CONF_DELAY, CONF_DEVICE_ID, CONF_DOMAIN, @@ -930,7 +933,7 @@ ZONE_CONDITION_SCHEMA = vol.Schema( AND_CONDITION_SCHEMA = vol.Schema( { vol.Required(CONF_CONDITION): "and", - vol.Required("conditions"): vol.All( + vol.Required(CONF_CONDITIONS): vol.All( ensure_list, # pylint: disable=unnecessary-lambda [lambda value: CONDITION_SCHEMA(value)], @@ -941,7 +944,7 @@ AND_CONDITION_SCHEMA = vol.Schema( OR_CONDITION_SCHEMA = vol.Schema( { vol.Required(CONF_CONDITION): "or", - vol.Required("conditions"): vol.All( + vol.Required(CONF_CONDITIONS): vol.All( ensure_list, # pylint: disable=unnecessary-lambda [lambda value: CONDITION_SCHEMA(value)], @@ -952,7 +955,7 @@ OR_CONDITION_SCHEMA = vol.Schema( NOT_CONDITION_SCHEMA = vol.Schema( { vol.Required(CONF_CONDITION): "not", - vol.Required("conditions"): vol.All( + vol.Required(CONF_CONDITIONS): vol.All( ensure_list, # pylint: disable=unnecessary-lambda [lambda value: CONDITION_SCHEMA(value)], @@ -1031,6 +1034,24 @@ _SCRIPT_REPEAT_SCHEMA = vol.Schema( } ) +_SCRIPT_CHOOSE_SCHEMA = vol.Schema( + { + vol.Optional(CONF_ALIAS): string, + vol.Required(CONF_CHOOSE): vol.All( + ensure_list, + [ + { + vol.Required(CONF_CONDITIONS): vol.All( + ensure_list, [CONDITION_SCHEMA] + ), + vol.Required(CONF_SEQUENCE): SCRIPT_SCHEMA, + } + ], + ), + vol.Optional(CONF_DEFAULT): SCRIPT_SCHEMA, + } +) + SCRIPT_ACTION_DELAY = "delay" SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template" SCRIPT_ACTION_CHECK_CONDITION = "condition" @@ -1039,6 +1060,7 @@ SCRIPT_ACTION_CALL_SERVICE = "call_service" SCRIPT_ACTION_DEVICE_AUTOMATION = "device" SCRIPT_ACTION_ACTIVATE_SCENE = "scene" SCRIPT_ACTION_REPEAT = "repeat" +SCRIPT_ACTION_CHOOSE = "choose" def determine_script_action(action: dict) -> str: @@ -1064,6 +1086,9 @@ def determine_script_action(action: dict) -> str: if CONF_REPEAT in action: return SCRIPT_ACTION_REPEAT + if CONF_CHOOSE in action: + return SCRIPT_ACTION_CHOOSE + return SCRIPT_ACTION_CALL_SERVICE @@ -1076,4 +1101,5 @@ ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = { SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA, SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA, SCRIPT_ACTION_REPEAT: _SCRIPT_REPEAT_SCHEMA, + SCRIPT_ACTION_CHOOSE: _SCRIPT_CHOOSE_SCHEMA, } diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 268028eec3d..a14106053fe 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -15,9 +15,12 @@ import homeassistant.components.scene as scene from homeassistant.const import ( ATTR_ENTITY_ID, CONF_ALIAS, + CONF_CHOOSE, CONF_CONDITION, + CONF_CONDITIONS, CONF_CONTINUE_ON_TIMEOUT, CONF_COUNT, + CONF_DEFAULT, CONF_DELAY, CONF_DEVICE_ID, CONF_DOMAIN, @@ -138,9 +141,9 @@ class _ScriptRun: if not self._stop.is_set(): self._script._changed() # pylint: disable=protected-access - @property - def _config_cache(self): - return self._script._config_cache # pylint: disable=protected-access + async def _async_get_condition(self, config): + # pylint: disable=protected-access + return await self._script._async_get_condition(config) def _log(self, msg, *args, level=logging.INFO): self._script._log(msg, *args, level=level) # pylint: disable=protected-access @@ -404,14 +407,6 @@ class _ScriptRun: self._action[CONF_EVENT], event_data, context=self._context ) - async def _async_get_condition(self, config): - config_cache_key = frozenset((k, str(v)) for k, v in config.items()) - cond = self._config_cache.get(config_cache_key) - if not cond: - cond = await condition.async_from_config(self._hass, config, False) - self._config_cache[config_cache_key] = cond - return cond - async def _async_condition_step(self): """Test if condition is matching.""" self._script.last_action = self._action.get( @@ -434,16 +429,13 @@ class _ScriptRun: repeat_vars = {"repeat": {"first": iteration == 1, "index": iteration}} if extra_vars: repeat_vars["repeat"].update(extra_vars) - task = self._hass.async_create_task( - # pylint: disable=protected-access - self._script._repeat_script[self._step].async_run( - # Add repeat to variables. Override if it already exists in case of - # nested calls. - {**(self._variables or {}), **repeat_vars}, - self._context, - ) + # pylint: disable=protected-access + await self._async_run_script( + self._script._get_repeat_script(self._step), + # Add repeat to variables. Override if it already exists in case of + # nested calls. + {**(self._variables or {}), **repeat_vars}, ) - await self._async_run_long_action(task) if CONF_COUNT in repeat: count = repeat[CONF_COUNT] @@ -487,6 +479,27 @@ class _ScriptRun: ): break + async def _async_choose_step(self): + """Choose a sequence.""" + # pylint: disable=protected-access + choose_data = await self._script._async_get_choose_data(self._step) + + for conditions, script in choose_data["choices"]: + if all(condition(self._hass, self._variables) for condition in conditions): + await self._async_run_script(script) + return + + if choose_data["default"]: + await self._async_run_script(choose_data["default"]) + + async def _async_run_script(self, script, variables=None): + """Execute a script.""" + await self._async_run_long_action( + self._hass.async_create_task( + script.async_run(variables or self._variables, self._context) + ) + ) + class _QueuedScriptRun(_ScriptRun): """Manage queued Script sequence run.""" @@ -562,27 +575,15 @@ class Script: self.last_triggered: Optional[datetime] = None self.can_cancel = True - self._repeat_script = {} - for step, action in enumerate(sequence): - if cv.determine_script_action(action) == cv.SCRIPT_ACTION_REPEAT: - step_name = action.get(CONF_ALIAS, f"Repeat at step {step}") - sub_script = Script( - hass, - action[CONF_REPEAT][CONF_SEQUENCE], - f"{name}: {step_name}", - script_mode=SCRIPT_MODE_PARALLEL, - logger=self._logger, - ) - sub_script.change_listener = partial( - self._chain_change_listener, sub_script - ) - self._repeat_script[step] = sub_script - self._runs: List[_ScriptRun] = [] self._max_runs = max_runs if script_mode == SCRIPT_MODE_QUEUED: self._queue_lck = asyncio.Lock() self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {} + self._repeat_script: Dict[int, Script] = {} + self._choose_data: Dict[ + int, List[Tuple[List[Callable[[HomeAssistant, Dict], bool]], Script]] + ] = {} self._referenced_entities: Optional[Set[str]] = None self._referenced_devices: Optional[Set[str]] = None @@ -701,6 +702,78 @@ class Script: if self.is_running: await asyncio.shield(self._async_stop(update_state)) + async def _async_get_condition(self, config): + config_cache_key = frozenset((k, str(v)) for k, v in config.items()) + cond = self._config_cache.get(config_cache_key) + if not cond: + cond = await condition.async_from_config(self._hass, config, False) + self._config_cache[config_cache_key] = cond + return cond + + def _prep_repeat_script(self, step): + action = self.sequence[step] + step_name = action.get(CONF_ALIAS, f"Repeat at step {step+1}") + sub_script = Script( + self._hass, + action[CONF_REPEAT][CONF_SEQUENCE], + f"{self.name}: {step_name}", + script_mode=SCRIPT_MODE_PARALLEL, + logger=self._logger, + ) + sub_script.change_listener = partial(self._chain_change_listener, sub_script) + return sub_script + + def _get_repeat_script(self, step): + sub_script = self._repeat_script.get(step) + if not sub_script: + sub_script = self._prep_repeat_script(step) + self._repeat_script[step] = sub_script + return sub_script + + async def _async_prep_choose_data(self, step): + action = self.sequence[step] + step_name = action.get(CONF_ALIAS, f"Choose at step {step+1}") + choices = [] + for idx, choice in enumerate(action[CONF_CHOOSE], start=1): + conditions = [ + await self._async_get_condition(config) + for config in choice.get(CONF_CONDITIONS, []) + ] + sub_script = Script( + self._hass, + choice[CONF_SEQUENCE], + f"{self.name}: {step_name}: choice {idx}", + script_mode=SCRIPT_MODE_PARALLEL, + logger=self._logger, + ) + sub_script.change_listener = partial( + self._chain_change_listener, sub_script + ) + choices.append((conditions, sub_script)) + + if CONF_DEFAULT in action: + default_script = Script( + self._hass, + action[CONF_DEFAULT], + f"{self.name}: {step_name}: default", + script_mode=SCRIPT_MODE_PARALLEL, + logger=self._logger, + ) + default_script.change_listener = partial( + self._chain_change_listener, default_script + ) + else: + default_script = None + + return {"choices": choices, "default": default_script} + + async def _async_get_choose_data(self, step): + choose_data = self._choose_data.get(step) + if not choose_data: + choose_data = await self._async_prep_choose_data(step) + self._choose_data[step] = choose_data + return choose_data + def _log(self, msg, *args, level=logging.INFO): if self.name: msg = f"%s: {msg}" diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 4ba407dd046..2e196b1f4ed 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -877,6 +877,41 @@ async def test_repeat_conditional(hass, condition): assert event.data.get("index") == str(index + 1) +@pytest.mark.parametrize("var,result", [(1, "first"), (2, "second"), (3, "default")]) +async def test_choose(hass, var, result): + """Test choose action.""" + event = "test_event" + events = async_capture_events(hass, event) + sequence = cv.SCRIPT_SCHEMA( + { + "choose": [ + { + "conditions": { + "condition": "template", + "value_template": "{{ var == 1 }}", + }, + "sequence": {"event": event, "event_data": {"choice": "first"}}, + }, + { + "conditions": { + "condition": "template", + "value_template": "{{ var == 2 }}", + }, + "sequence": {"event": event, "event_data": {"choice": "second"}}, + }, + ], + "default": {"event": event, "event_data": {"choice": "default"}}, + } + ) + script_obj = script.Script(hass, sequence) + + await script_obj.async_run({"var": var}) + await hass.async_block_till_done() + + assert len(events) == 1 + assert events[0].data["choice"] == result + + async def test_last_triggered(hass): """Test the last_triggered.""" event = "test_event"