diff --git a/.github/README.md b/.github/README.md index c6b248f4b0..b31590caf9 100644 --- a/.github/README.md +++ b/.github/README.md @@ -196,6 +196,7 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi Loop strength reduction is known to cause up to 10% performance changes for certain kernels with register pressure. - `TRITON_ALWAYS_COMPILE=1` forces to compile kernels regardless of cache hit. +- `TRITON_WRITE_IR_METADATA=1` writes metadata about IR generated for compiled kernels. This metadata is needed to compare between IRs generated across different versions of triton. - `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass. - `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass. - `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma). diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 60f0c2e616..ade0c725ba 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -11,6 +11,7 @@ from ..tools.disasm import get_sass, get_spvdis # TODO: this shouldn't be here from .code_generator import ast_to_ttir +from .write_ir_metadata import write_metadata from pathlib import Path import re import functools @@ -302,6 +303,9 @@ def compile(src, target=None, options=None): metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) + write_ir_metadata = os.environ.get("TRITON_WRITE_IR_METADATA", "0") == "1" + if write_ir_metadata: + write_metadata(fn_cache_manager.key, src) # Compilation completed, disabling multithreading in context. # This is needed to safely finalize threads pool inside context: if current process forks before # python GC deletes context object, thread pool in child process will be invalid, which could diff --git a/python/triton/compiler/write_ir_metadata.py b/python/triton/compiler/write_ir_metadata.py new file mode 100644 index 0000000000..2e822b4c1c --- /dev/null +++ b/python/triton/compiler/write_ir_metadata.py @@ -0,0 +1,52 @@ +import subprocess +import os +import hashlib +import json +import inspect +from ..runtime.cache import default_dump_dir + + +def get_commit_hash(): + script_dir = os.path.dirname(os.path.abspath(__file__)) + repo_path = subprocess.check_output(['git', 'rev-parse', '--show-toplevel'], cwd=script_dir).decode().strip() + commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=repo_path).decode().strip() + return commit_hash + + +def write_metadata(fn_cache_manager_key, src): + """ + Write metadata about IRs generated in triton + """ + metadata_dir = default_dump_dir() + "/IR_metadata" + constants = src.constants + kernel_src = src.fn.src + kernel_name = src.name + func = src.fn.fn + src_file = inspect.getfile(func) + line_no = inspect.getsourcelines(func)[1] + commit_hash = get_commit_hash() + metadata_filename = metadata_dir + "/" + commit_hash + ".json" + if not os.path.exists(metadata_filename): + os.makedirs(metadata_dir, exist_ok=True) + data = {} + else: + # Load existing data from the JSON file + try: + with open(metadata_filename, 'r') as f: + data = json.load(f) + except json.JSONDecodeError: + print(f"Warning: {metadata_filename} is corrupted. Creating a new file.") + data = {} + constants = str(constants) + metadata_key = kernel_src + constants + hashed_key = hashlib.sha256(metadata_key.encode("utf-8")).hexdigest() + value_dict = { + "fn_cache_manager_key": fn_cache_manager_key, "kernel_name": kernel_name, "definition_location": + f"{src_file}:{line_no}", "constants": constants + } + data[hashed_key] = value_dict + try: + with open(metadata_filename, 'w') as f: + json.dump(data, f, indent=4) + except Exception as e: + print(f"Error writing to {metadata_filename}: {e}")