diff --git a/quax/_core.py b/quax/_core.py index 132e7e4..6046d06 100644 --- a/quax/_core.py +++ b/quax/_core.py @@ -141,76 +141,51 @@ def _wrap_if_array(x: Union[ArrayLike, "Value"]) -> "Value": class _QuaxTrace(core.Trace[_QuaxTracer]): - def pure(self, val: ArrayLike) -> _QuaxTracer: - if _is_value(val): - raise TypeError( - f"Encountered Quax value of type {type(val)}. These must be " - "transformed by passing them across a `quax.quaxify` boundary before " - "being used.\n" - "For example, the following is incorrect, as `SomeValue()` is not " - "explicitly passed across the API boundary:\n" - "```\n" - "def f(x):\n" - " return x + SomeValue()\n" - "\n" - "quax.quaxify(f)(AnotherValue())\n" - "```\n" - "This should instead be written as the following:\n" - "explicitly passed across the API boundary:\n" - "```\n" - "def f(x, y):\n" - " return x + y\n" - "\n" - "quax.quaxify(f)(AnotherValue(), SomeValue())\n" - "```\n" - "To better understand this, remember that the purpose of Quax is " - "take a JAX program (given as a function) that acts on arrays, and to " - "instead run it with array-ish types. But in the first example above, " - "the original program already has an array-ish type, even before the " - "`quaxify` is introduced." - ) - if not eqx.is_array_like(val): - raise TypeError(f"{type(val)} is not a JAX type.") - return _QuaxTracer(self, _DenseArrayValue(val)) # pyright: ignore - - def lift(self, tracer: core.Tracer) -> _QuaxTracer: - return _QuaxTracer(self, _DenseArrayValue(tracer)) + def __init__(self, parent_trace, tag): + self.tag = tag + self.parent_trace = parent_trace - def sublift(self, tracer: _QuaxTracer) -> _QuaxTracer: - return tracer + def to_value(self, val): + if isinstance(val, _QuaxTracer) and val._trace.tag is self.tag: + return val.value + else: + return _DenseArrayValue(val) def process_primitive(self, primitive, tracers, params): - values = [t.value for t in tracers] + # params = dict(params); params.pop('sharding', None) + values = [self.to_value(t) for t in tracers] values = tuple( x.array if isinstance(x, _DenseArrayValue) else x for x in values ) try: rule = _rules[primitive] except KeyError: - out = _default_process(primitive, values, params) + with core.set_current_trace(self.parent_trace): + out = _default_process(primitive, values, params) else: try: method, _ = rule.resolve_method(values) except plum.NotFoundLookupError: - out = _default_process(primitive, values, params) + with core.set_current_trace(self.parent_trace): + out = _default_process(primitive, values, params) else: - out = method(*values, **params) + with core.set_current_trace(self.parent_trace): + out = method(*values, **params) if primitive.multiple_results: return [_QuaxTracer(self, _wrap_if_array(x)) for x in out] # pyright: ignore else: return _QuaxTracer(self, _wrap_if_array(out)) # pyright: ignore def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zeros): - in_values = [t.value for t in tracers] + in_values = [self.to_value(t) for t in tracers] # Each `t.value` will be some `Value`, and thus a PyTree. Here we flatten the # `Value`-ness away. in_leaves, in_treedef = jtu.tree_flatten(in_values) - fun, out_treedef1 = _custom_jvp_fun_wrap(fun, self.main, in_treedef) # pyright: ignore - jvp, out_treedef2 = _custom_jvp_jvp_wrap(jvp, self.main, in_treedef) # pyright: ignore - with jax.ensure_compile_time_eval(): - out_leaves = primitive.bind( - fun, jvp, *in_leaves, symbolic_zeros=symbolic_zeros - ) + fun, out_treedef1 = _custom_jvp_fun_wrap(fun, self.tag, in_treedef) # pyright: ignore + jvp, out_treedef2 = _custom_jvp_jvp_wrap(jvp, self.tag, in_treedef) # pyright: ignore + out_leaves = primitive.bind_with_trace( + self.parent_trace, (fun, jvp, *in_leaves), dict(symbolic_zeros=symbolic_zeros), + ) _, out_treedef = lu.merge_linear_aux(out_treedef1, out_treedef2) out_values = jtu.tree_unflatten(out_treedef, out_leaves) return [_QuaxTracer(self, x) for x in out_values] @@ -219,51 +194,59 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zero @lu.transformation_with_aux # pyright: ignore -def _custom_jvp_fun_wrap(main, in_treedef, *in_leaves): - trace = main.with_cur_sublevel() +def _custom_jvp_fun_wrap(tag, in_treedef, *in_leaves): in_values = jtu.tree_unflatten(in_treedef, in_leaves) - in_tracers = [x if type(x) is SZ else _QuaxTracer(trace, x) for x in in_values] - out_tracers = yield in_tracers, {} - # The symbolic zero branch here will actually create a `quax.zero.Zero`! - out_tracers = [ - jnp.zeros(t.aval.shape, t.aval.dtype) if type(t) is SZ else t # pyright: ignore - for t in out_tracers - ] - out_values = [trace.full_raise(t).value for t in out_tracers] + with core.take_current_trace() as parent_trace: + trace = _QuaxTrace(parent_trace, tag) + in_tracers = [x if type(x) is SZ else _QuaxTracer(trace, x) for x in in_values] + with core.set_current_trace(trace): + out_tracers = yield in_tracers, {} + # The symbolic zero branch here will actually create a `quax.zero.Zero`! + out_tracers = [ + jnp.zeros(t.aval.shape, t.aval.dtype) if type(t) is SZ else t # pyright: ignore + for t in out_tracers + ] + out_values = [trace.to_value(t) for t in out_tracers] + del out_tracers + del trace, in_tracers out_leaves, out_treedef = jtu.tree_flatten(out_values) yield out_leaves, out_treedef @lu.transformation_with_aux # pyright: ignore -def _custom_jvp_jvp_wrap(main, in_treedef, *in_primals_and_tangents): - trace = main.with_cur_sublevel() +def _custom_jvp_jvp_wrap(tag, in_treedef, *in_primals_and_tangents): in_primals = in_primals_and_tangents[: len(in_primals_and_tangents) // 2] in_tangents = in_primals_and_tangents[len(in_primals_and_tangents) // 2 :] in_primal_values = jtu.tree_unflatten(in_treedef, in_primals) in_tangent_values = jtu.tree_unflatten(in_treedef, in_tangents) # Calling `_QuaxTracer` directly here, not using `trace.{pure,lift}` as each `x` is # a `Value`, not an array (=> pure) or tracer (=> lift). - in_tracers = [ - _QuaxTracer(trace, x) for x in it.chain(in_primal_values, in_tangent_values) - ] - out_tracers = yield in_tracers, {} - # The symbolic zero branch here will actually create a `quax.zero.Zero`! - out_tracers = [ - jnp.zeros(t.aval.shape, t.aval.dtype) if type(t) is SZ else t # pyright: ignore - for t in out_tracers - ] - out_values = [trace.full_raise(t).value for t in out_tracers] - out_primal_values = out_values[: len(out_values) // 2] - out_tangent_values = out_values[len(out_values) // 2 :] - out_primal_values2 = [] - out_tangent_values2 = [] - assert len(out_primal_values) == len(out_tangent_values) - for primal, tangent in zip(out_primal_values, out_tangent_values): - if primal.__class__ != tangent.__class__: - primal = primal.materialise() - tangent = tangent.materialise() - out_primal_values2.append(primal) - out_tangent_values2.append(tangent) + with core.take_current_trace() as parent_trace: + trace = _QuaxTrace(parent_trace, tag) + in_tracers = [ + _QuaxTracer(trace, x) for x in it.chain(in_primal_values, in_tangent_values) + ] + with core.set_current_trace(trace): + out_tracers = yield in_tracers, {} + # The symbolic zero branch here will actually create a `quax.zero.Zero`! + out_tracers = [ + jnp.zeros(t.aval.shape, t.aval.dtype) if type(t) is SZ else t # pyright: ignore + for t in out_tracers + ] + out_values = [trace.to_value(t) for t in out_tracers] + out_primal_values = out_values[: len(out_values) // 2] + out_tangent_values = out_values[len(out_values) // 2 :] + out_primal_values2 = [] + out_tangent_values2 = [] + assert len(out_primal_values) == len(out_tangent_values) + for primal, tangent in zip(out_primal_values, out_tangent_values): + if primal.__class__ != tangent.__class__: + primal = primal.materialise() + tangent = tangent.materialise() + out_primal_values2.append(primal) + out_tangent_values2.append(tangent) + del out_tracers + del trace, in_tracers out_primals, out_primal_treedef = jtu.tree_flatten(out_primal_values2) out_tangents, out_tangent_treedef = jtu.tree_flatten(out_tangent_values2) if out_primal_treedef != out_tangent_treedef: @@ -307,21 +290,20 @@ def __wrapped__(self) -> CT: return self.fn def __call__(self, *args, **kwargs): - with core.new_main(_QuaxTrace, dynamic=self.dynamic) as main: - trace = _QuaxTrace(main, core.cur_sublevel()) - # Note that we do *not* wrap arraylikes here. We let that happen in - # `_QuaxTrace.{pure,lift}` as necessary. This means that we can do e.g. - # quaxify(jnp.moveaxis)(array, source=0, destination=-1). - dynamic, static = eqx.partition( - (self.fn, args, kwargs), self.filter_spec, is_leaf=_is_value - ) + dynamic, static = eqx.partition( + (self.fn, args, kwargs), self.filter_spec, is_leaf=_is_value + ) + tag = core.TraceTag() + with core.take_current_trace() as parent_trace: + trace = _QuaxTrace(parent_trace, tag) dynamic = jtu.tree_map( ft.partial(_wrap_tracer, trace), dynamic, is_leaf=_is_value, ) fn, args, kwargs = eqx.combine(dynamic, static) - out = fn(*args, **kwargs) + with core.set_current_trace(trace): + out = fn(*args, **kwargs) out = jtu.tree_map(ft.partial(_unwrap_tracer, trace), out) return out @@ -543,8 +525,7 @@ def _(*args: Union[ArrayLike, ArrayValue], jaxpr, inline, **kwargs): else: leaves, treedef = jtu.tree_flatten(args) # remove all Values flat_fun = lambda x: fun(*jtu.tree_unflatten(treedef, x)) - with jax.ensure_compile_time_eval(): # replace the dynamic QuaxTrace - return jax.jit(flat_fun)(leaves) # now we can call without Quax. + return jax.jit(flat_fun)(leaves) # now we can call without Quax. @register(jax.lax.while_p) diff --git a/quax/examples/sparse/_core.py b/quax/examples/sparse/_core.py index 185c2dd..4664eef 100644 --- a/quax/examples/sparse/_core.py +++ b/quax/examples/sparse/_core.py @@ -78,7 +78,7 @@ def _op_sparse_to_dense(x, y, op): @quax.register(lax.broadcast_in_dim_p) -def _(value: BCOO, *, broadcast_dimensions, shape) -> BCOO: +def _(value: BCOO, *, broadcast_dimensions, shape, sharding=None) -> BCOO: n_extra_batch_dims = len(shape) - value.ndim if broadcast_dimensions != tuple(range(n_extra_batch_dims, len(shape))): raise NotImplementedError( diff --git a/quax/examples/zero/_core.py b/quax/examples/zero/_core.py index b8087c3..51a9bfd 100644 --- a/quax/examples/zero/_core.py +++ b/quax/examples/zero/_core.py @@ -42,23 +42,19 @@ def materialise(self): @quax.register(lax.broadcast_in_dim_p) def _( - value: ArrayLike, *, broadcast_dimensions, shape + value: ArrayLike, *, broadcast_dimensions, shape, sharding=None ) -> Union[ArrayLike, quax.ArrayValue]: - aval = jax.core.get_aval(value) - if isinstance(aval, jax.core.ConcreteArray) and aval.shape == () and aval.val == 0: - return Zero(shape, np.result_type(value)) - else: - # Avoid an infinite loop, by pushing a new interpreter to the dynamic - # interpreter stack. - with jax.ensure_compile_time_eval(): - out = lax.broadcast_in_dim_p.bind( - value, broadcast_dimensions=broadcast_dimensions, shape=shape - ) - return out # pyright: ignore + # Avoid an infinite loop using ensure_compile_time_eval. + with jax.ensure_compile_time_eval(): + out = lax.broadcast_in_dim_p.bind( + value, broadcast_dimensions=broadcast_dimensions, shape=shape, + sharding=sharding + ) + return out # pyright: ignore @quax.register(lax.broadcast_in_dim_p) -def _(value: Zero, *, broadcast_dimensions, shape) -> Zero: +def _(value: Zero, *, broadcast_dimensions, shape, sharding=None) -> Zero: del broadcast_dimensions return Zero(shape, value.dtype)