Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Frontend] Avoid generating MLIR twice for the same function when com…
…puting its gradient. (#1172) **Context:** The `Function` class is used to wrap around python functions that are passed as parameters to the `grad` function. This has the side effect that each evaluation of the grad function will generate a new `Function` class. Therefore, even though we are correctly caching the StableHLO lowering of each function, each time a gradient evaluation is performed a new function will be lowered. E.g., ```python @qjit(target="mlir") def test_gradient_used_twice(x: float): def identity(x): return x diff_identity = grad(identity) return diff_identity(x) + diff_identity(x) ``` will generate functions `@identity` and `@identity_0` in MLIR. ```mlir func.func public @jit_test_gradient_used_twice(%arg0: tensor<f64>) -> tensor<f64> attributes {llvm.emit_c_interface} { %0 = gradient.grad "auto" @identity(%arg0) {diffArgIndices = dense<0> : tensor<1xi64>} : (tensor<f64>) -> tensor<f64> %1 = gradient.grad "auto" @identity_0(%arg0) {diffArgIndices = dense<0> : tensor<1xi64>} : (tensor<f64>) -> tensor<f64> %2 = stablehlo.add %0, %1 : tensor<f64> return %2 : tensor<f64> } func.func private @identity(%arg0: tensor<f64>) -> tensor<f64> attributes {llvm.linkage = #llvm.linkage<internal>} { return %arg0 : tensor<f64> } func.func private @identity_0(%arg0: tensor<f64>) -> tensor<f64> attributes {llvm.linkage = #llvm.linkage<internal>} { return %arg0 : tensor<f64> } ``` **Description of the Change:** This patch adds a cache to the `Function` class. This leads to only one single lowering of arbitrary number of calls to the gradient function. **Benefits:** Faster compile times, specially when not using autograph and calling the gradient in a loop. **Possible Drawbacks:** **Related GitHub Issues:**
- Loading branch information