Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Expert parallelism support #1435

Open
2 tasks done
chongli-uw opened this issue Sep 16, 2024 · 1 comment
Open
2 tasks done

[Feature] Expert parallelism support #1435

chongli-uw opened this issue Sep 16, 2024 · 1 comment
Assignees
Labels
enhancement New feature or request

Comments

@chongli-uw
Copy link

chongli-uw commented Sep 16, 2024

Checklist

Motivation

Hi team,
First of all thanks so much for such a great project. I am wondering if there is plan to support Expert Parallelism for MoE models?

Related resources

https://nvidia.github.io/TensorRT-LLM/advanced/expert-parallelism.html

@Ying1123 Ying1123 added the enhancement New feature or request label Sep 16, 2024
@merrymercy
Copy link
Contributor

class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}."
)
# Split experts equally between ranks
self.expert_indicies = np.array_split(
range(self.num_total_experts), self.tp_size
)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList(
[
(
MixtralMLP(
self.num_total_experts,
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
)
if idx in self.expert_indicies
else None
)
for idx in range(self.num_total_experts)
]
)
self.gate = ReplicatedLinear(
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = selected_experts == expert_idx
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states)

this is an early example

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants