From 2e6cb2235c29bbed8b8fe29262019de8cf132e2a Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 22 Aug 2018 12:17:14 +0200 Subject: [PATCH] Check correctly if package is loadable (#16121) --- homeassistant/util/package.py | 16 +++++++++++----- tests/util/test_package.py | 30 ++++++++++++++++++++++++++---- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/homeassistant/util/package.py b/homeassistant/util/package.py index 9433046e688..feefa65c0f6 100644 --- a/homeassistant/util/package.py +++ b/homeassistant/util/package.py @@ -32,7 +32,7 @@ def install_package(package: str, upgrade: bool = True, """ # Not using 'import pip; pip.main([])' because it breaks the logger with INSTALL_LOCK: - if check_package_exists(package): + if package_loadable(package): return True _LOGGER.info('Attempting install of %s', package) @@ -61,8 +61,8 @@ def install_package(package: str, upgrade: bool = True, return True -def check_package_exists(package: str) -> bool: - """Check if a package is installed globally or in lib_dir. +def package_loadable(package: str) -> bool: + """Check if a package is what will be loaded when we import it. Returns True when the requirement is met. Returns False when the package is not installed or doesn't meet req. @@ -73,8 +73,14 @@ def check_package_exists(package: str) -> bool: # This is a zip file req = pkg_resources.Requirement.parse(urlparse(package).fragment) - env = pkg_resources.Environment() - return any(dist in req for dist in env[req.project_name]) + for path in sys.path: + for dist in pkg_resources.find_distributions(path): + # If the project name is the same, it will be the one that is + # loaded when we import it. + if dist.project_name == req.project_name: + return dist in req + + return False async def async_get_user_site(deps_dir: str) -> str: diff --git a/tests/util/test_package.py b/tests/util/test_package.py index ab9f9f0ad2c..19e85a094ee 100644 --- a/tests/util/test_package.py +++ b/tests/util/test_package.py @@ -30,8 +30,8 @@ def mock_sys(): @pytest.fixture def mock_exists(): - """Mock check_package_exists.""" - with patch('homeassistant.util.package.check_package_exists') as mock: + """Mock package_loadable.""" + with patch('homeassistant.util.package.package_loadable') as mock: mock.return_value = False yield mock @@ -193,12 +193,12 @@ def test_install_constraint( def test_check_package_global(): """Test for an installed package.""" installed_package = list(pkg_resources.working_set)[0].project_name - assert package.check_package_exists(installed_package) + assert package.package_loadable(installed_package) def test_check_package_zip(): """Test for an installed zip package.""" - assert not package.check_package_exists(TEST_ZIP_REQ) + assert not package.package_loadable(TEST_ZIP_REQ) @asyncio.coroutine @@ -217,3 +217,25 @@ def test_async_get_user_site(mock_env_copy): stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.DEVNULL, env=env) assert ret == os.path.join(deps_dir, 'lib_dir') + + +def test_package_loadable_installed_twice(): + """Test that a package is loadable when installed twice. + + If a package is installed twice, only the first version will be imported. + Test that package_loadable will only compare with the first package. + """ + v1 = pkg_resources.Distribution(project_name='hello', version='1.0.0') + v2 = pkg_resources.Distribution(project_name='hello', version='2.0.0') + + with patch('pkg_resources.find_distributions', side_effect=[[v1]]): + assert not package.package_loadable('hello==2.0.0') + + with patch('pkg_resources.find_distributions', side_effect=[[v1], [v2]]): + assert not package.package_loadable('hello==2.0.0') + + with patch('pkg_resources.find_distributions', side_effect=[[v2], [v1]]): + assert package.package_loadable('hello==2.0.0') + + with patch('pkg_resources.find_distributions', side_effect=[[v2]]): + assert package.package_loadable('hello==2.0.0')