Check correctly if package is loadable (#16121)

This commit is contained in:
Paulus Schoutsen 2018-08-22 12:17:14 +02:00 committed by Pascal Vizeli
parent 0009be595c
commit 2e6cb2235c
2 changed files with 37 additions and 9 deletions

View File

@ -32,7 +32,7 @@ def install_package(package: str, upgrade: bool = True,
"""
# Not using 'import pip; pip.main([])' because it breaks the logger
with INSTALL_LOCK:
if check_package_exists(package):
if package_loadable(package):
return True
_LOGGER.info('Attempting install of %s', package)
@ -61,8 +61,8 @@ def install_package(package: str, upgrade: bool = True,
return True
def check_package_exists(package: str) -> bool:
"""Check if a package is installed globally or in lib_dir.
def package_loadable(package: str) -> bool:
"""Check if a package is what will be loaded when we import it.
Returns True when the requirement is met.
Returns False when the package is not installed or doesn't meet req.
@ -73,8 +73,14 @@ def check_package_exists(package: str) -> bool:
# This is a zip file
req = pkg_resources.Requirement.parse(urlparse(package).fragment)
env = pkg_resources.Environment()
return any(dist in req for dist in env[req.project_name])
for path in sys.path:
for dist in pkg_resources.find_distributions(path):
# If the project name is the same, it will be the one that is
# loaded when we import it.
if dist.project_name == req.project_name:
return dist in req
return False
async def async_get_user_site(deps_dir: str) -> str:

View File

@ -30,8 +30,8 @@ def mock_sys():
@pytest.fixture
def mock_exists():
"""Mock check_package_exists."""
with patch('homeassistant.util.package.check_package_exists') as mock:
"""Mock package_loadable."""
with patch('homeassistant.util.package.package_loadable') as mock:
mock.return_value = False
yield mock
@ -193,12 +193,12 @@ def test_install_constraint(
def test_check_package_global():
"""Test for an installed package."""
installed_package = list(pkg_resources.working_set)[0].project_name
assert package.check_package_exists(installed_package)
assert package.package_loadable(installed_package)
def test_check_package_zip():
"""Test for an installed zip package."""
assert not package.check_package_exists(TEST_ZIP_REQ)
assert not package.package_loadable(TEST_ZIP_REQ)
@asyncio.coroutine
@ -217,3 +217,25 @@ def test_async_get_user_site(mock_env_copy):
stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.DEVNULL,
env=env)
assert ret == os.path.join(deps_dir, 'lib_dir')
def test_package_loadable_installed_twice():
"""Test that a package is loadable when installed twice.
If a package is installed twice, only the first version will be imported.
Test that package_loadable will only compare with the first package.
"""
v1 = pkg_resources.Distribution(project_name='hello', version='1.0.0')
v2 = pkg_resources.Distribution(project_name='hello', version='2.0.0')
with patch('pkg_resources.find_distributions', side_effect=[[v1]]):
assert not package.package_loadable('hello==2.0.0')
with patch('pkg_resources.find_distributions', side_effect=[[v1], [v2]]):
assert not package.package_loadable('hello==2.0.0')
with patch('pkg_resources.find_distributions', side_effect=[[v2], [v1]]):
assert package.package_loadable('hello==2.0.0')
with patch('pkg_resources.find_distributions', side_effect=[[v2]]):
assert package.package_loadable('hello==2.0.0')