-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathcompiler.py
448 lines (391 loc) · 18.4 KB
/
compiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
from __future__ import annotations
import hashlib
import json
from .._C.libtriton import get_cache_invalidating_env_vars, ir
from ..backends import backends
from ..backends.compiler import GPUTarget
from .. import __version__
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
from ..runtime.driver import driver
from ..tools.disasm import get_sass, get_spvdis
# TODO: this shouldn't be here
from .code_generator import ast_to_ttir
from .write_ir_metadata import write_metadata
from pathlib import Path
import re
import functools
import os
import sysconfig
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
# and any following whitespace
# - (public\s+)? : optionally match the keyword public and any following whitespace
# - (@\w+) : match an @ symbol followed by one or more word characters
# (letters, digits, or underscores), and capture it as group 1 (the function name)
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
prototype_pattern = {
"ptx": ptx_prototype_pattern,
}
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
arg_type_pattern = {
"ptx": ptx_arg_type_pattern,
}
def convert_type_repr(x):
# Currently we only capture the pointer type and assume the pointer is on global memory.
# TODO: Capture and support shared memory space
match = re.search(r'!tt\.ptr<([^,]+)', x)
tma = re.search(r'tt.nv_tma_desc = 1', x)
if tma is not None:
return 'nvTmaDesc'
x = re.sub(r' {[^}]+}', '', x)
if match is not None:
return '*' + convert_type_repr(match.group(1))
return x
class ASTSource:
def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
self.fn = fn
self.ext = "ttir"
self.name = fn.__name__
self.signature = signature
self.constants = dict()
if constexprs is not None:
for k, v in constexprs.items():
k = (fn.arg_names.index(k), ) if isinstance(k, str) else k
assert isinstance(k, tuple)
self.constants[k] = v
self.attrs = attrs or dict()
if isinstance(self.signature, str):
self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))}
else:
for k in self.signature.keys():
if not isinstance(k, str):
raise TypeError("Signature keys must be string")
def hash(self):
sorted_sig = [v for k, v in sorted(self.signature.items())]
get_key = lambda x: x.cache_key if hasattr(x, 'cache_key') else str(x)
constants_key = '-'.join([get_key(v) for k, v in sorted(self.constants.items())])
key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}"
return hashlib.sha256(key.encode("utf-8")).hexdigest()
def make_ir(self, options, codegen_fns, module_map, context):
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
module_map=module_map)
def parse_options(self):
return dict()
class IRSource:
def __init__(self, path, context, backend):
self.path = path
path = Path(path)
self.ext = path.suffix[1:]
self.src = path.read_text()
ir.load_dialects(context)
backend.load_dialects(context)
# We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
# TODO - replace with a proper parser
if self.ext == "ptx":
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
self.name = match.group(1)
signature = match.group(2)
types = re.findall(arg_type_pattern[self.ext], signature)
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
else:
self.module = ir.parse_mlir_module(self.path, context)
fn_name = self.module.get_entry_func_name()
self.name = "@" + fn_name
funcOp = self.module.get_function(fn_name)
func_ty = self.module.get_function_signature(funcOp)
self.signature = {k: ty for k, ty in enumerate(func_ty)}
def hash(self):
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
def make_ir(self, options, codegen_fns, module_map, context):
self.module.context = context
return self.module
def parse_options(self):
if self.ext == "ttgir":
num_warps = self.module.get_int_attr("ttg.num-warps")
assert num_warps is not None, "Unable to parse ttg.num-warps attribute"
return {'num_warps': num_warps}
return dict()
@functools.lru_cache()
def triton_key():
import pkgutil
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
contents = []
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.sha256(f.read()).hexdigest()]
# compiler
path_prefixes = [
(os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
(os.path.join(TRITON_PATH, "backends"), "triton.backends."),
]
for path, prefix in path_prefixes:
for lib in pkgutil.walk_packages([path], prefix=prefix):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha256(f.read()).hexdigest()]
# backend
libtriton_hash = hashlib.sha256()
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
while True:
chunk = f.read(1024**2)
if not chunk:
break
libtriton_hash.update(chunk)
contents.append(libtriton_hash.hexdigest())
# language
language_path = os.path.join(TRITON_PATH, 'language')
for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha256(f.read()).hexdigest()]
return f'{__version__}' + '-'.join(contents)
def parse(full_name, ext, context):
if ext == "ttir" or ext == "ttgir":
module = ir.parse_mlir_module(full_name, context)
module.context = context
return module
if ext == "llir" or ext == "ptx" or ext == "amdgcn":
return Path(full_name).read_text()
if ext == "cubin" or ext == "hsaco":
return Path(full_name).read_bytes()
if ext == "spv":
return Path(full_name).read_bytes()
def filter_traceback(e: BaseException):
"""
Removes code_generator.py and related files from tracebacks.
These are uninteresting to the user -- "just show me *my* code!"
"""
if os.getenv("TRITON_FRONT_END_DEBUGGING", "0") == "1":
return
if e.__cause__ is not None:
filter_traceback(e.__cause__)
if e.__context__ is not None:
filter_traceback(e.__context__)
# If a user has a file that matches one of these, they're out of luck.
BAD_FILES = [
"/triton/compiler/code_generator.py",
"/ast.py",
]
BAD_FILES = [bad_file.replace("/", os.sep) for bad_file in BAD_FILES]
tb = e.__traceback__
frames = []
while tb is not None:
if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)):
frames.append(tb)
tb = tb.tb_next
for (cur_frame, next_frame) in zip(frames, frames[1:]):
cur_frame.tb_next = next_frame
if not frames:
e.__traceback__ = None
else:
frames[-1].tb_next = None
e.__traceback__ = frames[0]
def compile(src, target=None, options=None):
if target is None:
target = driver.active.get_current_target()
assert isinstance(target, GPUTarget), "target must be of GPUTarget type"
backend = make_backend(target)
ir_source = not isinstance(src, ASTSource)
# create backend
if ir_source:
assert isinstance(src, str), "source must be either AST or a filepath"
context = ir.context()
src = IRSource(src, context, backend)
extra_options = src.parse_options()
options = backend.parse_options(dict(options or dict(), **extra_options))
# create cache manager
env_vars = get_cache_invalidating_env_vars()
key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}"
hash = hashlib.sha256(key.encode("utf-8")).hexdigest()
fn_cache_manager = get_cache_manager(hash)
# For dumping/overriding only hash the source as we want it to be independent of triton
# core changes to make it easier to track kernels by hash.
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1"
store_only_binary = os.environ.get("TRITON_STORE_BINARY_ONLY", "0") == "1"
fn_override_manager = get_override_manager(src.hash()) if enable_override else None
fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None
# Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms.
# The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}".
# A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate
# the file name to 150 characters to be safe.
file_name = src.name[:150]
metadata_filename = f"{file_name}.json"
metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
metadata_path = metadata_group.get(metadata_filename)
always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1"
if not always_compile and metadata_path is not None:
# cache hit!
return CompiledKernel(src, metadata_group, hash)
# initialize metadata
metadata = {
"hash": hash,
"target": target,
**options.__dict__,
**env_vars,
}
# run compilation pipeline and populate metadata
stages = dict()
backend.add_stages(stages, options)
first_stage = list(stages.keys()).index(src.ext)
# when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
if ir_source:
first_stage += 1
# For IRSource, we have already grabbed the context + called both
# ir.load_dialects and backend.load_dialects.
if not isinstance(src, IRSource):
context = ir.context()
ir.load_dialects(context)
backend.load_dialects(context)
codegen_fns = backend.get_codegen_implementation(options)
module_map = backend.get_module_map()
try:
module = src.make_ir(options, codegen_fns, module_map, context)
except Exception as e:
filter_traceback(e)
raise
use_ir_loc = os.environ.get("USE_IR_LOC", None)
for ext, compile_ir in list(stages.items())[first_stage:]:
next_module = compile_ir(module, metadata)
ir_filename = f"{file_name}.{ext}"
if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None):
print(f"\nOverriding kernel with file {full_name}")
next_module = parse(full_name, ext, context)
# If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
if (not store_only_binary) or (ext in ("cubin", "hsaco", "json", "spv")):
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
if fn_dump_manager is not None:
fn_dump_manager.put(next_module, ir_filename)
# use an env variable to parse ir from file
if use_ir_loc == ext:
ir_full_name = fn_cache_manager.get_file(ir_filename)
next_module.create_location_snapshot(ir_full_name)
print(f"Creating new locations for {ir_full_name}")
module = next_module
# write-back metadata
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
binary=False)
fn_cache_manager.put_group(metadata_filename, metadata_group)
write_ir_metadata = os.environ.get("TRITON_WRITE_IR_METADATA", "0") == "1"
if write_ir_metadata:
write_metadata(fn_cache_manager.key, src)
# Compilation completed, disabling multithreading in context.
# This is needed to safely finalize threads pool inside context: if current process forks before
# python GC deletes context object, thread pool in child process will be invalid, which could
# lead to child crash or hang.
#
# However disabling multithreading causes the code to hang if the ASAN pass is enabled
# this is likely due to the llvm-symbolizer forking a process
# TODO: Reconcile the difference here between the ASAN and non-ASAN path with enabling
# multithreading in the MLIR context
if not os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
context.disable_multithreading()
# return handle to compiled kernel
return CompiledKernel(src, metadata_group, hash)
def make_backend(target):
actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)]
if len(actives) != 1:
raise RuntimeError(
f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.")
return actives[0](target)
class LazyDict:
def __init__(self, data):
self.data = data
self.extras = []
def get(self) -> None:
for func, args in self.extras:
self.data = self.data | func(*args)
self.extras.clear()
return self.data
def add(self, func, args):
self.extras.append((func, args))
class AsmDict(dict):
def __missing__(self, key):
if key == "sass":
value = get_sass(self["cubin"])
if key == "spvdis":
value = get_spvdis(self["spv"])
else:
raise KeyError("Unknown key: '%s'" % key)
self[key] = value
return value
class CompiledKernel:
# Hooks for external tools to monitor the execution of triton kernels
# TODO: move out of this namespace since it's a runtime thing
launch_enter_hook = None
launch_exit_hook = None
def __init__(self, src, metadata_group, hash):
from collections import namedtuple
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
metadata = json.loads(metadata_path.read_text())
metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
# JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
target = metadata['target']
metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])
KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys())))
self.metadata = KernelMetadata(**metadata)
backend = make_backend(self.metadata.target)
self.packed_metadata = backend.pack_metadata(self.metadata)
self.src = src
self.hash = hash
self.name = self.metadata.name
# stores the text of each level of IR that was generated during compilation
asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")]
binary_ext = backend.binary_ext
self.asm = AsmDict({
file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
for file in asm_files
})
self.kernel = self.asm[binary_ext]
# binaries are lazily initialized
# because it involves doing runtime things
# (e.g., checking amount of shared memory on current device)
self.module = None
self.function = None
def _init_handles(self):
if self.module is not None:
return
device = driver.active.get_current_device()
# create launcher
self.run = driver.active.launcher_cls(self.src, self.metadata)
# not enough shared memory to run the kernel
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
if self.metadata.shared > max_shared:
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:
# Use blackwell max tmem size for now, this should be moved in device properties
max_tmem_size = 512 # tmem size in number of columns
if self.metadata.tmem_size > max_tmem_size:
raise OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory")
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
self.name, self.kernel, self.metadata.shared, self.metadata.build_flags, device)
def __getattribute__(self, name):
if name == 'run':
self._init_handles()
return super().__getattribute__(name)
def launch_metadata(self, grid, stream, *args):
if CompiledKernel.launch_enter_hook is None:
return None
ret = LazyDict({"name": self.name, "function": self.function, "stream": stream})
if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None:
return ret
arg_dict = {}
arg_idx = 0
for i, arg_name in enumerate(self.src.fn.arg_names):
arg_dict[arg_name] = args[arg_idx]
arg_idx += 1
ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict))
return ret
def __getitem__(self, grid):
self._init_handles()
def runner(*args, stream=None):
if stream is None:
device = driver.active.get_current_device()
stream = driver.active.get_current_stream(device)
launch_metadata = self.launch_metadata(grid, stream, *args)
self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args)
return runner