Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[muti_backend] #294

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
fd24be2
new feature, muti_backend
Galaxy1458 Nov 16, 2024
c8c1bcd
update auto_tune_module
Galaxy1458 Nov 16, 2024
8bb14e3
update auto_tune_module
Galaxy1458 Nov 16, 2024
cca8de5
update auto_tune_module
Galaxy1458 Nov 16, 2024
1b8cc65
update __init__
Galaxy1458 Nov 18, 2024
d1596dd
rebase
Galaxy1458 Nov 18, 2024
e6bbce4
rebase
Galaxy1458 Nov 18, 2024
81ee7a8
fix bug
Galaxy1458 Nov 18, 2024
f6b3e94
modifiy auto_tune_config
Galaxy1458 Nov 18, 2024
8b77745
fix bug
Galaxy1458 Nov 20, 2024
9fbb1e3
fix bug
Galaxy1458 Nov 20, 2024
d55a303
fix conflict
Galaxy1458 Nov 20, 2024
0adbfe5
update
Galaxy1458 Nov 20, 2024
389f726
update
Galaxy1458 Nov 20, 2024
c2b03c1
update scatter&gather
Galaxy1458 Nov 20, 2024
51b22cf
fix auto_tune
Galaxy1458 Nov 20, 2024
8e0617e
add gen_torch_device_fn
Galaxy1458 Nov 20, 2024
d359b6a
fix codestyle
Galaxy1458 Nov 20, 2024
df73601
fix codestyle
Galaxy1458 Nov 20, 2024
3f64fff
fix comflict
Galaxy1458 Nov 21, 2024
1b98131
Modify code based on comments
Galaxy1458 Nov 28, 2024
8c83a46
Adjust the code based on the comments
Galaxy1458 Nov 28, 2024
dbb90d7
Modify gen_impl with loops instead of recursion
Galaxy1458 Nov 28, 2024
14a9a19
Update code structure
Galaxy1458 Nov 29, 2024
bcf2d78
Polish code
Galaxy1458 Nov 29, 2024
6854b5b
update
Galaxy1458 Nov 29, 2024
5e46495
Polish code
Galaxy1458 Nov 29, 2024
98fa16a
Modify code based on comments
Galaxy1458 Dec 3, 2024
b961015
modify based on comment
Galaxy1458 Dec 3, 2024
ace1b82
Modify code based on comments
Galaxy1458 Dec 4, 2024
ce19b23
update
Galaxy1458 Dec 4, 2024
08c8ab7
resolve conflict
Galaxy1458 Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
335 changes: 176 additions & 159 deletions src/flag_gems/__init__.py

Large diffs are not rendered by default.

44 changes: 2 additions & 42 deletions src/flag_gems/ops/addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,14 @@
import triton
import triton.language as tl

from .. import runtime
from ..utils import libentry
from ..utils import triton_lang_extension as tle


@libentry()
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
],
configs=runtime.get_op_tune_config("addmm"),
key=["M", "N", "K"],
)
@triton.jit(do_not_specialize=["alpha", "beta"])
Expand Down
10 changes: 2 additions & 8 deletions src/flag_gems/ops/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,13 @@
import triton
import triton.language as tl

from .. import runtime
from ..utils import dim_compress, libentry
from ..utils import triton_lang_extension as tle


# torch.all: Tests if all elements in input evaluate to True. If the dtype of input
# is not BOOL, then test if all elements in input evaluate to non-zero value
# In triton function, test if all elements in input evaluate to non-zero value is ok.
def cfggen():
block_m = [1, 2, 4, 8]
configs = [
triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m
]
return configs


