Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 54 additions & 12 deletions src/device/intrinsics/wmma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module WMMA
import ..LLVM
using ..CUDA: AS
using Core: LLVMPtr
using BFloat16s: BFloat16

################################################################################
# CONSTANTS
Expand All @@ -15,6 +16,7 @@ const map_ptx_to_jl_array = Dict(
"s8" => Int8,
"s32" => Int32,
"f16" => Float16,
"bf16" => BFloat16,
"f32" => Float32
)

Expand All @@ -24,6 +26,7 @@ const map_ptx_to_jl_frag = Dict(
"s8" => UInt32,
"s32" => Int32,
"f16" => NTuple{2, VecElement{Float16}},
"bf16" => UInt32,
"f32" => Float32
)

Expand All @@ -41,6 +44,10 @@ const map_frag_sizes = Dict(
"a.f16.m16n16k16" => 8,
"a.f16.m8n32k16" => 8,
"a.f16.m32n8k16" => 8,

"a.bf16.m16n16k16" => 4,
"a.bf16.m8n32k16" => 2,
"a.bf16.m32n8k16" => 8,
# B
"b.u8.m16n16k16" => 2,
"b.u8.m8n32k16" => 4,
Expand All @@ -53,6 +60,10 @@ const map_frag_sizes = Dict(
"b.f16.m16n16k16" => 8,
"b.f16.m8n32k16" => 8,
"b.f16.m32n8k16" => 8,

"b.bf16.m16n16k16" => 4,
"b.bf16.m8n32k16" => 8,
"b.bf16.m32n8k16" => 2,
# C
"c.s32.m16n16k16" => 8,
"c.s32.m8n32k16" => 8,
Expand Down Expand Up @@ -96,10 +107,13 @@ const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f
const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"]
const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"]
const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"]
# BFloat16 (requires Ampere+, only f32 accumulator supported)
const ldst_bf16_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["bf16"]
const wmma_bf16_ops = [(16,16,16), (32,8,16), (8,32,16)], ["bf16"], ["f32"], ["f32"]

const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops,
ldst_int_ab_ops, ldst_int_cd_ops)
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops)
ldst_int_ab_ops, ldst_int_cd_ops, ldst_bf16_ab_ops)
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_bf16_ops)

# Valid WMMA operation shapes
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)]
Expand Down Expand Up @@ -319,12 +333,12 @@ for ops in all_wmma_ops,
shape = get_hl_shape(mnk[1], mnk[2], mnk[3])

# Name of the LLVM intrinsic
# If integer/sub-byte/bit A/B types, name is determined by A/B types
if d_elem_type == "s32"
# If integer/sub-byte/bit/bf16 A/B types, name is determined by A/B types
if d_elem_type == "s32" || a_elem_type == "bf16"
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$a_elem_type"
# Name of the Julia wrapper function
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type]), "_"))
else # Name defined by D/C types
else # f16: Name defined by D/C types
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$d_elem_type.$c_elem_type"
# Name of the Julia wrapper function
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type]), "_"))
Expand Down Expand Up @@ -393,6 +407,28 @@ end
@generated flatten(x::typ) where typ = Expr(:tuple, flatten_recurse(typ, :x)...)
@generated unflatten(::Type{typ}, x) where typ = unflatten_recurse(typ, :x, 1)[1]

# BFloat16 packing/unpacking (UInt32 contains 2x BFloat16)
@inline function unpack_bf16(x::UInt32)
lo = reinterpret(BFloat16, UInt16(x & 0xFFFF))
hi = reinterpret(BFloat16, UInt16(x >> 16))
return (lo, hi)
end

@inline function pack_bf16(lo::BFloat16, hi::BFloat16)
return UInt32(reinterpret(UInt16, lo)) | (UInt32(reinterpret(UInt16, hi)) << 16)
end

@inline function flatten_bf16(x::NTuple{N, UInt32}) where N
ntuple(i -> begin
lo, hi = unpack_bf16(x[(i+1)÷2])
isodd(i) ? lo : hi
end, Val(2N))
end

@inline function unflatten_bf16(x::NTuple{N, BFloat16}) where N
ntuple(i -> pack_bf16(x[2i-1], x[2i]), Val(N÷2))
end

################################################################################
# HIGH LEVEL (CUDA-STYLE API)
################################################################################
Expand Down Expand Up @@ -513,6 +549,8 @@ const map_layout_ty_to_str = Dict(
const map_num_elems = Dict(
("a", Float16) => 16,
("b", Float16) => 16,
("a", BFloat16) => 8,
("b", BFloat16) => 8,
("c", Float16) => 8,
("c", Float32) => 8,
("d", Float16) => 8,
Expand Down Expand Up @@ -614,8 +652,9 @@ for mat in ["a", "b", "c"]
# Name of the Julia wrapper
wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str]), "_"))

_flatten = T == BFloat16 ? flatten_bf16 : flatten
return quote
x = flatten($wrapper(addr, stride))
x = $_flatten($wrapper(addr, stride))
return Fragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x)
end
end
Expand Down Expand Up @@ -656,19 +695,22 @@ mma
b_layout = get_hl_layout(B_L)
shape = get_hl_shape(M, N, K)

_, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T, shape)
_, a_frag_sz, a_frag_ty, a_arr_str = get_hl_frag_info("a", A_T, shape)
_, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T, shape)
_, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T, shape)
d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T, shape)

names = ["llvm", "wmma", "mma", a_layout, b_layout, shape]
# bf16 uses input type in intrinsic name, f16 uses d/c types
A_T === BFloat16 ? push!(names, a_arr_str) : push!(names, d_arr_str, c_arr_str)
wrapper = Symbol(join(filter(!isempty, names), "_"))


# Name of the Julia wrapper
wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_arr_str, c_arr_str]), "_"))
a_unfl_expr = A_T === BFloat16 ? :(unflatten_bf16(a.x)) : :(unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x))
b_unfl_expr = B_T === BFloat16 ? :(unflatten_bf16(b.x)) : :(unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x))

return quote
a_unfl = unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x)
b_unfl = unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x)
a_unfl = $a_unfl_expr
b_unfl = $b_unfl_expr
c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x)

x = flatten($wrapper(a_unfl, b_unfl, c_unfl))
Expand Down
153 changes: 149 additions & 4 deletions test/core/device/intrinsics/wmma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ if capability(device()) >= v"7.0"

using CUDA.WMMA

using BFloat16s: BFloat16

map_ptx_to_jl_frag = Dict(
"u8" => reinterpret(Int32, UInt8(42) * ones(UInt8, 4))[1],
"s8" => reinterpret(Int32, UInt8(42) * ones(UInt8, 4))[1],
"u32" => UInt32(42),
"s32" => Int32(42),
"f16" => ntuple(i -> VecElement{Float16}(42), 2),
"bf16" => reinterpret(UInt32, BFloat16(42) * ones(BFloat16, 2))[1],
"f32" => Float32(42)
)
# Return specific matrix shape given operation configuration
Expand Down Expand Up @@ -48,6 +51,10 @@ end
startswith(elem_type, "u"))
continue
end
# Skip BFloat16 WMMA on pre-Ampere devices
if capability(device()) < v"8.0" && elem_type == "bf16"
continue
end

shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3])

Expand Down Expand Up @@ -115,6 +122,10 @@ end
startswith(elem_type, "u"))
continue
end
# Skip BFloat16 WMMA on pre-Ampere devices
if capability(device()) < v"8.0" && elem_type == "bf16"
continue
end

shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3])

Expand Down Expand Up @@ -175,6 +186,10 @@ end
startswith(ab_elem_type, "u"))
continue
end
# Skip BFloat16 WMMA on pre-Ampere devices
if capability(device()) < v"8.0" && ab_elem_type == "bf16"
continue
end

# Type-dependent variables
d_ty = CUDA.WMMA.map_ptx_to_jl_array[d_elem_type]
Expand All @@ -187,9 +202,9 @@ end
lda_func = getfield(Main, Symbol("llvm_wmma_load_a_$(a_layout)_$(shape)_global_stride_$(ab_elem_type)"))
ldb_func = getfield(Main, Symbol("llvm_wmma_load_b_$(b_layout)_$(shape)_global_stride_$(ab_elem_type)"))
ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_$(shape)_global_stride_$(c_elem_type)"))
# Account for half and int/subint mma different naming conventions
# Int/subint mma functions are distinguished by the a/b element type
mma_sym = d_ty == Int32 ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") :
# Account for half and int/subint/bf16 mma different naming conventions
# Int/subint and bf16 mma functions are distinguished by the a/b element type
mma_sym = (d_ty == Int32 || ab_elem_type == "bf16") ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") :
Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)")
mma_func = getfield(Main, mma_sym)
std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_$(shape)_global_stride_$(d_elem_type)"))
Expand Down Expand Up @@ -227,6 +242,8 @@ end
# Alter test depending on a/b element Type
if ab_ty == Float16
@test new_a * new_b + c ≈ Array(d_dev) rtol=Base.rtoldefault(Float16)
elseif ab_ty == BFloat16
@test Float32.(new_a) * Float32.(new_b) + c ≈ Array(d_dev) rtol=Base.rtoldefault(BFloat16)
else # Cast a and b to prevent UInt8 rollover of resultant data
@test Int32.(new_a) * Int32.(new_b) + c == Array(d_dev)
end
Expand Down Expand Up @@ -256,12 +273,20 @@ end
@test WMMA.unflatten(NTuple{8, NTuple{2, Int64}}, ntuple(i -> i, 2 * 8)) == ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8)
@test WMMA.unflatten(NTuple{8, NTuple{2, VecElement{Float16}}}, ntuple(i -> Float16(i), 2 * 8)) == ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8)
end

@testset "BFloat16 packing/unpacking" begin
bf_vals = ntuple(i -> BFloat16(i), 8)
packed = WMMA.unflatten_bf16(bf_vals)
@test length(packed) == 4
unpacked = WMMA.flatten_bf16(packed)
@test unpacked == bf_vals
end
end

################################################################################

@testset "Broadcasting over fragments: size=$sz, type=$ty" for sz = [1, 2, 5],
ty = [Float16, Float32]
ty = [Float16, Float32, BFloat16]
@test ty(5) .* Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(i), sz)) == Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(5 * i), sz))
@test ty(5) .+ Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(i), sz)) == Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(5 + i), sz))
end
Expand Down Expand Up @@ -331,6 +356,126 @@ end

################################################################################

if capability(device()) >= v"8.0"
@testset "CUDA C-style API (BFloat16)" begin
@testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout" for a_layout in [ColMajor, RowMajor],
b_layout in [ColMajor, RowMajor],
c_layout in [ColMajor, RowMajor],
d_layout in [ColMajor, RowMajor],
do_mac in [true, false]

a = rand(BFloat16, (16, 16))
b = rand(BFloat16, (16, 16))
c = rand(Float32, (16, 16))
d = Array{Float32}(undef, (16, 16))

a_dev = CuArray(a)
b_dev = CuArray(b)
c_dev = CuArray(c)
d_dev = CuArray(d)

# Note: BFloat16 fragment broadcasting (alpha .* a_frag) requires native bf16
# scalar ops which aren't available on all architectures, so we skip scaling
@eval function kernel_bf16(a_dev, b_dev, c_dev, d_dev)
conf = Config{16, 16, 16, Float32}

a_frag = load_a(pointer(a_dev), 16, $a_layout, conf)
b_frag = load_b(pointer(b_dev), 16, $b_layout, conf)

if $do_mac
c_frag = load_c(pointer(c_dev), 16, $c_layout, conf)
else
c_frag = fill_c(Float32(0), conf)
end

d_frag = mma(a_frag, b_frag, c_frag, conf)

store_d(pointer(d_dev), d_frag, 16, $d_layout, conf)

return
end

@cuda threads=32 kernel_bf16(a_dev, b_dev, c_dev, d_dev)
d = Array(d_dev)

new_a = (a_layout == ColMajor) ? a : transpose(a)
new_b = (b_layout == ColMajor) ? b : transpose(b)
new_c = (c_layout == ColMajor) ? c : transpose(c)
new_d = (d_layout == ColMajor) ? d : transpose(d)

if do_mac
@test Float32.(new_a) * Float32.(new_b) + new_c ≈ new_d rtol=Base.rtoldefault(BFloat16)
else
@test Float32.(new_a) * Float32.(new_b) ≈ new_d rtol=Base.rtoldefault(BFloat16)
end
end
end
end

# BFloat16 fragment broadcasting requires native bf16 scalar ops (CC 8.9+)
# On earlier architectures, frag[i] returns UInt32 (packed), causing type mismatch
if capability(device()) >= v"8.9"
@testset "CUDA C-style API (BFloat16 with scaling)" begin
@testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout" for a_layout in [ColMajor, RowMajor],
b_layout in [ColMajor, RowMajor],
c_layout in [ColMajor, RowMajor],
d_layout in [ColMajor, RowMajor],
do_mac in [true, false]

a = rand(BFloat16, (16, 16))
b = rand(BFloat16, (16, 16))
c = rand(Float32, (16, 16))
d = Array{Float32}(undef, (16, 16))

a_dev = CuArray(a)
b_dev = CuArray(b)
c_dev = CuArray(c)
d_dev = CuArray(d)

alpha = rand(BFloat16)
beta = rand(Float32)

@eval function kernel_bf16_scaled(a_dev, b_dev, c_dev, d_dev, alpha, beta)
conf = Config{16, 16, 16, Float32}

a_frag = load_a(pointer(a_dev), 16, $a_layout, conf)
b_frag = load_b(pointer(b_dev), 16, $b_layout, conf)

if $do_mac
c_frag = load_c(pointer(c_dev), 16, $c_layout, conf)
else
c_frag = fill_c(Float32(0), conf)
end

a_frag = alpha .* a_frag
c_frag = beta .* c_frag

d_frag = mma(a_frag, b_frag, c_frag, conf)

store_d(pointer(d_dev), d_frag, 16, $d_layout, conf)

return
end

@cuda threads=32 kernel_bf16_scaled(a_dev, b_dev, c_dev, d_dev, alpha, beta)
d = Array(d_dev)

new_a = (a_layout == ColMajor) ? a : transpose(a)
new_b = (b_layout == ColMajor) ? b : transpose(b)
new_c = (c_layout == ColMajor) ? c : transpose(c)
new_d = (d_layout == ColMajor) ? d : transpose(d)

if do_mac
@test Float32(alpha) * Float32.(new_a) * Float32.(new_b) + beta * new_c ≈ new_d rtol=Base.rtoldefault(BFloat16)
else
@test Float32(alpha) * Float32.(new_a) * Float32.(new_b) ≈ new_d rtol=Base.rtoldefault(BFloat16)
end
end
end
end

################################################################################

@testset "Codegen addressing" begin
@testset "Global" begin
function kernel(d)
Expand Down