Check correctly if package is loadable (#16121)

This commit is contained in:
Paulus Schoutsen 2018-08-22 12:17:14 +02:00 committed by Pascal Vizeli
parent 0009be595c
commit 2e6cb2235c
2 changed files with 37 additions and 9 deletions

View File

@ -32,7 +32,7 @@ def install_package(package: str, upgrade: bool = True,
""" """
# Not using 'import pip; pip.main([])' because it breaks the logger # Not using 'import pip; pip.main([])' because it breaks the logger
with INSTALL_LOCK: with INSTALL_LOCK:
if check_package_exists(package): if package_loadable(package):
return True return True
_LOGGER.info('Attempting install of %s', package) _LOGGER.info('Attempting install of %s', package)
@ -61,8 +61,8 @@ def install_package(package: str, upgrade: bool = True,
return True return True
def check_package_exists(package: str) -> bool: def package_loadable(package: str) -> bool:
"""Check if a package is installed globally or in lib_dir. """Check if a package is what will be loaded when we import it.
Returns True when the requirement is met. Returns True when the requirement is met.
Returns False when the package is not installed or doesn't meet req. Returns False when the package is not installed or doesn't meet req.
@ -73,8 +73,14 @@ def check_package_exists(package: str) -> bool:
# This is a zip file # This is a zip file
req = pkg_resources.Requirement.parse(urlparse(package).fragment) req = pkg_resources.Requirement.parse(urlparse(package).fragment)
env = pkg_resources.Environment() for path in sys.path:
return any(dist in req for dist in env[req.project_name]) for dist in pkg_resources.find_distributions(path):
# If the project name is the same, it will be the one that is
# loaded when we import it.
if dist.project_name == req.project_name:
return dist in req
return False
async def async_get_user_site(deps_dir: str) -> str: async def async_get_user_site(deps_dir: str) -> str:

View File

@ -30,8 +30,8 @@ def mock_sys():
@pytest.fixture @pytest.fixture
def mock_exists(): def mock_exists():
"""Mock check_package_exists.""" """Mock package_loadable."""
with patch('homeassistant.util.package.check_package_exists') as mock: with patch('homeassistant.util.package.package_loadable') as mock:
mock.return_value = False mock.return_value = False
yield mock yield mock
@ -193,12 +193,12 @@ def test_install_constraint(
def test_check_package_global(): def test_check_package_global():
"""Test for an installed package.""" """Test for an installed package."""
installed_package = list(pkg_resources.working_set)[0].project_name installed_package = list(pkg_resources.working_set)[0].project_name
assert package.check_package_exists(installed_package) assert package.package_loadable(installed_package)
def test_check_package_zip(): def test_check_package_zip():
"""Test for an installed zip package.""" """Test for an installed zip package."""
assert not package.check_package_exists(TEST_ZIP_REQ) assert not package.package_loadable(TEST_ZIP_REQ)
@asyncio.coroutine @asyncio.coroutine
@ -217,3 +217,25 @@ def test_async_get_user_site(mock_env_copy):
stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.DEVNULL,
env=env) env=env)
assert ret == os.path.join(deps_dir, 'lib_dir') assert ret == os.path.join(deps_dir, 'lib_dir')
def test_package_loadable_installed_twice():
"""Test that a package is loadable when installed twice.
If a package is installed twice, only the first version will be imported.
Test that package_loadable will only compare with the first package.
"""
v1 = pkg_resources.Distribution(project_name='hello', version='1.0.0')
v2 = pkg_resources.Distribution(project_name='hello', version='2.0.0')
with patch('pkg_resources.find_distributions', side_effect=[[v1]]):
assert not package.package_loadable('hello==2.0.0')
with patch('pkg_resources.find_distributions', side_effect=[[v1], [v2]]):
assert not package.package_loadable('hello==2.0.0')
with patch('pkg_resources.find_distributions', side_effect=[[v2], [v1]]):
assert package.package_loadable('hello==2.0.0')
with patch('pkg_resources.find_distributions', side_effect=[[v2]]):
assert package.package_loadable('hello==2.0.0')