|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
1 | 6 | import math |
2 | 7 | from dataclasses import dataclass, field |
3 | 8 | from typing import Callable, ClassVar, Literal, Optional, Tuple |
@@ -123,14 +128,17 @@ def fill_indices_cpu( |
123 | 128 | # For each local expert |
124 | 129 | for e in range(experts_per_rank): |
125 | 130 | write_start = write_offsets[e].item() |
| 131 | + assert isinstance(write_start, int) |
126 | 132 | # For each remote rank |
127 | 133 | for r in range(num_ranks): |
128 | | - i = r * experts_per_rank + e |
| 134 | + i: int = r * experts_per_rank + e |
129 | 135 | start_index = start_index_values[i].item() |
130 | 136 | length = tokens_per_expert_group[i].item() |
| 137 | + assert isinstance(start_index, int) |
| 138 | + assert isinstance(length, int) |
131 | 139 | # Fill in the indices |
132 | 140 | if length > 0: |
133 | | - end_idx = min(write_start + length, max_len) |
| 141 | + end_idx: int = min(write_start + length, max_len) |
134 | 142 | permuted_indices[write_start:end_idx] = torch.arange( |
135 | 143 | start_index, |
136 | 144 | start_index + (end_idx - write_start), |
@@ -256,6 +264,8 @@ def wrapper( |
256 | 264 | ) -> torch.Tensor: |
257 | 265 | global TOKEN_GROUP_ALIGN_SIZE_M |
258 | 266 | if isinstance(w1, DTensor): |
| 267 | + assert isinstance(w2, DTensor) |
| 268 | + assert isinstance(w3, DTensor) |
259 | 269 | w1 = w1.to_local() |
260 | 270 | w2 = w2.to_local() |
261 | 271 | w3 = w3.to_local() |
@@ -343,19 +353,19 @@ def _run_experts_for_loop( |
343 | 353 | w1: torch.Tensor, |
344 | 354 | w2: torch.Tensor, |
345 | 355 | w3: torch.Tensor, |
346 | | - x: torch.Tensor, |
347 | | - num_tokens_per_expert: torch.Tensor, |
| 356 | + x_: torch.Tensor, |
| 357 | + num_tokens_per_expert_: torch.Tensor, |
348 | 358 | ) -> torch.Tensor: |
349 | 359 | # NOTE: this would incur a synchronization between device and host |
350 | | - num_tokens_per_expert = num_tokens_per_expert.tolist() |
| 360 | + num_tokens_per_expert: list[int] = num_tokens_per_expert_.tolist() |
351 | 361 |
|
352 | 362 | # side-effect code due to the usage of generate_permute_indices |
353 | | - num_padding = x.shape[0] - sum(num_tokens_per_expert) |
| 363 | + num_padding: int = x_.shape[0] - sum(num_tokens_per_expert) |
354 | 364 |
|
355 | 365 | # a tuple of tensors indexed by experts |
356 | 366 | # each with shape (tokens_per_expert(varying), dim) |
357 | | - x = torch.split( |
358 | | - x[: sum(num_tokens_per_expert)], |
| 367 | + x: tuple[torch.Tensor, ...] = torch.split( |
| 368 | + x_[: sum(num_tokens_per_expert)], |
359 | 369 | split_size_or_sections=num_tokens_per_expert, |
360 | 370 | dim=0, |
361 | 371 | ) |
@@ -444,7 +454,7 @@ def batched_histc( |
444 | 454 |
|
445 | 455 |
|
446 | 456 | @batched_histc.register_fake |
447 | | -def batched_histc( |
| 457 | +def batched_histc_fake( |
448 | 458 | x: torch.Tensor, bins: int = 100, min: int = 0, max: int = 0 |
449 | 459 | ) -> torch.Tensor: |
450 | 460 | assert max - min == bins |
@@ -987,10 +997,9 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): |
987 | 997 | route_scale=moe_args.route_scale, |
988 | 998 | ) |
989 | 999 | self.reorderer = TokenReorderer(num_experts=num_experts, top_k=moe_args.top_k) |
990 | | - self.shared_experts = ( |
991 | | - FeedForward(dim=dim, hidden_dim=hidden_dim * moe_args.num_shared_experts) |
992 | | - if moe_args.num_shared_experts > 0 |
993 | | - else None |
| 1000 | + assert moe_args.num_shared_experts > 0 |
| 1001 | + self.shared_experts = FeedForward( |
| 1002 | + dim=dim, hidden_dim=hidden_dim * moe_args.num_shared_experts |
994 | 1003 | ) |
995 | 1004 | self.score_before_experts = moe_args.score_before_experts |
996 | 1005 |
|
@@ -1060,6 +1069,7 @@ def init_weights( |
1060 | 1069 | self.experts.num_experts, dtype=torch.float32 |
1061 | 1070 | ) |
1062 | 1071 | if self.load_balance_coeff is not None: |
| 1072 | + assert isinstance(self.expert_bias, torch.Tensor) |
1063 | 1073 | self.expert_bias = torch.zeros( |
1064 | 1074 | self.experts.num_experts, dtype=torch.float32 |
1065 | 1075 | ) |
@@ -1513,13 +1523,14 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): |
1513 | 1523 | self.model_args = model_args |
1514 | 1524 |
|
1515 | 1525 | def init_weights(self, buffer_device: torch.device | None = None) -> None: |
1516 | | - buffer_device = buffer_device or self.freqs_cis.device |
| 1526 | + buffer_device = buffer_device or self.freqs_cis.device # type: ignore[has-type] |
1517 | 1527 | with torch.device(buffer_device): |
1518 | 1528 | self.freqs_cis = precompute_freqs_cis(self.model_args) |
1519 | 1529 | if self.tok_embeddings is not None: |
1520 | 1530 | nn.init.normal_(self.tok_embeddings.weight) |
1521 | 1531 | for layer in self.layers.values(): |
1522 | 1532 | if layer is not None: |
| 1533 | + assert isinstance(layer, TransformerBlock) |
1523 | 1534 | layer.init_weights(buffer_device=buffer_device) |
1524 | 1535 | if self.norm is not None: |
1525 | 1536 | self.norm.reset_parameters() |
|
0 commit comments