From d30595edd8df91bc7235b8173d92f6df6e20e26d Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 23 Oct 2025 09:51:29 -0400 Subject: [PATCH 1/4] Bump jax to 0.7.2 --- .dep-versions | 2 +- frontend/catalyst/jax_extras/patches.py | 30 ++++++++++++++++++++- frontend/test/pytest/test_jax_primitives.py | 10 ++++--- 3 files changed, 37 insertions(+), 5 deletions(-) diff --git a/.dep-versions b/.dep-versions index 943b6e8691..36886deb2d 100644 --- a/.dep-versions +++ b/.dep-versions @@ -1,7 +1,7 @@ # Always update the version check in catalyst.__init__ when changing the JAX version. # To update JAX version alongside compatible dependency tags, run the following script: # python3 .github/workflows/set_dep_versions.py {JAX_version} -jax=5712de44e97c455faed1fd45532e821ca66d025a +jax=94233144f5469af28c065aa4263a6849338eeaa1 stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d llvm=113f01aa82d055410f22a9d03b3468fa68600589 enzyme=v0.0.203 diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index 82f2150e1f..dd3281ad71 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -20,7 +20,7 @@ from functools import partial import jax -from jax._src.core import abstractify, standard_vma_rule +from jax._src.core import abstractify, standard_vma_rule, typeof, set_current_trace from jax._src.lax.slicing import ( _argnum_weak_type, _gather_dtype_rule, @@ -243,6 +243,21 @@ def patch_primitives(): This patch wraps the bind method to convert lists to tuples to make them hashable. """ + import numpy as np + + def patched_literal_int(_cls, value: int, _dtype: np.dtype): + return int.__new__(int, value) + + def patched_literal_float(_cls, value: float, _dtype: np.dtype): + return float.__new__(float, value) + + def patched_literal_complex(_cls, value: complex, _dtype: np.dtype): + return complex.__new__(complex, value) + + jax._src.literals.LiteralInt.__new__ = patched_literal_int + jax._src.literals.LiteralFloat.__new__ = patched_literal_float + jax._src.literals.LiteralComplex.__new__ = patched_literal_complex + def make_hashable(value): """Recursively convert lists to tuples to make them hashable.""" if isinstance(value, list): @@ -462,5 +477,18 @@ def patched_multi_broadcast_in_dim(ctx, ops, ops_avals, out_shape, out_sharding= mlir.multi_broadcast_in_dim = patched_multi_broadcast_in_dim + def patched_bind_with_trace(self, trace, args, params): + try: + in_type = list(map(typeof, args)) + except Exception: # pylint: disable=broad-exception-caught + return trace.process_primitive(self, args, params) + else: + if self.is_high(*in_type, **params) and trace.requires_low: + with set_current_trace(trace): + return self.to_lojax(*args, **params) + return trace.process_primitive(self, args, params) + + core.Primitive.bind_with_trace = patched_bind_with_trace + except ImportError: pass diff --git a/frontend/test/pytest/test_jax_primitives.py b/frontend/test/pytest/test_jax_primitives.py index 44d5995bae..1f4a2caac5 100644 --- a/frontend/test/pytest/test_jax_primitives.py +++ b/frontend/test/pytest/test_jax_primitives.py @@ -21,7 +21,7 @@ import jax import pytest -from jax import make_jaxpr +from jax import make_jaxpr, core from jax._src.lib.mlir import ir from jax.interpreters.mlir import ir_constant, make_ir_context @@ -50,7 +50,9 @@ def test_extract_wire_type_error(self, test_input): ctx = make_ir_context() jax_ctx = JAXCTX(MODCTX(ctx)) with ir.Location.unknown(ctx): - index_value = ir_constant(test_input) + dtype = jax.numpy.dtype(type(test_input)) + aval = core.ShapedArray([], dtype) + index_value = ir_constant(test_input, aval=aval) qreg_value = VALUE(ir.OpaqueType.get("quantum", "reg")) with pytest.raises(TypeError, match="Operator wires expected to be integers"): _qextract_lowering(jax_ctx, qreg_value, index_value) @@ -62,7 +64,9 @@ def test_insert_wire_type_error(self, test_input): ctx = make_ir_context() jax_ctx = JAXCTX(MODCTX(ctx)) with ir.Location.unknown(ctx): - index_value = ir_constant(test_input) + dtype = jax.numpy.dtype(type(test_input)) + aval = core.ShapedArray([], dtype) + index_value = ir_constant(test_input, aval=aval) qreg_value = VALUE(ir.OpaqueType.get("quantum", "reg")) qbit_value = VALUE(ir.OpaqueType.get("quantum", "bit")) with pytest.raises(TypeError, match="Operator wires expected to be integers"): From 1cdf1c0976dbbfdddf6cb06966dc5ca6af211f7e Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 23 Oct 2025 10:30:16 -0400 Subject: [PATCH 2/4] fix literal array --- frontend/catalyst/jax_extras/patches.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index dd3281ad71..6f0b23c445 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -254,9 +254,16 @@ def patched_literal_float(_cls, value: float, _dtype: np.dtype): def patched_literal_complex(_cls, value: complex, _dtype: np.dtype): return complex.__new__(complex, value) + def patched_literal_array( + _cls, val: np.ndarray, weak_type: bool = False + ): # pylint: disable=unused-argument + arr = np.asarray(val) + return arr.view(np.ndarray) + jax._src.literals.LiteralInt.__new__ = patched_literal_int jax._src.literals.LiteralFloat.__new__ = patched_literal_float jax._src.literals.LiteralComplex.__new__ = patched_literal_complex + jax._src.literals.LiteralArray.__new__ = patched_literal_array def make_hashable(value): """Recursively convert lists to tuples to make them hashable.""" From cc18a5b605ad45e5e81709e7e633668c2cf0dd77 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 6 Nov 2025 13:54:44 -0500 Subject: [PATCH 3/4] fix --- frontend/catalyst/jax_extras/patches.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index d23c010553..7c51b81824 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -23,7 +23,7 @@ import jax._src.interpreters.mlir as mlir import jax._src.interpreters.partial_eval as pe from jax._src import config, core, source_info_util -from jax._src.core import JaxprEqnContext, Var, abstractify, standard_vma_rule, typeof +from jax._src.core import JaxprEqnContext, Var, abstractify, standard_vma_rule, typeof, set_current_trace from jax._src.interpreters import pxla from jax._src.interpreters.partial_eval import ( DynamicJaxprTracer, @@ -390,7 +390,7 @@ def patched_bind_with_trace(self, trace, args, params): return trace.process_primitive(self, args, params) else: if self.is_high(*in_type, **params) and trace.requires_low: - with set_current_trace(trace): + with core.set_current_trace(trace): return self.to_lojax(*args, **params) return trace.process_primitive(self, args, params) From bc2a71998d63ccaae3e403c35d59c3ea128fdd9d Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 6 Nov 2025 13:55:17 -0500 Subject: [PATCH 4/4] fix --- frontend/catalyst/jax_extras/patches.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index 7c51b81824..c96368ea7b 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -23,7 +23,7 @@ import jax._src.interpreters.mlir as mlir import jax._src.interpreters.partial_eval as pe from jax._src import config, core, source_info_util -from jax._src.core import JaxprEqnContext, Var, abstractify, standard_vma_rule, typeof, set_current_trace +from jax._src.core import JaxprEqnContext, Var, abstractify, standard_vma_rule, typeof from jax._src.interpreters import pxla from jax._src.interpreters.partial_eval import ( DynamicJaxprTracer, @@ -253,21 +253,26 @@ def _gather_shape_rule_dynamic( import numpy as np + def patched_literal_int(_cls, value: int, _dtype: np.dtype): return int.__new__(int, value) + def patched_literal_float(_cls, value: float, _dtype: np.dtype): return float.__new__(float, value) + def patched_literal_complex(_cls, value: complex, _dtype: np.dtype): return complex.__new__(complex, value) + def patched_literal_array( _cls, val: np.ndarray, weak_type: bool = False ): # pylint: disable=unused-argument arr = np.asarray(val) return arr.view(np.ndarray) + jax._src.literals.LiteralInt.__new__ = patched_literal_int jax._src.literals.LiteralFloat.__new__ = patched_literal_float jax._src.literals.LiteralComplex.__new__ = patched_literal_complex @@ -275,6 +280,7 @@ def patched_literal_array( # pylint: disable=protected-access original_drop_unused_vars = pe._drop_unused_vars + # pylint: disable=too-many-function-args def patched_drop_unused_vars(constvars, constvals, eqns=None, outvars=None): """Patched drop_unused_vars to ensure constvals is a list.""" @@ -394,8 +400,10 @@ def patched_bind_with_trace(self, trace, args, params): return self.to_lojax(*args, **params) return trace.process_primitive(self, args, params) + core.Primitive.bind_with_trace = patched_bind_with_trace + def patched_dyn_shape_staging_rule(trace, source_info, prim, out_aval, *args, **params): """Patched _dyn_shape_staging_rule for dynamic shape handling.""" eqn, out_tracer = trace.make_eqn(args, out_aval, prim, params, core.no_effects, source_info)