diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 1e67effb97f..1be157c789d 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -244,6 +244,20 @@ def template(value): raise vol.Invalid('invalid template ({})'.format(ex)) +def template_complex(value): + """Validate a complex jinja2 template.""" + if isinstance(value, list): + for idx, element in enumerate(value): + value[idx] = template_complex(element) + return value + if isinstance(value, dict): + for key, element in value.items(): + value[key] = template_complex(element) + return value + + return template(value) + + def time(value): """Validate time.""" time_val = dt_util.parse_time(value) @@ -310,7 +324,7 @@ SERVICE_SCHEMA = vol.All(vol.Schema({ vol.Exclusive('service', 'service name'): service, vol.Exclusive('service_template', 'service name'): template, vol.Optional('data'): dict, - vol.Optional('data_template'): {match_all: template}, + vol.Optional('data_template'): {match_all: template_complex}, vol.Optional(CONF_ENTITY_ID): entity_ids, }), has_at_least_one_key('service', 'service_template')) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index b594889fd77..21cfb0aab54 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -63,9 +63,21 @@ def call_from_config(hass, config, blocking=False, variables=None, domain, service_name = domain_service.split('.', 1) service_data = dict(config.get(CONF_SERVICE_DATA, {})) + def _data_template_creator(value): + """Recursive template creator helper function.""" + if isinstance(value, list): + for idx, element in enumerate(value): + value[idx] = _data_template_creator(element) + return value + if isinstance(value, dict): + for key, element in value.items(): + value[key] = _data_template_creator(element) + return value + return template.render(hass, value, variables) + if CONF_SERVICE_DATA_TEMPLATE in config: for key, value in config[CONF_SERVICE_DATA_TEMPLATE].items(): - service_data[key] = template.render(hass, value, variables) + service_data[key] = _data_template_creator(value) if CONF_SERVICE_ENTITY_ID in config: service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID] diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py index 637d5ead0b7..d9da2c51da7 100644 --- a/tests/helpers/test_config_validation.py +++ b/tests/helpers/test_config_validation.py @@ -299,9 +299,7 @@ def test_template(): """Test template validator.""" schema = vol.Schema(cv.template) - for value in ( - None, '{{ partial_print }', '{% if True %}Hello', {'dict': 'isbad'} - ): + for value in (None, '{{ partial_print }', '{% if True %}Hello', ['test']): with pytest.raises(vol.MultipleInvalid): schema(value) @@ -313,6 +311,24 @@ def test_template(): schema(value) +def test_template_complex(): + """Test template_complex validator.""" + schema = vol.Schema(cv.template_complex) + + for value in (None, '{{ partial_print }', '{% if True %}Hello'): + with pytest.raises(vol.MultipleInvalid): + schema(value) + + for value in ( + 1, 'Hello', + '{{ beer }}', + '{% if 1 == 1 %}Hello{% else %}World{% endif %}', + {'test': 1, 'test': '{{ beer }}'}, + ['{{ beer }}', 1] + ): + schema(value) + + def test_time_zone(): """Test time zone validation.""" schema = vol.Schema(cv.time_zone) diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 5372b6a77d4..34f321776d6 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -43,6 +43,11 @@ class TestServiceHelpers(unittest.TestCase): 'entity_id': 'hello.world', 'data_template': { 'hello': '{{ \'goodbye\' }}', + 'data': { + 'value': '{{ \'complex\' }}', + 'simple': 'simple' + }, + 'list': ['{{ \'list\' }}', '2'], }, } runs = [] @@ -54,6 +59,9 @@ class TestServiceHelpers(unittest.TestCase): self.hass.pool.block_till_done() self.assertEqual('goodbye', runs[0].data['hello']) + self.assertEqual('complex', runs[0].data['data']['value']) + self.assertEqual('simple', runs[0].data['data']['simple']) + self.assertEqual('list', runs[0].data['list'][0]) def test_passing_variables_to_templates(self): """Test passing variables to templates."""