Skip to content

Commit 60288f6

Browse files
adam-smnkrolfmoreltkarnafschlimb
authored
[examples][mlir] Basic MLIR compilation and execution example (#10)
Adds a simple end-to-end example demonstrating programmatic transform schedule creation, MLIR JIT compilation, execution, and numerical verification of the result. Additionally, 'utils' submodule is added with basic tools to simplify creation of ctype arguments in format expected by MLIR jitted functions. Co-authored-by: Rolf Morel <[email protected]> Co-authored-by: Tuomas Kärnä <[email protected]> Co-authored-by: Frank Schlimbach <[email protected]>
1 parent c89575c commit 60288f6

File tree

3 files changed

+268
-0
lines changed

3 files changed

+268
-0
lines changed
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import torch
2+
import os
3+
4+
from mlir import ir
5+
from mlir.dialects import transform
6+
from mlir.dialects.transform import structured
7+
from mlir.dialects.transform import interpreter
8+
from mlir.execution_engine import ExecutionEngine
9+
from mlir.passmanager import PassManager
10+
11+
from lighthouse import utils as lh_utils
12+
13+
14+
def create_kernel(ctx: ir.Context) -> ir.Module:
15+
"""
16+
Create an MLIR module containing a function to execute.
17+
18+
Args:
19+
ctx: MLIR context.
20+
"""
21+
with ctx:
22+
module = ir.Module.parse(
23+
r"""
24+
// Compute element-wise addition.
25+
func.func @add(%a: memref<16x32xf32>, %b: memref<16x32xf32>, %out: memref<16x32xf32>) {
26+
linalg.add ins(%a, %b : memref<16x32xf32>, memref<16x32xf32>)
27+
outs(%out : memref<16x32xf32>)
28+
return
29+
}
30+
"""
31+
)
32+
return module
33+
34+
35+
def create_schedule(ctx: ir.Context) -> ir.Module:
36+
"""
37+
Create an MLIR module containing transformation schedule.
38+
The schedule provides partial lowering to scalar operations.
39+
40+
Args:
41+
ctx: MLIR context.
42+
"""
43+
with ctx, ir.Location.unknown(context=ctx):
44+
# Create transform module.
45+
schedule = ir.Module.create()
46+
schedule.operation.attributes["transform.with_named_sequence"] = (
47+
ir.UnitAttr.get()
48+
)
49+
50+
# For simplicity, use generic matchers without requiring specific types.
51+
anytype = transform.any_op_t()
52+
53+
# Create entry point transformation sequence.
54+
with ir.InsertionPoint(schedule.body):
55+
named_seq = transform.NamedSequenceOp(
56+
sym_name="__transform_main",
57+
input_types=[anytype],
58+
result_types=[],
59+
arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}],
60+
)
61+
62+
# Create the schedule.
63+
with ir.InsertionPoint(named_seq.body):
64+
# Find the kernel's function op.
65+
func = structured.MatchOp.match_op_names(
66+
named_seq.bodyTarget, ["func.func"]
67+
)
68+
# Use C interface wrappers - required to make function executable after jitting.
69+
func = transform.apply_registered_pass(
70+
anytype, func, "llvm-request-c-wrappers"
71+
)
72+
73+
# Find the kernel's module op.
74+
mod = transform.get_parent_op(
75+
anytype, func, op_name="builtin.module", deduplicate=True
76+
)
77+
# Naive lowering to loops.
78+
mod = transform.apply_registered_pass(
79+
anytype, mod, "convert-linalg-to-loops"
80+
)
81+
# Cleanup.
82+
transform.apply_cse(mod)
83+
with ir.InsertionPoint(transform.ApplyPatternsOp(mod).patterns):
84+
transform.apply_patterns_canonicalization()
85+
86+
# Terminate the schedule.
87+
transform.yield_([])
88+
return schedule
89+
90+
91+
def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None:
92+
"""
93+
Apply transformation schedule to a kernel module.
94+
The kernel is modified in-place.
95+
96+
Args:
97+
kernel: A module with payload function.
98+
schedule: A module with transform schedule.
99+
"""
100+
interpreter.apply_named_sequence(
101+
payload_root=kernel,
102+
transform_root=schedule.body.operations[0],
103+
transform_module=schedule,
104+
)
105+
106+
107+
def create_pass_pipeline(ctx: ir.Context) -> PassManager:
108+
"""
109+
Create an MLIR pass pipeline.
110+
The pipeline lowers operations further down to LLVM dialect.
111+
112+
Args:
113+
ctx: MLIR context.
114+
"""
115+
with ctx:
116+
# Create a pass manager that applies passes to the whole module.
117+
pm = PassManager("builtin.module")
118+
# Lower to LLVM.
119+
pm.add("convert-scf-to-cf")
120+
pm.add("convert-to-llvm")
121+
pm.add("reconcile-unrealized-casts")
122+
# Cleanup
123+
pm.add("cse")
124+
pm.add("canonicalize")
125+
return pm
126+
127+
128+
# The example's entry point.
129+
def main():
130+
### Baseline computation ###
131+
# Create inputs.
132+
a = torch.randn(16, 32, dtype=torch.float32)
133+
b = torch.randn(16, 32, dtype=torch.float32)
134+
135+
# Compute baseline result to verify numerical correctness.
136+
out_ref = torch.add(a, b)
137+
138+
### MLIR payload preparation ###
139+
# Create payload kernel.
140+
ctx = ir.Context()
141+
kernel = create_kernel(ctx)
142+
143+
# Create a transform schedule and apply initial lowering.
144+
schedule = create_schedule(ctx)
145+
apply_schedule(kernel, schedule)
146+
147+
# Create a pass pipeline and lower the kernel to LLVM dialect.
148+
pm = create_pass_pipeline(ctx)
149+
pm.run(kernel.operation)
150+
151+
### Compilation ###
152+
# External shared libraries, containing MLIR runner utilities, are generally
153+
# required to execute the compiled module.
154+
# In this case, MLIR runner utils libraries are expected:
155+
# - libmlir_runner_utils.so
156+
# - libmlir_c_runner_utils.so
157+
#
158+
# Get paths to MLIR runner shared libraries through an environment variable.
159+
# The execution engine requires full paths to the libraries.
160+
# For example, the env variable can be set as:
161+
# LIGHTHOUSE_SHARED_LIBS=$PATH_TO_LLVM/build/lib/lib1.so:$PATH_TO_LLVM/build/lib/lib2.so
162+
mlir_libs = os.environ.get("LIGHTHOUSE_SHARED_LIBS", default="").split(":")
163+
164+
# JIT the kernel.
165+
eng = ExecutionEngine(kernel, opt_level=2, shared_libs=mlir_libs)
166+
167+
# Initialize the JIT engine.
168+
#
169+
# The deferred initialization executes global constructors that might have been
170+
# created by the module during engine creation (for example, when `gpu.module`
171+
# is present) or registered afterwards.
172+
#
173+
# Initialization is not strictly necessary in this case.
174+
# However, it is a good practice to perform it regardless.
175+
eng.initialize()
176+
177+
# Get the kernel function.
178+
add_func = eng.lookup("add")
179+
180+
### Execution ###
181+
# Create an empty buffer to hold results.
182+
out = torch.empty_like(out_ref)
183+
184+
# Execute the kernel.
185+
args = lh_utils.torch_to_packed_args([a, b, out])
186+
add_func(args)
187+
188+
### Verification ###
189+
# Check numerical correctness.
190+
if not torch.allclose(out_ref, out, rtol=0.01, atol=0.01):
191+
print("Error! Result mismatch!")
192+
else:
193+
print("Result matched!")
194+
195+
196+
if __name__ == "__main__":
197+
main()
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""A collection of utility tools"""
2+
3+
from .runtime_args import (
4+
get_packed_arg,
5+
memref_to_ctype,
6+
memrefs_to_packed_args,
7+
torch_to_memref,
8+
torch_to_packed_args,
9+
)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import ctypes
2+
import torch
3+
4+
from mlir.runtime.np_to_memref import (
5+
get_ranked_memref_descriptor,
6+
)
7+
8+
9+
def get_packed_arg(ctypes_args) -> list[ctypes.c_void_p]:
10+
"""
11+
Return a list of packed ctype arguments compatible with
12+
jitted MLIR function's interface.
13+
14+
Args:
15+
ctypes_args: A list of ctype pointer arguments.
16+
"""
17+
packed_args = (ctypes.c_void_p * len(ctypes_args))()
18+
for argNum in range(len(ctypes_args)):
19+
packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
20+
return packed_args
21+
22+
23+
def memref_to_ctype(memref_desc) -> ctypes._Pointer:
24+
"""
25+
Convert a memref descriptor into a ctype argument.
26+
27+
Args:
28+
memref_desc: An MLIR memref descriptor.
29+
"""
30+
return ctypes.pointer(ctypes.pointer(memref_desc))
31+
32+
33+
def memrefs_to_packed_args(memref_descs) -> list[ctypes.c_void_p]:
34+
"""
35+
Convert a list of memref descriptors into packed ctype arguments.
36+
37+
Args:
38+
memref_descs: A list of memref descriptors.
39+
"""
40+
ctype_args = [memref_to_ctype(memref) for memref in memref_descs]
41+
return get_packed_arg(ctype_args)
42+
43+
44+
def torch_to_memref(input: torch.Tensor) -> ctypes.Structure:
45+
"""
46+
Convert a PyTorch tensor into a memref descriptor.
47+
48+
Args:
49+
input: PyTorch tensor.
50+
"""
51+
return get_ranked_memref_descriptor(input.numpy())
52+
53+
54+
def torch_to_packed_args(inputs: list[torch.Tensor]) -> list[ctypes.c_void_p]:
55+
"""
56+
Convert a list of PyTorch tensors into packed ctype arguments.
57+
58+
Args:
59+
inputs: A list of PyTorch tensors.
60+
"""
61+
memrefs = [torch_to_memref(input) for input in inputs]
62+
return memrefs_to_packed_args(memrefs)

0 commit comments

Comments
 (0)