@triton.jit
Expand All @@ -26,7 +20,7 @@ def reduce_all(a, b):


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.autotune(configs=runtime.get_op_tune_config("all"), key=["M", "N"])
@triton.jit
def all_kernel_dim(
inp,
Expand Down
12 changes: 3 additions & 9 deletions src/flag_gems/ops/amax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton
import triton.language as tl

from .. import runtime
from ..utils import dim_compress, libentry
from ..utils import triton_lang_extension as tle

Expand All @@ -18,6 +19,7 @@ def amax_kernel_1(
BLOCK_SIZE: tl.constexpr,
):
pid = tle.program_id(0)

offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
inp_ptrs = inp + offset
mask = offset < M
Expand All @@ -38,16 +40,8 @@ def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
tl.store(out, amax_val)


def cfggen():
block_m = [1, 2, 4, 8]
configs = [
triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m
]
return configs


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.autotune(configs=runtime.get_op_tune_config("amax"), key=["M", "N"])
@triton.jit
def amax_kernel(
inp,
Expand Down
10 changes: 2 additions & 8 deletions src/flag_gems/ops/any.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,13 @@
import triton
import triton.language as tl

from .. import runtime
from ..utils import dim_compress, libentry
from ..utils import triton_lang_extension as tle


# torch.any: Tests if any elements in input evaluate to True. If the dtype of input
# is not BOOL, then test if any elements in input evaluate to non-zero value
# In triton function, test if any elements in input evaluate to non-zero value is ok.
def cfggen():
block_m = [1, 2, 4, 8]
configs = [
triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m
]
return configs


@triton.jit
Expand All @@ -26,7 +20,7 @@ def reduce_any(a, b):


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.autotune(configs=runtime.get_op_tune_config("any"), key=["M", "N"])
@triton.jit
def any_kernel_dim(
inp,
Expand Down
7 changes: 2 additions & 5 deletions src/flag_gems/ops/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton
import triton.language as tl

from .. import runtime
from ..utils import libentry
from ..utils import triton_lang_extension as tle

Expand Down Expand Up @@ -50,11 +51,7 @@ def heur_block_n(args):

@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 8}, num_warps=8),
triton.Config({"BLOCK_M": 16}, num_warps=8),
triton.Config({"BLOCK_M": 32}, num_warps=8),
],
configs=runtime.get_op_tune_config("argmax"),
key=[
"M",
"N",
Expand Down
64 changes: 2 additions & 62 deletions src/flag_gems/ops/bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import triton
import triton.language as tl

from .. import runtime
from ..utils import libentry
from ..utils import triton_lang_extension as tle

Expand All @@ -22,68 +23,7 @@ def heur_divisible_k(args):

@libentry()
@triton.autotune(
configs=[
triton.Config(
{"TILE_M": 32, "TILE_N": 32, "TILE_K": 32, "GROUP_M": 1},
num_warps=4,
num_stages=2,
),
triton.Config(
{"TILE_M": 64, "TILE_N": 32, "TILE_K": 32, "GROUP_M": 2},
num_warps=4,
num_stages=2,
),
triton.Config(
{"TILE_M": 64, "TILE_N": 64, "TILE_K": 32, "GROUP_M": 2},
num_warps=4,
num_stages=2,
),
triton.Config(
{"TILE_M": 128, "TILE_N": 32, "TILE_K": 32, "GROUP_M": 2},
num_warps=4,
num_stages=2,
),
triton.Config(
{"TILE_M": 128, "TILE_N": 64, "TILE_K": 32, "GROUP_M": 2},
num_warps=4,
num_stages=2,
),
triton.Config(
{"TILE_M": 128, "TILE_N": 128, "TILE_K": 32, "GROUP_M": 2},
num_warps=4,
num_stages=2,
),
triton.Config(
{"TILE_M": 32, "TILE_N": 32, "TILE_K": 32, "GROUP_M": 1},
num_warps=4,
num_stages=3,
),
triton.Config(
{"TILE_M": 64, "TILE_N": 32, "TILE_K": 32, "GROUP_M": 2},
num_warps=4,
num_stages=3,
),
triton.Config(
{"TILE_M": 64, "TILE_N": 64, "TILE_K": 32, "GROUP_M": 2},
num_warps=4,
num_stages=3,
),
triton.Config(
{"TILE_M": 128, "TILE_N": 32, "TILE_K": 32, "GROUP_M": 2},
num_warps=4,
num_stages=3,
),
triton.Config(
{"TILE_M": 128, "TILE_N": 64, "TILE_K": 32, "GROUP_M": 2},
num_warps=4,
num_stages=3,
),
triton.Config(
{"TILE_M": 128, "TILE_N": 128, "TILE_K": 32, "GROUP_M": 2},
num_warps=4,
num_stages=3,
),
],
configs=runtime.get_op_tune_config("bmm"),
key=["M", "N", "K"],
)
@triton.heuristics(
Expand Down
37 changes: 7 additions & 30 deletions src/flag_gems/ops/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@
import triton
import triton.language as tl

from .. import runtime
from ..utils import libentry
from ..utils import triton_lang_extension as tle


@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_C": c, "BLOCK_D": d}, num_warps=4)
for c in [256, 512, 1024]
for d in [1, 4, 16]
],
configs=runtime.get_op_tune_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["ignore_index"])
Expand Down Expand Up @@ -75,11 +72,7 @@ def celoss_indices_kernel(

@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_C": c, "BLOCK_D": d}, num_warps=4)
for c in [256, 512, 1024]
for d in [1, 4, 16]
],
configs=runtime.get_op_tune_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["label_smoothing"])
Expand Down Expand Up @@ -138,11 +131,7 @@ def celoss_probability_kernel(

@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_C": c, "BLOCK_D": d}, num_warps=4)
for c in [256, 512, 1024]
for d in [1, 4, 16]
],
configs=runtime.get_op_tune_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"])
Expand Down Expand Up @@ -223,11 +212,7 @@ def celoss_indices_smooth_kernel(

@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_C": c, "BLOCK_D": d}, num_warps=4)
for c in [256, 512, 1024]
for d in [1, 4, 16]
],
configs=runtime.get_op_tune_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["ignore_index", "mean_num"])
Expand Down Expand Up @@ -298,11 +283,7 @@ def celoss_indices_bwd(

@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_C": c, "BLOCK_D": d}, num_warps=4)
for c in [256, 512, 1024]
for d in [1, 4, 16]
],
configs=runtime.get_op_tune_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["label_smoothing", "mean_num"])
Expand Down Expand Up @@ -387,11 +368,7 @@ def celoss_probability_bwd(

@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_C": c, "BLOCK_D": d}, num_warps=4)
for c in [256, 512, 1024]
for d in [1, 4, 16]
],
configs=runtime.get_op_tune_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"])
Expand Down
17 changes: 5 additions & 12 deletions src/flag_gems/ops/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
code.writeline("import triton.language as tl")
code.newline()
code.writeline("from flag_gems.utils import libentry")
code.writeline("from flag_gems import runtime")
code.writeline("from flag_gems.utils import triton_lang_extension as tle")

