Skip to content

Commit 0510df2

Browse files
xmfanfmassa
andauthored
Add example_ds3_local_map.py to CI (#220)
* Add example_ds3_local_map.py to CI stack-info: PR: #220, branch: xmfan/stack/16 * Try fix lint + CI --------- Co-authored-by: Francisco Massa <[email protected]>
1 parent a6cef81 commit 0510df2

File tree

4 files changed

+36
-19
lines changed

4 files changed

+36
-19
lines changed

.github/workflows/test_cuda.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@ jobs:
4545
python examples/example_llama3.py
4646
python examples/example_dcp.py
4747
python examples/example_local_map.py
48+
python examples/example_ds3_local_map.py

autoparallel/_passes/graph_partition.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
16
from typing import Any, Callable
27

38
import torch

autoparallel/_testing/models/dsv3.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
16
import math
27
from dataclasses import dataclass, field
38
from typing import Callable, ClassVar, Literal, Optional, Tuple
@@ -123,14 +128,17 @@ def fill_indices_cpu(
123128
# For each local expert
124129
for e in range(experts_per_rank):
125130
write_start = write_offsets[e].item()
131+
assert isinstance(write_start, int)
126132
# For each remote rank
127133
for r in range(num_ranks):
128-
i = r * experts_per_rank + e
134+
i: int = r * experts_per_rank + e
129135
start_index = start_index_values[i].item()
130136
length = tokens_per_expert_group[i].item()
137+
assert isinstance(start_index, int)
138+
assert isinstance(length, int)
131139
# Fill in the indices
132140
if length > 0:
133-
end_idx = min(write_start + length, max_len)
141+
end_idx: int = min(write_start + length, max_len)
134142
permuted_indices[write_start:end_idx] = torch.arange(
135143
start_index,
136144
start_index + (end_idx - write_start),
@@ -256,6 +264,8 @@ def wrapper(
256264
) -> torch.Tensor:
257265
global TOKEN_GROUP_ALIGN_SIZE_M
258266
if isinstance(w1, DTensor):
267+
assert isinstance(w2, DTensor)
268+
assert isinstance(w3, DTensor)
259269
w1 = w1.to_local()
260270
w2 = w2.to_local()
261271
w3 = w3.to_local()
@@ -343,19 +353,19 @@ def _run_experts_for_loop(
343353
w1: torch.Tensor,
344354
w2: torch.Tensor,
345355
w3: torch.Tensor,
346-
x: torch.Tensor,
347-
num_tokens_per_expert: torch.Tensor,
356+
x_: torch.Tensor,
357+
num_tokens_per_expert_: torch.Tensor,
348358
) -> torch.Tensor:
349359
# NOTE: this would incur a synchronization between device and host
350-
num_tokens_per_expert = num_tokens_per_expert.tolist()
360+
num_tokens_per_expert: list[int] = num_tokens_per_expert_.tolist()
351361

352362
# side-effect code due to the usage of generate_permute_indices
353-
num_padding = x.shape[0] - sum(num_tokens_per_expert)
363+
num_padding: int = x_.shape[0] - sum(num_tokens_per_expert)
354364

355365
# a tuple of tensors indexed by experts
356366
# each with shape (tokens_per_expert(varying), dim)
357-
x = torch.split(
358-
x[: sum(num_tokens_per_expert)],
367+
x: tuple[torch.Tensor, ...] = torch.split(
368+
x_[: sum(num_tokens_per_expert)],
359369
split_size_or_sections=num_tokens_per_expert,
360370
dim=0,
361371
)
@@ -444,7 +454,7 @@ def batched_histc(
444454

445455

446456
@batched_histc.register_fake
447-
def batched_histc(
457+
def batched_histc_fake(
448458
x: torch.Tensor, bins: int = 100, min: int = 0, max: int = 0
449459
) -> torch.Tensor:
450460
assert max - min == bins
@@ -987,10 +997,9 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
987997
route_scale=moe_args.route_scale,
988998
)
989999
self.reorderer = TokenReorderer(num_experts=num_experts, top_k=moe_args.top_k)
990-
self.shared_experts = (
991-
FeedForward(dim=dim, hidden_dim=hidden_dim * moe_args.num_shared_experts)
992-
if moe_args.num_shared_experts > 0
993-
else None
1000+
assert moe_args.num_shared_experts > 0
1001+
self.shared_experts = FeedForward(
1002+
dim=dim, hidden_dim=hidden_dim * moe_args.num_shared_experts
9941003
)
9951004
self.score_before_experts = moe_args.score_before_experts
9961005

@@ -1060,6 +1069,7 @@ def init_weights(
10601069
self.experts.num_experts, dtype=torch.float32
10611070
)
10621071
if self.load_balance_coeff is not None:
1072+
assert isinstance(self.expert_bias, torch.Tensor)
10631073
self.expert_bias = torch.zeros(
10641074
self.experts.num_experts, dtype=torch.float32
10651075
)
@@ -1513,13 +1523,14 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
15131523
self.model_args = model_args
15141524

15151525
def init_weights(self, buffer_device: torch.device | None = None) -> None:
1516-
buffer_device = buffer_device or self.freqs_cis.device
1526+
buffer_device = buffer_device or self.freqs_cis.device # type: ignore[has-type]
15171527
with torch.device(buffer_device):
15181528
self.freqs_cis = precompute_freqs_cis(self.model_args)
15191529
if self.tok_embeddings is not None:
15201530
nn.init.normal_(self.tok_embeddings.weight)
15211531
for layer in self.layers.values():
15221532
if layer is not None:
1533+
assert isinstance(layer, TransformerBlock)
15231534
layer.init_weights(buffer_device=buffer_device)
15241535
if self.norm is not None:
15251536
self.norm.reset_parameters()

tests/test_graph_partition.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,14 @@ def input_fn():
115115

116116
# Symbolically evaluate in case you want to test running a graph bigger than your gpu
117117

118-
with (
119-
FakeTensorMode(
118+
mode: nullcontext[None] | FakeTensorMode = nullcontext()
119+
if fake_evaluate:
120+
mode = FakeTensorMode(
120121
allow_non_fake_inputs=True,
121122
shape_env=ShapeEnv(),
122123
)
123-
if fake_evaluate
124-
else nullcontext()
125-
):
124+
125+
with mode:
126126
# # now let's run it
127127
outputs = pp_mod(*x)
128128
assert len(outputs) == 1

0 commit comments

Comments
 (0)