From ba73ac12ba5aa4374922bea9db23a1044e56510b Mon Sep 17 00:00:00 2001 From: Franck Nijhof Date: Mon, 15 Jun 2020 22:54:19 +0200 Subject: [PATCH] Add support for multiple entity_ids in conditions (#36817) --- homeassistant/helpers/condition.py | 30 +++-- homeassistant/helpers/config_validation.py | 6 +- tests/helpers/test_condition.py | 131 +++++++++++++++++++++ 3 files changed, 153 insertions(+), 14 deletions(-) diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index 535de0304a0..b05445eff27 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -238,7 +238,7 @@ def async_numeric_state_from_config( """Wrap action method with state based condition.""" if config_validation: config = cv.NUMERIC_STATE_CONDITION_SCHEMA(config) - entity_id = config.get(CONF_ENTITY_ID) + entity_ids = config.get(CONF_ENTITY_ID, []) below = config.get(CONF_BELOW) above = config.get(CONF_ABOVE) value_template = config.get(CONF_VALUE_TEMPLATE) @@ -250,8 +250,11 @@ def async_numeric_state_from_config( if value_template is not None: value_template.hass = hass - return async_numeric_state( - hass, entity_id, below, above, value_template, variables + return all( + async_numeric_state( + hass, entity_id, below, above, value_template, variables + ) + for entity_id in entity_ids ) return if_numeric_state @@ -288,13 +291,15 @@ def state_from_config( """Wrap action method with state based condition.""" if config_validation: config = cv.STATE_CONDITION_SCHEMA(config) - entity_id = config.get(CONF_ENTITY_ID) + entity_ids = config.get(CONF_ENTITY_ID, []) req_state = cast(str, config.get(CONF_STATE)) for_period = config.get("for") def if_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: """Test if condition.""" - return state(hass, entity_id, req_state, for_period) + return all( + state(hass, entity_id, req_state, for_period) for entity_id in entity_ids + ) return if_state @@ -506,12 +511,12 @@ def zone_from_config( """Wrap action method with zone based condition.""" if config_validation: config = cv.ZONE_CONDITION_SCHEMA(config) - entity_id = config.get(CONF_ENTITY_ID) + entity_ids = config.get(CONF_ENTITY_ID, []) zone_entity_id = config.get(CONF_ZONE) def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: """Test if condition.""" - return zone(hass, zone_entity_id, entity_id) + return all(zone(hass, zone_entity_id, entity_id) for entity_id in entity_ids) return if_in_zone @@ -556,7 +561,7 @@ async def async_validate_condition_config( @callback def async_extract_entities(config: ConfigType) -> Set[str]: """Extract entities from a condition.""" - referenced = set() + referenced: Set[str] = set() to_process = deque([config]) while to_process: @@ -567,10 +572,13 @@ def async_extract_entities(config: ConfigType) -> Set[str]: to_process.extend(config["conditions"]) continue - entity_id = config.get(CONF_ENTITY_ID) + entity_ids = config.get(CONF_ENTITY_ID) - if entity_id is not None: - referenced.add(entity_id) + if isinstance(entity_ids, str): + entity_ids = [entity_ids] + + if entity_ids is not None: + referenced.update(entity_ids) return referenced diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 30cda4e4540..69cc422da0a 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -844,7 +844,7 @@ NUMERIC_STATE_CONDITION_SCHEMA = vol.All( vol.Schema( { vol.Required(CONF_CONDITION): "numeric_state", - vol.Required(CONF_ENTITY_ID): entity_id, + vol.Required(CONF_ENTITY_ID): entity_ids, CONF_BELOW: vol.Coerce(float), CONF_ABOVE: vol.Coerce(float), vol.Optional(CONF_VALUE_TEMPLATE): template, @@ -857,7 +857,7 @@ STATE_CONDITION_SCHEMA = vol.All( vol.Schema( { vol.Required(CONF_CONDITION): "state", - vol.Required(CONF_ENTITY_ID): entity_id, + vol.Required(CONF_ENTITY_ID): entity_ids, vol.Required(CONF_STATE): str, vol.Optional(CONF_FOR): vol.All(time_period, positive_timedelta), # To support use_trigger_value in automation @@ -905,7 +905,7 @@ TIME_CONDITION_SCHEMA = vol.All( ZONE_CONDITION_SCHEMA = vol.Schema( { vol.Required(CONF_CONDITION): "zone", - vol.Required(CONF_ENTITY_ID): entity_id, + vol.Required(CONF_ENTITY_ID): entity_ids, "zone": entity_id, # To support use_trigger_value in automation # Deprecated 2016/04/25 diff --git a/tests/helpers/test_condition.py b/tests/helpers/test_condition.py index c4b87b667fa..5d81c110635 100644 --- a/tests/helpers/test_condition.py +++ b/tests/helpers/test_condition.py @@ -266,6 +266,123 @@ async def test_if_numeric_state_not_raise_on_unavailable(hass): assert len(logwarn.mock_calls) == 0 +async def test_state_multiple_entities(hass): + """Test with multiple entities in condition.""" + test = await condition.async_from_config( + hass, + { + "condition": "and", + "conditions": [ + { + "condition": "state", + "entity_id": ["sensor.temperature_1", "sensor.temperature_2"], + "state": "100", + }, + ], + }, + ) + + hass.states.async_set("sensor.temperature_1", 100) + hass.states.async_set("sensor.temperature_2", 100) + assert test(hass) + + hass.states.async_set("sensor.temperature_1", 101) + hass.states.async_set("sensor.temperature_2", 100) + assert not test(hass) + + hass.states.async_set("sensor.temperature_1", 100) + hass.states.async_set("sensor.temperature_2", 101) + assert not test(hass) + + +async def test_numeric_state_multiple_entities(hass): + """Test with multiple entities in condition.""" + test = await condition.async_from_config( + hass, + { + "condition": "and", + "conditions": [ + { + "condition": "numeric_state", + "entity_id": ["sensor.temperature_1", "sensor.temperature_2"], + "below": 50, + }, + ], + }, + ) + + hass.states.async_set("sensor.temperature_1", 49) + hass.states.async_set("sensor.temperature_2", 49) + assert test(hass) + + hass.states.async_set("sensor.temperature_1", 50) + hass.states.async_set("sensor.temperature_2", 49) + assert not test(hass) + + hass.states.async_set("sensor.temperature_1", 49) + hass.states.async_set("sensor.temperature_2", 50) + assert not test(hass) + + +async def test_zone_multiple_entities(hass): + """Test with multiple entities in condition.""" + test = await condition.async_from_config( + hass, + { + "condition": "and", + "conditions": [ + { + "condition": "zone", + "entity_id": ["device_tracker.person_1", "device_tracker.person_2"], + "zone": "zone.home", + }, + ], + }, + ) + + hass.states.async_set( + "zone.home", + "zoning", + {"name": "home", "latitude": 2.1, "longitude": 1.1, "radius": 10}, + ) + + hass.states.async_set( + "device_tracker.person_1", + "home", + {"friendly_name": "person_1", "latitude": 2.1, "longitude": 1.1}, + ) + hass.states.async_set( + "device_tracker.person_2", + "home", + {"friendly_name": "person_2", "latitude": 2.1, "longitude": 1.1}, + ) + assert test(hass) + + hass.states.async_set( + "device_tracker.person_1", + "home", + {"friendly_name": "person_1", "latitude": 20.1, "longitude": 10.1}, + ) + hass.states.async_set( + "device_tracker.person_2", + "home", + {"friendly_name": "person_2", "latitude": 2.1, "longitude": 1.1}, + ) + assert not test(hass) + + hass.states.async_set( + "device_tracker.person_1", + "home", + {"friendly_name": "person_1", "latitude": 2.1, "longitude": 1.1}, + ) + hass.states.async_set( + "device_tracker.person_2", + "home", + {"friendly_name": "person_2", "latitude": 20.1, "longitude": 10.1}, + ) + assert not test(hass) + + async def test_extract_entities(): """Test extracting entities.""" assert condition.async_extract_entities( @@ -312,6 +429,16 @@ async def test_extract_entities(): }, ], }, + { + "condition": "state", + "entity_id": ["sensor.temperature_7", "sensor.temperature_8"], + "state": "100", + }, + { + "condition": "numeric_state", + "entity_id": ["sensor.temperature_9", "sensor.temperature_10"], + "below": 110, + }, ], } ) == { @@ -321,6 +448,10 @@ async def test_extract_entities(): "sensor.temperature_4", "sensor.temperature_5", "sensor.temperature_6", + "sensor.temperature_7", + "sensor.temperature_8", + "sensor.temperature_9", + "sensor.temperature_10", }