Load YAML config into an ordered dict

This commit is contained in:
Paulus Schoutsen 2016-01-23 22:37:15 -08:00
parent e5497d89f4
commit de08f0afaa
6 changed files with 71 additions and 44 deletions

View File

@ -223,7 +223,7 @@ def from_config_file(config_path, hass=None, verbose=False, daemon=False,
enable_logging(hass, verbose, daemon, log_rotate_days) enable_logging(hass, verbose, daemon, log_rotate_days)
config_dict = config_util.load_config_file(config_path) config_dict = config_util.load_yaml_config_file(config_path)
return from_config_dict(config_dict, hass, enable_log=False, return from_config_dict(config_dict, hass, enable_log=False,
skip_pip=skip_pip) skip_pip=skip_pip)

View File

@ -12,6 +12,7 @@ from homeassistant.const import (
CONF_LATITUDE, CONF_LONGITUDE, CONF_TEMPERATURE_UNIT, CONF_NAME, CONF_LATITUDE, CONF_LONGITUDE, CONF_TEMPERATURE_UNIT, CONF_NAME,
CONF_TIME_ZONE) CONF_TIME_ZONE)
import homeassistant.util.location as loc_util import homeassistant.util.location as loc_util
from homeassistant.util.yaml import load_yaml
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -113,40 +114,9 @@ def find_config_file(config_dir):
return config_path if os.path.isfile(config_path) else None return config_path if os.path.isfile(config_path) else None
def load_config_file(config_path):
""" Loads given config file. """
return load_yaml_config_file(config_path)
def load_yaml_config_file(config_path): def load_yaml_config_file(config_path):
""" Parse a YAML configuration file. """ """ Parse a YAML configuration file. """
import yaml conf_dict = load_yaml(config_path)
def parse(fname):
""" Parse a YAML file. """
try:
with open(fname, encoding='utf-8') as conf_file:
# If configuration file is empty YAML returns None
# We convert that to an empty dict
return yaml.load(conf_file) or {}
except yaml.YAMLError:
error = 'Error reading YAML configuration file {}'.format(fname)
_LOGGER.exception(error)
raise HomeAssistantError(error)
def yaml_include(loader, node):
"""
Loads another YAML file and embeds it using the !include tag.
Example:
device_tracker: !include device_tracker.yaml
"""
fname = os.path.join(os.path.dirname(loader.name), node.value)
return parse(fname)
yaml.add_constructor('!include', yaml_include)
conf_dict = parse(config_path)
if not isinstance(conf_dict, dict): if not isinstance(conf_dict, dict):
_LOGGER.error( _LOGGER.error(

View File

@ -11,7 +11,6 @@ import logging
import signal import signal
import threading import threading
import enum import enum
import re
import functools as ft import functools as ft
from collections import namedtuple from collections import namedtuple
@ -26,6 +25,7 @@ from homeassistant.exceptions import (
import homeassistant.util as util import homeassistant.util as util
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
import homeassistant.util.location as location import homeassistant.util.location as location
from homeassistant.helpers.entity import valid_entity_id
import homeassistant.helpers.temperature as temp_helper import homeassistant.helpers.temperature as temp_helper
from homeassistant.config import get_default_config_dir from homeassistant.config import get_default_config_dir
@ -42,9 +42,6 @@ SERVICE_CALL_LIMIT = 10 # seconds
# will be added for each component that polls devices. # will be added for each component that polls devices.
MIN_WORKER_THREAD = 2 MIN_WORKER_THREAD = 2
# Pattern for validating entity IDs (format: <domain>.<entity>)
ENTITY_ID_PATTERN = re.compile(r"^(?P<domain>\w+)\.(?P<entity>\w+)$")
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# Temporary to support deprecated methods # Temporary to support deprecated methods
@ -339,7 +336,7 @@ class State(object):
def __init__(self, entity_id, state, attributes=None, last_changed=None, def __init__(self, entity_id, state, attributes=None, last_changed=None,
last_updated=None): last_updated=None):
"""Initialize a new state.""" """Initialize a new state."""
if not ENTITY_ID_PATTERN.match(entity_id): if not valid_entity_id(entity_id):
raise InvalidEntityFormatError(( raise InvalidEntityFormatError((
"Invalid entity id encountered: {}. " "Invalid entity id encountered: {}. "
"Format should be <domain>.<object_id>").format(entity_id)) "Format should be <domain>.<object_id>").format(entity_id))

View File

@ -6,6 +6,7 @@ Provides ABC for entities in HA.
""" """
from collections import defaultdict from collections import defaultdict
import re
from homeassistant.exceptions import NoEntitySpecifiedError from homeassistant.exceptions import NoEntitySpecifiedError
@ -17,6 +18,14 @@ from homeassistant.const import (
# Dict mapping entity_id to a boolean that overwrites the hidden property # Dict mapping entity_id to a boolean that overwrites the hidden property
_OVERWRITE = defaultdict(dict) _OVERWRITE = defaultdict(dict)
# Pattern for validating entity IDs (format: <domain>.<entity>)
ENTITY_ID_PATTERN = re.compile(r"^(\w+)\.(\w+)$")
def valid_entity_id(entity_id):
"""Test if an entity ID is a valid format."""
return ENTITY_ID_PATTERN.match(entity_id) is not None
class Entity(object): class Entity(object):
""" ABC for Home Assistant entities. """ """ ABC for Home Assistant entities. """

View File

@ -0,0 +1,50 @@
"""
YAML utility functions.
"""
from collections import OrderedDict
import logging
import os
import yaml
from homeassistant.exceptions import HomeAssistantError
_LOGGER = logging.getLogger(__name__)
def load_yaml(fname):
"""Load a YAML file."""
try:
with open(fname, encoding='utf-8') as conf_file:
# If configuration file is empty YAML returns None
# We convert that to an empty dict
return yaml.load(conf_file) or {}
except yaml.YAMLError:
error = 'Error reading YAML configuration file {}'.format(fname)
_LOGGER.exception(error)
raise HomeAssistantError(error)
def _include_yaml(loader, node):
"""
Loads another YAML file and embeds it using the !include tag.
Example:
device_tracker: !include device_tracker.yaml
"""
fname = os.path.join(os.path.dirname(loader.name), node.value)
return load_yaml(fname)
def _ordered_dict(loader, node):
"""
Loads YAML mappings into an ordered dict to preserve key order.
"""
loader.flatten_mapping(node)
return OrderedDict(loader.construct_pairs(node))
yaml.add_constructor('!include', _include_yaml)
yaml.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
_ordered_dict)

View File

@ -94,13 +94,14 @@ class TestConfig(unittest.TestCase):
with self.assertRaises(HomeAssistantError): with self.assertRaises(HomeAssistantError):
config_util.load_yaml_config_file(YAML_PATH) config_util.load_yaml_config_file(YAML_PATH)
def test_load_config_loads_yaml_config(self): def test_load_yaml_config_preserves_key_order(self):
""" Test correct YAML config loading. """
with open(YAML_PATH, 'w') as f: with open(YAML_PATH, 'w') as f:
f.write('hello: world') f.write('hello: 0\n')
f.write('world: 1\n')
self.assertEqual({'hello': 'world'}, self.assertEqual(
config_util.load_config_file(YAML_PATH)) [('hello', 0), ('world', 1)],
list(config_util.load_yaml_config_file(YAML_PATH).items()))
@mock.patch('homeassistant.util.location.detect_location_info', @mock.patch('homeassistant.util.location.detect_location_info',
mock_detect_location_info) mock_detect_location_info)
@ -109,7 +110,7 @@ class TestConfig(unittest.TestCase):
""" Test that detect location sets the correct config keys. """ """ Test that detect location sets the correct config keys. """
config_util.ensure_config_exists(CONFIG_DIR) config_util.ensure_config_exists(CONFIG_DIR)
config = config_util.load_config_file(YAML_PATH) config = config_util.load_yaml_config_file(YAML_PATH)
self.assertIn(DOMAIN, config) self.assertIn(DOMAIN, config)