Allow any value when triggering on state attribute (#41261)

This commit is contained in:
Paulus Schoutsen 2020-10-05 12:53:12 +02:00
parent a6d50ba89b
commit b524cc9c56
5 changed files with 131 additions and 27 deletions

View File

@ -1,7 +1,7 @@
"""Offer state listening automation rules.""" """Offer state listening automation rules."""
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import Dict, Optional from typing import Any, Dict, Optional
import voluptuous as vol import voluptuous as vol
@ -25,18 +25,43 @@ CONF_ENTITY_ID = "entity_id"
CONF_FROM = "from" CONF_FROM = "from"
CONF_TO = "to" CONF_TO = "to"
TRIGGER_SCHEMA = vol.Schema( BASE_SCHEMA = {
{
vol.Required(CONF_PLATFORM): "state", vol.Required(CONF_PLATFORM): "state",
vol.Required(CONF_ENTITY_ID): cv.entity_ids, vol.Required(CONF_ENTITY_ID): cv.entity_ids,
vol.Optional(CONF_FOR): cv.positive_time_period_template,
vol.Optional(CONF_ATTRIBUTE): cv.match_all,
}
TRIGGER_STATE_SCHEMA = vol.Schema(
{
**BASE_SCHEMA,
# These are str on purpose. Want to catch YAML conversions # These are str on purpose. Want to catch YAML conversions
vol.Optional(CONF_FROM): vol.Any(str, [str]), vol.Optional(CONF_FROM): vol.Any(str, [str]),
vol.Optional(CONF_TO): vol.Any(str, [str]), vol.Optional(CONF_TO): vol.Any(str, [str]),
vol.Optional(CONF_FOR): cv.positive_time_period_template,
vol.Optional(CONF_ATTRIBUTE): cv.match_all,
} }
) )
TRIGGER_ATTRIBUTE_SCHEMA = vol.Schema(
{
**BASE_SCHEMA,
vol.Optional(CONF_FROM): cv.match_all,
vol.Optional(CONF_TO): cv.match_all,
}
)
def TRIGGER_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name
"""Validate trigger."""
if not isinstance(value, dict):
raise vol.Invalid("Expected a dictionary")
# We use this approach instead of vol.Any because
# this gives better error messages.
if CONF_ATTRIBUTE in value:
return TRIGGER_ATTRIBUTE_SCHEMA(value)
return TRIGGER_STATE_SCHEMA(value)
async def async_attach_trigger( async def async_attach_trigger(
hass: HomeAssistant, hass: HomeAssistant,

View File

@ -297,7 +297,7 @@ def async_numeric_state_from_config(
def state( def state(
hass: HomeAssistant, hass: HomeAssistant,
entity: Union[None, str, State], entity: Union[None, str, State],
req_state: Union[str, List[str]], req_state: Any,
for_period: Optional[timedelta] = None, for_period: Optional[timedelta] = None,
attribute: Optional[str] = None, attribute: Optional[str] = None,
) -> bool: ) -> bool:
@ -314,17 +314,20 @@ def state(
assert isinstance(entity, State) assert isinstance(entity, State)
if attribute is None: if attribute is None:
value = entity.state value: Any = entity.state
else: else:
value = str(entity.attributes.get(attribute)) value = entity.attributes.get(attribute)
if isinstance(req_state, str): if not isinstance(req_state, list):
req_state = [req_state] req_state = [req_state]
is_state = False is_state = False
for req_state_value in req_state: for req_state_value in req_state:
state_value = req_state_value state_value = req_state_value
if INPUT_ENTITY_ID.match(req_state_value) is not None: if (
isinstance(req_state_value, str)
and INPUT_ENTITY_ID.match(req_state_value) is not None
):
state_entity = hass.states.get(req_state_value) state_entity = hass.states.get(req_state_value)
if not state_entity: if not state_entity:
continue continue

View File

@ -929,22 +929,44 @@ NUMERIC_STATE_CONDITION_SCHEMA = vol.All(
has_at_least_one_key(CONF_BELOW, CONF_ABOVE), has_at_least_one_key(CONF_BELOW, CONF_ABOVE),
) )
STATE_CONDITION_SCHEMA = vol.All( STATE_CONDITION_BASE_SCHEMA = {
vol.Schema(
{
vol.Required(CONF_CONDITION): "state", vol.Required(CONF_CONDITION): "state",
vol.Required(CONF_ENTITY_ID): entity_ids, vol.Required(CONF_ENTITY_ID): entity_ids,
vol.Optional(CONF_ATTRIBUTE): str, vol.Optional(CONF_ATTRIBUTE): str,
vol.Required(CONF_STATE): vol.Any(str, [str]),
vol.Optional(CONF_FOR): positive_time_period, vol.Optional(CONF_FOR): positive_time_period,
# To support use_trigger_value in automation # To support use_trigger_value in automation
# Deprecated 2016/04/25 # Deprecated 2016/04/25
vol.Optional("from"): str, vol.Optional("from"): str,
}
STATE_CONDITION_STATE_SCHEMA = vol.Schema(
{
**STATE_CONDITION_BASE_SCHEMA,
vol.Required(CONF_STATE): vol.Any(str, [str]),
} }
),
key_dependency("for", "state"),
) )
STATE_CONDITION_ATTRIBUTE_SCHEMA = vol.Schema(
{
**STATE_CONDITION_BASE_SCHEMA,
vol.Required(CONF_STATE): match_all,
}
)
def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name
"""Validate a state condition."""
if not isinstance(value, dict):
raise vol.Invalid("Expected a dictionary")
if CONF_ATTRIBUTE in value:
validated: dict = STATE_CONDITION_ATTRIBUTE_SCHEMA(value)
else:
validated = STATE_CONDITION_STATE_SCHEMA(value)
return key_dependency("for", "state")(validated)
SUN_CONDITION_SCHEMA = vol.All( SUN_CONDITION_SCHEMA = vol.All(
vol.Schema( vol.Schema(
{ {

View File

@ -1240,3 +1240,32 @@ async def test_attribute_if_not_fires_on_entities_change_with_for_after_stop(
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=10)) async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=10))
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
async def test_attribute_if_fires_on_entity_change_with_both_filters_boolean(
hass, calls
):
"""Test for firing if both filters are match attribute."""
hass.states.async_set("test.entity", "bla", {"happening": False})
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": {
"platform": "state",
"entity_id": "test.entity",
"from": False,
"to": True,
"attribute": "happening",
},
"action": {"service": "test.automation"},
}
},
)
await hass.async_block_till_done()
hass.states.async_set("test.entity", "bla", {"happening": True})
await hass.async_block_till_done()
assert len(calls) == 1

View File

@ -422,7 +422,7 @@ async def test_state_attribute(hass):
"condition": "state", "condition": "state",
"entity_id": "sensor.temperature", "entity_id": "sensor.temperature",
"attribute": "attribute1", "attribute": "attribute1",
"state": "200", "state": 200,
}, },
], ],
}, },
@ -435,7 +435,7 @@ async def test_state_attribute(hass):
assert test(hass) assert test(hass)
hass.states.async_set("sensor.temperature", 100, {"attribute1": "200"}) hass.states.async_set("sensor.temperature", 100, {"attribute1": "200"})
assert test(hass) assert not test(hass)
hass.states.async_set("sensor.temperature", 100, {"attribute1": 201}) hass.states.async_set("sensor.temperature", 100, {"attribute1": 201})
assert not test(hass) assert not test(hass)
@ -444,6 +444,31 @@ async def test_state_attribute(hass):
assert not test(hass) assert not test(hass)
async def test_state_attribute_boolean(hass):
"""Test with boolean state attribute in condition."""
test = await condition.async_from_config(
hass,
{
"condition": "state",
"entity_id": "sensor.temperature",
"attribute": "happening",
"state": False,
},
)
hass.states.async_set("sensor.temperature", 100, {"happening": 200})
assert not test(hass)
hass.states.async_set("sensor.temperature", 100, {"happening": True})
assert not test(hass)
hass.states.async_set("sensor.temperature", 100, {"no_happening": 201})
assert not test(hass)
hass.states.async_set("sensor.temperature", 100, {"happening": False})
assert test(hass)
async def test_state_using_input_entities(hass): async def test_state_using_input_entities(hass):
"""Test state conditions using input_* entities.""" """Test state conditions using input_* entities."""
await async_setup_component( await async_setup_component(