Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about efficiency of higher-order derivatives when defining custom jvp #813

Open
leonard-gleyzer opened this issue Aug 25, 2024 · 1 comment
Labels
question User queries

Comments

@leonard-gleyzer
Copy link

leonard-gleyzer commented Aug 25, 2024

Hello,

I'm trying to implement a Separable MLP, where instead of vmap-ing a single MLP(in_size=3, out_size="scalar", ...), I have 3 separate MLP(in_size="scalar", out_size=latent_size, ...)s, where I vmap each individual MLP across its specified coordinate, then take outer product and sum across the latent dimension for the final scalar outputs of the cartesian product of the three separate coordinate batches.

I have implemented a custom jvp, which significantly speeds up jacfwd, and scales very well with increasing latent size. However, jacfwd(jacfwd) is significantly slower than the non-separable implementation, and scales very poorly with increasing latent size.

MWE:

import functools as ft
import time
import warnings

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr


warnings.filterwarnings("ignore")


@eqx.filter_custom_jvp
def f(x_batch__y_batch__z_batch, *, mlp1__mlp2__mlp3):
    mlp1, mlp2, mlp3 = mlp1__mlp2__mlp3
    x_batch, y_batch, z_batch = x_batch__y_batch__z_batch
    mlp1_out = jax.vmap(mlp1)(x_batch)
    mlp2_out = jax.vmap(mlp2)(y_batch)
    mlp3_out = jax.vmap(mlp3)(z_batch)
    return jnp.sum(mlp1_out * mlp2_out * mlp3_out, axis=1)


@f.def_jvp
def f_jvp(primals, tangents, *, mlp1__mlp2__mlp3):
    mlp1, mlp2, mlp3 = mlp1__mlp2__mlp3
    ((x_batch, y_batch, z_batch),) = primals
    ((x_batch_dot, y_batch_dot, z_batch_dot),) = tangents

    @ft.partial(jax.vmap, in_axes=(None, 0, 0))
    def jvp_mlp(mlp, batch, batch_dot):
        out, out_dot = jax.jvp(mlp, (batch,), (batch_dot,))
        return out, out_dot

    mlp1_out, mlp1_out_dot = jvp_mlp(mlp1, x_batch, x_batch_dot)
    mlp2_out, mlp2_out_dot = jvp_mlp(mlp2, y_batch, y_batch_dot)
    mlp3_out, mlp3_out_dot = jvp_mlp(mlp3, z_batch, z_batch_dot)

    primals_out = jnp.sum(mlp1_out * mlp2_out * mlp3_out, axis=1)
    tangents_out = jnp.sum(
        mlp1_out_dot * mlp2_out * mlp3_out
        + mlp1_out * mlp2_out_dot * mlp3_out
        + mlp1_out * mlp2_out * mlp3_out_dot,
        axis=1,
    )

    return primals_out, tangents_out


latent_size = 64
width_size = 20
depth = 6
activation = jax.nn.tanh

key = jr.PRNGKey(0)
key1, key2, key3, key4 = jr.split(key, 4)

mlp1 = eqx.nn.MLP(
    in_size="scalar",
    out_size=latent_size,
    width_size=width_size,
    depth=depth,
    activation=activation,
    key=key1,
)
mlp2 = eqx.nn.MLP(
    in_size="scalar",
    out_size=latent_size,
    width_size=width_size,
    depth=depth,
    activation=activation,
    key=key2,
)

mlp3 = eqx.nn.MLP(
    in_size="scalar",
    out_size=latent_size,
    width_size=width_size,
    depth=depth,
    activation=activation,
    key=key3,
)

mlp4 = eqx.nn.MLP(
    in_size=3,
    out_size="scalar",
    width_size=width_size,
    depth=depth,
    key=key4,
)

n = 75
x_batch = jnp.arange(n, dtype=float)
y_batch = jnp.arange(n, dtype=float)
z_batch = jnp.arange(n, dtype=float)

x_, y_, z_ = jnp.meshgrid(x_batch, y_batch, z_batch, indexing="ij")
xyz_batch = jnp.column_stack((x_.ravel(), y_.ravel(), z_.ravel()))


eqx.filter_jit(eqx.filter_jacfwd(f))(
    (x_batch, y_batch, z_batch), mlp1__mlp2__mlp3=(mlp1, mlp2, mlp3)
)
start_time = time.time()
eqx.filter_jit(eqx.filter_jacfwd(f))(
    (x_batch, y_batch, z_batch), mlp1__mlp2__mlp3=(mlp1, mlp2, mlp3)
)
end_time = time.time()
print(f"\nTime taken for separable jacfwd: {end_time - start_time:.4f} seconds")

eqx.filter_jit(jax.vmap(eqx.filter_jacfwd(mlp4)))(xyz_batch)
start_time = time.time()
eqx.filter_jit(jax.vmap(eqx.filter_jacfwd(mlp4)))(xyz_batch)
end_time = time.time()
print(f"Time taken for non-separable jacfwd: {end_time - start_time:.4f} seconds\n")


eqx.filter_jit(eqx.filter_jacfwd(eqx.filter_jacfwd(f)))(
    (x_batch, y_batch, z_batch), mlp1__mlp2__mlp3=(mlp1, mlp2, mlp3)
)
start_time = time.time()
eqx.filter_jit(eqx.filter_jacfwd(eqx.filter_jacfwd(f)))(
    (x_batch, y_batch, z_batch), mlp1__mlp2__mlp3=(mlp1, mlp2, mlp3)
)
end_time = time.time()
print(f"Time taken for separable jacfwd(jacfwd): {end_time - start_time:.4f} seconds")


eqx.filter_jit(jax.vmap(eqx.filter_jacfwd(eqx.filter_jacfwd(mlp4))))(xyz_batch)
start_time = time.time()
eqx.filter_jit(jax.vmap(eqx.filter_jacfwd(eqx.filter_jacfwd(mlp4))))(xyz_batch)
end_time = time.time()
print(
    f"Time taken for non-separable jacfwd(jacfwd): {end_time - start_time:.4f} seconds\n"
)

shows the following output:

Time taken for separable jacfwd: 0.0077 seconds
Time taken for non-separable jacfwd: 0.1727 seconds

Time taken for separable jacfwd(jacfwd): 1.2657 seconds
Time taken for non-separable jacfwd(jacfwd): 0.0311 seconds

Not only is jacfwd(jacfwd) in the separable regime significantly slower than in the non-separable, but the non-separable jacfwd(jacfwd) is significantly faster than the non-separable jacfwd.

I'm not sure if this is an equinox-specific question or more generally jax, though any input you may have would be greatly appreciated.

@patrick-kidger
Copy link
Owner

Indeed, I think this is a general-JAX question! There's probably no special Equinox insight I can offer here I'm afraid :)

@patrick-kidger patrick-kidger added the question User queries label Aug 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants