Skip to content

Commit 43b9088

Browse files
committed
Make extensions give AttributeError when they are disabled
This is how the test suite and presumably some other codes detect if extensions are enabled or not. This also dynamically updates __all__ whenever extensions are enabled or disabled.
1 parent dd01b12 commit 43b9088

File tree

5 files changed

+160
-47
lines changed

5 files changed

+160
-47
lines changed

Diff for: array_api_strict/__init__.py

+22-10
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
1717
"""
1818

19+
__all__ = []
20+
1921
# Warning: __array_api_version__ could change globally with
2022
# set_array_api_strict_flags(). This should always be accessed as an
2123
# attribute, like xp.__array_api_version__, or using
2224
# array_api_strict.get_array_api_strict_flags()['api_version'].
2325
from ._flags import API_VERSION as __array_api_version__
2426

25-
__all__ = ["__array_api_version__"]
27+
__all__ += ["__array_api_version__"]
2628

2729
from ._constants import e, inf, nan, pi, newaxis
2830

@@ -266,19 +268,10 @@
266268
"__array_namespace_info__",
267269
]
268270

269-
# linalg is an extension in the array API spec, which is a sub-namespace. Only
270-
# a subset of functions in it are imported into the top-level namespace.
271-
from . import linalg
272-
273-
__all__ += ["linalg"]
274-
275271
from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot
276272

277273
__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
278274

279-
from . import fft
280-
__all__ += ["fft"]
281-
282275
from ._manipulation_functions import (
283276
concat,
284277
expand_dims,
@@ -330,3 +323,22 @@
330323
from . import _version
331324
__version__ = _version.get_versions()['version']
332325
del _version
326+
327+
328+
# Extensions can be enabled or disabled dynamically. In order to make
329+
# "array_api_strict.linalg" give an AttributeError when it is disabled, we
330+
# use __getattr__. Note that linalg and fft are dynamically added and removed
331+
# from __all__ in set_array_api_strict_flags.
332+
333+
def __getattr__(name):
334+
if name in ['linalg', 'fft']:
335+
if name in get_array_api_strict_flags()['enabled_extensions']:
336+
if name == 'linalg':
337+
from . import _linalg
338+
return _linalg
339+
elif name == 'fft':
340+
from . import _fft
341+
return _fft
342+
else:
343+
raise AttributeError(f"The {name!r} extension has been disabled for array_api_strict")
344+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
File renamed without changes.

Diff for: array_api_strict/_flags.py

+7
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ def set_array_api_strict_flags(
161161
else:
162162
ENABLED_EXTENSIONS = tuple([ext for ext in ENABLED_EXTENSIONS if extension_versions[ext] <= API_VERSION])
163163

164+
array_api_strict.__all__[:] = sorted(set(ENABLED_EXTENSIONS) |
165+
set(array_api_strict.__all__) -
166+
set(default_extensions))
167+
164168
# We have to do this separately or it won't get added as the docstring
165169
set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format(
166170
supported_versions=supported_versions,
@@ -321,6 +325,9 @@ def set_flags_from_environment():
321325
if enabled_extensions == [""]:
322326
enabled_extensions = []
323327
set_array_api_strict_flags(enabled_extensions=enabled_extensions)
328+
else:
329+
# Needed at first import to add linalg and fft to __all__
330+
set_array_api_strict_flags(enabled_extensions=default_extensions)
324331

325332
set_flags_from_environment()
326333

File renamed without changes.

Diff for: array_api_strict/tests/test_flags.py

+131-37
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1+
import sys
2+
import subprocess
3+
14
from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags,
25
reset_array_api_strict_flags)
36
from .._info import (capabilities, default_device, default_dtypes, devices,
47
dtypes)
8+
from .._fft import (fft, ifft, fftn, ifftn, rfft, irfft, rfftn, irfftn, hfft,
9+
ihfft, fftfreq, rfftfreq, fftshift, ifftshift)
10+
from .._linalg import (cholesky, cross, det, diagonal, eigh, eigvalsh, inv,
11+
matmul, matrix_norm, matrix_power, matrix_rank, matrix_transpose, outer, pinv,
12+
qr, slogdet, solve, svd, svdvals, tensordot, trace, vecdot, vector_norm)
513

614
from .. import (asarray, unique_all, unique_counts, unique_inverse,
715
unique_values, nonzero, repeat)
@@ -152,29 +160,29 @@ def test_boolean_indexing():
152160
pytest.raises(RuntimeError, lambda: a[mask])
153161

154162
linalg_examples = {
155-
'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)),
156-
'cross': lambda: xp.linalg.cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])),
157-
'det': lambda: xp.linalg.det(xp.eye(3)),
158-
'diagonal': lambda: xp.linalg.diagonal(xp.eye(3)),
159-
'eigh': lambda: xp.linalg.eigh(xp.eye(3)),
160-
'eigvalsh': lambda: xp.linalg.eigvalsh(xp.eye(3)),
161-
'inv': lambda: xp.linalg.inv(xp.eye(3)),
162-
'matmul': lambda: xp.linalg.matmul(xp.eye(3), xp.eye(3)),
163-
'matrix_norm': lambda: xp.linalg.matrix_norm(xp.eye(3)),
164-
'matrix_power': lambda: xp.linalg.matrix_power(xp.eye(3), 2),
165-
'matrix_rank': lambda: xp.linalg.matrix_rank(xp.eye(3)),
166-
'matrix_transpose': lambda: xp.linalg.matrix_transpose(xp.eye(3)),
167-
'outer': lambda: xp.linalg.outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])),
168-
'pinv': lambda: xp.linalg.pinv(xp.eye(3)),
169-
'qr': lambda: xp.linalg.qr(xp.eye(3)),
170-
'slogdet': lambda: xp.linalg.slogdet(xp.eye(3)),
171-
'solve': lambda: xp.linalg.solve(xp.eye(3), xp.eye(3)),
172-
'svd': lambda: xp.linalg.svd(xp.eye(3)),
173-
'svdvals': lambda: xp.linalg.svdvals(xp.eye(3)),
174-
'tensordot': lambda: xp.linalg.tensordot(xp.eye(3), xp.eye(3)),
175-
'trace': lambda: xp.linalg.trace(xp.eye(3)),
176-
'vecdot': lambda: xp.linalg.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])),
177-
'vector_norm': lambda: xp.linalg.vector_norm(xp.asarray([1., 2., 3.])),
163+
'cholesky': lambda: cholesky(xp.eye(3)),
164+
'cross': lambda: cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])),
165+
'det': lambda: det(xp.eye(3)),
166+
'diagonal': lambda: diagonal(xp.eye(3)),
167+
'eigh': lambda: eigh(xp.eye(3)),
168+
'eigvalsh': lambda: eigvalsh(xp.eye(3)),
169+
'inv': lambda: inv(xp.eye(3)),
170+
'matmul': lambda: matmul(xp.eye(3), xp.eye(3)),
171+
'matrix_norm': lambda: matrix_norm(xp.eye(3)),
172+
'matrix_power': lambda: matrix_power(xp.eye(3), 2),
173+
'matrix_rank': lambda: matrix_rank(xp.eye(3)),
174+
'matrix_transpose': lambda: matrix_transpose(xp.eye(3)),
175+
'outer': lambda: outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])),
176+
'pinv': lambda: pinv(xp.eye(3)),
177+
'qr': lambda: qr(xp.eye(3)),
178+
'slogdet': lambda: slogdet(xp.eye(3)),
179+
'solve': lambda: solve(xp.eye(3), xp.eye(3)),
180+
'svd': lambda: svd(xp.eye(3)),
181+
'svdvals': lambda: svdvals(xp.eye(3)),
182+
'tensordot': lambda: tensordot(xp.eye(3), xp.eye(3)),
183+
'trace': lambda: trace(xp.eye(3)),
184+
'vecdot': lambda: vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])),
185+
'vector_norm': lambda: vector_norm(xp.asarray([1., 2., 3.])),
178186
}
179187

180188
assert set(linalg_examples) == set(xp.linalg.__all__)
@@ -210,20 +218,20 @@ def test_linalg(func_name):
210218
main_namespace_func()
211219

212220
fft_examples = {
213-
'fft': lambda: xp.fft.fft(xp.asarray([0j, 1j, 0j, 0j])),
214-
'ifft': lambda: xp.fft.ifft(xp.asarray([0j, 1j, 0j, 0j])),
215-
'fftn': lambda: xp.fft.fftn(xp.asarray([[0j, 1j], [0j, 0j]])),
216-
'ifftn': lambda: xp.fft.ifftn(xp.asarray([[0j, 1j], [0j, 0j]])),
217-
'rfft': lambda: xp.fft.rfft(xp.asarray([0., 1., 0., 0.])),
218-
'irfft': lambda: xp.fft.irfft(xp.asarray([0j, 1j, 0j, 0j])),
219-
'rfftn': lambda: xp.fft.rfftn(xp.asarray([[0., 1.], [0., 0.]])),
220-
'irfftn': lambda: xp.fft.irfftn(xp.asarray([[0j, 1j], [0j, 0j]])),
221-
'hfft': lambda: xp.fft.hfft(xp.asarray([0j, 1j, 0j, 0j])),
222-
'ihfft': lambda: xp.fft.ihfft(xp.asarray([0., 1., 0., 0.])),
223-
'fftfreq': lambda: xp.fft.fftfreq(4),
224-
'rfftfreq': lambda: xp.fft.rfftfreq(4),
225-
'fftshift': lambda: xp.fft.fftshift(xp.asarray([0j, 1j, 0j, 0j])),
226-
'ifftshift': lambda: xp.fft.ifftshift(xp.asarray([0j, 1j, 0j, 0j])),
221+
'fft': lambda: fft(xp.asarray([0j, 1j, 0j, 0j])),
222+
'ifft': lambda: ifft(xp.asarray([0j, 1j, 0j, 0j])),
223+
'fftn': lambda: fftn(xp.asarray([[0j, 1j], [0j, 0j]])),
224+
'ifftn': lambda: ifftn(xp.asarray([[0j, 1j], [0j, 0j]])),
225+
'rfft': lambda: rfft(xp.asarray([0., 1., 0., 0.])),
226+
'irfft': lambda: irfft(xp.asarray([0j, 1j, 0j, 0j])),
227+
'rfftn': lambda: rfftn(xp.asarray([[0., 1.], [0., 0.]])),
228+
'irfftn': lambda: irfftn(xp.asarray([[0j, 1j], [0j, 0j]])),
229+
'hfft': lambda: hfft(xp.asarray([0j, 1j, 0j, 0j])),
230+
'ihfft': lambda: ihfft(xp.asarray([0., 1., 0., 0.])),
231+
'fftfreq': lambda: fftfreq(4),
232+
'rfftfreq': lambda: rfftfreq(4),
233+
'fftshift': lambda: fftshift(xp.asarray([0j, 1j, 0j, 0j])),
234+
'ifftshift': lambda: ifftshift(xp.asarray([0j, 1j, 0j, 0j])),
227235
}
228236

229237
assert set(fft_examples) == set(xp.fft.__all__)
@@ -276,3 +284,89 @@ def test_api_version_2023_12(func_name):
276284

277285
set_array_api_strict_flags(api_version='2022.12')
278286
pytest.raises(RuntimeError, func)
287+
288+
def test_disabled_extensions():
289+
# Test that xp.extension errors when an extension is disabled, and that
290+
# xp.__all__ is updated properly.
291+
292+
# First test that things are correct on the initial import. Since we have
293+
# already called set_array_api_strict_flags many times throughout running
294+
# the tests, we have to test this in a subprocess.
295+
subprocess_tests = [('''\
296+
import array_api_strict
297+
298+
array_api_strict.linalg # No error
299+
array_api_strict.fft # No error
300+
assert "linalg" in array_api_strict.__all__
301+
assert "fft" in array_api_strict.__all__
302+
assert len(array_api_strict.__all__) == len(set(array_api_strict.__all__))
303+
''', {}),
304+
# Test that the initial population of __all__ works correctly
305+
('''\
306+
from array_api_strict import * # No error
307+
linalg # Should have been imported by the previous line
308+
fft
309+
''', {}),
310+
('''\
311+
from array_api_strict import * # No error
312+
linalg # Should have been imported by the previous line
313+
assert 'fft' not in globals()
314+
''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": "linalg"}),
315+
('''\
316+
from array_api_strict import * # No error
317+
fft # Should have been imported by the previous line
318+
assert 'linalg' not in globals()
319+
''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": "fft"}),
320+
('''\
321+
from array_api_strict import * # No error
322+
assert 'linalg' not in globals()
323+
assert 'fft' not in globals()
324+
''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": ""}),
325+
]
326+
for test, env in subprocess_tests:
327+
try:
328+
subprocess.run([sys.executable, '-c', test], check=True,
329+
capture_output=True, encoding='utf-8', env=env)
330+
except subprocess.CalledProcessError as e:
331+
print(e.stdout, end='')
332+
# Ensure the exception is shown in the output log
333+
raise AssertionError(e.stderr)
334+
335+
assert 'linalg' in xp.__all__
336+
assert 'fft' in xp.__all__
337+
xp.linalg # No error
338+
xp.fft # No error
339+
ns = {}
340+
exec('from array_api_strict import *', ns)
341+
assert 'linalg' in ns
342+
assert 'fft' in ns
343+
344+
set_array_api_strict_flags(enabled_extensions=('linalg',))
345+
assert 'linalg' in xp.__all__
346+
assert 'fft' not in xp.__all__
347+
xp.linalg # No error
348+
pytest.raises(AttributeError, lambda: xp.fft)
349+
ns = {}
350+
exec('from array_api_strict import *', ns)
351+
assert 'linalg' in ns
352+
assert 'fft' not in ns
353+
354+
set_array_api_strict_flags(enabled_extensions=('fft',))
355+
assert 'linalg' not in xp.__all__
356+
assert 'fft' in xp.__all__
357+
pytest.raises(AttributeError, lambda: xp.linalg)
358+
xp.fft # No error
359+
ns = {}
360+
exec('from array_api_strict import *', ns)
361+
assert 'linalg' not in ns
362+
assert 'fft' in ns
363+
364+
set_array_api_strict_flags(enabled_extensions=())
365+
assert 'linalg' not in xp.__all__
366+
assert 'fft' not in xp.__all__
367+
pytest.raises(AttributeError, lambda: xp.linalg)
368+
pytest.raises(AttributeError, lambda: xp.fft)
369+
ns = {}
370+
exec('from array_api_strict import *', ns)
371+
assert 'linalg' not in ns
372+
assert 'fft' not in ns

0 commit comments

Comments
 (0)