Skip to content

Conversation

aseyboldt
Copy link
Member

@aseyboldt aseyboldt commented Sep 16, 2025

revisit #1120, which seems abandoned.

@jdehning I hope it is ok if I continue this PR?


📚 Documentation preview 📚: https://pytensor--1614.org.readthedocs.build/en/1614/

@aseyboldt aseyboldt force-pushed the as-jax-opt2 branch 3 times, most recently from 10dfa2e to ead4ac7 Compare September 16, 2025 15:08
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements the as_jax_op decorator which allows JAX functions to be used within PyTensor graphs. The decorator wraps JAX functions to make them compatible with PyTensor's variable system while preserving gradient computation capabilities.

  • Implements JAXOp class for wrapping JAX functions as PyTensor operations
  • Creates as_jax_op decorator for easy conversion of JAX functions to PyTensor-compatible operations
  • Adds comprehensive test coverage for various input/output patterns and data types

Reviewed Changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
pytensor/link/jax/ops.py Core implementation of JAXOp class and as_jax_op decorator
tests/link/jax/test_as_jax_op.py Comprehensive test suite covering various use cases and data types
pytensor/init.py Exports as_jax_op function with fallback for missing dependencies
pytensor/link/jax/dispatch/basic.py JAX dispatch registration for the new JAXOp
doc/library/index.rst Documentation entry for the new functionality
doc/environment.yml Updates documentation environment to include JAX dependencies
doc/conf.py Adds Equinox to intersphinx mapping
.github/workflows/test.yml Updates CI to install Equinox dependency

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Copy link

codecov bot commented Sep 16, 2025

Codecov Report

❌ Patch coverage is 90.84507% with 13 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.72%. Comparing base (1dc982c) to head (f832126).

Files with missing lines Patch % Lines
pytensor/link/jax/ops.py 90.57% 11 Missing and 2 partials ⚠️

❌ Your patch check has failed because the patch coverage (90.84%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1614      +/-   ##
==========================================
+ Coverage   81.69%   81.72%   +0.02%     
==========================================
  Files         230      231       +1     
  Lines       52950    53092     +142     
  Branches     9404     9419      +15     
==========================================
+ Hits        43260    43389     +129     
- Misses       7256     7267      +11     
- Partials     2434     2436       +2     
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/basic.py 83.52% <100.00%> (+0.81%) ⬆️
pytensor/link/jax/ops.py 90.57% <90.57%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

if any(s is None for s in shape):
_, shape = pt.basic.infer_static_shape(var.shape)
if any(s is None for s in shape):
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use this instead? https://docs.jax.dev/en/latest/export/shape_poly.html#shape-polymorphism

PyTensor only needs to know the dtype and rank of the outputs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that would be reliable.
I think jax will throw an error if the code tries to broadcast arrays when it cannot prove that they have compatible shapes.
If we have dims, we could use those to generate jax symbolic shapes. (But only with object dims, not with string ones I think?).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best we could do right now is create a new shape variable for every input dimension that is not statically known.
But then it would fail as soon as you even add two of those together.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it fail or does it infer they must match?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails:

u, v = export.symbolic_shape("u, v")

x1 = jax.ShapeDtypeStruct((u,), dtype=np.int32)
x2 = jax.ShapeDtypeStruct((v,), dtype=np.int32)

export.export(jax.jit(lambda x, y: x + y))(x1, x2)
# add got incompatible shapes for broadcasting: (u,), (v,).

@aseyboldt
Copy link
Member Author

I kept most of what was in the original PR, but made a few changes:

  • There is no longer a different op for the gradient. That is just again a JaxOp
  • I kept support for jax tree inputs and outputs, I think those are quite valuable. For instance when we have a neural network in a model, or if we want to solve an ODE, it is much nicer if we don't have to take apart all jax trees everywhere by hand. I did remove wrapping of returned functions though. That lead to some trouble if the jax trees contain callables that should not be wrapped, and seems overall a bit hackish to me. I also can't think of a use-case where we would really need that? If that does come along, maybe we can revisit this idea.

I think the biggest problem right now is that the as_jax_op wrapper needs pytensor inputs with static shapes.

If we want to avoid having the user specify the output types, we need to call the function at least once. We can do that with jax.infer_shape, but that still needs static shape info.

I'm not sure right now what the best way to handle this is.

@jdehning
Copy link

revisit #1120, which seems abandoned.

@jdehning I hope it is ok if I continue this PR?

📚 Documentation preview 📚: https://pytensor--1614.org.readthedocs.build/en/1614/

Yes, sure. Sorry that I dropped the ball.

@jdehning
Copy link

I think the biggest problem right now is that the as_jax_op wrapper needs pytensor inputs with static shapes.

I did use pytensor.compile.builders.infer_shape to get static shapes in the original PR. It did work for me for pymc models, if initial static shapes are lost because of a pt.cumsum. However, if I remember well, I didn't test whether it works with pm.Data, i.e. shared variables in the graph, and what happens when the shape of shared variables is changed between runs by setting new values

@jdehning
Copy link

I did remove wrapping of returned functions though. That lead to some trouble if the jax trees contain callables that should not be wrapped, and seems overall a bit hackish to me. I also can't think of a use-case where we would really need that? If that does come along, maybe we can revisit this idea.

I wrote it for ODEs that depend on time-dependent parameters; we need a function that takes a time point and returns some time-changing variables that interpolate between parameters. Wrapping the callable was the most user-friendly way to achieve it, as it allows defining the interpolation function and ODE solver separately. However, I agree it was somewhat hackish and not easily parsable. And its usage can be reasonably well avoided if both the interpolation function and the ODE solver are defined in a single function.

@aseyboldt
Copy link
Member Author

Yes, sure. Sorry that I dropped the ball.

No worries :-)

