Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stackless fixes #37

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 72 additions & 91 deletions quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,76 +141,51 @@ def _wrap_if_array(x: Union[ArrayLike, "Value"]) -> "Value":


class _QuaxTrace(core.Trace[_QuaxTracer]):
def pure(self, val: ArrayLike) -> _QuaxTracer:
if _is_value(val):
raise TypeError(
f"Encountered Quax value of type {type(val)}. These must be "
"transformed by passing them across a `quax.quaxify` boundary before "
"being used.\n"
"For example, the following is incorrect, as `SomeValue()` is not "
"explicitly passed across the API boundary:\n"
"```\n"
"def f(x):\n"
" return x + SomeValue()\n"
"\n"
"quax.quaxify(f)(AnotherValue())\n"
"```\n"
"This should instead be written as the following:\n"
"explicitly passed across the API boundary:\n"
"```\n"
"def f(x, y):\n"
" return x + y\n"
"\n"
"quax.quaxify(f)(AnotherValue(), SomeValue())\n"
"```\n"
"To better understand this, remember that the purpose of Quax is "
"take a JAX program (given as a function) that acts on arrays, and to "
"instead run it with array-ish types. But in the first example above, "
"the original program already has an array-ish type, even before the "
"`quaxify` is introduced."
)
if not eqx.is_array_like(val):
raise TypeError(f"{type(val)} is not a JAX type.")
return _QuaxTracer(self, _DenseArrayValue(val)) # pyright: ignore

def lift(self, tracer: core.Tracer) -> _QuaxTracer:
return _QuaxTracer(self, _DenseArrayValue(tracer))
def __init__(self, parent_trace, tag):
self.tag = tag
self.parent_trace = parent_trace

def sublift(self, tracer: _QuaxTracer) -> _QuaxTracer:
return tracer
def to_value(self, val):
if isinstance(val, _QuaxTracer) and val._trace.tag is self.tag:
return val.value
else:
return _DenseArrayValue(val)

def process_primitive(self, primitive, tracers, params):
values = [t.value for t in tracers]
# params = dict(params); params.pop('sharding', None)
values = [self.to_value(t) for t in tracers]
values = tuple(
x.array if isinstance(x, _DenseArrayValue) else x for x in values
)
try:
rule = _rules[primitive]
except KeyError:
out = _default_process(primitive, values, params)
with core.set_current_trace(self.parent_trace):
out = _default_process(primitive, values, params)
else:
try:
method, _ = rule.resolve_method(values)
except plum.NotFoundLookupError:
out = _default_process(primitive, values, params)
with core.set_current_trace(self.parent_trace):
out = _default_process(primitive, values, params)
else:
out = method(*values, **params)
with core.set_current_trace(self.parent_trace):
out = method(*values, **params)
if primitive.multiple_results:
return [_QuaxTracer(self, _wrap_if_array(x)) for x in out] # pyright: ignore
else:
return _QuaxTracer(self, _wrap_if_array(out)) # pyright: ignore

def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zeros):
in_values = [t.value for t in tracers]
in_values = [self.to_value(t) for t in tracers]
# Each `t.value` will be some `Value`, and thus a PyTree. Here we flatten the
# `Value`-ness away.
in_leaves, in_treedef = jtu.tree_flatten(in_values)
fun, out_treedef1 = _custom_jvp_fun_wrap(fun, self.main, in_treedef) # pyright: ignore
jvp, out_treedef2 = _custom_jvp_jvp_wrap(jvp, self.main, in_treedef) # pyright: ignore
with jax.ensure_compile_time_eval():
out_leaves = primitive.bind(
fun, jvp, *in_leaves, symbolic_zeros=symbolic_zeros
)
fun, out_treedef1 = _custom_jvp_fun_wrap(fun, self.tag, in_treedef) # pyright: ignore
jvp, out_treedef2 = _custom_jvp_jvp_wrap(jvp, self.tag, in_treedef) # pyright: ignore
out_leaves = primitive.bind_with_trace(
self.parent_trace, (fun, jvp, *in_leaves), dict(symbolic_zeros=symbolic_zeros),
)
_, out_treedef = lu.merge_linear_aux(out_treedef1, out_treedef2)
out_values = jtu.tree_unflatten(out_treedef, out_leaves)
return [_QuaxTracer(self, x) for x in out_values]
Expand All @@ -219,51 +194,59 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zero


@lu.transformation_with_aux # pyright: ignore
def _custom_jvp_fun_wrap(main, in_treedef, *in_leaves):
trace = main.with_cur_sublevel()
def _custom_jvp_fun_wrap(tag, in_treedef, *in_leaves):
in_values = jtu.tree_unflatten(in_treedef, in_leaves)
in_tracers = [x if type(x) is SZ else _QuaxTracer(trace, x) for x in in_values]
out_tracers = yield in_tracers, {}
# The symbolic zero branch here will actually create a `quax.zero.Zero`!
out_tracers = [
jnp.zeros(t.aval.shape, t.aval.dtype) if type(t) is SZ else t # pyright: ignore
for t in out_tracers
]
out_values = [trace.full_raise(t).value for t in out_tracers]
with core.take_current_trace() as parent_trace:
trace = _QuaxTrace(parent_trace, tag)
in_tracers = [x if type(x) is SZ else _QuaxTracer(trace, x) for x in in_values]
with core.set_current_trace(trace):
out_tracers = yield in_tracers, {}
# The symbolic zero branch here will actually create a `quax.zero.Zero`!
out_tracers = [
jnp.zeros(t.aval.shape, t.aval.dtype) if type(t) is SZ else t # pyright: ignore
for t in out_tracers
]
out_values = [trace.to_value(t) for t in out_tracers]
del out_tracers
del trace, in_tracers
out_leaves, out_treedef = jtu.tree_flatten(out_values)
yield out_leaves, out_treedef


