Skip to content

Commit

Permalink
[matmul] Add transpose B matrix coverage for CDNA3 (iree-org#16558)
Browse files Browse the repository at this point in the history
This commit adds transpose B matrix coverage in the matmul test suite.
This is to enable adding such tests for CDNA3 mfma CodeGen pipeline.

ci-extra: test_gpu
  • Loading branch information
antiagainst authored Feb 28, 2024
1 parent 09deadf commit d7de68a
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 93 deletions.
29 changes: 27 additions & 2 deletions tests/e2e/matmul/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,8 @@ iree_generated_e2e_matmul_test(
###########################################################################

# Testing CDNA3 + matrix core path.
# v_mfma_f32_16x16x16_f16
iree_generated_e2e_matmul_test(
name = "e2e_matmul_rocm_f16_large_cdna3_matrixcore",
name = "e2e_matmul_rocm_f16_large_cdna3_mfma",
compiler_flags = [
"--iree-rocm-target-chip=gfx942",
],
Expand All @@ -456,6 +455,32 @@ iree_generated_e2e_matmul_test(
test_runner = "//tools:iree-e2e-matmul-test",
)

iree_generated_e2e_matmul_test(
name = "e2e_matmul_rocm_f16_large_cdna3_mfma_tb",
compiler_flags = [
"--iree-rocm-target-chip=gfx942",
],
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f16",
"--acc_type=f32",
"--transpose_rhs",
"--shapes=gpu_large_aligned",
"--compilation_info=LLVMGPUVectorDistribute",
],
tags = [
"noasan",
"nomsan",
"notsan",
"noubsan",
"requires-gpu-cdna3",
],
target_backends_and_drivers = [
("rocm", "rocm"),
],
test_runner = "//tools:iree-e2e-matmul-test",
)

###########################################################################
##
## Vulkan backend
Expand Down
29 changes: 28 additions & 1 deletion tests/e2e/matmul/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ iree_generated_e2e_matmul_test(

iree_generated_e2e_matmul_test(
NAME
e2e_matmul_rocm_f16_large_cdna3_matrixcore
e2e_matmul_rocm_f16_large_cdna3_mfma
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
Expand All @@ -1002,6 +1002,33 @@ iree_generated_e2e_matmul_test(
"requires-gpu-cdna3"
)

iree_generated_e2e_matmul_test(
NAME
e2e_matmul_rocm_f16_large_cdna3_mfma_tb
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f16"
"--acc_type=f32"
"--transpose_rhs"
"--shapes=gpu_large_aligned"
"--compilation_info=LLVMGPUVectorDistribute"
TEST_RUNNER
iree-e2e-matmul-test
TARGET_BACKENDS
"rocm"
DRIVERS
"rocm"
COMPILER_FLAGS
"--iree-rocm-target-chip=gfx942"
LABELS
"noasan"
"nomsan"
"notsan"
"noubsan"
"requires-gpu-cdna3"
)

iree_generated_e2e_matmul_test(
NAME
e2e_matmul_vulkan_i8_large_valhall
Expand Down
120 changes: 82 additions & 38 deletions tests/e2e/matmul/generate_e2e_matmul_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,14 +420,24 @@ class TestInputMatricesShapes:
# Helper for generate_function. Generates TestInputMatricesShapes, i.e.
# converts from the runtime shape dimensions in TestShape and given dynamicity to
# the set of shapes to be used in a test function's input tensors.
def generate_shapes(shape: TestShape, dynamicity: Dynamicity):
def generate_shapes(shape: TestShape, transpose_rhs: bool, dynamicity: Dynamicity):
lhs_rows = shape_dim(shape.m, dynamicity)
lhs_cols = shape_dim(shape.k, dynamicity)
acc_rows = shape_dim(shape.m, dynamicity)
acc_cols = shape_dim(shape.n, dynamicity)
if transpose_rhs:
rhs_rows = shape_dim(shape.n, dynamicity)
rhs_cols = shape_dim(shape.k, dynamicity)
else:
rhs_rows = shape_dim(shape.k, dynamicity)
rhs_cols = shape_dim(shape.n, dynamicity)
shapes = TestInputMatricesShapes(
lhs_rows=shape_dim(shape.m, dynamicity),
lhs_cols=shape_dim(shape.k, dynamicity),
rhs_rows=shape_dim(shape.k, dynamicity),
rhs_cols=shape_dim(shape.n, dynamicity),
acc_rows=shape_dim(shape.m, dynamicity),
acc_cols=shape_dim(shape.n, dynamicity),
lhs_rows=lhs_rows,
lhs_cols=lhs_cols,
rhs_rows=rhs_rows,
rhs_cols=rhs_cols,
acc_rows=acc_rows,
acc_cols=acc_cols,
)
return shapes

Expand All @@ -443,12 +453,12 @@ def generate_function_name(
):
input_t = lhs_rhs_type.value
acc_t = acc_type.value
lhs_m = int_or_DYN(shapes.lhs_rows)
lhs_k = int_or_DYN(shapes.lhs_cols)
rhs_k = int_or_DYN(shapes.rhs_rows)
rhs_n = int_or_DYN(shapes.rhs_cols)
acc_m = int_or_DYN(shapes.acc_rows)
acc_n = int_or_DYN(shapes.acc_cols)
lhs_r = int_or_DYN(shapes.lhs_rows)
lhs_c = int_or_DYN(shapes.lhs_cols)
rhs_r = int_or_DYN(shapes.rhs_rows)
rhs_c = int_or_DYN(shapes.rhs_cols)
acc_r = int_or_DYN(shapes.acc_rows)
acc_c = int_or_DYN(shapes.acc_cols)

info = ""
if compilation_info:
Expand All @@ -462,8 +472,8 @@ def generate_function_name(

matmul_kind = "matmul_accumulate" if accumulate else "matmul"
return (
f"{matmul_kind}_{lhs_m}x{lhs_k}x{input_t}_times_"
+ f"{rhs_k}x{rhs_n}x{input_t}_into_{acc_m}x{acc_n}x{acc_t}{info}"
f"{matmul_kind}_{lhs_r}x{lhs_c}x{input_t}_times_"
+ f"{rhs_r}x{rhs_c}x{input_t}_into_{acc_r}x{acc_c}x{acc_t}{info}"
)


Expand All @@ -477,28 +487,34 @@ class MLIRFunction:


# Generates a test function in the generated MLIR code.
# The generated function will take the same arguments as linalg.matmul and
# will just call linalg.matmul with them, returning its result.
# The generated function will take the same arguments as linalg.matmul variants
# and will just call linalg.matmul variants with them, returning its result.
def generate_function(
lhs_rhs_type: MatrixElemTypeId,
acc_type: MatrixElemTypeId,
shape: TestShape,
transpose_rhs: bool,
dynamicity: Dynamicity,
compilation_info: typing.Optional[CompilationInfo] = None,
):
shapes = generate_shapes(shape, dynamicity)
shapes = generate_shapes(shape, transpose_rhs, dynamicity)
func_name = generate_function_name(
lhs_rhs_type, acc_type, shapes, shape.accumulate, compilation_info
)
lhs_m = int_or_question_mark(shapes.lhs_rows)
lhs_k = int_or_question_mark(shapes.lhs_cols)
rhs_k = int_or_question_mark(shapes.rhs_rows)
rhs_n = int_or_question_mark(shapes.rhs_cols)
acc_m = int_or_question_mark(shapes.acc_rows)
acc_n = int_or_question_mark(shapes.acc_cols)
lhs_tensor_type = f"tensor<{lhs_m}x{lhs_k}x{lhs_rhs_type.value}>"
rhs_tensor_type = f"tensor<{rhs_k}x{rhs_n}x{lhs_rhs_type.value}>"
acc_tensor_type = f"tensor<{acc_m}x{acc_n}x{acc_type.value}>"
lhs_r = int_or_question_mark(shapes.lhs_rows)
lhs_c = int_or_question_mark(shapes.lhs_cols)
rhs_r = int_or_question_mark(shapes.rhs_rows)
rhs_c = int_or_question_mark(shapes.rhs_cols)
acc_r = int_or_question_mark(shapes.acc_rows)
acc_c = int_or_question_mark(shapes.acc_cols)
lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
acc_tensor_type = f"tensor<{acc_r}x{acc_c}x{acc_type.value}>"

if transpose_rhs:
op_name = "linalg.matmul_transpose_b"
else:
op_name = "linalg.matmul"

# Compilation info is optional; prints empty string by default.
func_definition = ""
Expand Down Expand Up @@ -537,13 +553,13 @@ def generate_function(
import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view) -> !hal.buffer_view"
func_definition = func_definition + (
f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}, %acc: {acc_tensor_type}) -> {acc_tensor_type} {{\n"
f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
f" return %result: {acc_tensor_type}\n"
f"}}\n"
)
else:
literal_zero_for_acc_type = "0.0" if "f" in acc_type.value else "0"
if acc_m == "?":
if acc_r == "?":
signature = f"({lhs_tensor_type}, {rhs_tensor_type}) -> {acc_tensor_type}"
import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view) -> !hal.buffer_view"
func_definition = func_definition + (
Expand All @@ -555,7 +571,7 @@ def generate_function(
f" %init_acc = tensor.empty(%acc_dim0, %acc_dim1) : {acc_tensor_type}\n"
f" %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
f" %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
f" return %result: {acc_tensor_type}\n"
f"}}\n"
)
Expand All @@ -567,7 +583,7 @@ def generate_function(
f" %init_acc = tensor.empty() : {acc_tensor_type}\n"
f" %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n"
f" %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n"
f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
f" return %result: {acc_tensor_type}\n"
f"}}\n"
)
Expand Down Expand Up @@ -635,6 +651,7 @@ def generate_call(
lhs_rhs_type: MatrixElemTypeId,
acc_type: MatrixElemTypeId,
shape: TestShape,
transpose_rhs: bool,
):
global call_id
func_name = f"{function.name}_{shape.m}_{shape.k}_{shape.n}"
Expand All @@ -652,8 +669,16 @@ def generate_call(
" %device = hal.devices.get %device_index : !hal.device\n"
)

op = op + generate_random_matrix("lhs", [shape.m, shape.k], lhs_rhs_type)
op = op + generate_random_matrix("rhs", [shape.k, shape.n], lhs_rhs_type)
lhs_shape = [shape.m, shape.k]
if transpose_rhs:
rhs_shape = [shape.n, shape.k]
transpose_rhs = 1
else:
rhs_shape = [shape.k, shape.n]
transpose_rhs = 0

op = op + generate_random_matrix("lhs", lhs_shape, lhs_rhs_type)
op = op + generate_random_matrix("rhs", rhs_shape, lhs_rhs_type)
if shape.accumulate:
op = op + generate_random_matrix("acc", [shape.m, shape.n], acc_type)
# TODO(#16168): there's a bug with in-place input->output aliasing and
Expand All @@ -674,7 +699,8 @@ def generate_call(
f" %m = arith.constant {shape.m} : i64\n"
f" %k = arith.constant {shape.k} : i64\n"
f" %n = arith.constant {shape.n} : i64\n"
f" call @matmul_test.check_matmul_results(%device, %m, %k, %n, %lhs, %rhs, %acc, %result) : (!hal.device, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n"
f" %transpose_rhs = arith.constant {transpose_rhs} : i32\n"
f" call @matmul_test.check_matmul_results(%device, %m, %k, %n, %transpose_rhs, %lhs, %rhs, %acc, %result) : (!hal.device, i64, i64, i64, i32, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n"
)

op = op + " return\n"
Expand All @@ -688,6 +714,7 @@ def generate(
lhs_rhs_type: MatrixElemTypeId,
acc_type: MatrixElemTypeId,
shapes_id: ShapesId,
transpose_rhs: bool,
compilation_info_id: CompilationInfoId,
):
functions = {}
Expand All @@ -699,7 +726,12 @@ def generate(
for shape in get_test_shapes(shapes_id):
for dynamicity in get_dynamicities(shapes_id):
function = generate_function(
lhs_rhs_type, acc_type, shape, dynamicity, compilation_info
lhs_rhs_type,
acc_type,
shape,
transpose_rhs,
dynamicity,
compilation_info,
)
# Different testcases may differ only by runtime parameters but
# share the same code. For example, dynamic-shapes testcases
Expand All @@ -708,7 +740,11 @@ def generate(
# to calls, but unconditionally to function_definitions.
if function.name not in functions:
functions[function.name] = function
calls.append(generate_call(function, lhs_rhs_type, acc_type, shape))
calls.append(
generate_call(
function, lhs_rhs_type, acc_type, shape, transpose_rhs
)
)

return (functions, calls)

Expand Down Expand Up @@ -749,6 +785,13 @@ def parse_arguments():
help="Collection of matrix shapes to test",
required=True,
)
parser.add_argument(
"--transpose_rhs",
action="store_true",
help="Whether to transpose RHS",
default=False,
required=False,
)
parser.add_argument(
"--compilation_info",
type=str,
Expand Down Expand Up @@ -790,7 +833,7 @@ def write_calls_file(functions, calls, filename, requirements):
# Declare the custom module that generates arguments.
module_definition = module_definition + (
"func.func private @matmul_test.generate_random_matrix(%device: !hal.device, %dim0: i64, %dim1: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view\n"
"func.func private @matmul_test.check_matmul_results(%device: !hal.device, %m: i64, %k: i64, %n: i64, %lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view, %actual_result: !hal.buffer_view)\n"
"func.func private @matmul_test.check_matmul_results(%device: !hal.device, %m: i64, %k: i64, %n: i64, %transpose_rhs: i32, %lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view, %actual_result: !hal.buffer_view)\n"
"\n"
)

Expand Down Expand Up @@ -827,8 +870,9 @@ def main(args):
acc_type = infer_acc_type(lhs_rhs_type, acc_type)
shapes_id = ShapesId(args.shapes)
compilation_info_id = CompilationInfoId(args.compilation_info)

(functions, calls) = generate(
lhs_rhs_type, acc_type, shapes_id, compilation_info_id
lhs_rhs_type, acc_type, shapes_id, args.transpose_rhs, compilation_info_id
)

write_code_file(functions, args.output_matmuls_mlir)
Expand Down
Loading

0 comments on commit d7de68a

Please sign in to comment.