mirror of
https://github.com/home-assistant/core.git
synced 2025-07-20 03:37:07 +00:00
Handle requirements for scripts (#2765)
This commit is contained in:
parent
a03691455b
commit
aadf6a7750
@ -232,7 +232,7 @@ def from_config_dict(config: Dict[str, Any],
|
|||||||
if config_dir is not None:
|
if config_dir is not None:
|
||||||
config_dir = os.path.abspath(config_dir)
|
config_dir = os.path.abspath(config_dir)
|
||||||
hass.config.config_dir = config_dir
|
hass.config.config_dir = config_dir
|
||||||
_mount_local_lib_path(config_dir)
|
mount_local_lib_path(config_dir)
|
||||||
|
|
||||||
core_config = config.get(core.DOMAIN, {})
|
core_config = config.get(core.DOMAIN, {})
|
||||||
|
|
||||||
@ -300,7 +300,7 @@ def from_config_file(config_path: str,
|
|||||||
# Set config dir to directory holding config file
|
# Set config dir to directory holding config file
|
||||||
config_dir = os.path.abspath(os.path.dirname(config_path))
|
config_dir = os.path.abspath(os.path.dirname(config_path))
|
||||||
hass.config.config_dir = config_dir
|
hass.config.config_dir = config_dir
|
||||||
_mount_local_lib_path(config_dir)
|
mount_local_lib_path(config_dir)
|
||||||
|
|
||||||
enable_logging(hass, verbose, log_rotate_days)
|
enable_logging(hass, verbose, log_rotate_days)
|
||||||
|
|
||||||
@ -371,11 +371,6 @@ def _ensure_loader_prepared(hass: core.HomeAssistant) -> None:
|
|||||||
loader.prepare(hass)
|
loader.prepare(hass)
|
||||||
|
|
||||||
|
|
||||||
def _mount_local_lib_path(config_dir: str) -> None:
|
|
||||||
"""Add local library to Python Path."""
|
|
||||||
sys.path.insert(0, os.path.join(config_dir, 'deps'))
|
|
||||||
|
|
||||||
|
|
||||||
def _log_exception(ex, domain, config):
|
def _log_exception(ex, domain, config):
|
||||||
"""Generate log exception for config validation."""
|
"""Generate log exception for config validation."""
|
||||||
message = 'Invalid config for [{}]: '.format(domain)
|
message = 'Invalid config for [{}]: '.format(domain)
|
||||||
@ -391,3 +386,11 @@ def _log_exception(ex, domain, config):
|
|||||||
config.__line__ or '?')
|
config.__line__ or '?')
|
||||||
|
|
||||||
_LOGGER.error(message)
|
_LOGGER.error(message)
|
||||||
|
|
||||||
|
|
||||||
|
def mount_local_lib_path(config_dir: str) -> str:
|
||||||
|
"""Add local library to Python Path."""
|
||||||
|
deps_dir = os.path.join(config_dir, 'deps')
|
||||||
|
if deps_dir not in sys.path:
|
||||||
|
sys.path.insert(0, os.path.join(config_dir, 'deps'))
|
||||||
|
return deps_dir
|
||||||
|
@ -1,9 +1,15 @@
|
|||||||
"""Home Assistant command line scripts."""
|
"""Home Assistant command line scripts."""
|
||||||
|
import argparse
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from homeassistant.config import get_default_config_dir
|
||||||
|
from homeassistant.util.package import install_package
|
||||||
|
from homeassistant.bootstrap import mount_local_lib_path
|
||||||
|
|
||||||
|
|
||||||
def run(args: str) -> int:
|
def run(args: List) -> int:
|
||||||
"""Run a script."""
|
"""Run a script."""
|
||||||
scripts = []
|
scripts = []
|
||||||
path = os.path.dirname(__file__)
|
path = os.path.dirname(__file__)
|
||||||
@ -26,4 +32,21 @@ def run(args: str) -> int:
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
script = importlib.import_module('homeassistant.scripts.' + args[0])
|
script = importlib.import_module('homeassistant.scripts.' + args[0])
|
||||||
|
|
||||||
|
config_dir = extract_config_dir()
|
||||||
|
deps_dir = mount_local_lib_path(config_dir)
|
||||||
|
for req in getattr(script, 'REQUIREMENTS', []):
|
||||||
|
if not install_package(req, target=deps_dir):
|
||||||
|
print('Aborting scipt, could not install dependency', req)
|
||||||
|
return 1
|
||||||
|
|
||||||
return script.run(args[1:]) # type: ignore
|
return script.run(args[1:]) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def extract_config_dir(args=None) -> str:
|
||||||
|
"""Extract the config dir from the arguments or get the default."""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-c', '--config', default=None)
|
||||||
|
args = parser.parse_known_args(args)[0]
|
||||||
|
return (os.path.join(os.getcwd(), args.config) if args.config
|
||||||
|
else get_default_config_dir())
|
||||||
|
1
tests/scripts/__init__.py
Normal file
1
tests/scripts/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Tests for the scripts."""
|
19
tests/scripts/test_init.py
Normal file
19
tests/scripts/test_init.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
"""Test script init."""
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import homeassistant.scripts as scripts
|
||||||
|
|
||||||
|
|
||||||
|
class TestScripts(unittest.TestCase):
|
||||||
|
"""Tests homeassistant.scripts module."""
|
||||||
|
|
||||||
|
@patch('homeassistant.scripts.get_default_config_dir',
|
||||||
|
return_value='/default')
|
||||||
|
def test_config_per_platform(self, mock_def):
|
||||||
|
"""Test config per platform method."""
|
||||||
|
self.assertEquals(scripts.get_default_config_dir(), '/default')
|
||||||
|
self.assertEqual(scripts.extract_config_dir(), '/default')
|
||||||
|
self.assertEqual(scripts.extract_config_dir(['']), '/default')
|
||||||
|
self.assertEqual(scripts.extract_config_dir(['-c', '/arg']), '/arg')
|
||||||
|
self.assertEqual(scripts.extract_config_dir(['--config', '/a']), '/a')
|
@ -3,7 +3,7 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import homeassistant.bootstrap as bootstrap
|
from homeassistant.bootstrap import mount_local_lib_path
|
||||||
import homeassistant.util.package as package
|
import homeassistant.util.package as package
|
||||||
|
|
||||||
RESOURCE_DIR = os.path.abspath(
|
RESOURCE_DIR = os.path.abspath(
|
||||||
@ -21,7 +21,7 @@ class TestPackageUtil(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Create local library for testing."""
|
"""Create local library for testing."""
|
||||||
self.tmp_dir = tempfile.TemporaryDirectory()
|
self.tmp_dir = tempfile.TemporaryDirectory()
|
||||||
self.lib_dir = os.path.join(self.tmp_dir.name, 'deps')
|
self.lib_dir = mount_local_lib_path(self.tmp_dir.name)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
"""Stop everything that was started."""
|
"""Stop everything that was started."""
|
||||||
@ -49,8 +49,6 @@ class TestPackageUtil(unittest.TestCase):
|
|||||||
self.assertTrue(package.check_package_exists(
|
self.assertTrue(package.check_package_exists(
|
||||||
TEST_NEW_REQ, self.lib_dir))
|
TEST_NEW_REQ, self.lib_dir))
|
||||||
|
|
||||||
bootstrap._mount_local_lib_path(self.tmp_dir.name)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pyhelloworld3
|
import pyhelloworld3
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user