mirror of
https://github.com/home-assistant/core.git
synced 2025-11-09 10:59:40 +00:00
Add choose script action (#37818)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
@@ -15,9 +15,12 @@ import homeassistant.components.scene as scene
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
CONF_ALIAS,
|
||||
CONF_CHOOSE,
|
||||
CONF_CONDITION,
|
||||
CONF_CONDITIONS,
|
||||
CONF_CONTINUE_ON_TIMEOUT,
|
||||
CONF_COUNT,
|
||||
CONF_DEFAULT,
|
||||
CONF_DELAY,
|
||||
CONF_DEVICE_ID,
|
||||
CONF_DOMAIN,
|
||||
@@ -138,9 +141,9 @@ class _ScriptRun:
|
||||
if not self._stop.is_set():
|
||||
self._script._changed() # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def _config_cache(self):
|
||||
return self._script._config_cache # pylint: disable=protected-access
|
||||
async def _async_get_condition(self, config):
|
||||
# pylint: disable=protected-access
|
||||
return await self._script._async_get_condition(config)
|
||||
|
||||
def _log(self, msg, *args, level=logging.INFO):
|
||||
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
|
||||
@@ -404,14 +407,6 @@ class _ScriptRun:
|
||||
self._action[CONF_EVENT], event_data, context=self._context
|
||||
)
|
||||
|
||||
async def _async_get_condition(self, config):
|
||||
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
|
||||
cond = self._config_cache.get(config_cache_key)
|
||||
if not cond:
|
||||
cond = await condition.async_from_config(self._hass, config, False)
|
||||
self._config_cache[config_cache_key] = cond
|
||||
return cond
|
||||
|
||||
async def _async_condition_step(self):
|
||||
"""Test if condition is matching."""
|
||||
self._script.last_action = self._action.get(
|
||||
@@ -434,16 +429,13 @@ class _ScriptRun:
|
||||
repeat_vars = {"repeat": {"first": iteration == 1, "index": iteration}}
|
||||
if extra_vars:
|
||||
repeat_vars["repeat"].update(extra_vars)
|
||||
task = self._hass.async_create_task(
|
||||
# pylint: disable=protected-access
|
||||
self._script._repeat_script[self._step].async_run(
|
||||
# Add repeat to variables. Override if it already exists in case of
|
||||
# nested calls.
|
||||
{**(self._variables or {}), **repeat_vars},
|
||||
self._context,
|
||||
)
|
||||
# pylint: disable=protected-access
|
||||
await self._async_run_script(
|
||||
self._script._get_repeat_script(self._step),
|
||||
# Add repeat to variables. Override if it already exists in case of
|
||||
# nested calls.
|
||||
{**(self._variables or {}), **repeat_vars},
|
||||
)
|
||||
await self._async_run_long_action(task)
|
||||
|
||||
if CONF_COUNT in repeat:
|
||||
count = repeat[CONF_COUNT]
|
||||
@@ -487,6 +479,27 @@ class _ScriptRun:
|
||||
):
|
||||
break
|
||||
|
||||
async def _async_choose_step(self):
|
||||
"""Choose a sequence."""
|
||||
# pylint: disable=protected-access
|
||||
choose_data = await self._script._async_get_choose_data(self._step)
|
||||
|
||||
for conditions, script in choose_data["choices"]:
|
||||
if all(condition(self._hass, self._variables) for condition in conditions):
|
||||
await self._async_run_script(script)
|
||||
return
|
||||
|
||||
if choose_data["default"]:
|
||||
await self._async_run_script(choose_data["default"])
|
||||
|
||||
async def _async_run_script(self, script, variables=None):
|
||||
"""Execute a script."""
|
||||
await self._async_run_long_action(
|
||||
self._hass.async_create_task(
|
||||
script.async_run(variables or self._variables, self._context)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class _QueuedScriptRun(_ScriptRun):
|
||||
"""Manage queued Script sequence run."""
|
||||
@@ -562,27 +575,15 @@ class Script:
|
||||
self.last_triggered: Optional[datetime] = None
|
||||
self.can_cancel = True
|
||||
|
||||
self._repeat_script = {}
|
||||
for step, action in enumerate(sequence):
|
||||
if cv.determine_script_action(action) == cv.SCRIPT_ACTION_REPEAT:
|
||||
step_name = action.get(CONF_ALIAS, f"Repeat at step {step}")
|
||||
sub_script = Script(
|
||||
hass,
|
||||
action[CONF_REPEAT][CONF_SEQUENCE],
|
||||
f"{name}: {step_name}",
|
||||
script_mode=SCRIPT_MODE_PARALLEL,
|
||||
logger=self._logger,
|
||||
)
|
||||
sub_script.change_listener = partial(
|
||||
self._chain_change_listener, sub_script
|
||||
)
|
||||
self._repeat_script[step] = sub_script
|
||||
|
||||
self._runs: List[_ScriptRun] = []
|
||||
self._max_runs = max_runs
|
||||
if script_mode == SCRIPT_MODE_QUEUED:
|
||||
self._queue_lck = asyncio.Lock()
|
||||
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {}
|
||||
self._repeat_script: Dict[int, Script] = {}
|
||||
self._choose_data: Dict[
|
||||
int, List[Tuple[List[Callable[[HomeAssistant, Dict], bool]], Script]]
|
||||
] = {}
|
||||
self._referenced_entities: Optional[Set[str]] = None
|
||||
self._referenced_devices: Optional[Set[str]] = None
|
||||
|
||||
@@ -701,6 +702,78 @@ class Script:
|
||||
if self.is_running:
|
||||
await asyncio.shield(self._async_stop(update_state))
|
||||
|
||||
async def _async_get_condition(self, config):
|
||||
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
|
||||
cond = self._config_cache.get(config_cache_key)
|
||||
if not cond:
|
||||
cond = await condition.async_from_config(self._hass, config, False)
|
||||
self._config_cache[config_cache_key] = cond
|
||||
return cond
|
||||
|
||||
def _prep_repeat_script(self, step):
|
||||
action = self.sequence[step]
|
||||
step_name = action.get(CONF_ALIAS, f"Repeat at step {step+1}")
|
||||
sub_script = Script(
|
||||
self._hass,
|
||||
action[CONF_REPEAT][CONF_SEQUENCE],
|
||||
f"{self.name}: {step_name}",
|
||||
script_mode=SCRIPT_MODE_PARALLEL,
|
||||
logger=self._logger,
|
||||
)
|
||||
sub_script.change_listener = partial(self._chain_change_listener, sub_script)
|
||||
return sub_script
|
||||
|
||||
def _get_repeat_script(self, step):
|
||||
sub_script = self._repeat_script.get(step)
|
||||
if not sub_script:
|
||||
sub_script = self._prep_repeat_script(step)
|
||||
self._repeat_script[step] = sub_script
|
||||
return sub_script
|
||||
|
||||
async def _async_prep_choose_data(self, step):
|
||||
action = self.sequence[step]
|
||||
step_name = action.get(CONF_ALIAS, f"Choose at step {step+1}")
|
||||
choices = []
|
||||
for idx, choice in enumerate(action[CONF_CHOOSE], start=1):
|
||||
conditions = [
|
||||
await self._async_get_condition(config)
|
||||
for config in choice.get(CONF_CONDITIONS, [])
|
||||
]
|
||||
sub_script = Script(
|
||||
self._hass,
|
||||
choice[CONF_SEQUENCE],
|
||||
f"{self.name}: {step_name}: choice {idx}",
|
||||
script_mode=SCRIPT_MODE_PARALLEL,
|
||||
logger=self._logger,
|
||||
)
|
||||
sub_script.change_listener = partial(
|
||||
self._chain_change_listener, sub_script
|
||||
)
|
||||
choices.append((conditions, sub_script))
|
||||
|
||||
if CONF_DEFAULT in action:
|
||||
default_script = Script(
|
||||
self._hass,
|
||||
action[CONF_DEFAULT],
|
||||
f"{self.name}: {step_name}: default",
|
||||
script_mode=SCRIPT_MODE_PARALLEL,
|
||||
logger=self._logger,
|
||||
)
|
||||
default_script.change_listener = partial(
|
||||
self._chain_change_listener, default_script
|
||||
)
|
||||
else:
|
||||
default_script = None
|
||||
|
||||
return {"choices": choices, "default": default_script}
|
||||
|
||||
async def _async_get_choose_data(self, step):
|
||||
choose_data = self._choose_data.get(step)
|
||||
if not choose_data:
|
||||
choose_data = await self._async_prep_choose_data(step)
|
||||
self._choose_data[step] = choose_data
|
||||
return choose_data
|
||||
|
||||
def _log(self, msg, *args, level=logging.INFO):
|
||||
if self.name:
|
||||
msg = f"%s: {msg}"
|
||||
|
||||
Reference in New Issue
Block a user