You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to implement a Separable MLP, where instead of vmap-ing a single MLP(in_size=3, out_size="scalar", ...), I have 3 separate MLP(in_size="scalar", out_size=latent_size, ...)s, where I vmap each individual MLP across its specified coordinate, then take outer product and sum across the latent dimension for the final scalar outputs of the cartesian product of the three separate coordinate batches.
I have implemented a custom jvp, which significantly speeds up jacfwd, and scales very well with increasing latent size. However, jacfwd(jacfwd) is significantly slower than the non-separable implementation, and scales very poorly with increasing latent size.
Time taken for separable jacfwd: 0.0077 seconds
Time taken for non-separable jacfwd: 0.1727 seconds
Time taken for separable jacfwd(jacfwd): 1.2657 seconds
Time taken for non-separable jacfwd(jacfwd): 0.0311 seconds
Not only is jacfwd(jacfwd) in the separable regime significantly slower than in the non-separable, but the non-separable jacfwd(jacfwd) is significantly faster than the non-separable jacfwd.
I'm not sure if this is an equinox-specific question or more generally jax, though any input you may have would be greatly appreciated.
The text was updated successfully, but these errors were encountered:
Hello,
I'm trying to implement a Separable MLP, where instead of
vmap
-ing a singleMLP(in_size=3, out_size="scalar", ...)
, I have 3 separateMLP(in_size="scalar", out_size=latent_size, ...)
s, where Ivmap
each individualMLP
across its specified coordinate, then take outer product and sum across the latent dimension for the final scalar outputs of the cartesian product of the three separate coordinate batches.I have implemented a custom jvp, which significantly speeds up
jacfwd
, and scales very well with increasing latent size. However,jacfwd(jacfwd)
is significantly slower than the non-separable implementation, and scales very poorly with increasing latent size.MWE:
shows the following output:
Not only is
jacfwd(jacfwd)
in the separable regime significantly slower than in the non-separable, but the non-separablejacfwd(jacfwd)
is significantly faster than the non-separablejacfwd
.I'm not sure if this is an equinox-specific question or more generally jax, though any input you may have would be greatly appreciated.
The text was updated successfully, but these errors were encountered: