From b01b33c30409eea2b3f19f18348a33879df1c369 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Fri, 11 Jun 2021 15:05:57 +0200 Subject: [PATCH] Add trigger condition (#51710) * Add trigger condition * Tweaks, add tests --- homeassistant/helpers/condition.py | 21 +++++ homeassistant/helpers/config_validation.py | 22 +++-- homeassistant/helpers/trigger.py | 5 +- tests/components/automation/test_init.py | 94 ++++++++++++++++++++++ tests/helpers/test_condition.py | 14 ++++ 5 files changed, 149 insertions(+), 7 deletions(-) diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index d3020b8d6d8..cea79c4fc8f 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -31,6 +31,7 @@ from homeassistant.const import ( CONF_DEVICE_ID, CONF_DOMAIN, CONF_ENTITY_ID, + CONF_ID, CONF_STATE, CONF_VALUE_TEMPLATE, CONF_WEEKDAY, @@ -930,6 +931,26 @@ async def async_device_from_config( ) +async def async_trigger_from_config( + hass: HomeAssistant, config: ConfigType, config_validation: bool = True +) -> ConditionCheckerType: + """Test a trigger condition.""" + if config_validation: + config = cv.TRIGGER_CONDITION_SCHEMA(config) + trigger_id = config[CONF_ID] + + @trace_condition_function + def trigger_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: + """Validate trigger based if-condition.""" + return ( + variables is not None + and "trigger" in variables + and variables["trigger"].get("id") in trigger_id + ) + + return trigger_if + + async def async_validate_condition_config( hass: HomeAssistant, config: ConfigType | Template ) -> ConfigType | Template: diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index feb03cf04a2..e195c1ded31 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -45,6 +45,7 @@ from homeassistant.const import ( CONF_EVENT_DATA, CONF_EVENT_DATA_TEMPLATE, CONF_FOR, + CONF_ID, CONF_PLATFORM, CONF_REPEAT, CONF_SCAN_INTERVAL, @@ -1026,6 +1027,14 @@ TIME_CONDITION_SCHEMA = vol.All( has_at_least_one_key("before", "after", "weekday"), ) +TRIGGER_CONDITION_SCHEMA = vol.Schema( + { + **CONDITION_BASE_SCHEMA, + vol.Required(CONF_CONDITION): "trigger", + vol.Required(CONF_ID): vol.All(ensure_list, [string]), + } +) + ZONE_CONDITION_SCHEMA = vol.Schema( { **CONDITION_BASE_SCHEMA, @@ -1090,23 +1099,26 @@ CONDITION_SCHEMA: vol.Schema = vol.Schema( key_value_schemas( CONF_CONDITION, { + "and": AND_CONDITION_SCHEMA, + "device": DEVICE_CONDITION_SCHEMA, + "not": NOT_CONDITION_SCHEMA, "numeric_state": NUMERIC_STATE_CONDITION_SCHEMA, + "or": OR_CONDITION_SCHEMA, "state": STATE_CONDITION_SCHEMA, "sun": SUN_CONDITION_SCHEMA, "template": TEMPLATE_CONDITION_SCHEMA, "time": TIME_CONDITION_SCHEMA, + "trigger": TRIGGER_CONDITION_SCHEMA, "zone": ZONE_CONDITION_SCHEMA, - "and": AND_CONDITION_SCHEMA, - "or": OR_CONDITION_SCHEMA, - "not": NOT_CONDITION_SCHEMA, - "device": DEVICE_CONDITION_SCHEMA, }, ), dynamic_template, ) ) -TRIGGER_BASE_SCHEMA = vol.Schema({vol.Required(CONF_PLATFORM): str}) +TRIGGER_BASE_SCHEMA = vol.Schema( + {vol.Required(CONF_PLATFORM): str, vol.Optional(CONF_ID): str} +) TRIGGER_SCHEMA = vol.All( ensure_list, [TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)] diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 045c56d964c..64c5373d8f5 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -8,7 +8,7 @@ from typing import Any, Callable import voluptuous as vol -from homeassistant.const import CONF_PLATFORM +from homeassistant.const import CONF_ID, CONF_PLATFORM from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.typing import ConfigType @@ -74,7 +74,8 @@ async def async_initialize_triggers( triggers = [] for idx, conf in enumerate(trigger_config): platform = await _async_get_trigger_platform(hass, conf) - info = {**info, "trigger_id": f"{idx}"} + trigger_id = conf.get(CONF_ID, f"{idx}") + info = {**info, "trigger_id": trigger_id} triggers.append(platform.async_attach_trigger(hass, conf, action, info)) attach_results = await asyncio.gather(*triggers, return_exceptions=True) diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 5997be22644..80fe5c52abc 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -1405,3 +1405,97 @@ async def test_trigger_service(hass, calls): assert len(calls) == 1 assert calls[0].data.get("trigger") == {"platform": None} assert calls[0].context.parent_id is context.id + + +async def test_trigger_condition_implicit_id(hass, calls): + """Test triggers.""" + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "trigger": [ + {"platform": "event", "event_type": "test_event1"}, + {"platform": "event", "event_type": "test_event2"}, + {"platform": "event", "event_type": "test_event3"}, + ], + "action": { + "choose": [ + { + "conditions": {"condition": "trigger", "id": [0, "2"]}, + "sequence": { + "service": "test.automation", + "data": {"param": "one"}, + }, + }, + { + "conditions": {"condition": "trigger", "id": "1"}, + "sequence": { + "service": "test.automation", + "data": {"param": "two"}, + }, + }, + ] + }, + } + }, + ) + + hass.bus.async_fire("test_event1") + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[-1].data.get("param") == "one" + + hass.bus.async_fire("test_event2") + await hass.async_block_till_done() + assert len(calls) == 2 + assert calls[-1].data.get("param") == "two" + + hass.bus.async_fire("test_event3") + await hass.async_block_till_done() + assert len(calls) == 3 + assert calls[-1].data.get("param") == "one" + + +async def test_trigger_condition_explicit_id(hass, calls): + """Test triggers.""" + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "trigger": [ + {"platform": "event", "event_type": "test_event1", "id": "one"}, + {"platform": "event", "event_type": "test_event2", "id": "two"}, + ], + "action": { + "choose": [ + { + "conditions": {"condition": "trigger", "id": "one"}, + "sequence": { + "service": "test.automation", + "data": {"param": "one"}, + }, + }, + { + "conditions": {"condition": "trigger", "id": "two"}, + "sequence": { + "service": "test.automation", + "data": {"param": "two"}, + }, + }, + ] + }, + } + }, + ) + + hass.bus.async_fire("test_event1") + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[-1].data.get("param") == "one" + + hass.bus.async_fire("test_event2") + await hass.async_block_till_done() + assert len(calls) == 2 + assert calls[-1].data.get("param") == "two" diff --git a/tests/helpers/test_condition.py b/tests/helpers/test_condition.py index bd5e15ad11f..38a9367e36d 100644 --- a/tests/helpers/test_condition.py +++ b/tests/helpers/test_condition.py @@ -2829,3 +2829,17 @@ async def test_if_action_after_sunset_no_offset_kotzebue(hass, hass_ws_client, c "sun", {"result": True, "wanted_time_after": "2015-07-23T11:22:18.467277+00:00"}, ) + + +async def test_trigger(hass): + """Test trigger condition.""" + test = await condition.async_from_config( + hass, + {"alias": "Trigger Cond", "condition": "trigger", "id": "123456"}, + ) + + assert not test(hass) + assert not test(hass, {}) + assert not test(hass, {"other_var": "123456"}) + assert not test(hass, {"trigger": {"trigger_id": "123456"}}) + assert test(hass, {"trigger": {"id": "123456"}})