|
1 | 1 | #!/usr/bin/env python3 |
2 | 2 |
|
| 3 | +import sys |
3 | 4 | from pathlib import Path |
4 | 5 |
|
5 | 6 | from mlir import ir, passmanager |
6 | 7 | from lighthouse.ingress import torch as torch_ingress |
7 | 8 |
|
8 | 9 |
|
9 | 10 | kernels_as_pytorch_folder = Path(__file__).parent / "KernelBench" / "KernelBench" |
| 11 | + |
| 12 | +if not (kernels_as_pytorch_folder.exists() and kernels_as_pytorch_folder.is_dir()): |
| 13 | + print( |
| 14 | + "ERROR: KernelBench repo not found.\n" |
| 15 | + "NOTE: Pull in dependency with: git submodule update " |
| 16 | + + str(kernels_as_pytorch_folder.parent.relative_to(Path.cwd())) |
| 17 | + + "", |
| 18 | + file=sys.stderr, |
| 19 | + ) |
| 20 | + sys.exit(1) |
| 21 | + |
| 22 | + |
10 | 23 | kernels_as_pytorch_level1 = kernels_as_pytorch_folder / "level1" |
11 | 24 | kernels_as_pytorch_level2 = kernels_as_pytorch_folder / "level2" |
12 | 25 |
|
|
105 | 118 | pm = passmanager.PassManager(context=ctx) |
106 | 119 | pm.add("linalg-specialize-generic-ops") |
107 | 120 |
|
| 121 | +print("Output directory:", kernels_as_mlir_folder) |
108 | 122 | for pytorch_level, mlir_level in ( |
109 | 123 | (kernels_as_pytorch_level1, kernels_as_mlir_level1), |
110 | 124 | (kernels_as_pytorch_level2, kernels_as_mlir_level2), |
|
133 | 147 | mlir_kernel = torch_ingress.import_from_file( |
134 | 148 | kernel_pytorch_file, ir_context=ctx |
135 | 149 | ) |
| 150 | + assert isinstance(mlir_kernel, ir.Module) |
136 | 151 |
|
137 | | - before_clean_up = "//" + str(mlir_kernel)[:-1].replace("\n", "\n//") + "\n" |
138 | 152 | try: |
139 | 153 | pm.run(mlir_kernel.operation) # cleanup |
140 | 154 | except Exception as e: |
141 | | - print(f"Error: got the following error cleaning up {kernel_name}") |
| 155 | + print(f"Error: got the following error cleaning up '{kernel_name}'") |
142 | 156 | raise e |
143 | 157 |
|
144 | 158 | with kernel_as_mlir_path.open("w") as f: |
145 | | - print("// Torch-MLIR output:", file=f) |
146 | | - print(before_clean_up, file=f) |
147 | | - print("// MLIR output after clean-up:", file=f) |
| 159 | + print("// MLIR output after conversion and clean-up:", file=f) |
148 | 160 | print(mlir_kernel, file=f) |
0 commit comments