Skip to content

Commit 29e69c2

Browse files
authored
[SIMT][XeVM] Add xevm dialect and conversions (#1047)
1 parent 7752bf1 commit 29e69c2

File tree

62 files changed

+5233
-7
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+5233
-7
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
with:
5959
path: |
6060
${{ github.workspace }}/mlir
61-
key: ${{ runner.os }}-build-llvm-${{ env.LLVM_CACHE_NUMBER }}-${{ env.LLVM_SHA }}
61+
key: ${{ runner.os }}-build-llvm-${{ env.LLVM_CACHE_NUMBER }}-${{ env.LLVM_SHA }}-nocache
6262

6363
- name: Build LLVM-MLIR
6464
if: steps.cache-llvm-mlir.outputs.cache-hit != 'true'
@@ -75,7 +75,7 @@ jobs:
7575
-DLLVM_ENABLE_ASSERTIONS=ON \
7676
-DLLVM_USE_LINKER=gold \
7777
-DLLVM_INSTALL_UTILS=ON \
78-
-DLLVM_TARGETS_TO_BUILD=X86 \
78+
-DLLVM_TARGETS_TO_BUILD="X86;SPIRV" \
7979
-DLLVM_ENABLE_BINDINGS=OFF \
8080
-DLLVM_ENABLE_ZSTD=OFF \
8181
-DCMAKE_INSTALL_PREFIX=${{ github.workspace }}/mlir

CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,22 @@ if(LINUX)
3737
message(STATUS "Building for target: LINUX")
3838
endif()
3939

40+
41+
string(FIND "${LLVM_TARGETS_TO_BUILD}" "SPIRV" SPIRV_FOUND)
42+
if(SPIRV_FOUND GREATER -1)
43+
message(STATUS "SPIRV is part of the LLVM targets")
44+
set(IMEX_SPIRV_BACKEND_ENABLED 1)
45+
else()
46+
message(STATUS "SPIRV is not part of the LLVM targets")
47+
set(IMEX_SPIRV_BACKEND_ENABLED 0)
48+
endif()
49+
50+
# option(IMEX_BUILD_SPIRV_BACKEND "Append SPIRV to LLVM_TARGETS_TO_BUILD" ON)
51+
# if(IMEX_BUILD_SPIRV_BACKEND)
52+
# set(LLVM_TARGETS_TO_BUILD "${LLVM_TARGETS_TO_BUILD};SPIRV" CACHE STRING "LLVM targets to build" FORCE)
53+
# message(STATUS "IMEX adds SPIRV target to LLVM, LLVM_TARGETS_TO_BUILD = ${LLVM_TARGETS_TO_BUILD}")
54+
# endif()
55+
4056
# Expected LLVM SHA
4157
file(STRINGS ${CMAKE_CURRENT_SOURCE_DIR}/build_tools/llvm_version.txt EXPECTED_LLVM_SHA)
4258
message(STATUS "Expected llvm sha: \"${EXPECTED_LLVM_SHA}\"")

include/imex/Conversion/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
#include <imex/Conversion/NDArrayToLinalg/NDArrayToLinalg.h>
2727
#include <imex/Conversion/RegionParallelLoopToGpu/RegionParallelLoopToGpu.h>
2828
#include <imex/Conversion/XeGPUToVC/XeGPUToVC.h>
29+
#include <imex/Conversion/XeGPUToXeVM/XeGPUToXeVM.h>
2930
#include <imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h>
31+
#include <imex/Conversion/XeVMToLLVM/XeVMToLLVM.h>
3032

3133
namespace imex {
3234

include/imex/Conversion/Passes.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,5 +466,32 @@ def ConvertArithToVC : Pass<"convert-arith-to-vc", "::mlir::gpu::GPUModuleOp"> {
466466
let constructor = "imex::createConvertArithToVCPass()";
467467
}
468468

469+
//===----------------------------------------------------------------------===//
470+
// XeVMToLLVM
471+
//===----------------------------------------------------------------------===//
472+
473+
def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
474+
let summary = "Convert XeVM to LLVM dialect";
475+
let dependentDialects = [
476+
"imex::xevm::XeVMDialect",
477+
];
478+
}
479+
480+
//===----------------------------------------------------------------------===//
481+
// XeGPUToXeVM
482+
//===----------------------------------------------------------------------===//
483+
484+
def ConvertXeGPUToXeVMPass : Pass<"convert-xegpu-to-xevm"> {
485+
let summary = "Convert XeGPU to XeVM dialect";
486+
let dependentDialects = [
487+
"::mlir::xegpu::XeGPUDialect",
488+
"::imex::xevm::XeVMDialect",
489+
"::mlir::vector::VectorDialect",
490+
"::mlir::memref::MemRefDialect",
491+
"::mlir::arith::ArithDialect",
492+
];
493+
}
494+
495+
469496

470497
#endif // _IMEX_CONVERSION_PASSES_TD_INCLUDED_
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===-- XeGPUToXeVM.h - Convert XeVM to LLVM dialect -------------*- C++
2+
//-*-===//
3+
//
4+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
10+
#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class DialectRegistry;
16+
class LLVMTypeConverter;
17+
class RewritePatternSet;
18+
class Pass;
19+
} // namespace mlir
20+
21+
namespace imex {
22+
#define GEN_PASS_DECL_CONVERTXEGPUTOXEVMPASS
23+
#include "imex/Conversion/Passes.h.inc"
24+
25+
void populateXeGPUToXeVMConversionPatterns(
26+
mlir::RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter);
27+
28+
} // namespace imex
29+
30+
#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//===-- Mangling.h - Mangle intrinsics -------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef MLIR_CONVERSION_XEVMTOLLVM_MANGLING_H
9+
#define MLIR_CONVERSION_XEVMTOLLVM_MANGLING_H
10+
11+
#include "mlir/IR/BuiltinTypes.h"
12+
#include "mlir/IR/Types.h"
13+
#include "mlir/Support/LLVM.h"
14+
15+
#include "llvm/ADT/STLExtras.h"
16+
#include "llvm/ADT/TypeSwitch.h"
17+
#include "llvm/Support/raw_ostream.h"
18+
19+
#include <string>
20+
namespace mlir {
21+
std::string getTypeMangling(Type ty, bool isUnsigned = false) {
22+
return TypeSwitch<Type, std::string>(ty)
23+
.Case([isUnsigned](VectorType ty) -> std::string {
24+
return "Dv" + std::to_string(ty.getNumElements()) + "_" +
25+
getTypeMangling(ty.getElementType(), isUnsigned);
26+
})
27+
.Case([](Float16Type) -> std::string { return "Dh"; })
28+
.Case([](Float32Type) -> std::string { return "f"; })
29+
.Case([](Float64Type) -> std::string { return "d"; })
30+
.Case([isUnsigned](IntegerType ty) -> std::string {
31+
switch (ty.getWidth()) {
32+
case 8:
33+
return isUnsigned ? "h" : "c";
34+
case 16:
35+
return isUnsigned ? "t" : "s";
36+
case 32:
37+
return isUnsigned ? "j" : "i";
38+
case 64:
39+
return isUnsigned ? "m" : "l";
40+
default:
41+
llvm_unreachable("unhandled integer type");
42+
}
43+
});
44+
}
45+
} // namespace mlir
46+
47+
#endif // MLIR_CONVERSION_XEVMTOLLVM_MANGLING_H
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===-- XeVMToLLVM.h - Convert XeVM to LLVM dialect -------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
9+
#define MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
10+
11+
#include <memory>
12+
13+
namespace mlir {
14+
class DialectRegistry;
15+
class LLVMTypeConverter;
16+
class RewritePatternSet;
17+
class Pass;
18+
} // namespace mlir
19+
20+
namespace imex {
21+
#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
22+
#include "imex/Conversion/Passes.h.inc"
23+
24+
void populateXeVMToLLVMConversionPatterns(mlir::RewritePatternSet &patterns);
25+
26+
void registerConvertXeVMToLLVMInterface(mlir::DialectRegistry &registry);
27+
} // namespace imex
28+
29+
#endif // MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_

