Skip to content

Commit

Permalink
[Frontend] Avoid generating MLIR twice for the same function when com…
Browse files Browse the repository at this point in the history
…puting its gradient. (#1172)

**Context:** The `Function` class is used to wrap around python
functions that are passed as parameters to the `grad` function. This has
the side effect that each evaluation of the grad function will generate
a new `Function` class. Therefore, even though we are correctly caching
the StableHLO lowering of each function, each time a gradient evaluation
is performed a new function will be lowered. E.g.,

```python
@qjit(target="mlir")
def test_gradient_used_twice(x: float):

    def identity(x):
        return x

    diff_identity = grad(identity)
    return diff_identity(x) + diff_identity(x)

```

will generate functions `@identity` and `@identity_0` in MLIR.

```mlir
  func.func public @jit_test_gradient_used_twice(%arg0: tensor<f64>) -> tensor<f64> attributes {llvm.emit_c_interface} {
    %0 = gradient.grad "auto" @identity(%arg0) {diffArgIndices = dense<0> : tensor<1xi64>} : (tensor<f64>) -> tensor<f64>
    %1 = gradient.grad "auto" @identity_0(%arg0) {diffArgIndices = dense<0> : tensor<1xi64>} : (tensor<f64>) -> tensor<f64>
    %2 = stablehlo.add %0, %1 : tensor<f64>
    return %2 : tensor<f64>
  }
  func.func private @identity(%arg0: tensor<f64>) -> tensor<f64> attributes {llvm.linkage = #llvm.linkage<internal>} {
    return %arg0 : tensor<f64>
  }
  func.func private @identity_0(%arg0: tensor<f64>) -> tensor<f64> attributes {llvm.linkage = #llvm.linkage<internal>} {
    return %arg0 : tensor<f64>
  }

```

**Description of the Change:** This patch adds a cache to the `Function`
class. This leads to only one single lowering of arbitrary number of
calls to the gradient function.

**Benefits:** Faster compile times, specially when not using autograph
and calling the gradient in a loop.

**Possible Drawbacks:**

**Related GitHub Issues:**
  • Loading branch information
erick-xanadu authored Oct 4, 2024
1 parent 7f0d3c3 commit 3bea1b6
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 0 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,13 @@
* Cached primitive lowerings is used instead of a custom cache structure.
[(#1159)](https://github.com/PennyLaneAI/catalyst/pull/1159)

* Calling gradients twice (with same GradParams) will now only lower to a single MLIR function.
[(#1172)](https://github.com/PennyLaneAI/catalyst/pull/1172)

* Samples on lightning.qubit/kokkos can now be seeded with `qjit(seed=...)`.
[(#1164)](https://github.com/PennyLaneAI/catalyst/pull/1164)


<h3>Breaking changes</h3>

* Remove `static_size` field from `AbstractQreg` class.
Expand Down
9 changes: 9 additions & 0 deletions frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ class Function:
AssertionError: Invalid function type.
"""

CACHE = {}

def __new__(cls, fn):
if cached_instance := cls.CACHE.get(fn):
return cached_instance
new_instance = super().__new__(cls)
cls.CACHE[fn] = new_instance
return new_instance

@debug_logger_init
def __init__(self, fn):
self.fn = fn
Expand Down
60 changes: 60 additions & 0 deletions frontend/test/lit/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,63 @@ def circuit(params):


print(grad_hoist_constant.mlir)


# CHECK-LABEL: @test_gradient_used_twice
@qjit(target="mlir")
def test_gradient_used_twice(x: float):
"""This tests that calling the return of grad
more than once does not define multiple functions.
"""

# CHECK-NOT: @identity_0
# CHECK-LABEL: @identity
# CHECK-NOT: @identity_0
def identity(x):
return x

diff_identity = grad(identity)
return diff_identity(x) + diff_identity(x)


print(test_gradient_used_twice.mlir)


# CHECK-LABEL: @test_gradient_taken_twice
@qjit(target="mlir")
def test_gradient_taken_twice(x: float):
"""This tests that calling grad
more than once does not define multiple functions.
"""

# CHECK-NOT: @identity_0
# CHECK-LABEL: @identity
# CHECK-NOT: @identity_0
def identity(x):
return x

diff_identity0 = grad(identity)
diff_identity1 = grad(identity)
return diff_identity0(x) + diff_identity1(x)


print(test_gradient_taken_twice.mlir)


# CHECK-LABEL: @test_higher_order_used_twice
@qjit(target="mlir")
def test_higher_order_used_twice(x: float):
"""Test that a single function is generated when using higher order derivatives"""

# CHECK-NOT: @identity_0
# CHECK-LABEL: @identity
# CHECK-NOT: @identity_0
def identity(x):
return x

dydx_identity = grad(identity, method="fd")
dy2dx2_identity = grad(dydx_identity, method="fd")
return dy2dx2_identity(x) + dy2dx2_identity(x)


print(test_higher_order_used_twice.mlir)

0 comments on commit 3bea1b6

Please sign in to comment.