Skip to content

Commit 4f532e0

Browse files
authored
[DeepSeek][Kernels] add on device, parallelized permute_indices kernel (#1077)
This PR updates the original Triton kernel in indices.py (#1062) to move it to a more parallelized and vectorized implementation to better support larger runs. This provides roughly an 8.7x perf improvement for 256 experts. This kernel functionally achieves the same as before: Prepare permutation indices and the number of tokens for each expert with any required alignment, where alignment is driven by the group gemm requirement. * The permutation indices are the indices of the tokens for each expert. * The number of tokens for each expert is the sum of the number of tokens for such experts from all ranks. However, this kernel improves the parallelization by assigning one block to each expert, and threads in each block process the tokens in chunks, all while using vectorized loads and stores. This results in the performance improvement while yielding the same indices. Testing: ~~~ python moe_kernels.py ~~~ (runs simple verification using cpu, previous kernel and new optimized kernel) Benchmark of performance improvement using fixed 256 experts across various world sizes ~~~ Summary for fixed total experts (256), where we mimic various DeepSeek experts across a number of ranks: Ranks | Experts/Rank | Original (ms) | Optimized (ms) | Block Size | Speedup | Correct ------------------------------------------------------------------------------------------ 1 | 256 | 0.241 | 0.020 | 128 | 11.95x | 1 4 | 64 | 0.222 | 0.020 | 128 | 10.97x | 1 16 | 16 | 0.221 | 0.021 | 128 | 10.40x | 1 32 | 8 | 0.224 | 0.025 | 64 | 8.81x | 1 64 | 4 | 0.225 | 0.037 | 128 | 6.01x | 1 128 | 2 | 0.217 | 0.051 | 128 | 4.26x | 1 ~~~ added unit testing: ~~~ ..Testing 256 experts per rank across 1 ranks (total: 256) Testing 128 experts per rank across 2 ranks (total: 256) Testing 64 experts per rank across 4 ranks (total: 256) Testing 32 experts per rank across 8 ranks (total: 256) Testing 16 experts per rank across 16 ranks (total: 256) Testing 8 experts per rank across 32 ranks (total: 256) Testing 4 experts per rank across 64 ranks (total: 256) Testing 2 experts per rank across 128 ranks (total: 256) . ---------------------------------------------------------------------- Ran 3 tests in 1.104s OK ~~~
1 parent 34c2346 commit 4f532e0

File tree

2 files changed

+670
-0
lines changed

2 files changed

+670
-0
lines changed
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
import triton
9+
import triton.language as tl
10+
11+
12+
__all__ = ["generate_permute_indices"]
13+
14+
15+
# parallelized kernel
16+
@triton.jit
17+
def _fill_indices_kernel(
18+
tokens_per_expert_group_ptr,
19+
start_index_values_ptr,
20+
write_offsets_ptr,
21+
output_ptr,
22+
experts_per_rank: tl.constexpr,
23+
num_ranks: tl.constexpr,
24+
BLOCK_SIZE: tl.constexpr, # Number of threads per block
25+
):
26+
pid = tl.program_id(axis=0)
27+
num_programs = tl.num_programs(axis=0)
28+
29+
# map programs (blocks) to the experts and loop (grid stride) if needed
30+
for expert_id in range(pid, experts_per_rank, num_programs):
31+
32+
# read this experts write offset
33+
write_offset = tl.load(write_offsets_ptr + expert_id)
34+
35+
# loop over all ranks
36+
for r in range(num_ranks):
37+
# index into tokens_per_expert_group array
38+
i = r * experts_per_rank + expert_id
39+
40+
# load start index and number of tokens for this expert-rank pair
41+
start_index = tl.load(start_index_values_ptr + i)
42+
length = tl.load(tokens_per_expert_group_ptr + i)
43+
44+
# each thread in block processes tokens in parallel
45+
offsets = tl.arange(0, BLOCK_SIZE)
46+
47+
# tokens are processed in chunks of BLOCK_SIZE
48+
for chunk_start in range(0, length, BLOCK_SIZE):
49+
chunk_offsets = chunk_start + offsets
50+
51+
# mask valid indices
52+
mask = chunk_offsets < length
53+
54+
values = start_index + chunk_offsets
55+
56+
# destination
57+
dest_indices = write_offset + chunk_offsets
58+
59+
# store
60+
tl.store(output_ptr + dest_indices, values, mask=mask)
61+
62+
# update write offset for next rank
63+
write_offset += length
64+
65+
66+
# ==============
67+
# wrapper
68+
# ==============
69+
70+
71+
def fill_indices_wrapper(
72+
tokens_per_expert_group: torch.Tensor,
73+
start_index_values: torch.Tensor,
74+
write_offsets: torch.Tensor,
75+
experts_per_rank: int,
76+
num_ranks: int,
77+
max_len: int,
78+
block_size: int = 128,
79+
max_blocks: int = 1024, # cap on total number of blocks to launch
80+
):
81+
# preallocate output
82+
permuted_indices = torch.full(
83+
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
84+
)
85+
86+
# write offsets is per local expert...
87+
num_blocks = min(experts_per_rank, max_blocks)
88+
# grid = one block per expert unless capped and then we loop...
89+
grid = (num_blocks,)
90+
91+
# launch kernel
92+
_fill_indices_kernel[grid](
93+
tokens_per_expert_group,
94+
start_index_values,
95+
write_offsets,
96+
permuted_indices,
97+
experts_per_rank,
98+
num_ranks,
99+
BLOCK_SIZE=block_size,
100+
)
101+
return permuted_indices
102+
103+
104+
# reference
105+
def fill_indices_cpu(
106+
tokens_per_expert_group: torch.Tensor,
107+
start_index_values: torch.Tensor,
108+
write_offsets: torch.Tensor,
109+
experts_per_rank: int,
110+
num_ranks: int,
111+
max_len: int,
112+
):
113+
# We need to preallocate the output - we ignore device and force it on cpu
114+
# device = tokens_per_expert_group.device
115+
permuted_indices = torch.full(
116+
(max_len,),
117+
-1,
118+
dtype=torch.int32,
119+
) # device=device)
120+
# Fill the permuted indices
121+
# For each local expert
122+
for e in range(experts_per_rank):
123+
write_start = write_offsets[e].item()
124+
# For each remote rank
125+
for r in range(num_ranks):
126+
i = r * experts_per_rank + e
127+
start_index = start_index_values[i].item()
128+
length = tokens_per_expert_group[i].item()
129+
# Fill in the indices
130+
if length > 0:
131+
end_idx = min(write_start + length, max_len)
132+
permuted_indices[write_start:end_idx] = torch.arange(
133+
start_index,
134+
start_index + (end_idx - write_start),
135+
dtype=torch.int32,
136+
# device=device,
137+
)
138+
write_start += length
139+
return permuted_indices
140+
141+
142+
def generate_permute_indices(
143+
tokens_per_expert_group: torch.Tensor,
144+
experts_per_rank: int,
145+
num_ranks: int,
146+
max_len: int,
147+
alignment: int,
148+
use_cpu: bool = False,
149+
):
150+
"""
151+
Prepare permutation indices and the number of tokens for each expert.
152+
153+
Args:
154+
tokens_per_expert_group: number of tokens for each expert from all ranks.
155+
experts_per_rank: number of experts per rank.
156+
num_ranks: number of ranks.
157+
max_len: maximum length of the output index vector.
158+
alignment: alignment for each returned element in `m_sizes`.
159+
use_cpu: whether to use CPU implementation.
160+
use_optimized: whether to use optimized Triton implementation.
161+
block_size: block size for optimized implementation.
162+
163+
Returns:
164+
permuted_indices: permutation indices.
165+
m_sizes: number of tokens for each expert.
166+
"""
167+
# prefix sum to get start index of each expert (parallel scan kernel in future?)
168+
start_index_values = (
169+
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
170+
)
171+
172+
# chunk sizes for each expert
173+
chunk_size_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
174+
175+
# align the chunk sizes (cdiv)
176+
m_sizes = ((chunk_size_per_expert + alignment - 1) // alignment * alignment).to(
177+
torch.int32
178+
)
179+
180+
# additional prefix sum to get write offset of each expert in permuted_indices
181+
# write offsets is per local expert, not global
182+
write_offsets = torch.cumsum(m_sizes, 0) - m_sizes
183+
184+
# Select the implementation to use
185+
if use_cpu:
186+
permuted_indices = fill_indices_cpu(
187+
tokens_per_expert_group,
188+
start_index_values,
189+
write_offsets,
190+
experts_per_rank,
191+
num_ranks,
192+
max_len,
193+
)
194+
else:
195+
permuted_indices = fill_indices_wrapper(
196+
tokens_per_expert_group,
197+
start_index_values,
198+
write_offsets,
199+
experts_per_rank,
200+
num_ranks,
201+
max_len,
202+
)
203+
204+
return permuted_indices, m_sizes
205+
206+
207+
# Below is for testing only
208+
209+
210+
def simple_test():
211+
device = torch.device("cuda", 0)
212+
experts_per_rank = 4
213+
num_ranks = 4
214+
tokens_per_expert_group = torch.full(
215+
(num_ranks * experts_per_rank,), 4, dtype=torch.int32, device=device
216+
)
217+
max_len = 128
218+
alignment = 32
219+
# Use the GPU kernel
220+
permuted_indices_gpu, m_sizes = generate_permute_indices(
221+
tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment
222+
)
223+
# Use the CPU method
224+
permuted_indices_cpu, _ = generate_permute_indices(
225+
tokens_per_expert_group,
226+
experts_per_rank,
227+
num_ranks,
228+
max_len,
229+
alignment,
230+
use_cpu=True,
231+
)
232+
# Check that the results are the same
233+
234+
assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu)
235+
assert torch.equal(
236+
torch.remainder(m_sizes, alignment),
237+
torch.zeros(experts_per_rank, device=device),
238+
)
239+
# Print the results
240+
print(f"{permuted_indices_gpu=}, \n{permuted_indices_cpu=}")
241+
print(f"{m_sizes=}")
242+
print("Success")
243+
return True # assert would have failed meaning getting here is success.
244+
245+
246+
if __name__ == "__main__":
247+
simple_test()

0 commit comments

Comments
 (0)