Yaml secret fallback to parent folders (#2878)

* Move secret cache out of loader so it can be referenced by other folders
* Unit test to verify secrets from another folder work & see if it overrides parent secret
* Clear secret cache after load
This commit is contained in:
Teagan Glenn 2016-08-20 13:39:56 -06:00 committed by Johann Kellerman
parent ca75e66c1a
commit 8d1a9d86ea
3 changed files with 93 additions and 30 deletions

View File

@ -19,6 +19,7 @@ import homeassistant.config as conf_util
import homeassistant.core as core import homeassistant.core as core
import homeassistant.loader as loader import homeassistant.loader as loader
import homeassistant.util.package as pkg_util import homeassistant.util.package as pkg_util
from homeassistant.util.yaml import clear_secret_cache
from homeassistant.const import EVENT_COMPONENT_LOADED, PLATFORM_FORMAT from homeassistant.const import EVENT_COMPONENT_LOADED, PLATFORM_FORMAT
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import ( from homeassistant.helpers import (
@ -308,6 +309,8 @@ def from_config_file(config_path: str,
config_dict = conf_util.load_yaml_config_file(config_path) config_dict = conf_util.load_yaml_config_file(config_path)
except HomeAssistantError: except HomeAssistantError:
return None return None
finally:
clear_secret_cache()
return from_config_dict(config_dict, hass, enable_log=False, return from_config_dict(config_dict, hass, enable_log=False,
skip_pip=skip_pip) skip_pip=skip_pip)

View File

@ -2,6 +2,7 @@
import glob import glob
import logging import logging
import os import os
import sys
from collections import OrderedDict from collections import OrderedDict
from typing import Union, List, Dict from typing import Union, List, Dict
@ -16,6 +17,7 @@ from homeassistant.exceptions import HomeAssistantError
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_SECRET_NAMESPACE = 'homeassistant' _SECRET_NAMESPACE = 'homeassistant'
_SECRET_YAML = 'secrets.yaml' _SECRET_YAML = 'secrets.yaml'
__SECRET_CACHE = {} # type: Dict
# pylint: disable=too-many-ancestors # pylint: disable=too-many-ancestors
@ -43,6 +45,11 @@ def load_yaml(fname: str) -> Union[List, Dict]:
raise HomeAssistantError(exc) raise HomeAssistantError(exc)
def clear_secret_cache() -> None:
"""Clear the secrete cache."""
__SECRET_CACHE.clear()
def _include_yaml(loader: SafeLineLoader, def _include_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node) -> Union[List, Dict]: node: yaml.nodes.Node) -> Union[List, Dict]:
"""Load another YAML file and embeds it using the !include tag. """Load another YAML file and embeds it using the !include tag.
@ -140,40 +147,44 @@ def _env_var_yaml(loader: SafeLineLoader,
raise HomeAssistantError(node.value) raise HomeAssistantError(node.value)
def _load_secret_yaml(secret_path: str) -> Dict:
"""Load the secrets yaml from path."""
_LOGGER.debug('Loading %s', os.path.join(secret_path, _SECRET_YAML))
secrets = {}
if os.path.isfile(os.path.join(secret_path, _SECRET_YAML)):
secrets = load_yaml(
os.path.join(secret_path, _SECRET_YAML))
if 'logger' in secrets:
logger = str(secrets['logger']).lower()
if logger == 'debug':
_LOGGER.setLevel(logging.DEBUG)
else:
_LOGGER.error("secrets.yaml: 'logger: debug' expected,"
" but 'logger: %s' found", logger)
del secrets['logger']
return secrets
# pylint: disable=protected-access # pylint: disable=protected-access
def _secret_yaml(loader: SafeLineLoader, def _secret_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node): node: yaml.nodes.Node):
"""Load secrets and embed it into the configuration YAML.""" """Load secrets and embed it into the configuration YAML."""
# Create secret cache on loader and load secrets.yaml secret_path = os.path.dirname(loader.name)
if not hasattr(loader, '_SECRET_CACHE'): while os.path.exists(secret_path) and not secret_path == os.path.dirname(
loader._SECRET_CACHE = {} sys.path[0]):
secrets = __SECRET_CACHE.get(secret_path,
_load_secret_yaml(secret_path))
if node.value in secrets:
_LOGGER.debug('Secret %s retrieved from secrets.yaml in '
'folder %s', node.value, secret_path)
return secrets[node.value]
next_path = os.path.dirname(secret_path)
secret_path = os.path.join(os.path.dirname(loader.name), _SECRET_YAML) if not next_path or next_path == secret_path:
if secret_path not in loader._SECRET_CACHE: # Somehow we got past the .homeassistant configuration folder...
if os.path.isfile(secret_path): break
loader._SECRET_CACHE[secret_path] = load_yaml(secret_path)
secrets = loader._SECRET_CACHE[secret_path]
if 'logger' in secrets:
logger = str(secrets['logger']).lower()
if logger == 'debug':
_LOGGER.setLevel(logging.DEBUG)
else:
_LOGGER.error("secrets.yaml: 'logger: debug' expected,"
" but 'logger: %s' found", logger)
del secrets['logger']
else:
loader._SECRET_CACHE[secret_path] = None
secrets = loader._SECRET_CACHE[secret_path]
# Retrieve secret, first from secrets.yaml, then from keyring secret_path = next_path
if secrets is not None and node.value in secrets:
_LOGGER.debug('Secret %s retrieved from secrets.yaml.', node.value)
return secrets[node.value]
for sname, sdict in loader._SECRET_CACHE.items():
if node.value in sdict:
_LOGGER.debug('Secret %s retrieved from secrets.yaml in other '
'folder %s', node.value, sname)
return sdict[node.value]
if keyring: if keyring:
# do ome keyring stuff # do ome keyring stuff

View File

@ -3,6 +3,7 @@ import io
import unittest import unittest
import os import os
import tempfile import tempfile
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import yaml from homeassistant.util import yaml
import homeassistant.config as config_util import homeassistant.config as config_util
from tests.common import get_test_config_dir from tests.common import get_test_config_dir
@ -165,9 +166,16 @@ class TestSecrets(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
"""Create & load secrets file.""" """Create & load secrets file."""
config_dir = get_test_config_dir() config_dir = get_test_config_dir()
yaml.clear_secret_cache()
self._yaml_path = os.path.join(config_dir, self._yaml_path = os.path.join(config_dir,
config_util.YAML_CONFIG_FILE) config_util.YAML_CONFIG_FILE)
self._secret_path = os.path.join(config_dir, 'secrets.yaml') self._secret_path = os.path.join(config_dir, yaml._SECRET_YAML)
self._sub_folder_path = os.path.join(config_dir, 'subFolder')
if not os.path.exists(self._sub_folder_path):
os.makedirs(self._sub_folder_path)
self._unrelated_path = os.path.join(config_dir, 'unrelated')
if not os.path.exists(self._unrelated_path):
os.makedirs(self._unrelated_path)
load_yaml(self._secret_path, load_yaml(self._secret_path,
'http_pw: pwhttp\n' 'http_pw: pwhttp\n'
@ -185,7 +193,11 @@ class TestSecrets(unittest.TestCase):
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
"""Clean up secrets.""" """Clean up secrets."""
for path in [self._yaml_path, self._secret_path]: yaml.clear_secret_cache()
for path in [self._yaml_path, self._secret_path,
os.path.join(self._sub_folder_path, 'sub.yaml'),
os.path.join(self._sub_folder_path, yaml._SECRET_YAML),
os.path.join(self._unrelated_path, yaml._SECRET_YAML)]:
if os.path.isfile(path): if os.path.isfile(path):
os.remove(path) os.remove(path)
@ -199,6 +211,43 @@ class TestSecrets(unittest.TestCase):
'password': 'pw1'} 'password': 'pw1'}
self.assertEqual(expected, self._yaml['component']) self.assertEqual(expected, self._yaml['component'])
def test_secrets_from_parent_folder(self):
"""Test loading secrets from parent foler."""
expected = {'api_password': 'pwhttp'}
self._yaml = load_yaml(os.path.join(self._sub_folder_path, 'sub.yaml'),
'http:\n'
' api_password: !secret http_pw\n'
'component:\n'
' username: !secret comp1_un\n'
' password: !secret comp1_pw\n'
'')
self.assertEqual(expected, self._yaml['http'])
def test_secret_overrides_parent(self):
"""Test loading current directory secret overrides the parent."""
expected = {'api_password': 'override'}
load_yaml(os.path.join(self._sub_folder_path, yaml._SECRET_YAML),
'http_pw: override')
self._yaml = load_yaml(os.path.join(self._sub_folder_path, 'sub.yaml'),
'http:\n'
' api_password: !secret http_pw\n'
'component:\n'
' username: !secret comp1_un\n'
' password: !secret comp1_pw\n'
'')
self.assertEqual(expected, self._yaml['http'])
def test_secrets_from_unrelated_fails(self):
"""Test loading secrets from unrelated folder fails."""
load_yaml(os.path.join(self._unrelated_path, yaml._SECRET_YAML),
'test: failure')
with self.assertRaises(HomeAssistantError):
load_yaml(os.path.join(self._sub_folder_path, 'sub.yaml'),
'http:\n'
' api_password: !secret test')
def test_secrets_keyring(self): def test_secrets_keyring(self):
"""Test keyring fallback & get_password.""" """Test keyring fallback & get_password."""
yaml.keyring = None # Ensure its not there yaml.keyring = None # Ensure its not there