|
| 1 | +import sys |
| 2 | +import subprocess |
| 3 | + |
1 | 4 | from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags,
|
2 | 5 | reset_array_api_strict_flags)
|
3 | 6 | from .._info import (capabilities, default_device, default_dtypes, devices,
|
4 | 7 | dtypes)
|
| 8 | +from .._fft import (fft, ifft, fftn, ifftn, rfft, irfft, rfftn, irfftn, hfft, |
| 9 | + ihfft, fftfreq, rfftfreq, fftshift, ifftshift) |
| 10 | +from .._linalg import (cholesky, cross, det, diagonal, eigh, eigvalsh, inv, |
| 11 | + matmul, matrix_norm, matrix_power, matrix_rank, matrix_transpose, outer, pinv, |
| 12 | + qr, slogdet, solve, svd, svdvals, tensordot, trace, vecdot, vector_norm) |
5 | 13 |
|
6 | 14 | from .. import (asarray, unique_all, unique_counts, unique_inverse,
|
7 | 15 | unique_values, nonzero, repeat)
|
@@ -152,29 +160,29 @@ def test_boolean_indexing():
|
152 | 160 | pytest.raises(RuntimeError, lambda: a[mask])
|
153 | 161 |
|
154 | 162 | linalg_examples = {
|
155 |
| - 'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)), |
156 |
| - 'cross': lambda: xp.linalg.cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])), |
157 |
| - 'det': lambda: xp.linalg.det(xp.eye(3)), |
158 |
| - 'diagonal': lambda: xp.linalg.diagonal(xp.eye(3)), |
159 |
| - 'eigh': lambda: xp.linalg.eigh(xp.eye(3)), |
160 |
| - 'eigvalsh': lambda: xp.linalg.eigvalsh(xp.eye(3)), |
161 |
| - 'inv': lambda: xp.linalg.inv(xp.eye(3)), |
162 |
| - 'matmul': lambda: xp.linalg.matmul(xp.eye(3), xp.eye(3)), |
163 |
| - 'matrix_norm': lambda: xp.linalg.matrix_norm(xp.eye(3)), |
164 |
| - 'matrix_power': lambda: xp.linalg.matrix_power(xp.eye(3), 2), |
165 |
| - 'matrix_rank': lambda: xp.linalg.matrix_rank(xp.eye(3)), |
166 |
| - 'matrix_transpose': lambda: xp.linalg.matrix_transpose(xp.eye(3)), |
167 |
| - 'outer': lambda: xp.linalg.outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), |
168 |
| - 'pinv': lambda: xp.linalg.pinv(xp.eye(3)), |
169 |
| - 'qr': lambda: xp.linalg.qr(xp.eye(3)), |
170 |
| - 'slogdet': lambda: xp.linalg.slogdet(xp.eye(3)), |
171 |
| - 'solve': lambda: xp.linalg.solve(xp.eye(3), xp.eye(3)), |
172 |
| - 'svd': lambda: xp.linalg.svd(xp.eye(3)), |
173 |
| - 'svdvals': lambda: xp.linalg.svdvals(xp.eye(3)), |
174 |
| - 'tensordot': lambda: xp.linalg.tensordot(xp.eye(3), xp.eye(3)), |
175 |
| - 'trace': lambda: xp.linalg.trace(xp.eye(3)), |
176 |
| - 'vecdot': lambda: xp.linalg.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), |
177 |
| - 'vector_norm': lambda: xp.linalg.vector_norm(xp.asarray([1., 2., 3.])), |
| 163 | + 'cholesky': lambda: cholesky(xp.eye(3)), |
| 164 | + 'cross': lambda: cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])), |
| 165 | + 'det': lambda: det(xp.eye(3)), |
| 166 | + 'diagonal': lambda: diagonal(xp.eye(3)), |
| 167 | + 'eigh': lambda: eigh(xp.eye(3)), |
| 168 | + 'eigvalsh': lambda: eigvalsh(xp.eye(3)), |
| 169 | + 'inv': lambda: inv(xp.eye(3)), |
| 170 | + 'matmul': lambda: matmul(xp.eye(3), xp.eye(3)), |
| 171 | + 'matrix_norm': lambda: matrix_norm(xp.eye(3)), |
| 172 | + 'matrix_power': lambda: matrix_power(xp.eye(3), 2), |
| 173 | + 'matrix_rank': lambda: matrix_rank(xp.eye(3)), |
| 174 | + 'matrix_transpose': lambda: matrix_transpose(xp.eye(3)), |
| 175 | + 'outer': lambda: outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), |
| 176 | + 'pinv': lambda: pinv(xp.eye(3)), |
| 177 | + 'qr': lambda: qr(xp.eye(3)), |
| 178 | + 'slogdet': lambda: slogdet(xp.eye(3)), |
| 179 | + 'solve': lambda: solve(xp.eye(3), xp.eye(3)), |
| 180 | + 'svd': lambda: svd(xp.eye(3)), |
| 181 | + 'svdvals': lambda: svdvals(xp.eye(3)), |
| 182 | + 'tensordot': lambda: tensordot(xp.eye(3), xp.eye(3)), |
| 183 | + 'trace': lambda: trace(xp.eye(3)), |
| 184 | + 'vecdot': lambda: vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), |
| 185 | + 'vector_norm': lambda: vector_norm(xp.asarray([1., 2., 3.])), |
178 | 186 | }
|
179 | 187 |
|
180 | 188 | assert set(linalg_examples) == set(xp.linalg.__all__)
|
@@ -210,20 +218,20 @@ def test_linalg(func_name):
|
210 | 218 | main_namespace_func()
|
211 | 219 |
|
212 | 220 | fft_examples = {
|
213 |
| - 'fft': lambda: xp.fft.fft(xp.asarray([0j, 1j, 0j, 0j])), |
214 |
| - 'ifft': lambda: xp.fft.ifft(xp.asarray([0j, 1j, 0j, 0j])), |
215 |
| - 'fftn': lambda: xp.fft.fftn(xp.asarray([[0j, 1j], [0j, 0j]])), |
216 |
| - 'ifftn': lambda: xp.fft.ifftn(xp.asarray([[0j, 1j], [0j, 0j]])), |
217 |
| - 'rfft': lambda: xp.fft.rfft(xp.asarray([0., 1., 0., 0.])), |
218 |
| - 'irfft': lambda: xp.fft.irfft(xp.asarray([0j, 1j, 0j, 0j])), |
219 |
| - 'rfftn': lambda: xp.fft.rfftn(xp.asarray([[0., 1.], [0., 0.]])), |
220 |
| - 'irfftn': lambda: xp.fft.irfftn(xp.asarray([[0j, 1j], [0j, 0j]])), |
221 |
| - 'hfft': lambda: xp.fft.hfft(xp.asarray([0j, 1j, 0j, 0j])), |
222 |
| - 'ihfft': lambda: xp.fft.ihfft(xp.asarray([0., 1., 0., 0.])), |
223 |
| - 'fftfreq': lambda: xp.fft.fftfreq(4), |
224 |
| - 'rfftfreq': lambda: xp.fft.rfftfreq(4), |
225 |
| - 'fftshift': lambda: xp.fft.fftshift(xp.asarray([0j, 1j, 0j, 0j])), |
226 |
| - 'ifftshift': lambda: xp.fft.ifftshift(xp.asarray([0j, 1j, 0j, 0j])), |
| 221 | + 'fft': lambda: fft(xp.asarray([0j, 1j, 0j, 0j])), |
| 222 | + 'ifft': lambda: ifft(xp.asarray([0j, 1j, 0j, 0j])), |
| 223 | + 'fftn': lambda: fftn(xp.asarray([[0j, 1j], [0j, 0j]])), |
| 224 | + 'ifftn': lambda: ifftn(xp.asarray([[0j, 1j], [0j, 0j]])), |
| 225 | + 'rfft': lambda: rfft(xp.asarray([0., 1., 0., 0.])), |
| 226 | + 'irfft': lambda: irfft(xp.asarray([0j, 1j, 0j, 0j])), |
| 227 | + 'rfftn': lambda: rfftn(xp.asarray([[0., 1.], [0., 0.]])), |
| 228 | + 'irfftn': lambda: irfftn(xp.asarray([[0j, 1j], [0j, 0j]])), |
| 229 | + 'hfft': lambda: hfft(xp.asarray([0j, 1j, 0j, 0j])), |
| 230 | + 'ihfft': lambda: ihfft(xp.asarray([0., 1., 0., 0.])), |
| 231 | + 'fftfreq': lambda: fftfreq(4), |
| 232 | + 'rfftfreq': lambda: rfftfreq(4), |
| 233 | + 'fftshift': lambda: fftshift(xp.asarray([0j, 1j, 0j, 0j])), |
| 234 | + 'ifftshift': lambda: ifftshift(xp.asarray([0j, 1j, 0j, 0j])), |
227 | 235 | }
|
228 | 236 |
|
229 | 237 | assert set(fft_examples) == set(xp.fft.__all__)
|
@@ -276,3 +284,89 @@ def test_api_version_2023_12(func_name):
|
276 | 284 |
|
277 | 285 | set_array_api_strict_flags(api_version='2022.12')
|
278 | 286 | pytest.raises(RuntimeError, func)
|
| 287 | + |
| 288 | +def test_disabled_extensions(): |
| 289 | + # Test that xp.extension errors when an extension is disabled, and that |
| 290 | + # xp.__all__ is updated properly. |
| 291 | + |
| 292 | + # First test that things are correct on the initial import. Since we have |
| 293 | + # already called set_array_api_strict_flags many times throughout running |
| 294 | + # the tests, we have to test this in a subprocess. |
| 295 | + subprocess_tests = [('''\ |
| 296 | +import array_api_strict |
| 297 | +
|
| 298 | +array_api_strict.linalg # No error |
| 299 | +array_api_strict.fft # No error |
| 300 | +assert "linalg" in array_api_strict.__all__ |
| 301 | +assert "fft" in array_api_strict.__all__ |
| 302 | +assert len(array_api_strict.__all__) == len(set(array_api_strict.__all__)) |
| 303 | +''', {}), |
| 304 | +# Test that the initial population of __all__ works correctly |
| 305 | +('''\ |
| 306 | +from array_api_strict import * # No error |
| 307 | +linalg # Should have been imported by the previous line |
| 308 | +fft |
| 309 | +''', {}), |
| 310 | +('''\ |
| 311 | +from array_api_strict import * # No error |
| 312 | +linalg # Should have been imported by the previous line |
| 313 | +assert 'fft' not in globals() |
| 314 | +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": "linalg"}), |
| 315 | +('''\ |
| 316 | +from array_api_strict import * # No error |
| 317 | +fft # Should have been imported by the previous line |
| 318 | +assert 'linalg' not in globals() |
| 319 | +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": "fft"}), |
| 320 | +('''\ |
| 321 | +from array_api_strict import * # No error |
| 322 | +assert 'linalg' not in globals() |
| 323 | +assert 'fft' not in globals() |
| 324 | +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": ""}), |
| 325 | +] |
| 326 | + for test, env in subprocess_tests: |
| 327 | + try: |
| 328 | + subprocess.run([sys.executable, '-c', test], check=True, |
| 329 | + capture_output=True, encoding='utf-8', env=env) |
| 330 | + except subprocess.CalledProcessError as e: |
| 331 | + print(e.stdout, end='') |
| 332 | + # Ensure the exception is shown in the output log |
| 333 | + raise AssertionError(e.stderr) |
| 334 | + |
| 335 | + assert 'linalg' in xp.__all__ |
| 336 | + assert 'fft' in xp.__all__ |
| 337 | + xp.linalg # No error |
| 338 | + xp.fft # No error |
| 339 | + ns = {} |
| 340 | + exec('from array_api_strict import *', ns) |
| 341 | + assert 'linalg' in ns |
| 342 | + assert 'fft' in ns |
| 343 | + |
| 344 | + set_array_api_strict_flags(enabled_extensions=('linalg',)) |
| 345 | + assert 'linalg' in xp.__all__ |
| 346 | + assert 'fft' not in xp.__all__ |
| 347 | + xp.linalg # No error |
| 348 | + pytest.raises(AttributeError, lambda: xp.fft) |
| 349 | + ns = {} |
| 350 | + exec('from array_api_strict import *', ns) |
| 351 | + assert 'linalg' in ns |
| 352 | + assert 'fft' not in ns |
| 353 | + |
| 354 | + set_array_api_strict_flags(enabled_extensions=('fft',)) |
| 355 | + assert 'linalg' not in xp.__all__ |
| 356 | + assert 'fft' in xp.__all__ |
| 357 | + pytest.raises(AttributeError, lambda: xp.linalg) |
| 358 | + xp.fft # No error |
| 359 | + ns = {} |
| 360 | + exec('from array_api_strict import *', ns) |
| 361 | + assert 'linalg' not in ns |
| 362 | + assert 'fft' in ns |
| 363 | + |
| 364 | + set_array_api_strict_flags(enabled_extensions=()) |
| 365 | + assert 'linalg' not in xp.__all__ |
| 366 | + assert 'fft' not in xp.__all__ |
| 367 | + pytest.raises(AttributeError, lambda: xp.linalg) |
| 368 | + pytest.raises(AttributeError, lambda: xp.fft) |
| 369 | + ns = {} |
| 370 | + exec('from array_api_strict import *', ns) |
| 371 | + assert 'linalg' not in ns |
| 372 | + assert 'fft' not in ns |
0 commit comments