Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 36 additions & 34 deletions mmengine/utils/package_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import subprocess
from importlib.metadata import PackageNotFoundError, distribution
from typing import Any


def is_installed(package: str) -> bool:
Expand All @@ -9,28 +11,19 @@ def is_installed(package: str) -> bool:
Args:
package (str): Name of package to be checked.
"""
# When executing `import mmengine.runner`,
# pkg_resources will be imported and it takes too much time.
# Therefore, import it in function scope to save time.
import importlib.util

import pkg_resources # type: ignore
from pkg_resources import get_distribution
# First check if it's an importable module
spec = importlib.util.find_spec(package)
if spec is not None and spec.origin is not None:
return True

# refresh the pkg_resources
# more datails at https://github.com/pypa/setuptools/issues/373
importlib.reload(pkg_resources)
# If not found as module, check if it's a distribution package
try:
get_distribution(package)
distribution(package)
return True
except pkg_resources.DistributionNotFound:
spec = importlib.util.find_spec(package)
if spec is None:
return False
elif spec.origin is not None:
return True
else:
return False
except PackageNotFoundError:
return False


def get_installed_path(package: str) -> str:
Expand All @@ -45,17 +38,21 @@ def get_installed_path(package: str) -> str:
"""
import importlib.util

from pkg_resources import DistributionNotFound, get_distribution

# if the package name is not the same as module name, module name should be
# inferred. For example, mmcv-full is the package name, but mmcv is module
# name. If we want to get the installed path of mmcv-full, we should concat
# the pkg.location and module name
# Try to get location from distribution package metadata
location = None
try:
pkg = get_distribution(package)
except DistributionNotFound as e:
# if the package is not installed, package path set in PYTHONPATH
# can be detected by `find_spec`
dist = distribution(package)
locate_result: Any = dist.locate_file('')
location = str(locate_result.parent)
except PackageNotFoundError:
pass

# If distribution package not found, try to find via importlib
if location is None:
spec = importlib.util.find_spec(package)
if spec is not None:
if spec.origin is not None:
Expand All @@ -67,28 +64,33 @@ def get_installed_path(package: str) -> str:
f'{package} is a namespace package, which is invalid '
'for `get_install_path`')
else:
raise e
raise PackageNotFoundError(f'Package {package} is not installed')

possible_path = osp.join(pkg.location, package) # type: ignore
# Check if package directory exists in the location
possible_path = osp.join(location, package)
if osp.exists(possible_path):
return possible_path
else:
return osp.join(pkg.location, package2module(package)) # type: ignore
return osp.join(location, package2module(package))


def package2module(package: str):
def package2module(package: str) -> str:
"""Infer module name from package.

Args:
package (str): Package to infer module name.
"""
from pkg_resources import get_distribution
pkg = get_distribution(package)
if pkg.has_metadata('top_level.txt'):
module_name = pkg.get_metadata('top_level.txt').split('\n')[0]
return module_name
else:
raise ValueError(f'can not infer the module name of {package}')
dist = distribution(package)

# In importlib.metadata,
# top-level modules are in dist.read_text('top_level.txt')
top_level_text = dist.read_text('top_level.txt')
if top_level_text is not None:
lines = top_level_text.strip().split('\n')
if lines:
module_name = lines[0].strip()
return module_name
raise ValueError(f'can not infer the module name of {package}')


def call_command(cmd: list) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/data/config/lazy_module_config/test_ast_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
from ._base_.default_runtime import default_scope as scope
from ._base_.scheduler import val_cfg
from rich.progress import Progress

start = Progress.start
1 change: 0 additions & 1 deletion tests/data/config/lazy_module_config/test_mix_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,3 @@
chained = list(chain([1, 2], [3, 4]))
existed = ex(__file__)
cfgname = partial(basename, __file__)()

1 change: 0 additions & 1 deletion tests/data/config/lazy_module_config/toy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

param_scheduler.milestones = [2, 4]


train_dataloader = dict(
dataset=dict(type=ToyDataset),
sampler=dict(type=DefaultSampler, shuffle=True),
Expand Down
1 change: 1 addition & 0 deletions tests/data/config/py_config/test_custom_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
class A:
...


item_a = dict(a=A)
2 changes: 1 addition & 1 deletion tests/data/config/py_config/test_dump_pickle_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ def func():
dict_item5 = {'x/x': {'a.0': 233}}
dict_list_item6 = {'x/x': [{'a.0': 1., 'b.0': 2.}, {'c/3': 3.}]}
# Test windows path and escape.
str_item_7 = osp.join(osp.expanduser('~'), 'folder') # with backslash in
str_item_7 = osp.join(osp.expanduser('~'), 'folder') # with backslash in
str_item_8 = func()
9 changes: 2 additions & 7 deletions tests/data/config/py_config/test_get_external_cfg3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@
'mmdet::_base_/models/faster-rcnn_r50_fpn.py',
'mmdet::_base_/datasets/coco_detection.py',
'mmdet::_base_/schedules/schedule_1x.py',
'mmdet::_base_/default_runtime.py',
'./test_get_external_cfg_base.py'
'mmdet::_base_/default_runtime.py', './test_get_external_cfg_base.py'
]

custom_hooks = [dict(type='mmdet.DetVisualizationHook')]

model = dict(
roi_head=dict(
bbox_head=dict(
loss_cls=dict(_delete_=True, type='test.ToyLoss')
)
)
)
bbox_head=dict(loss_cls=dict(_delete_=True, type='test.ToyLoss'))))
10 changes: 8 additions & 2 deletions tests/test_utils/test_package_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import sys
from importlib.metadata import PackageNotFoundError

import pkg_resources # type: ignore
import pytest

from mmengine.utils import get_installed_path, is_installed
Expand All @@ -20,6 +20,12 @@ def test_is_installed():
assert is_installed('optim')
sys.path.pop()

assert is_installed('nonexistentpackage12345') is False
assert is_installed('os') is True # 'os' is a module name
assert is_installed('setuptools') is True
# Should work on both distribution and module name
assert is_installed('pillow') is True and is_installed('PIL') is True


def test_get_install_path():
# TODO: Windows CI may failed in unknown reason. Skip check the value
Expand All @@ -33,5 +39,5 @@ def test_get_install_path():
assert get_installed_path('optim') == osp.join(PYTHONPATH, 'optim')
sys.path.pop()

with pytest.raises(pkg_resources.DistributionNotFound):
with pytest.raises(PackageNotFoundError):
get_installed_path('unknown')
Loading