Skip to content

梯度爆炸 #8

@tanjiarui

Description

@tanjiarui

你好,最近把你们模改的tutel拿来运行了一下,跑了一个简单的模型来验证代码。但是不管怎么调整都出现了损失消失的问题。可以帮我看一下是超参数不对还是模型结构的问题吗?谢谢

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import nn
from torch.optim import Adam
from tutel import moe


# 定义多头注意力
class MultiHeadAttention(nn.Module):
	def __init__(self, dim, heads=8):
		super().__init__()
		self.heads = heads
		self.scale = dim ** -0.5
		self.to_qkv = nn.Linear(dim, dim * 3)
		self.out_proj = nn.Linear(dim, dim)

	def forward(self, x):
		B, T, D = x.shape
		qkv = self.to_qkv(x).chunk(3, dim=-1)
		q, k, v = [t.reshape(B, T, self.heads, D // self.heads).transpose(1, 2) for t in qkv]
		attn_weights = (q @ k.transpose(-2, -1)) * self.scale
		attn = attn_weights.softmax(dim=-1)
		out = (attn @ v).transpose(1, 2).reshape(B, T, D)
		return self.out_proj(out)


# Transformer block + GMoE FFN
class TransformerBlock(nn.Module):
	def __init__(self, dim=512, num_experts=4, hidden_dim=2048):
		super().__init__()
		self.attn = MultiHeadAttention(dim)
		self.norm1 = nn.LayerNorm(dim)
		self.norm2 = nn.LayerNorm(dim)
		self.norm3 = nn.LayerNorm(dim)
		self.moe_layer = moe.moe_layer(gate_type={'type': 'top', 'k': 1, 'capacity_factor': 0, 'gate_noise': 1.0}, experts={'type': 'ffn', 'count_per_node': num_experts, 'hidden_size_per_expert': hidden_dim, 'activation_fn': lambda x: nn.functional.relu(x)}, group=2, model_dim=dim)
		self.ffn = nn.Sequential(nn.Linear(dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, dim))

	def forward(self, x):
		x = x + self.attn(self.norm1(x))
		moe_out = self.moe_layer(self.norm2(x))
		return x + moe_out

	# def forward(self, x):
	# 	x = x + self.attn(self.norm1(x))
	# 	x = x + self.norm3(self.ffn(self.norm2(x)))
	# 	return x


# 🧪 模拟训练数据
def build_dataloader(batch_size=8):
	X = torch.randn(100, 16, 64)
	W = torch.randn(64, 3)
	Y = (X.mean(dim=1) @ W).argmax(dim=1)
	dataset = TensorDataset(X, Y)
	return DataLoader(dataset, batch_size=batch_size, shuffle=True)


# 🎯 简单分类头(接在 TransformerBlock 后)
class ClassificationModel(nn.Module):
	def __init__(self, dim=512, num_experts=4, hidden_dim=2048, num_classes=10):
		super().__init__()
		self.projection = nn.Linear(64, dim)
		self.transformer = TransformerBlock(dim, num_experts, hidden_dim)
		self.pool = nn.AdaptiveAvgPool1d(1)
		self.classifier = nn.Linear(dim, num_classes)

	def forward(self, x):
		x = self.projection(x)
		x = self.transformer(x)
		x = x.transpose(1, 2)  # [B, D, T]
		x = self.pool(x).squeeze(-1)  # [B, D]
		logits = self.classifier(x)  # [B, num_classes]
		return logits


# 🚀 开始训练
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataloader = build_dataloader()
model = ClassificationModel(dim=128, num_experts=4, hidden_dim=1024, num_classes=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-4)
for epoch in range(1, 200):
	epoch_loss = 0
	epoch_correct = 0
	epoch_total = 0
	for x, y in dataloader:
		x, y = x.to(device), y.to(device)
		optimizer.zero_grad()
		# model.transformer.moe_layer.begin_record_routing()
		logits = model(x)
		loss = criterion(logits, y) + .0001 * model.transformer.moe_layer.l_aux
		loss.backward()

		# for p in model.parameters():
		# 	if not hasattr(p, 'skip_allreduce') and p.grad is not None:
		# 		p.grad = net.simple_all_reduce(p.grad)

		optimizer.step()

		prediction = logits.argmax(dim=1, keepdim=True)
		correct = prediction.eq(y.view_as(prediction)).sum().item()
		total_items = int(logits.size(0))

		epoch_loss += loss.item()
		epoch_correct += correct
		epoch_total += total_items

	# if epoch % 10 == 0:
	# 	model.transformer.moe_layer.adaptive_update_experts()
	# 	model.transformer.moe_layer.end_record_routing()
	if epoch % 10 == 0:
		print(f'epoch {epoch}: loss = {epoch_loss / len(dataloader):.4f}, accuracy = {epoch_correct / epoch_total * 100:.2f}%')

不管参数gate_type={'type': 'top', 'k': 1, 'capacity_factor': 0, 'gate_noise': 1.0}还是gate_type={'type': 'gated_multi_gate', 'max_expert_num': 16},运行出来的损失都消失了

Gate types:  ['LinearTopKGate']
4
epoch 10: loss = nan, accuracy = 36.00%
epoch 20: loss = nan, accuracy = 36.00%
epoch 30: loss = nan, accuracy = 36.00%
epoch 40: loss = nan, accuracy = 36.00%
epoch 50: loss = nan, accuracy = 36.00%
epoch 60: loss = nan, accuracy = 36.00%
epoch 70: loss = nan, accuracy = 36.00%
epoch 80: loss = nan, accuracy = 36.00%
epoch 90: loss = nan, accuracy = 36.00%
epoch 100: loss = nan, accuracy = 36.00%
epoch 110: loss = nan, accuracy = 36.00%
epoch 120: loss = nan, accuracy = 36.00%
epoch 130: loss = nan, accuracy = 36.00%
epoch 140: loss = nan, accuracy = 36.00%
epoch 150: loss = nan, accuracy = 36.00%
epoch 160: loss = nan, accuracy = 36.00%
epoch 170: loss = nan, accuracy = 36.00%
epoch 180: loss = nan, accuracy = 36.00%
epoch 190: loss = nan, accuracy = 36.00%

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions