From b524cc9c56d320e11bdc63ca952fd06f71df672e Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 5 Oct 2020 12:53:12 +0200 Subject: [PATCH] Allow any value when triggering on state attribute (#41261) --- .../homeassistant/triggers/state.py | 37 +++++++++++--- homeassistant/helpers/condition.py | 13 +++-- homeassistant/helpers/config_validation.py | 50 +++++++++++++------ .../homeassistant/triggers/test_state.py | 29 +++++++++++ tests/helpers/test_condition.py | 29 ++++++++++- 5 files changed, 131 insertions(+), 27 deletions(-) diff --git a/homeassistant/components/homeassistant/triggers/state.py b/homeassistant/components/homeassistant/triggers/state.py index a7377ffe43e..915856951d2 100644 --- a/homeassistant/components/homeassistant/triggers/state.py +++ b/homeassistant/components/homeassistant/triggers/state.py @@ -1,7 +1,7 @@ """Offer state listening automation rules.""" from datetime import timedelta import logging -from typing import Dict, Optional +from typing import Any, Dict, Optional import voluptuous as vol @@ -25,18 +25,43 @@ CONF_ENTITY_ID = "entity_id" CONF_FROM = "from" CONF_TO = "to" -TRIGGER_SCHEMA = vol.Schema( +BASE_SCHEMA = { + vol.Required(CONF_PLATFORM): "state", + vol.Required(CONF_ENTITY_ID): cv.entity_ids, + vol.Optional(CONF_FOR): cv.positive_time_period_template, + vol.Optional(CONF_ATTRIBUTE): cv.match_all, +} + +TRIGGER_STATE_SCHEMA = vol.Schema( { - vol.Required(CONF_PLATFORM): "state", - vol.Required(CONF_ENTITY_ID): cv.entity_ids, + **BASE_SCHEMA, # These are str on purpose. Want to catch YAML conversions vol.Optional(CONF_FROM): vol.Any(str, [str]), vol.Optional(CONF_TO): vol.Any(str, [str]), - vol.Optional(CONF_FOR): cv.positive_time_period_template, - vol.Optional(CONF_ATTRIBUTE): cv.match_all, } ) +TRIGGER_ATTRIBUTE_SCHEMA = vol.Schema( + { + **BASE_SCHEMA, + vol.Optional(CONF_FROM): cv.match_all, + vol.Optional(CONF_TO): cv.match_all, + } +) + + +def TRIGGER_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name + """Validate trigger.""" + if not isinstance(value, dict): + raise vol.Invalid("Expected a dictionary") + + # We use this approach instead of vol.Any because + # this gives better error messages. + if CONF_ATTRIBUTE in value: + return TRIGGER_ATTRIBUTE_SCHEMA(value) + + return TRIGGER_STATE_SCHEMA(value) + async def async_attach_trigger( hass: HomeAssistant, diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index f67b9a4b0ab..c982b58d8d9 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -297,7 +297,7 @@ def async_numeric_state_from_config( def state( hass: HomeAssistant, entity: Union[None, str, State], - req_state: Union[str, List[str]], + req_state: Any, for_period: Optional[timedelta] = None, attribute: Optional[str] = None, ) -> bool: @@ -314,17 +314,20 @@ def state( assert isinstance(entity, State) if attribute is None: - value = entity.state + value: Any = entity.state else: - value = str(entity.attributes.get(attribute)) + value = entity.attributes.get(attribute) - if isinstance(req_state, str): + if not isinstance(req_state, list): req_state = [req_state] is_state = False for req_state_value in req_state: state_value = req_state_value - if INPUT_ENTITY_ID.match(req_state_value) is not None: + if ( + isinstance(req_state_value, str) + and INPUT_ENTITY_ID.match(req_state_value) is not None + ): state_entity = hass.states.get(req_state_value) if not state_entity: continue diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 282e63e6440..08d23a98da6 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -929,22 +929,44 @@ NUMERIC_STATE_CONDITION_SCHEMA = vol.All( has_at_least_one_key(CONF_BELOW, CONF_ABOVE), ) -STATE_CONDITION_SCHEMA = vol.All( - vol.Schema( - { - vol.Required(CONF_CONDITION): "state", - vol.Required(CONF_ENTITY_ID): entity_ids, - vol.Optional(CONF_ATTRIBUTE): str, - vol.Required(CONF_STATE): vol.Any(str, [str]), - vol.Optional(CONF_FOR): positive_time_period, - # To support use_trigger_value in automation - # Deprecated 2016/04/25 - vol.Optional("from"): str, - } - ), - key_dependency("for", "state"), +STATE_CONDITION_BASE_SCHEMA = { + vol.Required(CONF_CONDITION): "state", + vol.Required(CONF_ENTITY_ID): entity_ids, + vol.Optional(CONF_ATTRIBUTE): str, + vol.Optional(CONF_FOR): positive_time_period, + # To support use_trigger_value in automation + # Deprecated 2016/04/25 + vol.Optional("from"): str, +} + +STATE_CONDITION_STATE_SCHEMA = vol.Schema( + { + **STATE_CONDITION_BASE_SCHEMA, + vol.Required(CONF_STATE): vol.Any(str, [str]), + } ) +STATE_CONDITION_ATTRIBUTE_SCHEMA = vol.Schema( + { + **STATE_CONDITION_BASE_SCHEMA, + vol.Required(CONF_STATE): match_all, + } +) + + +def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name + """Validate a state condition.""" + if not isinstance(value, dict): + raise vol.Invalid("Expected a dictionary") + + if CONF_ATTRIBUTE in value: + validated: dict = STATE_CONDITION_ATTRIBUTE_SCHEMA(value) + else: + validated = STATE_CONDITION_STATE_SCHEMA(value) + + return key_dependency("for", "state")(validated) + + SUN_CONDITION_SCHEMA = vol.All( vol.Schema( { diff --git a/tests/components/homeassistant/triggers/test_state.py b/tests/components/homeassistant/triggers/test_state.py index 688115dc400..61fa991e0f4 100644 --- a/tests/components/homeassistant/triggers/test_state.py +++ b/tests/components/homeassistant/triggers/test_state.py @@ -1240,3 +1240,32 @@ async def test_attribute_if_not_fires_on_entities_change_with_for_after_stop( async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=10)) await hass.async_block_till_done() assert len(calls) == 1 + + +async def test_attribute_if_fires_on_entity_change_with_both_filters_boolean( + hass, calls +): + """Test for firing if both filters are match attribute.""" + hass.states.async_set("test.entity", "bla", {"happening": False}) + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "trigger": { + "platform": "state", + "entity_id": "test.entity", + "from": False, + "to": True, + "attribute": "happening", + }, + "action": {"service": "test.automation"}, + } + }, + ) + await hass.async_block_till_done() + + hass.states.async_set("test.entity", "bla", {"happening": True}) + await hass.async_block_till_done() + assert len(calls) == 1 diff --git a/tests/helpers/test_condition.py b/tests/helpers/test_condition.py index 71770d21186..001c59e2f2c 100644 --- a/tests/helpers/test_condition.py +++ b/tests/helpers/test_condition.py @@ -422,7 +422,7 @@ async def test_state_attribute(hass): "condition": "state", "entity_id": "sensor.temperature", "attribute": "attribute1", - "state": "200", + "state": 200, }, ], }, @@ -435,7 +435,7 @@ async def test_state_attribute(hass): assert test(hass) hass.states.async_set("sensor.temperature", 100, {"attribute1": "200"}) - assert test(hass) + assert not test(hass) hass.states.async_set("sensor.temperature", 100, {"attribute1": 201}) assert not test(hass) @@ -444,6 +444,31 @@ async def test_state_attribute(hass): assert not test(hass) +async def test_state_attribute_boolean(hass): + """Test with boolean state attribute in condition.""" + test = await condition.async_from_config( + hass, + { + "condition": "state", + "entity_id": "sensor.temperature", + "attribute": "happening", + "state": False, + }, + ) + + hass.states.async_set("sensor.temperature", 100, {"happening": 200}) + assert not test(hass) + + hass.states.async_set("sensor.temperature", 100, {"happening": True}) + assert not test(hass) + + hass.states.async_set("sensor.temperature", 100, {"no_happening": 201}) + assert not test(hass) + + hass.states.async_set("sensor.temperature", 100, {"happening": False}) + assert test(hass) + + async def test_state_using_input_entities(hass): """Test state conditions using input_* entities.""" await async_setup_component(