I did use pytensor.compile.builders.infer_shape to get static shapes in the original PR.

That works for cases where the shape is in principle known, but pytensor missed it during the static shape inference.
It does not help if the shape really is dynamic, which is the case if you use dimensions in a pymc model.

In the case where the input shapes only depend on random variables (which will usually be the case in pymc models), we Ricardo realized that we can just eval the shapes once. It should be also be fine if we find just some valid input shapes, if the shapes change later that shouldn't be a problem. We only have to set the output shapes to None in that case, because we don't actually know what the output shapes should be.

I wrote it for ODEs that depend on time-dependent parameters; we need a function that takes a time point and returns some time-changing variables that interpolate between parameters. Wrapping the callable was the most user-friendly way to achieve it, as it allows defining the interpolation function and ODE solver separately. However, I agree it was somewhat hackish and not easily parsable. And its usage can be reasonably well avoided if both the interpolation function and the ODE solver are defined in a single function.

I thought it might have been something like that. I think that makes sense, but for now I'd rather get this in as-is. We can always revisit this later if we think we need it.

@aseyboldt
Copy link
Member Author

We also should decide on a name: I think wrap_jax is maybe better than as_jax_op?
Or possibly jax_to_pytensor or jax_bridge?

@aseyboldt aseyboldt marked this pull request as ready for review September 16, 2025 18:47
@@ -167,6 +167,18 @@ def get_underlying_scalar_constant(v):
from pytensor.scan.views import foldl, foldr, map, reduce
from pytensor.compile.builders import OpFromGraph

try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this eager import? would be great if we didn't need to slowdown PyTensor import by default

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the try/except, that was already in the original function.
as_jax_op shouldn't try to import jax now, the ops.py file delays importing jax.

[array([1., 1.], dtype=float32)]
"""

__props__ = ("input_types", "output_types", "jax_func", "name")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why name in props, only stuff that affects the behavior of the Op should be

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

)

# Create VJP operation
vjp_op = JAXOp(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTensor won't be able to fuse duplicate applications of the grad (if they somehow exist) as it will create a new function under the hood and that's used for the Op equality (whatever is on __props__ is used)

Not a blocker, just laying it out. Alternatively if the grad was an Op parametrized by the original Op/func and the connected gradients it could be merged.

However, if running on jax backed, JAX itself may be able to avoid duplication.

I would err on keeping it simple for now like you did

[],
resolved_input_shapes,
on_unused_input="ignore",
mode="FAST_COMPILE",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mode="FAST_COMPILE",
mode=Mode(linker="py", optimizer="fast_compile"),

No reason for C code I guess

@ricardoV94
Copy link
Member

I like as_jax_op in line with as_op?

They do the same except in this case we also know how to generate gradients and dispatch to jax?

@ricardoV94
Copy link
Member

Alternatively, jax_as_op?

@aseyboldt
Copy link
Member Author

as_jax_op sounds like it should do the opposite of what it does, export something to jax. jax_as_op solves that problem...

I don't think Op is a good name to put into the public interface that much. Most people who use pymc won't even know what we mean by it. The function also doesn't return an op, it just happens to create one internally, and doesn't even show that to the user.

I see the argument of it being like as_op, but I also don't quite like that name for the same reasons, and I don't think many people are using it in the first place.

Other options: pytensorize_jax or jax_to_pytensor?

I think something that describes what it does, instead of how it does it would be much friendlier?

@jdehning
Copy link

+1 for jax_to_pytensor. It is for me the most easily understandable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants