Replies: 2 comments 1 reply
-
vjp seems to be the closest thing to what you are asking. It can return your loss value as well as a function that can calculate vector jacobian products (vjps). As noted in the docs, gradient is a special case of VJPs so you can use the returned function to compute the gradient. The problem for your use case is that vjp does not directly expose the intermediates. I am not sure if you can save the returned function directly (e.g. via using pickle or another tool). But this can be sidestepped. The function returned by vjp seems to be a partial function with some arguments of the partial function also being partial functions themselves. You can rewrite the vjp so that it return a stateless function and its fixed arguments seperately. Then you can save the fixed arguments and later restore them to call the stateless function. |
Beta Was this translation helpful? Give feedback.
-
Extending on the answer by @lamflokas, you are essentially looking for the @jax.jit
def f(x,y):
return jnp.cos(x)*(1+jnp.sin(y))
@jax.jit
def f_fwd(x,y):
z, vjp = jax.vjp(f,x,y)
# extract the constants from the bound VJP
intermediate_values = vjp.args[0].args[0]
return z, intermediate_values
# somehow gain access to the `unbound_vjp`
_,abstract_vjp = jax.vjp(f, 0., 0.)
@jax.jit
def f_bwd(intermediate_values, ct):
# re-assemble constants/intermediate values into a bound VJP
vjp = type(abstract_vjp)(
abstract_vjp.func,
type(abstract_vjp.args[0])(
abstract_vjp.args[0].func,
intermediate_values,
*abstract_vjp.args[0].args[1:],
),
*abstract_vjp.args[1:],
)
return vjp(ct)
# demonstration
for x in np.linspace(0,np.pi,31):
for y in np.linspace(0,np.pi,31):
ctgt = 1.
# forward pass in normal `jax.vjp`; you get access to the backward pass function only *after* you pass in the input
ve, vjp = jax.vjp(f,x,y)
# backward pass
gxe, gye = vjp(ctgt)
# compiled forward pass returning the "intermediate values" for later re-use in a backward pass
va,ivals = f_fwd(x,y)
# pre-compiled backward pass using those intermediate values
gxa, gya = f_bwd(ivals, ctgt)
# verify we have the same results
expected = np.r_[ve,gxe,gye]
actual = np.r_[va,gxa,gya]
assert np.allclose(actual, expected), f"{x=:.3f} {y=:.3f}: {np.round(expected,3)} <=> {np.round(actual,3)} ({ivals})" The challenge here is that the |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Say I have a model forward function
loss_fn
, which takes the input data and return the computed loss. Currently JAX is able to perform the following:loss_fn(input)
to compute the forward functiongrad(loss_fn)(input)
to compute the gradient at input, also evaluatesloss_fn
value_and_grad(loss_fn)(input)
returns both loss and gradients, also evaluatesloss_fn
But now I would like to separate the execution of forward and backward computation, evaluating the function first, and then use the intermediate variables as well as the loss to compute the backward. Is it possible to save all the activations of the function, and avoid re-evaluating the original function by passing/storing the computed activations to perform backward propagation?
Something like the following would be helpful, especially when training models in a pipeline parallel fashion:
Beta Was this translation helpful? Give feedback.
All reactions