|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import torch |
| 8 | +import triton |
| 9 | +import triton.language as tl |
| 10 | + |
| 11 | + |
| 12 | +__all__ = ["generate_permute_indices"] |
| 13 | + |
| 14 | + |
| 15 | +# parallelized kernel |
| 16 | +@triton.jit |
| 17 | +def _fill_indices_kernel( |
| 18 | + tokens_per_expert_group_ptr, |
| 19 | + start_index_values_ptr, |
| 20 | + write_offsets_ptr, |
| 21 | + output_ptr, |
| 22 | + experts_per_rank: tl.constexpr, |
| 23 | + num_ranks: tl.constexpr, |
| 24 | + BLOCK_SIZE: tl.constexpr, # Number of threads per block |
| 25 | +): |
| 26 | + pid = tl.program_id(axis=0) |
| 27 | + num_programs = tl.num_programs(axis=0) |
| 28 | + |
| 29 | + # map programs (blocks) to the experts and loop (grid stride) if needed |
| 30 | + for expert_id in range(pid, experts_per_rank, num_programs): |
| 31 | + |
| 32 | + # read this experts write offset |
| 33 | + write_offset = tl.load(write_offsets_ptr + expert_id) |
| 34 | + |
| 35 | + # loop over all ranks |
| 36 | + for r in range(num_ranks): |
| 37 | + # index into tokens_per_expert_group array |
| 38 | + i = r * experts_per_rank + expert_id |
| 39 | + |
| 40 | + # load start index and number of tokens for this expert-rank pair |
| 41 | + start_index = tl.load(start_index_values_ptr + i) |
| 42 | + length = tl.load(tokens_per_expert_group_ptr + i) |
| 43 | + |
| 44 | + # each thread in block processes tokens in parallel |
| 45 | + offsets = tl.arange(0, BLOCK_SIZE) |
| 46 | + |
| 47 | + # tokens are processed in chunks of BLOCK_SIZE |
| 48 | + for chunk_start in range(0, length, BLOCK_SIZE): |
| 49 | + chunk_offsets = chunk_start + offsets |
| 50 | + |
| 51 | + # mask valid indices |
| 52 | + mask = chunk_offsets < length |
| 53 | + |
| 54 | + values = start_index + chunk_offsets |
| 55 | + |
| 56 | + # destination |
| 57 | + dest_indices = write_offset + chunk_offsets |
| 58 | + |
| 59 | + # store |
| 60 | + tl.store(output_ptr + dest_indices, values, mask=mask) |
| 61 | + |
| 62 | + # update write offset for next rank |
| 63 | + write_offset += length |
| 64 | + |
| 65 | + |
| 66 | +# ============== |
| 67 | +# wrapper |
| 68 | +# ============== |
| 69 | + |
| 70 | + |
| 71 | +def fill_indices_wrapper( |
| 72 | + tokens_per_expert_group: torch.Tensor, |
| 73 | + start_index_values: torch.Tensor, |
| 74 | + write_offsets: torch.Tensor, |
| 75 | + experts_per_rank: int, |
| 76 | + num_ranks: int, |
| 77 | + max_len: int, |
| 78 | + block_size: int = 128, |
| 79 | + max_blocks: int = 1024, # cap on total number of blocks to launch |
| 80 | +): |
| 81 | + # preallocate output |
| 82 | + permuted_indices = torch.full( |
| 83 | + (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device |
| 84 | + ) |
| 85 | + |
| 86 | + # write offsets is per local expert... |
| 87 | + num_blocks = min(experts_per_rank, max_blocks) |
| 88 | + # grid = one block per expert unless capped and then we loop... |
| 89 | + grid = (num_blocks,) |
| 90 | + |
| 91 | + # launch kernel |
| 92 | + _fill_indices_kernel[grid]( |
| 93 | + tokens_per_expert_group, |
| 94 | + start_index_values, |
| 95 | + write_offsets, |
| 96 | + permuted_indices, |
| 97 | + experts_per_rank, |
| 98 | + num_ranks, |
| 99 | + BLOCK_SIZE=block_size, |
| 100 | + ) |
| 101 | + return permuted_indices |
| 102 | + |
| 103 | + |
| 104 | +# reference |
| 105 | +def fill_indices_cpu( |
| 106 | + tokens_per_expert_group: torch.Tensor, |
| 107 | + start_index_values: torch.Tensor, |
| 108 | + write_offsets: torch.Tensor, |
| 109 | + experts_per_rank: int, |
| 110 | + num_ranks: int, |
| 111 | + max_len: int, |
| 112 | +): |
| 113 | + # We need to preallocate the output - we ignore device and force it on cpu |
| 114 | + # device = tokens_per_expert_group.device |
| 115 | + permuted_indices = torch.full( |
| 116 | + (max_len,), |
| 117 | + -1, |
| 118 | + dtype=torch.int32, |
| 119 | + ) # device=device) |
| 120 | + # Fill the permuted indices |
| 121 | + # For each local expert |
| 122 | + for e in range(experts_per_rank): |
| 123 | + write_start = write_offsets[e].item() |
| 124 | + # For each remote rank |
| 125 | + for r in range(num_ranks): |
| 126 | + i = r * experts_per_rank + e |
| 127 | + start_index = start_index_values[i].item() |
| 128 | + length = tokens_per_expert_group[i].item() |
| 129 | + # Fill in the indices |
| 130 | + if length > 0: |
| 131 | + end_idx = min(write_start + length, max_len) |
| 132 | + permuted_indices[write_start:end_idx] = torch.arange( |
| 133 | + start_index, |
| 134 | + start_index + (end_idx - write_start), |
| 135 | + dtype=torch.int32, |
| 136 | + # device=device, |
| 137 | + ) |
| 138 | + write_start += length |
| 139 | + return permuted_indices |
| 140 | + |
| 141 | + |
| 142 | +def generate_permute_indices( |
| 143 | + tokens_per_expert_group: torch.Tensor, |
| 144 | + experts_per_rank: int, |
| 145 | + num_ranks: int, |
| 146 | + max_len: int, |
| 147 | + alignment: int, |
| 148 | + use_cpu: bool = False, |
| 149 | +): |
| 150 | + """ |
| 151 | + Prepare permutation indices and the number of tokens for each expert. |
| 152 | +
|
| 153 | + Args: |
| 154 | + tokens_per_expert_group: number of tokens for each expert from all ranks. |
| 155 | + experts_per_rank: number of experts per rank. |
| 156 | + num_ranks: number of ranks. |
| 157 | + max_len: maximum length of the output index vector. |
| 158 | + alignment: alignment for each returned element in `m_sizes`. |
| 159 | + use_cpu: whether to use CPU implementation. |
| 160 | + use_optimized: whether to use optimized Triton implementation. |
| 161 | + block_size: block size for optimized implementation. |
| 162 | +
|
| 163 | + Returns: |
| 164 | + permuted_indices: permutation indices. |
| 165 | + m_sizes: number of tokens for each expert. |
| 166 | + """ |
| 167 | + # prefix sum to get start index of each expert (parallel scan kernel in future?) |
| 168 | + start_index_values = ( |
| 169 | + torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group |
| 170 | + ) |
| 171 | + |
| 172 | + # chunk sizes for each expert |
| 173 | + chunk_size_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0) |
| 174 | + |
| 175 | + # align the chunk sizes (cdiv) |
| 176 | + m_sizes = ((chunk_size_per_expert + alignment - 1) // alignment * alignment).to( |
| 177 | + torch.int32 |
| 178 | + ) |
| 179 | + |
| 180 | + # additional prefix sum to get write offset of each expert in permuted_indices |
| 181 | + # write offsets is per local expert, not global |
| 182 | + write_offsets = torch.cumsum(m_sizes, 0) - m_sizes |
| 183 | + |
| 184 | + # Select the implementation to use |
| 185 | + if use_cpu: |
| 186 | + permuted_indices = fill_indices_cpu( |
| 187 | + tokens_per_expert_group, |
| 188 | + start_index_values, |
| 189 | + write_offsets, |
| 190 | + experts_per_rank, |
| 191 | + num_ranks, |
| 192 | + max_len, |
| 193 | + ) |
| 194 | + else: |
| 195 | + permuted_indices = fill_indices_wrapper( |
| 196 | + tokens_per_expert_group, |
| 197 | + start_index_values, |
| 198 | + write_offsets, |
| 199 | + experts_per_rank, |
| 200 | + num_ranks, |
| 201 | + max_len, |
| 202 | + ) |
| 203 | + |
| 204 | + return permuted_indices, m_sizes |
| 205 | + |
| 206 | + |
| 207 | +# Below is for testing only |
| 208 | + |
| 209 | + |
| 210 | +def simple_test(): |
| 211 | + device = torch.device("cuda", 0) |
| 212 | + experts_per_rank = 4 |
| 213 | + num_ranks = 4 |
| 214 | + tokens_per_expert_group = torch.full( |
| 215 | + (num_ranks * experts_per_rank,), 4, dtype=torch.int32, device=device |
| 216 | + ) |
| 217 | + max_len = 128 |
| 218 | + alignment = 32 |
| 219 | + # Use the GPU kernel |
| 220 | + permuted_indices_gpu, m_sizes = generate_permute_indices( |
| 221 | + tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment |
| 222 | + ) |
| 223 | + # Use the CPU method |
| 224 | + permuted_indices_cpu, _ = generate_permute_indices( |
| 225 | + tokens_per_expert_group, |
| 226 | + experts_per_rank, |
| 227 | + num_ranks, |
| 228 | + max_len, |
| 229 | + alignment, |
| 230 | + use_cpu=True, |
| 231 | + ) |
| 232 | + # Check that the results are the same |
| 233 | + |
| 234 | + assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu) |
| 235 | + assert torch.equal( |
| 236 | + torch.remainder(m_sizes, alignment), |
| 237 | + torch.zeros(experts_per_rank, device=device), |
| 238 | + ) |
| 239 | + # Print the results |
| 240 | + print(f"{permuted_indices_gpu=}, \n{permuted_indices_cpu=}") |
| 241 | + print(f"{m_sizes=}") |
| 242 | + print("Success") |
| 243 | + return True # assert would have failed meaning getting here is success. |
| 244 | + |
| 245 | + |
| 246 | +if __name__ == "__main__": |
| 247 | + simple_test() |
0 commit comments