Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dep-versions
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Always update the version check in catalyst.__init__ when changing the JAX version.
# To update JAX version alongside compatible dependency tags, run the following script:
# python3 .github/workflows/set_dep_versions.py {JAX_version}
jax=0.7.1
jax=0.7.2
stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d
llvm=113f01aa82d055410f22a9d03b3468fa68600589
enzyme=v0.0.203
Expand Down
43 changes: 42 additions & 1 deletion frontend/catalyst/jax_extras/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import jax._src.interpreters.mlir as mlir
import jax._src.interpreters.partial_eval as pe
from jax._src import config, core, source_info_util
from jax._src.core import JaxprEqnContext, Var, abstractify, standard_vma_rule
from jax._src.core import JaxprEqnContext, Var, abstractify, standard_vma_rule, typeof
from jax._src.interpreters import pxla
from jax._src.interpreters.partial_eval import (
DynamicJaxprTracer,
Expand Down Expand Up @@ -251,6 +251,32 @@ def _gather_shape_rule_dynamic(
vma_rule=partial(standard_vma_rule, "gather"),
)

import numpy as np


def patched_literal_int(_cls, value: int, _dtype: np.dtype):
return int.__new__(int, value)


def patched_literal_float(_cls, value: float, _dtype: np.dtype):
return float.__new__(float, value)


def patched_literal_complex(_cls, value: complex, _dtype: np.dtype):
return complex.__new__(complex, value)


def patched_literal_array(
_cls, val: np.ndarray, weak_type: bool = False
): # pylint: disable=unused-argument
arr = np.asarray(val)
return arr.view(np.ndarray)


jax._src.literals.LiteralInt.__new__ = patched_literal_int
jax._src.literals.LiteralFloat.__new__ = patched_literal_float
jax._src.literals.LiteralComplex.__new__ = patched_literal_complex
jax._src.literals.LiteralArray.__new__ = patched_literal_array
# pylint: disable=protected-access
original_drop_unused_vars = pe._drop_unused_vars

Expand Down Expand Up @@ -363,6 +389,21 @@ def patched_multi_broadcast_in_dim(ctx, ops, ops_avals, out_shape, out_sharding=
return out


def patched_bind_with_trace(self, trace, args, params):
try:
in_type = list(map(typeof, args))
except Exception: # pylint: disable=broad-exception-caught
return trace.process_primitive(self, args, params)
else:
if self.is_high(*in_type, **params) and trace.requires_low:
with core.set_current_trace(trace):
return self.to_lojax(*args, **params)
return trace.process_primitive(self, args, params)


core.Primitive.bind_with_trace = patched_bind_with_trace


def patched_dyn_shape_staging_rule(trace, source_info, prim, out_aval, *args, **params):
"""Patched _dyn_shape_staging_rule for dynamic shape handling."""
eqn, out_tracer = trace.make_eqn(args, out_aval, prim, params, core.no_effects, source_info)
Expand Down
10 changes: 7 additions & 3 deletions frontend/test/pytest/test_jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import jax
import pytest
from jax import make_jaxpr
from jax import make_jaxpr, core
from jax._src.lib.mlir import ir
from jax.interpreters.mlir import ir_constant, make_ir_context

Expand Down Expand Up @@ -50,7 +50,9 @@ def test_extract_wire_type_error(self, test_input):
ctx = make_ir_context()
jax_ctx = JAXCTX(MODCTX(ctx))
with ir.Location.unknown(ctx):
index_value = ir_constant(test_input)
dtype = jax.numpy.dtype(type(test_input))
aval = core.ShapedArray([], dtype)
index_value = ir_constant(test_input, aval=aval)
qreg_value = VALUE(ir.OpaqueType.get("quantum", "reg"))
with pytest.raises(TypeError, match="Operator wires expected to be integers"):
_qextract_lowering(jax_ctx, qreg_value, index_value)
Expand All @@ -62,7 +64,9 @@ def test_insert_wire_type_error(self, test_input):
ctx = make_ir_context()
jax_ctx = JAXCTX(MODCTX(ctx))
with ir.Location.unknown(ctx):
index_value = ir_constant(test_input)
dtype = jax.numpy.dtype(type(test_input))
aval = core.ShapedArray([], dtype)
index_value = ir_constant(test_input, aval=aval)
qreg_value = VALUE(ir.OpaqueType.get("quantum", "reg"))
qbit_value = VALUE(ir.OpaqueType.get("quantum", "bit"))
with pytest.raises(TypeError, match="Operator wires expected to be integers"):
Expand Down
Loading