diff --git a/homeassistant/requirements.py b/homeassistant/requirements.py index b9b5e137d5c..a3d168d22e7 100644 --- a/homeassistant/requirements.py +++ b/homeassistant/requirements.py @@ -3,12 +3,17 @@ import asyncio from functools import partial import logging import os +import sys from typing import Any, Dict, List, Optional +from urllib.parse import urlparse + +import pkg_resources import homeassistant.util.package as pkg_util from homeassistant.core import HomeAssistant DATA_PIP_LOCK = 'pip_lock' +DATA_PKG_CACHE = 'pkg_cache' CONSTRAINT_FILE = 'package_constraints.txt' _LOGGER = logging.getLogger(__name__) @@ -23,12 +28,20 @@ async def async_process_requirements(hass: HomeAssistant, name: str, if pip_lock is None: pip_lock = hass.data[DATA_PIP_LOCK] = asyncio.Lock(loop=hass.loop) + pkg_cache = hass.data.get(DATA_PKG_CACHE) + if pkg_cache is None: + pkg_cache = hass.data[DATA_PKG_CACHE] = PackageLoadable(hass) + pip_install = partial(pkg_util.install_package, **pip_kwargs(hass.config.config_dir)) async with pip_lock: for req in requirements: + if await pkg_cache.loadable(req): + continue + ret = await hass.async_add_executor_job(pip_install, req) + if not ret: _LOGGER.error("Not initializing %s because could not install " "requirement %s", name, req) @@ -45,3 +58,50 @@ def pip_kwargs(config_dir: Optional[str]) -> Dict[str, Any]: if not (config_dir is None or pkg_util.is_virtual_env()): kwargs['target'] = os.path.join(config_dir, 'deps') return kwargs + + +class PackageLoadable: + """Class to check if a package is loadable, with built-in cache.""" + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the PackageLoadable class.""" + self.dist_cache = {} # type: Dict[str, pkg_resources.Distribution] + self.hass = hass + + async def loadable(self, package: str) -> bool: + """Check if a package is what will be loaded when we import it. + + Returns True when the requirement is met. + Returns False when the package is not installed or doesn't meet req. + """ + dist_cache = self.dist_cache + + try: + req = pkg_resources.Requirement.parse(package) + except ValueError: + # This is a zip file. We no longer use this in Home Assistant, + # leaving it in for custom components. + req = pkg_resources.Requirement.parse(urlparse(package).fragment) + + req_proj_name = req.project_name.lower() + dist = dist_cache.get(req_proj_name) + + if dist is not None: + return dist in req + + for path in sys.path: + # We read the whole mount point as we're already here + # Caching it on first call makes subsequent calls a lot faster. + await self.hass.async_add_executor_job(self._fill_cache, path) + + dist = dist_cache.get(req_proj_name) + if dist is not None: + return dist in req + + return False + + def _fill_cache(self, path: str) -> None: + """Add packages from a path to the cache.""" + dist_cache = self.dist_cache + for dist in pkg_resources.find_distributions(path): + dist_cache.setdefault(dist.project_name.lower(), dist) diff --git a/homeassistant/util/package.py b/homeassistant/util/package.py index 3f12fc223b8..422809f7594 100644 --- a/homeassistant/util/package.py +++ b/homeassistant/util/package.py @@ -4,17 +4,11 @@ import logging import os from subprocess import PIPE, Popen import sys -import threading -from urllib.parse import urlparse from typing import Optional -import pkg_resources - _LOGGER = logging.getLogger(__name__) -INSTALL_LOCK = threading.Lock() - def is_virtual_env() -> bool: """Return if we run in a virtual environtment.""" @@ -31,58 +25,30 @@ def install_package(package: str, upgrade: bool = True, Return boolean if install successful. """ # Not using 'import pip; pip.main([])' because it breaks the logger - with INSTALL_LOCK: - if package_loadable(package): - return True + _LOGGER.info('Attempting install of %s', package) + env = os.environ.copy() + args = [sys.executable, '-m', 'pip', 'install', '--quiet', package] + if upgrade: + args.append('--upgrade') + if constraints is not None: + args += ['--constraint', constraints] + if target: + assert not is_virtual_env() + # This only works if not running in venv + args += ['--user'] + env['PYTHONUSERBASE'] = os.path.abspath(target) + if sys.platform != 'win32': + # Workaround for incompatible prefix setting + # See http://stackoverflow.com/a/4495175 + args += ['--prefix='] + process = Popen(args, stdin=PIPE, stdout=PIPE, stderr=PIPE, env=env) + _, stderr = process.communicate() + if process.returncode != 0: + _LOGGER.error("Unable to install package %s: %s", + package, stderr.decode('utf-8').lstrip().strip()) + return False - _LOGGER.info('Attempting install of %s', package) - env = os.environ.copy() - args = [sys.executable, '-m', 'pip', 'install', '--quiet', package] - if upgrade: - args.append('--upgrade') - if constraints is not None: - args += ['--constraint', constraints] - if target: - assert not is_virtual_env() - # This only works if not running in venv - args += ['--user'] - env['PYTHONUSERBASE'] = os.path.abspath(target) - if sys.platform != 'win32': - # Workaround for incompatible prefix setting - # See http://stackoverflow.com/a/4495175 - args += ['--prefix='] - process = Popen(args, stdin=PIPE, stdout=PIPE, stderr=PIPE, env=env) - _, stderr = process.communicate() - if process.returncode != 0: - _LOGGER.error("Unable to install package %s: %s", - package, stderr.decode('utf-8').lstrip().strip()) - return False - - return True - - -def package_loadable(package: str) -> bool: - """Check if a package is what will be loaded when we import it. - - Returns True when the requirement is met. - Returns False when the package is not installed or doesn't meet req. - """ - try: - req = pkg_resources.Requirement.parse(package) - except ValueError: - # This is a zip file - req = pkg_resources.Requirement.parse(urlparse(package).fragment) - - req_proj_name = req.project_name.lower() - - for path in sys.path: - for dist in pkg_resources.find_distributions(path): - # If the project name is the same, it will be the one that is - # loaded when we import it. - if dist.project_name.lower() == req_proj_name: - return dist in req - - return False + return True async def async_get_user_site(deps_dir: str) -> str: diff --git a/tests/test_requirements.py b/tests/test_requirements.py index e3ef797df4d..71ae80f22e4 100644 --- a/tests/test_requirements.py +++ b/tests/test_requirements.py @@ -1,11 +1,22 @@ """Test requirements module.""" import os -from unittest import mock +from unittest.mock import patch, call from homeassistant import loader, setup -from homeassistant.requirements import CONSTRAINT_FILE +from homeassistant.requirements import ( + CONSTRAINT_FILE, PackageLoadable, async_process_requirements) -from tests.common import get_test_home_assistant, MockModule +import pkg_resources + +from tests.common import get_test_home_assistant, MockModule, mock_coro + +RESOURCE_DIR = os.path.abspath( + os.path.join(os.path.dirname(__file__), '..', 'resources')) + +TEST_NEW_REQ = 'pyhelloworld3==1.0.0' + +TEST_ZIP_REQ = 'file://{}#{}' \ + .format(os.path.join(RESOURCE_DIR, 'pyhelloworld3.zip'), TEST_NEW_REQ) class TestRequirements: @@ -23,11 +34,9 @@ class TestRequirements: """Clean up.""" self.hass.stop() - @mock.patch('os.path.dirname') - @mock.patch('homeassistant.util.package.is_virtual_env', - return_value=True) - @mock.patch('homeassistant.util.package.install_package', - return_value=True) + @patch('os.path.dirname') + @patch('homeassistant.util.package.is_virtual_env', return_value=True) + @patch('homeassistant.util.package.install_package', return_value=True) def test_requirement_installed_in_venv( self, mock_install, mock_venv, mock_dirname): """Test requirement installed in virtual environment.""" @@ -39,15 +48,13 @@ class TestRequirements: MockModule('comp', requirements=['package==0.0.1'])) assert setup.setup_component(self.hass, 'comp') assert 'comp' in self.hass.config.components - assert mock_install.call_args == mock.call( + assert mock_install.call_args == call( 'package==0.0.1', constraints=os.path.join('ha_package_path', CONSTRAINT_FILE)) - @mock.patch('os.path.dirname') - @mock.patch('homeassistant.util.package.is_virtual_env', - return_value=False) - @mock.patch('homeassistant.util.package.install_package', - return_value=True) + @patch('os.path.dirname') + @patch('homeassistant.util.package.is_virtual_env', return_value=False) + @patch('homeassistant.util.package.install_package', return_value=True) def test_requirement_installed_in_deps( self, mock_install, mock_venv, mock_dirname): """Test requirement installed in deps directory.""" @@ -58,6 +65,61 @@ class TestRequirements: MockModule('comp', requirements=['package==0.0.1'])) assert setup.setup_component(self.hass, 'comp') assert 'comp' in self.hass.config.components - assert mock_install.call_args == mock.call( + assert mock_install.call_args == call( 'package==0.0.1', target=self.hass.config.path('deps'), constraints=os.path.join('ha_package_path', CONSTRAINT_FILE)) + + +async def test_install_existing_package(hass): + """Test an install attempt on an existing package.""" + with patch('homeassistant.util.package.install_package', + return_value=mock_coro(True)) as mock_inst: + assert await async_process_requirements( + hass, 'test_component', ['hello==1.0.0']) + + assert len(mock_inst.mock_calls) == 1 + + with patch('homeassistant.requirements.PackageLoadable.loadable', + return_value=mock_coro(True)), \ + patch( + 'homeassistant.util.package.install_package') as mock_inst: + assert await async_process_requirements( + hass, 'test_component', ['hello==1.0.0']) + + assert len(mock_inst.mock_calls) == 0 + + +async def test_check_package_global(hass): + """Test for an installed package.""" + installed_package = list(pkg_resources.working_set)[0].project_name + assert await PackageLoadable(hass).loadable(installed_package) + + +async def test_check_package_zip(hass): + """Test for an installed zip package.""" + assert not await PackageLoadable(hass).loadable(TEST_ZIP_REQ) + + +async def test_package_loadable_installed_twice(hass): + """Test that a package is loadable when installed twice. + + If a package is installed twice, only the first version will be imported. + Test that package_loadable will only compare with the first package. + """ + v1 = pkg_resources.Distribution(project_name='hello', version='1.0.0') + v2 = pkg_resources.Distribution(project_name='hello', version='2.0.0') + + with patch('pkg_resources.find_distributions', side_effect=[[v1]]): + assert not await PackageLoadable(hass).loadable('hello==2.0.0') + + with patch('pkg_resources.find_distributions', side_effect=[[v1], [v2]]): + assert not await PackageLoadable(hass).loadable('hello==2.0.0') + + with patch('pkg_resources.find_distributions', side_effect=[[v2], [v1]]): + assert await PackageLoadable(hass).loadable('hello==2.0.0') + + with patch('pkg_resources.find_distributions', side_effect=[[v2]]): + assert await PackageLoadable(hass).loadable('hello==2.0.0') + + with patch('pkg_resources.find_distributions', side_effect=[[v2]]): + assert await PackageLoadable(hass).loadable('Hello==2.0.0') diff --git a/tests/util/test_package.py b/tests/util/test_package.py index 1e93a078bd9..5422140c232 100644 --- a/tests/util/test_package.py +++ b/tests/util/test_package.py @@ -6,18 +6,12 @@ import sys from subprocess import PIPE from unittest.mock import MagicMock, call, patch -import pkg_resources import pytest import homeassistant.util.package as package -RESOURCE_DIR = os.path.abspath( - os.path.join(os.path.dirname(__file__), '..', 'resources')) -TEST_EXIST_REQ = 'pip>=7.0.0' TEST_NEW_REQ = 'pyhelloworld3==1.0.0' -TEST_ZIP_REQ = 'file://{}#{}' \ - .format(os.path.join(RESOURCE_DIR, 'pyhelloworld3.zip'), TEST_NEW_REQ) @pytest.fixture @@ -28,14 +22,6 @@ def mock_sys(): yield sys_mock -@pytest.fixture -def mock_exists(): - """Mock package_loadable.""" - with patch('homeassistant.util.package.package_loadable') as mock: - mock.return_value = False - yield mock - - @pytest.fixture def deps_dir(): """Return path to deps directory.""" @@ -89,20 +75,10 @@ def mock_async_subprocess(): return async_popen -def test_install_existing_package(mock_exists, mock_popen): - """Test an install attempt on an existing package.""" - mock_exists.return_value = True - assert package.install_package(TEST_EXIST_REQ) - assert mock_exists.call_count == 1 - assert mock_exists.call_args == call(TEST_EXIST_REQ) - assert mock_popen.return_value.communicate.call_count == 0 - - -def test_install(mock_sys, mock_exists, mock_popen, mock_env_copy, mock_venv): +def test_install(mock_sys, mock_popen, mock_env_copy, mock_venv): """Test an install attempt on a package that doesn't exist.""" env = mock_env_copy() assert package.install_package(TEST_NEW_REQ, False) - assert mock_exists.call_count == 1 assert mock_popen.call_count == 1 assert ( mock_popen.call_args == @@ -115,11 +91,10 @@ def test_install(mock_sys, mock_exists, mock_popen, mock_env_copy, mock_venv): def test_install_upgrade( - mock_sys, mock_exists, mock_popen, mock_env_copy, mock_venv): + mock_sys, mock_popen, mock_env_copy, mock_venv): """Test an upgrade attempt on a package.""" env = mock_env_copy() assert package.install_package(TEST_NEW_REQ) - assert mock_exists.call_count == 1 assert mock_popen.call_count == 1 assert ( mock_popen.call_args == @@ -131,8 +106,7 @@ def test_install_upgrade( assert mock_popen.return_value.communicate.call_count == 1 -def test_install_target( - mock_sys, mock_exists, mock_popen, mock_env_copy, mock_venv): +def test_install_target(mock_sys, mock_popen, mock_env_copy, mock_venv): """Test an install with a target.""" target = 'target_folder' env = mock_env_copy() @@ -144,7 +118,6 @@ def test_install_target( TEST_NEW_REQ, '--user', '--prefix='] assert package.install_package(TEST_NEW_REQ, False, target=target) - assert mock_exists.call_count == 1 assert mock_popen.call_count == 1 assert ( mock_popen.call_args == @@ -153,15 +126,14 @@ def test_install_target( assert mock_popen.return_value.communicate.call_count == 1 -def test_install_target_venv( - mock_sys, mock_exists, mock_popen, mock_env_copy, mock_venv): +def test_install_target_venv(mock_sys, mock_popen, mock_env_copy, mock_venv): """Test an install with a target in a virtual environment.""" target = 'target_folder' with pytest.raises(AssertionError): package.install_package(TEST_NEW_REQ, False, target=target) -def test_install_error(caplog, mock_sys, mock_exists, mock_popen, mock_venv): +def test_install_error(caplog, mock_sys, mock_popen, mock_venv): """Test an install with a target.""" caplog.set_level(logging.WARNING) mock_popen.return_value.returncode = 1 @@ -171,14 +143,12 @@ def test_install_error(caplog, mock_sys, mock_exists, mock_popen, mock_venv): assert record.levelname == 'ERROR' -def test_install_constraint( - mock_sys, mock_exists, mock_popen, mock_env_copy, mock_venv): +def test_install_constraint(mock_sys, mock_popen, mock_env_copy, mock_venv): """Test install with constraint file on not installed package.""" env = mock_env_copy() constraints = 'constraints_file.txt' assert package.install_package( TEST_NEW_REQ, False, constraints=constraints) - assert mock_exists.call_count == 1 assert mock_popen.call_count == 1 assert ( mock_popen.call_args == @@ -190,17 +160,6 @@ def test_install_constraint( assert mock_popen.return_value.communicate.call_count == 1 -def test_check_package_global(): - """Test for an installed package.""" - installed_package = list(pkg_resources.working_set)[0].project_name - assert package.package_loadable(installed_package) - - -def test_check_package_zip(): - """Test for an installed zip package.""" - assert not package.package_loadable(TEST_ZIP_REQ) - - @asyncio.coroutine def test_async_get_user_site(mock_env_copy): """Test async get user site directory.""" @@ -217,28 +176,3 @@ def test_async_get_user_site(mock_env_copy): stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.DEVNULL, env=env) assert ret == os.path.join(deps_dir, 'lib_dir') - - -def test_package_loadable_installed_twice(): - """Test that a package is loadable when installed twice. - - If a package is installed twice, only the first version will be imported. - Test that package_loadable will only compare with the first package. - """ - v1 = pkg_resources.Distribution(project_name='hello', version='1.0.0') - v2 = pkg_resources.Distribution(project_name='hello', version='2.0.0') - - with patch('pkg_resources.find_distributions', side_effect=[[v1]]): - assert not package.package_loadable('hello==2.0.0') - - with patch('pkg_resources.find_distributions', side_effect=[[v1], [v2]]): - assert not package.package_loadable('hello==2.0.0') - - with patch('pkg_resources.find_distributions', side_effect=[[v2], [v1]]): - assert package.package_loadable('hello==2.0.0') - - with patch('pkg_resources.find_distributions', side_effect=[[v2]]): - assert package.package_loadable('hello==2.0.0') - - with patch('pkg_resources.find_distributions', side_effect=[[v2]]): - assert package.package_loadable('Hello==2.0.0')