diff --git a/homeassistant/util/yaml.py b/homeassistant/util/yaml.py index d70a8f1e3e0..58458986063 100644 --- a/homeassistant/util/yaml.py +++ b/homeassistant/util/yaml.py @@ -46,7 +46,7 @@ def _include_yaml(loader, node): def _include_dir_named_yaml(loader, node): - """Load multiple files from dir.""" + """Load multiple files from dir as a dict.""" mapping = OrderedDict() files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') for fname in glob.glob(files): @@ -55,12 +55,34 @@ def _include_dir_named_yaml(loader, node): return mapping +def _include_dir_merge_named_yaml(loader, node): + """Load multiple files from dir as a merged dict.""" + mapping = OrderedDict() + files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') + for fname in glob.glob(files): + loaded_yaml = load_yaml(fname) + if isinstance(loaded_yaml, dict): + mapping.update(loaded_yaml) + return mapping + + def _include_dir_list_yaml(loader, node): - """Load multiple files from dir.""" + """Load multiple files from dir as a list.""" files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') return [load_yaml(f) for f in glob.glob(files)] +def _include_dir_merge_list_yaml(loader, node): + """Load multiple files from dir as a merged list.""" + files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') + merged_list = [] + for fname in glob.glob(files): + loaded_yaml = load_yaml(fname) + if isinstance(loaded_yaml, list): + merged_list.extend(loaded_yaml) + return merged_list + + def _ordered_dict(loader, node): """Load YAML mappings into an ordered dict to preserve key order.""" loader.flatten_mapping(node) @@ -102,4 +124,8 @@ yaml.SafeLoader.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, _ordered_dict) yaml.SafeLoader.add_constructor('!env_var', _env_var_yaml) yaml.SafeLoader.add_constructor('!include_dir_list', _include_dir_list_yaml) +yaml.SafeLoader.add_constructor('!include_dir_merge_list', + _include_dir_merge_list_yaml) yaml.SafeLoader.add_constructor('!include_dir_named', _include_dir_named_yaml) +yaml.SafeLoader.add_constructor('!include_dir_merge_named', + _include_dir_merge_named_yaml) diff --git a/tests/util/test_yaml.py b/tests/util/test_yaml.py index 106cb01264e..244f9323334 100644 --- a/tests/util/test_yaml.py +++ b/tests/util/test_yaml.py @@ -2,6 +2,7 @@ import io import unittest import os +import tempfile from homeassistant.util import yaml @@ -53,3 +54,84 @@ class TestYaml(unittest.TestCase): pass else: assert 0 + + def test_include_yaml(self): + """Test include yaml.""" + with tempfile.NamedTemporaryFile() as include_file: + include_file.write(b"value") + include_file.seek(0) + conf = "key: !include {}".format(include_file.name) + with io.StringIO(conf) as f: + doc = yaml.yaml.safe_load(f) + assert doc["key"] == "value" + + def test_include_dir_list(self): + """Test include dir list yaml.""" + with tempfile.TemporaryDirectory() as include_dir: + file_1 = tempfile.NamedTemporaryFile(dir=include_dir, + suffix=".yaml", delete=False) + file_1.write(b"one") + file_1.close() + file_2 = tempfile.NamedTemporaryFile(dir=include_dir, + suffix=".yaml", delete=False) + file_2.write(b"two") + file_2.close() + conf = "key: !include_dir_list {}".format(include_dir) + with io.StringIO(conf) as f: + doc = yaml.yaml.safe_load(f) + assert sorted(doc["key"]) == sorted(["one", "two"]) + + def test_include_dir_named(self): + """Test include dir named yaml.""" + with tempfile.TemporaryDirectory() as include_dir: + file_1 = tempfile.NamedTemporaryFile(dir=include_dir, + suffix=".yaml", delete=False) + file_1.write(b"one") + file_1.close() + file_2 = tempfile.NamedTemporaryFile(dir=include_dir, + suffix=".yaml", delete=False) + file_2.write(b"two") + file_2.close() + conf = "key: !include_dir_named {}".format(include_dir) + correct = {} + correct[os.path.splitext(os.path.basename(file_1.name))[0]] = "one" + correct[os.path.splitext(os.path.basename(file_2.name))[0]] = "two" + with io.StringIO(conf) as f: + doc = yaml.yaml.safe_load(f) + assert doc["key"] == correct + + def test_include_dir_merge_list(self): + """Test include dir merge list yaml.""" + with tempfile.TemporaryDirectory() as include_dir: + file_1 = tempfile.NamedTemporaryFile(dir=include_dir, + suffix=".yaml", delete=False) + file_1.write(b"- one") + file_1.close() + file_2 = tempfile.NamedTemporaryFile(dir=include_dir, + suffix=".yaml", delete=False) + file_2.write(b"- two\n- three") + file_2.close() + conf = "key: !include_dir_merge_list {}".format(include_dir) + with io.StringIO(conf) as f: + doc = yaml.yaml.safe_load(f) + assert sorted(doc["key"]) == sorted(["one", "two", "three"]) + + def test_include_dir_merge_named(self): + """Test include dir merge named yaml.""" + with tempfile.TemporaryDirectory() as include_dir: + file_1 = tempfile.NamedTemporaryFile(dir=include_dir, + suffix=".yaml", delete=False) + file_1.write(b"key1: one") + file_1.close() + file_2 = tempfile.NamedTemporaryFile(dir=include_dir, + suffix=".yaml", delete=False) + file_2.write(b"key2: two\nkey3: three") + file_2.close() + conf = "key: !include_dir_merge_named {}".format(include_dir) + with io.StringIO(conf) as f: + doc = yaml.yaml.safe_load(f) + assert doc["key"] == { + "key1": "one", + "key2": "two", + "key3": "three" + }