Skip to content

ENH: unary functions overhaul; better input validation #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Apr 23, 2025

xref #145

  • Rewrite all unary functions with a generator
  • Disallow numpy generics in binary functions, clip, and where
  • Improve error message when the first argument of where is not an Array
  • Test for device mismatches in the inputs of binary functions, clip, and where
  • Test input-output device propagation in where

@@ -168,9 +231,6 @@ def _array_vals():
for d in _floating_dtypes:
yield asarray(1.0, dtype=d)

# Use the latest version of the standard so all functions are included
set_array_api_strict_flags(api_version="2024.12")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant with auto-applied fixture

@pytest.mark.parametrize("func_name", elementwise_function_input_types)
def test_elementwise_function_vs_numpy_generics(func_name):
"""
Test that NumPy generics are explicitly disallowed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor, trivial: instances of np.generic are typically referred to as numpy scalars. The fact that the type name is np.generic is... not very well known (for the better IMO). Anyway, this is not worth flushing the CI for.

Copy link
Member

@ev-br ev-br left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Let keep this PR open for a while though, in case somebody has opinions on generating unary functions from a decorator. I personally think this is a good change, but there were concerns in #100


_ = func(a, a)
with pytest.raises(TypeError, match="neither Array nor Python scalars"):
func(a, b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to be really thorough, can add also test func(b, a): unlike #103, functions we control.

res = xp.where(cond, 1, x2)
assert res.device == device
res = xp.where(cond, x1, 2)
assert res.device == device
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the following tests are great. I imaging we'll want to parrot them in array-api-tests at some point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants