Don't use template in cover device condition (#59408)

This commit is contained in:
Erik Montnemery 2021-11-09 12:43:21 +01:00 committed by GitHub
parent b5ce84cd89
commit 23fad60769
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 98 additions and 53 deletions

View File

@ -18,12 +18,7 @@ from homeassistant.const import (
STATE_OPENING, STATE_OPENING,
) )
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import ( from homeassistant.helpers import condition, config_validation as cv, entity_registry
condition,
config_validation as cv,
entity_registry,
template,
)
from homeassistant.helpers.config_validation import DEVICE_CONDITION_BASE_SCHEMA from homeassistant.helpers.config_validation import DEVICE_CONDITION_BASE_SCHEMA
from homeassistant.helpers.entity import get_supported_features from homeassistant.helpers.entity import get_supported_features
from homeassistant.helpers.typing import ConfigType, TemplateVarsType from homeassistant.helpers.typing import ConfigType, TemplateVarsType
@ -148,22 +143,19 @@ def async_condition_from_config(
return test_is_state return test_is_state
if config[CONF_TYPE] == "is_position": if config[CONF_TYPE] == "is_position":
position = "current_position" position_attr = "current_position"
if config[CONF_TYPE] == "is_tilt_position": if config[CONF_TYPE] == "is_tilt_position":
position = "current_tilt_position" position_attr = "current_tilt_position"
min_pos = config.get(CONF_ABOVE) min_pos = config.get(CONF_ABOVE)
max_pos = config.get(CONF_BELOW) max_pos = config.get(CONF_BELOW)
value_template = template.Template( # type: ignore
f"{{{{ state.attributes.{position} }}}}"
)
@callback @callback
def template_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: def check_numeric_state(
"""Validate template based if-condition.""" hass: HomeAssistant, variables: TemplateVarsType = None
value_template.hass = hass ) -> bool:
"""Return whether the criteria are met."""
return condition.async_numeric_state( return condition.async_numeric_state(
hass, config[ATTR_ENTITY_ID], max_pos, min_pos, value_template hass, config[ATTR_ENTITY_ID], max_pos, min_pos, attribute=position_attr
) )
return template_if return check_numeric_state

View File

@ -15,6 +15,7 @@ from homeassistant.const import (
STATE_CLOSING, STATE_CLOSING,
STATE_OPEN, STATE_OPEN,
STATE_OPENING, STATE_OPENING,
STATE_UNAVAILABLE,
) )
from homeassistant.helpers import device_registry from homeassistant.helpers import device_registry
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -353,7 +354,7 @@ async def test_if_state(hass, calls):
assert calls[3].data["some"] == "is_closing - event - test_event4" assert calls[3].data["some"] == "is_closing - event - test_event4"
async def test_if_position(hass, calls, enable_custom_integrations): async def test_if_position(hass, calls, caplog, enable_custom_integrations):
"""Test for position conditions.""" """Test for position conditions."""
platform = getattr(hass.components, f"test.{DOMAIN}") platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init() platform.init()
@ -368,20 +369,28 @@ async def test_if_position(hass, calls, enable_custom_integrations):
automation.DOMAIN: [ automation.DOMAIN: [
{ {
"trigger": {"platform": "event", "event_type": "test_event1"}, "trigger": {"platform": "event", "event_type": "test_event1"},
"condition": [
{
"condition": "device",
"domain": DOMAIN,
"device_id": "",
"entity_id": ent.entity_id,
"type": "is_position",
"above": 45,
}
],
"action": { "action": {
"service": "test.automation", "choose": {
"data_template": { "conditions": {
"some": "is_pos_gt_45 - {{ trigger.platform }} - {{ trigger.event.event_type }}" "condition": "device",
"domain": DOMAIN,
"device_id": "",
"entity_id": ent.entity_id,
"type": "is_position",
"above": 45,
},
"sequence": {
"service": "test.automation",
"data_template": {
"some": "is_pos_gt_45 - {{ trigger.platform }} - {{ trigger.event.event_type }}"
},
},
},
"default": {
"service": "test.automation",
"data_template": {
"some": "is_pos_not_gt_45 - {{ trigger.platform }} - {{ trigger.event.event_type }}"
},
}, },
}, },
}, },
@ -427,8 +436,13 @@ async def test_if_position(hass, calls, enable_custom_integrations):
] ]
}, },
) )
caplog.clear()
hass.bus.async_fire("test_event1") hass.bus.async_fire("test_event1")
await hass.async_block_till_done()
hass.bus.async_fire("test_event2") hass.bus.async_fire("test_event2")
await hass.async_block_till_done()
hass.bus.async_fire("test_event3") hass.bus.async_fire("test_event3")
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 3 assert len(calls) == 3
@ -440,11 +454,14 @@ async def test_if_position(hass, calls, enable_custom_integrations):
ent.entity_id, STATE_CLOSED, attributes={"current_position": 45} ent.entity_id, STATE_CLOSED, attributes={"current_position": 45}
) )
hass.bus.async_fire("test_event1") hass.bus.async_fire("test_event1")
await hass.async_block_till_done()
hass.bus.async_fire("test_event2") hass.bus.async_fire("test_event2")
await hass.async_block_till_done()
hass.bus.async_fire("test_event3") hass.bus.async_fire("test_event3")
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 4 assert len(calls) == 5
assert calls[3].data["some"] == "is_pos_lt_90 - event - test_event2" assert calls[3].data["some"] == "is_pos_not_gt_45 - event - test_event1"
assert calls[4].data["some"] == "is_pos_lt_90 - event - test_event2"
hass.states.async_set( hass.states.async_set(
ent.entity_id, STATE_CLOSED, attributes={"current_position": 90} ent.entity_id, STATE_CLOSED, attributes={"current_position": 90}
@ -453,11 +470,20 @@ async def test_if_position(hass, calls, enable_custom_integrations):
hass.bus.async_fire("test_event2") hass.bus.async_fire("test_event2")
hass.bus.async_fire("test_event3") hass.bus.async_fire("test_event3")
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 5 assert len(calls) == 6
assert calls[4].data["some"] == "is_pos_gt_45 - event - test_event1" assert calls[5].data["some"] == "is_pos_gt_45 - event - test_event1"
hass.states.async_set(ent.entity_id, STATE_UNAVAILABLE, attributes={})
hass.bus.async_fire("test_event1")
await hass.async_block_till_done()
assert len(calls) == 7
assert calls[6].data["some"] == "is_pos_not_gt_45 - event - test_event1"
for record in caplog.records:
assert record.levelname in ("DEBUG", "INFO")
async def test_if_tilt_position(hass, calls, enable_custom_integrations): async def test_if_tilt_position(hass, calls, caplog, enable_custom_integrations):
"""Test for tilt position conditions.""" """Test for tilt position conditions."""
platform = getattr(hass.components, f"test.{DOMAIN}") platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init() platform.init()
@ -472,20 +498,28 @@ async def test_if_tilt_position(hass, calls, enable_custom_integrations):
automation.DOMAIN: [ automation.DOMAIN: [
{ {
"trigger": {"platform": "event", "event_type": "test_event1"}, "trigger": {"platform": "event", "event_type": "test_event1"},
"condition": [
{
"condition": "device",
"domain": DOMAIN,
"device_id": "",
"entity_id": ent.entity_id,
"type": "is_tilt_position",
"above": 45,
}
],
"action": { "action": {
"service": "test.automation", "choose": {
"data_template": { "conditions": {
"some": "is_pos_gt_45 - {{ trigger.platform }} - {{ trigger.event.event_type }}" "condition": "device",
"domain": DOMAIN,
"device_id": "",
"entity_id": ent.entity_id,
"type": "is_tilt_position",
"above": 45,
},
"sequence": {
"service": "test.automation",
"data_template": {
"some": "is_pos_gt_45 - {{ trigger.platform }} - {{ trigger.event.event_type }}"
},
},
},
"default": {
"service": "test.automation",
"data_template": {
"some": "is_pos_not_gt_45 - {{ trigger.platform }} - {{ trigger.event.event_type }}"
},
}, },
}, },
}, },
@ -531,8 +565,13 @@ async def test_if_tilt_position(hass, calls, enable_custom_integrations):
] ]
}, },
) )
caplog.clear()
hass.bus.async_fire("test_event1") hass.bus.async_fire("test_event1")
await hass.async_block_till_done()
hass.bus.async_fire("test_event2") hass.bus.async_fire("test_event2")
await hass.async_block_till_done()
hass.bus.async_fire("test_event3") hass.bus.async_fire("test_event3")
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 3 assert len(calls) == 3
@ -544,18 +583,32 @@ async def test_if_tilt_position(hass, calls, enable_custom_integrations):
ent.entity_id, STATE_CLOSED, attributes={"current_tilt_position": 45} ent.entity_id, STATE_CLOSED, attributes={"current_tilt_position": 45}
) )
hass.bus.async_fire("test_event1") hass.bus.async_fire("test_event1")
await hass.async_block_till_done()
hass.bus.async_fire("test_event2") hass.bus.async_fire("test_event2")
await hass.async_block_till_done()
hass.bus.async_fire("test_event3") hass.bus.async_fire("test_event3")
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 4 assert len(calls) == 5
assert calls[3].data["some"] == "is_pos_lt_90 - event - test_event2" assert calls[3].data["some"] == "is_pos_not_gt_45 - event - test_event1"
assert calls[4].data["some"] == "is_pos_lt_90 - event - test_event2"
hass.states.async_set( hass.states.async_set(
ent.entity_id, STATE_CLOSED, attributes={"current_tilt_position": 90} ent.entity_id, STATE_CLOSED, attributes={"current_tilt_position": 90}
) )
hass.bus.async_fire("test_event1") hass.bus.async_fire("test_event1")
await hass.async_block_till_done()
hass.bus.async_fire("test_event2") hass.bus.async_fire("test_event2")
await hass.async_block_till_done()
hass.bus.async_fire("test_event3") hass.bus.async_fire("test_event3")
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 5 assert len(calls) == 6
assert calls[4].data["some"] == "is_pos_gt_45 - event - test_event1" assert calls[5].data["some"] == "is_pos_gt_45 - event - test_event1"
hass.states.async_set(ent.entity_id, STATE_UNAVAILABLE, attributes={})
hass.bus.async_fire("test_event1")
await hass.async_block_till_done()
assert len(calls) == 7
assert calls[6].data["some"] == "is_pos_not_gt_45 - event - test_event1"
for record in caplog.records:
assert record.levelname in ("DEBUG", "INFO")