diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 4bb51a9975f..6cd1916d4c2 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -108,7 +108,8 @@ def async_track_template(hass, template, action, variables=None): already_triggered = False return async_track_state_change( - hass, template.extract_entities(), template_condition_listener) + hass, template.extract_entities(variables), + template_condition_listener) track_template = threaded_listener_factory(async_track_template) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index bafaf4d0fdb..7154e990563 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -129,7 +129,7 @@ class Script(): self.hass.async_add_job(self.async_run(variables)) self._async_listener.append(async_track_template( - self.hass, wait_template, async_script_wait)) + self.hass, wait_template, async_script_wait, variables)) self._cur = cur + 1 if self._change_listener: diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index a390568e9c6..6f83688623a 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -25,8 +25,8 @@ DATE_STR_FORMAT = "%Y-%m-%d %H:%M:%S" _RE_NONE_ENTITIES = re.compile(r"distance\(|closest\(", re.I | re.M) _RE_GET_ENTITIES = re.compile( - r"(?:(?:states\.|(?:is_state|is_state_attr|states)\(.)([\w]+\.[\w]+))", - re.I | re.M + r"(?:(?:states\.|(?:is_state|is_state_attr|states)" + r"\((?:[\ \'\"]?))([\w]+\.[\w]+)|([\w]+))", re.I | re.M ) @@ -43,14 +43,27 @@ def attach(hass, obj): obj.hass = hass -def extract_entities(template): +def extract_entities(template, variables=None): """Extract all entities for state_changed listener from template string.""" if template is None or _RE_NONE_ENTITIES.search(template): return MATCH_ALL extraction = _RE_GET_ENTITIES.findall(template) - if extraction: - return list(set(extraction)) + extraction_final = [] + + for result in extraction: + if result[0] == 'trigger.entity_id' and 'trigger' in variables and \ + 'entity_id' in variables['trigger']: + extraction_final.append(variables['trigger']['entity_id']) + elif result[0]: + extraction_final.append(result[0]) + + if variables and result[1] in variables and \ + isinstance(variables[result[1]], str): + extraction_final.append(variables[result[1]]) + + if extraction_final: + return list(set(extraction_final)) return MATCH_ALL @@ -77,9 +90,9 @@ class Template(object): except jinja2.exceptions.TemplateSyntaxError as err: raise TemplateError(err) - def extract_entities(self): + def extract_entities(self, variables=None): """Extract all entities for state_changed listener.""" - return extract_entities(self.template) + return extract_entities(self.template, variables) def render(self, variables=None, **kwargs): """Render given template.""" diff --git a/tests/components/automation/test_numeric_state.py b/tests/components/automation/test_numeric_state.py index 0a7db4a122d..cb36a91dddb 100644 --- a/tests/components/automation/test_numeric_state.py +++ b/tests/components/automation/test_numeric_state.py @@ -704,3 +704,37 @@ class TestAutomationNumericState(unittest.TestCase): fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=10)) self.hass.block_till_done() self.assertEqual(1, len(self.calls)) + + def test_wait_template_with_trigger(self): + """Test using wait template with 'trigger.entity_id'.""" + assert setup_component(self.hass, automation.DOMAIN, { + automation.DOMAIN: { + 'trigger': { + 'platform': 'numeric_state', + 'entity_id': 'test.entity', + 'above': 10, + }, + 'action': [ + {'wait_template': + "{{ states(trigger.entity_id) | int < 10 }}"}, + {'service': 'test.automation', + 'data_template': { + 'some': + '{{ trigger.%s }}' % '}} - {{ trigger.'.join(( + 'platform', 'entity_id', 'to_state.state')) + }} + ], + } + }) + + self.hass.block_till_done() + self.calls = [] + + self.hass.states.set('test.entity', '12') + self.hass.block_till_done() + self.hass.states.set('test.entity', '8') + self.hass.block_till_done() + self.assertEqual(1, len(self.calls)) + self.assertEqual( + 'numeric_state - test.entity - 12', + self.calls[0].data['some']) diff --git a/tests/components/automation/test_state.py b/tests/components/automation/test_state.py index 2fd6c8415db..1f245d1cf5c 100644 --- a/tests/components/automation/test_state.py +++ b/tests/components/automation/test_state.py @@ -506,3 +506,38 @@ class TestAutomationState(unittest.TestCase): }, 'action': {'service': 'test.automation'}, }}) + + def test_wait_template_with_trigger(self): + """Test using wait template with 'trigger.entity_id'.""" + assert setup_component(self.hass, automation.DOMAIN, { + automation.DOMAIN: { + 'trigger': { + 'platform': 'state', + 'entity_id': 'test.entity', + 'to': 'world', + }, + 'action': [ + {'wait_template': + "{{ is_state(trigger.entity_id, 'hello') }}"}, + {'service': 'test.automation', + 'data_template': { + 'some': + '{{ trigger.%s }}' % '}} - {{ trigger.'.join(( + 'platform', 'entity_id', 'from_state.state', + 'to_state.state')) + }} + ], + } + }) + + self.hass.block_till_done() + self.calls = [] + + self.hass.states.set('test.entity', 'world') + self.hass.block_till_done() + self.hass.states.set('test.entity', 'hello') + self.hass.block_till_done() + self.assertEqual(1, len(self.calls)) + self.assertEqual( + 'state - test.entity - hello - world', + self.calls[0].data['some']) diff --git a/tests/components/automation/test_template.py b/tests/components/automation/test_template.py index 5cc47687665..937fa16988a 100644 --- a/tests/components/automation/test_template.py +++ b/tests/components/automation/test_template.py @@ -399,3 +399,38 @@ class TestAutomationTemplate(unittest.TestCase): self.hass.states.set('test.entity', 'world') self.hass.block_till_done() self.assertEqual(0, len(self.calls)) + + def test_wait_template_with_trigger(self): + """Test using wait template with 'trigger.entity_id'.""" + assert setup_component(self.hass, automation.DOMAIN, { + automation.DOMAIN: { + 'trigger': { + 'platform': 'template', + 'value_template': + "{{ states.test.entity.state == 'world' }}", + }, + 'action': [ + {'wait_template': + "{{ is_state(trigger.entity_id, 'hello') }}"}, + {'service': 'test.automation', + 'data_template': { + 'some': + '{{ trigger.%s }}' % '}} - {{ trigger.'.join(( + 'platform', 'entity_id', 'from_state.state', + 'to_state.state')) + }} + ], + } + }) + + self.hass.block_till_done() + self.calls = [] + + self.hass.states.set('test.entity', 'world') + self.hass.block_till_done() + self.hass.states.set('test.entity', 'hello') + self.hass.block_till_done() + self.assertEqual(1, len(self.calls)) + self.assertEqual( + 'template - test.entity - hello - world', + self.calls[0].data['some']) diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index d9ef7bc5a2b..b6e3ea17e1a 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -345,6 +345,41 @@ class TestScriptHelper(unittest.TestCase): assert not script_obj.is_running assert len(events) == 1 + def test_wait_template_variables(self): + """Test the wait template with variables.""" + event = 'test_event' + events = [] + + @callback + def record_event(event): + """Add recorded event to set.""" + events.append(event) + + self.hass.bus.listen(event, record_event) + + self.hass.states.set('switch.test', 'on') + + script_obj = script.Script(self.hass, cv.SCRIPT_SCHEMA([ + {'event': event}, + {'wait_template': "{{is_state(data, 'off')}}"}, + {'event': event}])) + + script_obj.run({ + 'data': 'switch.test' + }) + self.hass.block_till_done() + + assert script_obj.is_running + assert script_obj.can_cancel + assert script_obj.last_action == event + assert len(events) == 1 + + self.hass.states.set('switch.test', 'off') + self.hass.block_till_done() + + assert not script_obj.is_running + assert len(events) == 2 + def test_passing_variables_to_script(self): """Test if we can pass variables to script.""" calls = [] diff --git a/tests/helpers/test_template.py b/tests/helpers/test_template.py index a32b2dc13a1..a214d69f80a 100644 --- a/tests/helpers/test_template.py +++ b/tests/helpers/test_template.py @@ -683,7 +683,7 @@ class TestHelpersTemplate(unittest.TestCase): MATCH_ALL, template.extract_entities(""" {% for state in states.sensor %} - {{ state.entity_id }}={{ state.state }}, + {{ state.entity_id }}={{ state.state }},d {% endfor %} """)) @@ -753,6 +753,35 @@ is_state_attr('device_tracker.phone_2', 'battery', 40) " %}true{% endif %}" ))) + def test_extract_entities_with_variables(self): + """Test extract entities function with variables and entities stuff.""" + self.assertEqual( + ['input_boolean.switch'], + template.extract_entities( + "{{ is_state('input_boolean.switch', 'off') }}", {})) + + self.assertEqual( + ['trigger.entity_id'], + template.extract_entities( + "{{ is_state(trigger.entity_id, 'off') }}", {})) + + self.assertEqual( + MATCH_ALL, + template.extract_entities( + "{{ is_state(data, 'off') }}", {})) + + self.assertEqual( + ['input_boolean.switch'], + template.extract_entities( + "{{ is_state(data, 'off') }}", + {'data': 'input_boolean.switch'})) + + self.assertEqual( + ['input_boolean.switch'], + template.extract_entities( + "{{ is_state(trigger.entity_id, 'off') }}", + {'trigger': {'entity_id': 'input_boolean.switch'}})) + @asyncio.coroutine def test_state_with_unit(hass):