Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
940 changes: 940 additions & 0 deletions lzero/entry/train_unizero_multitask_segment_ddp copy.py

Large diffs are not rendered by default.

49 changes: 14 additions & 35 deletions lzero/entry/train_unizero_multitask_segment_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,8 @@ def train_unizero_multitask_segment_ddp(
model_path: Optional[str] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
benchmark_name: str = "atari"
benchmark_name: str = "atari",
finetune_components=[]
) -> 'Policy':
"""
Overview:
Expand All @@ -391,15 +392,17 @@ def train_unizero_multitask_segment_ddp(
# 原始的 RANDOM_SCORES 和 HUMAN_SCORES
if benchmark_name == "atari":
RANDOM_SCORES = np.array([
227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5,
152.1, 0.0, 65.2, 257.6, 1027.0, 29.0, 52.0, 1598.0, 258.5, 307.3,
-20.7, 24.9, 163.9, 11.5, 68.4, 533.4
148.0 # SpaceInvader
])
HUMAN_SCORES = np.array([
7127.7, 1719.5, 742.0, 8503.3, 753.1, 37187.5, 12.1, 30.5, 7387.8, 35829.4,
1971.0, 29.6, 4334.7, 2412.5, 30826.4, 302.8, 3035.0, 2665.5, 22736.3, 6951.6,
14.6, 69571.3, 13455.0, 7845.0, 42054.7, 11693.2
1652.0 # SpaceInvader
])
# RANDOM_SCORES = np.array([
# 148.0
# ])
# HUMAN_SCORES = np.array([
# 1652.0
# ])
elif benchmark_name == "dmc":
# RANDOM_SCORES = np.array([0]*26)
# HUMAN_SCORES = np.array([1000]*26)
Expand All @@ -415,33 +418,8 @@ def train_unizero_multitask_segment_ddp(
# PrivateEye, UpNDown, Qbert, Breakout]
# 映射为原始数组中的索引(注意:索引均从0开始)
new_order = [
20, # Pong
19, # MsPacman
24, # Seaquest
6, # Boxing
0, # Alien
8, # ChopperCommand
14, # Hero
23, # RoadRunner
1, # Amidar
2, # Assault
3, # Asterix
4, # BankHeist
5, # BattleZone
9, # CrazyClimber
10, # DemonAttack
11, # Freeway
12, # Frostbite
13, # Gopher
15, # Jamesbond
16, # Kangaroo
17, # Krull
18, # KungFuMaster
21, # PrivateEye
25, # UpNDown
22, # Qbert
7 # Breakout
]
0 # SpaceInvader (唯一任务,索引为0)
]
global new_RANDOM_SCORES, new_HUMAN_SCORES
# 根据 new_order 生成新的数组
new_RANDOM_SCORES = RANDOM_SCORES[new_order]
Expand Down Expand Up @@ -521,12 +499,13 @@ def train_unizero_multitask_segment_ddp(
# 编译配置
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# 创建共享的policy
cfg.policy.learn.learner.hook.log_show_after_iter=100
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# 加载预训练模型(如果提供)
if model_path is not None:
logging.info(f'开始加载模型: {model_path}')
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device),finetune_components=finetune_components)
logging.info(f'完成加载模型: {model_path}')

# 创建TensorBoard日志记录器
Expand Down
28 changes: 11 additions & 17 deletions lzero/model/unizero_model_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,33 +123,27 @@ def __init__(
final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder,
))
elif world_model_cfg.encoder_type == "vit":

lora_config={
'r': world_model_cfg.get('encoder_lora_r', 0),
'alpha': world_model_cfg.get('encoder_lora_alpha', 0),
'dropout': world_model_cfg.get('encoder_lora_dropout', 0),
}

for task_id in range(1): # TODO: one share encoder
if world_model_cfg.task_num <=8:
# # vit base
# self.representation_network.append(ViT(
# image_size =observation_shape[1],
# patch_size = 8,
# num_classes = obs_act_embed_dim,
# dim = 768,
# depth = 12,
# heads = 12,
# mlp_dim = 3072,
# dropout = 0.1,
# emb_dropout = 0.1,
# final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder,
# ))
# vit small
self.representation_network.append(ViT(
image_size =observation_shape[1],
patch_size = 8,
num_classes = obs_act_embed_dim,
dim = 768,
depth = 6,
heads = 6,
mlp_dim = 2048,
depth = 12,
heads = 12,
mlp_dim = 3072,
dropout = 0.1,
emb_dropout = 0.1,
final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder,
lora_config=lora_config
))
elif world_model_cfg.task_num > 8:
# vit base
Expand Down
92 changes: 75 additions & 17 deletions lzero/model/unizero_world_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,57 @@ def __init__(self, init=1.0, s_max=1.5):
def forward(self):
return self.s_max * torch.sigmoid(self.logit)

##############################################
# LoRALinear 实现
##############################################

class LoRALinear(nn.Module):
"""
基础的LoRALinear实现,对标准线性层进行LoRA微调扩展。

