-
Notifications
You must be signed in to change notification settings - Fork 139
Implement as_jax_op #1614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement as_jax_op #1614
Conversation
…he previous approach for testing purposes
…be used without the decorator @as_jax_op
10dfa2e
to
ead4ac7
Compare
ead4ac7
to
d04f41d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements the as_jax_op
decorator which allows JAX functions to be used within PyTensor graphs. The decorator wraps JAX functions to make them compatible with PyTensor's variable system while preserving gradient computation capabilities.
- Implements
JAXOp
class for wrapping JAX functions as PyTensor operations - Creates
as_jax_op
decorator for easy conversion of JAX functions to PyTensor-compatible operations - Adds comprehensive test coverage for various input/output patterns and data types
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
Show a summary per file
File | Description |
---|---|
pytensor/link/jax/ops.py | Core implementation of JAXOp class and as_jax_op decorator |
tests/link/jax/test_as_jax_op.py | Comprehensive test suite covering various use cases and data types |
pytensor/init.py | Exports as_jax_op function with fallback for missing dependencies |
pytensor/link/jax/dispatch/basic.py | JAX dispatch registration for the new JAXOp |
doc/library/index.rst | Documentation entry for the new functionality |
doc/environment.yml | Updates documentation environment to include JAX dependencies |
doc/conf.py | Adds Equinox to intersphinx mapping |
.github/workflows/test.yml | Updates CI to install Equinox dependency |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (90.84%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1614 +/- ##
==========================================
+ Coverage 81.69% 81.72% +0.02%
==========================================
Files 230 231 +1
Lines 52950 53092 +142
Branches 9404 9419 +15
==========================================
+ Hits 43260 43389 +129
- Misses 7256 7267 +11
- Partials 2434 2436 +2
🚀 New features to boost your workflow:
|
pytensor/link/jax/ops.py
Outdated
if any(s is None for s in shape): | ||
_, shape = pt.basic.infer_static_shape(var.shape) | ||
if any(s is None for s in shape): | ||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use this instead? https://docs.jax.dev/en/latest/export/shape_poly.html#shape-polymorphism
PyTensor only needs to know the dtype and rank of the outputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that would be reliable.
I think jax will throw an error if the code tries to broadcast arrays when it cannot prove that they have compatible shapes.
If we have dims, we could use those to generate jax symbolic shapes. (But only with object dims, not with string ones I think?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The best we could do right now is create a new shape variable for every input dimension that is not statically known.
But then it would fail as soon as you even add two of those together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it fail or does it infer they must match?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It fails:
u, v = export.symbolic_shape("u, v")
x1 = jax.ShapeDtypeStruct((u,), dtype=np.int32)
x2 = jax.ShapeDtypeStruct((v,), dtype=np.int32)
export.export(jax.jit(lambda x, y: x + y))(x1, x2)
# add got incompatible shapes for broadcasting: (u,), (v,).
I kept most of what was in the original PR, but made a few changes:
I think the biggest problem right now is that the If we want to avoid having the user specify the output types, we need to call the function at least once. We can do that with I'm not sure right now what the best way to handle this is. |
Yes, sure. Sorry that I dropped the ball. |
I did use pytensor.compile.builders.infer_shape to get static shapes in the original PR. It did work for me for pymc models, if initial static shapes are lost because of a pt.cumsum. However, if I remember well, I didn't test whether it works with pm.Data, i.e. shared variables in the graph, and what happens when the shape of shared variables is changed between runs by setting new values |
I wrote it for ODEs that depend on time-dependent parameters; we need a function that takes a time point and returns some time-changing variables that interpolate between parameters. Wrapping the callable was the most user-friendly way to achieve it, as it allows defining the interpolation function and ODE solver separately. However, I agree it was somewhat hackish and not easily parsable. And its usage can be reasonably well avoided if both the interpolation function and the ODE solver are defined in a single function. |
No worries :-)
That works for cases where the shape is in principle known, but pytensor missed it during the static shape inference. In the case where the input shapes only depend on random variables (which will usually be the case in pymc models), we Ricardo realized that we can just eval the shapes once. It should be also be fine if we find just some valid input shapes, if the shapes change later that shouldn't be a problem. We only have to set the output shapes to None in that case, because we don't actually know what the output shapes should be.
I thought it might have been something like that. I think that makes sense, but for now I'd rather get this in as-is. We can always revisit this later if we think we need it. |
We also should decide on a name: I think |
pytensor/__init__.py
Outdated
@@ -167,6 +167,18 @@ def get_underlying_scalar_constant(v): | |||
from pytensor.scan.views import foldl, foldr, map, reduce | |||
from pytensor.compile.builders import OpFromGraph | |||
|
|||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this eager import? would be great if we didn't need to slowdown PyTensor import by default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the try/except, that was already in the original function.
as_jax_op shouldn't try to import jax now, the ops.py file delays importing jax.
pytensor/link/jax/ops.py
Outdated
[array([1., 1.], dtype=float32)] | ||
""" | ||
|
||
__props__ = ("input_types", "output_types", "jax_func", "name") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why name in props, only stuff that affects the behavior of the Op should be
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
) | ||
|
||
# Create VJP operation | ||
vjp_op = JAXOp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTensor won't be able to fuse duplicate applications of the grad (if they somehow exist) as it will create a new function under the hood and that's used for the Op equality (whatever is on __props__
is used)
Not a blocker, just laying it out. Alternatively if the grad was an Op parametrized by the original Op/func and the connected gradients it could be merged.
However, if running on jax backed, JAX itself may be able to avoid duplication.
I would err on keeping it simple for now like you did
pytensor/link/jax/ops.py
Outdated
[], | ||
resolved_input_shapes, | ||
on_unused_input="ignore", | ||
mode="FAST_COMPILE", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mode="FAST_COMPILE", | |
mode=Mode(linker="py", optimizer="fast_compile"), |
No reason for C code I guess
I like They do the same except in this case we also know how to generate gradients and dispatch to jax? |
Alternatively, |
I don't think I see the argument of it being like Other options: I think something that describes what it does, instead of how it does it would be much friendlier? |
+1 for |
revisit #1120, which seems abandoned.
@jdehning I hope it is ok if I continue this PR?
📚 Documentation preview 📚: https://pytensor--1614.org.readthedocs.build/en/1614/