code.newline()
code.newline()
return code
Expand All @@ -31,24 +33,15 @@ def generate_gather_kernel(
code.newline()

# the autotune function
code.writeline("def cfggen():")
with code.indent():
code.writeline("block_m = [1, 2, 4, 8]")
code.writeline("block_n = [256, 512, 1024, 2048]")
code.writeline("configs = [")
with code.indent():
code.writeline('triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4)')
code.writeline("for m in block_m")
code.writeline("for n in block_n")
code.writeline("]")
code.writeline("return configs")

code.newline()
code.newline()

# the decorators
code.writeline("@libentry()")
code.writeline('@triton.autotune(configs=cfggen(), key=["M", "N"])')
code.writeline(
'@triton.autotune(configs=runtime.get_op_tune_config("gather"), key=["M", "N"])'
)
code.writeline("@triton.jit")

# signature
Expand Down
14 changes: 2 additions & 12 deletions src/flag_gems/ops/index_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,13 @@
import triton
import triton.language as tl

from .. import runtime
from ..utils import dim_compress, libentry
from ..utils import triton_lang_extension as tle


def cfggen():
block_m = [1, 2, 4]
block_n = [1024, 2048, 4096]
configs = [
triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4)
for m in block_m
for n in block_n
]
return configs


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.autotune(configs=runtime.get_op_tune_config("index_select"), key=["M", "N"])
@triton.jit
def index_select_kernel(
inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
Expand Down
Loading
Loading