Skip to content

Commit 2e4e185

Browse files
awaelchliBorda
andauthored
Fix extras check in RequirementCache (#283)
Co-authored-by: Jirka Borovec <[email protected]>
1 parent d42f7c0 commit 2e4e185

File tree

3 files changed

+73
-3
lines changed

3 files changed

+73
-3
lines changed

src/lightning_utilities/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22

3-
__version__ = "0.11.4"
3+
__version__ = "0.11.5"
44
__author__ = "Lightning AI et al."
55
__author_email__ = "[email protected]"
66
__license__ = "Apache-2.0"

src/lightning_utilities/core/imports.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88
import warnings
99
from functools import lru_cache
10-
from importlib.metadata import PackageNotFoundError
10+
from importlib.metadata import PackageNotFoundError, distribution
1111
from importlib.metadata import version as _version
1212
from importlib.util import find_spec
1313
from types import ModuleType
@@ -128,7 +128,9 @@ def _check_requirement(self) -> None:
128128
try:
129129
req = Requirement(self.requirement)
130130
pkg_version = Version(_version(req.name))
131-
self.available = req.specifier.contains(pkg_version)
131+
self.available = req.specifier.contains(pkg_version) and (
132+
not req.extras or self._check_extras_available(req)
133+
)
132134
except (PackageNotFoundError, InvalidVersion) as ex:
133135
self.available = False
134136
self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`"
@@ -143,6 +145,9 @@ def _check_requirement(self) -> None:
143145
self.available = module_available(module)
144146
if self.available:
145147
self.message = f"Module {module!r} available"
148+
self.message = (
149+
f"Requirement {self.requirement!r} not met. HINT: Try running `pip install -U {self.requirement!r}`"
150+
)
146151

147152
def _check_module(self) -> None:
148153
assert self.module # noqa: S101; needed for typing
@@ -160,6 +165,34 @@ def _check_available(self) -> None:
160165
if getattr(self, "available", True) and self.module:
161166
self._check_module()
162167

168+
def _check_extras_available(self, requirement: Requirement) -> bool:
169+
if not requirement.extras:
170+
return True
171+
172+
extra_requirements = self._get_extra_requirements(requirement)
173+
174+
if not extra_requirements:
175+
# The specified extra is not found in the package metadata
176+
return False
177+
178+
# Verify each extra requirement is installed
179+
for extra_req in extra_requirements:
180+
try:
181+
extra_dist = distribution(extra_req.name)
182+
extra_installed_version = Version(extra_dist.version)
183+
if extra_req.specifier and not extra_req.specifier.contains(extra_installed_version):
184+
return False
185+
except importlib.metadata.PackageNotFoundError:
186+
return False
187+
188+
return True
189+
190+
def _get_extra_requirements(self, requirement: Requirement) -> List[Requirement]:
191+
dist = distribution(requirement.name)
192+
# Get the required dependencies for the specified extras
193+
extra_requirements = dist.metadata.get_all("Requires-Dist") or []
194+
return [Requirement(r) for r in extra_requirements if any(extra in r for extra in requirement.extras)]
195+
163196
def __bool__(self) -> bool:
164197
"""Format as bool."""
165198
self._check_available()

tests/unittests/core/test_imports.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import operator
22
import re
3+
from unittest import mock
4+
from unittest.mock import Mock
35

46
import pytest
57
from lightning_utilities.core.imports import (
@@ -61,6 +63,41 @@ def test_requirement_cache():
6163
assert not cache
6264
assert "pip install -U 'this_module_is_not_installed" in str(cache)
6365

66+
cache = RequirementCache("pytest[not-valid-extra]")
67+
assert not cache
68+
assert "pip install -U 'pytest[not-valid-extra]" in str(cache)
69+
70+
71+
@mock.patch("lightning_utilities.core.imports.Requirement")
72+
@mock.patch("lightning_utilities.core.imports._version")
73+
@mock.patch("lightning_utilities.core.imports.distribution")
74+
def test_requirement_cache_with_extras(distribution_mock, version_mock, requirement_mock):
75+
requirement_mock().specifier.contains.return_value = True
76+
requirement_mock().name = "jsonargparse"
77+
requirement_mock().extras = []
78+
version_mock.return_value = "1.0.0"
79+
assert RequirementCache("jsonargparse>=1.0.0")
80+
81+
with mock.patch("lightning_utilities.core.imports.RequirementCache._get_extra_requirements") as get_extra_req_mock:
82+
get_extra_req_mock.return_value = [
83+
# Extra packages, all versions satisfied
84+
Mock(name="extra_package1", specifier=Mock(contains=Mock(return_value=True))),
85+
Mock(name="extra_package2", specifier=Mock(contains=Mock(return_value=True))),
86+
]
87+
distribution_mock.return_value = Mock(version="0.10.0")
88+
requirement_mock().extras = ["signatures"]
89+
assert RequirementCache("jsonargparse[signatures]>=1.0.0")
90+
91+
with mock.patch("lightning_utilities.core.imports.RequirementCache._get_extra_requirements") as get_extra_req_mock:
92+
get_extra_req_mock.return_value = [
93+
# Extra packages, but not all versions are satisfied
94+
Mock(name="extra_package1", specifier=Mock(contains=Mock(return_value=True))),
95+
Mock(name="extra_package2", specifier=Mock(contains=Mock(return_value=False))),
96+
]
97+
distribution_mock.return_value = Mock(version="0.10.0")
98+
requirement_mock().extras = ["signatures"]
99+
assert not RequirementCache("jsonargparse[signatures]>=1.0.0")
100+
64101

65102
def test_module_available_cache():
66103
assert RequirementCache(module="pytest")

0 commit comments

Comments
 (0)