diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index 981e0988639..82839d1e0f8 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -52,26 +52,30 @@ ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool] async def async_from_config( - hass: HomeAssistant, config: ConfigType, config_validation: bool = True + hass: HomeAssistant, + config: Union[ConfigType, Template], + config_validation: bool = True, ) -> ConditionCheckerType: """Turn a condition configuration into a method. Should be run on the event loop. """ + if isinstance(config, Template): + # We got a condition template, wrap it in a configuration to pass along. + config = { + CONF_CONDITION: "template", + CONF_VALUE_TEMPLATE: config, + } + + condition = config.get(CONF_CONDITION) for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT): - factory = getattr( - sys.modules[__name__], fmt.format(config.get(CONF_CONDITION)), None - ) + factory = getattr(sys.modules[__name__], fmt.format(condition), None) if factory: break if factory is None: - raise HomeAssistantError( - 'Invalid condition "{}" specified {}'.format( - config.get(CONF_CONDITION), config - ) - ) + raise HomeAssistantError(f'Invalid condition "{condition}" specified {config}') # Check for partials to properly determine if coroutine function check_factory = factory @@ -584,9 +588,12 @@ async def async_device_from_config( async def async_validate_condition_config( - hass: HomeAssistant, config: ConfigType -) -> ConfigType: + hass: HomeAssistant, config: Union[ConfigType, Template] +) -> Union[ConfigType, Template]: """Validate config.""" + if isinstance(config, Template): + return config + condition = config[CONF_CONDITION] if condition in ("and", "not", "or"): conditions = [] @@ -597,6 +604,7 @@ async def async_validate_condition_config( if condition == "device": config = cv.DEVICE_CONDITION_SCHEMA(config) + assert not isinstance(config, Template) platform = await async_get_device_automation_platform( hass, config[CONF_DOMAIN], "condition" ) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index dd5a8b6522c..cc8adea81da 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -1018,20 +1018,25 @@ DEVICE_CONDITION_BASE_SCHEMA = vol.Schema( DEVICE_CONDITION_SCHEMA = DEVICE_CONDITION_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA) -CONDITION_SCHEMA: vol.Schema = key_value_schemas( - CONF_CONDITION, - { - "numeric_state": NUMERIC_STATE_CONDITION_SCHEMA, - "state": STATE_CONDITION_SCHEMA, - "sun": SUN_CONDITION_SCHEMA, - "template": TEMPLATE_CONDITION_SCHEMA, - "time": TIME_CONDITION_SCHEMA, - "zone": ZONE_CONDITION_SCHEMA, - "and": AND_CONDITION_SCHEMA, - "or": OR_CONDITION_SCHEMA, - "not": NOT_CONDITION_SCHEMA, - "device": DEVICE_CONDITION_SCHEMA, - }, +CONDITION_SCHEMA: vol.Schema = vol.Schema( + vol.Any( + key_value_schemas( + CONF_CONDITION, + { + "numeric_state": NUMERIC_STATE_CONDITION_SCHEMA, + "state": STATE_CONDITION_SCHEMA, + "sun": SUN_CONDITION_SCHEMA, + "template": TEMPLATE_CONDITION_SCHEMA, + "time": TIME_CONDITION_SCHEMA, + "zone": ZONE_CONDITION_SCHEMA, + "and": AND_CONDITION_SCHEMA, + "or": OR_CONDITION_SCHEMA, + "not": NOT_CONDITION_SCHEMA, + "device": DEVICE_CONDITION_SCHEMA, + }, + ), + dynamic_template, + ) ) TRIGGER_SCHEMA = vol.All( diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index d4d9c11fa71..604102e6af3 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -935,7 +935,10 @@ class Script: await asyncio.shield(self._async_stop(update_state)) async def _async_get_condition(self, config): - config_cache_key = frozenset((k, str(v)) for k, v in config.items()) + if isinstance(config, template.Template): + config_cache_key = config.template + else: + config_cache_key = frozenset((k, str(v)) for k, v in config.items()) cond = self._config_cache.get(config_cache_key) if not cond: cond = await condition.async_from_config(self._hass, config, False) diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 8bbe28d3003..3952e781952 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -232,6 +232,31 @@ async def test_two_conditions_with_and(hass, calls): assert len(calls) == 1 +async def test_shorthand_conditions_template(hass, calls): + """Test shorthand nation form in conditions.""" + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "trigger": [{"platform": "event", "event_type": "test_event"}], + "condition": "{{ is_state('test.entity', 'hello') }}", + "action": {"service": "test.automation"}, + } + }, + ) + + hass.states.async_set("test.entity", "hello") + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + assert len(calls) == 1 + + hass.states.async_set("test.entity", "goodbye") + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + assert len(calls) == 1 + + async def test_automation_list_setting(hass, calls): """Event is not a valid condition.""" assert await async_setup_component( diff --git a/tests/helpers/test_condition.py b/tests/helpers/test_condition.py index afe1c294290..dcd652913e5 100644 --- a/tests/helpers/test_condition.py +++ b/tests/helpers/test_condition.py @@ -128,10 +128,7 @@ async def test_or_condition_with_template(hass): { "condition": "or", "conditions": [ - { - "condition": "template", - "value_template": '{{ states.sensor.temperature.state == "100" }}', - }, + {'{{ states.sensor.temperature.state == "100" }}'}, { "condition": "numeric_state", "entity_id": "sensor.temperature", diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index e997e3e92d6..d298283d11e 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -982,7 +982,8 @@ async def test_repeat_count(hass): @pytest.mark.parametrize("condition", ["while", "until"]) -async def test_repeat_conditional(hass, condition): +@pytest.mark.parametrize("direct_template", [False, True]) +async def test_repeat_conditional(hass, condition, direct_template): """Test repeat action w/ while option.""" event = "test_event" events = async_capture_events(hass, event) @@ -1004,15 +1005,23 @@ async def test_repeat_conditional(hass, condition): } } if condition == "while": - sequence["repeat"]["while"] = { - "condition": "template", - "value_template": "{{ not is_state('sensor.test', 'done') }}", - } + template = "{{ not is_state('sensor.test', 'done') }}" + if direct_template: + sequence["repeat"]["while"] = template + else: + sequence["repeat"]["while"] = { + "condition": "template", + "value_template": template, + } else: - sequence["repeat"]["until"] = { - "condition": "template", - "value_template": "{{ is_state('sensor.test', 'done') }}", - } + template = "{{ is_state('sensor.test', 'done') }}" + if direct_template: + sequence["repeat"]["until"] = template + else: + sequence["repeat"]["until"] = { + "condition": "template", + "value_template": template, + } script_obj = script.Script( hass, cv.SCRIPT_SCHEMA(sequence), "Test Name", "test_domain" ) @@ -1193,10 +1202,7 @@ async def test_choose(hass, var, result): "sequence": {"event": event, "event_data": {"choice": "first"}}, }, { - "conditions": { - "condition": "template", - "value_template": "{{ var == 2 }}", - }, + "conditions": "{{ var == 2 }}", "sequence": {"event": event, "event_data": {"choice": "second"}}, }, ],