Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed of function that was jitted with eqx.filter_jit is slower than Flax's jit equivalent on CPU #897

Open
Artur-Galstyan opened this issue Nov 17, 2024 · 1 comment
Labels
question User queries

Comments

@Artur-Galstyan
Copy link
Contributor

Hi there,

I was just comparing the performances of different ML libraries (basically to see if JAX is faster than PyTorch or Tensorflow - the consensus is that JAX is faster but I want numbers) and I noticed that the function generated by eqx.filter_jit was quite slow? It's easier to show in code. So, this is the very simple benchmark:

the pip installs to reproduce this pip install equinox tqdm polars
The timeit util, not the main point of the question, but here for completeness
import time
from functools import wraps

import polars as pl
from tqdm import tqdm


def timeit(
    func, print_time: bool = False, name: str | None = None, n_repeats: int = 10
):
    @wraps(func)
    def wrapper(*args, **kwargs):
        df = []
        result = func(*args, **kwargs)
        for _ in tqdm(range(n_repeats)):
            start = time.time()
            func(*args, **kwargs)
            end = time.time()
            df.append(end - start)
        if print_time:
            print(
                f"{func.__name__ if name is None else name} took {sum(df) / n_repeats:.2f} seconds on average"
            )

        return result, pl.DataFrame(
            {
                "function": func.__name__ if name is None else name,
                "avg_time": sum(df) / n_repeats,
                "n_repeats": n_repeats,
                "min": min(df),
                "max": max(df),
            }
        )

    return wrapper
class EquinoxLinearModule(eqx.Module):
    linear: eqx.nn.Linear

    def __init__(self, in_features: int, out_features: int):
        self.linear = eqx.nn.Linear(in_features, out_features, key=jax.random.key(0))

    def __call__(self, x):
        return self.linear(x)


def benchmark_equinox_linear(min_power: int = 1, max_power: int = 4):
    powers = [i for i in range(min_power, max_power + 1)]
    res = []
    for power in powers:
        lin = EquinoxLinearModule(10**power, 10**power)
        lin_jit = eqx.filter_jit(lin)  # very slow
        # lin_jit = eqx.filter_jit(lin.__call__)  # equally slow
        # lin_jit = jax.jit(lin)  # error
        # lin_jit = jax.jit(lin.__call__)  # extremely fast
        x = jax.random.normal(jax.random.PRNGKey(0), (10**power, 10**power))
        func = lambda: lin_jit(x)
        _, df = timeit(func, name=f"eqx-lin-{10}^{power}")()
        res.append(df)
    res = pl.concat([df for df in res], how="vertical")
    return res

eqx.filter_jit gives this result:

┌──────────────┬──────────┬───────────┬──────────┬──────────┐
│ function     ┆ avg_time ┆ n_repeats ┆ min      ┆ max      │
│ ---          ┆ ---      ┆ ---       ┆ ---      ┆ ---      │
│ str          ┆ f64      ┆ i64       ┆ f64      ┆ f64      │
╞══════════════╪══════════╪═══════════╪══════════╪══════════╡
│ eqx-lin-10^4 ┆ 2.937645 ┆ 10        ┆ 2.888101 ┆ 3.034286 │
└──────────────┴──────────┴───────────┴──────────┴──────────┘

whereas the jax.jit(lin.__call__) and the Flax version give these results respectively (which are besically equivalent):

┌──────────────┬──────────┬───────────┬──────────┬──────────┐
│ function     ┆ avg_time ┆ n_repeats ┆ min      ┆ max      │
│ ---          ┆ ---      ┆ ---       ┆ ---      ┆ ---      │
│ str          ┆ f64      ┆ i64       ┆ f64      ┆ f64      │
╞══════════════╪══════════╪═══════════╪══════════╪══════════╡
│ eqx-lin-10^4 ┆ 0.000013 ┆ 10        ┆ 0.000007 ┆ 0.000039 │
└──────────────┴──────────┴───────────┴──────────┴──────────┘
┌───────────────┬──────────┬───────────┬──────────┬──────────┐
│ function      ┆ avg_time ┆ n_repeats ┆ min      ┆ max      │
│ ---           ┆ ---      ┆ ---       ┆ ---      ┆ ---      │
│ str           ┆ f64      ┆ i64       ┆ f64      ┆ f64      │
╞═══════════════╪══════════╪═══════════╪══════════╪══════════╡
│ flax-lin-10^4 ┆ 0.000068 ┆ 10        ┆ 0.000028 ┆ 0.000394 │
└───────────────┴──────────┴───────────┴──────────┴──────────┘

So this left me a bit confused, because I always thought that eqx.filter_jit was just a thin wrapper around jax.jit but that wouldn't explain the large difference. My tests were performed on a Macbook M1 on the CPU.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Nov 17, 2024

You've forgotten to call .block_until_ready(). Equinox will actually call this for you automatically:

marker.block_until_ready()

But the others don't do this by default.

Equinox does this so that runtime errors are correctly surfaced during the JIT'd call, and not at some later point (or possibly not at all if the program stops before then).

I tested your benchmark with this addition, using both jax.jit and eqx.filter_jit, and get comparable timings.

@patrick-kidger patrick-kidger added the question User queries label Nov 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants