Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions examples/elementwise/elementwise_bilinearInterpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import argparse

import tilelang
import tilelang.language as T
import torch

tilelang.cache.clear_cache()

# parser = argparse.ArgumentParser(description="NPU Kernel Compilation")
# parser.add_argument("--m", type=int, default=1024, help="Matrix M dimension")
# parser.add_argument("--n", type=int, default=1024, help="Matrix N dimension")
# args = parser.parse_args()

# M = args.m
# N = args.n
src0 = torch.arange(1, 513, dtype=torch.float16).reshape(1, -1).npu()
src0offset_int = torch.arange(0, 1024, 32, dtype=torch.int64).reshape(1, -1).npu()
src0offset = src0offset_int.to(dtype=torch.uint32)
src1 = torch.arange(2, 18, dtype=torch.float16).reshape(1, -1).npu()

hRepeat = 2
mask1 = 128
repeatMode = False
dstBlkStride = 1
vROffset = 128
vRepeat = 2
mask0 = 0


@tilelang.jit(out_idx=[-1])
def bilinear_interpolation(mask, h_repeat, repeat_mode,
dst_blk_stride, v_r_offset, v_repeat):
m_num = 1
n_num = 1

VEC_NUM = 1

@T.prim_func
def main(
src0: T.Tensor((src0.shape[0], src0.shape[1]), "float16"),
src0_offset: T.Tensor((src0offset.shape[0], src0offset.shape[1]), "uint32"),
src1: T.Tensor((src1.shape[0], src1.shape[1]), "float16"),
dst: T.Tensor((src0.shape[0], src0.shape[1] // 2), "float16"),
):
with T.Kernel(m_num * n_num, is_npu=True) as (cid, vid):
bx = cid // n_num
by = cid % n_num

src0_ub = T.alloc_ub((src0.shape[0] // VEC_NUM, src0.shape[1]), "float16")
src0_offset_ub = T.alloc_ub((src0offset.shape[0] // VEC_NUM, src0offset.shape[1]), "uint32")
src1_ub = T.alloc_ub((src1.shape[0] // VEC_NUM, src1.shape[1]), "float16")
dst_ub = T.alloc_ub((src0.shape[0] // VEC_NUM, src0.shape[1] // 2), "float16")
shared_tmp_buffer_ub = T.alloc_ub((src0.shape[0], src0.shape[1]), "uint8")

with T.Scope("V"):
T.copy(src0[0, 0], src0_ub)
T.copy(src0_offset[0, 0], src0_offset_ub)
T.copy(src1[0, 0], src1_ub)

T.barrier_all()
T.bilinear_interpolation(dst_ub, src0_ub, src0_offset_ub, src1_ub, mask, h_repeat,
repeat_mode, dst_blk_stride, v_r_offset, v_repeat, shared_tmp_buffer_ub)
T.barrier_all()

T.copy(dst_ub, dst[0, 0])

return main


func = bilinear_interpolation(mask1, hRepeat, repeatMode, dstBlkStride, vROffset, vRepeat)

torch.manual_seed(0)

torch.npu.synchronize()
print("init successful!")

c = func(src0, src0offset, src1)

# 计算ref_c

def fun_ref(a, b, c, hRepeat, vRepeat, repeatMode, vROffset):
a = a.flatten()
b = b.flatten()
c = c.flatten()
re = []

if repeatMode:
for k in range(vRepeat):
s = torch.zeros(128, dtype=torch.float16).npu() #初始化累加器
r = torch.zeros(128 * hRepeat, dtype=torch.float16).npu()
for i in range(hRepeat):
for j in range(8):
idx = b[k * 8 * hRepeat + i * 8 + j].to(torch.int64) // 32
r[i * 128 + j * 16 : i * 128 + (j+1) * 16] = a[idx * 16 : (idx + 1) * 16] * c[k * 8 * hRepeat + i * 8 + j]
s += r[i * 128 : (i + 1) * 128]
re.append(s)
else:
for k in range(vRepeat):
s = torch.zeros(128, dtype=torch.float16).npu()
r = torch.zeros(128 * hRepeat, dtype=torch.float16).npu()
for i in range(hRepeat):
for j in range(8):
idx = b[k * 8 * hRepeat + i * 8 + j].to(torch.int64) // 32
r[i * 128 + j * 16 : i * 128 + (j + 1) * 16] = a[idx * 16 : (idx + 1) * 16] * c[k * hRepeat + i]
s += r[i * 128 : (i + 1) * 128]
re.append(s)
return torch.cat(re, dim=0).flatten()

if vROffset > 128:
outsize = vRepeat * vROffset
else:
outsize = vRepeat * 128

out = fun_ref(src0, src0offset, src1, hRepeat, vRepeat, repeatMode, vROffset)

out_real = torch.zeros(outsize, dtype=torch.float16).npu()
if mask0 == 0:
for i in range(vRepeat):
n = mask1 // 16
l = mask1 % 16
for j in range(n):
out_real[i * vROffset + j * 16 : i * vROffset + (j + 1) * 16] = out[i * 128 + j * 16 : i * 128 + (j + 1) * 16]
out_real[i * vROffset + n * 16 : i * vROffset + n * 16 + l] = out[i * 128 + n * 16 : i * 128 + n * 16 + l]

ref_c = out_real[:vRepeat * 128].unsqueeze(0)

torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel Output Match!")
82 changes: 82 additions & 0 deletions examples/elementwise/elementwise_wholereducemax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import argparse

import tilelang
import tilelang.language as T
import torch

tilelang.cache.clear_cache()

parser = argparse.ArgumentParser(description="NPU Kernel Compilation")
parser.add_argument("--m", type=int, default=1024, help="Matrix M dimension")
parser.add_argument("--n", type=int, default=1024, help="Matrix N dimension")
args = parser.parse_args()

M = 2
N = 512
block_M = 2
block_N = 128
mask = 64
repeatTimes = 2
dstRepStride = 1
srcBlkStride = 1
srcRepStride = 4


@tilelang.jit(out_idx=[-1])
def wholereducemax(M, N, block_M, block_N, mask, repeatTimes, dstRepStride, srcBlkStride, srcRepStride, dtype="float16"):
m_num = M // block_M
n_num = N // block_N

VEC_NUM = 2

@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, 2 * N // mask), dtype),
):
with T.Kernel(m_num * n_num, is_npu=True) as (cid, vid):
bx = cid // n_num
by = cid % n_num

a_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype)
b_ub = T.alloc_ub((block_M // VEC_NUM, 2 * block_N // mask), dtype)
with T.Scope("V"):
T.copy(A[bx * block_M + vid * block_M // VEC_NUM, by * block_N], a_ub)

T.barrier_all()
T.wholereducemax(b_ub, a_ub, mask, repeatTimes, dstRepStride, srcBlkStride, srcRepStride)
T.barrier_all()

T.copy(b_ub, B[bx * block_M + vid * block_M // VEC_NUM, by * 2 * block_N // mask])

return main


func = wholereducemax(M, N, block_M, block_N, mask, repeatTimes, dstRepStride, srcBlkStride, srcRepStride)

torch.manual_seed(0)

a = torch.randn(M, N, dtype=torch.float16).npu()

torch.npu.synchronize()
print("init successful!")

b = func(a)

num_groups = M * N // mask
ref_b = torch.zeros((1, 2 * num_groups)).to(torch.float16)
a_flag = a.reshape(-1)
for i in range(num_groups):
start = i * mask
end = start + mask
group = a_flag[start:end]
max_val = torch.max(group).item()
max_idx_in_group = torch.argmax(group).item()
result = torch.tensor([max_idx_in_group], dtype=torch.uint16).view(torch.float16).float().item()
ref_b[0, 2 * i] = max_val
ref_b[0, 2 * i + 1] = result
ref_b = ref_b.reshape(M, 2 * N // mask)
ref_b = ref_b.npu().to(dtype=torch.float16)

torch.testing.assert_close(b, ref_b, rtol=1e-2, atol=1e-2)
print("Kernel Output Match!")
82 changes: 82 additions & 0 deletions examples/elementwise/elementwise_wholereducemin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import argparse

import tilelang
import tilelang.language as T
import torch

tilelang.cache.clear_cache()

parser = argparse.ArgumentParser(description="NPU Kernel Compilation")
parser.add_argument("--m", type=int, default=1024, help="Matrix M dimension")
parser.add_argument("--n", type=int, default=1024, help="Matrix N dimension")
args = parser.parse_args()

M = 2
N = 512
block_M = 2
block_N = 128
mask = 64
repeatTimes = 2
dstRepStride = 1
srcBlkStride = 1
srcRepStride = 4


@tilelang.jit(out_idx=[-1])
def wholereducemin(M, N, block_M, block_N, mask, repeatTimes, dstRepStride, srcBlkStride, srcRepStride, dtype="float16"):
m_num = M // block_M
n_num = N // block_N

VEC_NUM = 2

@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, 2 * N // mask), dtype),
):
with T.Kernel(m_num * n_num, is_npu=True) as (cid, vid):
bx = cid // n_num
by = cid % n_num

a_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype)
b_ub = T.alloc_ub((block_M // VEC_NUM, 2 * block_N // mask), dtype)
with T.Scope("V"):
T.copy(A[bx * block_M + vid * block_M // VEC_NUM, by * block_N], a_ub)

T.barrier_all()
T.wholereducemin(b_ub, a_ub, mask, repeatTimes, dstRepStride, srcBlkStride, srcRepStride)
T.barrier_all()

T.copy(b_ub, B[bx * block_M + vid * block_M // VEC_NUM, by * 2 * block_N // mask])

return main


func = wholereducemin(M, N, block_M, block_N, mask, repeatTimes, dstRepStride, srcBlkStride, srcRepStride)

torch.manual_seed(0)

a = torch.randn(M, N, dtype=torch.float16).npu()

torch.npu.synchronize()
print("init successful!")

b = func(a)

num_groups = M * N // mask
ref_b = torch.zeros((1, 2 * num_groups)).to(torch.float16)
a_flag = a.reshape(-1)
for i in range(num_groups):
start = i * mask
end = start + mask
group = a_flag[start:end]
min_val = torch.min(group).item()
min_idx_in_group = torch.argmin(group).item()
result = torch.tensor([min_idx_in_group], dtype=torch.uint16).view(torch.float16).float().item()
ref_b[0, 2 * i] = min_val
ref_b[0, 2 * i + 1] = result
ref_b = ref_b.reshape(M, 2 * N // mask)
ref_b = ref_b.npu().to(dtype=torch.float16)

torch.testing.assert_close(b, ref_b, rtol=1e-2, atol=1e-2)
print("Kernel Output Match!")
79 changes: 79 additions & 0 deletions examples/elementwise/elementwise_wholereducesum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import argparse

import tilelang
import tilelang.language as T
import torch

tilelang.cache.clear_cache()

parser = argparse.ArgumentParser(description="NPU Kernel Compilation")
parser.add_argument("--m", type=int, default=1024, help="Matrix M dimension")
parser.add_argument("--n", type=int, default=1024, help="Matrix N dimension")
args = parser.parse_args()

M = 2
N = 512
block_M = 2
block_N = 128
mask = 64
repeatTimes = 2
dstRepStride = 1
srcBlkStride = 1
srcRepStride = 4


@tilelang.jit(out_idx=[-1])
def wholereducesum(M, N, block_M, block_N, mask, repeatTimes, dstRepStride, srcBlkStride, srcRepStride, dtype="float16"):
m_num = M // block_M
n_num = N // block_N

VEC_NUM = 2

@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N // mask), dtype),
):
with T.Kernel(m_num * n_num, is_npu=True) as (cid, vid):
bx = cid // n_num
by = cid % n_num

a_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype)
b_ub = T.alloc_ub((block_M // VEC_NUM, block_N // mask), dtype)
with T.Scope("V"):
T.copy(A[bx * block_M + vid * block_M // VEC_NUM, by * block_N], a_ub)

T.barrier_all()
T.wholereducesum(b_ub, a_ub, mask, repeatTimes, dstRepStride, srcBlkStride, srcRepStride)
T.barrier_all()

T.copy(b_ub, B[bx * block_M + vid * block_M // VEC_NUM, by * block_N // mask])

return main


func = wholereducesum(M, N, block_M, block_N, mask, repeatTimes, dstRepStride, srcBlkStride, srcRepStride)

torch.manual_seed(0)

a = torch.randn(M, N, dtype=torch.float16).npu()

torch.npu.synchronize()
print("init successful!")

b = func(a)

num_groups = M * N // mask
ref_b = torch.zeros((1, num_groups))
a_flag = a.reshape(-1)
for i in range(num_groups):
start = i * mask
end = start + mask
group = a_flag[start:end]
sum_val = torch.sum(group).item()
ref_b[0, i] = sum_val
ref_b = ref_b.reshape(M, N // mask)
ref_b = ref_b.npu().to(dtype=torch.float16)

torch.testing.assert_close(b, ref_b, rtol=1e-2, atol=1e-2)
print("Kernel Output Match!")
Loading