- 保留原始的weight和bias参数
- 添加LoRA的A和B矩阵进行低秩分解
- 前向计算: output = F.linear(x, W, bias) + scaling * lora_B(lora_A(dropout(x)))
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True,
r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.r = r
self.lora_alpha = lora_alpha
self.scaling = lora_alpha / r if r > 0 else 1.0
self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else nn.Identity()

# 初始化基础权重
self.weight = nn.Parameter(torch.empty(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
else:
self.bias = None
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if bias:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)

# 初始化LoRA参数
if r > 0:
self.lora_A = nn.Parameter(torch.randn(r, in_features) * 0.01)
self.lora_B = nn.Parameter(torch.zeros(out_features, r))
else:
self.lora_A = None
self.lora_B = None

def forward(self, x: torch.Tensor) -> torch.Tensor:
baseline_out = F.linear(x, self.weight, self.bias)
if self.r == 0 or self.lora_A is None or self.lora_B is None:
return baseline_out

lora_out = F.linear(self.lora_dropout(x), self.lora_A)
lora_out = F.linear(lora_out, self.lora_B)
return baseline_out + self.scaling * lora_out

##############################################
# CurriculumLoRALinear 实现
##############################################
Expand Down Expand Up @@ -132,9 +183,10 @@ def set_curriculum_stage(self, stage: int):

同时将 log 出模块信息和状态变化。
"""
# return
assert 0 <= stage < self.curriculum_stage_num, f"stage 必须在 [0, {self.curriculum_stage_num-1}] 范围内"
self.curriculum_stage = stage

# 输出 log 信息,展示当前模块(可结合 in_features, out_features 标识)
module_id = f"({self.in_features}x{self.out_features})"
if stage == 0:
Expand Down Expand Up @@ -202,7 +254,7 @@ def _maybe_wrap_linear(linear: nn.Linear, config, module_label: str) -> nn.Modul
- 并且 config 中配置了 curriculum_stage_num > 1
否则,若仅满足基础 LoRA 条件,则返回原有 LoRALinear;否则返回原始的线性层。
"""
if config.lora_r > 0 and (module_label in config.lora_target_modules) and getattr(config, "curriculum_stage_num", 1) > 1:
if False and config.lora_r > 0 and (module_label in config.lora_target_modules) and getattr(config, "curriculum_stage_num", 1) > 1:
new_linear = CurriculumLoRALinear(
in_features=linear.in_features,
out_features=linear.out_features,
Expand All @@ -217,20 +269,20 @@ def _maybe_wrap_linear(linear: nn.Linear, config, module_label: str) -> nn.Modul
if linear.bias is not None:
new_linear.bias.data.copy_(linear.bias.data)
return new_linear
# elif config.lora_r > 0 and (module_label in config.lora_target_modules):
# # 若不使用课程学习,则调用原有 LoRALinear 实现(未展示,此处假设其已定义)
# new_linear = LoRALinear(
# in_features=linear.in_features,
# out_features=linear.out_features,
# bias=(linear.bias is not None),
# r=config.lora_r,
# lora_alpha=config.lora_alpha,
# lora_dropout=config.lora_dropout
# )
# new_linear.weight.data.copy_(linear.weight.data)
# if linear.bias is not None:
# new_linear.bias.data.copy_(linear.bias.data)
# return new_linear
elif config.lora_r > 0 and (module_label in config.lora_target_modules):
# 若不使用课程学习,则调用原有 LoRALinear 实现(未展示,此处假设其已定义)
new_linear = LoRALinear(
in_features=linear.in_features,
out_features=linear.out_features,
bias=(linear.bias is not None),
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout
)
new_linear.weight.data.copy_(linear.weight.data)
if linear.bias is not None:
new_linear.bias.data.copy_(linear.bias.data)
return new_linear
else:
return linear

Expand Down Expand Up @@ -346,7 +398,13 @@ def __init__(self, config: TransformerConfig, task_embed=None) -> None:

else:
self.use_register_token = False # TODO


# if config.lora_r > 0:
# set_curriculum_stage_for_transformer(self,)
# # set_curriculum_stage_for_transformer(
# self.policy._learn_model.world_model.transformer,
# self.stage
# )

def add_register_tokens(self, sequences: torch.Tensor, task_id: int) -> torch.Tensor:
"""
Expand Down
Loading