@lu.transformation_with_aux # pyright: ignore
def _custom_jvp_jvp_wrap(main, in_treedef, *in_primals_and_tangents):
trace = main.with_cur_sublevel()
def _custom_jvp_jvp_wrap(tag, in_treedef, *in_primals_and_tangents):
in_primals = in_primals_and_tangents[: len(in_primals_and_tangents) // 2]
in_tangents = in_primals_and_tangents[len(in_primals_and_tangents) // 2 :]
in_primal_values = jtu.tree_unflatten(in_treedef, in_primals)
in_tangent_values = jtu.tree_unflatten(in_treedef, in_tangents)
# Calling `_QuaxTracer` directly here, not using `trace.{pure,lift}` as each `x` is
# a `Value`, not an array (=> pure) or tracer (=> lift).
in_tracers = [
_QuaxTracer(trace, x) for x in it.chain(in_primal_values, in_tangent_values)
]
out_tracers = yield in_tracers, {}
# The symbolic zero branch here will actually create a `quax.zero.Zero`!
out_tracers = [
jnp.zeros(t.aval.shape, t.aval.dtype) if type(t) is SZ else t # pyright: ignore
for t in out_tracers
]
out_values = [trace.full_raise(t).value for t in out_tracers]
out_primal_values = out_values[: len(out_values) // 2]
out_tangent_values = out_values[len(out_values) // 2 :]
out_primal_values2 = []
out_tangent_values2 = []
assert len(out_primal_values) == len(out_tangent_values)
for primal, tangent in zip(out_primal_values, out_tangent_values):
if primal.__class__ != tangent.__class__:
primal = primal.materialise()
tangent = tangent.materialise()
out_primal_values2.append(primal)
out_tangent_values2.append(tangent)
with core.take_current_trace() as parent_trace:
trace = _QuaxTrace(parent_trace, tag)
in_tracers = [
_QuaxTracer(trace, x) for x in it.chain(in_primal_values, in_tangent_values)
]
with core.set_current_trace(trace):
out_tracers = yield in_tracers, {}
# The symbolic zero branch here will actually create a `quax.zero.Zero`!
out_tracers = [
jnp.zeros(t.aval.shape, t.aval.dtype) if type(t) is SZ else t # pyright: ignore
for t in out_tracers
]
out_values = [trace.to_value(t) for t in out_tracers]
out_primal_values = out_values[: len(out_values) // 2]
out_tangent_values = out_values[len(out_values) // 2 :]
out_primal_values2 = []
out_tangent_values2 = []
assert len(out_primal_values) == len(out_tangent_values)
for primal, tangent in zip(out_primal_values, out_tangent_values):
if primal.__class__ != tangent.__class__:
primal = primal.materialise()
tangent = tangent.materialise()
out_primal_values2.append(primal)
out_tangent_values2.append(tangent)
del out_tracers
del trace, in_tracers
out_primals, out_primal_treedef = jtu.tree_flatten(out_primal_values2)
out_tangents, out_tangent_treedef = jtu.tree_flatten(out_tangent_values2)
if out_primal_treedef != out_tangent_treedef:
Expand Down Expand Up @@ -307,21 +290,20 @@ def __wrapped__(self) -> CT:
return self.fn

def __call__(self, *args, **kwargs):
with core.new_main(_QuaxTrace, dynamic=self.dynamic) as main:
trace = _QuaxTrace(main, core.cur_sublevel())
# Note that we do *not* wrap arraylikes here. We let that happen in
# `_QuaxTrace.{pure,lift}` as necessary. This means that we can do e.g.
# quaxify(jnp.moveaxis)(array, source=0, destination=-1).
dynamic, static = eqx.partition(
(self.fn, args, kwargs), self.filter_spec, is_leaf=_is_value
)
dynamic, static = eqx.partition(
(self.fn, args, kwargs), self.filter_spec, is_leaf=_is_value
)
tag = core.TraceTag()
with core.take_current_trace() as parent_trace:
trace = _QuaxTrace(parent_trace, tag)
dynamic = jtu.tree_map(
ft.partial(_wrap_tracer, trace),
dynamic,
is_leaf=_is_value,
)
fn, args, kwargs = eqx.combine(dynamic, static)
out = fn(*args, **kwargs)
with core.set_current_trace(trace):
out = fn(*args, **kwargs)
out = jtu.tree_map(ft.partial(_unwrap_tracer, trace), out)
return out

Expand Down Expand Up @@ -543,8 +525,7 @@ def _(*args: Union[ArrayLike, ArrayValue], jaxpr, inline, **kwargs):
else:
leaves, treedef = jtu.tree_flatten(args) # remove all Values
flat_fun = lambda x: fun(*jtu.tree_unflatten(treedef, x))
with jax.ensure_compile_time_eval(): # replace the dynamic QuaxTrace
return jax.jit(flat_fun)(leaves) # now we can call without Quax.
return jax.jit(flat_fun)(leaves) # now we can call without Quax.


@register(jax.lax.while_p)
Expand Down
2 changes: 1 addition & 1 deletion quax/examples/sparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _op_sparse_to_dense(x, y, op):


@quax.register(lax.broadcast_in_dim_p)
def _(value: BCOO, *, broadcast_dimensions, shape) -> BCOO:
def _(value: BCOO, *, broadcast_dimensions, shape, sharding=None) -> BCOO:
n_extra_batch_dims = len(shape) - value.ndim
if broadcast_dimensions != tuple(range(n_extra_batch_dims, len(shape))):
raise NotImplementedError(
Expand Down
9 changes: 5 additions & 4 deletions quax/examples/zero/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,24 @@ def materialise(self):

@quax.register(lax.broadcast_in_dim_p)
def _(
value: ArrayLike, *, broadcast_dimensions, shape
value: ArrayLike, *, broadcast_dimensions, shape, sharding=None
) -> Union[ArrayLike, quax.ArrayValue]:
aval = jax.core.get_aval(value)
if isinstance(aval, jax.core.ConcreteArray) and aval.shape == () and aval.val == 0:
if False and aval.shape == () and aval.val == 0:
return Zero(shape, np.result_type(value))
else:
# Avoid an infinite loop, by pushing a new interpreter to the dynamic
# interpreter stack.
with jax.ensure_compile_time_eval():
out = lax.broadcast_in_dim_p.bind(
value, broadcast_dimensions=broadcast_dimensions, shape=shape
value, broadcast_dimensions=broadcast_dimensions, shape=shape,
sharding=sharding
)
return out # pyright: ignore


@quax.register(lax.broadcast_in_dim_p)
def _(value: Zero, *, broadcast_dimensions, shape) -> Zero:
def _(value: Zero, *, broadcast_dimensions, shape, sharding=None) -> Zero:
del broadcast_dimensions
return Zero(shape, value.dtype)

Expand Down