diff --git a/.dep-versions b/.dep-versions index bde589f8a..f6223e173 100644 --- a/.dep-versions +++ b/.dep-versions @@ -1,7 +1,7 @@ # Always update the version check in catalyst.__init__ when changing the JAX version. # To update JAX version alongside compatible dependency tags, run the following script: # python3 .github/workflows/set_dep_versions.py {JAX_version} -jax=0.7.1 +jax=0.7.2 stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d llvm=113f01aa82d055410f22a9d03b3468fa68600589 enzyme=v0.0.203 diff --git a/Makefile b/Makefile index a3d0daf25..403a4cc6b 100644 --- a/Makefile +++ b/Makefile @@ -118,7 +118,7 @@ frontend: $(PYTHON) -m pip uninstall -y pennylane $(PYTHON) -m pip install -e . --extra-index-url https://test.pypi.org/simple $(PIP_VERBOSE_FLAG) # TODO: remove after https://github.com/PennyLaneAI/pennylane/pull/8525 is merged. - $(PYTHON) -m pip install git+https://github.com/PennyLaneAI/pennylane@bump-jax-to-0.7.0 + $(PYTHON) -m pip install git+https://github.com/PennyLaneAI/pennylane@bump-jax-0.7.2 rm -r frontend/pennylane_catalyst.egg-info .PHONY: mlir llvm stablehlo enzyme dialects runtime oqc diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index cce08e7ff..2e61c9bda 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -23,7 +23,7 @@ import jax._src.interpreters.mlir as mlir import jax._src.interpreters.partial_eval as pe from jax._src import config, core, source_info_util -from jax._src.core import JaxprEqnContext, Var, abstractify, standard_vma_rule +from jax._src.core import JaxprEqnContext, Var, abstractify, standard_vma_rule, typeof from jax._src.interpreters import pxla from jax._src.interpreters.partial_eval import ( DynamicJaxprTracer, @@ -251,6 +251,32 @@ def _gather_shape_rule_dynamic( vma_rule=partial(standard_vma_rule, "gather"), ) +# import numpy as np + +# def patched_literal_int(_cls, value: int, _dtype: np.dtype): +# return int.__new__(int, value) + + +# def patched_literal_float(_cls, value: float, _dtype: np.dtype): +# return float.__new__(float, value) + + +# def patched_literal_complex(_cls, value: complex, _dtype: np.dtype): +# return complex.__new__(complex, value) + + +# def patched_literal_array( +# _cls, val: np.ndarray, weak_type: bool = False +# ): # pylint: disable=unused-argument +# arr = np.asarray(val) +# return arr.view(np.ndarray) + + +# jax._src.literals.LiteralInt.__new__ = patched_literal_int +# jax._src.literals.LiteralFloat.__new__ = patched_literal_float +# jax._src.literals.LiteralComplex.__new__ = patched_literal_complex +# jax._src.literals.LiteralArray.__new__ = patched_literal_array + # pylint: disable=protected-access original_drop_unused_vars = pe._drop_unused_vars @@ -363,6 +389,21 @@ def patched_multi_broadcast_in_dim(ctx, ops, ops_avals, out_shape, out_sharding= return out +def patched_bind_with_trace(self, trace, args, params): + try: + in_type = list(map(typeof, args)) + except Exception: # pylint: disable=broad-exception-caught + return trace.process_primitive(self, args, params) + else: + if self.is_high(*in_type, **params) and trace.requires_low: + with core.set_current_trace(trace): + return self.to_lojax(*args, **params) + return trace.process_primitive(self, args, params) + + +core.Primitive.bind_with_trace = patched_bind_with_trace + + def patched_dyn_shape_staging_rule(trace, source_info, prim, out_aval, *args, **params): """Patched _dyn_shape_staging_rule for dynamic shape handling.""" eqn, out_tracer = trace.make_eqn(args, out_aval, prim, params, core.no_effects, source_info) diff --git a/frontend/test/pytest/test_jax_primitives.py b/frontend/test/pytest/test_jax_primitives.py index 44d5995ba..1f4a2caac 100644 --- a/frontend/test/pytest/test_jax_primitives.py +++ b/frontend/test/pytest/test_jax_primitives.py @@ -21,7 +21,7 @@ import jax import pytest -from jax import make_jaxpr +from jax import make_jaxpr, core from jax._src.lib.mlir import ir from jax.interpreters.mlir import ir_constant, make_ir_context @@ -50,7 +50,9 @@ def test_extract_wire_type_error(self, test_input): ctx = make_ir_context() jax_ctx = JAXCTX(MODCTX(ctx)) with ir.Location.unknown(ctx): - index_value = ir_constant(test_input) + dtype = jax.numpy.dtype(type(test_input)) + aval = core.ShapedArray([], dtype) + index_value = ir_constant(test_input, aval=aval) qreg_value = VALUE(ir.OpaqueType.get("quantum", "reg")) with pytest.raises(TypeError, match="Operator wires expected to be integers"): _qextract_lowering(jax_ctx, qreg_value, index_value) @@ -62,7 +64,9 @@ def test_insert_wire_type_error(self, test_input): ctx = make_ir_context() jax_ctx = JAXCTX(MODCTX(ctx)) with ir.Location.unknown(ctx): - index_value = ir_constant(test_input) + dtype = jax.numpy.dtype(type(test_input)) + aval = core.ShapedArray([], dtype) + index_value = ir_constant(test_input, aval=aval) qreg_value = VALUE(ir.OpaqueType.get("quantum", "reg")) qbit_value = VALUE(ir.OpaqueType.get("quantum", "bit")) with pytest.raises(TypeError, match="Operator wires expected to be integers"):