include/imex/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ add_subdirectory(DistRuntime)
22
add_subdirectory(NDArray)
33
add_subdirectory(Region)
44
add_subdirectory(GPUX)
5+
add_subdirectory(LLVMIR)
56
add_subdirectory(XeTile)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
add_mlir_dialect(XeVMOps xevm)
2+
add_mlir_doc(XeVMOps XeVMDialect Dialects/ -gen-dialect-doc -dialect=xevm)
3+
set(LLVM_TARGET_DEFINITIONS XeVMOps.td)
4+
mlir_tablegen(XeVMConversions.inc -gen-llvmir-conversions)
5+
mlir_tablegen(XeVMOpsEnums.h.inc -gen-enum-decls)
6+
mlir_tablegen(XeVMOpsEnums.cpp.inc -gen-enum-defs)
7+
mlir_tablegen(XeVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=xevm)
8+
mlir_tablegen(XeVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=xevm)
9+
add_public_tablegen_target(MLIRXeVMConversionsIncGen)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===-- XeVMDialect.h - MLIR XeVM target definitions ------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_
10+
#define MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_
11+
12+
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14+
#include "mlir/IR/Dialect.h"
15+
#include "mlir/IR/OpDefinition.h"
16+
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
17+
18+
#include <imex/Dialect/LLVMIR/XeVMOpsEnums.h.inc>
19+
20+
namespace imex::xevm {
21+
22+
enum class XeVMAddrSpace : uint32_t {
23+
kPrivate = 0, // OpenCL Workitem address space, SPIRV function
24+
kGlobal = 1, // OpenCL Global memory, SPIRV crossworkgroup
25+
kConstant = 2, // OpenCL Constant memory, SPIRV uniform constant
26+
kShared = 3, // OpenCL Local memory, SPIRV workgroup
27+
kGeneric = 4 // OpenCL Generic memory, SPIRV generic
28+
};
29+
30+
} // namespace imex::xevm
31+
32+
#define GET_ATTRDEF_CLASSES
33+
#include <imex/Dialect/LLVMIR/XeVMOpsAttributes.h.inc>
34+
35+
#define GET_OP_CLASSES
36+
#include <imex/Dialect/LLVMIR/XeVMOps.h.inc>
37+
38+
#include <imex/Dialect/LLVMIR/XeVMOpsDialect.h.inc>
39+
40+
#endif /* MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_ */

0 commit comments

Comments
 (0)