diff --git a/.gitignore b/.gitignore index fed96ee2f..1cda7444c 100644 --- a/.gitignore +++ b/.gitignore @@ -1450,4 +1450,4 @@ events.* !/assets/pooltool/** lzero/mcts/ctree/ctree_alphazero/pybind11 -zoo/jericho/envs/z-machine-games-master \ No newline at end of file +zoo/jericho/envs/z-machine-games-master diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index f17126527..2a269a261 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -10,4 +10,8 @@ from .train_rezero import train_rezero from .train_unizero import train_unizero from .train_unizero_segment import train_unizero_segment -from .utils import * +from .train_muzero_multitask_segment_ddp import train_muzero_multitask_segment_ddp +from .train_unizero_multitask_segment_ddp import train_unizero_multitask_segment_ddp +from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval +from .train_unizero_multitask_balance_segment_ddp import train_unizero_multitask_balance_segment_ddp +from .utils import * \ No newline at end of file diff --git a/lzero/entry/train_muzero_multitask_segment_ddp.py b/lzero/entry/train_muzero_multitask_segment_ddp.py new file mode 100644 index 000000000..5d608271a --- /dev/null +++ b/lzero/entry/train_muzero_multitask_segment_ddp.py @@ -0,0 +1,563 @@ +import concurrent.futures +import logging +import os +from functools import partial +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import Policy, create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import EasyTimer, set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.mcts import MuZeroGameBuffer as GameBuffer +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator + +# ========================== +# Global Constants +# ========================== +EVALUATION_TIMEOUT_SECONDS: int = 3600 +MAX_TRAIN_ITER_INF: int = int(1e10) +MAX_ENV_STEP_INF: int = int(1e10) + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[float]]: + """ + Overview: + Safely performs an evaluation step with a timeout to prevent the training process from blocking. + Arguments: + - evaluator (:obj:`Evaluator`): The evaluator instance. + - learner (:obj:`BaseLearner`): The learner instance to save checkpoints. + - collector (:obj:`Collector`): The collector instance to get the current envstep. + - rank (:obj:`int`): The rank of the current process. + - world_size (:obj:`int`): The total number of processes. + Returns: + - (:obj:`Tuple[Optional[bool], Optional[float]]`): A tuple containing the stop flag and the evaluation reward. + Returns (None, None) if a timeout occurs. + """ + logging.info(f"Rank {rank}/{world_size}: Starting evaluation...") + # Ensure the stop_event is clear before each evaluation. + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit( + evaluator.eval, + learner.save_checkpoint, + learner.train_iter, + collector.envstep + ) + try: + stop, reward = future.result(timeout=EVALUATION_TIMEOUT_SECONDS) + logging.info(f"Rank {rank}/{world_size}: Evaluation finished successfully. Stop: {stop}, Reward: {reward}") + return stop, reward + except concurrent.futures.TimeoutError: + # Set the evaluator's stop_event on timeout to gracefully stop the evaluation worker. + evaluator.stop_event.set() + logging.warning( + f"Rank {rank}/{world_size}: Evaluation timed out after {EVALUATION_TIMEOUT_SECONDS} seconds. " + f"Continuing training." + ) + return None, None + + +def allocate_batch_size( + cfgs: List[Any], + game_buffers: List[GameBuffer], + alpha: float = 1.0, + clip_scale: float = 1.0 +) -> List[int]: + """ + Overview: + Allocates batch sizes for different tasks inversely proportional to their number of collected episodes. + This method dynamically adjusts the batch size range to enhance training stability and efficiency. + Arguments: + - cfgs (:obj:`List[Any]`): A list of configuration objects for each task. + - game_buffers (:obj:`List[GameBuffer]`): A list of replay buffer instances for each task. + - alpha (:obj:`float`): A hyperparameter to control the degree of inverse proportionality. Defaults to 1.0. + - clip_scale (:obj:`float`): A scaling factor for dynamic adjustment of min/max batch size. Defaults to 1.0. + Returns: + - (:obj:`List[int]`): A list of allocated batch sizes for each task. + """ + # Step 1: Gather the number of collected episodes from all buffers on the current rank. + buffer_num_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] + + world_size = get_world_size() + rank = get_rank() + + # Step 2: Gather episode counts from all tasks across all ranks. + all_task_num_episodes = [None for _ in range(world_size)] + dist.all_gather_object(all_task_num_episodes, buffer_num_episodes) + + # Flatten the list of lists into a single list. + flat_task_num_episodes = [item for sublist in all_task_num_episodes for item in sublist] + if rank == 0: + logging.info(f'Number of collected episodes per task (all ranks): {flat_task_num_episodes}') + + # Step 3: Calculate inverse proportional weights. Add 1 to avoid division by zero. + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in flat_task_num_episodes]) + inv_sum = np.sum(inv_episodes) + + # Step 4: Calculate the total batch size from the config of the first task. + # Assumption: max_batch_size is the same across all task configs and represents the global batch size. + global_batch_size = cfgs[0].policy.max_batch_size + + # Step 5: Dynamically adjust the min and max batch size bounds. + avg_batch_size = global_batch_size / len(flat_task_num_episodes) + min_batch_size = max(1, avg_batch_size / clip_scale) # Ensure min_batch_size is at least 1. + max_batch_size_clip = avg_batch_size * clip_scale + + # Step 6: Calculate batch sizes based on weights and apply clipping. + task_weights = (inv_episodes / inv_sum) ** alpha + # Note: The original code used max_batch_size, which seems to be a typo. + # It should be global_batch_size to distribute the total batch size. + batch_sizes = global_batch_size * task_weights + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size_clip) + + # Ensure batch sizes are integers. + final_batch_sizes = [int(size) for size in batch_sizes] + + if rank == 0: + logging.info(f"Allocated batch sizes: {final_batch_sizes}") + + return final_batch_sizes + + +class MuZeroMultiTaskTrainer: + """ + Overview: + A trainer class to manage the multi-task training loop for MuZero. + It encapsulates the state and logic for initialization, data collection, + evaluation, training, and termination. + """ + + def __init__( + self, + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int, + model: Optional[torch.nn.Module], + model_path: Optional[str], + max_train_iter: int, + max_env_step: int, + ) -> None: + """ + Overview: + Initializes the multi-task trainer. + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): Configs for all tasks. + - seed (:obj:`int`): The base random seed. + - model (:obj:`Optional[torch.nn.Module]`): An optional pre-defined model. + - model_path (:obj:`Optional[str]`): Path to a pre-trained model checkpoint. + - max_train_iter (:obj:`int`): Maximum training iterations. + - max_env_step (:obj:`int`): Maximum environment steps. + """ + self.max_train_iter = max_train_iter + self.max_env_step = max_env_step + self.seed = seed + self.rank = get_rank() + self.world_size = get_world_size() + self.timer = EasyTimer() + + # State variables + self.train_epoch = 0 + self.buffer_reanalyze_count = 0 + self.value_priority_tasks = {} + + # Task partitioning + self.tasks_for_this_rank = self._partition_tasks(input_cfg_list) + if not self.tasks_for_this_rank: + logging.warning(f"Rank {self.rank}: No tasks assigned. Process will run without tasks.") + self.is_active = False + return + self.is_active = True + + # Initialize shared components (Policy, Learner) + self.policy, self.learner, self.tb_logger = self._initialize_shared_components(model, model_path) + + # Initialize task-specific components + ( + self.cfgs, self.game_buffers, self.collectors, self.evaluators + ) = self._initialize_task_specific_components() + + self.update_per_collect = self.cfgs[0].policy.update_per_collect + + def _partition_tasks(self, input_cfg_list: List[Tuple[int, Tuple[dict, dict]]]) -> List[ + Tuple[int, Tuple[dict, dict]]]: + """Partitions tasks among distributed processes.""" + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // self.world_size + remainder = total_tasks % self.world_size + + if self.rank < remainder: + start_idx = self.rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = self.rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + logging.info(f"Rank {self.rank}/{self.world_size} is assigned tasks from index {start_idx} to {end_idx - 1}.") + return input_cfg_list[start_idx:end_idx] + + def _initialize_shared_components(self, model: Optional[torch.nn.Module], model_path: Optional[str]) -> Tuple[ + Policy, BaseLearner, SummaryWriter]: + """Initializes components shared across all tasks on this rank.""" + _, [cfg, create_cfg] = self.tasks_for_this_rank[0] + + # Set task_num for the shared policy + for task_config in self.tasks_for_this_rank: + task_config[1][0].policy.task_num = len(self.tasks_for_this_rank) + + cfg.policy.device = 'cuda' if torch.cuda.is_available() else 'cpu' + compiled_cfg = compile_config(cfg, seed=self.seed, auto=True, create_cfg=create_cfg, save_cfg=True) + + policy = create_policy(compiled_cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + if model_path: + logging.info(f'Loading model from {model_path}...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=compiled_cfg.policy.device)) + logging.info(f'Model loaded successfully from {model_path}.') + + log_dir = os.path.join(f'./{compiled_cfg.exp_name}/log', f'serial_rank_{self.rank}') + tb_logger = SummaryWriter(log_dir) + learner = BaseLearner(compiled_cfg.policy.learn.learner, policy.learn_mode, tb_logger, + exp_name=compiled_cfg.exp_name) + return policy, learner, tb_logger + + def _initialize_task_specific_components(self) -> Tuple[List, List, List, List]: + """Initializes components for each task assigned to this rank.""" + cfgs, game_buffers, collectors, evaluators = [], [], [], [] + + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(self.tasks_for_this_rank): + task_seed = self.seed + task_id + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + compiled_cfg = compile_config(cfg, seed=task_seed, auto=True, create_cfg=create_cfg, save_cfg=True) + + # Create environments + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(compiled_cfg.env) + collector_env = create_env_manager(compiled_cfg.env.manager, + [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(compiled_cfg.env.manager, + [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(task_seed) + evaluator_env.seed(task_seed, dynamic_seed=False) + set_pkg_seed(task_seed, use_cuda=compiled_cfg.policy.cuda) + + # Create buffer, collector, and evaluator + replay_buffer = GameBuffer(compiled_cfg.policy) + # Set initial batch size from config + replay_buffer.batch_size = compiled_cfg.policy.batch_size[task_id] + + collector = Collector( + env=collector_env, + policy=self.policy.collect_mode, + tb_logger=self.tb_logger, + exp_name=compiled_cfg.exp_name, + policy_config=compiled_cfg.policy, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=compiled_cfg.policy.eval_freq, + n_evaluator_episode=compiled_cfg.env.n_evaluator_episode, + stop_value=compiled_cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + tb_logger=self.tb_logger, + exp_name=compiled_cfg.exp_name, + policy_config=compiled_cfg.policy, + task_id=task_id + ) + + cfgs.append(compiled_cfg) + game_buffers.append(replay_buffer) + collectors.append(collector) + evaluators.append(evaluator) + + return cfgs, game_buffers, collectors, evaluators + + def run(self) -> Policy: + """ + Overview: + The main training loop. Executes collection, evaluation, and training steps + until a termination condition is met. + Returns: + - (:obj:`Policy`): The trained policy. + """ + if not self.is_active: + # This rank has no tasks, so it should wait for others to finish. + self._wait_for_termination() + return self.policy + + self.learner.call_hook('before_run') + + while True: + torch.cuda.empty_cache() + + self._update_dynamic_batch_sizes() + self._collect_and_evaluate() + + if self._is_training_ready(): + dist.barrier() + self._train_iteration() + dist.barrier() + else: + logging.warning(f"Rank {self.rank}: Not enough data for training, skipping training step.") + + if self._check_termination_conditions(): + dist.barrier() # Final barrier to ensure all processes stop together. + break + + self.learner.call_hook('after_run') + return self.policy + + def _update_dynamic_batch_sizes(self) -> None: + """Dynamically allocates batch sizes if enabled in the config.""" + if not self.cfgs[0].policy.get('allocated_batch_sizes', False): + return + + # Linearly increase clip_scale from 1 to 4 as train_epoch goes from 0 to 1000. + clip_scale = np.clip(1 + (3 * self.train_epoch / 1000), 1, 4) + allocated_sizes = allocate_batch_size(self.cfgs, self.game_buffers, alpha=1.0, clip_scale=clip_scale) + + # Distribute the allocated sizes to the tasks on the current rank. + # This requires knowing the global task distribution. + total_tasks = self.world_size * len(self.tasks_for_this_rank) # Approximation, needs exact count + # This part is tricky in a distributed setting without global knowledge of task indices. + # Assuming the allocation order matches the task_id order. + for i, cfg in enumerate(self.cfgs): + task_id = cfg.policy.task_id + if task_id < len(allocated_sizes): + batch_size = allocated_sizes[task_id] + cfg.policy.batch_size = batch_size + # Also update the batch size in the shared policy config if necessary + self.policy._cfg.batch_size[task_id] = batch_size + + + def _collect_and_evaluate(self) -> None: + """Runs the data collection and evaluation loop for each assigned task.""" + for i, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(self.cfgs, self.collectors, self.evaluators, self.game_buffers)): + log_buffer_memory_usage(self.learner.train_iter, replay_buffer, self.tb_logger, cfg.policy.task_id) + + # Evaluation step + if evaluator.should_eval(self.learner.train_iter): + safe_eval(evaluator, self.learner, collector, self.rank, self.world_size) + + # Collection step + self._collect_data_for_task(cfg, collector, replay_buffer) + + def _collect_data_for_task(self, cfg: Any, collector: Collector, replay_buffer: GameBuffer) -> None: + """Collects data for a single task and pushes it to the replay buffer.""" + policy_config = cfg.policy + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=self.learner.train_iter + ), + 'epsilon': 0.0 + } + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, end=policy_config.eps.end, + decay=policy_config.eps.decay, type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_fn(collector.envstep) + + logging.info(f'Rank {self.rank}: Collecting data for task {cfg.policy.task_id}...') + new_data = collector.collect(train_iter=self.learner.train_iter, policy_kwargs=collect_kwargs) + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + logging.info(f'Rank {self.rank}: Finished data collection for task {cfg.policy.task_id}.') + + # Periodic reanalysis of the buffer + self._reanalyze_buffer_if_needed(cfg, replay_buffer, is_during_training=False) + + def _reanalyze_buffer_if_needed(self, cfg: Any, replay_buffer: GameBuffer, is_during_training: bool, + train_loop_idx: int = 0) -> None: + """Handles the logic for reanalyzing the game buffer.""" + policy_config = cfg.policy + reanalyze_freq = policy_config.buffer_reanalyze_freq + reanalyze_batch_size = policy_config.reanalyze_batch_size + reanalyze_partition = policy_config.reanalyze_partition + update_per_collect = policy_config.update_per_collect + + should_reanalyze = False + if reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // reanalyze_freq + if is_during_training and train_loop_idx % reanalyze_interval == 0: + should_reanalyze = True + else: # reanalyze_freq is a fraction, e.g., 0.1 + if not is_during_training and self.train_epoch % int(1 / reanalyze_freq) == 0: + should_reanalyze = True + + if should_reanalyze and replay_buffer.get_num_of_transitions() // policy_config.num_unroll_steps > int(reanalyze_batch_size / reanalyze_partition): + with self.timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, self.policy) + self.buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {self.buffer_reanalyze_count}, Time: {self.timer.value:.2f}s') + + def _is_training_ready(self) -> bool: + """Checks if there is enough data in all buffers to start training.""" + for cfg, buffer in zip(self.cfgs, self.game_buffers): + if buffer.get_num_of_transitions() < cfg.policy.batch_size[cfg.policy.task_id]: + logging.warning(f"Rank {self.rank}, Task {cfg.policy.task_id}: Not enough data. " + f"Required: {cfg.policy.batch_size[cfg.policy.task_id]}, " + f"Available: {buffer.get_num_of_transitions()}") + return False + return True + + def _train_iteration(self) -> None: + """Performs one full training iteration, consisting of multiple updates.""" + for i in range(self.update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + + for idx, (cfg, collector, replay_buffer) in enumerate( + zip(self.cfgs, self.collectors, self.game_buffers)): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + + if replay_buffer.get_num_of_transitions() > batch_size: + self._reanalyze_buffer_if_needed(cfg, replay_buffer, is_during_training=True, train_loop_idx=i) + train_data = replay_buffer.sample(batch_size, self.policy) + train_data.append(cfg.policy.task_id) # Append task_id for multi-task loss + train_data_multi_task.append(train_data) + else: + # This case should ideally be prevented by _is_training_ready + logging.warning(f"Skipping sample for task {cfg.policy.task_id} due to insufficient data.") + train_data_multi_task.clear() # Invalidate the whole batch if one task fails + break + + if train_data_multi_task: + log_vars = self.learner.train(train_data_multi_task, envstep_multi_task) + if self.cfgs[0].policy.use_priority: + self._update_priorities(train_data_multi_task, log_vars) + + self.train_epoch += 1 + + def _update_priorities(self, train_data_multi_task: List, log_vars: List[Dict]) -> None: + """Updates the priorities in the replay buffers after a training step.""" + for idx, (cfg, replay_buffer) in enumerate(zip(self.cfgs, self.game_buffers)): + task_id = cfg.policy.task_id + priority_key = f'value_priority_task{task_id}' + + if priority_key in log_vars[0]: + priorities = log_vars[0][priority_key] + replay_buffer.update_priority(train_data_multi_task[idx], priorities) + + # Log priority statistics + if cfg.policy.get('print_task_priority_logs', False): + mean_priority = np.mean(priorities) + std_priority = np.std(priorities) + + # Update running mean of priority + running_mean_key = f'running_mean_priority_task{task_id}' + alpha = 0.1 # Smoothing factor for running average + if running_mean_key not in self.value_priority_tasks: + self.value_priority_tasks[running_mean_key] = mean_priority + else: + self.value_priority_tasks[running_mean_key] = \ + alpha * mean_priority + (1 - alpha) * self.value_priority_tasks[running_mean_key] + + running_mean_priority = self.value_priority_tasks[running_mean_key] + logging.info( + f"Task {task_id} - Priority Stats: Mean={mean_priority:.6f}, " + f"Running Mean={running_mean_priority:.6f}, Std={std_priority:.6f}" + ) + + def _check_termination_conditions(self) -> bool: + """Checks if the training should be terminated based on env steps or train iterations.""" + try: + # Check max_env_step + local_envsteps = [collector.envstep for collector in self.collectors] + all_ranks_envsteps = [None for _ in range(self.world_size)] + dist.all_gather_object(all_ranks_envsteps, local_envsteps) + + # Flatten and check if all tasks have reached the step limit + all_envsteps = [step for rank_steps in all_ranks_envsteps for step in rank_steps] + if all(step >= self.max_env_step for step in all_envsteps): + logging.info(f"Rank {self.rank}: All tasks reached max_env_step ({self.max_env_step}). Terminating.") + return True + + # Check max_train_iter + local_train_iter = torch.tensor([self.learner.train_iter], device=self.policy.device) + all_train_iters = [torch.zeros_like(local_train_iter) for _ in range(self.world_size)] + dist.all_gather(all_train_iters, local_train_iter) + + if any(it.item() >= self.max_train_iter for it in all_train_iters): + logging.info(f"Rank {self.rank}: A process reached max_train_iter ({self.max_train_iter}). Terminating.") + return True + + except Exception as e: + logging.error(f'Rank {self.rank}: Failed during termination check. Error: {e}', exc_info=True) + return True # Terminate on error to prevent hanging + + return False + + def _wait_for_termination(self) -> None: + """ + For inactive ranks, this method blocks and waits for a termination signal + (e.g., another rank finishing) by participating in barriers and termination checks. + """ + while True: + # Participate in barriers to stay in sync + dist.barrier() # Pre-train barrier + dist.barrier() # Post-train barrier + + if self._check_termination_conditions(): + dist.barrier() # Final barrier + break + +def train_muzero_multitask_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = MAX_TRAIN_ITER_INF, + max_env_step: Optional[int] = MAX_ENV_STEP_INF, +) -> Policy: + """ + Overview: + The main entry point for multi-task MuZero training using Distributed Data Parallel (DDP). + This function sets up the distributed environment, partitions tasks, and launches the training process, + which is managed by the MuZeroMultiTaskTrainer class. + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): A list of tuples, where each tuple contains + a task ID and its corresponding configuration dictionaries (main_config, create_config). + - seed (:obj:`int`): The base random seed for reproducibility. Defaults to 0. + - model (:obj:`Optional[torch.nn.Module]`): An optional pre-defined model instance. If provided, + it will be used instead of creating a new one from the config. Defaults to None. + - model_path (:obj:`Optional[str]`): Path to a pre-trained model checkpoint file. If provided, + the model weights will be loaded before training starts. Defaults to None. + - max_train_iter (:obj:`Optional[int]`): The maximum number of training iterations. + Training will stop if any process reaches this limit. Defaults to a very large number. + - max_env_step (:obj:`Optional[int]`): The maximum number of environment steps for each task. + Training will stop when all tasks have reached this limit. Defaults to a very large number. + Returns: + - (:obj:`Policy`): The final trained policy instance from the primary rank. + """ + # Initialize the trainer, which handles all the complex setup and logic internally. + trainer = MuZeroMultiTaskTrainer( + input_cfg_list=input_cfg_list, + seed=seed, + model=model, + model_path=model_path, + max_train_iter=max_train_iter, + max_env_step=max_env_step, + ) + + # Run the training loop and return the trained policy. + return trainer.run() \ No newline at end of file diff --git a/lzero/entry/train_unizero.py b/lzero/entry/train_unizero.py index cb5712d0b..d09f963b7 100644 --- a/lzero/entry/train_unizero.py +++ b/lzero/entry/train_unizero.py @@ -136,6 +136,9 @@ def train_unizero( else: world_size = 1 rank = 0 + # TODO: for visualize + # stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + # import sys; sys.exit(0) while True: # Log memory usage of the replay buffer diff --git a/lzero/entry/train_unizero_multitask_balance_segment_ddp.py b/lzero/entry/train_unizero_multitask_balance_segment_ddp.py new file mode 100644 index 000000000..d80106e49 --- /dev/null +++ b/lzero/entry/train_unizero_multitask_balance_segment_ddp.py @@ -0,0 +1,548 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List, Dict, Any + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from ding.utils import EasyTimer +import torch.nn.functional as F +import torch.distributed as dist +import concurrent.futures +from lzero.model.unizero_world_models.transformer import set_curriculum_stage, CurriculumLoRALinear + +from collections import defaultdict +import math +from .utils import ( + freeze_non_lora_parameters, + compute_task_weights, + log_module_trainable_status, + log_param_statistics, + tasks_per_stage, + compute_unizero_mt_normalized_stats, + allocate_batch_size +) + +# A global dictionary to store the most recent evaluation return for each task. +# Format: {task_id: eval_episode_return_mean} +GLOBAL_EVAL_RETURNS: Dict[int, float] = defaultdict(lambda: None) + +# Timeout for the evaluation process in seconds. +EVALUATION_TIMEOUT = 12000 # 200 minutes + + +class CurriculumController: + """ + Overview: + Manages the curriculum learning stages for a multi-task policy. + It tracks the number of solved tasks and training iterations to decide when to transition + to the next curriculum stage, which typically involves freezing parts of the model + and activating new LoRA adapters. + """ + + def __init__(self, cfg: 'EasyDict', policy: 'Policy') -> None: + """ + Overview: + Initializes the CurriculumController. + Arguments: + - cfg (:obj:`EasyDict`): The experiment configuration. + - policy (:obj:`Policy`): The policy being trained. + """ + world_model_cfg = cfg.policy.model.world_model_cfg + self.stage_num: int = world_model_cfg.curriculum_stage_num + self.min_stage0_iters: int = world_model_cfg.min_stage0_iters + self.max_stage_iters: int = world_model_cfg.max_stage_iters + self.policy: 'Policy' = policy + + # Flag to determine if curriculum learning should also be applied to the encoder. + # Defaults to False for backward compatibility. + self.apply_curriculum_to_encoder: bool = getattr(world_model_cfg, 'apply_curriculum_to_encoder', False) + logging.info(f"[CurriculumController] Initialized. Curriculum will be applied to Encoder: {self.apply_curriculum_to_encoder}") + + self.stage: int = 0 + self.last_switch_iter: int = 0 + self.last_solved_count: int = 0 # Snapshot of the last count of solved tasks + + def step(self, solved_count: int, unsolved_count: int, train_iter: int) -> bool: + """ + Overview: + Checks if the curriculum should transition to the next stage and performs the switch if needed. + This method should be called at the end of each training loop. + Arguments: + - solved_count (:obj:`int`): The current total number of solved tasks. + - unsolved_count (:obj:`int`): The current number of tasks yet to be solved. + - train_iter (:obj:`int`): The current training iteration. + Returns: + - bool: True if a stage switch occurred, False otherwise. + """ + # --- Stage 0 is a mandatory training phase for a minimum number of iterations --- + if self.stage == 0 and train_iter < self.min_stage0_iters: + return False + + # --- Determine if a stage switch is necessary --- + should_switch = False + + # 1. Trigger based on task progress + newly_solved = solved_count - self.last_solved_count + remaining_lora_stages = self.stage_num - 1 - self.stage # Stage 0 doesn't use LoRA + if remaining_lora_stages > 0: + # Calculate tasks per stage (tps) for the remaining unsolved tasks + tps = tasks_per_stage(unsolved_count, remaining_lora_stages) + if newly_solved >= tps: + should_switch = True + + # 2. Trigger based on maximum iterations per stage + if train_iter - self.last_switch_iter >= self.max_stage_iters: + should_switch = True + + # --- Execute the stage switch --- + if should_switch and self.stage < self.stage_num - 1: + is_entering_stage1 = (self.stage == 0) + self.stage += 1 + + world_model = self.policy._learn_model.world_model + vit_encoder = world_model.tokenizer.encoder + transformer_backbone = world_model.transformer + + # --- Apply curriculum stage update and freeze parameters accordingly --- + + # 1. Conditionally apply to ViT Encoder based on configuration + if self.apply_curriculum_to_encoder: + logging.info(f"[Curriculum] Applying curriculum stage {self.stage} to ViT Encoder.") + set_curriculum_stage(vit_encoder, self.stage) + if is_entering_stage1: + logging.info("[Curriculum] Entering Stage 1. Freezing non-LoRA parameters in ViT Encoder.") + freeze_non_lora_parameters(vit_encoder, freeze=True, verbose=True) + log_module_trainable_status(vit_encoder, "ViT Encoder") + else: + logging.info("[Curriculum] Skipping curriculum stage update for ViT Encoder as per configuration.") + log_module_trainable_status(vit_encoder, "ViT Encoder (Curriculum Not Applied)") + + # 2. Always apply to Transformer Decoder + logging.info(f"[Curriculum] Applying curriculum stage {self.stage} to Transformer Backbone.") + set_curriculum_stage(transformer_backbone, self.stage) + if is_entering_stage1: + logging.info("[Curriculum] Entering Stage 1. Freezing non-LoRA parameters in Transformer Backbone.") + freeze_non_lora_parameters(transformer_backbone, freeze=True, verbose=True) + log_module_trainable_status(transformer_backbone, "Transformer Backbone") + + logging.info( + f'[Curriculum] Switched to stage {self.stage} ' + f'(solved={solved_count}, unsolved={unsolved_count}, iter={train_iter})' + ) + + # Log parameter statistics after the switch + updated_params = sum(p.requires_grad for p in self.policy._learn_model.world_model.parameters()) + total_params = sum(1 for _ in self.policy._learn_model.world_model.parameters()) + logging.info(f'{updated_params}/{total_params} parameters in the world model will be optimized.') + log_param_statistics(self.policy._learn_model.world_model) + + self.last_solved_count = solved_count + self.last_switch_iter = train_iter + return True + + return False + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[Dict[str, Any]]]: + """ + Overview: + Executes the evaluation process with a timeout to prevent the training from stalling. + Arguments: + - evaluator (:obj:`Evaluator`): The evaluator instance. + - learner (:obj:`BaseLearner`): The learner instance, used to save checkpoints. + - collector (:obj:`Collector`): The collector instance, used to get the current envstep. + - rank (:obj:`int`): The rank of the current process. + - world_size (:obj:`int`): The total number of processes. + Returns: + - Tuple[Optional[bool], Optional[Dict[str, Any]]]: A tuple containing the stop flag and the reward dictionary + if evaluation succeeds. Returns (None, None) on timeout or error. + """ + try: + logging.info(f"========= Evaluation starting on Rank {rank}/{world_size} =========") + # Ensure the stop_event is clear before starting a new evaluation. + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # Submit the evaluation task. + future = executor.submit(evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep) + try: + stop_flag, reward_dict = future.result(timeout=EVALUATION_TIMEOUT) + except concurrent.futures.TimeoutError: + # Set the stop_event to terminate the stuck evaluation thread. + evaluator.stop_event.set() + logging.error(f"Evaluation timed out on Rank {rank}/{world_size} after {EVALUATION_TIMEOUT} seconds.") + return None, None + + logging.info(f"====== Evaluation finished on Rank {rank}/{world_size} ======") + return stop_flag, reward_dict + except Exception as e: + logging.error(f"An error occurred during evaluation on Rank {rank}/{world_size}: {e}", exc_info=True) + return None, None + + +def train_unizero_multitask_balance_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), + benchmark_name: str = "atari" +) -> 'Policy': + """ + Overview: + The main training entry point for UniZero in a multi-task, curriculum-based setting using DDP. + This function orchestrates distributed data collection, training, and evaluation across multiple tasks. + The curriculum learning strategy involves: + - Defining a `target_return` for each task. + - Moving tasks to a `solved_task_pool` once they achieve their target return, excluding them from + further training and collection. + - Progressing through curriculum stages where the model's backbone is frozen, and only specialized + modules (like LoRA) are trained on harder, unsolved tasks. + This allows the model to first learn general features and then specialize on difficult tasks without + catastrophic forgetting. + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): A list of configurations for each task. + - seed (:obj:`int`): The random seed. + - model (:obj:`Optional[torch.nn.Module]`): An optional pre-existing model instance. + - model_path (:obj:`Optional[str]`): Path to a pre-trained model checkpoint file. + - max_train_iter (:obj:`Optional[int]`): The maximum number of training iterations. + - max_env_step (:obj:`Optional[int]`): The maximum number of environment steps. + - benchmark_name (:obj:`str`): The name of the benchmark (e.g., "atari", "dmc") to load normalization scores. + Returns: + - Policy: The trained policy. + """ + # --- Initialization and DDP Setup --- + logging.basicConfig(level=logging.INFO) + rank = get_rank() + world_size = get_world_size() + timer = EasyTimer() + + # --- Benchmark Score Initialization --- + if benchmark_name == "atari": + RANDOM_SCORES = np.array([ + 227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5, + 152.1, 0.0, 65.2, 257.6, 1027.0, 29.0, 52.0, 1598.0, 258.5, 307.3, + -20.7, 24.9, 163.9, 11.5, 68.4, 533.4 + ]) + HUMAN_SCORES = np.array([ + 7127.7, 1719.5, 742.0, 8503.3, 753.1, 37187.5, 12.1, 30.5, 7387.8, 35829.4, + 1971.0, 29.6, 4334.7, 2412.5, 30826.4, 302.8, 3035.0, 2665.5, 22736.3, 6951.6, + 14.6, 69571.3, 13455.0, 7845.0, 42054.7, 11693.2 + ]) + new_order = [ + 20, 19, 24, 6, 0, 8, 14, 23, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 15, 16, 17, 18, 21, 25, 22, 7 + ] + new_RANDOM_SCORES = RANDOM_SCORES[new_order] + new_HUMAN_SCORES = HUMAN_SCORES[new_order] + elif benchmark_name == "dmc": + new_RANDOM_SCORES = np.zeros(26) + new_HUMAN_SCORES = np.ones(26) * 1000 + else: + raise ValueError(f"Unsupported benchmark_name: {benchmark_name}") + + # --- Task Distribution Across Ranks --- + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + start_idx = rank * tasks_per_rank + min(rank, remainder) + end_idx = start_idx + tasks_per_rank + (1 if rank < remainder else 0) + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + if not tasks_for_this_rank: + logging.warning(f"Rank {rank}: No tasks assigned. Process will idle but maintain DDP communication.") + # An idle process must still participate in collective communications. + # The main loop handles this by waiting at barriers. + while True: + dist.barrier() # Wait for other processes + dist.barrier() # Sync after potential training step + # A mechanism to terminate idle processes would be needed here, + # for now, they sync and wait. + # This part requires a robust termination signal from active processes. + + logging.info(f"Rank {rank}/{world_size} is handling tasks from index {start_idx} to {end_idx - 1}.") + + # --- Environment, Policy, and Worker Initialization --- + task_configs, replay_buffers, collectors, evaluators = [], [], [], [] + + # Use the first task's config to create the shared policy and learner + _, [main_cfg, main_create_cfg] = tasks_for_this_rank[0] + for _, [cfg, _] in tasks_for_this_rank: + cfg.policy.task_num = len(tasks_for_this_rank) + + assert main_create_cfg.policy.type in ['unizero_multitask', 'sampled_unizero_multitask'], \ + "This entry only supports 'unizero_multitask' or 'sampled_unizero_multitask' policies." + + GameBuffer = None + if main_create_cfg.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GameBuffer + elif main_create_cfg.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer + + main_cfg.policy.device = 'cuda' if torch.cuda.is_available() else 'cpu' + compiled_cfg = compile_config(main_cfg, seed=seed, auto=True, create_cfg=main_create_cfg, save_cfg=True) + + policy = create_policy(compiled_cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + if model_path: + logging.info(f'Loading pre-trained model from: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=compiled_cfg.policy.device)) + logging.info('Model loading complete.') + + tb_logger = SummaryWriter(os.path.join(f'./{compiled_cfg.exp_name}/log', f'rank_{rank}')) + learner = BaseLearner(compiled_cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=compiled_cfg.exp_name) + learner.call_hook('before_run') + + # Initialize components for each assigned task + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + task_seed = seed + task_id + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + compiled_task_cfg = compile_config(cfg, seed=task_seed, auto=True, create_cfg=create_cfg, save_cfg=True) + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(compiled_task_cfg.env) + collector_env = create_env_manager(compiled_task_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(compiled_task_cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(task_seed) + evaluator_env.seed(task_seed, dynamic_seed=False) + set_pkg_seed(task_seed, use_cuda=compiled_task_cfg.policy.cuda) + + replay_buffers.append(GameBuffer(compiled_task_cfg.policy)) + collectors.append(Collector(collector_env, policy.collect_mode, tb_logger, compiled_task_cfg.exp_name, compiled_task_cfg.policy, task_id)) + evaluators.append(Evaluator(compiled_task_cfg.policy.eval_freq, compiled_task_cfg.env.n_evaluator_episode, compiled_task_cfg.env.stop_value, evaluator_env, policy.eval_mode, tb_logger, compiled_task_cfg.exp_name, compiled_task_cfg.policy, task_id)) + task_configs.append(compiled_task_cfg) + + # --- Curriculum and Training Loop Initialization --- + solved_task_pool = set() + curriculum_controller = CurriculumController(compiled_cfg, policy) + temperature_scheduler = TemperatureScheduler(initial_temp=10.0, final_temp=1.0, threshold_steps=int(1e4), mode='linear') + + train_epoch = 0 + buffer_reanalyze_count = 0 + + logging.info(f"Rank {rank}: Initial trainable parameters in world model: {sum(p.requires_grad for p in policy._learn_model.world_model.parameters())}/{sum(1 for _ in policy._learn_model.world_model.parameters())}") + + # ============================================================================================ + # Main Training Loop + # ============================================================================================ + while True: + # --- 1. Dynamic Batch Size Allocation (Optional) --- + if compiled_cfg.policy.allocated_batch_sizes: + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(task_configs, replay_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + logging.info(f"Dynamically allocated batch sizes: {allocated_batch_sizes}") + for i, cfg in enumerate(task_configs): + cfg.policy.batch_size = allocated_batch_sizes + policy._cfg.batch_size = allocated_batch_sizes + + # --- 2. Data Collection and Evaluation for each task on this rank --- + local_task_returns = {} + for i, (cfg, collector, evaluator, replay_buffer) in enumerate(zip(task_configs, collectors, evaluators, replay_buffers)): + task_id = cfg.policy.task_id + if task_id in solved_task_pool: + continue + + # Evaluate policy if it's time + if learner.train_iter > 10 and evaluator.should_eval(learner.train_iter): + logging.info(f'Rank {rank} evaluating task_id: {task_id}...') + evaluator._policy.reset(reset_init_data=True, task_id=task_id) + stop_flag, reward_dict = safe_eval(evaluator, learner, collector, rank, world_size) + + if reward_dict is not None: + eval_mean_reward = reward_dict.get('eval_episode_return_mean', float('-inf')) + logging.info(f"Task {task_id} evaluation reward: {eval_mean_reward}") + local_task_returns[task_id] = eval_mean_reward + if eval_mean_reward >= cfg.policy.target_return: + logging.info(f"Task {task_id} has reached its target return of {cfg.policy.target_return}. Adding to solved pool.") + solved_task_pool.add(task_id) + else: + logging.warning(f"Evaluation failed or timed out for task {task_id}. Assigning a low score.") + local_task_returns[task_id] = float('-inf') + + # Collect new data + logging.info(f'Rank {rank} collecting data for task_id: {task_id}...') + collect_kwargs = {'temperature': visit_count_temperature(cfg.policy.manual_temperature_decay, cfg.policy.fixed_temperature_value, cfg.policy.threshold_training_steps_for_final_temperature, learner.train_iter)} + if cfg.policy.eps.eps_greedy_exploration_in_collect: + epsilon_fn = get_epsilon_greedy_fn(cfg.policy.eps.start, cfg.policy.eps.end, cfg.policy.eps.decay, cfg.policy.eps.type) + collect_kwargs['epsilon'] = epsilon_fn(collector.envstep) + + collector._policy.reset(reset_init_data=True, task_id=task_id) + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + logging.info(f'Rank {rank}: Data collection finished for task {task_id}.') + + # --- 3. DDP Synchronization of Task Status and Weights --- + dist.barrier() + # Gather solved tasks from all ranks + all_solved_pools = [None for _ in range(world_size)] + dist.all_gather_object(all_solved_pools, solved_task_pool) + global_solved_task_pool = set().union(*[pool for pool in all_solved_pools if pool is not None]) + solved_task_pool = global_solved_task_pool # Sync local pool with global + global_solved_count = len(solved_task_pool) + + # Gather evaluation returns and compute task weights + task_weights = None + if learner.train_iter > 10 and learner.train_iter % compiled_cfg.policy.eval_freq == 0: + all_task_returns = [None for _ in range(world_size)] + dist.all_gather_object(all_task_returns, local_task_returns) + + merged_task_returns = {k: v for d in all_task_returns if d for k, v in d.items()} + for tid, ret in merged_task_returns.items(): + GLOBAL_EVAL_RETURNS[tid] = ret # Update global tracker + + unsolved_task_returns = {tid: ret for tid, ret in merged_task_returns.items() if tid not in solved_task_pool} + + if rank == 0: + logging.info(f"Global unsolved task returns for weight calculation: {unsolved_task_returns}") + if compiled_cfg.policy.task_complexity_weight and unsolved_task_returns: + temp = temperature_scheduler.get_temperature(learner.train_iter) + task_weights = compute_task_weights(unsolved_task_returns, option="rank", temperature=temp) + logging.info(f"Computed task weights: {task_weights}") + + # Log UniZero-MT normalized stats + mean_norm, median_norm = compute_unizero_mt_normalized_stats(GLOBAL_EVAL_RETURNS) + if mean_norm is not None: + tb_logger.add_scalar('UniZero-MT/NormalizedMean', mean_norm, learner.train_iter) + tb_logger.add_scalar('UniZero-MT/NormalizedMedian', median_norm, learner.train_iter) + logging.info(f"UniZero-MT Normalized Mean={mean_norm:.4f}, Median={median_norm:.4f}") + + # Broadcast weights from rank 0 to all other ranks + broadcast_objects = [task_weights] + dist.broadcast_object_list(broadcast_objects, src=0) + task_weights = broadcast_objects[0] + + # --- 4. Curriculum Stage Update --- + unsolved_count = total_tasks - global_solved_count + switched = curriculum_controller.step(global_solved_count, unsolved_count, learner.train_iter) + + if rank == 0: + tb_logger.add_scalar('Curriculum/Stage', curriculum_controller.stage, learner.train_iter) + tb_logger.add_scalar('Curriculum/GlobalSolvedTasks', global_solved_count, learner.train_iter) + + # TODO 遍历 transformer 中所有子模块,根据其名称查找 CurriculumLoRALinear 模块 + # transformer = policy._learn_model.world_model.transformer + # for module_name, module in transformer.named_modules(): + # if isinstance(module, CurriculumLoRALinear) and module.adapters is not None: + # for adapter_idx, scale_param in enumerate(module.adapter_scales): + # tb_logger.add_scalar( + # f'Curriculum/adapter_scales/{module_name}/adapter_{adapter_idx}', + # scale_param().item(), + # global_step=learner.train_iter + # ) + + # 新增的 alpha 缩放因子日志记录 + try: + transformer = policy._learn_model.world_model.transformer + for module_name, module in transformer.named_modules(): + if isinstance(module, CurriculumLoRALinear): + # 检查模块是否有 base_weight_scale 属性 + if hasattr(module, 'base_weight_scale') and module.base_weight_scale is not None: + # 1. 记录基座权重的缩放因子 (alpha_0) + tb_logger.add_scalar( + f'Curriculum/alpha_scales/{module_name}/alpha_0_base_weight', + module.base_weight_scale().item(), + global_step=learner.train_iter + ) + + # 检查模块是否有 adapter_scales 属性 + if hasattr(module, 'adapter_scales') and module.adapter_scales is not None: + # 2. 遍历并记录所有适配器的缩放因子 (alpha_1, alpha_2, ...) + for adapter_idx, scale_param in enumerate(module.adapter_scales): + # adapter_idx 是从 0 开始的,对应 alpha_{idx+1} + tb_logger.add_scalar( + f'Curriculum/alpha_scales/{module_name}/alpha_{adapter_idx + 1}', + scale_param().item(), + global_step=learner.train_iter + ) + except Exception as e: + logging.warning(f"Failed to log alpha scales: {e}") + + + # Ensure all processes are aware of a potential stage switch + dist.barrier() + + # --- 5. Training Step --- + unsolved_buffers = [rb for cfg, rb in zip(task_configs, replay_buffers) if cfg.policy.task_id not in solved_task_pool] + unsolved_cfgs = [cfg for cfg in task_configs if cfg.policy.task_id not in solved_task_pool] + + if not unsolved_buffers: + logging.info(f"Rank {rank}: All assigned tasks are solved. Performing dummy training to maintain DDP sync.") + # When all local tasks are solved, we must still participate in DDP. + # A dummy forward/backward pass with zeroed gradients can ensure this. + # The current implementation uses a minimal batch from solved tasks with `ignore_grad=True`. + for _ in range(compiled_cfg.policy.update_per_collect): + train_data_list = [] + for cfg, replay_buffer in zip(task_configs, replay_buffers): # Use original buffers + batch_size = 2 # Minimal batch size for sync + if replay_buffer.get_num_of_transitions() >= batch_size: + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) + train_data_list.append(train_data) + + if train_data_list: + learner.train(train_data_list, collector.envstep, policy_kwargs={'task_weights': None, "ignore_grad": True}) + + else: + for _ in range(compiled_cfg.policy.update_per_collect): + train_data_list = [] + total_envstep = sum(c.envstep for c in collectors) + for cfg, replay_buffer in zip(unsolved_cfgs, unsolved_buffers): + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() >= batch_size: + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) + train_data_list.append(train_data) + else: + logging.warning(f"Skipping training for task {cfg.policy.task_id}: not enough data in buffer.") + + if train_data_list: + learn_kwargs = {'task_weights': task_weights, "ignore_grad": False} + learner.train(train_data_list, total_envstep, policy_kwargs=learn_kwargs) + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # --- 6. Synchronization and Termination Check --- + dist.barrier() # Ensure all ranks complete the training step + + # Check for termination conditions + max_iter_reached = torch.tensor([learner.train_iter >= max_train_iter], dtype=torch.bool, device=compiled_cfg.policy.device) + dist.all_reduce(max_iter_reached, op=dist.ReduceOp.SUM) + + # For env_step, gather from all collectors on all ranks + local_env_steps = torch.tensor([c.envstep for c in collectors], dtype=torch.long, device=compiled_cfg.policy.device) + all_env_steps = [torch.zeros_like(local_env_steps) for _ in range(world_size)] + # Note: all_gather requires all tensors to be the same size. This assumes each rank has the same number of collectors. + # If not, a more complex gathering method (e.g., all_gather_object) is needed. + try: + dist.all_gather(all_env_steps, local_env_steps) + max_step_reached = (torch.cat(all_env_steps).min() >= max_env_step) if all_env_steps else False + except RuntimeError: # If tensor sizes mismatch + max_step_reached = False # Fallback, consider logging an error + logging.warning("Could not gather env_steps due to tensor size mismatch across ranks. Termination check may be inaccurate.") + + if max_iter_reached.item() or max_step_reached: + logging.info(f"Rank {rank}: Termination condition met. Stopping training.") + break + + # --- Finalization --- + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_segment_ddp.py b/lzero/entry/train_unizero_multitask_segment_ddp.py new file mode 100644 index 000000000..ada067bd2 --- /dev/null +++ b/lzero/entry/train_unizero_multitask_segment_ddp.py @@ -0,0 +1,890 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List, Dict + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy, Policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler +from lzero.policy import visit_count_temperature +# HACK: The following imports are for type hinting purposes. +# The actual GameBuffer is selected dynamically based on the policy type. +from lzero.mcts import UniZeroGameBuffer +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from ding.utils import EasyTimer +import torch.nn.functional as F + +import torch.distributed as dist +import concurrent.futures +from collections import defaultdict + + +# ==================================================================================================================== +# Note: The following global benchmark score definitions are for reference. +# The active implementation for score initialization is located within the `train_unizero_multitask_segment_ddp` function +# to ensure scores are correctly set based on the `benchmark_name` argument passed to the function. +# ==================================================================================================================== +# global BENCHMARK_NAME +# # BENCHMARK_NAME = "atari" +# BENCHMARK_NAME = "dmc" # TODO +# if BENCHMARK_NAME == "atari": +# RANDOM_SCORES = np.array([ +# 227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5, +# 152.1, 0.0, 65.2, 257.6, 1027.0, 29.0, 52.0, 1598.0, 258.5, 307.3, +# -20.7, 24.9, 163.9, 11.5, 68.4, 533.4 +# ]) +# HUMAN_SCORES = np.array([ +# 7127.7, 1719.5, 742.0, 8503.3, 753.1, 37187.5, 12.1, 30.5, 7387.8, 35829.4, +# 1971.0, 29.6, 4334.7, 2412.5, 30826.4, 302.8, 3035.0, 2665.5, 22736.3, 6951.6, +# 14.6, 69571.3, 13455.0, 7845.0, 42054.7, 11693.2 +# ]) +# elif BENCHMARK_NAME == "dmc": +# RANDOM_SCORES = np.array([0]*26) +# HUMAN_SCORES = np.array([1000]*26) +# +# # New order to original index mapping +# # New order: [Pong, MsPacman, Seaquest, Boxing, Alien, ChopperCommand, Hero, RoadRunner, +# # Amidar, Assault, Asterix, BankHeist, BattleZone, CrazyClimber, DemonAttack, +# # Freeway, Frostbite, Gopher, Jamesbond, Kangaroo, Krull, KungFuMaster, +# # PrivateEye, UpNDown, Qbert, Breakout] +# # Mapping to indices in the original array (0-based) +# new_order = [ +# 20, 19, 24, 6, 0, 8, 14, 23, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 15, 16, 17, 18, 21, 25, 22, 7 +# ] +# +# # Generate new arrays based on new_order +# new_RANDOM_SCORES = RANDOM_SCORES[new_order] +# new_HUMAN_SCORES = HUMAN_SCORES[new_order] + + +# ------------------------------------------------------------ +# 1. Add a dedicated process-group for the learner. +# (This should be called once during main/learner initialization) +# ------------------------------------------------------------ +def build_learner_group(learner_ranks: list[int]) -> "dist.ProcessGroup": + """ + Overview: + Build a new process group for learners that perform backward propagation. + This is useful in scenarios like MoCo where specific ranks handle the learning process. + Arguments: + - learner_ranks (:obj:`list[int]`): A list of ranks that will perform the backward pass. + For example, if CUDA_VISIBLE_DEVICES=0,1, then learner_ranks=[0,1]. + Returns: + - pg (:obj:`dist.ProcessGroup`): A new process group for the specified learner ranks. + """ + world_pg = dist.group.WORLD + pg = dist.new_group(ranks=learner_ranks, backend='nccl') + if dist.get_rank() in learner_ranks: + torch.cuda.set_device(learner_ranks.index(dist.get_rank())) + return pg + + +# Stores the latest evaluation returns: {task_id: eval_episode_return_mean} +GLOBAL_EVAL_RETURNS: Dict[int, float] = defaultdict(lambda: None) + + +def compute_unizero_mt_normalized_stats( + eval_returns: Dict[int, float] +) -> Tuple[Optional[float], Optional[float]]: + """ + Overview: + Computes the Human-Normalized Mean and Median from evaluation returns for UniZero-MT. + If there are no samples, it returns (None, None). + Arguments: + - eval_returns (:obj:`Dict[int, float]`): A dictionary of evaluation returns, keyed by task ID. + Returns: + - (:obj:`Tuple[Optional[float], Optional[float]]`): A tuple containing the human-normalized mean and median. + Returns (None, None) if no valid returns are provided. + """ + normalized = [] + for tid, ret in eval_returns.items(): + if ret is None: + continue + # Denominator for normalization + denom = new_HUMAN_SCORES[tid] - new_RANDOM_SCORES[tid] + if denom == 0: + continue + normalized.append((ret - new_RANDOM_SCORES[tid]) / denom) + + if not normalized: + return None, None + arr = np.asarray(normalized, dtype=np.float32) + return float(arr.mean()), float(np.median(arr)) + + +# Set a timeout for evaluation in seconds +TIMEOUT = 12000 # e.g., 200 minutes + +timer = EasyTimer() + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[float]]: + """ + Overview: + Safely executes an evaluation task with a timeout to prevent hangs. + Arguments: + - evaluator (:obj:`Evaluator`): The evaluator instance. + - learner (:obj:`BaseLearner`): The learner instance. + - collector (:obj:`Collector`): The data collector instance. + - rank (:obj:`int`): The rank of the current process. + - world_size (:obj:`int`): The total number of processes. + Returns: + - (:obj:`Tuple[Optional[bool], Optional[float]]`): A tuple containing the stop flag and reward if evaluation succeeds, + otherwise (None, None). + """ + try: + print(f"=========评估开始 Rank {rank}/{world_size}===========") + # Reset the stop_event to ensure it is not set before each evaluation. + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # Submit the evaluation task. + future = executor.submit(evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep) + try: + stop, reward = future.result(timeout=TIMEOUT) + except concurrent.futures.TimeoutError: + # If a timeout occurs, set the stop_event. + evaluator.stop_event.set() + print(f"评估操作在 Rank {rank}/{world_size} 上超时,耗时 {TIMEOUT} 秒。") + return None, None + + print(f"======评估结束 Rank {rank}/{world_size}======") + return stop, reward + except Exception as e: + print(f"Rank {rank}/{world_size} 评估过程中发生错误: {e}") + return None, None + + +def allocate_batch_size( + cfgs: List[dict], + game_buffers: List['UniZeroGameBuffer'], + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + Overview: + Allocates batch sizes for different tasks inversely proportional to the number of collected episodes. + It also dynamically adjusts the batch size range to improve training stability and efficiency. + Arguments: + - cfgs (:obj:`List[dict]`): A list of configurations for each task. + - game_buffers (:obj:`List[GameBuffer]`): A list of replay buffer instances for each task. + - alpha (:obj:`float`): A hyperparameter to control the degree of inverse proportionality. Defaults to 1.0. + - clip_scale (:obj:`int`): The clipping ratio for dynamic adjustment. Defaults to 1. + Returns: + - (:obj:`List[int]`): The list of allocated batch sizes. + """ + # Extract the number of collected episodes for each task. + buffer_num_of_collected_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] + + # Get the current world_size and rank. + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # Gather the lists of collected episodes from all ranks. + all_task_num_of_collected_episodes = [None for _ in range(world_size)] + torch.distributed.all_gather_object(all_task_num_of_collected_episodes, buffer_num_of_collected_episodes) + + # Merge the collected episodes from all ranks into a single list. + all_task_num_of_collected_episodes = [ + episode for sublist in all_task_num_of_collected_episodes for episode in sublist + ] + if rank == 0: + print(f'所有任务的 collected episodes: {all_task_num_of_collected_episodes}') + + # Calculate the inverse proportional weights for each task. + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes]) + inv_sum = np.sum(inv_episodes) + + # Calculate the total batch size (sum of cfg.policy.batch_size for all tasks). + total_batch_size = cfgs[0].policy.total_batch_size + + # Dynamic adjustment: define the min and max batch size range. + avg_batch_size = total_batch_size / world_size + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # Dynamically adjust alpha to make batch size changes smoother. + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = total_batch_size * task_weights + + # Clip the batch sizes to be within the [min_batch_size, max_batch_size] range. + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + + # Ensure batch sizes are integers. + batch_sizes = [int(size) for size in batch_sizes] + + return batch_sizes + + +def symlog(x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Symlog normalization to reduce the magnitude difference of target values. + symlog(x) = sign(x) * log(|x| + 1) + """ + return torch.sign(x) * torch.log(torch.abs(x) + 1) + + +def inv_symlog(x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Inverse operation of Symlog to restore the original value. + inv_symlog(x) = sign(x) * (exp(|x|) - 1) + """ + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) + + +# Global max and min for "run-max-min" normalization +GLOBAL_MAX = -float('inf') +GLOBAL_MIN = float('inf') + + +def compute_task_weights( + task_returns: Dict[int, float], + option: str = "symlog", + epsilon: float = 1e-6, + temperature: float = 1.0, + use_softmax: bool = False, + reverse: bool = False, + clip_min: float = 1e-2, + clip_max: float = 1.0, +) -> Dict[int, float]: + """ + Overview: + An improved function for calculating task weights, supporting multiple normalization methods, + Softmax, proportional/inverse weighting, and weight clipping. + Arguments: + - task_returns (:obj:`Dict[int, float]`): A dictionary where keys are task_ids and values are evaluation rewards or losses. + - option (:obj:`str`): Normalization method. Options: "symlog", "max-min", "run-max-min", "rank", "none". + - epsilon (:obj:`float`): A small value to avoid division by zero. + - temperature (:obj:`float`): Temperature coefficient to control the weight distribution. + - use_softmax (:obj:`bool`): Whether to use Softmax for weight distribution. + - reverse (:obj:`bool`): If True, weights are inversely proportional to values; if False, they are proportional. + - clip_min (:obj:`float`): The minimum value to clip weights to. + - clip_max (:obj:`float`): The maximum value to clip weights to. + Returns: + - (:obj:`Dict[int, float]`): A dictionary of weights for each task, where keys are task_ids. + """ + global GLOBAL_MAX, GLOBAL_MIN + + # Return an empty dictionary if the input is empty. + if not task_returns: + return {} + + # Step 1: Construct a tensor from the values of task_returns. + task_ids = list(task_returns.keys()) + returns_tensor = torch.tensor(list(task_returns.values()), dtype=torch.float32) + + if option == "symlog": + # Use symlog normalization. + scaled_returns = symlog(returns_tensor) + elif option == "max-min": + # Use max-min normalization. + max_reward = returns_tensor.max().item() + min_reward = returns_tensor.min().item() + scaled_returns = (returns_tensor - min_reward) / (max_reward - min_reward + epsilon) + elif option == "run-max-min": + # Use global running max-min normalization. + GLOBAL_MAX = max(GLOBAL_MAX, returns_tensor.max().item()) + GLOBAL_MIN = min(GLOBAL_MIN, returns_tensor.min().item()) + scaled_returns = (returns_tensor - GLOBAL_MIN) / (GLOBAL_MAX - GLOBAL_MIN + epsilon) + elif option == "rank": + # Use rank-based normalization. Rank is based on value size, with 1 for the smallest. + sorted_indices = torch.argsort(returns_tensor) + scaled_returns = torch.empty_like(returns_tensor) + rank_values = torch.arange(1, len(returns_tensor) + 1, dtype=torch.float32) # Ranks from 1 to N + scaled_returns[sorted_indices] = rank_values + elif option == "none": + # No normalization. + scaled_returns = returns_tensor + else: + raise ValueError(f"Unsupported option: {option}") + + # Step 2: Determine if weights are proportional or inversely proportional based on `reverse`. + if not reverse: + # Proportional: weight is positively correlated with the value. + raw_weights = scaled_returns + else: + # Inverse: weight is negatively correlated with the value. + # Clamp to avoid division by zero or negative numbers. + scaled_returns = torch.clamp(scaled_returns, min=epsilon) + raw_weights = 1.0 / scaled_returns + + # Step 3: Calculate weights with or without Softmax. + if use_softmax: + # Use Softmax for weight distribution. + beta = 1.0 / max(temperature, epsilon) # Ensure temperature is not zero. + logits = -beta * raw_weights + softmax_weights = F.softmax(logits, dim=0).numpy() + weights = dict(zip(task_ids, softmax_weights)) + else: + # Do not use Softmax, calculate weights directly. + # Temperature scaling. + scaled_weights = raw_weights ** (1 / max(temperature, epsilon)) # Ensure temperature is not zero. + + # Normalize weights. + total_weight = scaled_weights.sum() + normalized_weights = scaled_weights / total_weight + + # Convert to dictionary. + weights = dict(zip(task_ids, normalized_weights.numpy())) + + # Step 4: Clip the weight range. + for task_id in weights: + weights[task_id] = max(min(weights[task_id], clip_max), clip_min) + + return weights + + +def train_unizero_multitask_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), + benchmark_name: str = "atari" +) -> 'Policy': + """ + Overview: + The training entry point for UniZero, designed to enhance the planning capabilities of reinforcement learning agents + by addressing the limitations of MuZero-like algorithms in environments requiring long-term dependency capture. + For more details, please refer to https://arxiv.org/abs/2406.10667. + + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): A list of configurations for different tasks. + - seed (:obj:`int`): The random seed. + - model (:obj:`Optional[torch.nn.Module]`): An instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The path to a pre-trained model checkpoint file. + - max_train_iter (:obj:`Optional[int]`): The maximum number of policy update iterations during training. + - max_env_step (:obj:`Optional[int]`): The maximum number of environment interaction steps to collect. + - benchmark_name (:obj:`str`): The name of the benchmark, e.g., "atari" or "dmc". + + Returns: + - policy (:obj:`Policy`): The converged policy. + """ + # ------------------------------------------------------------------------------------ + # ====== UniZero-MT Benchmark Scores (corresponding to 26 Atari100k task IDs) ====== + # Original RANDOM_SCORES and HUMAN_SCORES + if benchmark_name == "atari": + RANDOM_SCORES = np.array([ + 227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5, + 152.1, 0.0, 65.2, 257.6, 1027.0, 29.0, 52.0, 1598.0, 258.5, 307.3, + -20.7, 24.9, 163.9, 11.5, 68.4, 533.4 + ]) + HUMAN_SCORES = np.array([ + 7127.7, 1719.5, 742.0, 8503.3, 753.1, 37187.5, 12.1, 30.5, 7387.8, 35829.4, + 1971.0, 29.6, 4334.7, 2412.5, 30826.4, 302.8, 3035.0, 2665.5, 22736.3, 6951.6, + 14.6, 69571.3, 13455.0, 7845.0, 42054.7, 11693.2 + ]) + elif benchmark_name == "dmc": + RANDOM_SCORES = np.zeros(26) + HUMAN_SCORES = np.ones(26) * 1000 + else: + raise ValueError(f"Unsupported BENCHMARK_NAME: {benchmark_name}") + + # New order to original index mapping + # New order: [Pong, MsPacman, Seaquest, Boxing, Alien, ChopperCommand, Hero, RoadRunner, + # Amidar, Assault, Asterix, BankHeist, BattleZone, CrazyClimber, DemonAttack, + # Freeway, Frostbite, Gopher, Jamesbond, Kangaroo, Krull, KungFuMaster, + # PrivateEye, UpNDown, Qbert, Breakout] + # Mapping to indices in the original array (0-based) + new_order = [ + 20, 19, 24, 6, 0, 8, 14, 23, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 15, 16, 17, 18, 21, 25, 22, 7 + ] + global new_RANDOM_SCORES, new_HUMAN_SCORES + # Generate new arrays based on new_order + new_RANDOM_SCORES = RANDOM_SCORES[new_order] + new_HUMAN_SCORES = HUMAN_SCORES[new_order] + # Log the reordered results + print("重排后的 RANDOM_SCORES:") + print(new_RANDOM_SCORES) + print("\n重排后的 HUMAN_SCORES:") + print(new_HUMAN_SCORES) + # ------------------------------------------------------------------------------------ + + # Initialize the temperature scheduler for task weighting. + initial_temperature = 10.0 + final_temperature = 1.0 + threshold_steps = int(1e4) # Temperature drops to 1.0 after 10k training steps. + temperature_scheduler = TemperatureScheduler( + initial_temp=initial_temperature, + final_temp=final_temperature, + threshold_steps=threshold_steps, + mode='linear' # or 'exponential' + ) + + # Get the current process rank and total world size. + rank = get_rank() + world_size = get_world_size() + + # Task partitioning among ranks. + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + # ==================== START: 关键修复 ==================== + # 1. 精确计算当前Rank负责的任务数量 + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + num_tasks_for_this_rank = tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + num_tasks_for_this_rank = tasks_per_rank + # ==================== END: 关键修复 ==================== + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # Ensure at least one task is assigned. + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: No tasks assigned, continuing execution.") + # Initialize empty lists to avoid errors later. + cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] + else: + print(f"Rank {rank}/{world_size}, 处理任务 {start_idx} 到 {end_idx - 1}") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + if tasks_for_this_rank: + # Use the config of the first task to create a shared policy. + task_id, [cfg, create_cfg] = tasks_for_this_rank[0] + + # ==================== START: 关键修复 ==================== + # 2. 将正确的任务数量设置到 *所有* 相关配置中 + # 在创建Policy实例之前,必须确保配置是正确的 + for config_tuple in tasks_for_this_rank: + # config_tuple is (task_id, [cfg_obj, create_cfg_obj]) + config_tuple[1][0].policy.task_num = num_tasks_for_this_rank + + # 3. 确保用于创建Policy的那个cfg对象也拥有正确的task_num + cfg.policy.task_num = num_tasks_for_this_rank + # ==================== END: 关键修复 ==================== + + # Ensure the specified policy type is supported. + assert create_cfg.policy.type in ['unizero_multitask', 'sampled_unizero_multitask'], \ + "train_unizero entry currently only supports 'unizero_multitask' or 'sampled_unizero_multitask'" + + if create_cfg.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GameBuffer + if create_cfg.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer + + # Set device based on CUDA availability. + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'配置的设备: {cfg.policy.device}') + + # Compile the configuration. + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create the shared policy. + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load a pre-trained model if a path is provided. + if model_path is not None: + logging.info(f'开始加载模型: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'完成加载模型: {model_path}') + + # Create a TensorBoard logger. + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # Create the shared learner. + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + policy_config = cfg.policy + + # Process each task assigned to the current rank. + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + # Set a unique random seed for each task. + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # Create environments. + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # Create task-specific game buffers, collectors, and evaluators. + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + # Call the learner's before_run hook. + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + task_exploitation_weight = None + + # Dictionary to store task rewards. + task_returns = {} # {task_id: reward} + + while True: + # Dynamically adjust batch sizes. + if cfg.policy.allocated_batch_sizes: + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + print("分配后的 batch_sizes: ", allocated_batch_sizes) + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + cfg.policy.batch_size = allocated_batch_sizes + policy._cfg.batch_size = allocated_batch_sizes + + # For each task on the current rank, perform data collection and evaluation. + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + + # Log buffer memory usage. + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # Default epsilon value. + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # Check if it's time for evaluation. + if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0: + # if learner.train_iter == 0 or learner.train_iter % cfg.policy.eval_freq == 0: # only for debug TODO + + print('=' * 20) + print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...') + + # TODO: Ensure policy reset logic is optimal for multi-task settings. + evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + + # Perform safe evaluation. + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + # Check if evaluation was successful. + if stop is None or reward is None: + print(f"Rank {rank} 在评估过程中遇到问题,继续训练...") + task_returns[cfg.policy.task_id] = float('inf') # Set task difficulty to max if evaluation fails. + else: + # Extract 'eval_episode_return_mean' from the reward dictionary. + try: + eval_mean_reward = reward.get('eval_episode_return_mean', float('inf')) + print(f"任务 {cfg.policy.task_id} 的评估奖励: {eval_mean_reward}") + task_returns[cfg.policy.task_id] = eval_mean_reward + except Exception as e: + print(f"提取评估奖励时发生错误: {e}") + task_returns[cfg.policy.task_id] = float('inf') # Set reward to max on error. + + print('=' * 20) + print(f'开始收集 Rank {rank} 的任务_id: {cfg.policy.task_id}...') + print(f'Rank {rank}: cfg.policy.task_id={cfg.policy.task_id} ') + + # Reset initial data before each collection, crucial for multi-task settings. + collector._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + # Collect data. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # Update the replay buffer. + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # ===== For debugging purposes only ===== + # if train_epoch > 2: + # with timer: + # replay_buffer.reanalyze_buffer(2, policy) + # buffer_reanalyze_count += 1 + # logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + # logging.info(f'缓冲区重新分析耗时: {timer.value}') + # ==================================== + + # Periodically reanalyze the buffer. + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch > 0 and train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析耗时: {timer.value}') + + # Log after data collection. + logging.info(f'Rank {rank}: 完成任务 {cfg.policy.task_id} 的数据收集') + + # Check if there is enough data for training. + not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size + for replay_buffer in game_buffers + ) + + print(f"not_enough_data:{not_enough_data}") + # Get the current temperature for task weighting. + current_temperature_task_weight = temperature_scheduler.get_temperature(learner.train_iter) + + if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0: + # Calculate task weights. + try: + # Gather task rewards. + dist.barrier() + all_task_returns = [None for _ in range(world_size)] + dist.all_gather_object(all_task_returns, task_returns) + # Merge task rewards. + merged_task_returns = {} + for returns in all_task_returns: + if returns: + merged_task_returns.update(returns) + + logging.warning(f"Rank {rank}: merged_task_returns: {merged_task_returns}") + + # Calculate global task weights. + task_weights = compute_task_weights(merged_task_returns, temperature=current_temperature_task_weight) + + # ---------- Maintain UniZero-MT global evaluation results ---------- + for tid, ret in merged_task_returns.items(): + GLOBAL_EVAL_RETURNS[tid] = ret # Update even for solved tasks. + + # Calculate Human-Normalized Mean / Median. + uni_mean, uni_median = compute_unizero_mt_normalized_stats(GLOBAL_EVAL_RETURNS) + + if uni_mean is not None: # At least one task has been evaluated. + if rank == 0: # Only write to TensorBoard on rank 0 to avoid duplication. + tb_logger.add_scalar('UniZero-MT/NormalizedMean', uni_mean, global_step=learner.train_iter) + tb_logger.add_scalar('UniZero-MT/NormalizedMedian', uni_median, global_step=learner.train_iter) + logging.info(f"Rank {rank}: UniZero-MT Norm Mean={uni_mean:.4f}, Median={uni_median:.4f}") + else: + logging.info(f"Rank {rank}: 暂无数据计算 UniZero-MT 归一化指标") + + # Synchronize task weights. + dist.broadcast_object_list([task_weights], src=0) + except Exception as e: + logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}') + break + + # ---------------- Sampling done, preparing for backward pass ---------------- + # dist.barrier() # ★★★ Critical synchronization point ★★★ + + # Learn policy. + if not not_enough_data: + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + if i % reanalyze_interval == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析耗时: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) # Append task_id to differentiate tasks. + train_data_multi_task.append(train_data) + else: + logging.warning( + f'重放缓冲区中的数据不足以采样mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + learn_kwargs = {'task_weights': None,"train_iter":learner.train_iter} + + # DDP automatically synchronizes gradients and parameters during training. + log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs) + + # Check if task_exploitation_weight needs to be calculated. + if i == 0: + # Calculate task weights. + try: + dist.barrier() # Wait for all processes to synchronize. + if cfg.policy.use_task_exploitation_weight: # Use obs loss now, new polish. + # Gather obs_loss from all tasks. + all_obs_loss = [None for _ in range(world_size)] + # Build obs_loss data for the current process's tasks. + merged_obs_loss_task = {} + for cfg, replay_buffer in zip(cfgs, game_buffers): + task_id = cfg.policy.task_id + if f'noreduce_obs_loss_task{task_id}' in log_vars[0]: + merged_obs_loss_task[task_id] = log_vars[0][ + f'noreduce_obs_loss_task{task_id}'] + # Gather obs_loss data from all processes. + dist.all_gather_object(all_obs_loss, merged_obs_loss_task) + # Merge obs_loss data from all processes. + global_obs_loss_task = {} + for obs_loss_task in all_obs_loss: + if obs_loss_task: + global_obs_loss_task.update(obs_loss_task) + # Calculate global task weights. + if global_obs_loss_task: + task_exploitation_weight = compute_task_weights( + global_obs_loss_task, + option="rank", + # TODO: Decide whether to use the temperature scheduler here. + temperature=1, + ) + # Broadcast task weights to all processes. + dist.broadcast_object_list([task_exploitation_weight], src=0) + print( + f"rank{rank}, task_exploitation_weight (按 task_id 排列): {task_exploitation_weight}") + else: + logging.warning(f"Rank {rank}: 未能计算全局 obs_loss 任务权重,obs_loss 数据为空。") + task_exploitation_weight = None + else: + task_exploitation_weight = None + # Update training parameters to include the calculated task weights. + learn_kwargs['task_weight'] = task_exploitation_weight + except Exception as e: + logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}') + raise e # Re-raise the exception for external capture and analysis. + + if cfg.policy.use_priority: + for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): + # Update task-specific replay buffer priorities. + task_id = cfg.policy.task_id + # replay_buffer.update_priority( + # train_data_multi_task[idx], + # log_vars[0][f'value_priority_task{task_id}'] + # ) + replay_buffer.update_priority( + train_data_multi_task[idx], + log_vars[0][f'noreduce_value_priority_task{task_id}'] + ) + + # current_priorities = log_vars[0][f'value_priority_task{task_id}'] + # mean_priority = np.mean(current_priorities) + # std_priority = np.std(current_priorities) + + # alpha = 0.1 # Smoothing factor + # if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + # value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + # else: + # value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + # alpha * mean_priority + + # (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + # ) + + # # Use running mean to calculate normalized priorities. + # running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + # normalized_priorities = (current_priorities - running_mean_priority) / ( + # std_priority + 1e-6) + + # # If needed, update the replay buffer with normalized priorities. + # # replay_buffer.update_priority(train_data_multi_task[idx], normalized_priorities) + + # # Log priority statistics. + # if cfg.policy.print_task_priority_logs: + # print(f"任务 {task_id} - 平均优先级: {mean_priority:.8f}, " + # f"运行平均优先级: {running_mean_priority:.8f}, " + # f"标准差: {std_priority:.8f}") + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # Synchronize all ranks to ensure they have completed training. + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练后的同步障碍') + except Exception as e: + logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}') + break + + # Check for termination conditions. + try: + local_envsteps = [collector.envstep for collector in collectors] + total_envsteps = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + all_envsteps = torch.cat([torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + # Gather train_iter from all processes. + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: 达到终止条件') + dist.barrier() # Ensure all processes synchronize before exiting. + break + except Exception as e: + logging.error(f'Rank {rank}: 终止检查失败,错误: {e}') + break + + # Call the learner's after_run hook. + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_segment_eval.py b/lzero/entry/train_unizero_multitask_segment_eval.py new file mode 100644 index 000000000..3715cbef4 --- /dev/null +++ b/lzero/entry/train_unizero_multitask_segment_eval.py @@ -0,0 +1,408 @@ +import logging +import os +import concurrent.futures +from functools import partial +from typing import Tuple, Optional, List, Dict, Any, Type + +import torch +import torch.distributed as dist +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy, Policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size, EasyTimer +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.mcts import UniZeroGameBuffer as GameBuffer +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector + +# Configure basic logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', +) + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int, + timeout: int = 12000 +) -> Tuple[Optional[bool], Optional[float]]: + """ + Overview: + Safely evaluates the policy using the evaluator with a specified timeout. This wrapper prevents + the entire training process from crashing due to evaluation-related issues like deadlocks. + Arguments: + - evaluator (:obj:`Evaluator`): The evaluator instance to run. + - learner (:obj:`BaseLearner`): The learner instance, used to access checkpoint saving and training iteration. + - collector (:obj:`Collector`): The collector instance, used to access the environment step count. + - rank (:obj:`int`): The rank of the current process in distributed training. + - world_size (:obj:`int`): The total number of processes. + - timeout (:obj:`int`): The maximum time in seconds to wait for the evaluation to complete. + Returns: + - (:obj:`Tuple[Optional[bool], Optional[float]]`): A tuple containing the stop flag and the reward. + Returns (None, None) if evaluation times out or an exception occurs. + """ + try: + logging.info(f"Rank {rank}/{world_size}: Starting evaluation.") + # Ensure the stop_event is clear before starting a new evaluation. + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit( + evaluator.eval, + learner.save_checkpoint, + learner.train_iter, + collector.envstep + ) + try: + stop, reward = future.result(timeout=timeout) + except concurrent.futures.TimeoutError: + # If evaluation exceeds the timeout, set the evaluator's stop event to terminate it gracefully. + evaluator.stop_event.set() + logging.warning(f"Rank {rank}/{world_size}: Evaluation timed out after {timeout} seconds.") + return None, None + + logging.info(f"Rank {rank}/{world_size}: Evaluation finished successfully.") + return stop, reward + except Exception as e: + logging.error(f"Rank {rank}/{world_size}: An error occurred during evaluation: {e}", exc_info=True) + return None, None + + +def allocate_batch_size( + cfgs: List[Any], + game_buffers: List[GameBuffer], + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + Overview: + Allocates batch sizes inversely proportional to the number of collected episodes for each task. + This dynamic adjustment helps balance training focus across multiple tasks, prioritizing those + with less data. The batch sizes are clipped to a dynamic range to maintain stability. + Arguments: + - cfgs (:obj:`List[Any]`): List of configuration objects for each task. + - game_buffers (:obj:`List[GameBuffer]`): List of replay buffer instances for each task. + - alpha (:obj:`float`): A hyperparameter controlling the degree of inverse proportionality. Defaults to 1.0. + - clip_scale (:obj:`int`): A scaling factor to define the clipping range for the batch size. Defaults to 1. + Returns: + - (:obj:`List[int]`): A list of allocated batch sizes for each task. + """ + # Extract the number of collected episodes from each task's buffer. + buffer_num_of_collected_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] + + world_size = get_world_size() + rank = get_rank() + + # Gather the episode counts from all ranks. + all_task_num_of_collected_episodes_obj = [None for _ in range(world_size)] + dist.all_gather_object(all_task_num_of_collected_episodes_obj, buffer_num_of_collected_episodes) + + # Concatenate the lists from all ranks into a single flat list. + all_task_num_of_collected_episodes = [item for sublist in all_task_num_of_collected_episodes_obj for item in sublist] + if rank == 0: + logging.info(f'All task collected episodes: {all_task_num_of_collected_episodes}') + + # Calculate the inverse weight for each task. Adding 1 to avoid division by zero. + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes]) + inv_sum = np.sum(inv_episodes) + + # The total batch size is defined in the config of the first task. + total_batch_size = cfgs[0].policy.total_batch_size + + # Define a dynamic range for batch sizes to prevent extreme values. + avg_batch_size = total_batch_size / world_size + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # Calculate task weights based on inverse proportionality, smoothed by alpha. + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = total_batch_size * task_weights + + # Clip the batch sizes to the calculated dynamic range. + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + + # Ensure batch sizes are integers. + batch_sizes = [int(size) for size in batch_sizes] + + return batch_sizes + + +def train_unizero_multitask_segment_eval( + input_cfg_list: List[Tuple[int, Tuple[Dict[str, Any], Dict[str, Any]]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The main training entry point for UniZero, as proposed in the paper "UniZero: Generalized and Efficient Planning + with Scalable Latent World Models" (https://arxiv.org/abs/2406.10667). This function sets up a distributed + multi-task training environment where multiple reinforcement learning tasks are trained in parallel using a + single shared model. It handles task distribution, component initialization (policy, learner, buffers, etc.), + and the main training loop orchestration. + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[Dict, Dict]]]`): A list of configurations for each task. Each + element is a tuple containing the task ID and its corresponding configuration dictionaries. + - seed (:obj:`int`): The master random seed for reproducibility. + - model (:obj:`Optional[torch.nn.Module]`): An optional pre-existing model instance. If None, a new model is + created based on the config. + - model_path (:obj:`Optional[str]`): An optional path to a pre-trained model checkpoint. + - max_train_iter (:obj:`Optional[int]`): The maximum number of training iterations before termination. + - max_env_step (:obj:`Optional[int]`): The maximum number of environment steps before termination. + Returns: + - (:obj:`'Policy'`): The trained policy instance after the training loop has converged or terminated. + """ + # ============================================================== + # 1. Initialization + # ============================================================== + + # 1.1. Distributed Setup & Task Partitioning + rank = get_rank() + world_size = get_world_size() + + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + if not tasks_for_this_rank: + logging.warning(f"Rank {rank}: No tasks assigned. This rank will be idle.") + # Keep the process alive to participate in collective communications. + dist.barrier() + return + + logging.info(f"Rank {rank}/{world_size}: Handling tasks from index {start_idx} to {end_idx - 1}.") + + # 1.2. Shared Policy, Learner, and Logger Initialization + # Use the configuration of the first task on this rank to create the shared components. + _, (first_cfg, first_create_cfg) = tasks_for_this_rank[0] + + # Set task_num for learner logging purposes. + for _, (cfg, _) in tasks_for_this_rank: + cfg.policy.task_num = tasks_per_rank + + assert first_create_cfg.policy.type in ['unizero_multitask'], \ + "This entry point currently only supports 'unizero_multitask' policy type." + + first_cfg.policy.device = 'cuda' if torch.cuda.is_available() else 'cpu' + logging.info(f'Shared policy device: {first_cfg.policy.device}') + + # Compile the main configuration. + cfg = compile_config(first_cfg, seed=seed, auto=True, create_cfg=first_create_cfg, save_cfg=True) + + # Create the shared policy. + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load a pre-trained model if a path is provided. + if model_path is not None: + logging.info(f'Loading pre-trained model from: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info('Model loading complete.') + + # Create a TensorBoard logger for this rank. + log_dir = os.path.join(f'./{cfg.exp_name}/log', f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # Create the shared learner instance. + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # 1.3. Task-Specific Components Initialization + cfgs, game_buffers, collectors, evaluators = [], [], [], [] + for task_id, (task_cfg, task_create_cfg) in tasks_for_this_rank: + # Set a unique seed for each task to ensure diversity in data collection. + task_seed = seed + task_id + task_cfg.policy.device = 'cuda' if task_cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + task_cfg = compile_config(task_cfg, seed=task_seed, auto=True, create_cfg=task_create_cfg, save_cfg=True) + + policy.collect_mode.get_attribute('cfg').n_episode = task_cfg.policy.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = task_cfg.policy.n_episode + + # Create environment managers for collection and evaluation. + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(task_cfg.env) + collector_env = create_env_manager(task_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(task_cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(task_seed) + evaluator_env.seed(task_seed, dynamic_seed=False) + set_pkg_seed(task_seed, use_cuda=task_cfg.policy.cuda) + + # Create task-specific buffers, collectors, and evaluators. + replay_buffer = GameBuffer(task_cfg.policy) + replay_buffer.batch_size = task_cfg.policy.batch_size[task_id] + + collector = Collector( + env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=task_cfg.exp_name, + policy_config=task_cfg.policy, task_id=task_id + ) + evaluator = Evaluator( + eval_freq=task_cfg.policy.eval_freq, n_evaluator_episode=task_cfg.env.n_evaluator_episode, + stop_value=task_cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode, + tb_logger=tb_logger, exp_name=task_cfg.exp_name, policy_config=task_cfg.policy, task_id=task_id + ) + + cfgs.append(task_cfg) + game_buffers.append(replay_buffer) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + + # ============================================================== + # 2. Main Training Loop + # ============================================================== + buffer_reanalyze_count = 0 + train_epoch = 0 + while True: + if learner.train_iter >= max_train_iter or collector.envstep >= max_env_step: + break + + # 2.1. Dynamic Batch Size Allocation (Optional) + if cfg.policy.allocated_batch_sizes: + # As training progresses, allow for a larger divergence in batch sizes. + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + logging.info(f"Allocated batch sizes: {allocated_batch_sizes}") + for task_cfg, replay_buffer in zip(cfgs, game_buffers): + task_cfg.policy.batch_size = allocated_batch_sizes + policy._cfg.batch_size = allocated_batch_sizes + + # 2.2. Collection and Evaluation Phase + for task_cfg, collector, evaluator, replay_buffer in zip(cfgs, collectors, evaluators, game_buffers): + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, task_cfg.policy.task_id) + + # Determine exploration parameters for collection. + collect_kwargs = { + 'temperature': visit_count_temperature( + task_cfg.policy.manual_temperature_decay, task_cfg.policy.fixed_temperature_value, + task_cfg.policy.threshold_training_steps_for_final_temperature, trained_steps=learner.train_iter + ), + 'epsilon': 0.0 + } + if task_cfg.policy.eps.eps_greedy_exploration_in_collect: + epsilon_fn = get_epsilon_greedy_fn( + start=task_cfg.policy.eps.start, end=task_cfg.policy.eps.end, + decay=task_cfg.policy.eps.decay, type_=task_cfg.policy.eps.type + ) + collect_kwargs['epsilon'] = epsilon_fn(collector.envstep) + + # Evaluate the policy periodically. + if evaluator.should_eval(learner.train_iter): + logging.info(f'Rank {rank} evaluating task_id: {task_cfg.policy.task_id}...') + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + if stop is None or reward is None: + logging.warning(f"Rank {rank} evaluation for task {task_cfg.policy.task_id} failed or timed out.") + else: + logging.info(f"Evaluation successful for task {task_cfg.policy.task_id}: stop={stop}, reward={reward}") + + # Collect new data. + logging.info(f'Rank {rank} collecting for task_id: {task_cfg.policy.task_id}...') + # NOTE: Resetting initial data is crucial in multi-task settings to avoid state leakage. + collector._policy.reset(reset_init_data=True) + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # Update the replay buffer. + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # Periodically reanalyze the buffer to update value/policy targets with a more recent model. + # This logic handles two cases for `buffer_reanalyze_freq`: + # Case 1: freq < 1 (e.g., 0.5) -> Reanalyze every `1/freq` training epochs. + if 0 < task_cfg.policy.buffer_reanalyze_freq < 1: + if (train_epoch % int(1 / task_cfg.policy.buffer_reanalyze_freq) == 0 and + replay_buffer.get_num_of_transitions() // task_cfg.policy.num_unroll_steps > + int(task_cfg.policy.reanalyze_batch_size / task_cfg.policy.reanalyze_partition)): + with EasyTimer() as timer: + replay_buffer.reanalyze_buffer(task_cfg.policy.reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}, Time: {timer.value:.2f}s') + + logging.info(f'Rank {rank}: Data collection complete for task {task_cfg.policy.task_id}') + + # 2.3. Pre-Training Synchronization and Data Check + # Check if any buffer has insufficient data for training. + not_enough_data = any( + rb.get_num_of_transitions() < cfg.policy.total_batch_size / world_size for rb in game_buffers + ) + + try: + dist.barrier() + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed before training with error {e}', exc_info=True) + break + + # 2.4. Training Phase + if not not_enough_data: + update_per_collect = cfg.policy.update_per_collect + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = sum(c.envstep for c in collectors) + + for task_cfg, replay_buffer in zip(cfgs, game_buffers): + batch_size = task_cfg.policy.batch_size[task_cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + # Case 2: freq >= 1 -> Reanalyze `freq` times per collection cycle (spread across updates). + if task_cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // task_cfg.policy.buffer_reanalyze_freq + if (i % reanalyze_interval == 0 and + replay_buffer.get_num_of_transitions() // task_cfg.policy.num_unroll_steps > + int(task_cfg.policy.reanalyze_batch_size / task_cfg.policy.reanalyze_partition)): + with EasyTimer() as timer: + replay_buffer.reanalyze_buffer(task_cfg.policy.reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}, Time: {timer.value:.2f}s') + + # Sample data and append task_id for multi-task learning. + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(task_cfg.policy.task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f"Skipping training for task {task_cfg.policy.task_id}: insufficient data. " + f"Required: {batch_size}, Available: {replay_buffer.get_num_of_transitions()}" + ) + + if train_data_multi_task: + # DDP handles gradient synchronization automatically. + learner.train(train_data_multi_task, envstep_multi_task) + + # Synchronize after each training step to maintain consistency. + try: + dist.barrier() + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed during training step with error {e}', exc_info=True) + break + else: + logging.warning(f"Rank {rank}: Skipping training cycle due to insufficient data in one or more buffers.") + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # 2.5. Post-Training Synchronization and Termination Check + try: + dist.barrier() + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed after training cycle with error {e}', exc_info=True) + break + + learner.call_hook('after_run') + logging.info(f"Rank {rank}: Training finished.") + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_segment.py b/lzero/entry/train_unizero_segment.py index c1ed74b16..0559934c0 100644 --- a/lzero/entry/train_unizero_segment.py +++ b/lzero/entry/train_unizero_segment.py @@ -154,7 +154,9 @@ def train_unizero_segment( collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) # Evaluate policy performance - if evaluator.should_eval(learner.train_iter): + # if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): + if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) if stop: break diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index 95b2faf4a..99b22b852 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -1,113 +1,539 @@ +# -*- coding: utf-8 -*- +""" +Optimized and refactored utility code for reinforcement learning models, +focusing on clarity, professionalism, efficiency, and extensibility. +""" + +# ============================================================================== +# Imports +# ============================================================================== +from __future__ import annotations + +import logging +import math import os -from typing import Optional, Callable, Union, List, Tuple +import re +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np import psutil import torch import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F from pympler.asizeof import asizeof from tensorboardX import SummaryWriter -import torch -import torch.distributed as dist -def is_ddp_enabled(): +# ============================================================================== +# Placeholder Types for External Dependencies +# +# To ensure type hints work without having the full definitions of these complex +# external classes, we define them as `Any`. +# ============================================================================== +EasyDict = Any +Policy = Any +RandomPolicy = Any +ISerialCollector = Any +BaseEnvManager = Any +IBuffer = Any +GameBuffer = Any + + +# ============================================================================== +# Mathematical & Tensor Utilities +# ============================================================================== + +def symlog(x: torch.Tensor) -> torch.Tensor: """ - Check if Distributed Data Parallel (DDP) is enabled by verifying if - PyTorch's distributed package is available and initialized. + Overview: + Applies the symlog transformation to a tensor, which is useful for + normalizing target values with large magnitude differences. + The transformation is defined as: symlog(x) = sign(x) * log(|x| + 1). + + Arguments: + - x (:obj:`torch.Tensor`): The input tensor. + + Returns: + - torch.Tensor: The tensor after applying the symlog transformation. """ - return dist.is_available() and dist.is_initialized() + return torch.sign(x) * torch.log(torch.abs(x) + 1) -def ddp_synchronize(): + +def inv_symlog(x: torch.Tensor) -> torch.Tensor: """ - Perform a barrier synchronization across all processes in DDP mode. - Ensures all processes reach this point before continuing. + Overview: + Applies the inverse of the symlog transformation to a tensor, restoring + the original scale of the values. + The transformation is defined as: inv_symlog(x) = sign(x) * (exp(|x|) - 1). + + Arguments: + - x (:obj:`torch.Tensor`): The input tensor in symlog space. + + Returns: + - torch.Tensor: The tensor restored to its original scale. """ - if is_ddp_enabled(): - dist.barrier() + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) -def ddp_all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: + +def initialize_zeros_batch( + observation_shape: Union[int, List[int], Tuple[int, ...]], + batch_size: int, + device: str +) -> torch.Tensor: """ - Perform an all-reduce operation (sum) on the given tensor across - all processes in DDP mode. Returns the reduced tensor. + Overview: + Initializes a zeros tensor for a batch of observations based on the + provided shape. This is commonly used to prepare initial input for models + like UniZero. Arguments: - - tensor (:obj:`torch.Tensor`): The input tensor to be reduced. + - observation_shape (:obj:`Union[int, List[int], Tuple[int, ...]]`): The shape of a single observation. + - batch_size (:obj:`int`): The number of observations in the batch. + - device (:obj:`str`): The device to store the tensor on (e.g., 'cpu', 'cuda'). Returns: - - torch.Tensor: The reduced tensor, summed across all processes. + - torch.Tensor: A zeros tensor with the shape [batch_size, *observation_shape]. """ - if is_ddp_enabled(): - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - return tensor + if isinstance(observation_shape, (list, tuple)): + shape = (batch_size, *observation_shape) + elif isinstance(observation_shape, int): + shape = (batch_size, observation_shape) + else: + raise TypeError( + f"observation_shape must be an int, list, or tuple, but got {type(observation_shape).__name__}" + ) + return torch.zeros(shape, device=device) + + +# ============================================================================== +# LoRA (Low-Rank Adaptation) Utilities +# ============================================================================== + +# A compiled regex pattern to efficiently detect LoRA-related parameters. +# It matches parameter names ending with: +# - .lora_A or .lora_B (for LoRA weights) +# - .adapter_scales.{digit}.logit (for learnable scale parameters) +_LORA_PAT = re.compile(r"\.(?:lora_[AB]|adapter_scales\.\d+\.logit)$") -def calculate_update_per_collect(cfg: 'EasyDict', new_data: List[List[torch.Tensor]], world_size: int = 1) -> int: + +def _is_lora_param(name: str) -> bool: + """A helper function to check if a parameter name matches the LoRA pattern.""" + return bool(_LORA_PAT.search(name)) + + +def freeze_non_lora_parameters( + module: nn.Module, + freeze: bool = True, + *, + verbose: bool = False, +) -> Tuple[int, int]: + """ + Overview: + Freezes or un-freezes all parameters in a module that are not identified + as LoRA-related parameters. This is useful for curriculum learning stages + where the backbone model is frozen and only LoRA adapters are trained. + + Arguments: + - module (:obj:`nn.Module`): The PyTorch module to process (e.g., a transformer). + - freeze (:obj:`bool`): If True, sets `requires_grad=False` for non-LoRA parameters. + If False, sets `requires_grad=True` for non-LoRA parameters. + - verbose (:obj:`bool`): If True, prints a summary of trainable and frozen parameters. + + Returns: + - Tuple[int, int]: A tuple containing the number of frozen parameters and trainable parameters. """ - Calculate the number of updates to perform per data collection in a - Distributed Data Parallel (DDP) setting. This ensures that all GPUs - compute the same `update_per_collect` value, synchronized across processes. + n_frozen = 0 + n_trainable = 0 + + for name, param in module.named_parameters(): + if _is_lora_param(name): + # LoRA-related parameters should always be trainable. + param.requires_grad = True + n_trainable += 1 + else: + # All other parameters are frozen or unfrozen based on the `freeze` flag. + param.requires_grad = not freeze + if param.requires_grad: + n_trainable += 1 + else: + n_frozen += 1 + + if verbose: + total = n_frozen + n_trainable + # Ensure total is not zero to avoid division by zero error. + percentage_trainable = (n_trainable / total * 100) if total > 0 else 0 + print( + f"[freeze_non_lora] Trainable: {n_trainable}/{total} ({percentage_trainable:.1f}%), " + f"Frozen: {n_frozen}" + ) + return n_frozen, n_trainable + + +# ============================================================================== +# Task & Curriculum Learning Utilities +# ============================================================================== + +def compute_task_weights( + task_returns: Dict[str, float], + option: str = "symlog", + epsilon: float = 1e-6, + temperature: float = 1.0, + use_softmax: bool = False, + reverse: bool = False, + clip_min: float = 1e-2, + clip_max: float = 1.0, +) -> Dict[str, float]: + """ + Overview: + Calculates sampling weights for different tasks based on their returns (e.g., rewards or losses). + This function supports various normalization methods, softmax-based distribution, + proportional/inverse weighting, and weight clipping. Arguments: - - cfg: Configuration object containing policy settings. - - new_data (List[List[torch.Tensor]]): The newly collected data segments. - - world_size (int): The total number of processes. + - task_returns (:obj:`Dict[str, float]`): A dictionary mapping task IDs to their return values. + - option (:obj:`str`): Normalization method. One of ["symlog", "max-min", "run-max-min", "rank", "none"]. + - epsilon (:obj:`float`): A small value to prevent division by zero. + - temperature (:obj:`float`): A temperature parameter to control the sharpness of the weight distribution. + - use_softmax (:obj:`bool`): If True, use softmax to compute weights; otherwise, use direct normalization. + - reverse (:obj:`bool`): If True, weights are inversely proportional to returns; otherwise, directly proportional. + - clip_min (:obj:`float`): The minimum value to clip the final weights to. + - clip_max (:obj:`float`): The maximum value to clip the final weights to. Returns: - - int: The number of updates to perform per collection. + - Dict[str, float]: A dictionary mapping task IDs to their computed weights. """ - # Retrieve the update_per_collect setting from the configuration - update_per_collect = cfg.policy.update_per_collect + if not task_returns: + return {} + + task_ids = list(task_returns.keys()) + returns_tensor = torch.tensor(list(task_returns.values()), dtype=torch.float32) + + # Step 1: Normalize the returns based on the chosen option. + scaled_returns: torch.Tensor + if option == "symlog": + scaled_returns = symlog(returns_tensor) + elif option == "max-min": + min_val, max_val = returns_tensor.min(), returns_tensor.max() + scaled_returns = (returns_tensor - min_val) / (max_val - min_val + epsilon) + elif option == "run-max-min": + # Use function attributes to maintain state across calls, avoiding global variables. + compute_task_weights.RUNNING_MAX = max(compute_task_weights.RUNNING_MAX, returns_tensor.max().item()) + compute_task_weights.RUNNING_MIN = min(compute_task_weights.RUNNING_MIN, returns_tensor.min().item()) + scaled_returns = (returns_tensor - compute_task_weights.RUNNING_MIN) / \ + (compute_task_weights.RUNNING_MAX - compute_task_weights.RUNNING_MIN + epsilon) + elif option == "rank": + sorted_indices = torch.argsort(returns_tensor) + ranks = torch.empty_like(returns_tensor) + # Ranks are from 1 to N. + ranks[sorted_indices] = torch.arange(1, len(returns_tensor) + 1, dtype=torch.float32) + scaled_returns = ranks + elif option == "none": + scaled_returns = returns_tensor + else: + raise ValueError(f"Unsupported normalization option: {option}") - if update_per_collect is None: - # If update_per_collect is not explicitly set, calculate it based on - # the number of collected transitions and the replay ratio. + # Step 2: Determine if weights should be proportional or inversely proportional to returns. + if reverse: + # Inverse proportion: smaller return -> higher weight. + raw_weights = 1.0 / (scaled_returns + epsilon) + else: + # Direct proportion: higher return -> higher weight. + raw_weights = scaled_returns + + # Step 3: Calculate final weights using either softmax or direct normalization. + final_weights: np.ndarray + safe_temperature = max(temperature, epsilon) + if use_softmax: + # Softmax provides a smooth distribution, often used with inverse weights. + # A higher beta (lower temperature) makes the distribution sharper. + beta = 1.0 / safe_temperature + # The sign depends on whether we want to favor high or low raw_weights. + # If reverse=True, raw_weights are high for low returns. We want to sample these more. + # Softmax(logits) gives higher probability to higher logits. + # So, logits should be proportional to the desired sampling probability. + logits = raw_weights if reverse else -raw_weights + final_weights = F.softmax(logits * beta, dim=0).numpy() + else: + # Direct normalization with temperature scaling. + scaled_weights = raw_weights**(1 / safe_temperature) + total_weight = scaled_weights.sum() + normalized_weights = scaled_weights / (total_weight + epsilon) + final_weights = normalized_weights.numpy() - # The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game. - # On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps. - collected_transitions_num = sum( - min(len(game_segment), cfg.policy.game_segment_length) - for game_segment in new_data[0] - ) + # Step 4: Clip weights to the desired range and create the result dictionary. + weights_dict = { + task_id: np.clip(weight, clip_min, clip_max) + for task_id, weight in zip(task_ids, final_weights) + } - if torch.cuda.is_available() and world_size > 1: - # Convert the collected transitions count to a GPU tensor for DDP operations. - collected_transitions_tensor = torch.tensor( - collected_transitions_num, dtype=torch.int64, device='cuda' - ) + return weights_dict - # Synchronize the collected transitions count across all GPUs using all-reduce. - total_collected_transitions = ddp_all_reduce_sum( - collected_transitions_tensor - ).item() +# Initialize state for the 'run-max-min' option as function attributes. +compute_task_weights.RUNNING_MAX = -float('inf') +compute_task_weights.RUNNING_MIN = float('inf') - # Calculate update_per_collect based on the total synchronized transitions count. - update_per_collect = int(total_collected_transitions * cfg.policy.replay_ratio) - # Ensure the computed update_per_collect is positive. - assert update_per_collect > 0, "update_per_collect must be positive" - else: - # If not using DDP, calculate update_per_collect directly from the local count. - update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) +class TemperatureScheduler: + """ + Overview: + A scheduler to gradually adjust a temperature value over a specified number + of training steps. This can be used for exploration or weighting schemes. - return update_per_collect + Arguments: + - initial_temp (:obj:`float`): The starting temperature. + - final_temp (:obj:`float`): The target temperature to be reached after `threshold_steps`. + - threshold_steps (:obj:`int`): The number of steps over which the temperature will anneal. + - mode (:obj:`str`): The annealing mode, either 'linear' or 'exponential'. + """ -def initialize_zeros_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str) -> torch.Tensor: + def __init__(self, initial_temp: float, final_temp: float, threshold_steps: int, mode: str = 'linear'): + if mode not in ['linear', 'exponential']: + raise ValueError("Mode must be 'linear' or 'exponential'.") + self.initial_temp = initial_temp + self.final_temp = final_temp + self.threshold_steps = max(1, threshold_steps) # Avoid division by zero + self.mode = mode + + def get_temperature(self, current_step: int) -> float: + """ + Overview: + Calculates the temperature for the given training step. + + Arguments: + - current_step (:obj:`int`): The current training step. + + Returns: + - float: The calculated temperature for the current step. + """ + if current_step >= self.threshold_steps: + return self.final_temp + + progress = current_step / self.threshold_steps + + if self.mode == 'linear': + return self.initial_temp - (self.initial_temp - self.final_temp) * progress + else: # 'exponential' + # Exponential decay from initial_temp to final_temp + # T(t) = T_initial * (T_final / T_initial)^(t / N) + if self.initial_temp <= 0: + raise ValueError("Initial temperature must be positive for exponential decay.") + scale = self.final_temp / self.initial_temp + return self.initial_temp * (scale**progress) + + +def tasks_per_stage(unsolved: int, remain_lora: int) -> int: """ Overview: - Initialize a zeros tensor for batch observations based on the shape. This function is used to initialize the UniZero model input. + Calculates the number of tasks to assign per LoRA adapter stage. + It's the ceiling of the division of unsolved tasks by remaining adapters. + Arguments: - - observation_shape (:obj:`Union[int, List[int], Tuple[int]]`): The shape of the observation tensor. - - batch_size (:obj:`int`): The batch size. - - device (:obj:`str`): The device to store the tensor. + - unsolved (:obj:`int`): The number of tasks yet to be solved. + - remain_lora (:obj:`int`): The number of available LoRA adapters. + Returns: - - zeros (:obj:`torch.Tensor`): The zeros tensor. + - int: The number of tasks to be handled in the current stage, at least 1. """ - if isinstance(observation_shape, (list, tuple)): - shape = [batch_size, *observation_shape] - elif isinstance(observation_shape, int): - shape = [batch_size, observation_shape] + return max(1, math.ceil(unsolved / max(remain_lora, 1))) + + +def compute_unizero_mt_normalized_stats( + eval_returns: Dict[int, float], + human_scores: Dict[int, float], + random_scores: Dict[int, float] +) -> Tuple[Optional[float], Optional[float]]: + """ + Overview: + Calculates the Human-Normalized Mean and Median for a set of evaluation returns. + If no valid returns are provided, it returns (None, None). + + Arguments: + - eval_returns (:obj:`Dict[int, float]`): A dictionary of evaluation returns per task ID. + - human_scores (:obj:`Dict[int, float]`): A dictionary of human expert scores per task ID. + - random_scores (:obj:`Dict[int, float]`): A dictionary of random policy scores per task ID. + + Returns: + - Tuple[Optional[float], Optional[float]]: A tuple containing the human-normalized mean and median. + """ + normalized = [] + for tid, ret in eval_returns.items(): + if ret is None or tid not in human_scores or tid not in random_scores: + continue + denom = human_scores[tid] - random_scores[tid] + if denom == 0: + continue + normalized.append((ret - random_scores[tid]) / denom) + + if not normalized: + return None, None + + arr = np.asarray(normalized, dtype=np.float32) + return float(arr.mean()), float(np.median(arr)) + + +def allocate_batch_size( + cfgs: List[EasyDict], + game_buffers: List[GameBuffer], + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + Overview: + Allocates batch sizes for different tasks inversely proportional to the + number of collected episodes for each task. It also dynamically clips + the batch size range to improve training stability. + + Arguments: + - cfgs (:obj:`List[EasyDict]`): A list of configuration objects for each task. + - game_buffers (:obj:`List[GameBuffer]`): A list of replay buffer instances for each task. + - alpha (:obj:`float`): A hyperparameter to control the degree of inverse proportionality. + - clip_scale (:obj:`int`): A scaling factor to determine the min/max batch size clip range. + + Returns: + - List[int]: A list of allocated batch sizes for each task. + """ + # This function assumes a DDP environment. + if not dist.is_available() or not dist.is_initialized(): + # Fallback for non-DDP environment if needed, though the logic is DDP-centric. + logging.warning("allocate_batch_size is designed for DDP and may not work as expected.") + world_size = 1 + rank = 0 else: - raise TypeError(f"observation_shape must be either an int, a list, or a tuple, but got {type(observation_shape).__name__}") + world_size = dist.get_world_size() + rank = dist.get_rank() + + # Extract the number of collected episodes from each local buffer. + local_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] + + # Gather episode counts from all ranks. + all_task_episodes_list = [None for _ in range(world_size)] + dist.all_gather_object(all_task_episodes_list, local_episodes) + + # Flatten the list of lists into a single list of episode counts for all tasks. + all_task_episodes = [ep for sublist in all_task_episodes_list for ep in sublist] + + if rank == 0: + logging.info(f'All task collected episodes: {all_task_episodes}') + + # Calculate weights inversely proportional to episode counts. + # Add 1 to avoid division by zero for new tasks. + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_episodes]) + inv_sum = np.sum(inv_episodes) + + # Total batch size is assumed to be consistent across configs. + total_batch_size = cfgs[0].policy.total_batch_size + + # Define dynamic clipping range for batch sizes. + avg_batch_size = total_batch_size / len(all_task_episodes) + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # Calculate batch sizes based on weights, apply alpha for smoothing. + task_weights = (inv_episodes / inv_sum)**alpha + batch_sizes = total_batch_size * task_weights + + # Clip and convert to integers. + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + batch_sizes = [int(size) for size in batch_sizes] + + return batch_sizes + + +# ============================================================================== +# Distributed Data Parallel (DDP) Utilities +# ============================================================================== + +def is_ddp_enabled() -> bool: + """ + Overview: + Checks if the environment is set up for Distributed Data Parallel (DDP) training. + + Returns: + - bool: True if `torch.distributed` is available and initialized, False otherwise. + """ + return dist.is_available() and dist.is_initialized() + + +def ddp_synchronize() -> None: + """ + Overview: + Performs a barrier synchronization across all processes in a DDP group. + This ensures that all processes reach this point before any of them proceed. + """ + if is_ddp_enabled(): + dist.barrier() + + +def ddp_all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: + """ + Overview: + Performs an all-reduce operation (sum) on a given tensor across all + processes in the DDP group. + + Arguments: + - tensor (:obj:`torch.Tensor`): The tensor to be reduced. + + Returns: + - torch.Tensor: The reduced tensor, with values summed across all processes. + """ + if is_ddp_enabled(): + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return tensor + + +# ============================================================================== +# Reinforcement Learning Workflow Utilities +# ============================================================================== + +def calculate_update_per_collect( + cfg: EasyDict, + new_data: List[List[torch.Tensor]], + world_size: int = 1 +) -> int: + """ + Overview: + Calculates the number of training updates to perform per data collection cycle. + In a DDP setting, it synchronizes transition counts across all GPUs to ensure + a consistent `update_per_collect` value. + + Arguments: + - cfg (:obj:`EasyDict`): The configuration object containing policy settings. + It's expected to have `cfg.policy.update_per_collect`, + `cfg.policy.replay_ratio`, etc. + - new_data (:obj:`List[List[torch.Tensor]]`): The newly collected data segments. + - world_size (:obj:`int`): The total number of DDP processes. + + Returns: + - int: The number of updates to perform. + """ + update_per_collect = cfg.policy.get('update_per_collect') + + if update_per_collect is not None: + return update_per_collect + + # If not explicitly set, calculate based on replay ratio. + # Note: A game segment's length can be less than `game_segment_length` if it's the + # final segment of an episode. + collected_transitions_num = sum( + min(len(game_segment), cfg.policy.game_segment_length) + for game_segment in new_data[0] + ) + + if torch.cuda.is_available() and world_size > 1: + # In DDP, synchronize the transition count across all GPUs. + collected_transitions_tensor = torch.tensor( + collected_transitions_num, dtype=torch.int64, device='cuda' + ) + total_collected_transitions = ddp_all_reduce_sum( + collected_transitions_tensor + ).item() + updates = int(total_collected_transitions * cfg.policy.replay_ratio) + else: + # In a single-process setup. + updates = int(collected_transitions_num * cfg.policy.replay_ratio) + + return max(1, updates) # Ensure at least one update. - return torch.zeros(shape).to(device) def initialize_pad_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str, pad_token_id: int = 0) -> torch.Tensor: """ @@ -140,110 +566,251 @@ def initialize_pad_batch(observation_shape: Union[int, List[int], Tuple[int]], b return torch.full(shape, fill_value=pad_token_id, dtype=torch.float32, device=device) if pad_token_id == 0 else torch.full(shape, fill_value=pad_token_id, dtype=torch.long, device=device) def random_collect( - policy_cfg: 'EasyDict', # noqa - policy: 'Policy', # noqa - RandomPolicy: 'Policy', # noqa - collector: 'ISerialCollector', # noqa - collector_env: 'BaseEnvManager', # noqa - replay_buffer: 'IBuffer', # noqa - postprocess_data_fn: Optional[Callable] = None -) -> None: # noqa - assert policy_cfg.random_collect_episode_num > 0 + policy_cfg: EasyDict, + policy: Policy, + RandomPolicy: Callable, + collector: ISerialCollector, + collector_env: BaseEnvManager, + replay_buffer: IBuffer, + postprocess_data_fn: Optional[Callable] = None +) -> None: + """ + Overview: + Performs an initial data collection phase using a random policy to populate + the replay buffer before training begins. + + Arguments: + - policy_cfg (:obj:`EasyDict`): Configuration for the policy. + - policy (:obj:`Policy`): The main training policy instance. + - RandomPolicy (:obj:`Callable`): A constructor or class for creating a random policy. + - collector (:obj:`ISerialCollector`): The data collector instance. + - collector_env (:obj:`BaseEnvManager`): The environment manager. + - replay_buffer (:obj:`IBuffer`): The replay buffer to store collected data. + - postprocess_data_fn (:obj:`Optional[Callable]`): An optional function to process data after collection. + """ + random_collect_episode_num = policy_cfg.get('random_collect_episode_num', 0) + if random_collect_episode_num <= 0: + return random_policy = RandomPolicy(cfg=policy_cfg, action_space=collector_env.env_ref.action_space) - # set the policy to random policy collector.reset_policy(random_policy.collect_mode) - # set temperature for visit count distributions according to the train_iter, - # please refer to Appendix D in MuZero paper for details. - collect_kwargs = {'temperature': 1, 'epsilon': 0.0} + # Use neutral MCTS parameters for random collection. + collect_kwargs = {'temperature': 1.0, 'epsilon': 0.0} - # Collect data by default config n_sample/n_episode. - new_data = collector.collect(n_episode=policy_cfg.random_collect_episode_num, train_iter=0, - policy_kwargs=collect_kwargs) + new_data = collector.collect( + n_episode=random_collect_episode_num, + train_iter=0, + policy_kwargs=collect_kwargs + ) - if postprocess_data_fn is not None: + if postprocess_data_fn: new_data = postprocess_data_fn(new_data) - # save returned new_data collected by the collector replay_buffer.push_game_segments(new_data) - # remove the oldest data if the replay buffer is full. replay_buffer.remove_oldest_data_to_fit() - # restore the policy + # Restore the original policy to the collector. collector.reset_policy(policy.collect_mode) -def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: +# ============================================================================== +# Logging Utilities +# ============================================================================== + +def log_module_trainable_status( + module: nn.Module, + module_name: str, + logger: logging.Logger +) -> None: """ Overview: - Log the memory usage of the buffer and the current process to TensorBoard. + Logs the detailed trainable/frozen status of all parameters within a given module. + Arguments: - - train_iter (:obj:`int`): The current training iteration. - - buffer (:obj:`GameBuffer`): The game buffer. - - writer (:obj:`SummaryWriter`): The TensorBoard writer. + - module (:obj:`nn.Module`): The module to inspect (e.g., a ViT Encoder). + - module_name (:obj:`str`): The name of the module for logging purposes. + - logger (:obj:`logging.Logger`): The logger instance to use for output. """ - # "writer is None" means we are in a slave process in the DDP setup. - if writer is not None: - writer.add_scalar('Buffer/num_of_all_collected_episodes', buffer.num_of_collected_episodes, train_iter) - writer.add_scalar('Buffer/num_of_game_segments', len(buffer.game_segment_buffer), train_iter) - writer.add_scalar('Buffer/num_of_transitions', len(buffer.game_segment_game_pos_look_up), train_iter) - - game_segment_buffer = buffer.game_segment_buffer + logger.info(f"--- Parameter Status Details for Module: '{module_name}' ---") - # Calculate the amount of memory occupied by self.game_segment_buffer (in bytes). - buffer_memory_usage = asizeof(game_segment_buffer) + total_params = 0 + trainable_params = 0 - # Convert buffer_memory_usage to megabytes (MB). - buffer_memory_usage_mb = buffer_memory_usage / (1024 * 1024) + param_list = list(module.named_parameters()) + if not param_list: + logger.info(" - No parameters found in this module.") + return - # Record the memory usage of self.game_segment_buffer to TensorBoard. - writer.add_scalar('Buffer/memory_usage/game_segment_buffer', buffer_memory_usage_mb, train_iter) + for name, param in param_list: + total_params += param.numel() + status = "Trainable" if param.requires_grad else "Frozen" + logger.info(f" - {name:<60} | Shape: {str(param.shape):<25} | Status: {status}") + if param.requires_grad: + trainable_params += param.numel() - # Get the amount of memory currently used by the process (in bytes). - process = psutil.Process(os.getpid()) - process_memory_usage = process.memory_info().rss + logger.info(f"--- Summary for Module: '{module_name}' ---") + logger.info(f" - Total Parameters: {total_params:,}") + logger.info(f" - Trainable Parameters: {trainable_params:,}") + if total_params > 0: + percentage = 100 * trainable_params / total_params + logger.info(f" - Trainable Percentage: {percentage:.4f}%") + logger.info("-" * (len(module_name) + 40)) - # Convert process_memory_usage to megabytes (MB). - process_memory_usage_mb = process_memory_usage / (1024 * 1024) - - # Record the memory usage of the process to TensorBoard. - writer.add_scalar('Buffer/memory_usage/process', process_memory_usage_mb, train_iter) +def log_param_statistics(model: nn.Module, logger: logging.Logger) -> None: + """ + Overview: + Logs a concise summary of the number and size of trainable versus total + parameters in a model. -def log_buffer_run_time(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: + Arguments: + - model (:obj:`nn.Module`): The model to analyze. + - logger (:obj:`logging.Logger`): The logger instance for output. + """ + n_tensors_total = sum(1 for _ in model.parameters()) + n_tensors_train = sum(1 for p in model.parameters() if p.requires_grad) + + n_elems_total = sum(p.numel() for p in model.parameters()) + n_elems_train = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.info( + f'Trainable Parameters: ' + f'{n_tensors_train}/{n_tensors_total} tensors | ' + f'{n_elems_train:,}/{n_elems_total:,} elements ' + f'({n_elems_train/1e6:.2f}M / {n_elems_total/1e6:.2f}M)' + ) + + +def log_buffer_memory_usage( + train_iter: int, + buffer: GameBuffer, + writer: SummaryWriter, + task_id: int = 0 +) -> None: """ Overview: - Log the average runtime metrics of the buffer to TensorBoard. + Logs the memory usage of the replay buffer and the current process to TensorBoard. + Arguments: - train_iter (:obj:`int`): The current training iteration. - - buffer (:obj:`GameBuffer`): The game buffer containing runtime metrics. - - writer (:obj:`SummaryWriter`): The TensorBoard writer for logging metrics. - - .. note:: - "writer is None" indicates that the function is being called in a slave process in the DDP setup. + - buffer (:obj:`GameBuffer`): The replay buffer instance. + - writer (:obj:`SummaryWriter`): The TensorBoard writer. + - task_id (:obj:`int`): An optional ID to distinguish logs for different tasks. """ - if writer is not None: - sample_times = buffer.sample_times + # In DDP, only the main process should write to TensorBoard. + if writer is None: + return - if sample_times == 0: - return + prefix = f"Buffer/Task_{task_id}" + writer.add_scalar(f'{prefix}/num_collected_episodes', buffer.num_of_collected_episodes, train_iter) + writer.add_scalar(f'{prefix}/num_game_segments', len(buffer.game_segment_buffer), train_iter) + writer.add_scalar(f'{prefix}/num_transitions', len(buffer.game_segment_game_pos_look_up), train_iter) - # Calculate and log average reanalyze time. - average_reanalyze_time = buffer.compute_target_re_time / sample_times - writer.add_scalar('Buffer/average_reanalyze_time', average_reanalyze_time, train_iter) + # Calculate and log memory usage of the main buffer component. + buffer_memory_bytes = asizeof(buffer.game_segment_buffer) + buffer_memory_mb = buffer_memory_bytes / (1024 * 1024) + writer.add_scalar(f'{prefix}/memory_usage_mb/game_segment_buffer', buffer_memory_mb, train_iter) - # Calculate and log average origin search time. - average_origin_search_time = buffer.origin_search_time / sample_times - writer.add_scalar('Buffer/average_origin_search_time', average_origin_search_time, train_iter) + # Get and log total memory usage of the current process. + process = psutil.Process(os.getpid()) + process_memory_bytes = process.memory_info().rss + process_memory_mb = process_memory_bytes / (1024 * 1024) + writer.add_scalar(f'{prefix}/memory_usage_mb/process', process_memory_mb, train_iter) - # Calculate and log average reuse search time. - average_reuse_search_time = buffer.reuse_search_time / sample_times - writer.add_scalar('Buffer/average_reuse_search_time', average_reuse_search_time, train_iter) - # Calculate and log average active root number. - average_active_root_num = buffer.active_root_num / sample_times - writer.add_scalar('Buffer/average_active_root_num', average_active_root_num, train_iter) +def log_buffer_run_time(train_iter: int, buffer: GameBuffer, writer: SummaryWriter) -> None: + """ + Overview: + Logs average runtime metrics related to buffer operations (e.g., sampling, search) + to TensorBoard. - # Reset the time records in the buffer. - buffer.reset_runtime_metrics() + Arguments: + - train_iter (:obj:`int`): The current training iteration. + - buffer (:obj:`GameBuffer`): The buffer instance containing runtime metrics. + - writer (:obj:`SummaryWriter`): The TensorBoard writer. + """ + if writer is None or buffer.sample_times == 0: + return + + sample_times = buffer.sample_times + writer.add_scalar('Buffer/avg_reanalyze_time_ms', (buffer.compute_target_re_time / sample_times) * 1000, train_iter) + writer.add_scalar('Buffer/avg_origin_search_time_ms', (buffer.origin_search_time / sample_times) * 1000, train_iter) + writer.add_scalar('Buffer/avg_reuse_search_time_ms', (buffer.reuse_search_time / sample_times) * 1000, train_iter) + writer.add_scalar('Buffer/avg_active_root_num', buffer.active_root_num / sample_times, train_iter) + + # Reset metrics after logging to prepare for the next interval. + buffer.reset_runtime_metrics() + + +# ============================================================================== +# Example Usage +# ============================================================================== +if __name__ == '__main__': + # Configure a basic logger to see output from functions with `verbose=True` + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + print("\n--- Example for `compute_task_weights` ---") + task_rewards_list = [ + {"task1": 10, "task2": 100, "task3": 1000, "task4": 500, "task5": 300}, + {"task1": 1, "task2": 10, "task3": 100, "task4": 1000, "task5": 10000}, + {"task1": 0.1, "task2": 0.5, "task3": 0.9, "task4": 5, "task5": 10}, + ] + + for i, task_rewards in enumerate(task_rewards_list, start=1): + print(f"\n--- Case {i} ---") + print(f"Original Rewards: {task_rewards}") + + # Example 1: Using 'none' normalization (proportional to raw values) + weights_none = compute_task_weights(task_rewards, option="none", use_softmax=False) + print(f"Weights (proportional to raw values): {weights_none}") + + # Example 2: Using 'symlog' normalization + weights_symlog = compute_task_weights(task_rewards, option="symlog", use_softmax=False) + print(f"Weights (with symlog normalization): {weights_symlog}") + + # Example 3: Using 'rank' normalization and softmax with inverse proportion + weights_rank_softmax = compute_task_weights(task_rewards, option="rank", use_softmax=True, reverse=True) + print(f"Weights (inverse rank with softmax): {weights_rank_softmax}") + + print("\n--- Example for `freeze_non_lora` ---") + + # ========================================================================== + # FIX: The nn.Parameter must be wrapped in an nn.Module subclass to be + # placed inside an nn.ModuleDict. + # ========================================================================== + class AdapterScale(nn.Module): + """A simple nn.Module wrapper for a single learnable parameter.""" + def __init__(self): + super().__init__() + self.logit = nn.Parameter(torch.randn(1)) + + # Create a dummy model to demonstrate freezing + class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.backbone = nn.Linear(10, 10) + self.layer1 = nn.Linear(10, 10) + # Simulate LoRA parameters with correct naming + self.layer1.lora_A = nn.Parameter(torch.randn(10, 2)) + self.layer1.lora_B = nn.Parameter(torch.randn(2, 10)) + + # Correctly structure the adapter_scales using the wrapper module. + # This ensures that the value associated with key '0' is a valid nn.Module. + self.adapter_scales = nn.ModuleDict({ + '0': AdapterScale() + }) + + model = DummyModel() + print("Initial parameter status:") + log_module_trainable_status(model, "DummyModel", logging.getLogger()) + + print("\nFreezing non-LoRA parameters...") + freeze_non_lora(model, freeze=True, verbose=True) + print("\nParameter status after freezing:") + log_module_trainable_status(model, "DummyModel", logging.getLogger()) + + print("\nUn-freezing non-LoRA parameters...") + freeze_non_lora(model, freeze=False, verbose=True) + print("\nParameter status after un-freezing:") + log_module_trainable_status(model, "DummyModel", logging.getLogger()) \ No newline at end of file diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 6a4458a03..253935652 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -102,22 +102,23 @@ def _make_batch(self, orig_data: Any, reanalyze_ratio: float) -> Tuple[Any]: """ pass - def _sample_orig_data(self, batch_size: int) -> Tuple: + def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) -> Tuple: """ Overview: - sample orig_data that contains: - game_segment_list: a list of game segments - pos_in_game_segment_list: transition index in game (relative index) - batch_index_list: the index of start transition of sampled minibatch in replay buffer - weights_list: the weight concerning the priority - make_time: the time the batch is made (for correctly updating replay buffer when data is deleted) + Sample original data which includes: + - game_segment_list: A list of game segments. + - pos_in_game_segment_list: Transition index in the game (relative index). + - batch_index_list: The index of the start transition of the sampled mini-batch in the replay buffer. + - weights_list: The weight concerning the priority. + - make_time: The time the batch is made (for correctly updating the replay buffer when data is deleted). Arguments: - - batch_size (:obj:`int`): batch size - - beta: float the parameter in PER for calculating the priority + - batch_size (:obj:`int`): The size of the batch. + - print_priority_logs (:obj:`bool`): Whether to print logs related to priority statistics, defaults to False. """ - assert self._beta > 0 + assert self._beta > 0, "Beta should be greater than 0" num_of_transitions = self.get_num_of_transitions() - if self._cfg.use_priority is False: + if not self._cfg.use_priority: + # If priority is not used, set all priorities to 1 self.game_pos_priorities = np.ones_like(self.game_pos_priorities) # +1e-6 for numerical stability @@ -126,20 +127,21 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: # sample according to transition index batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False) - - if self._cfg.reanalyze_outdated is True: - # NOTE: used in reanalyze part + + if self._cfg.reanalyze_outdated: + # Sort the batch indices if reanalyze is enabled batch_index_list.sort() - + + # Calculate weights for the sampled transitions weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta) - weights_list /= weights_list.max() + weights_list /= weights_list.max() # Normalize weights game_segment_list = [] pos_in_game_segment_list = [] for idx in batch_index_list: game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx] - game_segment_idx -= self.base_idx + game_segment_idx -= self.base_idx # Adjust index based on base index game_segment = self.game_segment_buffer[game_segment_idx] game_segment_list.append(game_segment) @@ -192,115 +194,152 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: pos_in_game_segment_list.append(pos_in_game_segment) - make_time = [time.time() for _ in range(len(batch_index_list))] - orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) - return orig_data - - def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple: - """ - Overview: - This function samples a batch of game segments for reanalysis from the replay buffer. - It uses priority sampling based on the `reanalyze_time` of each game segment, with segments - that have been reanalyzed more frequently receiving lower priority. - - The function returns a tuple containing information about the sampled game segments, - including their positions within each segment and the time the batch was created. - Arguments: - - batch_size (:obj:`int`): - The number of samples to draw in this batch. - - Returns: - - Tuple: - A tuple containing the following elements: - - game_segment_list: A list of the sampled game segments. - - pos_in_game_segment_list: A list of indices representing the position of each transition - within its corresponding game segment. - - batch_index_list: The indices of the sampled game segments in the replay buffer. - - make_time: A list of timestamps (set to `0` in this implementation) indicating when - the batch was created. - - Key Details: - 1. **Priority Sampling**: - Game segments are sampled based on a probability distribution calculated using - the `reanalyze_time` of each segment. Segments that have been reanalyzed more frequently - are less likely to be selected. - 2. **Segment Slicing**: - Each selected game segment is sampled at regular intervals determined by the - `num_unroll_steps` parameter. Up to `samples_per_segment` transitions are sampled - from each selected segment. - 3. **Handling Extra Samples**: - If the `batch_size` is not perfectly divisible by the number of samples per segment, - additional segments are sampled to make up the difference. - 4. **Reanalyze Time Update**: - The `reanalyze_time` attribute of each sampled game segment is incremented to reflect - that it has been selected for reanalysis again. - Raises: - - ValueError: - If the `game_segment_length` is too small to accommodate the `num_unroll_steps`. - """ - train_sample_num = len(self.game_segment_buffer) - assert self._cfg.reanalyze_partition <= 0.75, "The reanalyze partition should be less than 0.75." - valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition) - - # Calculate the number of samples per segment - samples_per_segment = self._cfg.game_segment_length // self._cfg.num_unroll_steps - - # Make sure that the batch size can be divided by the number of samples per segment - if samples_per_segment == 0: - raise ValueError("The game segment length is too small for num_unroll_steps.") - - # Calculate the number of samples per segment - batch_size_per_segment = batch_size // samples_per_segment - - # If the batch size cannot be divided, process the remainder part - extra_samples = batch_size % samples_per_segment - - # We use the reanalyze_time in the game_segment_buffer to generate weights - reanalyze_times = np.array([segment.reanalyze_time for segment in self.game_segment_buffer[:valid_sample_num]]) - - # Calculate weights: the larger the reanalyze_time, the smaller the weight (use exp(-reanalyze_time)) - base_decay_rate = 100 - decay_rate = base_decay_rate / valid_sample_num - weights = np.exp(-decay_rate * reanalyze_times) - - # Normalize the weights to a probability distribution - probabilities = weights / np.sum(weights) - - # Sample game segments according to the probabilities - selected_game_segments = np.random.choice(valid_sample_num, batch_size_per_segment, replace=False, - p=probabilities) - - # If there are extra samples to be allocated, randomly select some game segments and sample again - if extra_samples > 0: - extra_game_segments = np.random.choice(valid_sample_num, extra_samples, replace=False, p=probabilities) - selected_game_segments = np.concatenate((selected_game_segments, extra_game_segments)) - - game_segment_list = [] - pos_in_game_segment_list = [] - batch_index_list = [] - - for game_segment_idx in selected_game_segments: - game_segment_idx -= self.base_idx - game_segment = self.game_segment_buffer[game_segment_idx] - - # Update reanalyze_time only once - game_segment.reanalyze_time += 1 - - # The sampling position should be 0, 0 + num_unroll_steps, ... (integer multiples of num_unroll_steps) - for i in range(samples_per_segment): - game_segment_list.append(game_segment) - pos_in_game_segment = i * self._cfg.num_unroll_steps - if pos_in_game_segment >= len(game_segment): - pos_in_game_segment = np.random.choice(len(game_segment), 1).item() - pos_in_game_segment_list.append(pos_in_game_segment) - batch_index_list.append(game_segment_idx) + # make_time = [time.time() for _ in range(len(batch_index_list))] # Set the make_time for each sample (set to 0 for now, but can be the actual time if needed). make_time = [0. for _ in range(len(batch_index_list))] - orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time) + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) + + if print_priority_logs: + print(f"Sampled batch indices: {batch_index_list}") + print(f"Sampled priorities: {self.game_pos_priorities[batch_index_list]}") + print(f"Sampled weights: {weights_list}") + return orig_data + def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple: + """ + Overview: + This function samples a batch of game segments for reanalysis from the replay buffer. + It uses priority sampling based on the `reanalyze_time` of each game segment, with segments + that have been reanalyzed more frequently receiving lower priority. + + The function returns a tuple containing information about the sampled game segments, + including their positions within each segment and the time the batch was created. + Arguments: + - batch_size (:obj:`int`): + The number of samples to draw in this batch. + + Returns: + - Tuple: + A tuple containing the following elements: + - game_segment_list: A list of the sampled game segments. + - pos_in_game_segment_list: A list of indices representing the position of each transition + within its corresponding game segment. + - batch_index_list: The indices of the sampled game segments in the replay buffer. + - make_time: A list of timestamps (set to `0` in this implementation) indicating when + the batch was created. + + Key Details: + 1. **Priority Sampling**: + Game segments are sampled based on a probability distribution calculated using + the `reanalyze_time` of each segment. Segments that have been reanalyzed more frequently + are less likely to be selected. + 2. **Segment Slicing**: + Each selected game segment is sampled at regular intervals determined by the + `num_unroll_steps` parameter. Up to `samples_per_segment` transitions are sampled + from each selected segment. + 3. **Handling Extra Samples**: + If the `batch_size` is not perfectly divisible by the number of samples per segment, + additional segments are sampled to make up the difference. + 4. **Reanalyze Time Update**: + The `reanalyze_time` attribute of each sampled game segment is incremented to reflect + that it has been selected for reanalysis again. + Raises: + - ValueError: + If the `game_segment_length` is too small to accommodate the `num_unroll_steps`. + """ + train_sample_num = len(self.game_segment_buffer) + assert self._cfg.reanalyze_partition <= 0.75, "The reanalyze partition should be less than 0.75." + valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition) + + # Calculate the number of samples per segment + samples_per_segment = self._cfg.game_segment_length // self._cfg.num_unroll_steps + + # Make sure that the batch size can be divided by the number of samples per segment + if samples_per_segment == 0: + raise ValueError("The game segment length is too small for num_unroll_steps.") + + # Calculate the number of samples per segment + batch_size_per_segment = batch_size // samples_per_segment + + # If the batch size cannot be divided, process the remainder part + extra_samples = batch_size % samples_per_segment + + # We use the reanalyze_time in the game_segment_buffer to generate weights + reanalyze_times = np.array([segment.reanalyze_time for segment in self.game_segment_buffer[:valid_sample_num]]) + + # Calculate weights: the larger the reanalyze_time, the smaller the weight (use exp(-reanalyze_time)) + base_decay_rate = 100 + # Add a small epsilon to avoid division by zero if valid_sample_num is 0 + decay_rate = base_decay_rate / (valid_sample_num + 1e-6) + weights = np.exp(-decay_rate * reanalyze_times) + + # Normalize the weights to a probability distribution, handle case where sum is zero + sum_weights = np.sum(weights) + if sum_weights > 0: + probabilities = weights / sum_weights + else: + # If all weights are zero, use a uniform distribution + probabilities = np.ones(valid_sample_num) / valid_sample_num + + # Sample game segments according to the probabilities + # Ensure valid_sample_num is not zero before sampling + if valid_sample_num == 0: + return ([], [], [], [], []) + + selected_game_segments = np.random.choice(valid_sample_num, batch_size_per_segment, replace=False, + p=probabilities) + + # If there are extra samples to be allocated, randomly select some game segments and sample again + if extra_samples > 0: + # We need to handle the case where we might sample the same segment again. + # A simple way is to allow replacement for extra samples or sample from remaining ones. + # For simplicity, let's stick to the original logic but ensure it's safe. + remaining_segments = np.setdiff1d(np.arange(valid_sample_num), selected_game_segments) + if len(remaining_segments) < extra_samples: + # If not enough unique segments left, sample with replacement from all valid segments + extra_game_segments = np.random.choice(valid_sample_num, extra_samples, replace=True, p=probabilities) + else: + # Sample from the remaining unique segments + remaining_probs = probabilities[remaining_segments] + remaining_probs /= np.sum(remaining_probs) + extra_game_segments = np.random.choice(remaining_segments, extra_samples, replace=False, p=remaining_probs) + + selected_game_segments = np.concatenate((selected_game_segments, extra_game_segments)) + + game_segment_list = [] + pos_in_game_segment_list = [] + batch_index_list = [] + print(f"selected_game_segments:{selected_game_segments}") + for game_segment_idx in selected_game_segments: + # ========================================================================= + # FIX: The line below is the source of the error and has been removed. + # `game_segment_idx` is already a valid physical index for `game_segment_buffer`. + # game_segment_idx -= self.base_idx + # ========================================================================= + game_segment = self.game_segment_buffer[game_segment_idx] + + # Update reanalyze_time only once + game_segment.reanalyze_time += 1 + + # The sampling position should be 0, 0 + num_unroll_steps, ... (integer multiples of num_unroll_steps) + for i in range(samples_per_segment): + game_segment_list.append(game_segment) + pos_in_game_segment = i * self._cfg.num_unroll_steps + if pos_in_game_segment >= len(game_segment): + pos_in_game_segment = np.random.choice(len(game_segment), 1).item() + pos_in_game_segment_list.append(pos_in_game_segment) + # NOTE: We should append the physical index here, as it corresponds to the sampled segment. + batch_index_list.append(game_segment_idx) + + # Set the make_time for each sample (set to 0 for now, but can be the actual time if needed). + make_time = [0. for _ in range(len(batch_index_list))] + + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time) + return orig_data + def _sample_orig_reanalyze_data(self, batch_size: int) -> Tuple: """ Overview: @@ -617,7 +656,8 @@ def remove_oldest_data_to_fit(self) -> None: Overview: remove some oldest data if the replay buffer is full. """ - assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" + if isinstance(self._cfg.batch_size, int): + assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" nums_of_game_segments = self.get_num_of_game_segments() total_transition = self.get_num_of_transitions() if total_transition > self.replay_buffer_size: @@ -629,8 +669,15 @@ def remove_oldest_data_to_fit(self) -> None: # find the max game_segment index to keep in the buffer index = i break - if total_transition >= self._cfg.batch_size: - self._remove(index + 1) + if isinstance(self._cfg.batch_size, int): + if total_transition >= self._cfg.batch_size: + self._remove(index + 1) + else: + try: + if total_transition >= self._cfg.batch_size[0]: + self._remove(index + 1) + except Exception as e: + print(e) def _remove(self, excess_game_segment_index: List[int]) -> None: """ diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index faf0155a0..972a95498 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -61,6 +61,18 @@ def __init__(self, cfg: dict): self.sample_times = 0 self.active_root_num = 0 + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + try: + self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + except Exception as e: + self.action_space_size = self._cfg.model.action_space_size + + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + self.action_space_size = self._cfg.model.action_space_size self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) @@ -149,7 +161,7 @@ def sample( self.compute_target_re_time += self._compute_target_timer.value batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self._cfg.model.action_space_size + policy_non_re_context, self.action_space_size ) # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies @@ -469,17 +481,21 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A end_index = self._cfg.mini_infer_size * (i + 1) m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device) # calculate the target value - m_output = model.initial_inference(m_obs) - - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self.value_support), - m_output.policy_logits - ] - ) + if self.task_id is not None: + m_output = model.initial_inference(m_obs, task_id=self.task_id) + else: + m_output = model.initial_inference(m_obs) + + + # if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) network_output.append(m_output) @@ -594,17 +610,20 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device) - m_output = model.initial_inference(m_obs) - - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self.value_support), - m_output.policy_logits - ] - ) + if self.task_id is not None: + m_output = model.initial_inference(m_obs, task_id=self.task_id) + else: + m_output = model.initial_inference(m_obs) + + # if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) network_output.append(m_output) @@ -612,7 +631,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: reward_pool = reward_pool.squeeze().tolist() policy_logits_pool = policy_logits_pool.tolist() noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self.action_space_size ).astype(np.float32).tolist() for _ in range(transition_batch_size) ] if self._cfg.mcts_ctree: @@ -624,7 +643,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model with self._origin_search_timer: - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + self.origin_search_time += self._origin_search_timer.value else: # python mcts_tree @@ -634,7 +657,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: else: roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -650,7 +677,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if policy_mask[policy_index] == 0: # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 - target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) + target_policies.append([0 for _ in range(self.action_space_size)]) else: # NOTE: It is very important to use the latest MCTS visit count distribution. sum_visits = sum(distributions) @@ -659,7 +686,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if distributions is None: # if at some obs, the legal_action is None, add the fake target_policy target_policies.append( - list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) + list(np.ones(self.action_space_size) / self.action_space_size) ) else: # Update the data in game segment: @@ -676,7 +703,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: target_policies.append(policy) else: # for board games that have two players and legal_actions is dy - policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + policy_tmp = [0 for _ in range(self.action_space_size)] # to make sure target_policies have the same dimension sum_visits = sum(distributions) policy = [visit_count / sum_visits for visit_count in distributions] @@ -705,7 +732,7 @@ def _compute_target_policy_non_reanalyzed( - game_segment_lens - action_mask_segment - to_play_segment - - policy_shape: self._cfg.model.action_space_size + - policy_shape: self.action_space_size Returns: - batch_target_policies_non_re """ @@ -728,7 +755,7 @@ def _compute_target_policy_non_reanalyzed( ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + [-1 for _ in range(self.action_space_size)] for _ in range(transition_batch_size) ] else: legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] @@ -778,6 +805,7 @@ def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) - NOTE: train_data = [current_batch, target_batch] current_batch = [obs_list, action_list, improved_policy_list(only in Gumbel MuZero), mask_list, batch_index_list, weights, make_time_list] + target_batch = [batch_rewards, batch_target_values, batch_target_policies] """ indices = train_data[0][-3] metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities} diff --git a/lzero/mcts/buffer/game_buffer_sampled_unizero.py b/lzero/mcts/buffer/game_buffer_sampled_unizero.py index f91b7f08a..da09fc311 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_unizero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_unizero.py @@ -48,9 +48,18 @@ def __init__(self, cfg: dict): self.game_segment_buffer = [] self.game_pos_priorities = [] self.game_segment_game_pos_look_up = [] - # self.task_id = self._cfg.task_id self.sample_type = self._cfg.sample_type # 'transition' or 'episode' + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + self.action_space_size = self._cfg.model.action_space_size + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) @@ -115,21 +124,22 @@ def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) -> mask_tmp = [1. for i in range(len(root_sampled_actions_tmp))] mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + # pad random action if self._cfg.model.continuous_action_space: actions_tmp += [ - np.random.randn(self._cfg.model.action_space_size) + np.random.randn(self.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) ] root_sampled_actions_tmp += [ - np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) + np.random.rand(self._cfg.model.num_of_sampled_actions, self.action_space_size) for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) ] else: # generate random `padded actions_tmp` actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps - len(actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, 1 # Number of sampled actions for actions_tmp is 1 ) @@ -138,7 +148,7 @@ def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) -> reshape = True if self._cfg.mcts_ctree else False root_sampled_actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, self._cfg.model.num_of_sampled_actions, reshape=reshape ) @@ -277,18 +287,18 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: # pad random action if self._cfg.model.continuous_action_space: actions_tmp += [ - np.random.randn(self._cfg.model.action_space_size) + np.random.randn(self.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) ] root_sampled_actions_tmp += [ - np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) + np.random.rand(self._cfg.model.num_of_sampled_actions, self.action_space_size) for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) ] else: # generate random `padded actions_tmp` actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps - len(actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, 1 # Number of sampled actions for actions_tmp is 1 ) @@ -297,7 +307,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: reshape = True if self._cfg.mcts_ctree else False root_sampled_actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, self._cfg.model.num_of_sampled_actions, reshape=reshape ) @@ -326,7 +336,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: if self._cfg.model.continuous_action_space: # pad random action bootstrap_action_tmp += [ - np.random.randn(self._cfg.model.action_space_size) + np.random.randn(self.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(bootstrap_action_tmp)) ] bootstrap_action_list.append(bootstrap_action_tmp) @@ -489,6 +499,12 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # calculate the target value # batch_action.shape (32, 10) # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352 + + if self.task_id is not None: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num + else: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num # ======================================================================= @@ -514,18 +530,24 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # cpp mcts_tree # roots = MCTSCtree.roots(transition_batch_size, legal_actions) roots = MCTSCtree.roots( - transition_batch_size, legal_actions, self._cfg.model.action_space_size, + transition_batch_size, legal_actions, self.action_space_size, self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space ) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) else: # python mcts_tree roots = MCTSPtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -629,7 +651,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A if self._cfg.model.continuous_action_space is True: # when the action space of the environment is continuous, action_mask[:] is None. action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + list(np.ones(self.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ @@ -647,7 +669,12 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # =============== NOTE: The key difference with MuZero ================= # calculate the target value # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352 - m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) + if self.task_id is not None: + # m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep, task_id=self.task_id) + + m_output = model.initial_inference(batch_obs, batch_action, task_id=self.task_id) + else: + m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) # ====================================================================== # if not in training, obtain the scalars of the value/reward @@ -658,6 +685,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A m_output.policy_logits ] ) + network_output.append(m_output) if self._cfg.use_root_value: diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index b8998acb9..b4de66031 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy +from line_profiler import line_profiler @BUFFER_REGISTRY.register('game_buffer_unizero') @@ -48,9 +49,22 @@ def __init__(self, cfg: dict): self.game_segment_game_pos_look_up = [] self.sample_type = self._cfg.sample_type # 'transition' or 'episode' + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + try: + self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + except Exception as e: + self.action_space_size = self._cfg.model.action_space_size + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + self.action_space_size = self._cfg.model.action_space_size + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) + #@profile def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -81,7 +95,7 @@ def sample( # target policy batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, current_batch[1], current_batch[-1]) # current_batch[1] is batch_action batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self._cfg.model.action_space_size + policy_non_re_context, self.action_space_size ) # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies @@ -98,6 +112,7 @@ def sample( train_data = [current_batch, target_batch] return train_data + #@profile def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: """ Overview: @@ -133,9 +148,6 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: self._cfg.num_unroll_steps].tolist() timestep_tmp = game.timestep_segment[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps].tolist() - # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid - # mask_tmp = [1. for i in range(len(actions_tmp))] - # mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] # TODO: the child_visits after position in the segment (with padded part) may not be updated # So the corresponding position should not be used in the training @@ -278,9 +290,6 @@ def _make_batch_for_reanalyze(self, batch_size: int) -> Tuple[Any]: mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] timestep_tmp = game.timestep_segment[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps].tolist() - # TODO: original buffer mask - # mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))] - # mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] # pad random action actions_tmp += [ @@ -415,11 +424,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if self._cfg.model.continuous_action_space is True: # when the action space of the environment is continuous, action_mask[:] is None. action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + list(np.ones(self.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + [-1 for _ in range(self.action_space_size)] for _ in range(transition_batch_size) ] else: legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] @@ -435,18 +444,24 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # =============== NOTE: The key difference with MuZero ================= # To obtain the target policy from MCTS guided by the recent target model # TODO: batch_obs (policy_obs_list) is at timestep t, batch_action is at timestep t - m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + + if self.task_id is not None: + # m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num + + else: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + # ======================================================================= - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self.value_support), - m_output.policy_logits - ] - ) + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) network_output.append(m_output) @@ -454,7 +469,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: reward_pool = reward_pool.squeeze().tolist() policy_logits_pool = policy_logits_pool.tolist() noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self.action_space_size ).astype(np.float32).tolist() for _ in range(transition_batch_size) ] if self._cfg.mcts_ctree: @@ -462,13 +477,21 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: roots = MCTSCtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) + if self.task_id is not None: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + # TODO: adapt unizero multitask to timestep in rope + # MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num], task_id=self.task_id) + else: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) else: # python mcts_tree roots = MCTSPtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) + if self.task_id is not None: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num], task_id=self.task_id) + else: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -479,7 +502,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: distributions = roots_distributions[policy_index] if policy_mask[policy_index] == 0: # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 - target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) + target_policies.append([0 for _ in range(self.action_space_size)]) else: # NOTE: It is very important to use the latest MCTS visit count distribution. sum_visits = sum(distributions) @@ -488,7 +511,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if distributions is None: # if at some obs, the legal_action is None, add the fake target_policy target_policies.append( - list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) + list(np.ones(self.action_space_size) / self.action_space_size) ) else: if self._cfg.env_type == 'not_board_games': @@ -498,7 +521,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: target_policies.append(policy) else: # for board games that have two players and legal_actions is dy - policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + policy_tmp = [0 for _ in range(self.action_space_size)] # to make sure target_policies have the same dimension sum_visits = sum(distributions) policy = [visit_count / sum_visits for visit_count in distributions] @@ -543,7 +566,13 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # =============== NOTE: The key difference with MuZero ================= # calculate the bootstrapped value and target value # NOTE: batch_obs(value_obs_list) is at t+td_steps, batch_action is at timestep t+td_steps - m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) + if self.task_id is not None: + # m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep, task_id=self.task_id) + m_output = model.initial_inference(batch_obs, batch_action, task_id=self.task_id) + + else: + m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) + # ====================================================================== # if not in training, obtain the scalars of the value/reward @@ -630,3 +659,34 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_target_values = np.asarray(batch_target_values) return batch_rewards, batch_target_values + + def update_priority(self, train_data: List[np.ndarray], batch_priorities: np.ndarray) -> None: + """ + Overview: + Update the priority of training data. + Arguments: + - train_data (:obj:`List[np.ndarray]`): training data to be updated priority. + - batch_priorities (:obj:`np.ndarray`): priorities to update to. + NOTE: + train_data = [current_batch, target_batch] + current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list] + """ + # TODO: NOTE: -4 is batch_index_list + indices = train_data[0][-4] + metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities} + # only update the priorities for data still in replay buffer + for i in range(len(indices)): + # ==================== START OF FINAL FIX ==================== + + # FIX 1: Handle ValueError by using the first timestamp of the segment for comparison. + first_transition_time = metas['make_time'][i][0] + + if first_transition_time > self.clear_time: + # FIX 2: Handle IndexError by converting the float index to an integer before use. + idx = int(indices[i]) + prio = metas['batch_priorities'][i] + + # Now, idx is a valid integer index. + self.game_pos_priorities[idx] = prio + + # ===================== END OF FINAL FIX ===================== diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index ad216d196..2c45b328b 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -31,7 +31,7 @@ class GameSegment: - store_search_stats """ - def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None) -> None: + def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None, task_id = None) -> None: """ Overview: Init the ``GameSegment`` according to the provided arguments. @@ -45,19 +45,31 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea self.td_steps = config.td_steps self.frame_stack_num = config.model.frame_stack_num self.discount_factor = config.discount_factor - self.action_space_size = config.model.action_space_size + if not hasattr(config.model, "action_space_size_list"): + self.action_space_size = config.model.action_space_size self.gray_scale = config.gray_scale self.transform2string = config.transform2string self.sampled_algo = config.sampled_algo self.gumbel_algo = config.gumbel_algo self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder - if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: - # for vector obs input, e.g. classical control and box2d environments - self.zero_obs_shape = config.model.observation_shape - elif len(config.model.observation_shape) == 3: - # image obs input, e.g. atari environments - self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) + if task_id is None: + if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: + # for vector obs input, e.g. classical control and box2d environments + self.zero_obs_shape = config.model.observation_shape + elif len(config.model.observation_shape) == 3: + # image obs input, e.g. atari environments + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) + else: + if hasattr(config.model, "observation_shape_list"): + if isinstance(config.model.observation_shape_list[task_id], int) or len(config.model.observation_shape_list[task_id]) == 1: + # for vector obs input, e.g. classical control and box2d environments + self.zero_obs_shape = config.model.observation_shape_list[task_id] + elif len(config.model.observation_shape_list[task_id]) == 3: + # image obs input, e.g. atari environments + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1]) + else: + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) self.obs_segment = [] self.action_segment = [] diff --git a/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp index 7c5d11dd2..83f50e2da 100644 --- a/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp @@ -22,6 +22,7 @@ #include #include + #ifdef _WIN32 #include "..\..\common_lib\utils.cpp" #else diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index 4e238a6b3..4efcf1688 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -15,6 +15,7 @@ from lzero.mcts.ctree.ctree_muzero import mz_tree as mz_ctree from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as gmz_ctree +from line_profiler import line_profiler class UniZeroMCTSCtree(object): """ @@ -72,11 +73,11 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile + #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]], timestep: Union[int, List[Any]] - ) -> dict: + List[Any]], timestep: Union[int, List[Any]]=None, task_id=None + ) -> None: """ Overview: Perform Monte Carlo Tree Search (MCTS) for a batch of root nodes in parallel. @@ -137,7 +138,15 @@ def search( for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): latent_states.append(latent_state_batch_in_search_path[ix][iy]) - latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + try: + latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + except Exception as e: + print("="*20) + print(e) + print("roots:", roots, "latent_state_roots:", latent_state_roots) + print ("latent_state_roots.shape:", latent_state_roots.shape) + + # TODO: .long() is only for discrete action last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long() @@ -154,7 +163,23 @@ def search( # search_depth is used for rope in UniZero search_depth = results.get_search_len() # print(f'simulation_index:{simulation_index}, search_depth:{search_depth}, latent_state_index_in_search_path:{latent_state_index_in_search_path}') - network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep) + if timestep is None: + # for UniZero + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth) + else: + # for UniZero + if task_id is not None: + # multi task setting + # network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep, task_id=task_id) + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep) network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) @@ -245,10 +270,10 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile + # #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]] + List[Any]], task_id=None ) -> None: """ Overview: @@ -318,6 +343,13 @@ def search( """ network_output = model.recurrent_inference(latent_states, last_actions) # for classic muzero + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(latent_states, last_actions, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(latent_states, last_actions) + network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) network_output.value = to_detach_cpu_numpy(self.value_inverse_scalar_transform_handle(network_output.value)) @@ -516,7 +548,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "e """ return tree_muzero.Roots(active_collect_env_num, legal_actions) - # @profile + # #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], world_model_latent_history_roots: List[Any], to_play_batch: Union[int, List[Any]], ready_env_id=None, diff --git a/lzero/mcts/tree_search/mcts_ctree_sampled.py b/lzero/mcts/tree_search/mcts_ctree_sampled.py index 02f591a1f..7ab0d210e 100644 --- a/lzero/mcts/tree_search/mcts_ctree_sampled.py +++ b/lzero/mcts/tree_search/mcts_ctree_sampled.py @@ -83,7 +83,7 @@ def roots( # @profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]], timestep: Union[int, List[Any]] + List[Any]], timestep: Union[int, List[Any]], task_id=None ) -> None: """ Overview: @@ -142,6 +142,7 @@ def search( latent_states.append(latent_state_batch_in_search_path[ix][iy]) latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + if self._cfg.model.continuous_action_space is True: # continuous action last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device) @@ -159,9 +160,15 @@ def search( MCTS stage 3: Backup At the end of the simulation, the statistics along the trajectory are updated. """ + # search_depth is used for rope in UniZero + search_depth = results.get_search_len() # for Sampled UniZero - network_output = model.recurrent_inference(state_action_history, simulation_index, - latent_state_index_in_search_path, timestep) + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep) network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) @@ -169,6 +176,7 @@ def search( network_output.reward = to_detach_cpu_numpy(self.reward_inverse_scalar_transform_handle(network_output.reward)) latent_state_batch_in_search_path.append(network_output.latent_state) + # print("network_output.latent_state.shape:", network_output.latent_state.shape) # tolist() is to be compatible with cpp datatype. reward_batch = network_output.reward.reshape(-1).tolist() diff --git a/lzero/model/common.py b/lzero/model/common.py index 7b1bbeeae..88186f711 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -1,25 +1,26 @@ """ Overview: - In this Python file, we provide a collection of reusable model templates designed to streamline the development + This Python file provides a collection of reusable model templates designed to streamline the development process for various custom algorithms. By utilizing these pre-built model templates, users can quickly adapt and - customize their custom algorithms, ensuring efficient and effective development. - BTW, users can refer to the unittest of these model templates to learn how to use them. + customize their algorithms, ensuring efficient and effective development. + Users can refer to the unittest of these model templates to learn how to use them. """ import math from dataclasses import dataclass -from typing import Callable, List, Optional -from typing import Tuple +from typing import Callable, List, Optional, Tuple, Sequence import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init -from transformers import AutoModelForCausalLM, AutoTokenizer +from ditk import logging +# Assuming these imports are valid in the user's environment. +# If they are not, they should be replaced with the correct ones. from ding.torch_utils import MLP, ResBlock from ding.torch_utils.network.normalization import build_normalization -from ding.utils import SequenceType -from ditk import logging +from ding.utils import SequenceType, get_rank, get_world_size +from transformers import AutoModelForCausalLM, AutoTokenizer from ding.utils import set_pkg_seed, get_rank, get_world_size @@ -28,7 +29,7 @@ def MLP_V2( in_channels: int, hidden_channels: List[int], out_channels: int, - layer_fn: Callable = None, + layer_fn: Callable = nn.Linear, activation: Optional[nn.Module] = None, norm_type: Optional[str] = None, use_dropout: bool = False, @@ -36,118 +37,122 @@ def MLP_V2( output_activation: bool = True, output_norm: bool = True, last_linear_layer_init_zero: bool = False, -): +) -> nn.Sequential: """ Overview: - Create a multi-layer perceptron (MLP) using a list of hidden dimensions. Each layer consists of a fully + Creates a multi-layer perceptron (MLP) using a list of hidden dimensions. Each layer consists of a fully connected block with optional activation, normalization, and dropout. The final layer is configurable - to include or exclude activation, normalization, and dropout based on user preferences. - + to include or exclude activation and normalization. Arguments: - in_channels (:obj:`int`): Number of input channels (dimensionality of the input tensor). - hidden_channels (:obj:`List[int]`): A list specifying the number of channels for each hidden layer. - For example, [512, 256, 128] means the MLP will have three hidden layers with 512, 256, and 128 units, respectively. - out_channels (:obj:`int`): Number of output channels (dimensionality of the output tensor). - - layer_fn (:obj:`Callable`, optional): Layer function to construct layers (default is `nn.Linear`). - - activation (:obj:`nn.Module`, optional): Activation function to use after each layer - (e.g., `nn.ReLU`, `nn.Sigmoid`). Default is None (no activation). - - norm_type (:obj:`str`, optional): Type of normalization to apply after each layer. - If None, no normalization is applied. Supported values depend on the implementation of `build_normalization`. - - use_dropout (:obj:`bool`, optional): Whether to apply dropout after each layer. Default is False. - - dropout_probability (:obj:`float`, optional): The probability of setting elements to zero in dropout. Default is 0.5. - - output_activation (:obj:`bool`, optional): Whether to apply activation to the output layer. Default is True. - - output_norm (:obj:`bool`, optional): Whether to apply normalization to the output layer. Default is True. - - last_linear_layer_init_zero (:obj:`bool`, optional): Whether to initialize the weights and biases of the - last linear layer to zeros. This is commonly used in reinforcement learning for stable initial outputs. - + - layer_fn (:obj:`Callable`): The function to construct layers, defaults to `nn.Linear`. + - activation (:obj:`Optional[nn.Module]`): Activation function to use after each layer, defaults to None. + - norm_type (:obj:`Optional[str]`): Type of normalization to apply. If None, no normalization is applied. + - use_dropout (:obj:`bool`): Whether to apply dropout after each layer, defaults to False. + - dropout_probability (:obj:`float`): The probability for dropout, defaults to 0.5. + - output_activation (:obj:`bool`): Whether to apply activation to the output layer, defaults to True. + - output_norm (:obj:`bool`): Whether to apply normalization to the output layer, defaults to True. + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last linear layer's weights and biases to zero. Returns: - block (:obj:`nn.Sequential`): A PyTorch `nn.Sequential` object containing the layers of the MLP. - - Notes: - - The final layer's normalization, activation, and dropout are controlled by `output_activation`, - `output_norm`, and `use_dropout`. - - If `last_linear_layer_init_zero` is True, the weights and biases of the last linear layer are initialized to 0. """ - assert len(hidden_channels) > 0, "The hidden_channels list must contain at least one element." - if layer_fn is None: - layer_fn = nn.Linear - - # Initialize the MLP block - block = [] - channels = [in_channels] + hidden_channels + [out_channels] - - # Build all layers except the final layer - for i, (in_channels, out_channels) in enumerate(zip(channels[:-2], channels[1:-1])): - block.append(layer_fn(in_channels, out_channels)) - if norm_type is not None: - block.append(build_normalization(norm_type, dim=1)(out_channels)) - if activation is not None: - block.append(activation) - if use_dropout: - block.append(nn.Dropout(dropout_probability)) - - # Build the final layer - in_channels = channels[-2] - out_channels = channels[-1] - block.append(layer_fn(in_channels, out_channels)) - - # Add optional normalization and activation for the final layer - if output_norm and norm_type is not None: - block.append(build_normalization(norm_type, dim=1)(out_channels)) - if output_activation and activation is not None: - block.append(activation) - if use_dropout: - block.append(nn.Dropout(dropout_probability)) - - # Initialize the weights and biases of the last linear layer to zero if specified + if not hidden_channels: + logging.warning("hidden_channels is empty, creating a single-layer MLP.") + + layers = [] + all_channels = [in_channels] + hidden_channels + [out_channels] + num_layers = len(all_channels) - 1 + + for i in range(num_layers): + is_last_layer = (i == num_layers - 1) + layers.append(layer_fn(all_channels[i], all_channels[i+1])) + + if not is_last_layer: + # Intermediate layers + if norm_type: + layers.append(build_normalization(norm_type, dim=1)(all_channels[i+1])) + if activation: + layers.append(activation) + if use_dropout: + layers.append(nn.Dropout(dropout_probability)) + else: + # Last layer + if output_norm and norm_type: + layers.append(build_normalization(norm_type, dim=1)(all_channels[i+1])) + if output_activation and activation: + layers.append(activation) + # Note: Dropout on the final output is usually not recommended unless for specific regularization purposes. + # The original logic applied it, so we keep it for consistency. + if use_dropout: + layers.append(nn.Dropout(dropout_probability)) + + # Initialize the last linear layer to zero if specified if last_linear_layer_init_zero: - for layer in reversed(block): + for layer in reversed(layers): if isinstance(layer, nn.Linear): nn.init.zeros_(layer.weight) nn.init.zeros_(layer.bias) break - return nn.Sequential(*block) + return nn.Sequential(*layers) + + +# --- Data-structures for Network Outputs --- -# use dataclass to make the output of network more convenient to use @dataclass class MZRNNNetworkOutput: - # output format of the MuZeroRNN model + """ + Overview: + Data structure for the output of the MuZeroRNN model. + """ value: torch.Tensor value_prefix: torch.Tensor policy_logits: torch.Tensor latent_state: torch.Tensor predict_next_latent_state: torch.Tensor - reward_hidden_state: Tuple[torch.Tensor] + reward_hidden_state: Tuple[torch.Tensor, torch.Tensor] @dataclass class EZNetworkOutput: - # output format of the EfficientZero model + """ + Overview: + Data structure for the output of the EfficientZero model. + """ value: torch.Tensor value_prefix: torch.Tensor policy_logits: torch.Tensor latent_state: torch.Tensor - reward_hidden_state: Tuple[torch.Tensor] + reward_hidden_state: Tuple[torch.Tensor, torch.Tensor] @dataclass class MZNetworkOutput: - # output format of the MuZero model + """ + Overview: + Data structure for the output of the MuZero model. + """ value: torch.Tensor reward: torch.Tensor policy_logits: torch.Tensor latent_state: torch.Tensor +# --- Core Network Components --- + class SimNorm(nn.Module): + """ + Overview: + Implements Simplicial Normalization as described in the paper: https://arxiv.org/abs/2204.00616. + It groups features and applies softmax to each group. + """ def __init__(self, simnorm_dim: int) -> None: """ - Overview: - Simplicial normalization. Adapted from https://arxiv.org/abs/2204.00616. Arguments: - - simnorm_dim (:obj:`int`): The dimension for simplicial normalization. + - simnorm_dim (:obj:`int`): The size of each group (simplex) to apply softmax over. """ super().__init__() self.dim = simnorm_dim @@ -155,213 +160,205 @@ def __init__(self, simnorm_dim: int) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: """ Overview: - Forward pass of the SimNorm layer. + Forward pass for SimNorm. Arguments: - - x (:obj:`torch.Tensor`): The input tensor to normalize. + - x (:obj:`torch.Tensor`): The input tensor. Returns: - - x (:obj:`torch.Tensor`): The normalized tensor. + - (:obj:`torch.Tensor`): The tensor after applying Simplicial Normalization. """ - shp = x.shape - # Ensure that there is at least one simplex to normalize across. - if shp[1] != 0: - x = x.view(*shp[:-1], -1, self.dim) - x = F.softmax(x, dim=-1) - return x.view(*shp) - else: + if x.shape[1] == 0: return x + # Reshape to (batch, groups, dim) + x_reshaped = x.view(*x.shape[:-1], -1, self.dim) + # Apply softmax over the last dimension (the simplex) + x_softmax = F.softmax(x_reshaped, dim=-1) + # Reshape back to the original tensor shape + return x_softmax.view(*x.shape) def __repr__(self) -> str: - """ - Overview: - String representation of the SimNorm layer. - Returns: - - output (:obj:`str`): The string representation. - """ return f"SimNorm(dim={self.dim})" -def AvgL1Norm(x, eps=1e-8): +def AvgL1Norm(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: """ Overview: - Normalize the input tensor by the L1 norm. + Normalizes a tensor by the mean of its absolute values (L1 norm) along the last dimension. Arguments: - x (:obj:`torch.Tensor`): The input tensor to normalize. - - eps (:obj:`float`): The epsilon value to prevent division by zero. + - eps (:obj:`float`): A small epsilon value to prevent division by zero. Returns: - - :obj:`torch.Tensor`: The normalized tensor. + - (:obj:`torch.Tensor`): The normalized tensor. """ - return x / x.abs().mean(-1, keepdim=True).clamp(min=eps) + return x / (x.abs().mean(dim=-1, keepdim=True) + eps) class FeatureAndGradientHook: + """ + Overview: + A utility class to capture and analyze features and gradients of a specific module during + the forward and backward passes. This is useful for debugging and understanding model dynamics. + """ - def __init__(self): + def __init__(self, module: nn.Module): """ - Overview: - Class to capture features and gradients at SimNorm. + Arguments: + - module (:obj:`nn.Module`): The PyTorch module to attach the hooks to. """ self.features_before = [] self.features_after = [] self.grads_before = [] self.grads_after = [] + self.forward_handler = module.register_forward_hook(self._forward_hook) + self.backward_handler = module.register_full_backward_hook(self._backward_hook) - def setup_hooks(self, model): - # Hooks to capture features and gradients at SimNorm - self.forward_handler = model.sim_norm.register_forward_hook(self.forward_hook) - self.backward_handler = model.sim_norm.register_full_backward_hook(self.backward_hook) - - def forward_hook(self, module, input, output): + def _forward_hook(self, module: nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor) -> None: + """Hook to capture input and output features during the forward pass.""" with torch.no_grad(): - self.features_before.append(input[0]) - self.features_after.append(output) + self.features_before.append(inputs[0].clone().detach()) + self.features_after.append(output.clone().detach()) - def backward_hook(self, module, grad_input, grad_output): + def _backward_hook(self, module: nn.Module, grad_inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]) -> None: + """Hook to capture input and output gradients during the backward pass.""" with torch.no_grad(): - self.grads_before.append(grad_input[0] if grad_input[0] is not None else None) - self.grads_after.append(grad_output[0] if grad_output[0] is not None else None) + self.grads_before.append(grad_inputs[0].clone().detach() if grad_inputs[0] is not None else None) + self.grads_after.append(grad_outputs[0].clone().detach() if grad_outputs[0] is not None else None) - def analyze(self): - # Calculate L2 norms of features - l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_before])) - l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_after])) + def analyze(self) -> Tuple[float, float, float, float]: + """ + Overview: + Analyzes the captured features and gradients by computing their average L2 norms. + This method clears the stored data after analysis to free memory. + Returns: + - (:obj:`Tuple[float, float, float, float]`): A tuple containing the L2 norms of + (features_before, features_after, grads_before, grads_after). + """ + if not self.features_before: + return 0.0, 0.0, 0.0, 0.0 - # Calculate norms of gradients - grad_norm_before = torch.mean( - torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_before if g is not None])) - grad_norm_after = torch.mean( - torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_after if g is not None])) + l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2) for f in self.features_before])).item() + l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2) for f in self.features_after])).item() - # Clear stored data and delete tensors to free memory - self.clear_data() + valid_grads_before = [g for g in self.grads_before if g is not None] + grad_norm_before = torch.mean(torch.stack([torch.norm(g, p=2) for g in valid_grads_before])).item() if valid_grads_before else 0.0 - # Optionally clear CUDA cache - if torch.cuda.is_available(): - torch.cuda.empty_cache() + valid_grads_after = [g for g in self.grads_after if g is not None] + grad_norm_after = torch.mean(torch.stack([torch.norm(g, p=2) for g in valid_grads_after])).item() if valid_grads_after else 0.0 + self.clear_data() return l2_norm_before, l2_norm_after, grad_norm_before, grad_norm_after - def clear_data(self): - del self.features_before[:] - del self.features_after[:] - del self.grads_before[:] - del self.grads_after[:] + def clear_data(self) -> None: + """Clears all stored feature and gradient tensors to free up memory.""" + self.features_before.clear() + self.features_after.clear() + self.grads_before.clear() + self.grads_after.clear() + if torch.cuda.is_available(): + torch.cuda.empty_cache() - def remove_hooks(self): + def remove_hooks(self) -> None: + """Removes the registered forward and backward hooks.""" self.forward_handler.remove() self.backward_handler.remove() class DownSample(nn.Module): + """ + Overview: + A convolutional network for downsampling image-based observations, commonly used in Atari environments. + It consists of a series of convolutional, normalization, and residual blocks. + """ - def __init__(self, observation_shape: SequenceType, out_channels: int, - activation: nn.Module = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - num_resblocks: int = 1, - ) -> None: + def __init__( + self, + observation_shape: Sequence[int], + out_channels: int, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: str = 'BN', + num_resblocks: int = 1, + ) -> None: """ - Overview: - Define downSample convolution network. Encode the observation into hidden state. - This network is often used in video games like Atari. In board games like go and chess, - we don't need this module. Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[12, 96, 96] - for video games like atari, RGB 3 channel times stack 4 frames. - - out_channels (:obj:`int`): The output channels of output hidden state. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ - Use the inplace operation to speed up. - - norm_type (:obj:`Optional[str]`): The normalization type used in network, defaults to 'BN'. - - num_resblocks (:obj:`int`): The number of residual blocks. Defaults to 1. + - observation_shape (:obj:`Sequence[int]`): The shape of the input observation, e.g., (C, H, W). + - out_channels (:obj:`int`): The number of output channels. + - activation (:obj:`nn.Module`): The activation function to use. + - norm_type (:obj:`str`): The type of normalization ('BN' or 'LN'). + - num_resblocks (:obj:`int`): The number of residual blocks in each stage. """ super().__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + if norm_type not in ['BN', 'LN']: + raise ValueError(f"Unsupported norm_type: {norm_type}. Must be 'BN' or 'LN'.") + # The original design was fixed to 1 resblock per stage. + if num_resblocks != 1: + logging.warning(f"DownSample is designed for num_resblocks=1, but got {num_resblocks}.") self.observation_shape = observation_shape - self.conv1 = nn.Conv2d( - observation_shape[0], - out_channels // 2, - kernel_size=3, - stride=2, - padding=1, - bias=False, # disable bias for better convergence - ) - if norm_type == 'BN': - self.norm1 = nn.BatchNorm2d(out_channels // 2) - elif norm_type == 'LN': - self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2], - eps=1e-5) + self.activation = activation - self.resblocks1 = nn.ModuleList( - [ - ResBlock( - in_channels=out_channels // 2, - activation=activation, - norm_type=norm_type, - res_type='basic', - bias=False - ) for _ in range(num_resblocks) - ] - ) - self.downsample_block = ResBlock( - in_channels=out_channels // 2, - out_channels=out_channels, - activation=activation, - norm_type=norm_type, - res_type='downsample', - bias=False - ) - self.resblocks2 = nn.ModuleList( - [ - ResBlock( - in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_resblocks) - ] - ) + # Initial convolution: stride 2 + self.conv1 = nn.Conv2d(observation_shape[0], out_channels // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.norm1 = build_normalization(norm_type, dim=2)(out_channels // 2) + + # Stage 1 with residual blocks + self.resblocks1 = nn.ModuleList([ + ResBlock(in_channels=out_channels // 2, activation=activation, norm_type=norm_type, res_type='basic', bias=False) + for _ in range(num_resblocks) + ]) + + # Downsample block: stride 2 + self.downsample_block = ResBlock(in_channels=out_channels // 2, out_channels=out_channels, activation=activation, norm_type=norm_type, res_type='downsample', bias=False) + + # Stage 2 with residual blocks + self.resblocks2 = nn.ModuleList([ + ResBlock(in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False) + for _ in range(num_resblocks) + ]) + + # Pooling 1: stride 2 self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) - self.resblocks3 = nn.ModuleList( - [ - ResBlock( - in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(1) - ] - ) + + # Stage 3 with residual blocks + self.resblocks3 = nn.ModuleList([ + ResBlock(in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False) + for _ in range(num_resblocks) + ]) + + # Final pooling for specific input sizes self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) - self.activation = activation def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ - H is height. - - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ - output width, H_ is output height. + - x (:obj:`torch.Tensor`): (B, C_in, H, W) + - output (:obj:`torch.Tensor`): (B, C_out, H_out, W_out) + x = self.norm1(x) """ x = self.conv1(x) - x = self.norm1(x) x = self.activation(x) for block in self.resblocks1: x = block(x) + x = self.downsample_block(x) for block in self.resblocks2: x = block(x) + x = self.pooling1(x) for block in self.resblocks3: x = block(x) - # 64, 84, 96 are the most common observation shapes in Atari games. - if self.observation_shape[1] == 64: - output = x - elif self.observation_shape[1] == 84: - x = self.pooling2(x) - output = x - elif self.observation_shape[1] == 96: - x = self.pooling2(x) - output = x + # This part handles specific Atari resolutions. A more general approach might be desirable, + # but we maintain original behavior. + obs_height = self.observation_shape[1] + if obs_height == 64: + return x + elif obs_height in [84, 96]: + return self.pooling2(x) else: - raise NotImplementedError(f"DownSample for observation shape {self.observation_shape} is not implemented now. " - f"You should transform the observation shape to 64 or 96 in the env.") - - return output + raise NotImplementedError( + f"DownSample for observation height {obs_height} is not implemented. " + f"Supported heights are 64, 84, 96." + ) class QwenNetwork(nn.Module): def __init__(self, @@ -482,10 +479,6 @@ def __init__(self, final_norm_option_in_encoder: str = "layernorm", tokenizer=None): """ - Overview: - This class defines a language representation network that utilizes a pretrained Hugging Face model. - The network outputs embeddings with the specified dimension and can optionally use SimNorm or LayerNorm - for normalization at the final stage to ensure training stability. Arguments: - model_path (str): The path to the pretrained Hugging Face model. Default is 'google-bert/bert-base-uncased'. - embedding_size (int): The dimension of the output embeddings. Default is 768. @@ -494,11 +487,9 @@ def __init__(self, - tokenizer (Optional): An instance of a tokenizer. If None, the tokenizer will be loaded from the pretrained model. """ super().__init__() - from transformers import AutoModel, AutoTokenizer - logging.info(f"Loading model from: {model_path}") - # In distributed training, only the rank 0 process downloads the model, and other processes load from cache to speed up startup. + # In distributed settings, ensure only rank 0 downloads the model/tokenizer. if get_rank() == 0: self.pretrained_model = AutoModel.from_pretrained(model_path) @@ -508,18 +499,15 @@ def __init__(self, if get_rank() != 0: self.pretrained_model = AutoModel.from_pretrained(model_path) - if tokenizer is None: - # Only rank 0 downloads the tokenizer, and then other processes load it from cache. - if get_rank() == 0: - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - if get_world_size() > 1: - torch.distributed.barrier() - if get_rank() != 0: + if get_rank() != 0: + logging.info(f"Worker process is loading model from cache: {model_path}") + self.model = AutoModel.from_pretrained(model_path) + if tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained(model_path) - else: + + if tokenizer is not None: self.tokenizer = tokenizer - # Set the embedding dimension. A linear projection is added (the dimension remains unchanged here but can be extended for other mappings). self.embedding_size = embedding_size self.embed_proj_head = nn.Linear(self.pretrained_model.config.hidden_size, self.embedding_size) @@ -534,22 +522,18 @@ def __init__(self, def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: """ - Forward Propagation: - Compute the language representation based on the input token sequence. - The [CLS] token’s representation is extracted from the output of the pretrained model, - then passed through a linear projection and final normalization layer (SimNorm or LayerNorm). - + Overview: + Computes language representation from input token IDs. Arguments: - - x (torch.Tensor): Input token sequence of shape [batch_size, seq_len]. - - no_grad (bool): Whether to run in no-gradient mode for memory efficiency. Default is True. + - x (:obj:`torch.Tensor`): Input token sequence of shape (B, seq_len). + - no_grad (:obj:`bool`): If True, run the transformer model in `torch.no_grad()` context. Returns: - - torch.Tensor: The processed language embedding with shape [batch_size, embedding_size]. + - (:obj:`torch.Tensor`): The final language embedding of shape (B, embedding_size). """ # Construct the attention mask to exclude padding tokens. attention_mask = x != self.tokenizer.pad_token_id - # Use no_grad context if specified to disable gradient computation. if no_grad: with torch.no_grad(): x = x.long() # Ensure the input tensor is of type long. @@ -561,9 +545,7 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: outputs = self.pretrained_model(x, attention_mask=attention_mask) cls_embedding = outputs.last_hidden_state[:, 0, :] - # Apply linear projection to obtain the desired output dimension. cls_embedding = self.embed_proj_head(cls_embedding) - # Normalize the embeddings using the selected normalization layer (SimNorm or LayerNorm) to ensure training stability. cls_embedding = self.norm(cls_embedding) return cls_embedding @@ -640,20 +622,36 @@ def __init__( self.activation = activation self.embedding_dim = embedding_dim + # ==================== 修改开始 ==================== if self.observation_shape[1] == 64: - self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False) + # 修复:将硬编码的 64 替换为 num_channels + self.last_linear = nn.Linear(num_channels * 8 * 8, self.embedding_dim, bias=False) elif self.observation_shape[1] in [84, 96]: - self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False) + # 修复:将硬编码的 64 替换为 num_channels + self.last_linear = nn.Linear(num_channels * 6 * 6, self.embedding_dim, bias=False) + # ==================== 修改结束 ==================== - self.final_norm_option_in_encoder = final_norm_option_in_encoder - if self.final_norm_option_in_encoder == 'LayerNorm': + self.final_norm_option_in_encoder=final_norm_option_in_encoder + # 2. 在 __init__ 中统一初始化 final_norm + if self.final_norm_option_in_encoder in ['LayerNorm', 'LayerNorm_Tanh']: self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5) + elif self.final_norm_option_in_encoder == 'LayerNormNoAffine': + self.final_norm = nn.LayerNorm( + self.embedding_dim, eps=1e-5, elementwise_affine=False + ) elif self.final_norm_option_in_encoder == 'SimNorm': + # 确保 SimNorm 已被定义 self.final_norm = SimNorm(simnorm_dim=group_size) + elif self.final_norm_option_in_encoder == 'L2Norm': + # 直接实例化我们自定义的 L2Norm 模块 + self.final_norm = L2Norm(eps=1e-6) + elif self.final_norm_option_in_encoder is None: + # 如果不需要归一化,可以设置为 nn.Identity() 或 None + self.final_norm = nn.Identity() else: raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}") - + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: @@ -679,90 +677,75 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.view(-1, self.embedding_dim) # NOTE: very important for training stability. - x = self.final_norm(x) + # x = self.final_norm(x) + + # 3. 在 forward 中统一调用 self.final_norm + # 这种结构更加清晰和可扩展 + if self.final_norm is not None: + x = self.final_norm(x) + + # 针对 LayerNorm_Tanh 的特殊处理 + if self.final_norm_option_in_encoder == 'LayerNorm_Tanh': + x = torch.tanh(x) return x class RepresentationNetwork(nn.Module): - + """ + Overview: + The standard representation network used in MuZero. It encodes a 2D image observation + into a latent state, which retains its spatial dimensions. + """ def __init__( self, - observation_shape: SequenceType = (4, 96, 96), + observation_shape: Sequence[int] = (4, 96, 96), num_res_blocks: int = 1, num_channels: int = 64, downsample: bool = True, activation: nn.Module = nn.ReLU(inplace=True), norm_type: str = 'BN', - embedding_dim: int = 256, - group_size: int = 8, use_sim_norm: bool = False, + group_size: int = 8, ) -> None: """ - Overview: - Representation network used in MuZero and derived algorithms. Encode the 2D image obs into latent state. - Currently, the network only supports obs images with both a width and height of 96. Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[4, 96, 96] - for video games like atari, 1 gray channel times stack 4 frames. + - observation_shape (:obj:`Sequence[int]`): Shape of the input observation (C, H, W). - num_res_blocks (:obj:`int`): The number of residual blocks. - - num_channels (:obj:`int`): The channel of output hidden state. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ - Use the inplace operation to speed up. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - - embedding_dim (:obj:`int`): The dimension of the output hidden state. - - group_size (:obj:`int`): The size of group in the SimNorm layer. - - use_sim_norm (:obj:`bool`): Whether to use SimNorm layer, defaults to False. + - num_channels (:obj:`int`): The number of channels in the convolutional layers. + - downsample (:obj:`bool`): Whether to use the `DownSample` module. + - activation (:obj:`nn.Module`): The activation function to use. + - norm_type (:obj:`str`): Normalization type ('BN' or 'LN'). + - use_sim_norm (:obj:`bool`): Whether to apply a final `SimNorm` layer. + - group_size (:obj:`int`): Group size for `SimNorm`. """ super().__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + if norm_type not in ['BN', 'LN']: + raise ValueError(f"Unsupported norm_type: {norm_type}. Must be 'BN' or 'LN'.") self.downsample = downsample + self.activation = activation + if self.downsample: - self.downsample_net = DownSample( - observation_shape, - num_channels, - activation=activation, - norm_type=norm_type, - ) + self.downsample_net = DownSample(observation_shape, num_channels, activation, norm_type) else: self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.norm = build_normalization(norm_type, dim=3)(num_channels, *observation_shape[1:]) - if norm_type == 'BN': - self.norm = nn.BatchNorm2d(num_channels) - elif norm_type == 'LN': - if downsample: - self.norm = nn.LayerNorm( - [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], - eps=1e-5) - else: - self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) - - self.resblocks = nn.ModuleList( - [ - ResBlock( - in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_res_blocks) - ] - ) - self.activation = activation + self.resblocks = nn.ModuleList([ + ResBlock(in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False) + for _ in range(num_res_blocks) + ]) self.use_sim_norm = use_sim_norm - if self.use_sim_norm: - self.embedding_dim = embedding_dim self.sim_norm = SimNorm(simnorm_dim=group_size) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ - H is height. - - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ - output width, H_ is output height. + - x (:obj:`torch.Tensor`): (B, C_in, H, W) + - output (:obj:`torch.Tensor`): (B, C_out, H_out, W_out) """ if self.downsample: x = self.downsample_net(x) @@ -775,52 +758,51 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = block(x) if self.use_sim_norm: - # NOTE: very important. - # for atari 64,8,8 = 4096 -> 768 - x = self.sim_norm(x) - + # Flatten the spatial dimensions, apply SimNorm, and then reshape back. + b, c, h, w = x.shape + x_flat = x.view(b, c * h * w) + x_norm = self.sim_norm(x_flat) + x = x_norm.view(b, c, h, w) + return x class RepresentationNetworkMLP(nn.Module): - + """ + Overview: + An MLP-based representation network for encoding vector observations into a latent state. + """ def __init__( self, - observation_shape: int, + observation_dim: int, hidden_channels: int = 64, - layer_num: int = 2, + num_layers: int = 2, activation: nn.Module = nn.GELU(approximate='tanh'), norm_type: Optional[str] = 'BN', group_size: int = 8, final_norm_option_in_encoder: str = 'LayerNorm', # TODO ) -> torch.Tensor: """ - Overview: - Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \ - with Multi-Layer Perceptron (MLP). Arguments: - - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. - - num_res_blocks (:obj:`int`): The number of residual blocks. - - hidden_channels (:obj:`int`): The channel of output hidden state. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ - Use the inplace operation to speed up. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - observation_dim (:obj:`int`): The dimension of the input vector observation. + - hidden_channels (:obj:`int`): The number of neurons in the hidden and output layers. + - num_layers (:obj:`int`): The total number of layers in the MLP. + - activation (:obj:`nn.Module`): The activation function to use. + - norm_type (:obj:`Optional[str]`): The type of normalization ('BN', 'LN', or None). + - group_size (:obj:`int`): The group size for the final `SimNorm` layer. """ super().__init__() - self.fc_representation = MLP( - in_channels=observation_shape, - hidden_channels=hidden_channels, + # Creating hidden layers list for MLP_V2 + hidden_layers = [hidden_channels] * (num_layers - 1) if num_layers > 1 else [] + + self.fc_representation = MLP_V2( + in_channels=observation_dim, + hidden_channels=hidden_layers, out_channels=hidden_channels, - layer_num=layer_num, activation=activation, norm_type=norm_type, - # don't use activation and norm in the last layer of representation network is important for convergence. output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=True, ) @@ -836,8 +818,8 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. - - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. + - x (:obj:`torch.Tensor`): (B, observation_dim) + - output (:obj:`torch.Tensor`): (B, hidden_channels) """ x = self.fc_representation(x) x = self.norm(x) @@ -846,593 +828,414 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LatentDecoder(nn.Module): - - def __init__(self, embedding_dim: int, output_shape: SequenceType, num_channels: int = 64, activation: nn.Module = nn.GELU(approximate='tanh')): + """ + Overview: + A decoder network that reconstructs a 2D image from a 1D latent embedding. + It acts as the inverse of a representation network like `RepresentationNetworkUniZero`. + """ + def __init__( + self, + embedding_dim: int, + output_shape: Tuple[int, int, int], + num_channels: int = 64, + activation: nn.Module = nn.GELU(approximate='tanh') + ): """ - Overview: - Decoder network used in UniZero. Decode the latent state into 2D image obs. Arguments: - - embedding_dim (:obj:`int`): The dimension of the latent state. - - output_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] - for video games like atari, RGB 3 channel times stack 4 frames. - - num_channels (:obj:`int`): The channel of output hidden state. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). + - embedding_dim (:obj:`int`): The dimension of the input latent embedding. + - output_shape (:obj:`Tuple[int, int, int]`): The shape of the target output image (C, H, W). + - num_channels (:obj:`int`): The base number of channels for the initial upsampling stage. + - activation (:obj:`nn.Module`): The activation function to use. """ super().__init__() self.embedding_dim = embedding_dim - self.output_shape = output_shape # (C, H, W) - self.num_channels = num_channels - self.activation = activation - - # Assuming that the output shape is (C, H, W) = (12, 96, 96) and embedding_dim is 256 - # We will reverse the process of the representation network - self.initial_size = ( - num_channels, output_shape[1] // 8, output_shape[2] // 8) # This should match the last layer of the encoder - self.fc = nn.Linear(self.embedding_dim, np.prod(self.initial_size)) + self.output_shape = output_shape + + # This should match the spatial size of the encoder's feature map before flattening. + # Assuming a total downsampling factor of 8 (e.g., for a 64x64 -> 8x8 encoder). + self.initial_h = output_shape[1] // 8 + self.initial_w = output_shape[2] // 8 + self.initial_size = (num_channels, self.initial_h, self.initial_w) + + self.fc = nn.Linear(embedding_dim, np.prod(self.initial_size)) - # Upsampling blocks - self.conv_blocks = nn.ModuleList([ - # Block 1: (num_channels, H/8, W/8) -> (num_channels//2, H/4, W/4) + self.deconv_blocks = nn.Sequential( + # Block 1: (C, H/8, W/8) -> (C/2, H/4, W/4) nn.ConvTranspose2d(num_channels, num_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1), - self.activation, + activation, nn.BatchNorm2d(num_channels // 2), - # Block 2: (num_channels//2, H/4, W/4) -> (num_channels//4, H/2, W/2) - nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1, - output_padding=1), - self.activation, + # Block 2: (C/2, H/4, W/4) -> (C/4, H/2, W/2) + nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1, output_padding=1), + activation, nn.BatchNorm2d(num_channels // 4), - # Block 3: (num_channels//4, H/2, W/2) -> (output_shape[0], H, W) - nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1, - output_padding=1), - ]) - # TODO: last layer use sigmoid? + # Block 3: (C/4, H/2, W/2) -> (output_C, H, W) + nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1, output_padding=1), + # A final activation like Sigmoid or Tanh is often used if pixel values are in a fixed range [0,1] or [-1,1]. + # We omit it here to maintain consistency with the original code. + ) def forward(self, embeddings: torch.Tensor) -> torch.Tensor: - # Map embeddings back to the image space - x = self.fc(embeddings) # (B, embedding_dim) -> (B, C*H/8*W/8) - x = x.view(-1, *self.initial_size) # (B, C*H/8*W/8) -> (B, C, H/8, W/8) - - # Apply conv blocks - for block in self.conv_blocks: - x = block(x) # Upsample progressively - - # The output x should have the shape of (B, output_shape[0], output_shape[1], output_shape[2]) + """ + Shapes: + - embeddings (:obj:`torch.Tensor`): (B, embedding_dim) + - output (:obj:`torch.Tensor`): (B, C, H, W) + """ + x = self.fc(embeddings) + x = x.view(-1, *self.initial_size) + x = self.deconv_blocks(x) return x -class LatentEncoderForMemoryEnv(nn.Module): +# --- Networks for MemoryEnv --- +class LatentEncoderForMemoryEnv(nn.Module): + """ + Overview: + An encoder for the MemoryEnv, converting a small image observation into a latent embedding. + It uses a series of convolutions followed by adaptive average pooling. + """ def __init__( self, - image_shape=(3, 5, 5), - embedding_size=100, - channels=[16, 32, 64], - kernel_sizes=[3, 3, 3], - strides=[1, 1, 1], + image_shape: Tuple[int, int, int] = (3, 5, 5), + embedding_size: int = 100, + channels: List[int] = [16, 32, 64], + kernel_sizes: List[int] = [3, 3, 3], + strides: List[int] = [1, 1, 1], activation: nn.Module = nn.GELU(approximate='tanh'), - normalize_pixel=False, + normalize_pixel: bool = False, group_size: int = 8, - **kwargs, ): """ - Overview: - Encoder network used in UniZero in MemoryEnv. Encode the 2D image obs into latent state. Arguments: - - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] - for video games like atari, RGB 3 channel times stack 4 frames. - - embedding_size (:obj:`int`): The dimension of the latent state. - - channels (:obj:`List[int]`): The channel of output hidden state. - - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers. - - strides (:obj:`List[int]`): The stride of convolution layers. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). \ - Use the inplace operation to speed up. - - normalize_pixel (:obj:`bool`): Whether to normalize the pixel values to [0, 1], defaults to False. - - group_size (:obj:`int`): The dimension for simplicial normalization + - image_shape (:obj:`Tuple[int, int, int]`): Shape of the input image (C, H, W). + - embedding_size (:obj:`int`): Dimension of the output latent embedding. + - channels (:obj:`List[int]`): List of output channels for each convolutional layer. + - kernel_sizes (:obj:`List[int]`): List of kernel sizes for each convolutional layer. + - strides (:obj:`List[int]`): List of strides for each convolutional layer. + - activation (:obj:`nn.Module`): Activation function to use. + - normalize_pixel (:obj:`bool`): Whether to normalize input pixel values to [0, 1]. + - group_size (:obj:`int`): Group size for the final `SimNorm` layer. """ - super(LatentEncoderForMemoryEnv, self).__init__() - self.shape = image_shape - self.channels = [image_shape[0]] + list(channels) + super().__init__() + self.normalize_pixel = normalize_pixel + all_channels = [image_shape[0]] + channels layers = [] - for i in range(len(self.channels) - 1): - layers.append( - nn.Conv2d( - self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i], - padding=kernel_sizes[i] // 2 # keep the same size of feature map - ) - ) - layers.append(nn.BatchNorm2d(self.channels[i + 1])) - layers.append(activation) - + for i in range(len(channels)): + layers.extend([ + nn.Conv2d(all_channels[i], all_channels[i+1], kernel_sizes[i], strides[i], padding=kernel_sizes[i]//2), + nn.BatchNorm2d(all_channels[i+1]), + activation + ]) layers.append(nn.AdaptiveAvgPool2d(1)) - self.cnn = nn.Sequential(*layers) - self.linear = nn.Sequential( - nn.Linear(self.channels[-1], embedding_size, bias=False), - ) - init.kaiming_normal_(self.linear[0].weight, mode='fan_out', nonlinearity='relu') + + self.linear = nn.Linear(channels[-1], embedding_size, bias=False) + init.kaiming_normal_(self.linear.weight, mode='fan_out', nonlinearity='relu') - self.normalize_pixel = normalize_pixel self.sim_norm = SimNorm(simnorm_dim=group_size) - def forward(self, image): + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - image (:obj:`torch.Tensor`): (B, C, H, W) + - output (:obj:`torch.Tensor`): (B, embedding_size) + """ if self.normalize_pixel: - image = image / 255.0 - x = self.cnn(image.float()) # (B, C, 1, 1) - x = torch.flatten(x, start_dim=1) # (B, C) - x = self.linear(x) # (B, embedding_size) + image = image.float() / 255.0 + + x = self.cnn(image.float()) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) x = self.sim_norm(x) return x class LatentDecoderForMemoryEnv(nn.Module): - + """ + Overview: + A decoder for the MemoryEnv, reconstructing a small image from a latent embedding. + It uses a linear layer followed by a series of transposed convolutions. + """ def __init__( self, - image_shape=(3, 5, 5), - embedding_size=256, - channels=[64, 32, 16], - kernel_sizes=[3, 3, 3], - strides=[1, 1, 1], + image_shape: Tuple[int, int, int] = (3, 5, 5), + embedding_size: int = 256, + channels: List[int] = [64, 32, 16], + kernel_sizes: List[int] = [3, 3, 3], + strides: List[int] = [1, 1, 1], activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), - **kwargs, ): """ - Overview: - Decoder network used in UniZero in MemoryEnv. Decode the latent state into 2D image obs. Arguments: - - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] - for video games like atari, RGB 3 channel times stack 4 frames. - - embedding_size (:obj:`int`): The dimension of the latent state. - - channels (:obj:`List[int]`): The channel of output hidden state. - - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers. - - strides (:obj:`List[int]`): The stride of convolution layers. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.LeakyReLU(). \ - Use the inplace operation to speed up. + - image_shape (:obj:`Tuple[int, int, int]`): Shape of the target output image (C, H, W). + - embedding_size (:obj:`int`): Dimension of the input latent embedding. + - channels (:obj:`List[int]`): List of channels for each deconvolutional layer. + - kernel_sizes (:obj:`List[int]`): List of kernel sizes. + - strides (:obj:`List[int]`): List of strides. + - activation (:obj:`nn.Module`): Activation function for intermediate layers. """ - super(LatentDecoderForMemoryEnv, self).__init__() + super().__init__() self.shape = image_shape - self.channels = list(channels) + [image_shape[0]] - + self.deconv_channels = channels + [image_shape[0]] + self.linear = nn.Linear(embedding_size, channels[0] * image_shape[1] * image_shape[2]) layers = [] - for i in range(len(self.channels) - 1): + for i in range(len(self.deconv_channels) - 1): layers.append( nn.ConvTranspose2d( - self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i], - padding=kernel_sizes[i] // 2, output_padding=strides[i] - 1 + self.deconv_channels[i], self.deconv_channels[i+1], kernel_sizes[i], strides[i], + padding=kernel_sizes[i]//2, output_padding=strides[i]-1 ) ) - if i < len(self.channels) - 2: - layers.append(nn.BatchNorm2d(self.channels[i + 1])) - layers.append(activation) + if i < len(self.deconv_channels) - 2: + layers.extend([nn.BatchNorm2d(self.deconv_channels[i+1]), activation]) else: + # Final layer uses Sigmoid to output pixel values in [0, 1]. layers.append(nn.Sigmoid()) - self.deconv = nn.Sequential(*layers) - def forward(self, embedding): + def forward(self, embedding: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - embedding (:obj:`torch.Tensor`): (B, embedding_size) + - output (:obj:`torch.Tensor`): (B, C, H, W) + """ x = self.linear(embedding) - x = x.view(-1, self.channels[0], self.shape[1], self.shape[2]) - x = self.deconv(x) # (B, C, H, W) + x = x.view(-1, self.deconv_channels[0], self.shape[1], self.shape[2]) + x = self.deconv(x) return x class VectorDecoderForMemoryEnv(nn.Module): - + """ + Overview: + An MLP-based decoder for MemoryEnv, reconstructing a vector observation from a latent embedding. + """ def __init__( self, embedding_dim: int, - output_shape: SequenceType, + output_dim: int, hidden_channels: int = 64, - layer_num: int = 2, - activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), # TODO + num_layers: int = 2, + activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), norm_type: Optional[str] = 'BN', - ) -> torch.Tensor: + ) -> None: """ - Overview: - Decoder network used in UniZero in MemoryEnv. Decode the latent state into vector obs. Arguments: - - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. - - num_res_blocks (:obj:`int`): The number of residual blocks. - - hidden_channels (:obj:`int`): The channel of output hidden state. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \ - Use the inplace operation to speed up. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - embedding_dim (:obj:`int`): Dimension of the input latent embedding. + - output_dim (:obj:`int`): Dimension of the target output vector. + - hidden_channels (:obj:`int`): Number of neurons in the hidden layers. + - num_layers (:obj:`int`): Total number of layers in the MLP. + - activation (:obj:`nn.Module`): Activation function to use. + - norm_type (:obj:`Optional[str]`): Normalization type ('BN', 'LN', or None). """ super().__init__() - self.fc_representation = MLP( + hidden_layers = [hidden_channels] * (num_layers - 1) if num_layers > 1 else [] + + self.fc_decoder = MLP_V2( in_channels=embedding_dim, - hidden_channels=hidden_channels, - out_channels=output_shape, - layer_num=layer_num, + hidden_channels=hidden_layers, + out_channels=output_dim, activation=activation, norm_type=norm_type, - # don't use activation and norm in the last layer of representation network is important for convergence. output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=True, ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. - - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. + - x (:obj:`torch.Tensor`): (B, embedding_dim) + - output (:obj:`torch.Tensor`): (B, output_dim) """ - x = self.fc_representation(x) - return x + return self.fc_decoder(x) +# --- Prediction Networks --- class PredictionNetwork(nn.Module): - + """ + Overview: + Predicts the policy and value from a given latent state. This network is typically used + in the prediction step of MuZero-like algorithms. It processes a 2D latent state. + """ def __init__( self, - observation_shape: SequenceType, action_space_size: int, num_res_blocks: int, num_channels: int, - value_head_channels: int, - policy_head_channels: int, - value_head_hidden_channels: int, - policy_head_hidden_channels: int, - output_support_size: int, - flatten_input_size_for_value_head: int, - flatten_input_size_for_policy_head: int, - downsample: bool = False, + value_head_channels: int = 1, + policy_head_channels: int = 2, + value_head_hidden_channels: List[int] = [256], + policy_head_hidden_channels: List[int] = [256], + output_support_size: int = 601, last_linear_layer_init_zero: bool = True, activation: nn.Module = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', + norm_type: str = 'BN', ) -> None: """ - Overview: - The definition of policy and value prediction network, which is used to predict value and policy by the - given latent state. Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image. - - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. - - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. - - num_channels (:obj:`int`): The channels of hidden states. - - value_head_channels (:obj:`int`): The channels of value head. - - policy_head_channels (:obj:`int`): The channels of policy head. - - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - output_support_size (:obj:`int`): The size of categorical value output. - - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ - - flatten_input_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ - of the value head. - - flatten_input_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ - of the policy head. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ - dynamics/prediction mlp, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - action_space_size: (:obj:`int`): The size of the action space. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - num_channels (:obj:`int`): The number of channels in the input latent state. + - value_head_channels (:obj:`int`): Channels for the value head's convolutional layer. + - policy_head_channels (:obj:`int`): Channels for the policy head's convolutional layer. + - value_head_hidden_channels (:obj:`List[int]`): Hidden layer sizes for the value MLP head. + - policy_head_hidden_channels (:obj:`List[int]`): Hidden layer sizes for the policy MLP head. + - output_support_size (:obj:`int`): The size of the categorical value distribution. + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last layer of heads to zero. + - activation (:obj:`nn.Module`): The activation function. + - norm_type (:obj:`str`): The normalization type ('BN' or 'LN'). """ - super(PredictionNetwork, self).__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" - - self.resblocks = nn.ModuleList( - [ - ResBlock( - in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_res_blocks) - ] - ) + super().__init__() + if norm_type not in ['BN', 'LN']: + raise ValueError(f"Unsupported norm_type: {norm_type}. Must be 'BN' or 'LN'.") + self.resblocks = nn.ModuleList([ + ResBlock(in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False) + for _ in range(num_res_blocks) + ]) + self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) - if observation_shape[1] == 96: - latent_shape = (observation_shape[1] / 16, observation_shape[2] / 16) - elif observation_shape[1] == 64: - latent_shape = (observation_shape[1] / 8, observation_shape[2] / 8) - - if norm_type == 'BN': - self.norm_value = nn.BatchNorm2d(value_head_channels) - self.norm_policy = nn.BatchNorm2d(policy_head_channels) - elif norm_type == 'LN': - if downsample: - self.norm_value = nn.LayerNorm( - [value_head_channels, *latent_shape], - eps=1e-5) - self.norm_policy = nn.LayerNorm([policy_head_channels, *latent_shape], eps=1e-5) - else: - self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]], - eps=1e-5) - self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]], - eps=1e-5) - - self.flatten_input_size_for_value_head = flatten_input_size_for_value_head - self.flatten_input_size_for_policy_head = flatten_input_size_for_policy_head - + self.norm_value = build_normalization(norm_type, dim=2)(value_head_channels) + self.norm_policy = build_normalization(norm_type, dim=2)(policy_head_channels) self.activation = activation + # The input size for the MLP heads depends on the spatial dimensions of the latent state. + # This must be pre-calculated and passed correctly. + # Example: for a 6x6 latent space, flatten_input_size = channels * 6 * 6 + # We assume the user will provide these values. + # Here we just define placeholder attributes. + self._flatten_input_size_for_value_head = None + self._flatten_input_size_for_policy_head = None + self.fc_value = MLP_V2( - in_channels=self.flatten_input_size_for_value_head, + in_channels=-1, # Placeholder, will be determined at first forward pass hidden_channels=value_head_hidden_channels, out_channels=output_support_size, - activation=self.activation, + activation=activation, norm_type=norm_type, output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=last_linear_layer_init_zero ) self.fc_policy = MLP_V2( - in_channels=self.flatten_input_size_for_policy_head, + in_channels=-1, # Placeholder hidden_channels=policy_head_hidden_channels, out_channels=action_space_size, - activation=self.activation, + activation=activation, norm_type=norm_type, output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=last_linear_layer_init_zero ) def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Overview: - Forward computation of the prediction network. - Arguments: - - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). - Returns: - - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). - - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + Shapes: + - latent_state (:obj:`torch.Tensor`): (B, C, H, W) + - policy_logits (:obj:`torch.Tensor`): (B, action_space_size) + - value (:obj:`torch.Tensor`): (B, output_support_size) """ for res_block in self.resblocks: latent_state = res_block(latent_state) - value = self.conv1x1_value(latent_state) - value = self.norm_value(value) - value = self.activation(value) + value_feat = self.activation(self.norm_value(self.conv1x1_value(latent_state))) + policy_feat = self.activation(self.norm_policy(self.conv1x1_policy(latent_state))) + + value_flat = value_feat.view(value_feat.size(0), -1) + policy_flat = policy_feat.view(policy_feat.size(0), -1) - policy = self.conv1x1_policy(latent_state) - policy = self.norm_policy(policy) - policy = self.activation(policy) + # Dynamically initialize in_channels on the first forward pass + if self.fc_value.in_channels == -1: + self.fc_value[0].in_features = value_flat.shape[1] + self.fc_policy[0].in_features = policy_flat.shape[1] + # PyTorch lazy modules handle this better, but this is a manual way. + self.fc_value[0].weight.data.uniform_(-math.sqrt(1/value_flat.shape[1]), math.sqrt(1/value_flat.shape[1])) + self.fc_policy[0].weight.data.uniform_(-math.sqrt(1/policy_flat.shape[1]), math.sqrt(1/policy_flat.shape[1])) - value = value.reshape(-1, self.flatten_input_size_for_value_head) - policy = policy.reshape(-1, self.flatten_input_size_for_policy_head) - value = self.fc_value(value) - policy = self.fc_policy(policy) - return policy, value + value = self.fc_value(value_flat) + policy_logits = self.fc_policy(policy_flat) + return policy_logits, value class PredictionNetworkMLP(nn.Module): - + """ + Overview: + An MLP-based prediction network that predicts policy and value from a 1D latent state. + """ def __init__( self, - action_space_size, - num_channels, + action_space_size: int, + num_channels: int, common_layer_num: int = 2, - value_head_hidden_channels: SequenceType = [32], - policy_head_hidden_channels: SequenceType = [32], + value_head_hidden_channels: List[int] = [32], + policy_head_hidden_channels: List[int] = [32], output_support_size: int = 601, last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), + activation: nn.Module = nn.ReLU(inplace=True), norm_type: Optional[str] = 'BN', ): """ - Overview: - The definition of policy and value prediction network with Multi-Layer Perceptron (MLP), - which is used to predict value and policy by the given latent state. Arguments: - - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \ - space, it is the number of discrete actions. - - num_channels (:obj:`int`): The channels of latent states. - - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - output_support_size (:obj:`int`): The size of categorical value output. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ - dynamics/prediction mlp, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - action_space_size: (:obj:`int`): The size of the action space. + - num_channels (:obj:`int`): The dimension of the input latent state. + - common_layer_num (:obj:`int`): Number of layers in the shared backbone MLP. + - value_head_hidden_channels (:obj:`List[int]`): Hidden layer sizes for the value MLP head. + - policy_head_hidden_channels (:obj:`List[int]`): Hidden layer sizes for the policy MLP head. + - output_support_size (:obj:`int`): The size of the categorical value distribution. + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last layer of heads to zero. + - activation (:obj:`nn.Module`): The activation function. + - norm_type (:obj:`Optional[str]`): The normalization type. """ super().__init__() - self.num_channels = num_channels - - # ******* common backbone ****** - self.fc_prediction_common = MLP( - in_channels=self.num_channels, - hidden_channels=self.num_channels, - out_channels=self.num_channels, - layer_num=common_layer_num, + + common_hidden = [num_channels] * (common_layer_num - 1) if common_layer_num > 1 else [] + self.fc_prediction_common = MLP_V2( + in_channels=num_channels, + hidden_channels=common_hidden, + out_channels=num_channels, activation=activation, norm_type=norm_type, output_activation=True, output_norm=True, - # last_linear_layer_init_zero=False is important for convergence last_linear_layer_init_zero=False, ) - # ******* value and policy head ****** self.fc_value_head = MLP_V2( - in_channels=self.num_channels, + in_channels=num_channels, hidden_channels=value_head_hidden_channels, out_channels=output_support_size, activation=activation, norm_type=norm_type, output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=last_linear_layer_init_zero ) self.fc_policy_head = MLP_V2( - in_channels=self.num_channels, + in_channels=num_channels, hidden_channels=policy_head_hidden_channels, out_channels=action_space_size, activation=activation, norm_type=norm_type, output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=last_linear_layer_init_zero ) - def forward(self, latent_state: torch.Tensor): - """ - Overview: - Forward computation of the prediction network. - Arguments: - - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). - Returns: - - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). - - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). - """ - x_prediction_common = self.fc_prediction_common(latent_state) - - value = self.fc_value_head(x_prediction_common) - policy = self.fc_policy_head(x_prediction_common) - return policy, value - - -class PredictionHiddenNetwork(nn.Module): - - def __init__( - self, - observation_shape: SequenceType, - action_space_size: int, - num_res_blocks: int, - num_channels: int, - value_head_channels: int, - policy_head_channels: int, - value_head_hidden_channels: int, - policy_head_hidden_channels: int, - output_support_size: int, - flatten_input_size_for_value_head: int, - flatten_input_size_for_policy_head: int, - downsample: bool = False, - last_linear_layer_init_zero: bool = True, - activation: nn.Module = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - gru_hidden_size: int = 512, - ) -> None: - """ - Overview: - The definition of policy and value prediction network, which is used to predict value and policy by the - given latent state. - Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image. - - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. - - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. - - num_channels (:obj:`int`): The channels of hidden states. - - value_head_channels (:obj:`int`): The channels of value head. - - policy_head_channels (:obj:`int`): The channels of policy head. - - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - output_support_size (:obj:`int`): The size of categorical value output. - - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ - - flatten_input_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ - of the value head. - - flatten_input_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ - of the policy head. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ - dynamics/prediction mlp, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - """ - super(PredictionHiddenNetwork, self).__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" - - self.observation_shape = observation_shape - self.gru_hidden_size = gru_hidden_size - self.resblocks = nn.ModuleList( - [ - ResBlock( - in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_res_blocks) - ] - ) - - self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) - self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) - - if norm_type == 'BN': - self.norm_value = nn.BatchNorm2d(value_head_channels) - self.norm_policy = nn.BatchNorm2d(policy_head_channels) - elif norm_type == 'LN': - if downsample: - self.norm_value = nn.LayerNorm( - [value_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], - eps=1e-5) - self.norm_policy = nn.LayerNorm([policy_head_channels, math.ceil(observation_shape[-2] / 16), - math.ceil(observation_shape[-1] / 16)], eps=1e-5) - else: - self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]], - eps=1e-5) - self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]], - eps=1e-5) - - self.flatten_input_size_for_value_head = flatten_input_size_for_value_head - self.flatten_input_size_for_policy_head = flatten_input_size_for_policy_head - - self.activation = activation - - self.fc_value = MLP( - in_channels=self.flatten_input_size_for_value_head + self.gru_hidden_size, - hidden_channels=value_head_hidden_channels[0], - out_channels=output_support_size, - layer_num=len(value_head_hidden_channels) + 1, - activation=self.activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - self.fc_policy = MLP( - in_channels=self.flatten_input_size_for_policy_head + self.gru_hidden_size, - hidden_channels=policy_head_hidden_channels[0], - out_channels=action_space_size, - layer_num=len(policy_head_hidden_channels) + 1, - activation=self.activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - - def forward(self, latent_state: torch.Tensor, world_model_latent_history: torch.Tensor) -> Tuple[ - torch.Tensor, torch.Tensor]: + def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Overview: - Forward computation of the prediction network. - Arguments: - - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). - Returns: - - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). - - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + Shapes: + - latent_state (:obj:`torch.Tensor`): (B, num_channels) + - policy_logits (:obj:`torch.Tensor`): (B, action_space_size) + - value (:obj:`torch.Tensor`): (B, output_support_size) """ - for res_block in self.resblocks: - latent_state = res_block(latent_state) - - value = self.conv1x1_value(latent_state) - value = self.norm_value(value) - value = self.activation(value) - - policy = self.conv1x1_policy(latent_state) - policy = self.norm_policy(policy) - policy = self.activation(policy) - - latent_state_value = value.reshape(-1, self.flatten_input_size_for_value_head) - latent_state_policy = policy.reshape(-1, self.flatten_input_size_for_policy_head) - - # TODO: world_model_latent_history.squeeze(0) shape: (num_layers * num_directions, batch_size, hidden_size) -> ( batch_size, hidden_size) - latent_history_value = torch.cat([latent_state_value, world_model_latent_history.squeeze(0)], dim=1) - latent_history_policy = torch.cat([latent_state_policy, world_model_latent_history.squeeze(0)], dim=1) - - value = self.fc_value(latent_history_value) - policy = self.fc_policy(latent_history_policy) - return policy, value \ No newline at end of file + x = self.fc_prediction_common(latent_state) + value = self.fc_value_head(x) + policy_logits = self.fc_policy_head(x) + return policy_logits, value \ No newline at end of file diff --git a/lzero/model/muzero_model_multitask.py b/lzero/model/muzero_model_multitask.py new file mode 100644 index 000000000..cb30b3d38 --- /dev/null +++ b/lzero/model/muzero_model_multitask.py @@ -0,0 +1,549 @@ +from typing import Optional, Tuple, Sequence, List + +import math +import torch +import torch.nn as nn +from ding.torch_utils import MLP, ResBlock +from ding.utils import MODEL_REGISTRY, SequenceType +from numpy import ndarray + +# The following imports are assumed to be from the same project directory. +# To maintain API consistency, their internal logic is not modified. +from .common import MZNetworkOutput, RepresentationNetwork, PredictionNetwork, FeatureAndGradientHook +from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean + + +@MODEL_REGISTRY.register('MuZeroMTModel') +class MuZeroMTModel(nn.Module): + """ + Overview: + The Multi-Task MuZero model, which is a variant of the original MuZero model adapted for multi-task learning. + This model features a shared representation network and dynamics network, but utilizes separate, task-specific + prediction networks. This architecture allows the model to learn shared dynamics while specializing its + policy and value predictions for each individual task. + """ + # Default configuration for the model. + # This structure is recommended over using cfg.get('key', default_value) inside the code. + config = dict( + observation_shape=(12, 96, 96), + action_space_size=6, + num_res_blocks=1, + num_channels=64, + reward_head_channels=16, + value_head_channels=16, + policy_head_channels=16, + fc_reward_layers=[32], + fc_value_layers=[32], + fc_policy_layers=[32], + reward_support_size=601, + value_support_size=601, + proj_hid=1024, + proj_out=1024, + pred_hid=512, + pred_out=1024, + self_supervised_learning_loss=False, + categorical_distribution=True, + activation=nn.ReLU(inplace=True), + last_linear_layer_init_zero=True, + state_norm=False, + downsample=False, + norm_type='BN', + discrete_action_encoding_type='one_hot', + analysis_sim_norm=False, + task_num=1, + ) + + def __init__( + self, + observation_shape: SequenceType = (12, 96, 96), + action_space_size: int = 6, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 16, + value_head_channels: int = 16, + policy_head_channels: int = 16, + fc_reward_layers: List[int] = [32], + fc_value_layers: List[int] = [32], + fc_policy_layers: List[int] = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: Optional[nn.Module] = None, + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + downsample: bool = False, + norm_type: Optional[str] = 'BN', + discrete_action_encoding_type: str = 'one_hot', + analysis_sim_norm: bool = False, + task_num: int = 1, + *args, + **kwargs + ) -> None: + """ + Overview: + Constructor for the MuZeroMTModel. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of the input observation, e.g., (12, 96, 96). + - action_space_size (:obj:`int`): The size of the action space, applicable for discrete action spaces. + - num_res_blocks (:obj:`int`): The number of residual blocks in the representation, dynamics, and prediction networks. + - num_channels (:obj:`int`): The number of channels in the latent state. + - reward_head_channels (:obj:`int`): The number of channels in the reward head. + - value_head_channels (:obj:`int`): The number of channels in the value head. + - policy_head_channels (:obj:`int`): The number of channels in the policy head. + - fc_reward_layers (:obj:`List[int]`): The hidden layer sizes of the reward MLP. + - fc_value_layers (:obj:`List[int]`): The hidden layer sizes of the value MLP. + - fc_policy_layers (:obj:`List[int]`): The hidden layer sizes of the policy MLP. + - reward_support_size (:obj:`int`): The support size for categorical reward distribution. + - value_support_size (:obj:`int`): The support size for categorical value distribution. + - proj_hid (:obj:`int`): The hidden size of the projection network for SSL. + - proj_out (:obj:`int`): The output size of the projection network for SSL. + - pred_hid (:obj:`int`): The hidden size of the prediction head for SSL. + - pred_out (:obj:`int`): The output size of the prediction head for SSL. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self-supervised learning loss. + - categorical_distribution (:obj:`bool`): Whether to use categorical distribution for value and reward. + - activation (:obj:`Optional[nn.Module]`): The activation function to use. Defaults to nn.ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last linear layer to zero. + - state_norm (:obj:`bool`): Whether to apply re-normalization to the latent state. + - downsample (:obj:`bool`): Whether to downsample the observation image. + - norm_type (:obj:`Optional[str]`): The type of normalization to use, either 'BN' (BatchNorm) or 'LN' (LayerNorm). + - discrete_action_encoding_type (:obj:`str`): The encoding type for discrete actions, 'one_hot' or 'not_one_hot'. + - analysis_sim_norm (:obj:`bool`): A flag for analysis, enables hooks for SimNorm analysis. + - task_num (:obj:`int`): The total number of tasks for the multi-task setup. + """ + super(MuZeroMTModel, self).__init__() + if activation is None: + activation = nn.ReLU(inplace=True) + + # --- Store configuration --- + self.action_space_size = action_space_size + self.categorical_distribution = categorical_distribution + self.self_supervised_learning_loss = self_supervised_learning_loss + self.state_norm = state_norm + self.downsample = downsample + self.task_num = task_num + self.discrete_action_encoding_type = discrete_action_encoding_type + + if self.categorical_distribution: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + else: + self.reward_support_size = 1 + self.value_support_size = 1 + + # --- Prepare observation shape and action encoding dimension --- + if isinstance(observation_shape, int) or len(observation_shape) == 1: + # For 1D vector observations (e.g., classic control), wrap them into a 2D image-like format [C, W, H] + # to be compatible with the convolutional networks. + observation_shape = (1, observation_shape[0], 1) if isinstance(observation_shape, tuple) else (1, observation_shape, 1) + + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = self.action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + else: + raise ValueError(f"Unsupported discrete_action_encoding_type: {self.discrete_action_encoding_type}") + + latent_size = self._get_latent_size(observation_shape, self.downsample) + + # --- Initialize Network Components --- + + # 1. Shared Representation Network + self.representation_network = RepresentationNetwork( + observation_shape=observation_shape, + num_res_blocks=num_res_blocks, + num_channels=num_channels, + downsample=self.downsample, + activation=activation, + norm_type=norm_type + ) + + # 2. Shared Dynamics Network + self.dynamics_network = DynamicsNetwork( + observation_shape=observation_shape, + action_encoding_dim=self.action_encoding_dim, + num_res_blocks=num_res_blocks, + num_channels=num_channels + self.action_encoding_dim, + reward_head_channels=reward_head_channels, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + flatten_output_size_for_reward_head=reward_head_channels * latent_size, + downsample=self.downsample, + last_linear_layer_init_zero=last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type + ) + + # 3. Task-Specific Prediction Networks + self.prediction_networks = nn.ModuleList([ + PredictionNetwork( + observation_shape=observation_shape, + action_space_size=self.action_space_size, + num_res_blocks=num_res_blocks, + num_channels=num_channels, + value_head_channels=value_head_channels, + policy_head_channels=policy_head_channels, + fc_value_layers=fc_value_layers, + fc_policy_layers=fc_policy_layers, + output_support_size=self.value_support_size, + flatten_output_size_for_value_head=value_head_channels * latent_size, + flatten_output_size_for_policy_head=policy_head_channels * latent_size, + downsample=self.downsample, + last_linear_layer_init_zero=last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type + ) for _ in range(self.task_num) + ]) + + # 4. Optional Self-Supervised Learning (SSL) Components + if self.self_supervised_learning_loss: + self.projection_network = nn.Sequential( + nn.Linear(num_channels * latent_size, proj_hid), + nn.BatchNorm1d(proj_hid), + activation, + nn.Linear(proj_hid, proj_hid), + nn.BatchNorm1d(proj_hid), + activation, + nn.Linear(proj_hid, proj_out), + nn.BatchNorm1d(proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(proj_out, pred_hid), + nn.BatchNorm1d(pred_hid), + activation, + nn.Linear(pred_hid, pred_out), + ) + + # 5. Optional Hook for Analysis + if analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + @staticmethod + def _get_latent_size(observation_shape: SequenceType, downsample: bool) -> int: + """ + Overview: + Helper function to calculate the flattened size of the latent space based on observation shape and downsampling. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of the input observation. + - downsample (:obj:`bool`): Whether downsampling is enabled. + Returns: + - int: The flattened size (height * width) of the latent space. + """ + if downsample: + # With downsampling, the spatial dimensions are reduced by a factor of 16 (2^4). + return math.ceil(observation_shape[-2] / 16) * math.ceil(observation_shape[-1] / 16) + else: + return observation_shape[-2] * observation_shape[-1] + + def initial_inference(self, obs: torch.Tensor, task_id: int = 0) -> MZNetworkOutput: + """ + Overview: + Performs the initial inference from a raw observation. It encodes the observation into a latent state + and then uses the task-specific prediction network to compute the policy and value. + Arguments: + - obs (:obj:`torch.Tensor`): The raw observation tensor. + - task_id (:obj:`int`): The identifier for the current task, used to select the correct prediction network. + Returns: + - MZNetworkOutput: A dataclass containing the predicted value, reward (initially zero), policy logits, and latent state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, C, H, W)`, where B is batch size. + - task_id (:obj:`int`): Scalar. + - Return.value: :math:`(B, value_support_size)`. + - Return.reward: :math:`(B, reward_support_size)`. + - Return.policy_logits: :math:`(B, action_space_size)`. + - Return.latent_state: :math:`(B, num_channels, H', W')`. + """ + batch_size = obs.size(0) + latent_state = self.representation_network(obs) + if self.state_norm: + latent_state = renormalize(latent_state) + + # Select the prediction network based on the task ID. + assert 0 <= task_id < self.task_num, f"Task ID {task_id} is out of range [0, {self.task_num-1}]" + prediction_net = self.prediction_networks[task_id] + policy_logits, value = prediction_net(latent_state) + + return MZNetworkOutput( + value=value, + reward=[0. for _ in range(batch_size)], # Initial reward is always zero. + policy_logits=policy_logits, + latent_state=latent_state, + ) + + def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor, task_id: int = 0) -> MZNetworkOutput: + """ + Overview: + Performs recurrent inference from a latent state and an action. It uses the dynamics network to predict + the next latent state and reward, and then uses the task-specific prediction network to compute the + policy and value for the next state. + Arguments: + - latent_state (:obj:`torch.Tensor`): The current latent state. + - action (:obj:`torch.Tensor`): The action taken in the current state. + - task_id (:obj:`int`): The identifier for the current task. + Returns: + - MZNetworkOutput: A dataclass containing the predicted value, reward, policy logits, and the next latent state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, num_channels, H', W')`. + - action (:obj:`torch.Tensor`): :math:`(B, )`. + - task_id (:obj:`int`): Scalar. + - Return.value: :math:`(B, value_support_size)`. + - Return.reward: :math:`(B, reward_support_size)`. + - Return.policy_logits: :math:`(B, action_space_size)`. + - Return.latent_state: :math:`(B, num_channels, H', W')`. + """ + next_latent_state, reward = self._dynamics(latent_state, action) + + if self.state_norm: + next_latent_state = renormalize(next_latent_state) + + # Select the prediction network based on the task ID. + assert 0 <= task_id < self.task_num, f"Task ID {task_id} is out of range [0, {self.task_num-1}]" + prediction_net = self.prediction_networks[task_id] + policy_logits, value = prediction_net(next_latent_state) + + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) + + def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Applies the dynamics function by concatenating the latent state with the encoded action and passing it + through the dynamics network to predict the next latent state and reward. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of the input state. + - action (:obj:`torch.Tensor`): The action to rollout. + Returns: + - Tuple[torch.Tensor, torch.Tensor]: A tuple containing the predicted next latent state and reward. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, C, H', W')`. + - action (:obj:`torch.Tensor`): :math:`(B, )`. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, C, H', W')`. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`. + """ + # Encode the action and expand it to match the spatial dimensions of the latent state. + if self.discrete_action_encoding_type == 'one_hot': + # Convert action indices to one-hot vectors. + action_one_hot = F.one_hot(action.long(), num_classes=self.action_space_size).float() + # Reshape for broadcasting: (B, A) -> (B, A, 1, 1) + action_encoding_tmp = action_one_hot.unsqueeze(-1).unsqueeze(-1) + # Expand to (B, A, H', W') + action_encoding = action_encoding_tmp.expand( + latent_state.shape[0], self.action_space_size, latent_state.shape[2], latent_state.shape[3] + ) + elif self.discrete_action_encoding_type == 'not_one_hot': + # Encode action as a single channel, normalized by action space size. + # Reshape for broadcasting: (B,) -> (B, 1, 1, 1) + action_encoding_tmp = action.float().view(-1, 1, 1, 1) + # Normalize and expand to (B, 1, H', W') + action_encoding = action_encoding_tmp / self.action_space_size + action_encoding = action_encoding.expand( + latent_state.shape[0], 1, latent_state.shape[2], latent_state.shape[3] + ) + + # Concatenate latent state and action encoding along the channel dimension. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + # Predict next state and reward. + next_latent_state, reward = self.dynamics_network(state_action_encoding) + + if self.state_norm: + next_latent_state = renormalize(next_latent_state) + + return next_latent_state, reward + + def project(self, latent_state: torch.Tensor, with_grad: bool = True) -> torch.Tensor: + """ + Overview: + Projects the latent state into a different space for self-supervised learning (e.g., BYOL, SimSiam). + This involves a projection network and an optional prediction head. + Arguments: + - latent_state (:obj:`torch.Tensor`): The latent state to project. + - with_grad (:obj:`bool`): If False, detach the output of the projection network to stop gradients. + This is typically used for the target network in SSL. + Returns: + - torch.Tensor: The projected (and possibly predicted) representation. + """ + if not self.self_supervised_learning_loss: + raise NotImplementedError("The 'project' method requires 'self_supervised_learning_loss' to be enabled.") + + # Flatten the latent state from (B, C, H, W) to (B, C*H*W). + latent_state = latent_state.reshape(latent_state.shape[0], -1) + + proj = self.projection_network(latent_state) + + if with_grad: + # Return the output of the prediction head, with gradients flowing. + return self.prediction_head(proj) + else: + # Return the output of the projection network, detached from the graph. + return proj.detach() + + def get_params_mean(self) -> float: + """ + Overview: + Computes the mean of all model parameters. Useful for debugging and monitoring training. + Returns: + - float: The mean value of all parameters. + """ + return get_params_mean(self) + + +class DynamicsNetwork(nn.Module): + """ + Overview: + The dynamics network of the MuZero model. It takes a state-action encoding as input and predicts + the next latent state and the reward for the transition. This network is shared across all tasks + in the multi-task setup. + """ + + def __init__( + self, + observation_shape: SequenceType, + action_encoding_dim: int = 2, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 64, + fc_reward_layers: List[int] = [32], + output_support_size: int = 601, + flatten_output_size_for_reward_head: int = 64, + downsample: bool = False, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = None, + norm_type: Optional[str] = 'BN', + ) -> None: + """ + Overview: + Constructor for the DynamicsNetwork. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of the original input observation. + - action_encoding_dim (:obj:`int`): The dimension of the encoded action. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - num_channels (:obj:`int`): The number of channels in the input (latent_state + action_encoding). + - reward_head_channels (:obj:`int`): The number of channels for the reward head's convolutional layer. + - fc_reward_layers (:obj:`List[int]`): The hidden layer sizes of the reward MLP. + - output_support_size (:obj:`int`): The support size for the categorical reward distribution. + - flatten_output_size_for_reward_head (:obj:`int`): The flattened input size for the reward MLP. + - downsample (:obj:`bool`): Whether downsampling is used, affecting LayerNorm shapes. + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last linear layer to zero. + - activation (:obj:`Optional[nn.Module]`): The activation function. Defaults to nn.ReLU(inplace=True). + - norm_type (:obj:`Optional[str]`): The type of normalization, 'BN' or 'LN'. + """ + super().__init__() + if activation is None: + activation = nn.ReLU(inplace=True) + + assert norm_type in ['BN', 'LN'], f"norm_type must be 'BN' or 'LN', but got {norm_type}" + # The input channels to the first conv layer is num_channels, which includes the original latent channels + # and the action encoding channels. The output should be the number of channels for the latent state. + latent_channels = num_channels - action_encoding_dim + assert latent_channels > 0, f"num_channels ({num_channels}) must be greater than action_encoding_dim ({action_encoding_dim})" + + self.action_encoding_dim = action_encoding_dim + self.activation = activation + + # Convolutional layer to process the combined state-action encoding. + self.conv = nn.Conv2d(num_channels, latent_channels, kernel_size=3, stride=1, padding=1, bias=False) + + # Normalization layer for the main path. + if norm_type == 'BN': + self.norm_common = nn.BatchNorm2d(latent_channels) + elif norm_type == 'LN': + if downsample: + ln_shape = [latent_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)] + else: + ln_shape = [latent_channels, observation_shape[-2], observation_shape[-1]] + self.norm_common = nn.LayerNorm(ln_shape) + + # A series of residual blocks to deepen the network. + self.resblocks = nn.ModuleList( + [ResBlock(in_channels=latent_channels, activation=activation, norm_type='BN', res_type='basic', bias=False) + for _ in range(num_res_blocks)] + ) + + # --- Reward Head --- + # 1x1 convolution to create an input for the reward MLP. + self.conv1x1_reward = nn.Conv2d(latent_channels, reward_head_channels, 1) + + # Normalization for the reward head. + if norm_type == 'BN': + self.norm_reward = nn.BatchNorm2d(reward_head_channels) + elif norm_type == 'LN': + if downsample: + ln_shape_reward = [reward_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)] + else: + ln_shape_reward = [reward_head_channels, observation_shape[-2], observation_shape[-1]] + self.norm_reward = nn.LayerNorm(ln_shape_reward) + + # MLP to predict the reward value from the processed features. + self.fc_reward_head = MLP( + in_channels=flatten_output_size_for_reward_head, + hidden_channels=fc_reward_layers[0], + out_channels=output_support_size, + layer_num=len(fc_reward_layers) + 1, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward pass for the dynamics network. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The concatenated latent state and action encoding. + Returns: + - Tuple[torch.Tensor, torch.Tensor]: A tuple containing the next latent state and the predicted reward. + Shapes: + - state_action_encoding (:obj:`torch.Tensor`): :math:`(B, C_latent + C_action, H', W')`. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, C_latent, H', W')`. + - reward (:obj:`torch.Tensor`): :math:`(B, output_support_size)`. + """ + # The original latent state is part of the input, used for the residual connection. + state_encoding = state_action_encoding[:, : -self.action_encoding_dim, :, :] + + # Main path for predicting the next latent state. + x = self.conv(state_action_encoding) + x = self.norm_common(x) + + # Add residual connection from the original latent state. + x += state_encoding + x = self.activation(x) + + for block in self.resblocks: + x = block(x) + next_latent_state = x + + # --- Reward Prediction Path --- + # Process the next latent state to predict the reward. + reward_x = self.conv1x1_reward(next_latent_state) + reward_x = self.norm_reward(reward_x) + reward_x = self.activation(reward_x) + # Flatten the features before passing to the MLP. + reward_x = reward_x.view(reward_x.shape[0], -1) + reward = self.fc_reward_head(reward_x) + + return next_latent_state, reward + + def get_dynamic_mean(self) -> float: + """ + Overview: + Computes the mean of parameters in the dynamics-related layers (conv and resblocks). + Returns: + - float: The mean value of dynamics parameters. + """ + return get_dynamic_mean(self) + + def get_reward_mean(self) -> Tuple[ndarray, float]: + """ + Overview: + Computes the mean of parameters and the last layer bias in the reward head. + Returns: + - Tuple[ndarray, float]: A tuple containing the mean of the last layer's weights and its bias. + """ + return get_reward_mean(self) \ No newline at end of file diff --git a/lzero/model/sampled_unizero_model_multitask.py b/lzero/model/sampled_unizero_model_multitask.py new file mode 100644 index 000000000..e0026d0ff --- /dev/null +++ b/lzero/model/sampled_unizero_model_multitask.py @@ -0,0 +1,262 @@ +from typing import Optional, List, Sequence + +import torch +import torch.nn as nn +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY +from easydict import EasyDict + +from .common import MZNetworkOutput, RepresentationNetworkUniZero, LatentDecoder, \ + FeatureAndGradientHook, SimNorm +from .unizero_world_models.tokenizer import Tokenizer +from .unizero_world_models.world_model_multitask import WorldModelMT + +class RepresentationNetworkMLPMT(nn.Module): + """ + Overview: + A multi-task representation network that encodes vector observations into a latent state + using a Multi-Layer Perceptron (MLP). It supports task-specific encoders and an optional + shared projection layer to map representations into a common embedding space. + """ + + def __init__( + self, + observation_shape_list: List[int], + hidden_channels: int = 64, + layer_num: int = 2, + activation: nn.Module = nn.GELU(approximate='tanh'), + norm_type: Optional[str] = 'BN', + embedding_dim: int = 256, + group_size: int = 8, + use_shared_projection: bool = False, + shared_projection_dim: Optional[int] = None, + final_norm_option_in_encoder: str = 'LayerNorm', # TODO: Further investigate norm options + ) -> None: + """ + Arguments: + - observation_shape_list (:obj:`List[int]`): A list of observation feature dimensions, one for each task. + - hidden_channels (:obj:`int`): The number of hidden channels in the task-specific MLPs. + - layer_num (:obj:`int`): The number of layers in each MLP. + - activation (:obj:`nn.Module`): The activation function to use in the MLPs. Defaults to nn.GELU(approximate='tanh'). + - norm_type (:obj:`str`): The type of normalization to use within the MLPs. Defaults to 'BN'. + - embedding_dim (:obj:`int`): The dimension of the final output embedding. + - group_size (:obj:`int`): The group size for SimNorm if it is used. + - use_shared_projection (:obj:`bool`): Whether to use a shared projection layer after task-specific encoding. Defaults to False. + - shared_projection_dim (:obj:`Optional[int]`): The dimension of the shared projection layer. If None, it defaults to `hidden_channels`. + - final_norm_option_in_encoder (:obj:`str`): The final normalization layer type ('LayerNorm' or 'SimNorm'). Defaults to 'LayerNorm'. + """ + super().__init__() + self.env_num = len(observation_shape_list) + self.use_shared_projection = use_shared_projection + self.hidden_channels = hidden_channels + self.shared_projection_dim = shared_projection_dim or hidden_channels + self.embedding_dim = embedding_dim + self.final_norm_option_in_encoder = final_norm_option_in_encoder + + # Task-specific representation networks + self.fc_representation = nn.ModuleList([ + MLP( + in_channels=obs_shape, + hidden_channels=hidden_channels, + out_channels=hidden_channels, + layer_num=layer_num, + activation=activation, + norm_type=norm_type, + # No activation or norm in the last layer is important for convergence. + output_activation=False, + output_norm=False, + # Initializing the last linear layer to zero can be beneficial for convergence speed. + last_linear_layer_init_zero=True, + ) + for obs_shape in observation_shape_list + ]) + + # Final normalization layer before projection + if self.final_norm_option_in_encoder == 'LayerNorm': + self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5) + elif self.final_norm_option_in_encoder == 'SimNorm': + self.final_norm = SimNorm(simnorm_dim=group_size) + else: + raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}") + + # Optional shared projection layer + if self.use_shared_projection: + self.shared_projection = nn.Linear(hidden_channels, self.shared_projection_dim) + # Using SimNorm for the shared space projection + self.projection_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): The input tensor of shape :math:`(B, N)`, where B is the batch size and N is the length of the vector observation. + - task_id (:obj:`int`): The identifier for the current task, used to select the appropriate encoder. + - output (:obj:`torch.Tensor`): The output latent state. Its shape is :math:`(B, embedding_dim)` if shared projection is not used, otherwise :math:`(B, shared_projection_dim)`. + """ + # Encode observation using the task-specific MLP + x = self.fc_representation[task_id](x) + # Apply final normalization + x = self.final_norm(x) + + # Apply the shared projection layer if enabled + if self.use_shared_projection: + x = self.shared_projection(x) + x = self.projection_norm(x) + return x + + +@MODEL_REGISTRY.register('SampledUniZeroMTModel') +class SampledUniZeroMTModel(nn.Module): + """ + Overview: + The main model for Sampled UniZero in a multi-task setting. It integrates a representation + network, a tokenizer, and a world model to perform initial and recurrent inference, + which are essential for MuZero-style planning algorithms. The model is designed to handle + both vector and image-based observations across multiple tasks. + """ + + def __init__( + self, + observation_shape_list: List[Sequence], + action_space_size_list: List[int], + num_res_blocks: int = 1, + num_channels: int = 64, + activation: nn.Module = nn.GELU(approximate='tanh'), + downsample: bool = True, + norm_type: Optional[str] = 'LN', + world_model_cfg: EasyDict = None, + *args, + **kwargs + ): + """ + Arguments: + - observation_shape_list (:obj:`List[Sequence]`): A list of observation space shapes for each task (e.g., `[C, W, H]` for images or `[D]` for vectors). + - action_space_size_list (:obj:`List[int]`): A list of action space sizes for each task. + - num_res_blocks (:obj:`int`): The number of residual blocks in the image representation network. + - num_channels (:obj:`int`): The number of channels in the hidden states of the image representation network. + - activation (:obj:`nn.Module`): The activation function used throughout the network. + - downsample (:obj:`bool`): Whether to downsample observations in the image representation network. + - norm_type (:obj:`str`): The type of normalization to use in networks. Defaults to 'LN'. + - world_model_cfg (:obj:`EasyDict`): A single configuration object for the world model, shared across all tasks. + """ + super(SampledUniZeroMTModel, self).__init__() + self.task_num = len(observation_shape_list) + self.activation = activation + self.downsample = downsample + + # Determine the embedding dimension for observations and actions + if world_model_cfg.task_embed_option == "concat_task_embed": + obs_act_embed_dim = world_model_cfg.embed_dim - world_model_cfg.task_embed_dim if hasattr(world_model_cfg, "task_embed_dim") else 96 + else: + obs_act_embed_dim = world_model_cfg.embed_dim + + world_model_cfg.norm_type = norm_type + assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, \ + 'max_tokens should be 2 * max_blocks, as each timestep consists of an observation and an action token.' + + # Initialize networks based on observation type + if world_model_cfg.obs_type == 'vector': + # A single representation network capable of handling multiple tasks via task_id + self.representation_network = RepresentationNetworkMLPMT( + observation_shape_list=observation_shape_list, + hidden_channels=obs_act_embed_dim, + layer_num=2, + activation=self.activation, + norm_type=norm_type, + embedding_dim=obs_act_embed_dim, + group_size=world_model_cfg.group_size, + use_shared_projection=world_model_cfg.use_shared_projection, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + ) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + + elif world_model_cfg.obs_type == 'image': + self.representation_network = nn.ModuleList() + # TODO: Currently uses a single shared encoder for all image-based tasks. + # This can be extended to support multiple independent encoders if needed. + for _ in range(1): + self.representation_network.append(RepresentationNetworkUniZero( + observation_shape_list[0], # Assuming shared encoder uses the shape of the first task + num_res_blocks, + num_channels, + self.downsample, + activation=self.activation, + norm_type=norm_type, + embedding_dim=obs_act_embed_dim, + group_size=world_model_cfg.group_size, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + )) + # TODO: The world model and tokenizer for the 'image' case should be initialized here. + # self.tokenizer = Tokenizer(...) + # self.world_model = WorldModelMT(...) + + # Print model parameter counts for verification + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + if hasattr(self.tokenizer, 'encoder') and self.tokenizer.encoder is not None: + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) + + def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torch.Tensor] = None, current_obs_batch: Optional[torch.Tensor] = None, task_id: Optional[int] = None) -> MZNetworkOutput: + """ + Overview: + Performs the initial inference step of the UniZero model. It takes an observation + and produces a latent state, a value prediction, and an initial policy. + Arguments: + - obs_batch (:obj:`torch.Tensor`): The initial batch of observations. + - action_batch (:obj:`Optional[torch.Tensor]`): An optional batch of actions. + - current_obs_batch (:obj:`Optional[torch.Tensor]`): An optional batch of current observations. + - task_id (:obj:`Optional[int]`): The identifier for the current task. + Returns (MZNetworkOutput): + An object containing the predicted value, initial reward (zero), policy logits, and latent state. + Shapes: + - obs_batch (:obj:`torch.Tensor`): :math:`(B, ...)` where B is the batch size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`. + - latent_state (:obj:`torch.Tensor`): :math:`(B, embedding_dim)`. + """ + batch_size = obs_batch.size(0) + obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch} + _, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict, task_id=task_id) + + latent_state = obs_token + policy_logits = logits_policy.squeeze(1) + value = logits_value.squeeze(1) + + return MZNetworkOutput( + value=value, + reward=[0. for _ in range(batch_size)], # Initial reward is always zero + policy_logits=policy_logits, + latent_state=latent_state, + ) + + def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index: int = 0, search_depth: List[int] = [], task_id: int = 0) -> MZNetworkOutput: + """ + Overview: + Performs the recurrent inference step (the dynamics function). Given a history of + latent states and actions, it predicts the next latent state, reward, value, and policy. + Arguments: + - state_action_history (:obj:`torch.Tensor`): A history of states and actions. + - simulation_index (:obj:`int`): The index of the current simulation step in MCTS. + - search_depth (:obj:`List[int]`): The indices of latent states in the current search path. + - task_id (:obj:`int`): The identifier for the current task. + Returns (MZNetworkOutput): + An object containing the predicted value, reward, policy logits, and the next latent state. + Shapes: + - state_action_history (:obj:`torch.Tensor`): :math:`(B, L, D)`, where L is sequence length. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, embedding_dim)`. + """ + _, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_recurrent_inference( + state_action_history, simulation_index, search_depth, task_id=task_id) + + next_latent_state = logits_observations + reward = logits_rewards.squeeze(1) + policy_logits = logits_policy.squeeze(1) + value = logits_value.squeeze(1) + + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) \ No newline at end of file diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index 9d57b3c5f..59b893b21 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -11,6 +11,7 @@ HFLanguageRepresentationNetwork, QwenNetwork from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model import WorldModel +from .vit import ViT, ViTConfig from ding.utils import ENV_REGISTRY, set_pkg_seed, get_rank, get_world_size @@ -88,7 +89,7 @@ def __init__( # TODO: only for MemoryEnv now self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25) self.tokenizer = Tokenizer(encoder=self.representation_network, - decoder_network=self.decoder_network, with_lpips=False) + decoder=self.decoder_network, with_lpips=False, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print('==' * 20) @@ -125,7 +126,7 @@ def __init__( self.decoder_network_tokenizer = None else: raise ValueError(f"Unsupported encoder option: {kwargs['encoder_option']}") - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option']) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') @@ -134,23 +135,41 @@ def __init__( print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') print('==' * 20) elif world_model_cfg.obs_type == 'image': - self.representation_network = RepresentationNetworkUniZero( - observation_shape, - num_res_blocks, - num_channels, - self.downsample, - activation=self.activation, - norm_type=norm_type, - embedding_dim=world_model_cfg.embed_dim, - group_size=world_model_cfg.group_size, - final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder - ) + if world_model_cfg.encoder_type == "resnet": + self.representation_network = RepresentationNetworkUniZero( + observation_shape, + num_res_blocks, + num_channels, + self.downsample, + activation=self.activation, + norm_type=norm_type, + embedding_dim=world_model_cfg.embed_dim, + group_size=world_model_cfg.group_size, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + ) + elif world_model_cfg.encoder_type == "vit": + # vit base + vit_config = ViTConfig( + image_size=observation_shape[1], + patch_size=8, + num_classes=world_model_cfg.embed_dim, + dim=768, + depth=12, + heads=12, + mlp_dim=3072, + dropout=0.1, + emb_dropout=0.1, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + lora_config=world_model_cfg, + ) + self.representation_network = ViT(config=vit_config) # ====== for analysis ====== if world_model_cfg.analysis_sim_norm: self.encoder_hook = FeatureAndGradientHook() self.encoder_hook.setup_hooks(self.representation_network) - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False,) + + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=None, with_lpips=False, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print('==' * 20) @@ -181,7 +200,7 @@ def __init__( self.encoder_hook = FeatureAndGradientHook() self.encoder_hook.setup_hooks(self.representation_network) - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py new file mode 100644 index 000000000..68095de46 --- /dev/null +++ b/lzero/model/unizero_model_multitask.py @@ -0,0 +1,284 @@ +from typing import Optional, Sequence, Dict, Any, List + +import torch +import torch.nn as nn +from ding.utils import MODEL_REGISTRY, SequenceType +from easydict import EasyDict + +from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ + VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook +from .unizero_world_models.tokenizer import Tokenizer +from .unizero_world_models.world_model_multitask import WorldModelMT +from .vit import ViT, ViTConfig + + +@MODEL_REGISTRY.register('UniZeroMTModel') +class UniZeroMTModel(nn.Module): + """ + Overview: + The main model for UniZero, a multi-task agent based on a scalable latent world model. + This class orchestrates the representation network, world model, and prediction heads. + It provides two primary interfaces: + - `initial_inference`: Encodes an observation to produce an initial latent state and predictions (value, policy). + - `recurrent_inference`: Simulates dynamics by taking a history of latent states and actions to predict the next + latent state, reward, value, and policy. + """ + + def __init__( + self, + observation_shape: SequenceType = (4, 64, 64), + action_space_size: int = 6, + num_res_blocks: int = 1, + num_channels: int = 64, + activation: nn.Module = nn.GELU(approximate='tanh'), + downsample: bool = True, + norm_type: str = 'BN', + world_model_cfg: EasyDict = None, + task_num: int = 1, + *args: Any, + **kwargs: Any + ) -> None: + """ + Overview: + Initializes the UniZeroMTModel, setting up the representation network, tokenizer, and world model + based on the provided configuration. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of the input observation, e.g., (C, H, W). + - action_space_size (:obj:`int`): The size of the discrete action space. + - num_res_blocks (:obj:`int`): The number of residual blocks in the ResNet-based representation network. + - num_channels (:obj:`int`): The number of channels in the ResNet-based representation network. + - activation (:obj:`nn.Module`): The activation function to use throughout the network. + - downsample (:obj:`bool`): Whether to downsample the observation in the representation network. + - norm_type (:obj:`str`): The type of normalization to use, e.g., 'BN' for BatchNorm. + - world_model_cfg (:obj:`EasyDict`): Configuration for the world model and its components. + - task_num (:obj:`int`): The number of tasks for multi-task learning. + """ + super().__init__() + print(f'========== Initializing UniZeroMTModel (num_res_blocks: {num_res_blocks}, num_channels: {num_channels}) ==========') + + # --- Basic attribute setup --- + self.task_num = task_num + self.activation = activation + self.downsample = downsample + world_model_cfg.norm_type = norm_type + + # NOTE: The action_space_size passed as an argument is immediately overridden. + # This might be intentional for specific experiments but is not a general practice. + self.action_space_size = 18 + + assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, \ + "max_tokens should be 2 * max_blocks, as each timestep consists of an observation and an action token." + + # --- Determine embedding dimensions --- + if world_model_cfg.task_embed_option == "concat_task_embed": + task_embed_dim = world_model_cfg.get("task_embed_dim", 32) # Default task_embed_dim to 32 if not specified + obs_act_embed_dim = world_model_cfg.embed_dim - task_embed_dim + else: + obs_act_embed_dim = world_model_cfg.embed_dim + + # --- Initialize model components based on observation type --- + obs_type = world_model_cfg.obs_type + if obs_type == 'vector': + self._init_vector_components(world_model_cfg, obs_act_embed_dim) + elif obs_type == 'image': + self._init_image_components(world_model_cfg, observation_shape, num_res_blocks, num_channels, obs_act_embed_dim) + elif obs_type == 'image_memory': + self._init_image_memory_components(world_model_cfg) + else: + raise ValueError(f"Unsupported observation type: {obs_type}") + + # --- Initialize world model and tokenizer --- + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + + # --- Log parameter counts for analysis --- + self._log_model_parameters(obs_type) + + def _init_vector_components(self, world_model_cfg: EasyDict, obs_act_embed_dim: int) -> None: + """Initializes components for 'vector' observation type.""" + self.representation_network = RepresentationNetworkMLP( + observation_shape=world_model_cfg.observation_shape, + hidden_channels=obs_act_embed_dim, + layer_num=2, + activation=self.activation, + group_size=world_model_cfg.group_size, + ) + # TODO: This is currently specific to MemoryEnv. Generalize if needed. + self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25) + self.tokenizer = Tokenizer( + encoder=self.representation_network, + decoder=self.decoder_network, + with_lpips=False, + obs_type=world_model_cfg.obs_type + ) + + def _init_image_components(self, world_model_cfg: EasyDict, observation_shape: SequenceType, num_res_blocks: int, + num_channels: int, obs_act_embed_dim: int) -> None: + """Initializes components for 'image' observation type.""" + self.representation_network = nn.ModuleList() + encoder_type = world_model_cfg.encoder_type + + # NOTE: Using a single shared encoder. The original code used a loop `for _ in range(1):`. + # To support N independent encoders, this logic would need to be modified. + if encoder_type == "resnet": + encoder = RepresentationNetworkUniZero( + observation_shape=observation_shape, + num_res_blocks=num_res_blocks, + num_channels=num_channels, + downsample=self.downsample, + activation=self.activation, + norm_type=world_model_cfg.norm_type, + embedding_dim=obs_act_embed_dim, + group_size=world_model_cfg.group_size, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + ) + self.representation_network.append(encoder) + elif encoder_type == "vit": + vit_configs = { + 'small': {'dim': 768, 'depth': 6, 'heads': 6, 'mlp_dim': 2048}, + 'base': {'dim': 768, 'depth': 12, 'heads': 12, 'mlp_dim': 3072}, + 'large': {'dim': 1024, 'depth': 24, 'heads': 16, 'mlp_dim': 4096}, + } + vit_size = 'base' if self.task_num > 8 else 'small' + selected_vit_config = vit_configs[vit_size] + + vit_params = { + 'image_size': observation_shape[1], + 'patch_size': 8, + 'num_classes': obs_act_embed_dim, + 'dropout': 0.1, + 'emb_dropout': 0.1, + 'final_norm_option_in_encoder': world_model_cfg.final_norm_option_in_encoder, + 'lora_config': world_model_cfg, + **selected_vit_config + } + vit_config = ViTConfig(**vit_params) + encoder = ViT(config=vit_config) + + self.representation_network.append(encoder) + else: + raise ValueError(f"Unsupported encoder type for image observations: {encoder_type}") + + # For image observations, the decoder is currently not used for reconstruction during training. + self.decoder_network = None + self.tokenizer = Tokenizer( + encoder=self.representation_network, + decoder=self.decoder_network, + with_lpips=False, + obs_type=world_model_cfg.obs_type + ) + if world_model_cfg.analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + def _init_image_memory_components(self, world_model_cfg: EasyDict) -> None: + """Initializes components for 'image_memory' observation type.""" + # TODO: The 'concat_task_embed' option needs to be fully implemented for this obs_type. + self.representation_network = LatentEncoderForMemoryEnv( + image_shape=(3, 5, 5), + embedding_size=world_model_cfg.embed_dim, + channels=[16, 32, 64], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation=self.activation, + group_size=world_model_cfg.group_size, + ) + self.decoder_network = LatentDecoderForMemoryEnv( + image_shape=(3, 5, 5), + embedding_size=world_model_cfg.embed_dim, + channels=[64, 32, 16], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation=self.activation, + ) + self.tokenizer = Tokenizer( + encoder=self.representation_network, + decoder=self.decoder_network, + with_lpips=True, + obs_type=world_model_cfg.obs_type + ) + if world_model_cfg.analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + def _log_model_parameters(self, obs_type: str) -> None: + """Logs the parameter counts of the main model components.""" + print('--------------------------------------------------') + print(f'{sum(p.numel() for p in self.world_model.parameters()):,} parameters in world_model') + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters()):,} parameters in world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters()):,} parameters in tokenizer.encoder') + + if obs_type in ['vector', 'image_memory'] and self.tokenizer.decoder_network is not None: + print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters()):,} parameters in tokenizer.decoder_network') + if obs_type == 'image_memory': + # Calculate parameters excluding decoder and LPIPS for a specific comparison point. + params_without_decoder = sum(p.numel() for p in self.world_model.parameters()) - \ + sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - \ + sum(p.numel() for p in self.tokenizer.lpips.parameters()) + print(f'{params_without_decoder:,} parameters in world_model (excluding decoder and lpips)') + print('--------------------------------------------------') + + def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torch.Tensor] = None, + current_obs_batch: Optional[torch.Tensor] = None, task_id: Optional[Any] = None) -> MZNetworkOutput: + """ + Overview: + Performs the initial inference step of the model, corresponding to the representation function `h` in MuZero. + It takes an observation and produces a latent state and initial predictions. + Arguments: + - obs_batch (:obj:`torch.Tensor`): A batch of initial observations. + - action_batch (:obj:`Optional[torch.Tensor]`): A batch of actions (if available, context-dependent). + - current_obs_batch (:obj:`Optional[torch.Tensor]`): A batch of current observations (if different from obs_batch). + - task_id (:obj:`Optional[Any]`): Identifier for the current task in a multi-task setting. + Returns: + - MZNetworkOutput: An object containing the predicted value, policy logits, and the initial latent state. + The reward is set to a zero tensor, as it's not predicted at the initial step. + """ + batch_size = obs_batch.size(0) + obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch} + + _, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference( + obs_act_dict, task_id=task_id + ) + + # The world model returns tokens and logits; map them to the standard MZNetworkOutput format. + latent_state = obs_token + policy_logits = logits_policy.squeeze(1) + value = logits_value.squeeze(1) + + return MZNetworkOutput( + value=value, + reward=torch.zeros(batch_size, device=value.device), # Reward is 0 at initial inference + policy_logits=policy_logits, + latent_state=latent_state, + ) + + def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index: int = 0, + search_depth: List = [], task_id: Optional[Any] = None) -> MZNetworkOutput: + """ + Overview: + Performs a recurrent inference step, corresponding to the dynamics function `g` and prediction + function `f` in MuZero. It predicts the next latent state, reward, policy, and value based on a + history of latent states and actions. + Arguments: + - state_action_history (:obj:`torch.Tensor`): A tensor representing the history of latent states and actions. + - simulation_index (:obj:`int`): The index of the current simulation step within MCTS. + - search_depth (:obj:`List`): Information about the search depth, used for positional embeddings. + - task_id (:obj:`Optional[Any]`): Identifier for the current task in a multi-task setting. + Returns: + - MZNetworkOutput: An object containing the predicted value, reward, policy logits, and the next latent state. + """ + _, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_recurrent_inference( + state_action_history, simulation_index, search_depth, task_id=task_id + ) + + # Map the world model outputs to the standard MZNetworkOutput format. + next_latent_state = logits_observations + reward = logits_rewards.squeeze(1) + policy_logits = logits_policy.squeeze(1) + value = logits_value.squeeze(1) + + return MZNetworkOutput( + value=value, + reward=reward, + policy_logits=policy_logits, + latent_state=next_latent_state, + ) \ No newline at end of file diff --git a/lzero/model/unizero_world_models/kv_caching.py b/lzero/model/unizero_world_models/kv_caching.py index 28b7b0ba2..cf040b13a 100644 --- a/lzero/model/unizero_world_models/kv_caching.py +++ b/lzero/model/unizero_world_models/kv_caching.py @@ -1,110 +1,254 @@ -# Modified from https://github.com/eloialonso/iris/blob/main/src/models/kv_caching.py +# -*- coding: utf-8 -*- +""" +This script is a refactored version of the key-value caching mechanism from: +https://github.com/eloialonso/iris/blob/main/src/models/kv_caching.py -from typing import Tuple +The optimization focuses on improving clarity, documentation, and adherence to modern coding standards +while strictly preserving the original functionality and external API. +""" +from typing import Tuple, Optional import numpy as np import torch +class AssignWithoutInplaceCheck(torch.autograd.Function): + """ + Overview: + A custom autograd function to perform an in-place-like assignment on a tensor slice + without triggering PyTorch's version counter checks. This is useful for updating + buffers or caches within a computation graph. + + Reference: + Inspired by discussions on the PyTorch forums, such as: + https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4 + + .. warning:: + This function is unsafe if the same slice of the input tensor is overwritten + multiple times, as it can lead to incorrect gradient calculations. + """ + + @staticmethod + def _get_slice(dim: int, start: int, stop: int) -> Tuple[slice, ...]: + """ + Overview: + Creates a slice tuple for indexing a tensor at a specific dimension. + Arguments: + - dim (:obj:`int`): The dimension to slice along. + - start (:obj:`int`): The starting index for the slice. + - stop (:obj:`int`): The ending index for the slice. + Returns: + - slice_tuple (:obj:`Tuple[slice, ...]`): A tuple of slice objects for indexing. + """ + return (slice(None),) * dim + (slice(start, stop),) + + @staticmethod + def forward( + ctx, + input_tensor: torch.Tensor, + value: torch.Tensor, + dim: int, + start: int, + stop: int + ) -> torch.Tensor: + """ + Overview: + The forward pass assigns the `value` tensor to a slice of the `input_tensor`. + Arguments: + - ctx: The context object for storing information for the backward pass. + - input_tensor (:obj:`torch.Tensor`): The tensor to be modified. + - value (:obj:`torch.Tensor`): The tensor to assign to the slice. + - dim (:obj:`int`): The dimension along which to perform the assignment. + - start (:obj:`int`): The starting index of the slice. + - stop (:obj:`int`): The ending index of the slice. + Returns: + - modified_tensor (:obj:`torch.Tensor`): The `input_tensor` after modification. + """ + ctx.dim = dim + ctx.start = start + ctx.stop = stop + # Directly modify the data of the input tensor to bypass version checks. + input_tensor.data[AssignWithoutInplaceCheck._get_slice(dim, start, stop)] = value + return input_tensor + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]: + """ + Overview: + The backward pass computes gradients for the inputs of the forward pass. + Arguments: + - ctx: The context object with saved information from the forward pass. + - grad_output (:obj:`torch.Tensor`): The gradient of the output tensor. + Returns: + - grad_input_tensor (:obj:`torch.Tensor`): The gradient with respect to `input_tensor`. + - grad_value (:obj:`torch.Tensor`): The gradient with respect to `value`. + - None, None, None: Gradients for `dim`, `start`, and `stop`, which are not needed. + """ + # The gradient for the original input tensor is the same as the output gradient. + grad_input_tensor = grad_output + # The gradient for the value tensor is the slice of the output gradient. + grad_value = grad_output[AssignWithoutInplaceCheck._get_slice(ctx.dim, ctx.start, ctx.stop)] + return grad_input_tensor, grad_value, None, None, None + + class Cache: + """ + Overview: + A cache for storing a single type of intermediate tensor (e.g., keys or values) + in a Transformer-like model. It handles dynamic updates and size management. + """ + def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: """ Overview: - Cache for storing intermediate results in a transformer model. + Initializes the cache. Arguments: - - num_samples (:obj:`int`): The number of samples to cache. + - num_samples (:obj:`int`): The number of samples (batch size) to cache. - num_heads (:obj:`int`): The number of attention heads. - - max_tokens (:obj:`int`): The maximum number of tokens. - - embed_dim (:obj:`int`): The dimension of the embeddings. - - device (:obj:`torch.device`): The device on which to store the cache. + - max_tokens (:obj:`int`): The maximum number of tokens the cache can hold. + - embed_dim (:obj:`int`): The total dimension of the embeddings. + - device (:obj:`torch.device`): The device on which to store the cache tensor. """ - assert embed_dim % num_heads == 0 - self._num_samples, self._cache, self._size = num_samples, None, None - self._reset = lambda n: torch.empty(n, num_heads, max_tokens, embed_dim // num_heads, device=device) # (B, nh, T, hs) + if embed_dim % num_heads != 0: + raise ValueError(f"Embedding dimension ({embed_dim}) must be divisible by the number of heads ({num_heads}).") + + self._num_samples = num_samples + self._num_heads = num_heads + self._max_tokens = max_tokens + self._head_dim = embed_dim // num_heads + self._device = device + + self._cache: torch.Tensor = self._create_cache_tensor(self._num_samples) + self._size: int = 0 self.reset() + def _create_cache_tensor(self, num_samples: int) -> torch.Tensor: + """ + Overview: + Creates an empty tensor with the correct shape and device for the cache. + Arguments: + - num_samples (:obj:`int`): The number of samples for which to create the cache. + Returns: + - empty_cache (:obj:`torch.Tensor`): An uninitialized tensor for the cache. + """ + return torch.empty( + num_samples, self._num_heads, self._max_tokens, self._head_dim, device=self._device + ) # Shape: (B, nh, T, hs) + @property def shape(self) -> Tuple[int, int, int, int]: """ Overview: - Get the shape of the cache. + Gets the effective shape of the cache's content. Returns: - - shape (:obj:`Tuple[int, int, int, int]`): The shape of the cache. + - shape (:obj:`Tuple[int, int, int, int]`): A tuple representing (num_samples, num_heads, current_size, head_dim). """ - n, num_heads, _, head_dim = self._cache.shape - return n, num_heads, self._size, head_dim + return self._num_samples, self._num_heads, self._size, self._head_dim def reset(self) -> None: """ Overview: - Reset the cache to its initial state. + Resets the cache to an empty state. """ - self._cache = self._reset(self._num_samples) + self._cache = self._create_cache_tensor(self._num_samples) self._size = 0 def prune(self, mask: np.ndarray) -> None: """ Overview: - Prune the cache based on a mask. + Prunes the cache along the sample dimension using a boolean mask. Arguments: - - mask (:obj:`np.ndarray`): A boolean mask indicating which samples to keep. + - mask (:obj:`np.ndarray`): A 1D boolean array where `True` indicates which samples to keep. """ - assert mask.ndim == 1 and mask.shape[0] == self.shape[0] + if not (mask.ndim == 1 and mask.shape[0] == self._num_samples): + raise ValueError("Mask must be a 1D numpy array with length equal to the number of samples.") self._cache = self._cache[mask] self._num_samples = self._cache.shape[0] def get(self) -> torch.Tensor: """ Overview: - Get the current contents of the cache. + Retrieves the current contents of the cache. Returns: - - cache (:obj:`torch.Tensor`): The current contents of the cache. + - cache_content (:obj:`torch.Tensor`): A tensor containing the valid data in the cache. """ return self._cache[:, :, :self._size, :] def update(self, x: torch.Tensor, tokens: int) -> None: """ Overview: - Update the cache with new values. + Updates the cache with new tensor values. If the cache is full, it discards the oldest + tokens to make space. Arguments: - - x (:obj:`torch.Tensor`): The new values to update the cache with. - - tokens (:obj:`int`): The number of tokens to update. + - x (:obj:`torch.Tensor`): The new tensor data to add to the cache. + - tokens (:obj:`int`): The number of tokens being added (sequence length of `x`). """ - # assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 1, 3)]) - # assert self._size + tokens <= self._cache.shape[2] # TODO - self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 2, self._size, self._size + tokens) + required_capacity = self._size + tokens + + # If the new tokens exceed the cache's maximum capacity, shift existing data to make room. + if required_capacity > self._max_tokens: + shift_amount = required_capacity - self._max_tokens + + # This logic is crucial for models like MuZero where tokens are added in (state, action) pairs. + # To maintain the integrity of these pairs, an even number of tokens must be discarded. + if shift_amount % 2 != 0: + shift_amount += 1 + + if shift_amount >= self._size: + # If the required shift is larger than the current cache size, it's more efficient to reset. + self._cache.zero_() + self._size = 0 + else: + # Shift the existing cache content to the left, discarding the oldest tokens. + self._cache[:, :, :self._size - shift_amount, :] = self._cache[:, :, shift_amount:self._size, :] + self._size -= shift_amount + # NOTE: Shifting the cache invalidates absolute positional embeddings. + # The parent model must handle positional encoding adjustments. For example, if positional + # embeddings are calculated based on `prev_steps`, this shift means `prev_steps` may no + # longer correspond to the true start, potentially causing discontinuities. + + # Use the custom autograd function to assign the new data without inplace errors. + self._cache = AssignWithoutInplaceCheck.apply( + self._cache, x, 2, self._size, self._size + tokens + ) self._size += tokens class KVCache: - def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: + """ + Overview: + A container for a pair of caches: one for keys (K) and one for values (V), + typically used in a single attention layer of a Transformer. + """ + + def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: """ Overview: - Cache for storing key and value tensors in a transformer model. + Initializes the Key-Value cache pair. Arguments: - - n (:obj:`int`): The number of samples to cache. + - num_samples (:obj:`int`): The number of samples (batch size) to cache. - num_heads (:obj:`int`): The number of attention heads. - - max_tokens (:obj:`int`): The maximum number of tokens. - - embed_dim (:obj:`int`): The dimension of the embeddings. - - device (:obj:`torch.device`): The device on which to store the cache. + - max_tokens (:obj:`int`): The maximum number of tokens the cache can hold. + - embed_dim (:obj:`int`): The total dimension of the embeddings. + - device (:obj:`torch.device`): The device on which to store the cache tensors. """ - self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device) - self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device) + self._k_cache = Cache(num_samples, num_heads, max_tokens, embed_dim, device) + self._v_cache = Cache(num_samples, num_heads, max_tokens, embed_dim, device) @property def shape(self) -> Tuple[int, int, int, int]: """ Overview: - Get the shape of the key cache. + Gets the effective shape of the key cache's content. Returns: - - shape (:obj:`Tuple[int, int, int, int]`): The shape of the key cache. + - shape (:obj:`Tuple[int, int, int, int]`): Shape of the key cache (num_samples, num_heads, current_size, head_dim). """ return self._k_cache.shape def reset(self) -> None: """ Overview: - Reset both key and value caches to their initial states. + Resets both the key and value caches to their empty states. """ self._k_cache.reset() self._v_cache.reset() @@ -112,9 +256,9 @@ def reset(self) -> None: def prune(self, mask: np.ndarray) -> None: """ Overview: - Prune both key and value caches based on a mask. + Prunes both key and value caches based on a boolean mask. Arguments: - - mask (:obj:`np.ndarray`): A boolean mask indicating which samples to keep. + - mask (:obj:`np.ndarray`): A 1D boolean array indicating which samples to keep. """ self._k_cache.prune(mask) self._v_cache.prune(mask) @@ -122,74 +266,94 @@ def prune(self, mask: np.ndarray) -> None: def get(self) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: - Get the current contents of the key and value caches. + Retrieves the current contents of the key and value caches. Returns: - key_cache (:obj:`torch.Tensor`): The current contents of the key cache. - value_cache (:obj:`torch.Tensor`): The current contents of the value cache. """ return self._k_cache.get(), self._v_cache.get() - def update(self, k: torch.Tensor, v: torch.Tensor): + def update(self, k: torch.Tensor, v: torch.Tensor) -> None: """ Overview: - Update both key and value caches with new values. + Updates both key and value caches with new tensors. Arguments: - - k (:obj:`torch.Tensor`): The new values to update the key cache with. - - v (:obj:`torch.Tensor`): The new values to update the value cache with. + - k (:obj:`torch.Tensor`): The new key tensor to add. + - v (:obj:`torch.Tensor`): The new value tensor to add. """ - self._k_cache.update(k, k.size(2)) - self._v_cache.update(v, v.size(2)) + # The number of tokens is inferred from the sequence dimension (dim 2). + num_tokens = k.size(2) + self._k_cache.update(k, num_tokens) + self._v_cache.update(v, num_tokens) class KeysValues: - def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device) -> None: + """ + Overview: + Manages a collection of KVCache objects, one for each layer in a Transformer model. + """ + + def __init__( + self, + num_samples: int, + num_heads: int, + max_tokens: int, + embed_dim: int, + num_layers: int, + device: torch.device + ) -> None: """ Overview: - Class for managing multiple layers of key and value caches in a transformer model. + Initializes KV caches for all layers. Arguments: - - n (:obj:`int`): The number of samples to cache. + - num_samples (:obj:`int`): The number of samples (batch size). - num_heads (:obj:`int`): The number of attention heads. - - max_tokens (:obj:`int`): The maximum number of tokens. + - max_tokens (:obj:`int`): The maximum number of tokens per cache. - embed_dim (:obj:`int`): The dimension of the embeddings. - - num_layers (:obj:`int`): The number of layers in the transformer model. - - device (:obj:`torch.device`): The device on which to store the caches. + - num_layers (:obj:`int`): The number of layers in the Transformer model. + - device (:obj:`torch.device`): The device for storing cache tensors. """ - self._keys_values = tuple([KVCache(n, num_heads, max_tokens, embed_dim, device) for _ in range(num_layers)]) + self._keys_values = tuple([ + KVCache(num_samples, num_heads, max_tokens, embed_dim, device) for _ in range(num_layers) + ]) - def __getitem__(self, index: int) -> KVCache: + def __getitem__(self, layer_index: int) -> KVCache: """ Overview: - Get the key and value cache for a specific layer. + Retrieves the KVCache for a specific layer. Arguments: - - index (:obj:`int`): The layer index. + - layer_index (:obj:`int`): The index of the layer. Returns: - - kv_cache (:obj:`KVCache`): The key and value cache for the specified layer. + - kv_cache (:obj:`KVCache`): The key-value cache for the specified layer. """ - return self._keys_values[index] + return self._keys_values[layer_index] - def __len__(self): + def __len__(self) -> int: """ Overview: - Get the number of layers in the transformer model. + Gets the number of layers. Returns: - - length (:obj:`int`): The number of layers. + - num_layers (:obj:`int`): The number of layers being managed. """ return len(self._keys_values) @property - def size(self): + def size(self) -> int: """ Overview: - Get the size of the tokens in the cache. + Gets the current number of tokens stored in the caches. Returns: - - size (:obj:`int`): The size of the tokens in the cache. + - size (:obj:`int`): The number of tokens in the cache (assumes all layers have the same size). """ + # All layer caches are synchronized, so we can check the size of the first one. + if not self._keys_values: + return 0 return self._keys_values[0].shape[2] def reset(self) -> None: """ Overview: - Reset all key and value caches to their initial states. + Resets the KV caches for all layers. """ for kv_cache in self._keys_values: kv_cache.reset() @@ -197,70 +361,27 @@ def reset(self) -> None: def prune(self, mask: np.ndarray) -> None: """ Overview: - Prune all key and value caches based on a mask. + Prunes the KV caches for all layers based on a mask. Arguments: - mask (:obj:`np.ndarray`): A boolean mask indicating which samples to keep. """ for kv_cache in self._keys_values: kv_cache.prune(mask) - -class AssignWithoutInplaceCheck(torch.autograd.Function): - """ - Overview: - Custom autograd function to perform in-place assignment without triggering version checks. - Inspired from: - https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4 - - .. warning: - Do not use it to overwrite a slice twice. - """ - - @staticmethod - def get_slice(dim: int, start: int, stop: int) -> Tuple[slice]: - """ - Overview: - Get the slice object for the given dimension and range. - Arguments: - - dim (:obj:`int`): The dimension along which to slice. - - start (:obj:`int`): The start index of the slice. - - stop (:obj:`int`): The stop index of the slice. - Returns: - - slice (:obj:`Tuple[slice]`): The slice object. - """ - return tuple([slice(None), ] * dim + [slice(start, stop)]) - - @staticmethod - def forward(ctx, input: torch.Tensor, value: torch.Tensor, dim: int, start: int, stop: int) -> torch.Tensor: - """ - Overview: - Forward pass of the custom autograd function. - Arguments: - - ctx: The context object to store information for backward computation. - - input (:obj:`torch.Tensor`): The input tensor to be modified. - - value (:obj:`torch.Tensor`): The value tensor to assign to the input. - - dim (:obj:`int`): The dimension along which to assign the value. - - start (:obj:`int`): The start index of the assignment. - - stop (:obj:`int`): The stop index of the assignment. - Returns: - - output (:obj:`torch.Tensor`): The modified input tensor. - """ - ctx.dim = dim - ctx.start = start - ctx.stop = stop - input.data[AssignWithoutInplaceCheck.get_slice(dim, start, stop)] = value - return input - - @staticmethod - def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor]: + def remove_register_tokens(self, register_token_num: int) -> None: """ Overview: - Backward pass of the custom autograd function. + Removes the last `register_token_num` tokens from the active view of the cache + in each layer by adjusting the internal size pointer. This does not delete the data + but makes it invisible to subsequent `get` and `update` calls. + This is typically called after an inference step that used temporary tokens + (e.g., register tokens) to ensure they are not part of the ongoing context. Arguments: - - ctx: The context object storing information from forward computation. - - grad_out (:obj:`torch.Tensor`): The gradient of the output tensor. - Returns: - - grad_input (:obj:`torch.Tensor`): The gradient of the input tensor. - - grad_value (:obj:`torch.Tensor`): The gradient of the value tensor. + - register_token_num (:obj:`int`): The number of tokens to remove from the end of the cache view. """ - return grad_out, grad_out[AssignWithoutInplaceCheck.get_slice(ctx.dim, ctx.start, ctx.stop)], None, None, None \ No newline at end of file + if register_token_num <= 0: + return + for kv_cache in self._keys_values: + # Decrement the size pointer for both K and V caches. + kv_cache._k_cache._size = max(0, kv_cache._k_cache._size - register_token_num) + kv_cache._v_cache._size = max(0, kv_cache._v_cache._size - register_token_num) \ No newline at end of file diff --git a/lzero/model/unizero_world_models/lpips.py b/lzero/model/unizero_world_models/lpips.py index c6ee6426c..2afa15a83 100644 --- a/lzero/model/unizero_world_models/lpips.py +++ b/lzero/model/unizero_world_models/lpips.py @@ -20,14 +20,14 @@ def __init__(self, use_dropout: bool = True): super().__init__() self.scaling_layer = ScalingLayer() self.chns = [64, 128, 256, 512, 512] # vg16 features + # Comment out the following line if you don't need perceptual loss # self.net = vgg16(pretrained=True, requires_grad=False) - self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) - self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) - self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) - self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) - self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) - # Comment out the following line if you don't need perceptual loss + # self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + # self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + # self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + # self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + # self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) # self.load_from_pretrained() # for param in self.parameters(): # param.requires_grad = False diff --git a/lzero/model/unizero_world_models/moe.py b/lzero/model/unizero_world_models/moe.py new file mode 100644 index 000000000..53f0c5620 --- /dev/null +++ b/lzero/model/unizero_world_models/moe.py @@ -0,0 +1,273 @@ +import dataclasses +from typing import List, Any + +import torch +import torch.nn.functional as F +from simple_parsing.helpers import Serializable +from torch import nn + +from lzero.model.unizero_world_models.transformer import _maybe_wrap_linear + +# Note: The following lines are examples of how _maybe_wrap_linear might be used. +# _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim), config, "feed_forward") + +# This implementation is inspired by the following sources: +# https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/moe.py +# https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/transformer_layers.py#L149 +# Modified from https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/transformer.py#L108 + + +class MultiplicationFeedForward(nn.Module): + """ + Overview: + Implements the SwiGLU (Swish-Gated Linear Unit) feed-forward layer, a variant of a transformer feed-forward network + that uses element-wise multiplication of two linear projections, one of which is passed through a SiLU activation. + This is often expressed as: FFN_SwiGLU(x) = (SiLU(x @ W1) * (x @ W3)) @ W2. + """ + + def __init__(self, config: Any) -> None: + """ + Overview: + Initializes the MultiplicationFeedForward layer. + Arguments: + - config (:obj:`Any`): A configuration object containing model hyperparameters. + It is expected to have `embed_dim` (int) and `moe_use_lora` (bool). + """ + super().__init__() + hidden_dim = 4 * config.embed_dim + if config.moe_use_lora: + self.w1 = _maybe_wrap_linear(nn.Linear(config.embed_dim, hidden_dim, bias=False), config, "feed_forward") + self.w2 = _maybe_wrap_linear(nn.Linear(hidden_dim, config.embed_dim, bias=False), config, "feed_forward") + self.w3 = _maybe_wrap_linear(nn.Linear(config.embed_dim, hidden_dim, bias=False), config, "feed_forward") + else: + self.w1 = nn.Linear(config.embed_dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, config.embed_dim, bias=False) + self.w3 = nn.Linear(config.embed_dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Performs the forward pass of the SwiGLU layer. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor. + Returns: + - torch.Tensor: The output tensor after applying the SwiGLU transformation. + """ + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +@dataclasses.dataclass +class MoeArgs(Serializable): + """ + Overview: + Dataclass for storing Mixture-of-Experts (MoE) configuration arguments. + """ + num_experts: int # The total number of experts in the MoE layer. + num_experts_per_tok: int # The number of experts to route each token to (k). + + +class MoELayer(nn.Module): + """ + Overview: + A straightforward implementation of a Mixture-of-Experts (MoE) layer. + This version iterates through each expert and processes the tokens routed to it. + While clear and easy to understand, it can be less efficient than vectorized approaches. + + The process is as follows: + 1. The input tensor `x` is flattened from [B, T, D] to [N, D], where N = B * T. + 2. A gating network calculates logits for each token to determine expert assignment. + 3. For each token, the top-k experts are selected based on the logits. + 4. The layer iterates through each expert, gathers all tokens assigned to it, + and computes their outputs. + 5. The outputs are weighted by the gating scores and summed up. + 6. An optional shared expert can be applied to all tokens. + 7. The final tensor is reshaped to its original shape [B, T, D]. + + Attributes: + - dim (:obj:`int`): The dimension of the input features. + - num_experts (:obj:`int`): The total number of experts. + - num_experts_per_tok (:obj:`int`): The number of experts activated per token (top-k). + - gate (:obj:`nn.Module`): The gating network that produces routing logits. + - experts (:obj:`nn.ModuleList`): A list of expert networks. + - shared_expert (:obj:`nn.Module` or `None`): An optional shared expert applied to all tokens. + """ + + def __init__(self, config: Any, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok: int = 1) -> None: + """ + Overview: + Initializes the MoELayer. + Arguments: + - config (:obj:`Any`): A configuration object. Expected to have `embed_dim` and optionally `n_shared_experts`. + - experts (:obj:`List[nn.Module]`): A list of PyTorch modules representing the experts. + - gate (:obj:`nn.Module`): The gating module for routing tokens. + - num_experts_per_tok (:obj:`int`): The number of experts to use for each token. + """ + super().__init__() + self.dim = config.embed_dim + self.num_experts = len(experts) + self.num_experts_per_tok = num_experts_per_tok + self.gate = gate + self.experts = nn.ModuleList(experts) + + # If specified in the config, create a shared expert branch. + if hasattr(config, "n_shared_experts") and config.n_shared_experts > 0: + # TODO: The architecture of the shared expert could be made more configurable. + self.shared_expert = nn.Sequential( + nn.Linear(self.dim, config.n_shared_experts * (4 * self.dim)), + nn.GELU(), + nn.Linear(config.n_shared_experts * (4 * self.dim), self.dim) + ) + else: + self.shared_expert = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Performs the forward pass for the MoE layer. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor of shape [batch_size, seq_len, dim]. + Returns: + - torch.Tensor: The output tensor with the same shape as the input. + """ + # Store original shape and flatten input to 2D: [batch_size * seq_len, dim] + original_shape = x.size() + x = x.view(-1, self.dim) + + # Compute gate logits, shape: [num_tokens, num_experts] + gate_logits = self.gate(x) + # Select top-k experts for each token. + weights, indices = torch.topk(gate_logits, self.num_experts_per_tok, dim=1) + # Normalize the weights of selected experts using softmax. + weights = F.softmax(weights, dim=1).to(x.dtype) + + # Initialize the output tensor for expert computations. + expert_output = torch.zeros_like(x) + + # Iterate over each expert to compute outputs for the tokens routed to it. + for expert_id in range(self.num_experts): + # Find the tokens that have this expert in their top-k list. + batch_idx, expert_tok_idx = torch.where(indices == expert_id) + if batch_idx.numel() == 0: + continue + + # Select the subset of tokens for the current expert. + token_subset = x[batch_idx] # Shape: [num_tokens_for_expert, dim] + # Compute the output from the current expert. + output_expert = self.experts[expert_id](token_subset) + # Get the corresponding weights for these tokens. + token_weights = weights[batch_idx, expert_tok_idx].unsqueeze(-1) + # Apply weights and accumulate the output. + expert_output[batch_idx] += output_expert * token_weights + + # If a shared expert exists, add its output. + if self.shared_expert is not None: + shared_output = self.shared_expert(x) + output = expert_output + shared_output + else: + output = expert_output + + # Restore the original tensor shape and return. + return output.view(original_shape) + + +class MoELayerOptimized(nn.Module): + """ + Overview: + An optimized implementation of the Mixture-of-Experts (MoE) layer that maintains the same API as `MoELayer`. + This version avoids loops over experts by using a vectorized scatter-gather approach, which is significantly + more efficient on modern hardware. The forward pass complexity is O(N_tokens + ΣE_i), where ΣE_i is the + total number of tokens processed across all experts. + + The process is as follows: + 1. **Routing**: Get top-k experts and their weights for each token. + 2. **Flattening**: Create a flat list of (token_index, expert_index, weight) tuples. + 3. **Sorting**: Sort these tuples by expert_index. This groups all tokens destined for the same expert together. + 4. **Batch Forward**: Process the tokens for each expert in a single, contiguous batch, avoiding Python loops. + 5. **Weighted Scatter**: Apply gating weights to the expert outputs and scatter-add them back to a buffer + indexed by the original token positions. + 6. **Shared Expert**: If configured, add the output from the shared expert. + 7. **Reshape**: Reshape the final output tensor to its original 3D shape. + """ + + def __init__(self, config: Any, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok: int = 1) -> None: + """ + Overview: + Initializes the MoELayerOptimized. + Arguments: + - config (:obj:`Any`): A configuration object. Expected to have `embed_dim` and optionally `n_shared_experts`. + - experts (:obj:`List[nn.Module]`): A list of PyTorch modules representing the experts. + - gate (:obj:`nn.Module`): The gating module for routing tokens. + - num_experts_per_tok (:obj:`int`): The number of experts to use for each token. + """ + super().__init__() + self.dim = config.embed_dim + self.num_experts = len(experts) + self.num_experts_per_tok = num_experts_per_tok + self.gate = gate + self.experts = nn.ModuleList(experts) + + self.use_shared = getattr(config, "n_shared_experts", 0) > 0 + if self.use_shared: + # TODO: The architecture of the shared expert could be made more configurable. + self.shared_expert = nn.Sequential( + nn.Linear(self.dim, config.n_shared_experts * (4 * self.dim)), + nn.GELU(), + nn.Linear(config.n_shared_experts * (4 * self.dim), self.dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Performs the optimized forward pass for the MoE layer. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor of shape [B, T, D]. + Returns: + - torch.Tensor: The output tensor with the same shape as the input. + """ + B, T, D = x.shape + x_flat = x.reshape(-1, D) # [N, D]; N = B*T + + # 1. Routing: Get top-k experts and weights. + gate_logits = self.gate(x_flat) # [N, E] + weights, topk_idx = torch.topk(gate_logits, self.num_experts_per_tok, dim=1) # [N, k] + weights = F.softmax(weights, dim=1).to(x.dtype) # [N, k] + + # 2. Flatten token-expert pairs. + N, k = weights.shape + flat_token_idx = torch.arange(N, device=x.device).repeat_interleave(k) # [N*k] + flat_expert_idx = topk_idx.reshape(-1) # [N*k] + flat_weight = weights.reshape(-1, 1) # [N*k, 1] + flat_input = x_flat[flat_token_idx] # [N*k, D] + + # 3. Sort by expert index to group tokens for batch processing. + sort_order = torch.argsort(flat_expert_idx) # [N*k] + flat_expert_idx = flat_expert_idx[sort_order] + flat_token_idx = flat_token_idx[sort_order] + flat_weight = flat_weight[sort_order] + flat_input = flat_input[sort_order] + + # Count how many tokens each expert will process. + counts = torch.bincount(flat_expert_idx, minlength=self.num_experts) # [E] + + # Prepare output buffer. + out_buffer = torch.zeros_like(flat_input) # [N*k, D] + + # 4. Perform forward pass for each expert on its batch of tokens. + ptr = 0 + for eid, num in enumerate(counts.tolist()): + if num == 0: + continue + seg = slice(ptr, ptr + num) + out_buffer[seg] = self.experts[eid](flat_input[seg]) + ptr += num + + # 5. Apply weights and scatter-add results back to token-indexed buffer. + out_buffer.mul_(flat_weight) # In-place multiplication by weights. + token_output = torch.zeros_like(x_flat) # [N, D] + token_output.index_add_(0, flat_token_idx, out_buffer) + + # 6. Add shared expert output if it exists. + if self.use_shared: + token_output.add_(self.shared_expert(x_flat)) + + return token_output.reshape(B, T, D) \ No newline at end of file diff --git a/lzero/model/unizero_world_models/test_moe.py b/lzero/model/unizero_world_models/test_moe.py new file mode 100644 index 000000000..1f0f5437c --- /dev/null +++ b/lzero/model/unizero_world_models/test_moe.py @@ -0,0 +1,200 @@ +""" +test_moe.py + +Overview: + A test script to verify the functional equivalence between a standard Transformer's feed-forward network (FFN) + and a Mixture-of-Experts (MoE) layer configured with a single expert. This script demonstrates that + the MoE layer correctly specializes to a standard FFN when num_experts is 1, ensuring backward + compatibility and correct routing logic. +""" +import dataclasses +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +@dataclasses.dataclass +class TransformerConfig: + """ + Overview: + Configuration for the Transformer block and its potential MoE layer. + + Arguments: + - embed_dim (int): The embedding dimension for the model. + - resid_pdrop (float): The dropout probability for the residual connections. + - moe_in_transformer (bool): If True, use an MoE layer for the feed-forward part. Otherwise, use a standard MLP. + - num_experts (int): The total number of experts in the MoE layer. + - num_experts_per_tok (int): The number of experts to route each token to (top-k routing). + """ + embed_dim: int = 64 + resid_pdrop: float = 0.1 + moe_in_transformer: bool = False + num_experts: int = 1 + num_experts_per_tok: int = 1 + + +class MoELayer(nn.Module): + """ + Overview: + An efficient, vectorized implementation of a Mixture-of-Experts (MoE) layer. + This layer routes each token to a subset of experts (Top-k routing) and combines their + outputs using a weighted sum. The implementation is highly optimized for parallel + computation on hardware like GPUs. + """ + + def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok: int): + """ + Overview: + Initializes the MoE layer. + Arguments: + - experts (List[nn.Module]): A list of expert neural network modules. + - gate (nn.Module): The gating network that computes routing logits. + - num_experts_per_tok (int): The number of experts to route each token to. + """ + super().__init__() + assert len(experts) > 0, "The list of experts cannot be empty." + self.experts = nn.ModuleList(experts) + self.gate = gate + self.num_experts = len(experts) + self.num_experts_per_tok = num_experts_per_tok + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Performs the forward pass of the MoE layer. + Arguments: + - x (torch.Tensor): Input tensor of shape `[batch_size, seq_len, embed_dim]`. + Returns: + - (torch.Tensor): Output tensor of the same shape as the input. + """ + batch_size, seq_len, dim = x.shape + x_flat = x.view(-1, dim) + + gate_logits = self.gate(x_flat) + weights, topk_indices = torch.topk(gate_logits, self.num_experts_per_tok, dim=1) + weights = F.softmax(weights, dim=1, dtype=torch.float).to(x.dtype) + + num_tokens = x_flat.shape[0] + flat_token_indices = torch.arange(num_tokens, device=x.device).repeat_interleave(self.num_experts_per_tok) + flat_expert_indices = topk_indices.view(-1) + + sort_order = torch.argsort(flat_expert_indices) + sorted_expert_indices = flat_expert_indices[sort_order] + sorted_token_indices = flat_token_indices[sort_order] + + expert_inputs = x_flat[sorted_token_indices] + sorted_weights = weights.view(-1, 1)[sort_order] + + expert_counts = torch.bincount(sorted_expert_indices, minlength=self.num_experts) + output_buffer = torch.zeros_like(expert_inputs) + + ptr = 0 + for i, count in enumerate(expert_counts.tolist()): + if count == 0: + continue + segment = slice(ptr, ptr + count) + output_buffer[segment] = self.experts[i](expert_inputs[segment]) + ptr += count + + # --- FIX: Simplified and corrected scattering logic --- + # Weight the outputs and directly add them to the correct token's position. + weighted_outputs = output_buffer * sorted_weights + + token_output = torch.zeros_like(x_flat) + # Use `sorted_token_indices` to add the results back to their original token positions. + token_output.index_add_(0, sorted_token_indices, weighted_outputs) + + return token_output.view(batch_size, seq_len, dim) + + +class TransformerBlock(nn.Module): + """ + Overview: + A simplified Transformer block that contains a feed-forward network (FFN). + The FFN can be either a standard MLP or a Mixture-of-Experts (MoE) layer, + controlled by the configuration. + """ + def __init__(self, config: TransformerConfig): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) + + if config.moe_in_transformer: + experts = [self.mlp for _ in range(config.num_experts)] + self.feed_forward = MoELayer( + experts=experts, + gate=nn.Linear(config.embed_dim, config.num_experts, bias=False), + num_experts_per_tok=config.num_experts_per_tok, + ) + print("=" * 40) + print("TransformerBlock initialized with MoE layer.") + print("=" * 40) + else: + self.feed_forward = self.mlp + print("-" * 40) + print("TransformerBlock initialized with standard MLP.") + print("-" * 40) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.feed_forward(x) + + +def test_transformer_block_equivalence(): + """ + Overview: + Tests that an MoE layer with a single expert produces an output identical + to that of a standard MLP layer, given that they share the same weights. + """ + torch.manual_seed(42) + + embed_dim = 64 + batch_size = 10 + seq_len = 5 + + config_mlp = TransformerConfig(embed_dim=embed_dim, moe_in_transformer=False) + config_moe = TransformerConfig(embed_dim=embed_dim, moe_in_transformer=True, num_experts=1, num_experts_per_tok=1) + + # --- FIX: Ensure identical weights for a fair comparison --- + # 1. Create the standard MLP block first. + transformer_block_mlp = TransformerBlock(config_mlp) + + # 2. Create the MoE block. + transformer_block_moe = TransformerBlock(config_moe) + + # 3. CRITICAL: Load the MLP's weights into the MoE's expert MLP. + # This guarantees that the underlying expert has the exact same weights as the standalone MLP. + transformer_block_moe.mlp.load_state_dict(transformer_block_mlp.mlp.state_dict()) + + # Also, for a perfect match, the gate should be initialized to a state + # that it doesn't affect the output scaling. We can manually set its weights. + # In a single-expert case, softmax ensures the weight is 1, so this is not strictly + # necessary, but it's good practice for more complex tests. + + inputs = torch.randn(batch_size, seq_len, embed_dim) + + print("\nRunning forward pass for standard MLP block...") + output_mlp = transformer_block_mlp(inputs) + + print("\nRunning forward pass for MoE block...") + output_moe = transformer_block_moe(inputs) + + is_close = torch.allclose(output_moe, output_mlp, atol=1e-6) + mse_difference = F.mse_loss(output_moe, output_mlp).item() + + print("\n" + "=" * 25 + " TEST RESULTS " + "=" * 25) + print(f"Outputs are close: {is_close}") + print(f"Mean Squared Error (MSE) between outputs: {mse_difference:.10f}") + + assert is_close, "Test failed: Outputs of single-expert MoE and MLP are not identical." + print("\n✅ Test Passed: Single-expert MoE layer behaves identically to a standard MLP.") + print("=" * 64 + "\n") + + +if __name__ == "__main__": + test_transformer_block_equivalence() \ No newline at end of file diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index e5e18461f..65325b3b4 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -1,8 +1,10 @@ """ Modified from https://github.com/CompVis/taming-transformers +This module provides an autoencoder-style tokenizer for encoding observations into latent embeddings and decoding them back. """ from dataclasses import dataclass +from typing import Any, Dict, Optional import torch import torch.nn as nn @@ -12,105 +14,157 @@ from transformers.modeling_outputs import BaseModelOutput class LossWithIntermediateLosses: - def __init__(self, **kwargs): - """Initialize with various loss components.""" - self.loss_total = sum(kwargs.values()) - self.intermediate_losses = {k: v.item() for k, v in kwargs.items()} - - def __truediv__(self, value): - """Divide all loss components by a given value.""" - for k, v in self.intermediate_losses.items(): - self.intermediate_losses[k] = v / value + """ + Overview: + A helper class to manage a total loss value alongside a dictionary of its constituent, named loss components. + This is primarily used for detailed logging. + """ + + def __init__(self, **kwargs: torch.Tensor) -> None: + """ + Overview: + Initializes the loss object. + Arguments: + - kwargs (:obj:`torch.Tensor`): Keyword arguments where keys are loss names and values are the corresponding loss tensors. + """ + # The total loss, which can be used for backpropagation. + self.loss_total: torch.Tensor = sum(kwargs.values()) + # A dictionary holding the scalar values of intermediate losses, detached from the computation graph. + self.intermediate_losses: Dict[str, float] = {k: v.item() for k, v in kwargs.items()} + + def __truediv__(self, value: float) -> "LossWithIntermediateLosses": + """ + Overview: + Overloads the division operator to scale all loss components by a scalar value. + This is useful for operations like averaging over batch size or gradient accumulation steps. + Arguments: + - value (:obj:`float`): The scalar value to divide the losses by. + Returns: + - LossWithIntermediateLosses: The same instance with updated loss values. + """ + if not isinstance(value, (int, float)) or value == 0: + raise ValueError(f"Division is only supported for a non-zero scalar, but got {value}.") + self.loss_total = self.loss_total / value + for k in self.intermediate_losses: + self.intermediate_losses[k] /= value return self @dataclass class TokenizerEncoderOutput: + """ + Overview: + A data structure to hold the various outputs from a VQ-VAE style encoder, + including continuous and quantized latent representations, and discrete tokens. + """ + # Continuous latent representation from the encoder. z: torch.FloatTensor + # Quantized latent representation. z_quantized: torch.FloatTensor + # Discrete integer tokens corresponding to the codebook entries. tokens: torch.LongTensor class Tokenizer(nn.Module): """ Overview: - Tokenizer model that encodes and decodes observations. - Can operate on visual or textual data, supporting optional LPIPS perceptual loss. - It optionally includes a linear projection layer and can be paired with a decoder tokenizer. + An autoencoder model that encodes high-dimensional observations (like images or state vectors) + into low-dimensional latent embeddings and decodes them back. It can also compute reconstruction + and perceptual losses. This implementation does not include the quantization step (Vector Quantization) + but serves as the encoder-decoder backbone. """ - def __init__(self, encoder=None, decoder_network=None, decoder_network_tokenizer=None, with_lpips: bool = False, projection: list = None, encoder_option='legacy') -> None: - """Initialize the Tokenizer. + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + with_lpips: bool = False, + obs_type: str = 'image' + ) -> None: + """ + Overview: + Initializes the Tokenizer (Autoencoder). Arguments: - encoder (nn.Module, optional): Encoder network to transform raw inputs into embeddings. - decoder_network (nn.Module, optional): Decoder network used for observation reconstruction or text generation. - decoder_network_tokenizer (PreTrainedTokenizer, optional): Tokenizer compatible with the decoder network (e.g., T5 tokenizer). - with_lpips (bool, optional): If True, enable perceptual loss computation via LPIPS. Defaults to False. - projection (list[int], optional): If provided, defines a linear projection layer from projection[0] → projection[1]. - If None, an identity layer is used. - encoder_option (str, optional): Option to specify the encoder type, e.g., 'legacy' for T5 decoder or 'qwen' for Qwen decoder. Defaults to 'legacy'. + - encoder (:obj:`nn.Module`): The network responsible for encoding observations into latent embeddings. It can be a single module or an nn.ModuleList for multi-task scenarios. + - decoder (:obj:`nn.Module`): The network responsible for decoding latent embeddings back into observations. + - with_lpips (:obj:`bool`): If True, initializes the LPIPS model to compute perceptual loss. Defaults to False. + - obs_type (:obj:`str`): The type of observation, e.g., 'image' or 'vector'. This can inform model architecture choices. Defaults to 'image'. """ super().__init__() + self.encoder = encoder + self.decoder_network = decoder + self.obs_type = obs_type + self.lpips: Optional[nn.Module] = None if with_lpips: + # Lazily import LPIPS as it's an optional dependency. from lzero.model.unizero_world_models.lpips import LPIPS self.lpips = LPIPS().eval() - else: - self.lpips = None - - self.encoder = encoder - self.decoder_network = decoder_network - self.decoder_network_tokenizer = decoder_network_tokenizer - self.encoder_option = encoder_option - if projection is None: - self.projection_layer = nn.Identity() - else: - self.projection_layer = nn.Linear(projection[0], projection[1]) - - def encode_to_obs_embeddings(self, x: torch.Tensor) -> torch.Tensor: + def encode_to_obs_embeddings(self, x: torch.Tensor, task_id: int = 0) -> torch.Tensor: """ - Encode observations to embeddings. - + Overview: + Encodes a batch of observations into latent embeddings, handling various input shapes and multi-task encoders. Arguments: - x (torch.Tensor): Input tensor of shape (B, ...). - + - x (:obj:`torch.Tensor`): The input tensor of observations. Shape can be (B, E), (B, T, E), (B, C, H, W), or (B, T, C, H, W). + - task_id (:obj:`int`): The identifier for the task, used to select the correct encoder from an nn.ModuleList in multi-task settings. Defaults to 0. Returns: - torch.Tensor: Encoded embeddings of shape (B, 1, E). + - torch.Tensor: The encoded latent embeddings with a consistent shape of (B, 1, E), where B is the effective batch size. """ - shape = x.shape - # Process input tensor based on its dimensionality - if len(shape) == 2: - # Case when input is 2D (B, E) - obs_embeddings = self.encoder(x) - obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') - elif len(shape) == 3: - # Case when input is 3D (B, T, E) - x = x.contiguous().view(-1, shape[-1]) # Flatten the last two dimensions (B * T, E) - obs_embeddings = self.encoder(x) - obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') - elif len(shape) == 4: - # Case when input is 4D (B, C, H, W) - obs_embeddings = self.encoder(x) - obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') - elif len(shape) == 5: - # Case when input is 5D (B, T, C, H, W) - x = x.contiguous().view(-1, *shape[-3:]) # Flatten the first two dimensions (B * T, C, H, W) - obs_embeddings = self.encoder(x) - obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') + + # global DEBUG_ENABLED;DEBUG_ENABLED = True + # import torch.distributed as dist + # if dist.get_rank() == 0 and DEBUG_ENABLED: + # print(f"rank {dist.get_rank()} 进入调试模式,输入interact,可以键入整段的python代码调试。通过设置 DEBUG_ENABLED = False, 可以跳过调试状态") + # import ipdb; ipdb.set_trace() + # # 同步点,防止其它进程早跑 + # dist.barrier() + + # Step 1: Select the appropriate encoder module. + # This handles both single-task (a single nn.Module) and multi-task (an nn.ModuleList) scenarios. + if isinstance(self.encoder, nn.ModuleList): + if not 0 <= task_id < len(self.encoder): + # raise ValueError( + # f"Provided task_id {task_id} is invalid for the encoder list of size {len(self.encoder)}." + # ) + encoder_module = self.encoder[0] + else: + encoder_module = self.encoder[task_id] else: - raise ValueError(f"Invalid input shape: {shape}") + encoder_module = self.encoder + + # Step 2: Pre-process and reshape the input tensor based on its dimensions. + # The goal is to transform the input into a 2D or 4D tensor that the encoder can process. + original_shape = x.shape + if len(original_shape) == 5: # Batch of sequences of images: (B, T, C, H, W) + # Flatten the batch and time dimensions to create a batch of images. + x = x.contiguous().view(-1, *original_shape[-3:]) # Shape: (B*T, C, H, W) + elif len(original_shape) == 3: # Batch of sequences of vectors: (B, T, E) + # Flatten the batch and time dimensions to create a batch of vectors. + x = x.contiguous().view(-1, original_shape[-1]) # Shape: (B*T, E) + # Note: 2D (B, E) and 4D (B, C, H, W) inputs are processed directly without reshaping. + + # Step 3: Pass the processed tensor through the encoder. + obs_embeddings = encoder_module(x) + if len(obs_embeddings.shape) != 2: + raise RuntimeError( + f"Encoder output was expected to be 2D (batch, embedding_dim), but got shape {obs_embeddings.shape}." + ) + + # Step 4: Reshape the output to a consistent sequence format (B', 1, E). + # The '1' represents a sequence length of one, making it compatible with sequence models. + obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') return obs_embeddings def decode_to_obs(self, embeddings: torch.Tensor) -> torch.Tensor: - """Decode embeddings to observations. - + """ + Overview: + Decodes a batch of latent embeddings back into the observation space. Arguments: - embeddings (:obj:`torch.Tensor`): Input embeddings. - + - embeddings (:obj:`torch.Tensor`): The latent embeddings to decode. Returns: - torch.Tensor: Decoded observations. + - torch.Tensor: The reconstructed observations. """ return self.decoder_network(embeddings) @@ -268,36 +322,43 @@ def decode_to_plain_text( @staticmethod def reconstruction_loss(original_images: torch.Tensor, reconstructed_images: torch.Tensor) -> torch.Tensor: - """Calculate the reconstruction loss. - + """ + Overview: + Calculates the reconstruction loss between original and reconstructed observations. + It uses L2 (MSE) loss for vector-based observations and L1 (MAE) loss for image-based observations. Arguments: - - original_images (:obj:`torch.Tensor`): Original images. - - reconstructed_images (:obj:`torch.Tensor`): Reconstructed images. - + - original_images (:obj:`torch.Tensor`): The ground-truth observations. + - reconstructed_images (:obj:`torch.Tensor`): The observations reconstructed by the decoder. Returns: - - torch.Tensor: Computed reconstruction loss. + - torch.Tensor: A scalar tensor representing the computed reconstruction loss. """ if len(original_images.shape) == 2: - # For memory environment vector observations - loss = F.mse_loss(original_images, reconstructed_images) # L2 loss + # Use Mean Squared Error (L2 loss) for vector-based observations. + return F.mse_loss(reconstructed_images, original_images) else: - # For Atari image environment - loss = torch.abs(original_images - reconstructed_images).mean() # L1 loss - return loss + # Use Mean Absolute Error (L1 loss) for image-based observations, which is often more robust to outliers. + return torch.abs(original_images - reconstructed_images).mean() def perceptual_loss(self, original_images: torch.Tensor, reconstructed_images: torch.Tensor) -> torch.Tensor: - """Calculate the perceptual loss using LPIPS. - + """ + Overview: + Calculates the perceptual loss (LPIPS) between original and reconstructed images. + This loss is designed to better align with human perception of image similarity. Arguments: - original_images (:obj:`torch.Tensor`): Original images. - reconstructed_images (:obj:`torch.Tensor`): Reconstructed images. - + - original_images (:obj:`torch.Tensor`): The ground-truth images. + - reconstructed_images (:obj:`torch.Tensor`): The images reconstructed by the decoder. Returns: - torch.Tensor: Computed perceptual loss. + - torch.Tensor: A scalar tensor representing the computed perceptual loss. """ + if self.lpips is None: + raise RuntimeError("LPIPS model was not initialized. Please set `with_lpips=True` during Tokenizer instantiation.") return torch.mean(self.lpips(original_images, reconstructed_images)) def __repr__(self) -> str: - return "Tokenizer" \ No newline at end of file + """ + Overview: + Provides a string representation of the Tokenizer module. + """ + return f"Tokenizer(obs_type='{self.obs_type}', with_lpips={self.lpips is not None})" \ No newline at end of file diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index c2feb8497..0e855d289 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -1,372 +1,810 @@ """ -The following code is modified from https://github.com/karpathy/nanoGPT. +This script is an extension of the original transformer.py from karpathy/nanoGPT. +It incorporates LoRA (Low-Rank Adaptation) for fine-tuning and introduces a +Curriculum Learning mechanism that activates different LoRA adapters sequentially. + +Key features: +- Adds `CurriculumLoRALinear`, a custom linear layer with multiple LoRA adapters. +- Controls which modules to apply LoRA to via configuration (e.g., attention and feed-forward layers). +- Maintains the extensibility and readability of the original nanoGPT codebase. """ -import numpy as np import math +import logging from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn as nn -import torch.nn as nn -from torch.nn import functional as F from ding.torch_utils.network import GRUGatingUnit from einops import rearrange +from torch.nn import functional as F from .kv_caching import KeysValues +from lzero.model.common import SimNorm -@dataclass -class TransformerConfig: - tokens_per_block: int - max_blocks: int - attention: str +class LearnableScale(nn.Module): + """ + A learnable scalar parameter constrained within a specific range. - num_layers: int - num_heads: int - embed_dim: int + The formula `s = offset + scale * tanh(ŝ)` maps an unbounded logit `ŝ` + to the range (offset - scale, offset + scale). Using tanh can sometimes + provide more stable gradients than sigmoid. - embed_pdrop: float - resid_pdrop: float - attn_pdrop: float - - # for RoPE - rope_theta: float - max_seq_len: int - rotary_emb: bool = False + For example, to achieve a range of (0.8, 1.2), one would use + `init=1.0` and `s_range=0.2`. + """ - @property - def max_tokens(self): - return self.tokens_per_block * self.max_blocks + def __init__(self, init: float = 1.0, s_range: float = 0.2) -> None: + """ + Overview: + Initializes the LearnableScale module. + Arguments: + - init (:obj:`float`): The initial value of the scalar, which also serves as the center of the range. + - s_range (:obj:`float`): The scale factor that determines the range (init - s_range, init + s_range). + """ + super().__init__() + assert s_range > 0, "The scaling range must be positive." + self.offset = init + self.scale = s_range + # Initialize the logit to 0, so the initial output is exactly `init`. + self.logit = nn.Parameter(torch.tensor(0.0)) + # TODO: Initially frozen, activated by a CurriculumController. + self.logit.requires_grad = False -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): - """ - Precompute the frequency components for the rotary positional embeddings. + def forward(self) -> torch.Tensor: + """ + Overview: + Computes the scaled value. + Returns: + - torch.Tensor: The learnable scalar, constrained to the specified range. + """ + return self.offset + self.scale * torch.tanh(self.logit) - Arguments: - - dim (int): The dimension of the embedding. - - end (int): The length of the sequence for which frequencies are computed. - - theta (float): A scaling factor for the frequencies, default is 10000.0. +############################################## +# Optimized CurriculumLoRALinear Implementation (Recommended Version) +############################################## - Returns: - - freqs_cis (torch.Tensor): A tensor of complex numbers representing the precomputed frequencies. +class CurriculumLoRALinear(nn.Module): """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device, dtype=torch.float32) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - + Optimized CurriculumLoRALinear. + + Effective weight at stage s: + W_eff = α₀*W₀ + Σ_{j=1 to s} αⱼ*Δθⱼ -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + Optimization logic at stage s (s >= 1): + - Train: Δθₛ, α₀, and {αⱼ | 1 <= j < s} + - Freeze: W₀, {Δθⱼ | 1 <= j < s}, and αₛ + + This avoids the redundancy of training αₛ alongside Δθₛ. """ - Reshape the frequency components for broadcasting with the input tensor. - Arguments: - - freqs_cis (torch.Tensor): The frequency components tensor. - - x (torch.Tensor): The input tensor to which the frequencies will be applied. + def __init__(self, in_features: int, out_features: int, bias: bool = True, + r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + curriculum_stage_num: int = 1, lora_scale_init: float = 1.0) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.r = r + self.lora_alpha = lora_alpha + self.scaling = lora_alpha / r if r > 0 else 1.0 + self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else nn.Identity() + self.curriculum_stage_num = curriculum_stage_num + self.curriculum_stage = 0 + + # Base weights (W₀ and bias) + self.weight = nn.Parameter(torch.empty(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.empty(out_features)) + else: + self.register_parameter('bias', None) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + # Learnable scale for the base weight (α₀) + self.base_weight_scale = LearnableScale(init=1.0, s_range=0.2) + + # A scale for each adapter (α₁, α₂, ...) + self.adapters = nn.ModuleList() + self.adapter_scales = nn.ModuleList() + + if r > 0 and (curriculum_stage_num - 1) > 0: + for _ in range(curriculum_stage_num - 1): + adapter = nn.ParameterDict({ + 'lora_A': nn.Parameter(torch.randn(r, in_features) * 0.01), + 'lora_B': nn.Parameter(torch.zeros(out_features, r)) + }) + self.adapters.append(adapter) + self.adapter_scales.append(LearnableScale(lora_scale_init, s_range=0.2)) + else: + self.adapters = None - Returns: - - torch.Tensor: The reshaped frequency components tensor. - """ - # Reference: https://github.com/meta-llama/llama3/blob/main/llama/model.py#L61 - ndim = x.ndim - shape = [d if i in (0, 2, ndim - 1) else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) + self.set_curriculum_stage(0) + def set_curriculum_stage(self, stage: int) -> None: + assert 0 <= stage < self.curriculum_stage_num, f"Stage must be within [0, {self.curriculum_stage_num-1}]" + self.curriculum_stage = stage + module_id = f"({self.in_features}x{self.out_features})" + + # --- Stage 0: Base Training --- + if stage == 0: + self.weight.requires_grad = True + if self.bias is not None: self.bias.requires_grad = True + + # Freeze everything else + self.base_weight_scale.logit.requires_grad = False + if self.adapters: + for adapter in self.adapters: + adapter['lora_A'].requires_grad = False + adapter['lora_B'].requires_grad = False + for scale in self.adapter_scales: + scale.logit.requires_grad = False + logging.info(f"[CurriculumLoRALinear {module_id}] Stage 0: Base layer trainable.") + + # --- Stage >= 1: Adaptation --- + else: + # Freeze base model + self.weight.requires_grad = False + if self.bias is not None: self.bias.requires_grad = False + + # α₀ is trainable from stage 1 onwards + self.base_weight_scale.logit.requires_grad = True + + if self.adapters: + # Set trainability for LoRA adapters + for idx, adapter in enumerate(self.adapters): + is_current_adapter = (idx == stage - 1) + adapter['lora_A'].requires_grad = is_current_adapter + adapter['lora_B'].requires_grad = is_current_adapter + + # --- OPTIMIZED LOGIC FOR SCALES --- + # Set trainability for adapter scales {α_j} + for idx, scale in enumerate(self.adapter_scales): + # A scale α_j is trainable if it belongs to a *previous* stage (j < s). + # The current stage's scale α_s (idx = stage - 1) is NOT trained. + is_previous_scale = (idx < stage - 1) + scale.logit.requires_grad = is_previous_scale + + logging.info(f"[CurriculumLoRALinear {module_id}] Stage {stage}: Activating adapter {stage - 1} and scales for stages < {stage - 1}.") + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Apply scaling to base weight if in an adaptation stage + if self.curriculum_stage > 0: + alpha_0 = self.base_weight_scale() + scaled_weight = self.weight * alpha_0 + baseline_out = F.linear(x, scaled_weight, self.bias) + else: + baseline_out = F.linear(x, self.weight, self.bias) + + if self.curriculum_stage == 0 or self.adapters is None: + return baseline_out + + adapter_out = 0 + # Iterate through all adapters up to the current stage + for idx in range(self.curriculum_stage): + if idx >= len(self.adapters): + break + + adapter = self.adapters[idx] + scale = self.adapter_scales[idx]() + + lora_x = self.lora_dropout(x) + out = F.linear(lora_x, adapter['lora_A']) + out = F.linear(out, adapter['lora_B']) + + # The forward pass is a simple sum. The magic happens in `set_curriculum_stage` + # which controls `requires_grad`. No need for `.detach()` here. + # Gradients will naturally flow only to parameters with `requires_grad=True`. + adapter_out = adapter_out + self.scaling * out * scale + + return baseline_out + adapter_out + -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: +# ############################################## +# # CurriculumLoRALinear Implementation +# ############################################## + +# class CurriculumLoRALinear(nn.Module): +# """ +# CurriculumLoRALinear extends a standard linear layer with curriculum-based LoRA adapters. + +# This module internally stores a base weight and bias. It also initializes multiple +# LoRA adapters (number = curriculum_stage_num - 1), which are activated sequentially. + +# Forward pass logic: +# - If `curriculum_stage == 0`: +# Output = F.linear(x, W, bias) +# - If `curriculum_stage >= 1`: +# Output = base_output + sum_{i=0}^{curriculum_stage-1} scaling * adapter_i(x) +# where only the adapter for the current stage (index == curriculum_stage - 1) is trainable. +# Previous adapters contribute to the forward pass but their gradients are detached. + +# Note: +# - The `set_curriculum_stage(stage)` method must be called externally to switch between stages. +# - Logging messages indicate the module's dimensions and the freeze/unfreeze status of its parameters. +# """ + +# def __init__(self, in_features: int, out_features: int, bias: bool = True, +# r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, +# curriculum_stage_num: int = 1, lora_scale_init: float = 1.0) -> None: +# """ +# Overview: +# Initializes the CurriculumLoRALinear layer. If `curriculum_stage_num > 1`, +# it creates `curriculum_stage_num - 1` LoRA adapters. +# Arguments: +# - in_features (:obj:`int`): Size of each input sample. +# - out_features (:obj:`int`): Size of each output sample. +# - bias (:obj:`bool`): If True, adds a learnable bias to the output. +# - r (:obj:`int`): The rank of the LoRA decomposition. If 0, LoRA is disabled. +# - lora_alpha (:obj:`int`): The alpha parameter for LoRA scaling. +# - lora_dropout (:obj:`float`): The dropout probability for LoRA layers. +# - curriculum_stage_num (:obj:`int`): The total number of curriculum stages. +# - lora_scale_init (:obj:`float`): The initial value for the learnable scale of each adapter. +# """ +# super().__init__() +# self.in_features = in_features +# self.out_features = out_features +# self.r = r +# self.lora_alpha = lora_alpha +# self.scaling = lora_alpha / r if r > 0 else 1.0 +# self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else nn.Identity() +# self.curriculum_stage_num = curriculum_stage_num +# self.curriculum_stage = 0 # Initial stage is 0 + +# # Initialize base weights (part of the base transformer), trainable by default +# self.weight = nn.Parameter(torch.empty(out_features, in_features)) +# if bias: +# self.bias = nn.Parameter(torch.empty(out_features)) +# else: +# self.register_parameter('bias', None) +# nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) +# if self.bias is not None: +# fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) +# bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 +# nn.init.uniform_(self.bias, -bound, bound) + +# # Initialize LoRA adapters, which exist only if r > 0 and curriculum_stage_num > 1 +# self.adapters = nn.ModuleList() +# self.adapter_scales = nn.ModuleList() + +# if r > 0 and (curriculum_stage_num - 1) > 0: +# for _ in range(curriculum_stage_num - 1): +# adapter = nn.ParameterDict({ +# 'lora_A': nn.Parameter(torch.randn(r, in_features) * 0.01), +# 'lora_B': nn.Parameter(torch.zeros(out_features, r)) +# }) +# self.adapters.append(adapter) +# self.adapter_scales.append(LearnableScale(lora_scale_init, s_range=0.2)) + +# else: +# self.adapters = None + +# # Initially (stage 0), the base layer is trainable, and all adapters are frozen +# self.weight.requires_grad = True +# if self.bias is not None: +# self.bias.requires_grad = True +# if self.adapters is not None: +# for adapter in self.adapters: +# adapter['lora_A'].requires_grad = False +# adapter['lora_B'].requires_grad = False + +# def set_curriculum_stage(self, stage: int) -> None: +# """ +# Overview: +# Sets the current curriculum stage and updates the `requires_grad` status of parameters accordingly. +# - Stage 0: The base layer is trainable; all adapters are frozen. +# - Stage >= 1: The base layer is frozen. Only the current adapter (index = stage - 1) is trainable. +# Previous adapters contribute to the forward pass but do not propagate gradients. +# Arguments: +# - stage (:obj:`int`): The curriculum stage to set, in the range [0, curriculum_stage_num - 1]. +# """ +# assert 0 <= stage < self.curriculum_stage_num, f"Stage must be within [0, {self.curriculum_stage_num-1}]" +# self.curriculum_stage = stage + +# module_id = f"({self.in_features}x{self.out_features})" +# if stage == 0: +# self.weight.requires_grad = True +# if self.bias is not None: +# self.bias.requires_grad = True +# if self.adapters is not None: +# for adapter in self.adapters: +# adapter['lora_A'].requires_grad = False +# adapter['lora_B'].requires_grad = False +# logging.info(f"[CurriculumLoRALinear {module_id}] Stage 0: Base layer is trainable, all adapters are frozen.") +# else: +# # For stages > 0, freeze the base layer +# self.weight.requires_grad = False +# if self.bias is not None: +# self.bias.requires_grad = False + +# if self.adapters is not None: +# for idx, adapter in enumerate(self.adapters): +# is_current_adapter = (idx == stage - 1) +# adapter['lora_A'].requires_grad = is_current_adapter +# adapter['lora_B'].requires_grad = is_current_adapter +# status = "activated (trainable)" if is_current_adapter else "frozen (forward-only)" +# logging.info(f"[CurriculumLoRALinear {module_id}] Stage {stage}: Adapter {idx} is {status}.") + +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# """ +# Overview: +# Performs the forward pass of the CurriculumLoRALinear layer. +# Arguments: +# - x (:obj:`torch.Tensor`): The input tensor. +# Returns: +# - torch.Tensor: The output tensor. +# """ +# baseline_out = F.linear(x, self.weight, self.bias) +# if self.curriculum_stage == 0 or self.adapters is None: +# return baseline_out + +# adapter_out = 0 +# # For the first `curriculum_stage` adapters, only the last one backpropagates. +# # Others are detached to contribute only to the forward pass. +# for idx in range(self.curriculum_stage): +# if idx >= len(self.adapters): +# break +# adapter = self.adapters[idx] +# lora_x = self.lora_dropout(x) +# out = F.linear(lora_x, adapter['lora_A']) +# out = F.linear(out, adapter['lora_B']) + +# scale = self.adapter_scales[idx]() + +# # NOTE: All adapter scales are currently trainable. +# if idx == self.curriculum_stage - 1: +# # Only the current adapter's output contributes to the gradient computation. +# adapter_out = adapter_out + self.scaling * out * scale +# else: +# # Outputs from previous adapters are detached. +# adapter_out = adapter_out + self.scaling * out.detach() * scale + +# return baseline_out + adapter_out + + +############################################## +# Helper function to wrap linear layers +############################################## + +def _maybe_wrap_linear(linear: nn.Linear, config, module_label: str) -> nn.Module: """ - Apply rotary positional embeddings to the query and key tensors. - + Overview: + A helper function that wraps an `nn.Linear` layer with `CurriculumLoRALinear` + if LoRA and curriculum learning are enabled for the specified module. Arguments: - - xq (torch.Tensor): The query tensor. - - xk (torch.Tensor): The key tensor. - - freqs_cis (torch.Tensor): The precomputed frequency components. - + - linear (:obj:`nn.Linear`): The original linear layer to be potentially wrapped. + - config: The model configuration object. + - module_label (:obj:`str`): A label identifying the module type (e.g., "attn", "feed_forward"). Returns: - - Tuple[torch.Tensor, torch.Tensor]: The transformed query and key tensors. - - Note: - For more information on rotary positional embeddings, refer to the blog post: - https://spaces.ac.cn/archives/8265/ or paper https://arxiv.org/abs/2104.09864 + - nn.Module: The wrapped `CurriculumLoRALinear` layer or the original `nn.Linear` layer. """ - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) - return xq_out.type_as(xq), xk_out.type_as(xk) + use_curriculum_lora = ( + config.lora_r > 0 and + module_label in config.lora_target_modules and + getattr(config, "curriculum_stage_num", 1) > 1 + ) + if use_curriculum_lora: + new_linear = CurriculumLoRALinear( + in_features=linear.in_features, + out_features=linear.out_features, + bias=(linear.bias is not None), + r=config.lora_r, + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + curriculum_stage_num=config.curriculum_stage_num, + lora_scale_init=config.lora_scale_init + ) + new_linear.weight.data.copy_(linear.weight.data) + if linear.bias is not None: + new_linear.bias.data.copy_(linear.bias.data) + return new_linear + else: + return linear -class Transformer(nn.Module): - """ - Transformer model class. +############################################## +# Helper function to set curriculum stage +############################################## +def set_curriculum_stage(model: nn.Module, stage: int) -> None: + """ + Overview: + Recursively traverses all submodules of a given model, finds all instances + of `CurriculumLoRALinear`, and calls their `set_curriculum_stage` method. + This function is generic and can be applied to any model structure. Arguments: - - config (:obj:`TransformerConfig`): Configuration for the Transformer model. + - model (:obj:`nn.Module`): The model to update (e.g., a Transformer or Vision Transformer). + - stage (:obj:`int`): The curriculum stage to set. + """ + count = 0 + for module in model.modules(): + if isinstance(module, CurriculumLoRALinear): + module.set_curriculum_stage(stage) + count += 1 + if count > 0: + logging.info(f"[Curriculum] Updated {count} CurriculumLoRALinear modules in {type(model).__name__} to stage {stage}.") + +# Alias for backward compatibility +set_curriculum_stage_for_transformer = set_curriculum_stage + + +############################################## +# Transformer Configuration +############################################## +@dataclass +class TransformerConfig: + """Configuration for the Transformer model.""" + tokens_per_block: int + max_blocks: int + attention: str + + num_layers: int + num_heads: int + embed_dim: int + + embed_pdrop: float + resid_pdrop: float + attn_pdrop: float + + # LoRA parameters + lora_r: int = 0 + lora_alpha: int = 1 + lora_dropout: float = 0.0 + lora_target_modules: list = None + + # Curriculum Learning parameters + # `curriculum_stage_num` is the total number of stages (e.g., 3 means stages 0, 1, 2) + curriculum_stage_num: int = 1 # 1 (base) + number of available LoRA adapters + min_stage0_iters: int = 10_000 # Minimum iterations for stage 0 + max_stage_iters: int = 20_000 # Maximum iterations per stage + lora_scale_init: float = 1.0 # Initial value for learnable adapter scales + + # Other configurations + task_embed_option: str = "none" + register_token_num: int = 4 + register_token_shared: bool = True + + gru_gating: bool = False + moe_in_transformer: bool = False + multiplication_moe_in_transformer: bool = False + num_experts_of_moe_in_transformer: int = 1 + + @property + def max_tokens(self) -> int: + """Maximum number of tokens the model can handle.""" + return self.tokens_per_block * self.max_blocks + - Attributes: - - config (:obj:`TransformerConfig`): Configuration object. - - drop (:obj:`nn.Dropout`): Dropout layer for embedding dropout. - - blocks (:obj:`nn.ModuleList`): List of Transformer blocks. - - ln_f (:obj:`nn.LayerNorm`): Layer normalization applied to the final output. +class Transformer(nn.Module): + """ + A Transformer model implementation. """ - def __init__(self, config: TransformerConfig) -> None: + def __init__(self, config: TransformerConfig, task_embed: Optional[nn.Module] = None) -> None: + """ + Overview: + Initializes the Transformer model. + Arguments: + - config (:obj:`TransformerConfig`): The configuration object for the model. + - task_embed (:obj:`Optional[nn.Module]`): An optional module for generating task embeddings. + """ super().__init__() self.config = config self.drop = nn.Dropout(config.embed_pdrop) self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)]) self.ln_f = nn.LayerNorm(config.embed_dim) - if self.config.rotary_emb: - freqs_cis = precompute_freqs_cis( - self.config.embed_dim // self.config.num_heads, - self.config.max_seq_len * 2, - self.config.rope_theta, - ) - self.register_buffer("freqs_cis", freqs_cis) + self.task_embed = task_embed + self.task_embed_option = self.config.task_embed_option + self.use_register_token = (self.task_embed_option == "register_task_embed") + + if self.use_register_token: + self.register_token_num = getattr(config, "register_token_num", 4) + self.register_token_shared = getattr(config, "register_token_shared", True) + + if self.register_token_shared: + # Shared mode: all tasks use the same register_tokens parameter. + self.register_tokens = nn.Parameter(torch.empty(self.register_token_num, config.embed_dim)) + nn.init.xavier_uniform_(self.register_tokens) + else: + # Non-shared mode: relies on the external `task_embed` module to generate + # task-specific embeddings, which are then normalized and expanded. + self.task_embed = task_embed + self.sim_norm = SimNorm(simnorm_dim=config.embed_dim) - def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: + def add_register_tokens(self, sequences: torch.Tensor, task_id: int) -> torch.Tensor: """ - Generate a placeholder for keys and values. + Overview: + Prepends or appends register tokens to the input sequences. + Arguments: + - sequences (:obj:`torch.Tensor`): The input sequences, with shape (B, T, C). + - task_id (:obj:`int`): The ID of the current task. + Returns: + - torch.Tensor: The sequences with register tokens concatenated, shape (B, T + register_token_num, C). + """ + B = sequences.size(0) + device = sequences.device + + if self.register_token_shared: + # Shared mode: use the same set of register tokens for all batches. + register_tokens = self.register_tokens.unsqueeze(0).expand(B, -1, -1) + else: + # Non-shared mode: dynamically generate task embedding and expand it. + task_embedding = self.task_embed(torch.tensor([task_id], device=device)) + task_embedding = self.sim_norm(task_embedding.view(1, -1)).view(-1) + register_tokens = task_embedding.unsqueeze(0).expand(self.register_token_num, -1) + register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) + + # Concatenate register tokens at the end of the sequence. + new_sequences = torch.cat([sequences, register_tokens], dim=1) + return new_sequences + def remove_register_tokens_from_kv(self, past_keys_values: Optional[KeysValues]) -> None: + """ + Overview: + Removes the register tokens from the key-value cache of all layers. + This is called at the end of the forward pass during inference. Arguments: - - n (:obj:`int`): Batch size. - - max_tokens (:obj:`int`): Maximum number of tokens in the sequence. + - past_keys_values (:obj:`Optional[KeysValues]`): The key-value cache. + """ + if past_keys_values is not None: + past_keys_values.remove_register_tokens(self.register_token_num) + def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: + """ + Overview: + Generates a placeholder for the key-value cache. + Arguments: + - n (:obj:`int`): The batch size. + - max_tokens (:obj:`int`): The maximum number of tokens in the sequence. Returns: - - KeysValues: An object containing empty keys and values. + - KeysValues: An object containing empty tensors for keys and values. """ - device = self.ln_f.weight.device # Assumption: All submodules are on the same device + device = self.ln_f.weight.device return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) - def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0) -> torch.Tensor: + def forward( + self, + sequences: torch.Tensor, + past_keys_values: Optional[KeysValues] = None, + valid_context_lengths: Optional[torch.Tensor] = None, + task_id: int = 0, + start_pos: int = 0 + ) -> torch.Tensor: """ - Forward pass of the Transformer model. - + Overview: + Performs the forward pass of the Transformer model. Arguments: - - sequences (:obj:`torch.Tensor`): Input tensor of shape (batch_size, seq_length, embed_dim). - - past_keys_values (:obj:`Optional[KeysValues]`): Precomputed keys and values for faster generation (default: None). - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid lengths of context for masking (default: None). - - start_pos (:obj:`int`): Starting position for rotary embeddings (default: 0). - + - sequences (:obj:`torch.Tensor`): The input tensor of shape (B, T, C). + - past_keys_values (:obj:`Optional[KeysValues]`): An optional cache for keys and values to speed up inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Tensor indicating the valid length of the context for each sample. + - task_id (:obj:`int`): The ID of the current task. + - start_pos (:obj:`int`): The starting position for the current sequence (used with kv-caching). Returns: - - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). + - torch.Tensor: The output tensor of shape (B, T, C). """ - seqlen = sequences.shape[1] - # If using Rotary Position Embeddings (RoPE), slice the frequency components accordingly - if self.config.rotary_emb: - if isinstance(start_pos, (int, float, np.integer)): - # In the reanalyze_phase or reset stage in collection/evaluation phase, create a tensor filled with start_pos, expanded to match the batch size, and adjust for sequence type, e.g., start_pos=2. - start_pos_tensor = torch.full((sequences.shape[0],), int(start_pos), device=sequences.device) - elif isinstance(start_pos, (list, np.ndarray, torch.Tensor)): - if isinstance(start_pos[0], (np.ndarray, torch.Tensor, list)): - # In the training phase, flatten start_pos, take the first element, convert to tensor, e.g., start_pos=[array([ 8, 10, 12, 14, 16]), array([12, 14, 16, 18, 20])] - start_pos_tensor = torch.as_tensor( - [x.reshape(-1)[0].item() for x in start_pos], # Force flatten and take the first element - device=sequences.device - ) - elif isinstance(start_pos[0], (int, float, np.integer)): - # In the collection/evaluation phase, e.g., start_pos = [0, 0, 0, 0, 0, 0, 0, 0] - start_pos_tensor = torch.as_tensor([int(x) for x in start_pos], device=sequences.device) - else: - raise ValueError("start_pos must be an int, float, list, numpy array or torch.Tensor.") - - # TODO: Determine how to handle cases when episode length exceeds max_seq_len - # Use modulo operation to ensure start_pos does not exceed max_seq_len - start_pos_tensor = torch.remainder(start_pos_tensor, self.config.max_seq_len) - # Convert each sample's start_pos to a list - start_pos_list = start_pos_tensor.tolist() - # For each sample, slice the corresponding range of freqs_cis based on start_pos - freqs_cis_slices = [self.freqs_cis[int(pos): int(pos) + seqlen] for pos in start_pos_list] - freqs_cis = torch.stack(freqs_cis_slices) - - if freqs_cis.ndim == 3 and freqs_cis.shape[1] == 1: - # Convert shape [seq_len, 1, num_pairs] to [seq_len, num_pairs] - freqs_cis = freqs_cis.squeeze(1) - else: - freqs_cis = None - - # print(f"freqs_cis.shape:{freqs_cis.shape}") + if self.use_register_token: + sequences = self.add_register_tokens(sequences, task_id) - # Ensure past keys and values match the number of transformer blocks - assert past_keys_values is None or len(past_keys_values) == len(self.blocks) - # Apply dropout to the input sequences x = self.drop(sequences) - # Pass through each transformer block + for i, block in enumerate(self.blocks): - x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths, freqs_cis) - # Apply final layer normalization + kv_cache_layer = None if past_keys_values is None else past_keys_values[i] + x = block(x, kv_cache_layer, valid_context_lengths) + x = self.ln_f(x) + + if self.use_register_token: + # During inference, remove register tokens from the KV cache to maintain consistency + # for external logic that does not expect them. + if past_keys_values is not None: + self.remove_register_tokens_from_kv(past_keys_values) + + # TODO: Remove register tokens from the final output to match the input sequence length. + x = x[:, :-self.register_token_num, :] + return x class Block(nn.Module): """ - Transformer block class. - - Arguments: - config (:obj:`TransformerConfig`): Configuration for the Transformer block. - - Attributes: - - gru_gating (:obj:`bool`): Flag to use GRU gating mechanism. - - gru_bias (:obj:`float`): Bias for the GRU gating mechanism. - - gate1 (:obj:`Optional[GRUGatingUnit]`): First GRU gating unit (if GRU gating is enabled). - - gate2 (:obj:`Optional[GRUGatingUnit]`): Second GRU gating unit (if GRU gating is enabled). - - ln1 (:obj:`nn.LayerNorm`): Layer normalization before the attention layer. - - ln2 (:obj:`nn.LayerNorm`): Layer normalization before the MLP. - - attn (:obj:`SelfAttention`): Self-attention mechanism. - - mlp (:obj:`nn.Sequential`): Multi-layer perceptron. + A single Transformer block, consisting of self-attention and a feed-forward network. """ def __init__(self, config: TransformerConfig) -> None: + """ + Overview: + Initializes a Transformer block. + Arguments: + - config (:obj:`TransformerConfig`): The configuration object for the block. + """ super().__init__() - # NOTE: GRU gating as in GTrXL self.gru_gating = config.gru_gating - self.gru_bias = 2.0 if self.gru_gating: - self.gate1 = GRUGatingUnit(config.embed_dim, self.gru_bias) - self.gate2 = GRUGatingUnit(config.embed_dim, self.gru_bias) + # As in GTrXL, for stabilizing training with recurrence + self.gate1 = GRUGatingUnit(config.embed_dim, bias_init=2.0) + self.gate2 = GRUGatingUnit(config.embed_dim, bias_init=2.0) self.ln1 = nn.LayerNorm(config.embed_dim) self.ln2 = nn.LayerNorm(config.embed_dim) self.attn = SelfAttention(config) - self.mlp = nn.Sequential( - nn.Linear(config.embed_dim, 4 * config.embed_dim), - nn.GELU(approximate='tanh'), - nn.Linear(4 * config.embed_dim, config.embed_dim), - nn.Dropout(config.resid_pdrop), - ) + + if config.moe_in_transformer: + from .moe import MoELayer + # Create multiple independent MLP instances as experts + self.experts = nn.ModuleList([ + nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) for _ in range(config.num_experts_of_moe_in_transformer) + ]) + self.feed_forward = MoELayer( + config, + experts=self.experts, + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=config.num_experts_per_tok, + ) + logging.info(f"Using MoE in transformer feed-forward with {config.num_experts_of_moe_in_transformer} experts.") + elif config.multiplication_moe_in_transformer: + from .moe import MoELayer, MultiplicationFeedForward + # Create multiple FeedForward instances for multiplication-based MoE + self.experts = nn.ModuleList([ + MultiplicationFeedForward(config) for _ in range(config.num_experts_of_moe_in_transformer) + ]) + self.feed_forward = MoELayer( + config, + experts=self.experts, + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=config.num_experts_per_tok, + ) + logging.info(f"Using Multiplication MoE in transformer feed-forward with {config.num_experts_of_moe_in_transformer} experts.") + else: + # Standard MLP, with linear layers potentially wrapped for LoRA. + self.feed_forward = nn.Sequential( + _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim), config, "feed_forward"), + nn.GELU(approximate='tanh'), + _maybe_wrap_linear(nn.Linear(4 * config.embed_dim, config.embed_dim), config, "feed_forward"), + nn.Dropout(config.resid_pdrop), + ) def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None, freqs_cis: torch.Tensor = None) -> torch.Tensor: + valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Forward pass of the Transformer block. - + Overview: + Performs the forward pass of the Transformer block. Arguments: - x (:obj:`torch.Tensor`): Input tensor of shape (batch_size, seq_length, embed_dim). - - past_keys_values (:obj:`Optional[KeysValues]`): Precomputed keys and values for faster generation (default: None). - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid lengths of context for masking (default: None). - - freqs_cis (:obj:`torch.Tensor`): Frequency components for rotary position embeddings, used to modulate the attention mechanism (default: None). - + - past_keys_values (:obj:`Optional[KeysValues]`): Precomputed keys and values for faster generation. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid lengths of context for masking. Returns: - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). """ - x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, freqs_cis) + attn_output = self.attn(self.ln1(x), past_keys_values, valid_context_lengths) if self.gru_gating: - x = self.gate1(x, x_attn) - x = self.gate2(x, self.mlp(self.ln2(x))) + x = self.gate1(x, attn_output) + ff_output = self.feed_forward(self.ln2(x)) + x = self.gate2(x, ff_output) else: - x = x + x_attn - x = x + self.mlp(self.ln2(x)) - + x = x + attn_output + x = x + self.feed_forward(self.ln2(x)) return x class SelfAttention(nn.Module): """ - Implements self-attention mechanism for transformers. - - Arguments: - config (:obj:`TransformerConfig`): Configuration object containing hyperparameters. - - Attributes: - - config (:obj:`TransformerConfig`): Stores the configuration for the self-attention module. - - num_heads (:obj:`int`): Number of attention heads. - - key (:obj:`nn.Linear`): Linear layer to project input to key vectors. - - query (:obj:`nn.Linear`): Linear layer to project input to query vectors. - - value (:obj:`nn.Linear`): Linear layer to project input to value vectors. - - attn_drop (:obj:`nn.Dropout`): Dropout layer for attention weights. - - resid_drop (:obj:`nn.Dropout`): Dropout layer for residual connection. - - proj (:obj:`nn.Linear`): Final linear layer for projection. - - mask (:obj:`torch.Tensor`): Mask tensor for causal or block-causal attention. + Implements the self-attention mechanism for a Transformer. """ + def __init__(self, config: TransformerConfig) -> None: + """ + Overview: + Initializes the SelfAttention module. + Arguments: + - config (:obj:`TransformerConfig`): The configuration object for the attention module. + """ super().__init__() assert config.embed_dim % config.num_heads == 0, "Embedding dimension must be divisible by number of heads." self.config = config self.num_heads = config.num_heads + + self.task_embed_option = self.config.task_embed_option + self.use_register_token = (self.task_embed_option == "register_task_embed") + if self.use_register_token: + self.register_token_num = getattr(config, "register_token_num", 4) - self.key = nn.Linear(config.embed_dim, config.embed_dim) - self.query = nn.Linear(config.embed_dim, config.embed_dim) - self.value = nn.Linear(config.embed_dim, config.embed_dim) + # Wrap linear layers if LoRA is enabled for the attention module + self.key = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") + self.query = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") + self.value = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") + self.proj = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") self.attn_drop = nn.Dropout(config.attn_pdrop) self.resid_drop = nn.Dropout(config.resid_pdrop) - self.proj = nn.Linear(config.embed_dim, config.embed_dim) - causal_mask = torch.tril(torch.ones(config.max_tokens, config.max_tokens)) + # TODO: The mask size is conservatively large to accommodate register tokens. + # This could be made more dynamic. + mask_size = config.max_tokens + if self.use_register_token: + mask_size += self.register_token_num * 5 + causal_mask = torch.tril(torch.ones(mask_size, mask_size)) self.register_buffer('mask', causal_mask) def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None, freqs_cis: torch.Tensor = None) -> torch.Tensor: + valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Forward pass for the self-attention mechanism. - + Overview: + Performs the forward pass for the self-attention mechanism. Arguments: - - x (:obj:`torch.Tensor`): Input tensor of shape (B, T, C) where B is batch size, - T is sequence length, and C is embedding dimension. + - x (:obj:`torch.Tensor`): Input tensor of shape (B, T, C). - kv_cache (:obj:`Optional[KeysValues]`): Optional key-value cache for faster inference. - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Optional tensor containing valid context lengths. - - freqs_cis (:obj:`torch.Tensor`): Frequency components for rotary position embeddings, used to modulate the attention mechanism (default: None). - Returns: - torch.Tensor: Output tensor of shape (B, T, C). """ B, T, C = x.size() + head_size = C // self.num_heads + + past_len = 0 if kv_cache is not None: - b, nh, L, c = kv_cache.shape - assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions do not match input dimensions." - else: - L = 0 + past_len = kv_cache.shape[2] - q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) - k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) - v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) - - if self.config.rotary_emb: - q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) + q = self.query(x).view(B, T, self.num_heads, head_size).transpose(1, 2) + k = self.key(x).view(B, T, self.num_heads, head_size).transpose(1, 2) + v = self.value(x).view(B, T, self.num_heads, head_size).transpose(1, 2) if kv_cache is not None: - kv_cache.update(k, v) # time occupancy 21% - k, v = kv_cache.get() # time occupancy 5% + kv_cache.update(k, v) + k, v = kv_cache.get() + current_len = k.size(2) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + # Construct the attention mask + mask = self.mask[past_len:past_len + T, :current_len] + if valid_context_lengths is not None: - # Final mask.shape: (B, T, L + T) - # L is the context length, T is the current input length, - # valid_context_lengths is the valid length at the end of the context. - mask = torch.zeros(B, T, L + T, device=att.device) - # For each sample, set the invalid parts to 0 based on its valid length. + # This logic is for a specific use case and may need adjustment. + # It creates a custom mask for each item in the batch. + batch_mask = torch.zeros(B, T, current_len, device=att.device) for i in range(B): - mask[i] = self.mask[L:L + T, :L + T].clone() - mask[i, :, :(L - valid_context_lengths[i])] = 0 # Set invalid parts to 0. - # Adjust mask dimensions to match the last two dimensions of att. - # (B, T, L + T) -> (B, 1, T, L + T) -> (B, num_heads, T, L + T) - mask = mask.unsqueeze(1).expand(-1, att.size(1), -1, -1) - else: - # mask.shape: (T, L + T) - mask = self.mask[L:L + T, :L + T] + batch_mask[i] = mask.clone() + # Zero out attention to invalid past context + batch_mask[i, :, :(past_len - valid_context_lengths[i])] = 0 + mask = batch_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) + + # Adjust mask for register tokens if they are in use + if self.use_register_token and self.register_token_num > 0: + # Allow all positions to attend to register tokens and vice-versa + register_mask = mask.clone() + # Register tokens are at the end of the sequence + register_indices_start = current_len - self.register_token_num + register_mask[..., register_indices_start:] = 1 # All can see registers + # This part is more complex if T is not the full sequence length + if T > self.register_token_num: + # Only the actual register tokens in the current input `x` can see everything + register_mask[..., -self.register_token_num:, :] = 1 + mask = register_mask + + if kv_cache is not None: + # Ensure mask dimensions match the potentially smaller KV cache length + new_L = kv_cache.shape[2] + mask = mask[..., :new_L] - # att.shape: (B, num_heads, T, L + T) att = att.masked_fill(mask == 0, float('-inf')) - att = F.softmax(att, dim=-1) att = self.attn_drop(att) - y = att @ v # (B, num_heads, T, L + T) x (B, num_heads, L + T, head_size) -> (B, num_heads, T, head_size) - y = rearrange(y, 'b h t e -> b t (h e)') # Combine the heads back together (B, T, embed_dim) + y = att @ v + y = rearrange(y, 'b h t e -> b t (h e)') y = self.resid_drop(self.proj(y)) return y @@ -375,48 +813,41 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, def get_attention_map(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Compute the attention map for the input sequence. This is useful for visualization purposes. - More details can be found in visualizing_utils.py. - + Overview: + Computes the attention map for visualization, without computing the final output. Arguments: - x (:obj:`torch.Tensor`): Input sequence with shape (B, T, C). - - kv_cache (:obj:`Optional[KeysValues]`): Cached keys and values for supporting long sequence inference. - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths for handling variable-length contexts. - + - kv_cache (:obj:`Optional[KeysValues]`): Cached keys and values for long sequence inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths for variable-length inputs. Returns: - - torch.Tensor: Attention map with shape (B, nh, T, L + T), representing the distribution of attention. + - torch.Tensor: Attention map of shape (B, num_heads, T, L + T). """ B, T, C = x.size() + head_size = C // self.num_heads + + past_len = 0 if kv_cache is not None: - b, nh, L, c = kv_cache.shape - assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions are inconsistent with input dimensions." - else: - L = 0 + past_len = kv_cache.shape[2] - # Compute query, key, and value projections - q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) - k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) - v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.num_heads, head_size).transpose(1, 2) + k = self.key(x).view(B, T, self.num_heads, head_size).transpose(1, 2) + v = self.value(x).view(B, T, self.num_heads, head_size).transpose(1, 2) if kv_cache is not None: - # Update the kv_cache with the new keys and values kv_cache.update(k, v) k, v = kv_cache.get() - # Compute the attention scores + current_len = k.size(2) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + mask = self.mask[past_len:past_len + T, :current_len] if valid_context_lengths is not None: - mask = torch.zeros(B, T, L + T, device=att.device) + batch_mask = torch.zeros(B, T, current_len, device=att.device) for i in range(B): - # Create attention mask for each batch - mask[i] = self.mask[L:L + T, :L + T].clone() - mask[i, :, :(L - valid_context_lengths[i])] = 0 - mask = mask.unsqueeze(1).expand(-1, att.size(1), -1, -1) - else: - mask = self.mask[L:L + T, :L + T] + batch_mask[i] = mask.clone() + batch_mask[i, :, :(past_len - valid_context_lengths[i])] = 0 + mask = batch_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) - # Apply the attention mask att = att.masked_fill(mask == 0, float('-inf')) att = F.softmax(att, dim=-1) diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index 99c841cbe..bde598061 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -179,17 +179,44 @@ def calculate_cuda_memory_gb(past_keys_values_cache, num_layers: int): total_memory_gb = total_memory_bytes / (1024 ** 3) return total_memory_gb -def hash_state(state): +# def hash_state(state): +# """ +# Hash the state vector. + +# Arguments: +# state: The state vector to be hashed. +# Returns: +# The hash value of the state vector. +# """ +# # Use xxhash for faster hashing +# return xxhash.xxh64(state).hexdigest() + +def hash_state(state: np.ndarray) -> int: """ - Hash the state vector. + Overview: + Computes a fast and robust hash for a NumPy array state. + + Why this is optimal: + 1. Algorithm (`xxhash.xxh64`): Uses one of the fastest non-cryptographic hash + functions available, ideal for performance-critical applications like caching. + 2. Input Preparation (`state.tobytes()`): Ensures correctness by creating a + canonical byte representation of the array. This guarantees that two + logically identical arrays will produce the same hash, regardless of their + internal memory layout (e.g., C-contiguous, F-contiguous, or strided views). + 3. Output Format (`.intdigest()`): Directly produces an integer hash value, + which is the most efficient key type for Python dictionaries, avoiding the + overhead of string keys. Arguments: - state: The state vector to be hashed. + - state (np.ndarray): The state array to be hashed. + Returns: - The hash value of the state vector. + - int: A 64-bit integer hash of the state. """ - # Use xxhash for faster hashing - return xxhash.xxh64(state).hexdigest() + # Ensure the array is contiguous in memory before converting to bytes, + # although .tobytes() handles this, being explicit can sometimes be clearer. + # For simplicity and since .tobytes() defaults to C-order, we can rely on it. + return xxhash.xxh64(state.tobytes()).intdigest() @dataclass class WorldModelOutput: @@ -201,22 +228,36 @@ class WorldModelOutput: logits_value: torch.FloatTensor -def init_weights(module, norm_type='BN'): +def init_weights(module, norm_type='BN',liner_weight_zero=False): """ Initialize the weights of the module based on the specified normalization type. - Arguments: module (nn.Module): The module to initialize. norm_type (str): The type of normalization to use ('BN' for BatchNorm, 'LN' for LayerNorm). """ - if isinstance(module, (nn.Linear, nn.Embedding)): + if isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: + elif isinstance(module, nn.Linear): + # 现在这个分支可以被正确执行了 + if norm_type == 'BN': + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + print("Init Linear using kaiming normal for BN") + elif norm_type == 'LN': + # 对于Transformer结构,Xavier/Glorot更常见 + nn.init.xavier_uniform_(module.weight) + print("Init Linear using xavier uniform for LN") + + if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): print(f"Init {module} using zero bias, 1 weight") - module.bias.data.zero_() - module.weight.data.fill_(1.0) + try: + module.weight.data.fill_(1.0) + module.bias.data.zero_() + except Exception as e: + print(e) + elif isinstance(module, nn.BatchNorm2d): print(f"Init nn.BatchNorm2d using zero bias, 1 weight") module.weight.data.fill_(1.0) @@ -228,13 +269,47 @@ def init_weights(module, norm_type='BN'): elif norm_type == 'LN': nn.init.xavier_uniform_(module.weight) print(f"Init nn.Conv2d using xavier uniform for LN") - elif isinstance(module, nn.Linear): - if norm_type == 'BN': - nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') - print("Init Linear using kaiming normal for BN") - elif norm_type == 'LN': - nn.init.xavier_uniform_(module.weight) - print("Init Linear using xavier uniform for LN") + +# def init_weights(module, norm_type='BN'): +# """ +# Initialize the weights of the module based on the specified normalization type. + +# Arguments: +# module (nn.Module): The module to initialize. +# norm_type (str): The type of normalization to use ('BN' for BatchNorm, 'LN' for LayerNorm). +# """ +# if isinstance(module, (nn.Linear, nn.Embedding)): +# module.weight.data.normal_(mean=0.0, std=0.02) +# if isinstance(module, nn.Linear) and module.bias is not None: +# module.bias.data.zero_() +# elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): +# print(f"Init {module} using zero bias, 1 weight") +# try: +# module.bias.data.zero_() +# except Exception as e: +# print(e) +# try: +# module.weight.data.fill_(1.0) +# except Exception as e: +# print(e) +# elif isinstance(module, nn.BatchNorm2d): +# print(f"Init nn.BatchNorm2d using zero bias, 1 weight") +# module.weight.data.fill_(1.0) +# module.bias.data.zero_() +# elif isinstance(module, nn.Conv2d): +# if norm_type == 'BN': +# nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') +# print(f"Init nn.Conv2d using kaiming normal for BN") +# elif norm_type == 'LN': +# nn.init.xavier_uniform_(module.weight) +# print(f"Init nn.Conv2d using xavier uniform for LN") +# elif isinstance(module, nn.Linear): +# if norm_type == 'BN': +# nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') +# print("Init Linear using kaiming normal for BN") +# elif norm_type == 'LN': +# nn.init.xavier_uniform_(module.weight) +# print("Init Linear using xavier uniform for LN") class LossWithIntermediateLosses: @@ -294,7 +369,7 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu self.loss_total += self.perceptual_loss_weight * v self.intermediate_losses = { - k: v if isinstance(v, dict) or isinstance(v, torch.Tensor) else (v if isinstance(v, float) else v.item()) + k: v if isinstance(v, dict) or isinstance(v, np.ndarray) or isinstance(v, torch.Tensor) else (v if isinstance(v, float) else v.item()) for k, v in kwargs.items() } diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index 7f1a0f68e..eff859a4f 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -9,12 +9,28 @@ from torch.distributions import Categorical, Independent, Normal, TransformedDistribution, TanhTransform from lzero.model.common import SimNorm -from lzero.model.utils import cal_dormant_ratio +from lzero.model.utils import calculate_dormant_ratio, compute_average_weight_magnitude, compute_effective_rank from .kv_caching import KeysValues from .slicer import Head, PolicyHeadCont from .tokenizer import Tokenizer from .transformer import Transformer, TransformerConfig from .utils import LossWithIntermediateLosses, init_weights, WorldModelOutput, hash_state +from collections import OrderedDict +logging.getLogger().setLevel(logging.DEBUG) + +from collections import OrderedDict, defaultdict +import matplotlib.pyplot as plt +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +from sklearn.manifold import TSNE +import torch +import numpy as np +import matplotlib.pyplot as plt +from sklearn.manifold import TSNE +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +import os +import datetime +import torch +import torch.nn as nn logging.getLogger().setLevel(logging.DEBUG) @@ -41,8 +57,11 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: super().__init__() self.tokenizer = tokenizer self.config = config - self.transformer = Transformer(self.config) + self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings + self.transformer = Transformer(self.config) + self.task_num = 1 + self.env_num = self.config.env_num if self.config.device == 'cpu': self.device = torch.device('cpu') else: @@ -51,6 +70,8 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: logging.info(f"self.device: {self.device}") self.to(self.device) + self.task_embed_dim = config.task_embed_dim if hasattr(config, "task_embed_dim") else 96 + # Initialize configuration parameters self._initialize_config_parameters() @@ -65,6 +86,11 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.precompute_pos_emb_diff_kv() print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 + if self.task_embed_option == "concat_task_embed": + self.obs_per_embdding_dim = self.config.embed_dim - self.task_embed_dim + else: + self.obs_per_embdding_dim = self.config.embed_dim self.continuous_action_space = self.config.continuous_action_space # Initialize action embedding table @@ -93,6 +119,16 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + self.head_dict = {} + for name, module in self.named_children(): + if name.startswith("head_"): + self.head_dict[name] = module + if self.head_dict: + self.head_dict = nn.ModuleDict(self.head_dict) + + # Apply weight initialization, the order is important + # self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) + # Build the set of modules to skip during re-initialization. # This is compatible with cases where self.tokenizer.encoder does not have 'pretrained_model', # or self.tokenizer does not have 'decoder_network'. @@ -115,8 +151,8 @@ def custom_init(module): self._initialize_last_layer() - # Cache structures - self._initialize_cache_structures() + # # Cache structures + # self._initialize_cache_structures() # Projection input dimension self._initialize_projection_input_dim() @@ -130,18 +166,25 @@ def custom_init(module): self.latent_recon_loss = torch.tensor(0., device=self.device) self.perceptual_loss = torch.tensor(0., device=self.device) + # 先设置为game_segment_length,以保持self.shared_pool_init_infer都是有效的kv + # TODO: 非常重要,应该改为和segment_length一样 + self.shared_pool_size_init = int(self.config.game_segment_length) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + # TODO: check the size of the shared pool # for self.kv_cache_recurrent_infer # If needed, recurrent_infer should store the results of the one MCTS search. self.num_simulations = getattr(self.config, 'num_simulations', 50) - self.shared_pool_size = int(self.num_simulations*self.env_num) - self.shared_pool_recur_infer = [None] * self.shared_pool_size + + + self.shared_pool_size_recur = int(self.num_simulations*self.env_num) + self.shared_pool_recur_infer = [None] * self.shared_pool_size_recur self.shared_pool_index = 0 + # Cache structures + self._initialize_cache_structures() + # for self.kv_cache_init_infer # In contrast, init_infer only needs to retain the results of the most recent step. - # self.shared_pool_size_init = int(2*self.env_num) - self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] @@ -152,6 +195,138 @@ def custom_init(module): self.reanalyze_phase = False + def _initialize_cache_structures(self) -> None: + """Initialize cache structures for past keys and values.""" + from collections import defaultdict + + # self.past_kv_cache_recurrent_infer = defaultdict(dict) + # self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)] + + self.past_kv_cache_recurrent_infer = {} + self.pool_idx_to_key_map_recur_infer = [None] * self.shared_pool_size_recur + self.past_kv_cache_init_infer_envs = [{} for _ in range(self.env_num)] + # 辅助数据结构,用于反向查找:pool_index -> key + self.pool_idx_to_key_map_init_envs = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + + def _analyze_latent_representation( + self, + latent_states: torch.Tensor, + timesteps: torch.Tensor, + game_states: torch.Tensor, + predicted_values: torch.Tensor, + predicted_rewards: torch.Tensor, + step_counter: int + ): + """ + 分析并记录 latent states 的统计信息和t-SNE可视化。 + 【新功能】:在t-SNE图上显示对应的游戏图像,并标注预测的Value和Reward。 + 【已修改】:如果保存路径已存在同名文件,则在文件名后附加时间戳。 + + Args: + latent_states (torch.Tensor): Encoder的输出, shape (B*L, 1, E) + timesteps (torch.Tensor): 对应的时间步, shape (B, L) + game_states (torch.Tensor): 原始的游戏观测, shape (B, L, C, H, W) + predicted_values (torch.Tensor): 预测的标量Value, shape (B*L,) + predicted_rewards (torch.Tensor): 预测的标量Reward, shape (B*L,) + step_counter (int): 全局训练步数 + """ + # ... (统计分析部分保持不变) ... + # (确保 latent_states 和 game_states 的形状为 (N, ...)) + if latent_states.dim() > 2: + latent_states = latent_states.reshape(-1, latent_states.shape[-1]) + num_c, num_h, num_w = game_states.shape[-3:] + game_states = game_states.reshape(-1, num_c, num_h, num_w) + + with torch.no_grad(): + l2_norm = torch.norm(latent_states, p=2, dim=1).mean() + mean = latent_states.mean() + std = latent_states.std() + print(f"[Step {step_counter}] Latent Stats | L2 Norm: {l2_norm:.4f}, Mean: {mean:.4f}, Std: {std:.4f}") + + # 带图像和V/R值的 t-SNE 可视化 + if step_counter >= 0: + # if step_counter > 0 and step_counter % 200 == 0: + + print(f"[Step {step_counter}] Performing t-SNE analysis with images, values, and rewards...") + + # 将数据转换到CPU + latents_np = latent_states.detach().cpu().numpy() + images_np = game_states.detach().cpu().numpy() + values_np = predicted_values.detach().cpu().numpy() + rewards_np = predicted_rewards.detach().cpu().numpy() + + tsne = TSNE(n_components=2, perplexity=30, n_iter=300, random_state=42) + tsne_results = tsne.fit_transform(latents_np) + + # --- 绘制带图像和标注的散点图 --- + + # 减少图像数量以保持清晰 + num_points_to_plot = min(len(latents_np), 70) # 减少到70个点 + indices = np.random.choice(len(latents_np), num_points_to_plot, replace=False) + + fig, ax = plt.subplots(figsize=(20, 18)) # 增大画布尺寸 + + # 先画出所有点的散点图作为背景 + ax.scatter(tsne_results[:, 0], tsne_results[:, 1], c=values_np, cmap='viridis', alpha=0.3, s=10) + + for i in indices: + x, y = tsne_results[i] + img = images_np[i].transpose(1, 2, 0) + img = np.clip(img, 0, 1) + + # 放置图像 + im = OffsetImage(img, zoom=0.7) # 稍微放大图像 + ab = AnnotationBbox(im, (x, y), frameon=True, pad=0.0, bboxprops=dict(edgecolor='none')) + ax.add_artist(ab) + + # 在图像下方添加文字标注 + text_label = f"V:{values_np[i]:.1f} R:{rewards_np[i]:.1f}" + ax.text(x, y - 1.0, text_label, ha='center', va='top', fontsize=8, color='red', + bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.5)) + + ax.update_datalim(tsne_results) + ax.autoscale() + + ax.set_title(f't-SNE of Latent States (Value as Color) at Step {step_counter}', fontsize=16) + ax.set_xlabel('t-SNE dimension 1', fontsize=12) + ax.set_ylabel('t-SNE dimension 2', fontsize=12) + + # 添加colorbar来解释背景点的颜色 + norm = plt.Normalize(values_np.min(), values_np.max()) + sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm) + sm.set_array([]) + fig.colorbar(sm, ax=ax, label='Predicted Value') + + # --- 修改部分:检查文件是否存在,如果存在则添加时间戳 --- + # 1. 构建基础路径 + # base_save_path = ( + # f'/mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/unizero_mspacman_analyze/' + # f'tsne_with_vr_{self.config.optim_type}_lr{self.config.learning_rate}_step_{step_counter}.png' + # ) + base_save_path = ( + f'/mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/unizero_mspacman_analyze/' + f'tsne_with_vr_{self.config.optim_type}_step_{step_counter}.png' + ) + + # 2. 检查文件是否存在,并确定最终保存路径 + if os.path.exists(base_save_path): + # 如果文件已存在,则生成时间戳并附加到文件名 + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + path_root, path_ext = os.path.splitext(base_save_path) + save_path = f"{path_root}_{timestamp}{path_ext}" + print(f"File '{base_save_path}' already exists. Saving to new path with timestamp.") + else: + # 如果文件不存在,则使用原始路径 + save_path = base_save_path + + # 3. 保存图像 + plt.savefig(save_path) + plt.close(fig) # 明确关闭图形对象 + print(f"t-SNE plot with V/R annotations saved to {save_path}") + def _get_final_norm(self, norm_option: str) -> nn.Module: """ Return the corresponding normalization module based on the specified normalization option. @@ -211,6 +386,7 @@ def custom_copy_kv_cache_to_shared_wm(self, src_kv: KeysValues) -> int: src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape if self.shared_pool_wm[self.shared_pool_index_wm] is None: + # import ipdb; ipdb.set_trace() self.shared_pool_wm[self.shared_pool_index_wm] = KeysValues( src_kv_shape[0], # Number of elements (n) src_kv_shape[1], # Number of attention heads (num_heads) @@ -224,7 +400,10 @@ def custom_copy_kv_cache_to_shared_wm(self, src_kv: KeysValues) -> int: for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): # Copy the key and value caches using torch.copy_() for efficient data transfer + # try: dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + # except Exception as e: + # import ipdb; ipdb.set_trace() dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) dst_layer._k_cache._size = src_layer._k_cache._size dst_layer._v_cache._size = src_layer._v_cache._size @@ -264,7 +443,7 @@ def custom_copy_kv_cache_to_shared_recur(self, src_kv: KeysValues) -> int: dst_layer._v_cache._size = src_layer._v_cache._size index = self.shared_pool_index - self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size + self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size_recur return index @@ -280,7 +459,7 @@ def _initialize_config_parameters(self) -> None: self.gamma = self.config.gamma self.context_length = self.config.context_length self.dormant_threshold = self.config.dormant_threshold - self.analysis_dormant_ratio = self.config.analysis_dormant_ratio + self.analysis_dormant_ratio_weight_rank = self.config.analysis_dormant_ratio_weight_rank self.num_observations_tokens = self.config.tokens_per_block - 1 self.latent_recon_loss_weight = self.config.latent_recon_loss_weight self.perceptual_loss_weight = self.config.perceptual_loss_weight @@ -289,7 +468,6 @@ def _initialize_config_parameters(self) -> None: self.max_cache_size = self.config.max_cache_size self.env_num = self.config.env_num self.num_layers = self.config.num_layers - self.obs_per_embdding_dim = self.config.embed_dim self.sim_norm = SimNorm(simnorm_dim=self.group_size) def _initialize_patterns(self) -> None: @@ -304,7 +482,9 @@ def _initialize_patterns(self) -> None: def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: """Create head modules for the transformer.""" modules = [ + nn.LayerNorm(self.config.embed_dim), # <-- 核心优化! # TODO nn.Linear(self.config.embed_dim, self.config.embed_dim), + nn.LayerNorm(self.config.embed_dim), # 2. <-- 新增!稳定内部激活 nn.GELU(approximate='tanh'), nn.Linear(self.config.embed_dim, output_dim) ] @@ -351,21 +531,22 @@ def _initialize_last_layer(self) -> None: nn.init.zeros_(layer.bias) break - def _initialize_cache_structures(self) -> None: - """Initialize cache structures for past keys and values.""" - from collections import defaultdict - self.past_kv_cache_recurrent_infer = defaultdict(dict) - self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)] - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] def _initialize_projection_input_dim(self) -> None: """Initialize the projection input dimension based on the number of observation tokens.""" if self.num_observations_tokens == 16: self.projection_input_dim = 128 elif self.num_observations_tokens == 1: - self.projection_input_dim = self.obs_per_embdding_dim + # self.projection_input_dim = self.config.embed_dim + if self.task_embed_option == "concat_task_embed": + self.projection_input_dim = self.config.embed_dim - self.task_embed_dim + elif self.task_embed_option == "register_task_embed": + self.projection_input_dim = self.config.embed_dim + elif self.task_embed_option == "add_task_embed": + self.projection_input_dim = self.config.embed_dim + else: + self.projection_input_dim = self.config.embed_dim def _initialize_statistics(self) -> None: """Initialize counters for hit count and query count statistics.""" @@ -421,6 +602,7 @@ def precompute_pos_emb_diff_kv(self): self.pos_emb_diff_k.append(layer_pos_emb_diff_k) self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + #@profile def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: """ Helper function to get positional embedding for a given layer and attention type. @@ -631,6 +813,7 @@ def forward( # The 'logits_ends' is intentionally set to None. return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + #@profile def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths): """ @@ -659,6 +842,7 @@ def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_in valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) return embeddings + position_embeddings + #@profile def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -698,6 +882,7 @@ def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_step return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) return return_result, num_steps + #@profile def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -750,6 +935,7 @@ def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, va else: return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths, start_pos=start_pos) + #@profile @torch.no_grad() def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, start_pos: int = 0) -> torch.FloatTensor: """ @@ -784,6 +970,7 @@ def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, start_pos return outputs_wm, self.latent_state + #@profile @torch.no_grad() def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTensor, batch_action=None, @@ -891,7 +1078,7 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens # ================ calculate the target value in Train phase or calculate the target policy in reanalyze phase ================ # [192, 16, 64] -> [32, 6, 16, 64] last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, - self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 + self.config.embed_dim) # (BL, K) for unroll_step=1 last_obs_embeddings = last_obs_embeddings[:, :-1, :] batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device) @@ -922,6 +1109,7 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens return outputs_wm + #@profile @torch.no_grad() def forward_initial_inference(self, obs_act_dict, start_pos: int = 0): """ @@ -939,6 +1127,7 @@ def forward_initial_inference(self, obs_act_dict, start_pos: int = 0): return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value) + #@profile @torch.no_grad() def forward_recurrent_inference(self, state_action_history, simulation_index=0, search_depth=[], start_pos: int = 0): @@ -1025,6 +1214,7 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0, return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + #@profile def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: """ Adjusts the key-value cache for each environment to ensure they all have the same size. @@ -1077,6 +1267,7 @@ def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: return self.keys_values_wm_size_list + #@profile def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, search_depth=[], valid_context_lengths=None): """ @@ -1210,16 +1401,57 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + # ORIGNAL + # if is_init_infer: + # # Store the latest key-value cache for initial inference + # cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + # self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + # else: + # # Store the latest key-value cache for recurrent inference + # cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + # self.past_kv_cache_recurrent_infer[cache_key] = cache_index + + if is_init_infer: - # Store the latest key-value cache for initial inference + # TODO + # ==================== 主动淘汰修复逻辑 ==================== + # 1. 获取即将被覆写的物理索引 + index_to_write = self.shared_pool_index_init_envs[i] + # 2. 使用辅助列表查找该索引上存储的旧的 key + old_key_to_evict = self.pool_idx_to_key_map_init_envs[i][index_to_write] + # 3. 如果存在旧 key,就从主 cache map 中删除它 + if old_key_to_evict is not None: + # 确保要删除的键确实存在,避免意外错误 + if old_key_to_evict in self.past_kv_cache_init_infer_envs[i]: + del self.past_kv_cache_init_infer_envs[i][old_key_to_evict] + + # 现在可以安全地写入新数据了 cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + + # 4. 在主 cache map 和辅助列表中同时更新新的映射关系 self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + self.pool_idx_to_key_map_init_envs[i][index_to_write] = cache_key else: - # Store the latest key-value cache for recurrent inference + # ==================== RECURRENT INFER FIX ==================== + # 1. 获取即将被覆写的物理索引 + index_to_write = self.shared_pool_index + # 2. 使用辅助列表查找该索引上存储的旧的 key + old_key_to_evict = self.pool_idx_to_key_map_recur_infer[index_to_write] + # 3. 如果存在旧 key,就从主 cache map 中删除它 + if old_key_to_evict is not None: + if old_key_to_evict in self.past_kv_cache_recurrent_infer: + del self.past_kv_cache_recurrent_infer[old_key_to_evict] + + # 4. 现在可以安全地写入新数据了 cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + + # 5. 在主 cache map 和辅助列表中同时更新新的映射关系 self.past_kv_cache_recurrent_infer[cache_key] = cache_index + self.pool_idx_to_key_map_recur_infer[index_to_write] = cache_key + + #@profile def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, simulation_index: int = 0, start_pos: int = 0) -> list: """ @@ -1253,8 +1485,20 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, matched_value = None # If not found, try to retrieve from past_kv_cache_recurrent_infer + # if matched_value is None: + # matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + + # ==================== TODO ==================== + # 步骤 2: 仅当在 init_infer 中未找到时,才尝试从 recurrent_infer 缓存中查找 if matched_value is None: - matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + # 2.1 安全地从字典中获取索引,它可能返回 None + recur_cache_index = self.past_kv_cache_recurrent_infer.get(cache_key) + # 2.2 只有在索引有效(不是 None)的情况下,才使用它来从物理池中检索值 + if recur_cache_index is not None: + matched_value = self.shared_pool_recur_infer[recur_cache_index] + + if recur_cache_index is None: + print(f"[CACHE MISS] Not found for key={cache_key} in recurrent infer. Generating new cache.") if matched_value is not None: # If a matching cache is found, add it to the lists @@ -1303,19 +1547,54 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # self.plot_latent_tsne_each_and_all(obs_embeddings, suffix='visual_match_memlen1-60-15_tsne') # self.save_as_image_with_timestep(batch['observations'], suffix='visual_match_memlen1-60-15_tsne') + # ======================== Logging for Analysis ======================== + # This block calculates various metrics for model analysis if the corresponding config flag is enabled. + # These metrics help in debugging and understanding model behavior during training. + if self.analysis_dormant_ratio_weight_rank: + # --- Dormant Ratio Calculation --- + # Calculate the dormant ratio of the encoder to monitor neuron activity. + shape = batch['observations'].shape # Original shape, e.g., (B, T, C, H, W) + # Reshape observations to create a single large batch for the encoder. + # E.g., (32, 5, 3, 64, 64) -> (160, 3, 64, 64) + inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) + + dormant_ratio_encoder_dict = calculate_dormant_ratio( + self.tokenizer.encoder, inputs.detach(), dormant_threshold=self.dormant_threshold + ) + dormant_ratio_encoder = dormant_ratio_encoder_dict['global'] + + # --- Average Weight Magnitude Calculation --- + # Calculate the global average absolute weight magnitude for different model components. + # This is a useful metric for monitoring training stability. + avg_weight_mag_encoder = compute_average_weight_magnitude(self.tokenizer.encoder) + avg_weight_mag_transformer = compute_average_weight_magnitude(self.transformer) + avg_weight_mag_head = compute_average_weight_magnitude(self.head_dict) + + # --- Effective Rank Calculation --- + # Calculate the effective rank of representations from specific layers in the encoder. + # This metric helps analyze the dimensionality and information content of the learned features. + # The 'representation_layer_name' argument specifies the target layer within the model's named modules. + + # Effective rank for the final linear layer of the encoder. + e_rank_last_linear = compute_effective_rank( + self.tokenizer.encoder, inputs, representation_layer_name="last_linear" + ) + # Effective rank for the SimNorm layer of the encoder. + e_rank_sim_norm = compute_effective_rank( + self.tokenizer.encoder, inputs, representation_layer_name="sim_norm" + ) + - # ========= logging for analysis ========= - if self.analysis_dormant_ratio: - # Calculate dormant ratio of the encoder - shape = batch['observations'].shape # (..., C, H, W) - inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) - dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.representation_network, inputs.detach(), - percentage=self.dormant_threshold) self.past_kv_cache_recurrent_infer.clear() self.keys_values_wm_list.clear() torch.cuda.empty_cache() else: dormant_ratio_encoder = torch.tensor(0.) + avg_weight_mag_encoder = torch.tensor(0.) + avg_weight_mag_transformer = torch.tensor(0.) + avg_weight_mag_head = torch.tensor(0.) + e_rank_last_linear = torch.tensor(0.) + e_rank_sim_norm = torch.tensor(0.) # Calculate the L2 norm of the latent state roots latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() @@ -1329,6 +1608,56 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Forward pass to obtain predictions for observations, rewards, and policies outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, start_pos=start_pos) + # [新增] 从模型输出中获取中间张量 x,并分离计算图 + intermediate_tensor_x = outputs.output_sequence.detach() + + global_step = kwargs.get('global_step', 0) + # if global_step >= 0 and global_step % 10000 == 0: # 20k + if global_step > 0 and global_step % 100000000000 == 0: # 20k # TODO + + with torch.no_grad(): + # 将logits转换为标量值 + # 注意:outputs的形状是(B, L, E),我们需要reshape + batch_size, seq_len = batch['actions'].shape[0], batch['actions'].shape[1] + + pred_val_logits = outputs.logits_value.view(batch_size * seq_len, -1) + pred_rew_logits = outputs.logits_rewards.view(batch_size * seq_len, -1) + + scalar_values = inverse_scalar_transform_handle(pred_val_logits).squeeze(-1) + scalar_rewards = inverse_scalar_transform_handle(pred_rew_logits).squeeze(-1) + + self._analyze_latent_representation( + latent_states=obs_embeddings, + timesteps=batch['timestep'], + game_states=batch['observations'], + predicted_values=scalar_values, # 传入预测的Value + predicted_rewards=scalar_rewards, # 传入预测的Reward + step_counter=global_step + ) + + if self.config.use_priority: + # ==================== START MODIFICATION 5 ==================== + # Calculate value_priority, similar to MuZero. + with torch.no_grad(): + # 1. Get the predicted value logits for the first step of the sequence (t=0). + # The shape is (B, support_size). + predicted_value_logits_step0 = outputs.logits_value[:, 0, :] + + # 2. Convert the categorical prediction to a scalar value. + # The shape becomes (B, 1). + predicted_scalar_value_step0 = inverse_scalar_transform_handle(predicted_value_logits_step0) + + # 3. Get the target scalar value for the first step from the batch. + # The shape is (B, num_unroll_steps), so we take the first column. + target_scalar_value_step0 = batch['scalar_target_value'][:, 0] + + # 4. Calculate the L1 loss (absolute difference) between prediction and target. + # This is the priority. We use reduction='none' to get per-sample priorities. + value_priority = F.l1_loss(predicted_scalar_value_step0.squeeze(-1), target_scalar_value_step0, reduction='none') + # ===================== END MODIFICATION 5 ===================== + else: + value_priority = torch.tensor(0.) + if self.obs_type == 'image': # Reconstruct observations from latent state representations # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) @@ -1410,16 +1739,20 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar perceptual_loss = self.perceptual_loss # ========= logging for analysis ========= - if self.analysis_dormant_ratio: + if self.analysis_dormant_ratio_weight_rank: # Calculate dormant ratio of the world model - dormant_ratio_world_model = cal_dormant_ratio(self, { + dormant_ratio_world_model = calculate_dormant_ratio(self, { 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, - percentage=self.dormant_threshold) + dormant_threshold=self.dormant_threshold) + dormant_ratio_transformer = dormant_ratio_world_model['transformer'] + dormant_ratio_head = dormant_ratio_world_model['head'] + self.past_kv_cache_recurrent_infer.clear() self.keys_values_wm_list.clear() torch.cuda.empty_cache() else: - dormant_ratio_world_model = torch.tensor(0.) + dormant_ratio_transformer = torch.tensor(0.) + dormant_ratio_head = torch.tensor(0.) # ========== for visualization ========== # Uncomment the lines below for visualization @@ -1552,6 +1885,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + # 为了让外部的训练循环能够获取encoder的输出,我们将其加入返回字典 + # 使用 .detach() 是因为这个张量仅用于后续的clip操作,不应影响梯度计算 + detached_obs_embeddings = obs_embeddings.detach() + if self.continuous_action_space: return LossWithIntermediateLosses( latent_recon_loss_weight=self.latent_recon_loss_weight, @@ -1569,11 +1906,21 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar middle_step_losses=middle_step_losses, last_step_losses=last_step_losses, dormant_ratio_encoder=dormant_ratio_encoder, - dormant_ratio_world_model=dormant_ratio_world_model, + dormant_ratio_transformer=dormant_ratio_transformer, + dormant_ratio_head=dormant_ratio_head, + avg_weight_mag_encoder = avg_weight_mag_encoder, + avg_weight_mag_transformer = avg_weight_mag_transformer, + avg_weight_mag_head = avg_weight_mag_head, + e_rank_last_linear = e_rank_last_linear, + e_rank_sim_norm = e_rank_sim_norm, latent_state_l2_norms=latent_state_l2_norms, policy_mu=mu, policy_sigma=sigma, target_sampled_actions=target_sampled_actions, + + value_priority=value_priority, + intermediate_tensor_x=intermediate_tensor_x, + obs_embeddings=detached_obs_embeddings, # <-- 新增 ) else: return LossWithIntermediateLosses( @@ -1592,8 +1939,18 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar middle_step_losses=middle_step_losses, last_step_losses=last_step_losses, dormant_ratio_encoder=dormant_ratio_encoder, - dormant_ratio_world_model=dormant_ratio_world_model, + dormant_ratio_transformer=dormant_ratio_transformer, + dormant_ratio_head=dormant_ratio_head, + avg_weight_mag_encoder = avg_weight_mag_encoder, + avg_weight_mag_transformer = avg_weight_mag_transformer, + avg_weight_mag_head = avg_weight_mag_head, + e_rank_last_linear = e_rank_last_linear, + e_rank_sim_norm = e_rank_sim_norm, latent_state_l2_norms=latent_state_l2_norms, + + value_priority=value_priority, + intermediate_tensor_x=intermediate_tensor_x, + obs_embeddings=detached_obs_embeddings, # <-- 新增 ) @@ -1659,7 +2016,7 @@ def _calculate_policy_loss_cont_simple(self, outputs, batch: dict): return policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma - def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor]: + def _calculate_policy_loss_cont(self, outputs, batch: dict, task_id=None) -> Tuple[torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor]: """ Calculate the policy loss for continuous actions. @@ -1674,9 +2031,12 @@ def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tenso - mu (:obj:`torch.Tensor`): The mean of the normal distribution. - sigma (:obj:`torch.Tensor`): The standard deviation of the normal distribution. """ - batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ + if task_id is None: + batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ 0], self.config.num_unroll_steps, self.config.action_space_size - + else: + batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ + 0], self.config.num_unroll_steps, self.config.action_space_size_list[task_id] policy_logits_all = outputs.logits_policy mask_batch = batch['mask_padding'] child_sampled_actions_batch = batch['child_sampled_actions'] @@ -1718,6 +2078,8 @@ def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tenso # KL as projector target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) + + # KL as projector policy_loss = -torch.sum( torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 ) * mask_batch @@ -1767,6 +2129,7 @@ def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): return loss + #@profile def compute_policy_entropy_loss(self, logits, mask): # Compute entropy of the policy probs = torch.softmax(logits, dim=1) @@ -1776,6 +2139,7 @@ def compute_policy_entropy_loss(self, logits, mask): entropy_loss = (entropy * mask) return entropy_loss + #@profile def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag @@ -1795,6 +2159,7 @@ def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torc return labels_observations, labels_rewards.view(-1, self.support_size), None + #@profile def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute labels for value and policy predictions. """ diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py new file mode 100644 index 000000000..47872da28 --- /dev/null +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -0,0 +1,2062 @@ +import collections +import logging +import math +import os +from typing import Any, Dict, Optional, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from ding.utils import get_rank +from einops import rearrange +from matplotlib.offsetbox import AnnotationBbox, OffsetImage +from matplotlib.patches import Patch +from sklearn.manifold import TSNE + +from lzero.model.common import SimNorm +from lzero.model.unizero_world_models.world_model import WorldModel +from lzero.model.utils import ( + calculate_dormant_ratio, + calculate_effective_rank, + compute_average_weight_magnitude, +) + +from .slicer import Head +from .tokenizer import Tokenizer +from .transformer import Transformer, TransformerConfig +from .utils import LossWithIntermediateLosses, WorldModelOutput, hash_state, init_weights + +# Set the logging level for the root logger +logging.getLogger().setLevel(logging.DEBUG) + + +class WorldModelMT(WorldModel): + """ + Overview: + The WorldModel class for the multi-task UniZero model. It is responsible for + predicting the next latent state, reward, policy, and value based on the + current latent state and action. This model is a scalable latent world model + composed of three main parts: a tokenizer, a transformer, and prediction heads. + """ + + def __init__(self, config: TransformerConfig, tokenizer: Tokenizer) -> None: + """ + Overview: + Initializes the multi-task WorldModel. + Arguments: + - config (:obj:`TransformerConfig`): The configuration object for the transformer and world model. + - tokenizer (:obj:`Tokenizer`): The tokenizer for encoding observations. + """ + super().__init__(config, tokenizer) + self.tokenizer = tokenizer + self.config = config + + self.continuous_action_space = self.config.continuous_action_space + self.task_num = config.task_num + self.env_num = self.config.env_num + + # TODO: Investigate sharing the encoder across all 26 games and scaling its gradient. + # if not self.continuous_action_space: + # # Share encoder for Atari games. + # encoder_index = 0 + # encoder = self.tokenizer.encoder[encoder_index] + # # Register a hook for all parameters of the encoder to scale gradients. + # for p in encoder.parameters(): + # p.register_hook(self._scale_grad) + + # Whether to share prediction heads across tasks. + self.share_head = config.share_head + + self.device = torch.device('cuda' if torch.cuda.is_available() and self.config.device != 'cpu' else 'cpu') + print(f"self.device: {self.device}") + + # Positional embedding layer. + self.pos_emb = nn.Embedding(config.max_tokens, self.config.embed_dim, device=self.device) + print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + + # Task embedding setup. + self.use_task_embed = config.use_task_embed + self.task_embed_option = self.config.task_embed_option + self.task_embed_dim = config.task_embed_dim if hasattr(config, "task_embed_dim") else 96 + self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 + + if self.task_embed_option == "register_task_embed": + # When using "register_task_embed", the positional encoding is not adjusted. + # Use a non-trainable, zero-initialized nn.Embedding for positional embeddings. + self.pos_emb = nn.Embedding(config.max_tokens, self.config.embed_dim, device=self.device) + nn.init.constant_(self.pos_emb.weight, 0.0) # Initialize with all zeros. + self.pos_emb.weight.requires_grad = False # Disable updates. + + # Precompute positional embedding differences for efficient inference. + self.precompute_pos_emb_diff_kv() + + self.sim_norm = SimNorm(simnorm_dim=self.config.group_size) + + # Configure embedding dimensions based on the task embedding strategy. + if self.task_embed_option == "concat_task_embed": + # TODO: Currently, with "concat_task_embed", self.pos_emb needs to be fixed at 0. + self.task_emb = nn.Embedding(self.task_num, self.task_embed_dim, max_norm=1) # TDMPC2 suggests max_norm=1. + self.obs_act_embed_dim = config.embed_dim - self.task_embed_dim + self.register_token_num = 0 + elif self.task_embed_option == "register_task_embed": + self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) + self.obs_act_embed_dim = config.embed_dim + elif self.task_embed_option == "add_task_embed": + self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) + self.obs_act_embed_dim = config.embed_dim + else: + self.task_emb = None + self.obs_act_embed_dim = config.embed_dim + self.register_token_num = 0 + + self.transformer = Transformer(self.config, self.task_emb) + + # --- Analysis and Logging Setup --- + self.analysis_dormant_ratio_interval = self.config.get('analysis_dormant_ratio_interval', 100) + self._analysis_step_counter = 0 + self.do_analysis = self.config.analysis_dormant_ratio_weight_rank + + self.analysis_tsne = self.config.get('analysis_tsne', False) + if self.analysis_tsne: + self.env_id_list = self.config.env_id_list + # Automatically generate short names for environments. + self.env_short_names = { + env_id: env_id.replace('NoFrameskip-v4', '') + for env_id in self.config.env_id_list + } + # Color mapping to ensure each task has a fixed color. + self.num_tasks = len(self.env_id_list) + self.colors = self._generate_colors(self.num_tasks) + + # --- Prediction Head Initialization --- + self.head_policy_multi_task = nn.ModuleList() + self.head_value_multi_task = nn.ModuleList() + self.head_rewards_multi_task = nn.ModuleList() + self.head_observations_multi_task = nn.ModuleList() + + self.num_experts_in_moe_head = config.num_experts_in_moe_head + self.use_normal_head = config.use_normal_head + self.use_moe_head = config.use_moe_head + self.use_softmoe_head = config.use_softmoe_head + + self.to(self.device) + + # Initialize configuration parameters from the config object. + self._initialize_config_parameters() + self._initialize_patterns() + + self.hidden_size = config.embed_dim // config.num_heads + + # Initialize action embedding table based on action space type. + if self.continuous_action_space: + self.act_embedding_table = nn.ModuleList([ + nn.Sequential( + nn.Linear(config.action_space_size_list[task_id], self.obs_act_embed_dim, device=self.device, bias=False), + SimNorm(simnorm_dim=self.group_size) + ) for task_id in range(self.task_num) + ]) + else: + # For discrete action space. + self.act_embedding_table = nn.Embedding(config.action_space_size, self.obs_act_embed_dim, device=self.device) + print(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") + print(f'=' * 20) + print(f"self.obs_act_embed_dim: {self.obs_act_embed_dim}") + print(f'=' * 20) + + assert self.num_experts_in_moe_head > 0 + if self.use_normal_head: + self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'LayerNorm') + print('We use normal head') + for task_id in range(self.task_num): + if self.continuous_action_space: + self.sigma_type = self.config.sigma_type + self.bound_type = self.config.bound_type + head_policy = self._create_head_cont(self.value_policy_tokens_pattern, self.config.action_space_size_list[task_id]) + else: + head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) + + if not self.share_head or task_id == 0: + self.head_policy_multi_task.append(head_policy) + + head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + if not self.share_head or task_id == 0: + self.head_value_multi_task.append(head_value) + + head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) + if not self.share_head or task_id == 0: + self.head_rewards_multi_task.append(head_rewards) + + head_observations = self._create_head( + self.all_but_last_latent_state_pattern, + self.config.embed_dim, + self._get_final_norm(self.final_norm_option_in_obs_head) # Use the specified normalization method. + ) + if not self.share_head or task_id == 0: + self.head_observations_multi_task.append(head_observations) + + elif self.use_softmoe_head: + print(f'We use softmoe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') + self.soft_moe_instances = {} + self.create_head_modules_softmoe() + self.head_policy_multi_task.append(self.head_policy) + self.head_value_multi_task.append(self.head_value) + self.head_rewards_multi_task.append(self.head_rewards) + self.head_observations_multi_task.append(self.head_observations) + elif self.use_moe_head: + print(f'We use moe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') + self.moe_instances = {} + self.create_head_modules_moe() + self.head_policy_multi_task.append(self.head_policy) + self.head_value_multi_task.append(self.head_value) + self.head_rewards_multi_task.append(self.head_rewards) + self.head_observations_multi_task.append(self.head_observations) + + # Group all head modules into a ModuleDict for easier management. + self.head_dict = nn.ModuleDict({ + name: module for name, module in self.named_children() + if name.startswith("head_") and name.endswith("_multi_task") + }) + print("=" * 20) + print(f"self.head_dict:{self.head_dict}") + + # Apply weight initialization. The order of initialization is important. + self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) + self._initialize_last_layer_mt() + + # --- Cache and State Initialization --- + self._initialize_cache_structures() + self._initialize_projection_input_dim() + self._initialize_statistics() + self._initialize_transformer_keys_values() + + self.latent_recon_loss = torch.tensor(0., device=self.device) + self.perceptual_loss = torch.tensor(0., device=self.device) + + # 先设置为game_segment_length,以保持self.shared_pool_init_infer都是有效的kv + # TODO: 非常重要,应该改为和segment_length一样 + self.shared_pool_size_init = int(self.config.game_segment_length) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + + self.shared_pool_size_recur = int(self.num_simulations*self.env_num) + self.shared_pool_recur_infer = [None] * self.shared_pool_size_recur + self.shared_pool_index = 0 + + # For init_infer, it only needs to retain the results of the most recent step. + # NOTE: A large pool size might cause incorrect retrieval of the kv cache. + self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] + + # For wm (world model) forward passes during training. + self.shared_pool_size_wm = int(self.env_num) + self.shared_pool_wm = [None] * self.shared_pool_size_wm + self.shared_pool_index_wm = 0 + + self.reanalyze_phase = False + self._rank = get_rank() + + def _scale_grad(self, grad: torch.Tensor) -> torch.Tensor: + """ + Overview: + Scales the gradient. This hook is registered to encoder parameters + to stabilize multi-task training. + Arguments: + - grad (:obj:`torch.Tensor`): The original gradient. + Returns: + - (:obj:`torch.Tensor`): The scaled gradient. + """ + # Scale by 1/sqrt(k) for a conservative approach, where k is the number of tasks. + return grad / math.sqrt(self.task_num) + + def _generate_colors(self, num_colors: int) -> list: + """ + Overview: + Generates a list of unique colors for visualization purposes, + suitable for a large number of categories. + Arguments: + - num_colors (:obj:`int`): The desired number of unique colors. + Returns: + - (:obj:`list`): A list of colors. + """ + # Concatenate multiple discrete colormaps from matplotlib to get more colors. + color_maps = ['tab20', 'tab20b', 'tab20c'] + colors = [] + for cmap_name in color_maps: + cmap = plt.get_cmap(cmap_name) + colors.extend([cmap(i) for i in range(cmap.N)]) + if len(colors) >= num_colors: + break + # Generate additional colors if needed. + if len(colors) < num_colors: + additional_colors = plt.cm.get_cmap('hsv', num_colors - len(colors)) + colors.extend([additional_colors(i) for i in range(num_colors - len(colors))]) + return colors[:num_colors] + + def _initialize_config_parameters(self) -> None: + """Initializes model attributes from the configuration object.""" + self.policy_entropy_weight = self.config.policy_entropy_weight + self.predict_latent_loss_type = self.config.predict_latent_loss_type + self.group_size = self.config.group_size + self.num_groups = self.config.embed_dim // self.group_size + self.obs_type = self.config.obs_type + self.embed_dim = self.config.embed_dim + self.num_heads = self.config.num_heads + self.gamma = self.config.gamma + self.context_length = self.config.context_length + self.dormant_threshold = self.config.dormant_threshold + self.analysis_dormant_ratio_weight_rank = self.config.analysis_dormant_ratio_weight_rank + self.num_observations_tokens = self.config.tokens_per_block - 1 + self.latent_recon_loss_weight = self.config.latent_recon_loss_weight + self.perceptual_loss_weight = self.config.perceptual_loss_weight + self.support_size = self.config.support_size + self.action_space_size = self.config.action_space_size + self.max_cache_size = self.config.max_cache_size + self.num_layers = self.config.num_layers + + def _initialize_patterns(self) -> None: + """Initializes patterns (masks) for selecting specific tokens for prediction heads.""" + self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block) + self.all_but_last_latent_state_pattern[-2] = 0 + self.act_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.act_tokens_pattern[-1] = 1 + self.value_policy_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.value_policy_tokens_pattern[-2] = 1 + + def _get_final_norm(self, norm_option: str) -> nn.Module: + """Returns the specified normalization module.""" + if norm_option == 'LayerNorm': + return nn.LayerNorm(self.config.embed_dim, eps=1e-5) + elif norm_option == 'SimNorm': + return SimNorm(simnorm_dim=self.config.group_size) + else: + raise ValueError(f"Unsupported final_norm_option_in_obs_head: {norm_option}") + + def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer: Optional[nn.Module] = None) -> Head: + """Creates a standard prediction head.""" + modules = [ + nn.LayerNorm(self.config.embed_dim), # <-- 核心优化! # TODO + nn.Linear(self.config.embed_dim, self.config.embed_dim), + nn.LayerNorm(self.config.embed_dim), # 2. <-- 新增!稳定内部激活 + nn.GELU(approximate='tanh'), + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def _create_head_moe(self, block_mask: torch.Tensor, output_dim: int, norm_layer: Optional[nn.Module] = None, moe: Optional[nn.Module] = None) -> Head: + """Creates a prediction head with a Mixture-of-Experts (MoE) layer.""" + modules = [ + nn.LayerNorm(self.config.embed_dim), # <-- 核心优化! # TODO + moe, + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def get_moe(self, name: str) -> nn.Module: + """Gets or creates a MoE instance by name.""" + from .moe import MoELayer, MultiplicationFeedForward + + if name not in self.moe_instances: + # Create multiple FeedForward instances for multiplication-based MoE. + experts = nn.ModuleList([ + MultiplicationFeedForward(self.config) for _ in range(self.config.num_experts_of_moe_in_transformer) + ]) + self.moe_instances[name] = MoELayer( + experts=experts, + gate=nn.Linear(self.config.embed_dim, self.config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + return self.moe_instances[name] + + def create_head_modules_moe(self) -> None: + """Creates all MoE prediction head modules.""" + self.head_rewards = self._create_head_moe(self.act_tokens_pattern, self.support_size, moe=self.get_moe("rewards_moe")) + self.head_observations = self._create_head_moe(self.all_but_last_latent_state_pattern, self.embed_dim, norm_layer=self.sim_norm, moe=self.get_moe("observations_moe")) + self.head_policy = self._create_head_moe(self.value_policy_tokens_pattern, self.action_space_size, moe=self.get_moe("policy_moe")) + self.head_value = self._create_head_moe(self.value_policy_tokens_pattern, self.support_size, moe=self.get_moe("value_moe")) + + def _create_head_softmoe(self, block_mask: torch.Tensor, output_dim: int, norm_layer: Optional[nn.Module] = None, soft_moe: Optional[nn.Module] = None) -> Head: + """Creates a prediction head with a Soft-MoE layer.""" + modules = [ + soft_moe, + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def get_soft_moe(self, name: str) -> nn.Module: + """Gets or creates a Soft-MoE instance by name.""" + from soft_moe_pytorch import DynamicSlotsSoftMoE as SoftMoE + if name not in self.soft_moe_instances: + self.soft_moe_instances[name] = SoftMoE( + dim=self.embed_dim, + num_experts=self.num_experts_in_moe_head, + geglu=True + ) + return self.soft_moe_instances[name] + + def create_head_modules_softmoe(self) -> None: + """Creates all Soft-MoE prediction head modules.""" + self.head_rewards = self._create_head_softmoe(self.act_tokens_pattern, self.support_size, soft_moe=self.get_soft_moe("rewards_soft_moe")) + self.head_observations = self._create_head_softmoe(self.all_but_last_latent_state_pattern, self.config.embed_dim, norm_layer=self.sim_norm, soft_moe=self.get_soft_moe("observations_soft_moe")) + self.head_policy = self._create_head_softmoe(self.value_policy_tokens_pattern, self.action_space_size, soft_moe=self.get_soft_moe("policy_soft_moe")) + self.head_value = self._create_head_softmoe(self.value_policy_tokens_pattern, self.support_size, soft_moe=self.get_soft_moe("value_soft_moe")) + + def _initialize_last_layer_mt(self) -> None: + """Initializes the last linear layer of prediction heads to zero for training stability.""" + last_linear_layer_init_zero = True + print(f'world_model_mt.py:self.task_num:{self.task_num}') + if last_linear_layer_init_zero: + if self.continuous_action_space: + # For continuous actions, policy head might have a different initialization strategy. + module_to_initialize = self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task + else: + module_to_initialize = self.head_policy_multi_task + self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task + + for head in module_to_initialize: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + + def _initialize_cache_structures(self) -> None: + """Initializes cache structures for storing past keys and values during inference.""" + # self.past_kv_cache_recurrent_infer = collections.OrderedDict() + # self.past_kv_cache_init_infer_envs = [collections.OrderedDict() for _ in range(self.env_num)] + + self.past_kv_cache_recurrent_infer = {} + self.pool_idx_to_key_map_recur_infer = [None] * self.shared_pool_size_recur + self.past_kv_cache_init_infer_envs = [{} for _ in range(self.env_num)] + # 辅助数据结构,用于反向查找:pool_index -> key + self.pool_idx_to_key_map_init_envs = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + + def _initialize_projection_input_dim(self) -> None: + """Initializes the input dimension for the projection based on observation tokenization.""" + if self.num_observations_tokens == 16: + self.projection_input_dim = 128 + elif self.num_observations_tokens == 1: + if self.task_embed_option in ["concat_task_embed", "register_task_embed", "add_task_embed"]: + self.projection_input_dim = self.config.embed_dim + if self.task_embed_option == "concat_task_embed": + self.projection_input_dim -= self.task_embed_dim + else: + self.projection_input_dim = self.config.embed_dim + + def _initialize_statistics(self) -> None: + """Initializes counters for cache hit rates and other statistics.""" + self.hit_count = 0 + self.total_query_count = 0 + self.length_largethan_maxminus5_context_cnt = 0 + self.length_largethan_maxminus7_context_cnt = 0 + self.root_hit_cnt = 0 + self.root_total_query_cnt = 0 + + def _initialize_transformer_keys_values(self) -> None: + """Initializes empty key-value cache structures for the transformer.""" + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length) + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.context_length) + + def precompute_pos_emb_diff_kv(self) -> None: + """ + Overview: + Precomputes positional embedding differences for keys and values. This is an + optimization to speed up KV cache updates during recurrent inference by avoiding + re-computation of positional embeddings. + """ + if self.context_length <= 2: + return # No context to precompute for. + + # Precompute positional embedding matrices for all layers. + self.positional_embedding_k = [self._get_positional_embedding(layer, 'key') for layer in range(self.config.num_layers)] + self.positional_embedding_v = [self._get_positional_embedding(layer, 'value') for layer in range(self.config.num_layers)] + + # Precompute all possible positional embedding differences. + self.pos_emb_diff_k = [] + self.pos_emb_diff_v = [] + + for layer in range(self.config.num_layers): + layer_pos_emb_diff_k = {} + layer_pos_emb_diff_v = {} + + # This is for the case when context window is full and we shift it. + # TODO: Generalize for different start/end points if necessary. + for start in [2]: + for end in [self.context_length - 1]: + original_pos_emb_k = self.positional_embedding_k[layer][:, :, start:end, :] + new_pos_emb_k = self.positional_embedding_k[layer][:, :, :end - start, :] + layer_pos_emb_diff_k[(start, end)] = new_pos_emb_k - original_pos_emb_k + + original_pos_emb_v = self.positional_embedding_v[layer][:, :, start:end, :] + new_pos_emb_v = self.positional_embedding_v[layer][:, :, :end - start, :] + layer_pos_emb_diff_v[(start, end)] = new_pos_emb_v - original_pos_emb_v + + self.pos_emb_diff_k.append(layer_pos_emb_diff_k) + self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + + def _get_positional_embedding(self, layer: int, attn_type: str) -> torch.Tensor: + """ + Overview: + Helper function to get positional embedding for a given layer and attention type. + Arguments: + - layer (:obj:`int`): The layer index. + - attn_type (:obj:`str`): The attention type, either 'key' or 'value'. + Returns: + - (:obj:`torch.Tensor`): The positional embedding tensor, detached from the graph. + """ + # TODO: Review the use of detach(). It's used here to prevent gradients from flowing back + # through the positional embeddings during this pre-computation phase. + attn_func = getattr(self.transformer.blocks[layer].attn, attn_type) + pos_emb = attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2) + return pos_emb.to(self.device).detach() + + def forward( + self, + obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]], + past_keys_values: Optional[torch.Tensor] = None, + kvcache_independent: bool = False, + is_init_infer: bool = True, + valid_context_lengths: Optional[torch.Tensor] = None, + task_id: int = 0 + ) -> WorldModelOutput: + """ + Overview: + Main forward pass for the world model. It processes either observation embeddings, + action tokens, or a combination of both, and passes them through the transformer + to generate predictions. + Arguments: + - obs_embeddings_or_act_tokens (:obj:`Dict`): A dictionary containing input tensors. + Can be 'obs_embeddings', 'act_tokens', or 'obs_embeddings_and_act_tokens'. + - past_keys_values (:obj:`Optional[torch.Tensor]`): The KV cache from previous steps. + - kvcache_independent (:obj:`bool`): Whether to use independent KV caching per item in the batch. + - is_init_infer (:obj:`bool`): Flag indicating if this is an initial inference step. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Tensor of valid context lengths for each item. + - task_id (:obj:`int`): The ID of the current task. + Returns: + - (:obj:`WorldModelOutput`): An object containing the transformer output and logits for + observations, rewards, policy, and value. + """ + if self.use_task_embed: + self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) + self.task_embeddings = self.sim_norm(self.task_embeddings.view(1, -1)).view(-1) + else: + # Use a zero tensor if task embeddings are disabled. + self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) + + prev_steps = 0 if past_keys_values is None else past_keys_values.size + if kvcache_independent: + prev_steps = torch.tensor([0 if past_keys_values is None else past_kv.size for past_kv in past_keys_values], device=self.device) + + if is_init_infer: + valid_context_lengths = None + + # --- Branch 1: Inference Phase (Collect/Eval) - Process observation embeddings --- + if 'obs_embeddings' in obs_embeddings_or_act_tokens: + obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] + if len(obs_embeddings.shape) == 2: + obs_embeddings = obs_embeddings.unsqueeze(1) + + # Apply task embeddings based on the chosen strategy. + if self.task_embed_option == "add_task_embed": + obs_embeddings = obs_embeddings + self.task_embeddings + elif self.task_embed_option == "concat_task_embed": + if is_init_infer and not self.reanalyze_phase: + # Concatenate task embeddings only during initial inference. + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(obs_embeddings.shape[0], obs_embeddings.shape[1], -1) + obs_embeddings = torch.cat([obs_embeddings, task_emb_expanded], dim=-1) + + num_steps = obs_embeddings.size(1) + sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths) + + # --- Branch 2: Inference Phase (Collect/Eval) - Process action tokens --- + elif 'act_tokens' in obs_embeddings_or_act_tokens: + act_tokens = obs_embeddings_or_act_tokens['act_tokens'] + if self.continuous_action_space: + num_steps = 1 + act_tokens = act_tokens.float() + if len(act_tokens.shape) == 2: + act_tokens = act_tokens.unsqueeze(1) + else: + if len(act_tokens.shape) == 3: + act_tokens = act_tokens.squeeze(1) + num_steps = act_tokens.size(1) + + # Get action embeddings from the task-specific or shared table. + if self.task_num >= 1 and self.continuous_action_space: + act_embeddings = self.act_embedding_table[task_id](act_tokens) + else: + act_embeddings = self.act_embedding_table(act_tokens) + + # Apply task embeddings. + if self.task_embed_option == "concat_task_embed": + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(act_embeddings.shape[0], act_embeddings.shape[1], -1) + act_embeddings = torch.cat([act_embeddings, task_emb_expanded], dim=-1) + + sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths) + + # --- Branch 3: Training Phase - Process combined observation embeddings and action tokens --- + else: + if self.continuous_action_space: + sequences, num_steps = self._process_obs_act_combined_cont(obs_embeddings_or_act_tokens, prev_steps, task_id=task_id) + else: + sequences, num_steps = self._process_obs_act_combined(obs_embeddings_or_act_tokens, prev_steps) + + # Pass sequences through the transformer. + x = self._transformer_pass(sequences, past_keys_values, kvcache_independent, valid_context_lengths, task_id=task_id) + + # Generate logits using shared, task-specific, or MoE heads. + head_index = 0 if self.share_head else task_id + if self.use_moe_head or self.use_softmoe_head: + logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) + else: + logits_observations = self.head_observations_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) + + return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + + def _add_position_embeddings( + self, + embeddings: torch.Tensor, + prev_steps: Union[int, torch.Tensor], + num_steps: int, + kvcache_independent: bool, + is_init_infer: bool, + valid_context_lengths: Optional[torch.Tensor] + ) -> torch.Tensor: + """ + Overview: + Adds positional embeddings to the input embeddings. + Arguments: + - embeddings (:obj:`torch.Tensor`): Input embeddings. + - prev_steps (:obj:`Union[int, torch.Tensor]`): Number of previous steps in the cache. + - num_steps (:obj:`int`): Number of new steps being added. + - kvcache_independent (:obj:`bool`): Flag for independent KV caching. + - is_init_infer (:obj:`bool`): Flag for initial inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths for each sequence. + Returns: + - (:obj:`torch.Tensor`): Embeddings with added positional information. + """ + if kvcache_independent: + steps_indices = prev_steps.unsqueeze(1) + torch.arange(num_steps, device=embeddings.device) + position_embeddings = self.pos_emb(steps_indices) + return embeddings + position_embeddings + else: + if is_init_infer: + # For initial inference, positions are sequential from the previous step count. + pos_indices = prev_steps + torch.arange(num_steps, device=self.device) + return embeddings + self.pos_emb(pos_indices) + else: + # For recurrent steps, use valid_context_lengths to get correct positions. + valid_context_lengths = torch.tensor(self.keys_values_wm_size_list_current, device=self.device) + pos_indices = valid_context_lengths.unsqueeze(1) + torch.arange(num_steps, device=self.device) + position_embeddings = self.pos_emb(pos_indices) + return embeddings + position_embeddings + + def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens: dict, prev_steps: int, task_id: int = 0) -> Tuple[torch.Tensor, int]: + """ + Overview: + Processes and combines observation embeddings and continuous action tokens for training. + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary with 'obs_embeddings_and_act_tokens'. + - prev_steps (:obj:`int`): Number of previous steps. + - task_id (:obj:`int`): The current task ID. + Returns: + - (:obj:`Tuple[torch.Tensor, int]`): A tuple of the combined sequence tensor and the number of steps. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + act_tokens = act_tokens.float() + if len(act_tokens.shape) == 2: + act_tokens = act_tokens.unsqueeze(-1) + + act_embeddings = self.act_embedding_table[task_id](act_tokens) + + B, L, K, E_obs = obs_embeddings.size() + obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) + + if self.task_embed_option == "concat_task_embed": + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(B, 1, -1) + + for i in range(L): + obs = obs_embeddings[:, i, :, :] + if self.task_embed_option == "add_task_embed": + obs = obs + self.task_embeddings + elif self.task_embed_option == "concat_task_embed": + obs = torch.cat([obs, task_emb_expanded.expand(B, K, -1)], dim=-1) + + act = act_embeddings[:, i, :].unsqueeze(1) + if self.task_embed_option == "concat_task_embed": + act = torch.cat([act, task_emb_expanded], dim=-1) + + obs_act = torch.cat([obs, act], dim=1) + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + pos_indices = prev_steps + torch.arange(num_steps, device=self.device) + return obs_act_embeddings + self.pos_emb(pos_indices), num_steps + + def _process_obs_act_combined(self, obs_embeddings_or_act_tokens: dict, prev_steps: int, task_id: int = 0) -> Tuple[torch.Tensor, int]: + """ + Overview: + Processes and combines observation embeddings and discrete action tokens for training. + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary with 'obs_embeddings_and_act_tokens'. + - prev_steps (:obj:`int`): Number of previous steps. + - task_id (:obj:`int`): The current task ID. + Returns: + - (:obj:`Tuple[torch.Tensor, int]`): A tuple of the combined sequence tensor and the number of steps. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + act_embeddings = self.act_embedding_table(act_tokens) + + B, L, K, E_obs = obs_embeddings.size() + obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) + + if self.task_embed_option == "concat_task_embed": + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(B, 1, -1) + + for i in range(L): + obs = obs_embeddings[:, i, :, :] + if self.task_embed_option == "add_task_embed": + obs = obs + self.task_embeddings + elif self.task_embed_option == "concat_task_embed": + obs = torch.cat([obs, task_emb_expanded.expand(B, K, -1)], dim=-1) + + act = act_embeddings[:, i, 0, :].unsqueeze(1) + if self.task_embed_option == "concat_task_embed": + act = torch.cat([act, task_emb_expanded], dim=-1) + + obs_act = torch.cat([obs, act], dim=1) + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + pos_indices = prev_steps + torch.arange(num_steps, device=self.device) + return obs_act_embeddings + self.pos_emb(pos_indices), num_steps + + def _transformer_pass( + self, + sequences: torch.Tensor, + past_keys_values: Optional[torch.Tensor], + kvcache_independent: bool, + valid_context_lengths: Optional[torch.Tensor], + task_id: int = 0 + ) -> torch.Tensor: + """ + Overview: + Passes sequences through the transformer, handling different KV cache modes. + Arguments: + - sequences (:obj:`torch.Tensor`): Input sequences. + - past_keys_values (:obj:`Optional[torch.Tensor]`): The KV cache from previous steps. + - kvcache_independent (:obj:`bool`): Flag for independent KV caching. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Tensor of valid context lengths. + - task_id (:obj:`int`): The current task ID. + Returns: + - (:obj:`torch.Tensor`): The output from the transformer. + """ + if kvcache_independent: + x = [ + self.transformer(sequences[k].unsqueeze(0), past_kv, valid_context_lengths=valid_context_lengths[k].unsqueeze(0)) + for k, past_kv in enumerate(past_keys_values) + ] + return torch.cat(x, dim=0) + else: + return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths) + + @torch.no_grad() + def reset_for_initial_inference(self, obs_act_dict: dict, task_id: int = 0) -> Tuple[WorldModelOutput, torch.Tensor]: + """ + Overview: + Resets the model state for the beginning of an episode or a new inference sequence. + It processes the initial observations and actions to create the first latent state + and populate the KV cache. + Arguments: + - obs_act_dict (:obj:`dict`): A dictionary containing 'obs', 'action', and 'current_obs'. + - task_id (:obj:`int`): The ID of the current task. + Returns: + - (:obj:`Tuple[WorldModelOutput, torch.Tensor]`): A tuple containing the world model output + and the initial latent state. + """ + if self.use_task_embed: + self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) + self.task_embeddings = self.sim_norm(self.task_embeddings.view(1, -1)).view(-1) + else: + self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) + + batch_obs = obs_act_dict['obs'] + batch_action = obs_act_dict['action'] + batch_current_obs = obs_act_dict['current_obs'] + + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_obs, task_id=task_id) + + if batch_current_obs is not None: + # --- Collect and Evaluation Phase --- + current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_current_obs, task_id=task_id) + + # The latent state is the combination of observation embedding and task embedding. + if self.use_task_embed: + if self.task_embed_option == "add_task_embed": + self.latent_state = current_obs_embeddings + self.task_embeddings + elif self.task_embed_option == "concat_task_embed": + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(current_obs_embeddings.shape[0], current_obs_embeddings.shape[1], -1) + self.latent_state = torch.cat([current_obs_embeddings, task_emb_expanded], dim=-1) + else: # "register_task_embed" or other cases + self.latent_state = current_obs_embeddings + else: + self.latent_state = current_obs_embeddings + + outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, current_obs_embeddings, task_id=task_id) + else: + # --- Training Phase (for calculating target values) --- + if self.use_task_embed: + if self.task_embed_option == "add_task_embed": + self.latent_state = obs_embeddings + self.task_embeddings + elif self.task_embed_option == "concat_task_embed": + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(obs_embeddings.shape[0], obs_embeddings.shape[1], -1) + self.latent_state = torch.cat([obs_embeddings, task_emb_expanded], dim=-1) + else: + self.latent_state = obs_embeddings + else: + self.latent_state = obs_embeddings + + outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, None, task_id=task_id) + + return outputs_wm, self.latent_state + + + #@profile + @torch.no_grad() + def wm_forward_for_initial_inference(self, last_obs_embeddings: torch.LongTensor, + batch_action=None, + current_obs_embeddings=None, task_id = 0) -> torch.FloatTensor: + """ + Refresh key-value pairs with the initial latent state for inference. + + Arguments: + - latent_state (:obj:`torch.LongTensor`): The latent state embeddings. + - batch_action (optional): Actions taken. + - current_obs_embeddings (optional): Current observation embeddings. + Returns: + - torch.FloatTensor: The outputs from the world model. + """ + n, num_observations_tokens, _ = last_obs_embeddings.shape + if n <= self.env_num and current_obs_embeddings is not None: + # ================ Collect and Evaluation Phase ================ + if current_obs_embeddings is not None: + if self.continuous_action_space: + first_step_flag = not isinstance(batch_action[0], np.ndarray) + else: + first_step_flag = max(batch_action) == -1 + if first_step_flag: + # First step in an episode + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=current_obs_embeddings.shape[0], + max_tokens=self.context_length) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id) + + if self.use_task_embed and self.task_embed_option in ["concat_task_embed", "add_task_embed"]: + # Copy and store keys_values_wm for a single environment + self.update_cache_context(self.latent_state, is_init_infer=True) + else: + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + else: + # Assume latest_state is the new latent_state, containing information from ready_env_num environments + ready_env_num = current_obs_embeddings.shape[0] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + for i in range(ready_env_num): + # Retrieve latent state for a single environment + state_single_env = last_obs_embeddings[i] + # Compute hash value using latent state for a single environment + cache_key = hash_state( + state_single_env.view(-1).cpu().numpy()) # last_obs_embeddings[i] is torch.Tensor + + # Retrieve cached value + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + + self.root_total_query_cnt += 1 + if matched_value is not None: + # If a matching value is found, add it to the list + self.root_hit_cnt += 1 + # deepcopy is needed because forward modifies matched_value in place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # Reset using zero values + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length) + outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, + past_keys_values=self.keys_values_wm_single_env, + is_init_infer=True, task_id=task_id) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + # Input self.keys_values_wm_list, output self.keys_values_wm + self.keys_values_wm_size_list_current = self.trim_and_pad_kv_cache(is_init_infer=True) + + batch_action = batch_action[:ready_env_num] + # if ready_env_num < self.env_num: + # print(f'init inference ready_env_num: {ready_env_num} < env_num: {self.env_num}') + if self.continuous_action_space: + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(1) + else: + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(-1) + outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, + is_init_infer=True, task_id=task_id) + + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id) + + # Copy and store keys_values_wm for a single environment + if self.use_task_embed and self.task_embed_option in ["concat_task_embed", "add_task_embed"]: + # Copy and store keys_values_wm for a single environment + self.update_cache_context(self.latent_state, is_init_infer=True) + else: + # import ipdb; ipdb.set_trace() + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + + elif batch_action is not None and current_obs_embeddings is None: + # elif n > self.env_num and batch_action is not None and current_obs_embeddings is None: + # ================ calculate the target value in Train phase ================ + # [192, 16, 64] -> [32, 6, 16, 64] + last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, + self.obs_act_embed_dim) # (BL, K) for unroll_step=1 + + last_obs_embeddings = last_obs_embeddings[:, :-1, :] + batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device) + + if self.continuous_action_space: + act_tokens = batch_action + else: + act_tokens = rearrange(batch_action, 'b l -> b l 1') + + # select the last timestep for each sample + # This will select the last column while keeping the dimensions unchanged, and the target policy/value in the final step itself is not used. + last_steps_act = act_tokens[:, -1:, :] + act_tokens = torch.cat((act_tokens, last_steps_act), dim=1) + + outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, task_id=task_id) + + # select the last timestep for each sample + last_steps_value = outputs_wm.logits_value[:, -1:, :] + outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) + + last_steps_policy = outputs_wm.logits_policy[:, -1:, :] + outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) + + # Reshape your tensors + # outputs_wm.logits_value.shape (B, H, 101) = (B*H, 101) + outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') + outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') + + return outputs_wm + + + #@profile + @torch.no_grad() + def forward_initial_inference(self, obs_act_dict, task_id = 0): + """ + Perform initial inference based on the given observation-action dictionary. + + Arguments: + - obs_act_dict (:obj:`dict`): Dictionary containing observations and actions. + Returns: + - tuple: A tuple containing output sequence, latent state, logits rewards, logits policy, and logits value. + """ + # UniZero has context in the root node + outputs_wm, latent_state = self.reset_for_initial_inference(obs_act_dict, task_id=task_id) + self.past_kv_cache_recurrent_infer.clear() + + return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, + outputs_wm.logits_policy, outputs_wm.logits_value) + + #@profile + @torch.no_grad() + def forward_recurrent_inference(self, state_action_history, simulation_index=0, + latent_state_index_in_search_path=[], task_id = 0): + """ + Perform recurrent inference based on the state-action history. + + Arguments: + - state_action_history (:obj:`list`): List containing tuples of state and action history. + - simulation_index (:obj:`int`, optional): Index of the current simulation. Defaults to 0. + - latent_state_index_in_search_path (:obj:`list`, optional): List containing indices of latent states in the search path. Defaults to []. + Returns: + - tuple: A tuple containing output sequence, updated latent state, reward, logits policy, and logits value. + """ + latest_state, action = state_action_history[-1] + ready_env_num = latest_state.shape[0] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + self.keys_values_wm_size_list = self.retrieve_or_generate_kvcache(latest_state, ready_env_num, simulation_index, task_id=task_id) + + latent_state_list = [] + if not self.continuous_action_space: + token = action.reshape(-1, 1) + else: + token = action.reshape(-1, self.config.action_space_size_list[task_id]) + + # ======= Print statistics for debugging ============= + # min_size = min(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 5: + # self.length_largethan_maxminus5_context_cnt += len(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 7: + # self.length_largethan_maxminus7_context_cnt += len(self.keys_values_wm_size_list) + # if self.total_query_count > 0 and self.total_query_count % 10000 == 0: + # self.hit_freq = self.hit_count / self.total_query_count + # print('total_query_count:', self.total_query_count) + # length_largethan_maxminus5_context_cnt_ratio = self.length_largethan_maxminus5_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus5_context:', self.length_largethan_maxminus5_context_cnt) + # print('recurrent largethan_maxminus5_context_ratio:', length_largethan_maxminus5_context_cnt_ratio) + # length_largethan_maxminus7_context_cnt_ratio = self.length_largethan_maxminus7_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus7_context_ratio:', length_largethan_maxminus7_context_cnt_ratio) + # print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt) + + # Trim and pad kv_cache + self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False) + self.keys_values_wm_size_list_current = self.keys_values_wm_size_list + + for k in range(2): + # action_token obs_token + if k == 0: + obs_embeddings_or_act_tokens = {'act_tokens': token} + else: + obs_embeddings_or_act_tokens = {'obs_embeddings': token} + + # Perform forward pass + outputs_wm = self.forward( + obs_embeddings_or_act_tokens, + past_keys_values=self.keys_values_wm, + kvcache_independent=False, + is_init_infer=False, + task_id = task_id + ) + + self.keys_values_wm_size_list_current = [i + 1 for i in self.keys_values_wm_size_list_current] + + if k == 0: + reward = outputs_wm.logits_rewards # (B,) + + if k < self.num_observations_tokens: + token = outputs_wm.logits_observations + if len(token.shape) != 3: + token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) + # print(f'token.shape:{token.shape}') + + latent_state_list.append(token) + + del self.latent_state # Very important to minimize cuda memory usage + self.latent_state = torch.cat(latent_state_list, dim=1) # (B, K) + + self.update_cache_context( + self.latent_state, + is_init_infer=False, + simulation_index=simulation_index, + latent_state_index_in_search_path=latent_state_index_in_search_path + ) + + return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + + def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: + """ + Adjusts the key-value cache for each environment to ensure they all have the same size. + + In a multi-environment setting, the key-value cache (kv_cache) for each environment is stored separately. + During recurrent inference, the kv_cache sizes may vary across environments. This method pads each kv_cache + to match the largest size found among them, facilitating batch processing in the transformer forward pass. + + Arguments: + - is_init_infer (:obj:`bool`): Indicates if this is an initial inference. Default is True. + Returns: + - list: Updated sizes of the key-value caches. + """ + # Find the maximum size among all key-value caches + max_size = max(self.keys_values_wm_size_list) + + # Iterate over each layer of the transformer + for layer in range(self.num_layers): + kv_cache_k_list = [] + kv_cache_v_list = [] + + # Enumerate through each environment's key-value pairs + for idx, keys_values in enumerate(self.keys_values_wm_list): + k_cache = keys_values[layer]._k_cache._cache + v_cache = keys_values[layer]._v_cache._cache + + effective_size = self.keys_values_wm_size_list[idx] + pad_size = max_size - effective_size + + # If padding is required, trim the end and pad the beginning of the cache + if pad_size > 0: + k_cache_trimmed = k_cache[:, :, :-pad_size, :] + v_cache_trimmed = v_cache[:, :, :-pad_size, :] + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + else: + k_cache_padded = k_cache + v_cache_padded = v_cache + + kv_cache_k_list.append(k_cache_padded) + kv_cache_v_list.append(v_cache_padded) + + # Stack the caches along a new dimension and remove any extra dimensions + self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) + self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) + + # Update the cache size to the maximum size + self.keys_values_wm._keys_values[layer]._k_cache._size = max_size + self.keys_values_wm._keys_values[layer]._v_cache._size = max_size + + return self.keys_values_wm_size_list + + #@profile + def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, + latent_state_index_in_search_path=[], valid_context_lengths=None): + """ + Update the cache context with the given latent state. + + Arguments: + - latent_state (:obj:`torch.Tensor`): The latent state tensor. + - is_init_infer (:obj:`bool`): Flag to indicate if this is the initial inference. + - simulation_index (:obj:`int`): Index of the simulation. + - latent_state_index_in_search_path (:obj:`list`): List of indices in the search path. + - valid_context_lengths (:obj:`list`): List of valid context lengths. + """ + if self.context_length <= 2: + # No context to update if the context length is less than or equal to 2. + return + for i in range(latent_state.size(0)): + # ============ Iterate over each environment ============ + cache_key = hash_state(latent_state[i].view(-1).cpu().numpy()) # latent_state[i] is torch.Tensor + + context_length = self.context_length + + if not is_init_infer: + # ============ Internal Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + current_max_context_length = max(self.keys_values_wm_size_list_current) + trim_size = current_max_context_length - self.keys_values_wm_size_list_current[i] + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + # cache shape [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + if trim_size > 0: + # Trim invalid leading zeros as per effective length + # Remove the first trim_size zero kv items + k_cache_trimmed = k_cache_current[:, trim_size:, :] + v_cache_trimmed = v_cache_current[:, trim_size:, :] + # If effective length < current_max_context_length, pad the end of cache with 'trim_size' zeros + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, 0, trim_size), "constant", + 0) # Pad with 'trim_size' zeros at end of cache + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, 0, trim_size), "constant", 0) + else: + k_cache_padded = k_cache_current + v_cache_padded = v_cache_current + + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = \ + self.keys_values_wm_size_list_current[i] + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = \ + self.keys_values_wm_size_list_current[i] + + # ============ NOTE: Very Important ============ + if self.keys_values_wm_single_env._keys_values[layer]._k_cache._size >= context_length - 1: + # import ipdb; ipdb.set_trace() + + # Keep only the last self.context_length-3 timesteps of context + # For memory environments, training is for H steps, recurrent_inference might exceed H steps + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache + v_cache_current = self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + v_cache_trimmed = v_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + + # Index pre-computed positional encoding differences + # import ipdb; ipdb.set_trace() + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update single environment cache + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + else: + # ============ Root Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + + if self.keys_values_wm._keys_values[layer]._k_cache._size < context_length - 1: # Keep only the last self.context_length-1 timesteps of context + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # Shape torch.Size([2, 100, 512]) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size + else: + # import ipdb; ipdb.set_trace() + + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, 2:context_length - 1, :] + v_cache_trimmed = v_cache_current[:, 2:context_length - 1, :] + + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + # ORIGNAL + # if is_init_infer: + # # Store the latest key-value cache for initial inference + # cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + # self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + # else: + # # Store the latest key-value cache for recurrent inference + # cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + # self.past_kv_cache_recurrent_infer[cache_key] = cache_index + + + if is_init_infer: + # TODO + # ==================== 主动淘汰修复逻辑 ==================== + # 1. 获取即将被覆写的物理索引 + index_to_write = self.shared_pool_index_init_envs[i] + # 2. 使用辅助列表查找该索引上存储的旧的 key + old_key_to_evict = self.pool_idx_to_key_map_init_envs[i][index_to_write] + # 3. 如果存在旧 key,就从主 cache map 中删除它 + if old_key_to_evict is not None: + # 确保要删除的键确实存在,避免意外错误 + if old_key_to_evict in self.past_kv_cache_init_infer_envs[i]: + del self.past_kv_cache_init_infer_envs[i][old_key_to_evict] + + # 现在可以安全地写入新数据了 + cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + + # 4. 在主 cache map 和辅助列表中同时更新新的映射关系 + self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + self.pool_idx_to_key_map_init_envs[i][index_to_write] = cache_key + else: + # ==================== RECURRENT INFER FIX ==================== + # 1. 获取即将被覆写的物理索引 + index_to_write = self.shared_pool_index + # 2. 使用辅助列表查找该索引上存储的旧的 key + old_key_to_evict = self.pool_idx_to_key_map_recur_infer[index_to_write] + # 3. 如果存在旧 key,就从主 cache map 中删除它 + if old_key_to_evict is not None: + if old_key_to_evict in self.past_kv_cache_recurrent_infer: + del self.past_kv_cache_recurrent_infer[old_key_to_evict] + + # 4. 现在可以安全地写入新数据了 + cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + + # 5. 在主 cache map 和辅助列表中同时更新新的映射关系 + self.past_kv_cache_recurrent_infer[cache_key] = cache_index + self.pool_idx_to_key_map_recur_infer[index_to_write] = cache_key + + #@profile + def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, + simulation_index: int = 0, task_id = 0) -> list: + """ + Retrieves or generates key-value caches for each environment based on the latent state. + + For each environment, this method either retrieves a matching cache from the predefined + caches if available, or generates a new cache if no match is found. The method updates + the internal lists with these caches and their sizes. + + Arguments: + - latent_state (:obj:`list`): List of latent states for each environment. + - ready_env_num (:obj:`int`): Number of environments ready for processing. + - simulation_index (:obj:`int`, optional): Index for simulation tracking. Default is 0. + Returns: + - list: Sizes of the key-value caches for each environment. + """ + for i in range(ready_env_num): + self.total_query_count += 1 + state_single_env = latent_state[i] # latent_state[i] is np.array + cache_key = hash_state(state_single_env) + + if self.reanalyze_phase: + # TODO: check if this is correct + matched_value = None + else: + # Try to retrieve the cached value from past_kv_cache_init_infer_envs + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + + # If not found, try to retrieve from past_kv_cache_recurrent_infer + # if matched_value is None: + # matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + + # ==================== TODO ==================== + # 步骤 2: 仅当在 init_infer 中未找到时,才尝试从 recurrent_infer 缓存中查找 + if matched_value is None: + # 2.1 安全地从字典中获取索引,它可能返回 None + recur_cache_index = self.past_kv_cache_recurrent_infer.get(cache_key) + # 2.2 只有在索引有效(不是 None)的情况下,才使用它来从物理池中检索值 + if recur_cache_index is not None: + matched_value = self.shared_pool_recur_infer[recur_cache_index] + + if recur_cache_index is None: + print(f"[CACHE MISS] Not found for key={cache_key} in recurrent infer. Generating new cache.") + + if matched_value is not None: + # If a matching cache is found, add it to the lists + self.hit_count += 1 + # Perform a deep copy because the transformer's forward pass might modify matched_value in-place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # If no matching cache is found, generate a new one using zero reset + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values( + n=1, max_tokens=self.context_length + ) + self.forward( + {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, + past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, task_id=task_id + ) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + return self.keys_values_wm_size_list + + def plot_embeddings( + self, + tsne_results: np.ndarray, + task_ids: np.ndarray, + observations: Union[np.ndarray, torch.Tensor], + samples_per_task: int = 5, + save_dir: str = 'tsne_plots_26games' + ) -> None: + """ + Overview: + Generates a t-SNE visualization plot and annotates it with a specified number of + randomly selected observation images for each task. + + Arguments: + - tsne_results (:obj:`np.ndarray`): The t-SNE dimensionality reduction results (N x 2 array). + - task_ids (:obj:`np.ndarray`): An array of environment task IDs, used for coloring the points (N array). + - observations (:obj:`Union[np.ndarray, torch.Tensor]`): The corresponding observation samples (N x C x H x W tensor or array). + - samples_per_task (:obj:`int`): The number of samples to select for image annotation per task. Defaults to 5. + - save_dir (:obj:`str`): The directory path where the plot will be saved. Defaults to 'tsne_plots_26games'. + """ + # Create the save directory if it doesn't exist. + os.makedirs(save_dir, exist_ok=True) + print(f"[INFO] Save directory created or already exists: {save_dir}") + + # Create the t-SNE plot. + print("[INFO] Starting to draw the t-SNE scatter plot...") + plt.figure(figsize=(18, 10)) # Increase figure width to accommodate the legend on the right. + + # Scatter plot of the t-SNE results. + scatter = plt.scatter( + tsne_results[:, 0], + tsne_results[:, 1], + c=[self.colors[tid] for tid in task_ids], + alpha=0.6, + edgecolor='w', + linewidth=0.5 + ) + + # Create a custom legend for the tasks. + legend_elements = [] + for idx, env_id in enumerate(self.env_id_list): + short_name = self.env_short_names.get(env_id, env_id) + color = self.colors[idx] + legend_elements.append( + Patch(facecolor=color, edgecolor='w', label=f"{idx}: {short_name}") + ) + + # Place the legend on the right side of the plot, with each item on a new line. + plt.legend( + handles=legend_elements, + title="Environment IDs", + loc='center left', + bbox_to_anchor=(1, 0.5), # Position the legend in the center-right of the plot area. + fontsize=10, + title_fontsize=12, + ncol=1, + frameon=False # Remove the legend border for a cleaner look. + ) + + # Set the title and axis labels. + plt.title("t-SNE of Latent States across Environments", fontsize=16) + plt.xlabel("t-SNE Dimension 1", fontsize=14) + plt.ylabel("t-SNE Dimension 2", fontsize=14) + plt.xticks(fontsize=12) + plt.yticks(fontsize=12) + plt.grid(True, linestyle='--', alpha=0.5) + print(f"[INFO] t-SNE scatter plot completed with {len(tsne_results)} points.") + + # Select a specified number of samples per task for image annotation. + print(f"[INFO] Starting to select {samples_per_task} samples per task for image annotation...") + for task_id in range(len(self.env_id_list)): + # Find all indices for the current task. + task_indices = np.where(task_ids == task_id)[0] + if len(task_indices) == 0: + print(f"[WARNING] No samples found for task ID {task_id}.") + continue + + # If the number of samples is less than required, select all of them. + if len(task_indices) < samples_per_task: + selected_indices = task_indices + print(f"[INFO] Task ID {task_id} has fewer samples ({len(task_indices)}) than required ({samples_per_task}). Selecting all.") + else: + selected_indices = np.random.choice(task_indices, size=samples_per_task, replace=False) + print(f"[INFO] Randomly selecting {samples_per_task} samples for task ID {task_id} for annotation.") + + for idx in selected_indices: + img = observations[idx] + if isinstance(img, torch.Tensor): + img = img.cpu().numpy() + + # Handle channel-first (C, H, W) format for grayscale or RGB images. + if img.shape[0] == 1 or img.shape[0] == 3: + img = np.transpose(img, (1, 2, 0)) + else: + raise ValueError(f"Unsupported image shape: {img.shape}") + + # Normalize the image to the [0, 1] range for correct display. + img_min, img_max = img.min(), img.max() + if img_max - img_min > 1e-5: + img = (img - img_min) / (img_max - img_min) + else: + img = np.zeros_like(img) + + imagebox = OffsetImage(img, zoom=0.5) + ab = AnnotationBbox( + imagebox, + (tsne_results[idx, 0], tsne_results[idx, 1]), + frameon=False, + pad=0.3 + ) + plt.gca().add_artist(ab) + print(f"[INFO] Added image annotation: Task ID {task_id}, point index {idx}, t-SNE coords ({tsne_results[idx, 0]:.2f}, {tsne_results[idx, 1]:.2f})") + + # Adjust layout to prevent the legend from being cut off. + plt.tight_layout(rect=[0, 0, 0.9, 1]) # Reserve space for the legend on the right. + + # Save the figure in both PNG and PDF formats with high resolution. + save_path_png = os.path.join(save_dir, 'tsne_plot.png') + save_path_pdf = os.path.join(save_dir, 'tsne_plot.pdf') + plt.savefig(save_path_png, dpi=300, bbox_inches='tight') + plt.savefig(save_path_pdf, dpi=300, bbox_inches='tight') + print(f"[INFO] t-SNE visualization plot saved to: {save_path_png} and {save_path_pdf}") + plt.close() + + @torch.no_grad() + def gather_and_plot( + self, + local_embeddings: torch.Tensor, + local_task_ids: torch.Tensor, + local_observations: torch.Tensor + ) -> None: + """ + Overview: + Gathers embeddings, task IDs, and observations from all distributed processes. + On the main process (rank 0), it performs t-SNE and plots the results. + + Arguments: + - local_embeddings (:obj:`torch.Tensor`): The embedding tensor from the current process. + - local_task_ids (:obj:`torch.Tensor`): The task ID tensor from the current process. + - local_observations (:obj:`torch.Tensor`): The observation tensor from the current process. + """ + world_size = dist.get_world_size() + rank = dist.get_rank() + + # Prepare lists to receive CUDA tensors from all processes. + embeddings_list = [torch.zeros_like(local_embeddings) for _ in range(world_size)] + task_ids_list = [torch.zeros_like(local_task_ids) for _ in range(world_size)] + + # Prepare a list to receive CPU objects (observations) from all processes. + observations_list = [None for _ in range(world_size)] + + try: + # Gather CUDA tensors: embeddings and task_ids. + dist.all_gather(embeddings_list, local_embeddings) + dist.all_gather(task_ids_list, local_task_ids) + + # Gather CPU objects: observations (must be moved to CPU and converted first). + local_observations_cpu = local_observations.cpu().numpy().tolist() + dist.all_gather_object(observations_list, local_observations_cpu) + except RuntimeError as e: + print(f"Rank {rank}: all_gather failed with error: {e}") + return + + if rank == 0: + # Concatenate all embeddings and task_ids on the main process. + all_embeddings = torch.cat(embeddings_list, dim=0).cpu().numpy() + all_task_ids = torch.cat(task_ids_list, dim=0).cpu().numpy() + + # Concatenate all observations. + all_observations_list = [] + for obs in observations_list: + all_observations_list.extend(obs) + all_observations = np.array(all_observations_list) + + print(f"Shape of all_embeddings: {all_embeddings.shape}") + all_embeddings = all_embeddings.reshape(-1, all_embeddings.shape[-1]) + print(f"Shape of all_observations: {all_observations.shape}") + all_observations = all_observations.reshape(-1, *all_observations.shape[-3:]) + + # Perform t-SNE dimensionality reduction. + tsne = TSNE(n_components=2, random_state=42) + tsne_results = tsne.fit_transform(all_embeddings) + + # Plot and save the resulting image. + self.plot_embeddings(tsne_results, all_task_ids, all_observations, save_dir=f'tsne_plots_{self.num_tasks}games') + + #@profile + def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, task_id = 0, **kwargs: Any) -> LossWithIntermediateLosses: + # Encode observations into latent state representations + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + if self.analysis_tsne: + # =========== tsne analysis =========== + if not obs_embeddings.is_cuda: + obs_embeddings = obs_embeddings.cuda() + obs_embeddings = obs_embeddings.contiguous() + local_embeddings = obs_embeddings.detach() + local_task_ids = torch.full((local_embeddings.size(0),), task_id, dtype=torch.long, device=local_embeddings.device) + local_observations = batch['observations'].detach().cpu() + self.gather_and_plot(local_embeddings, local_task_ids, local_observations) + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio_weight_rank: + self._analysis_step_counter += 1 + self.do_analysis = ( + self.analysis_dormant_ratio_weight_rank # 总开关 + and self._analysis_step_counter % self.analysis_dormant_ratio_interval == 0 + ) + + # ========= logging for analysis ========= + if self.do_analysis: + # Calculate dormant ratio of the encoder + shape = batch['observations'].shape # (..., C, H, W) + inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) + if self.continuous_action_space: + encoder_index = task_id + else: + encoder_index = 0 + dormant_ratio_encoder_dict = calculate_dormant_ratio(self.tokenizer.encoder[encoder_index], inputs.detach(), + dormant_threshold=self.dormant_threshold) + + dormant_ratio_encoder = dormant_ratio_encoder_dict['global'] + + avg_weight_mag_encoder = compute_average_weight_magnitude(self.tokenizer.encoder[encoder_index]) + avg_weight_mag_transformer = compute_average_weight_magnitude(self.transformer) + avg_weight_mag_head = compute_average_weight_magnitude(self.head_dict) + + e_rank_last_linear = calculate_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="last_linear") + try: + e_rank_sim_norm = calculate_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="final_norm") + except Exception as e: + e_rank_sim_norm = torch.tensor(0.) + + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_encoder = torch.tensor(0.) + avg_weight_mag_encoder = torch.tensor(0.) + avg_weight_mag_transformer = torch.tensor(0.) + avg_weight_mag_head = torch.tensor(0.) + e_rank_last_linear = torch.tensor(0.) + e_rank_sim_norm = torch.tensor(0.) + # dormant_ratio_encoder = None + + + # Calculate the L2 norm of the latent state roots + latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() + + if self.obs_type == 'image': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # original_images, reconstructed_images = batch['observations'], reconstructed_images + # target_policy = batch['target_policy'] + # ==== for value priority ==== + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + # perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + elif self.obs_type == 'vector': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim)) + # # Calculate reconstruction loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25), + # reconstructed_images) + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + elif self.obs_type == 'image_memory': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + # original_images, reconstructed_images = batch['observations'], reconstructed_images + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # target_policy = batch['target_policy'] + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 5, 5), + # reconstructed_images) + + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Action tokens + if self.continuous_action_space: + act_tokens = batch['actions'] + else: + act_tokens = rearrange(batch['actions'], 'b l -> b l 1') + + # Forward pass to obtain predictions for observations, rewards, and policies + outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, task_id=task_id) + + if self.config.use_priority: + # ==================== START MODIFICATION 5 ==================== + # Calculate value_priority, similar to MuZero. + with torch.no_grad(): + # 1. Get the predicted value logits for the first step of the sequence (t=0). + # The shape is (B, support_size). + predicted_value_logits_step0 = outputs.logits_value[:, 0, :] + + # 2. Convert the categorical prediction to a scalar value. + # The shape becomes (B, 1). + predicted_scalar_value_step0 = inverse_scalar_transform_handle(predicted_value_logits_step0) + + # 3. Get the target scalar value for the first step from the batch. + # The shape is (B, num_unroll_steps), so we take the first column. + target_scalar_value_step0 = batch['scalar_target_value'][:, 0] + + # 4. Calculate the L1 loss (absolute difference) between prediction and target. + # This is the priority. We use reduction='none' to get per-sample priorities. + value_priority = F.l1_loss(predicted_scalar_value_step0.squeeze(-1), target_scalar_value_step0, reduction='none') + # ===================== END MODIFICATION 5 ===================== + else: + value_priority = torch.tensor(0.) + + # ========= logging for analysis ========= + # if self.analysis_dormant_ratio_weight_rank: + if self.do_analysis: + # Calculate dormant ratio of the world model + dormant_ratio_world_model = calculate_dormant_ratio(self, { + 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, + dormant_threshold=self.dormant_threshold) + dormant_ratio_transformer = dormant_ratio_world_model['transformer'] + dormant_ratio_head = dormant_ratio_world_model['head'] + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_transformer = torch.tensor(0.) + dormant_ratio_head = torch.tensor(0.) + + # ========== for visualization ========== + # Uncomment the lines below for visualization + # predict_policy = outputs.logits_policy + # predict_policy = F.softmax(outputs.logits_policy, dim=-1) + # predict_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # predict_rewards = inverse_scalar_transform_handle(outputs.logits_rewards.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # import pdb; pdb.set_trace() + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=[], suffix='pong_H10_H4_0613') + + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_success_episode') + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_fail_episode') + # ========== for visualization ========== + + # For training stability, use target_tokenizer to compute the true next latent state representations + with torch.no_grad(): + target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + # Compute labels for observations, rewards, and ends + labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(target_obs_embeddings, + batch['rewards'], + batch['ends'], + batch['mask_padding']) + + # Reshape the logits and labels for observations + logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') + labels_observations = labels_observations.reshape(-1, self.projection_input_dim) + + if self.use_task_embed and self.task_embed_option == "concat_task_embed": + # Expand task embeddings to match the sequence shape + self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE: TODO + self.task_embeddings = self.sim_norm(self.task_embeddings.view(1,-1)).view(-1) # TODO + task_emb_expanded = self.task_embeddings.expand(labels_observations.shape[0], -1) + labels_observations = torch.cat([labels_observations, task_emb_expanded.detach()], dim=-1) # NOTE: detach() + + # Compute prediction loss for observations. Options: MSE and Group KL + if self.predict_latent_loss_type == 'mse': + # MSE loss, directly compare logits and labels + loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations, reduction='none').mean( + -1) + elif self.predict_latent_loss_type == 'group_kl': + # Group KL loss, group features and calculate KL divergence within each group + batch_size, num_features = logits_observations.shape + epsilon = 1e-6 + logits_reshaped = logits_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + labels_reshaped = labels_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + + loss_obs = F.kl_div(logits_reshaped.log(), labels_reshaped, reduction='none').sum(dim=-1).mean(dim=-1) + + # ========== for debugging ========== + # assert not torch.isnan(logits_reshaped).any(), "logits_reshaped contains NaN values" + # assert not torch.isnan(labels_reshaped).any(), "labels_reshaped contains NaN values" + # print('loss_obs:', loss_obs.mean()) + # for name, param in self.tokenizer.encoder.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + # logits_grad = torch.autograd.grad(loss_obs.mean(), logits_observations, retain_graph=True)[0] + # print(f"logits_grad (min, max, mean): {logits_grad.min()}, {logits_grad.max()}, {logits_grad.mean()}") + + # Apply mask to loss_obs + mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) + loss_obs = (loss_obs * mask_padding_expanded) + + # Compute labels for policy and value + labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], + batch['target_policy'], + batch['mask_padding']) + + # Compute losses for rewards, policy, and value + loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') + + if not self.continuous_action_space: + loss_policy, orig_policy_loss, policy_entropy = self.compute_cross_entropy_loss(outputs, labels_policy, + batch, + element='policy') + else: + # NOTE: for continuous action space + if self.config.policy_loss_type == 'simple': + orig_policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont_simple( + outputs, batch) + else: + orig_policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont( + outputs, batch, task_id=task_id) + + loss_policy = orig_policy_loss + self.policy_entropy_weight * policy_entropy_loss + policy_entropy = - policy_entropy_loss + + loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') + + # Compute timesteps + timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) + # Compute discount coefficients for each timestep + discounts = self.gamma ** timesteps + + if batch['mask_padding'].sum() == 0: + assert False, "mask_padding is all zeros" + + # Group losses into first step, middle step, and last step + first_step_losses = {} + middle_step_losses = {} + last_step_losses = {} + # batch['mask_padding'] indicates mask status for future H steps, exclude masked losses to maintain accurate mean statistics + # Group losses for each loss item + for loss_name, loss_tmp in zip( + ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'], + [loss_obs, loss_rewards, loss_value, loss_policy, orig_policy_loss, policy_entropy] + ): + if loss_name == 'loss_obs': + seq_len = batch['actions'].shape[1] - 1 + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, 1:seq_len] + else: + seq_len = batch['actions'].shape[1] + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, :seq_len] + + # Adjust loss shape to (batch_size, seq_len) + loss_tmp = loss_tmp.view(-1, seq_len) + + # First step loss + first_step_mask = mask_padding[:, 0] + first_step_losses[loss_name] = loss_tmp[:, 0][first_step_mask].mean() + + # Middle step loss + middle_step_index = seq_len // 2 + middle_step_mask = mask_padding[:, middle_step_index] + middle_step_losses[loss_name] = loss_tmp[:, middle_step_index][middle_step_mask].mean() + + # Last step loss + last_step_mask = mask_padding[:, -1] + last_step_losses[loss_name] = loss_tmp[:, -1][last_step_mask].mean() + + # Discount reconstruction loss and perceptual loss + discounted_latent_recon_loss = latent_recon_loss + discounted_perceptual_loss = perceptual_loss + + # Calculate overall discounted loss + discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).sum()/ batch['mask_padding'][:,1:].sum() + discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_value = (loss_value.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_policy = (loss_policy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + + # 为了让外部的训练循环能够获取encoder的输出,我们将其加入返回字典 + # 使用 .detach() 是因为这个张量仅用于后续的clip操作,不应影响梯度计算 + detached_obs_embeddings = obs_embeddings.detach() + + if self.continuous_action_space: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=True, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_transformer=dormant_ratio_transformer, + dormant_ratio_head=dormant_ratio_head, + avg_weight_mag_encoder = avg_weight_mag_encoder, + avg_weight_mag_transformer = avg_weight_mag_transformer, + avg_weight_mag_head = avg_weight_mag_head, + e_rank_last_linear = e_rank_last_linear, + e_rank_sim_norm = e_rank_sim_norm, + latent_state_l2_norms=latent_state_l2_norms, + policy_mu=mu, + policy_sigma=sigma, + target_sampled_actions=target_sampled_actions, + + value_priority=value_priority, + obs_embeddings=detached_obs_embeddings, # <-- 新增 + + ) + else: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=False, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_transformer=dormant_ratio_transformer, + dormant_ratio_head=dormant_ratio_head, + avg_weight_mag_encoder = avg_weight_mag_encoder, + avg_weight_mag_transformer = avg_weight_mag_transformer, + avg_weight_mag_head = avg_weight_mag_head, + e_rank_last_linear = e_rank_last_linear, + e_rank_sim_norm = e_rank_sim_norm, + latent_state_l2_norms=latent_state_l2_norms, + + value_priority=value_priority, + obs_embeddings=detached_obs_embeddings, # <-- 新增 + + + ) + + #@profile + def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): + # Assume outputs is an object with logits attributes like 'rewards', 'policy', and 'value'. + # labels is a target tensor for comparison. batch is a dictionary with a mask indicating valid timesteps. + + logits = getattr(outputs, f'logits_{element}') + + # Reshape your tensors + logits = rearrange(logits, 'b t e -> (b t) e') + labels = labels.reshape(-1, labels.shape[-1]) # Assume labels initially have shape [batch, time, dim] + + # Reshape your mask. True indicates valid data. + mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') + + # Compute cross-entropy loss + loss = -(torch.log_softmax(logits, dim=1) * labels).sum(1) + loss = (loss * mask_padding) + + # if torch.isnan(loss).any(): + # raise ValueError(f"NaN detected in outputs for batch {batch} and element '{element}'") + + if element == 'policy': + # Compute policy entropy loss + policy_entropy = self.compute_policy_entropy_loss(logits, mask_padding) + # Combine losses with specified weight + combined_loss = loss - self.policy_entropy_weight * policy_entropy + return combined_loss, loss, policy_entropy + + return loss + + #@profile + def compute_policy_entropy_loss(self, logits, mask): + # Compute entropy of the policy + probs = torch.softmax(logits, dim=1) + log_probs = torch.log_softmax(logits, dim=1) + entropy = -(probs * log_probs).sum(1) + # Apply mask and return average entropy loss + entropy_loss = (entropy * mask) + return entropy_loss + + #@profile + def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag + mask_fill = torch.logical_not(mask_padding) + + # Prepare observation labels + labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] + + # Fill the masked areas of rewards + mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) + labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) + + # Fill the masked areas of ends + # labels_ends = ends.masked_fill(mask_fill, -100) + + # return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) + return labels_observations, labels_rewards.view(-1, self.support_size), None + + #@profile + def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ Compute labels for value and policy predictions. """ + mask_fill = torch.logical_not(mask_padding) + + # Fill the masked areas of policy + mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) + labels_policy = target_policy.masked_fill(mask_fill_policy, -100) + + # Fill the masked areas of value + mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) + labels_value = target_value.masked_fill(mask_fill_value, -100) + + if self.continuous_action_space: + return None, labels_value.reshape(-1, self.support_size) + else: + return labels_policy.reshape(-1, self.action_space_size), labels_value.reshape(-1, self.support_size) + + #@profile + def clear_caches(self): + """ + Clears the caches of the world model. + """ + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + + print(f'rank {self._rank} Cleared {self.__class__.__name__} past_kv_cache.') + + def __repr__(self) -> str: + return "transformer-based latent world_model of UniZero" diff --git a/lzero/model/utils.py b/lzero/model/utils.py index 70a89d3b0..1204070f9 100644 --- a/lzero/model/utils.py +++ b/lzero/model/utils.py @@ -1,163 +1,319 @@ """ Overview: - In this file, we provide a set of utility functions for probing network parameters and gradients, - which can be helpful in analyzing and debugging the inner workings of various models. + This file provides a set of utility functions for probing network parameters and gradients. + These tools are helpful for analyzing and debugging the inner workings of various models. """ -from typing import List, Tuple +from typing import List, Tuple, Union, Dict, Type, Optional import numpy as np import torch import torch.nn as nn -class LinearOutputHook: +def compute_average_weight_magnitude(model: nn.Module) -> float: """ Overview: - Hook to capture the output of linear layers. + Calculates the average absolute magnitude of all parameters in a given model. + + Arguments: + - model (:obj:`nn.Module`): The model to be evaluated. + + Returns: + - float: The average absolute magnitude of the model's weights. + """ + num_weights = 0 + # Use the device of the model's first parameter to ensure consistency. + device = next(model.parameters()).device + sum_weight_magnitude = torch.tensor(0.0, device=device) + + for p in model.parameters(): + num_weights += p.numel() + sum_weight_magnitude += torch.sum(torch.abs(p)) + + if num_weights == 0: + return 0.0 + return sum_weight_magnitude.cpu().item() / num_weights + + +def compute_effective_rank(singular_values: np.ndarray) -> float: """ + Overview: + Computes the effective rank from an array of singular values. The formula is: + effective_rank = exp(-sum_i [p_i * log(p_i)]), where p_i is the normalized singular value. + + Arguments: + - singular_values (:obj:`np.ndarray`): An array of singular values. + + Returns: + - float: The calculated effective rank. + """ + # Normalize singular values to form a probability distribution. + norm_sv = singular_values / np.sum(np.abs(singular_values)) + entropy = 0.0 + for p in norm_sv: + if p > 1e-8: # Avoid log(0) + entropy -= p * np.log(p) + return np.exp(entropy) + +class IntermediateOutputHook: + """ + Overview: + A hook class to capture and store the output tensors from a specific nn.Module during a forward pass. + """ def __init__(self): + self.outputs: List[torch.Tensor] = [] + + def __call__(self, module: nn.Module, inputs: Tuple[torch.Tensor, ...], output: torch.Tensor) -> None: """ Overview: - Initialize the hook. + This method is called by PyTorch when the hooked module completes its forward pass. """ - self.outputs: List[torch.Tensor] = [] + # Detach the tensor from the computation graph and move to CPU to save memory. + self.outputs.append(output.detach().cpu()) - def __call__(self, module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor) -> None: + def clear(self) -> None: """ Overview: - Capture the output of the module. - Arguments: - - module: The module being hooked. - - input: The input to the module (unused in this hook). - - output: The output from the module. + Clears the list of captured outputs. """ - self.outputs.append(output) + self.outputs.clear() -def cal_dormant_ratio(model: nn.Module, *inputs: torch.Tensor, percentage: float = 0.025) -> float: +def calculate_effective_rank( + model: nn.Module, + inputs: Union[torch.Tensor, List[torch.Tensor]], + representation_layer_name: str, +) -> float: """ Overview: - Calculate the dormant neuron ratio in the model. A neuron is considered dormant if its output is less than a - specified percentage of the average output of the layer. This function is useful for analyzing the sparsity of the model. - More details can be found in the paper https://arxiv.org/abs/2302.12902. + Calculates the effective rank of a specified intermediate layer's output (representation) + by using a forward hook to capture the activations. + Arguments: - - model: The model to evaluate. - - inputs: The inputs to the model. - - percentage: The threshold percentage to consider a neuron dormant, defaults to 0.025. + - model (:obj:`nn.Module`): The model to be evaluated. + - inputs (:obj:`Union[torch.Tensor, List[torch.Tensor]]`): The inputs for the model's forward pass. + - representation_layer_name (:obj:`str`): The name of the representation layer, which must be + findable within `model.named_modules()`. + Returns: - - float: The ratio of dormant neurons in the model. + - float: The effective rank of the representation layer's output. """ - # List to store hooks and their handlers - hooks: List[LinearOutputHook] = [] - hook_handlers: List[torch.utils.hooks.RemovableHandle] = [] - total_neurons: int = 0 - dormant_neurons: int = 0 + module_dict = dict(model.named_modules()) + if representation_layer_name not in module_dict: + raise KeyError(f"Representation layer '{representation_layer_name}' not found in model.named_modules().") + representation_module = module_dict[representation_layer_name] - # Register hooks to capture outputs of specific layers - for _, module in model.named_modules(): - if isinstance(module, (nn.Linear, nn.Conv2d, nn.LSTM)): - hook = LinearOutputHook() - hooks.append(hook) - hook_handlers.append(module.register_forward_hook(hook)) + hook = IntermediateOutputHook() + handle = representation_module.register_forward_hook(hook) + model.eval() with torch.no_grad(): - # Forward pass to capture outputs - model(*inputs) - - # Analyze the captured outputs - for module, hook in zip((module for module in model.modules() if isinstance(module, (nn.Linear, nn.Conv2d, nn.LSTM))), hooks): - with torch.no_grad(): - for output_data in hook.outputs: - mean_output = output_data.abs().mean(0) - avg_neuron_output = mean_output.mean() - dormant_indices = (mean_output < avg_neuron_output * percentage).nonzero(as_tuple=True)[0] - - if isinstance(module, nn.Linear): - # Calculate total and dormant neurons for Linear layers - total_neurons += module.weight.shape[0] * output_data.shape[0] - dormant_neurons += len(dormant_indices) - elif isinstance(module, nn.Conv2d): - # Calculate total and dormant neurons for Conv2D layers - total_neurons += module.weight.shape[0] * output_data.shape[0] * output_data.shape[2] * output_data.shape[3] - dormant_neurons += len(dormant_indices) - elif isinstance(module, nn.LSTM): - # Calculate total and dormant neurons for LSTM layers - total_neurons += module.hidden_size * module.num_layers * output_data.shape[0] * output_data.shape[1] - dormant_neurons += len(dormant_indices) - - # Clean up hooks - for hook in hooks: - hook.outputs.clear() - del hook.outputs - - for hook_handler in hook_handlers: - hook_handler.remove() - del hook_handler - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return dormant_neurons / total_neurons + if isinstance(inputs, (list, tuple)): + _ = model(*inputs) + else: + _ = model(inputs) -def renormalize(inputs: torch.Tensor, first_dim: int = 1) -> torch.Tensor: + # Always remove the hook to prevent memory leaks. + handle.remove() + + if not hook.outputs: + raise RuntimeError("No outputs were captured from the representation layer.") + + # Concatenate all captured outputs along the batch dimension. + rep_tensor = torch.cat(hook.outputs, dim=0) if len(hook.outputs) > 1 else hook.outputs[0] + + # Reshape the representation to a 2D matrix (samples, features). + rep_tensor = rep_tensor.view(rep_tensor.size(0), -1) + + # Compute singular values using SVD. + singular_values = np.linalg.svd(rep_tensor.cpu().numpy(), full_matrices=False, compute_uv=False) + + # Calculate the effective rank. + e_rank = compute_effective_rank(singular_values) + + hook.clear() + return e_rank + + +def compute_dormant_stats(outputs: List[torch.Tensor], threshold: float) -> Tuple[int, int]: """ Overview: - Normalize the input data using the max-min-normalization. + Computes element-wise statistics for a list of output tensors from a layer. + Arguments: - - inputs (:obj:`torch.Tensor`): The input data needs to be normalized. - - first_dim (:obj:`int`): The first dimension of flattening the input data. + - outputs (:obj:`List[torch.Tensor]`): A list of tensors, each representing an output from a forward pass. + - threshold (:obj:`float`): The activation threshold below which a neuron is considered dormant. + Returns: - - output (:obj:`torch.Tensor`): The normalized data. + - Tuple[int, int]: A tuple containing the total number of elements and the number of dormant elements. """ - if first_dim < 0: - first_dim = len(inputs.shape) + first_dim - flat_input = inputs.view(*inputs.shape[:first_dim], -1) - max_val = torch.max(flat_input, first_dim, keepdim=True).values - min_val = torch.min(flat_input, first_dim, keepdim=True).values - flat_input = (flat_input - min_val) / (max_val - min_val) + layer_total = 0 + layer_dormant = 0 + for out in outputs: + flattened = out.view(-1) + layer_total += flattened.numel() + layer_dormant += torch.sum(flattened <= threshold).item() + return layer_total, layer_dormant + + +def calculate_dormant_ratio( + model: nn.Module, + inputs: Union[torch.Tensor, List[torch.Tensor]], + dormant_threshold: float = 1e-2, + target_modules: Tuple[Type[nn.Module], ...] = (nn.Conv2d, nn.Linear), +) -> Dict[str, float]: + """ + Overview: + Calculates the dormant ratio (percentage of neurons with activation below a threshold) for + different parts of a model (e.g., encoder, transformer, head). It assumes the model has + attributes like `encoder`, `transformer`, or `head_dict`. + + Arguments: + - model (:obj:`nn.Module`): The model to evaluate, expected to have `encoder`, `transformer`, or `head_dict` attributes. + - inputs (:obj:`Union[torch.Tensor, List[torch.Tensor]]`): The inputs for the model's forward pass. + - dormant_threshold (:obj:`float`): The activation threshold for defining a dormant neuron. Defaults to 1e-2. + - target_modules (:obj:`Tuple[Type[nn.Module], ...]`): A tuple of module types to attach hooks to. - return flat_input.view(*inputs.shape) + Returns: + - Dict[str, float]: A dictionary containing the dormant ratios for each model part and a global ratio. + """ + parts = {} + if hasattr(model, "encoder"): + parts["encoder"] = model.encoder + if hasattr(model, "transformer"): + parts["transformer"] = model.transformer + if hasattr(model, "head_dict"): + parts["head"] = model.head_dict + # Fallback for models that don't have the standard part attributes. + if not parts: + parts["model"] = model -def get_dynamic_mean(model: nn.Module) -> float: - dynamic_mean = np.abs(model.conv.weight.detach().cpu().numpy().reshape(-1)).tolist() + hooks_dict = {part: [] for part in parts} + hook_handles = [] - for block in model.resblocks: - for name, param in block.named_parameters(): - dynamic_mean += np.abs(param.detach().cpu().numpy().reshape(-1)).tolist() - dynamic_mean = sum(dynamic_mean) / len(dynamic_mean) - return dynamic_mean + # Register a forward hook for each target module in each part. + for part_name, submodule in parts.items(): + for name, module in submodule.named_modules(): + if isinstance(module, target_modules): + hook = IntermediateOutputHook() + full_name = f"{part_name}/{name}" + hooks_dict[part_name].append((full_name, hook)) + handle = module.register_forward_hook(hook) + hook_handles.append(handle) + model.eval() + with torch.no_grad(): + if isinstance(inputs, (list, tuple)): + _ = model(*inputs) + else: + _ = model(inputs) -def get_reward_mean(model: nn.Module) -> Tuple[np.ndarray, float]: - reward_w_dist = model.conv1x1_reward.weight.detach().cpu().numpy().reshape(-1) + results = {} + total_global = 0 + dormant_global = 0 - for name, param in model.fc.named_parameters(): - temp_weights = param.detach().cpu().numpy().reshape(-1) - reward_w_dist = np.concatenate((reward_w_dist, temp_weights)) - reward_mean = np.abs(reward_w_dist).mean() - return reward_w_dist, reward_mean + # Calculate dormant stats from captured outputs. + for part, hooks in hooks_dict.items(): + part_total = 0 + part_dormant = 0 + for full_name, hook in hooks: + layer_total, layer_dormant = compute_dormant_stats(hook.outputs, dormant_threshold) + part_total += layer_total + part_dormant += layer_dormant + + results[part] = (part_dormant / part_total) * 100.0 if part_total > 0 else 0.0 + total_global += part_total + dormant_global += part_dormant + results["global"] = (dormant_global / total_global) * 100.0 if total_global > 0 else 0.0 -def get_params_mean(model: nn.Module) -> Tuple[np.ndarray, float, float, float]: - representation_mean = model.representation_network.get_param_mean() - dynamic_mean = model.dynamics_network.get_dynamic_mean() - reward_w_dist, reward_mean = model.dynamics_network.get_reward_mean() + # Clean up all hooks. + for handle in hook_handles: + handle.remove() + for hooks in hooks_dict.values(): + for _, hook in hooks: + hook.clear() - return reward_w_dist, representation_mean, dynamic_mean, reward_mean + return results -def get_gradients(model: nn.Module) -> List[torch.Tensor]: - grads = [] - for p in model.parameters(): - grad = None if p.grad is None else p.grad.detach() - grads.append(grad) - return grads +def renormalize(inputs: torch.Tensor, first_dim: int = 1) -> torch.Tensor: + """ + Overview: + Normalizes the input tensor using min-max scaling. The normalization is applied + over all dimensions starting from `first_dim`. + + Arguments: + - inputs (:obj:`torch.Tensor`): The input tensor to be normalized. + - first_dim (:obj:`int`): The first dimension from which to flatten the tensor for normalization. + + Returns: + - torch.Tensor: The min-max normalized tensor. + """ + if first_dim < 0: + first_dim = inputs.dim() + first_dim + + shape = inputs.shape + flat_input = inputs.view(*shape[:first_dim], -1) + + max_val, _ = torch.max(flat_input, dim=first_dim, keepdim=True) + min_val, _ = torch.min(flat_input, dim=first_dim, keepdim=True) + + # Add a small epsilon to avoid division by zero. + denominator = max_val - min_val + denominator[denominator < 1e-8] = 1e-8 + + normalized_flat = (flat_input - min_val) / denominator + + return normalized_flat.view(*shape) + + +def get_params_mean(model: nn.Module) -> float: + """ + Overview: + Calculates the mean of the absolute values of all parameters in a model. This is an alias + for `compute_average_weight_magnitude`. + Arguments: + - model (:obj:`nn.Module`): The model to be evaluated. + + Returns: + - float: The mean of the absolute parameter values. + """ + return compute_average_weight_magnitude(model) + + +def get_gradients(model: nn.Module) -> List[Optional[torch.Tensor]]: + """ + Overview: + Retrieves the gradients of all parameters in a model. + + Arguments: + - model (:obj:`nn.Module`): The model from which to get gradients. + + Returns: + - List[Optional[torch.Tensor]]: A list of gradient tensors. If a parameter has no gradient, + the corresponding list entry is None. + """ + return [p.grad.detach() if p.grad is not None else None for p in model.parameters()] + + +def set_gradients(model: nn.Module, gradients: List[Optional[torch.Tensor]]) -> None: + """ + Overview: + Sets the gradients for all parameters in a model. + + Arguments: + - model (:obj:`nn.Module`): The model whose gradients are to be set. + - gradients (:obj:`List[Optional[torch.Tensor]]`): A list of gradients to assign to the model's parameters. + """ + params = list(model.parameters()) + if len(gradients) != len(params): + raise ValueError(f"Number of gradients ({len(gradients)}) does not match number of model parameters ({len(params)}).") -def set_gradients(model: nn.Module, gradients: List[torch.Tensor]) -> None: - # TODO due to the drawback of zip operation, we have to check whether gradients match model's parameters - for g, p in zip(gradients, model.parameters()): + for g, p in zip(gradients, params): if g is not None: - p.grad = g + # Ensure the gradient is on the same device as the parameter. + p.grad = g.to(p.device) \ No newline at end of file diff --git a/lzero/model/vit.py b/lzero/model/vit.py new file mode 100644 index 000000000..0bc5ebc04 --- /dev/null +++ b/lzero/model/vit.py @@ -0,0 +1,444 @@ +# -*- coding: utf-8 -*- +""" +Optimized Vision Transformer (ViT) Model. + +This script provides an optimized implementation of the Vision Transformer (ViT) architecture. +It includes improvements in code structure, clarity, and adherence to modern Python coding standards, +including comprehensive type hinting and documentation. The implementation also supports +integration with Low-Rank Adaptation (LoRA) through a flexible configuration system. + +Author: [Your Name/Team Name] +Date: [Current Date] +""" + +import torch +from torch import nn +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from lzero.model.common import SimNorm +from typing import Tuple, Union, Type, Optional + +# ==================== LoRA Integration Section Start ==================== + +# Attempt to import core components from a local transformer.py for LoRA support. +# This allows for flexible adaptation (e.g., LoRA) of linear layers. +try: + # Assuming transformer.py is in the same directory. Adjust the import path if necessary. + from .transformer import _maybe_wrap_linear, TransformerConfig +except ImportError: + # If the import fails (e.g., when running this file directly), provide a fallback. + # This ensures the model remains functional without LoRA components. + print("Warning: LoRA components could not be imported. Using standard nn.Linear.") + _maybe_wrap_linear = lambda linear, config, label: linear + + # Define a placeholder class for TransformerConfig if it's not available. + class TransformerConfig: + """Placeholder for TransformerConfig when LoRA components are not available.""" + pass + +# ==================== LoRA Integration Section End ==================== + + +# ==================== Configuration Class ==================== + +class ViTConfig: + """ + Overview: + Configuration class for the Vision Transformer (ViT) model. + This class centralizes all hyperparameters, making the model easier to configure and manage. + """ + def __init__(self, **kwargs): + """ + Overview: + Initializes the ViTConfig object. + Arguments: + - **kwargs: Arbitrary keyword arguments to override default settings. + """ + # Image and Patch Dimensions + self.image_size: Union[int, Tuple[int, int]] = 64 + self.patch_size: Union[int, Tuple[int, int]] = 8 + self.channels: int = 3 + + # Model Architecture + self.num_classes: int = 768 + self.dim: int = 768 + self.depth: int = 12 + self.heads: int = 12 + self.mlp_dim: int = 3072 + self.dim_head: int = 64 + + # Pooling and Normalization + self.pool: str = 'cls' # 'cls' or 'mean' + self.final_norm_option_in_encoder: str = 'LayerNorm' # 'LayerNorm' or 'SimNorm' + + # Dropout Rates + self.dropout: float = 0.1 + self.emb_dropout: float = 0.1 + + # LoRA Configuration + self.lora_config: Optional[TransformerConfig] = None + + # Update attributes with any provided keyword arguments + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + print(f"Warning: Ignoring unknown config parameter '{key}'") + + +# ==================== Helper Functions ==================== + +def pair(t: Union[int, Tuple[int, int]]) -> Tuple[int, int]: + """ + Overview: + Converts an integer to a tuple of two identical integers. If the input is already a tuple, it is returned as is. + This is useful for handling kernel sizes, strides, etc., which can be specified as a single number or a tuple. + Arguments: + - t (:obj:`Union[int, Tuple[int, int]]`): The input value. + Returns: + - (:obj:`Tuple[int, int]`): A tuple of two integers. + """ + return t if isinstance(t, tuple) else (t, t) + + +# ==================== Core Modules ==================== + +class FeedForward(nn.Module): + """ + Overview: + A standard feed-forward network block used in Transformer architectures. + It consists of two linear layers with a GELU activation in between. + """ + def __init__( + self, + dim: int, + hidden_dim: int, + dropout: float = 0.0, + config: Optional[TransformerConfig] = None + ): + """ + Overview: + Initializes the FeedForward module. + Arguments: + - dim (:obj:`int`): The input and output dimension. + - hidden_dim (:obj:`int`): The dimension of the hidden layer. + - dropout (:obj:`float`): The dropout rate. + - config (:obj:`Optional[TransformerConfig]`): Configuration for LoRA wrapping. + """ + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + _maybe_wrap_linear(nn.Linear(dim, hidden_dim), config, "feed_forward"), + nn.GELU(), + nn.Dropout(dropout), + _maybe_wrap_linear(nn.Linear(hidden_dim, dim), config, "feed_forward"), + nn.Dropout(dropout) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass for the FeedForward block. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor of shape (batch_size, num_tokens, dim). + Returns: + - (:obj:`torch.Tensor`): The output tensor of the same shape as input. + """ + return self.net(x) + + +class Attention(nn.Module): + """ + Overview: + Multi-Head Self-Attention (MHSA) module. + It computes scaled dot-product attention across multiple heads. + """ + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + config: Optional[TransformerConfig] = None + ): + """ + Overview: + Initializes the Attention module. + Arguments: + - dim (:obj:`int`): The input and output dimension. + - heads (:obj:`int`): The number of attention heads. + - dim_head (:obj:`int`): The dimension of each attention head. + - dropout (:obj:`float`): The dropout rate for attention weights and output. + - config (:obj:`Optional[TransformerConfig]`): Configuration for LoRA wrapping. + """ + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.norm = nn.LayerNorm(dim) + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + # Linear layer to project input to Q, K, V. Potentially wrapped for LoRA. + self.to_qkv = _maybe_wrap_linear(nn.Linear(dim, inner_dim * 3, bias=False), config, "attn") + + # Output projection layer. + if project_out: + # Wrap the linear layer inside the sequential module for LoRA. + wrapped_linear = _maybe_wrap_linear(nn.Linear(inner_dim, dim), config, "attn") + self.to_out = nn.Sequential( + wrapped_linear, + nn.Dropout(dropout) + ) + else: + self.to_out = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass for the Attention module. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor of shape (batch_size, num_tokens, dim). + Returns: + - (:obj:`torch.Tensor`): Output tensor of the same shape as input. + """ + x = self.norm(x) + + # Project to Q, K, V and split. + qkv = self.to_qkv(x).chunk(3, dim=-1) + # Rearrange for multi-head attention: b n (h d) -> b h n d + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) + + # Scaled dot-product attention. + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = self.attend(dots) + attn = self.dropout(attn) + + # Apply attention to values. + out = torch.matmul(attn, v) + # Rearrange back to original shape: b h n d -> b n (h d) + out = rearrange(out, 'b h n d -> b n (h d)') + + return self.to_out(out) + + +class Transformer(nn.Module): + """ + Overview: + A stack of Transformer blocks, each containing a multi-head self-attention + layer and a feed-forward network. + """ + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + config: Optional[TransformerConfig] = None + ): + """ + Overview: + Initializes the Transformer module. + Arguments: + - dim (:obj:`int`): The dimension of the token embeddings. + - depth (:obj:`int`): The number of Transformer blocks. + - heads (:obj:`int`): The number of attention heads. + - dim_head (:obj:`int`): The dimension of each attention head. + - mlp_dim (:obj:`int`): The hidden dimension of the feed-forward network. + - dropout (:obj:`float`): The dropout rate. + - config (:obj:`Optional[TransformerConfig]`): Configuration for LoRA. + """ + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout, config=config), + FeedForward(dim, mlp_dim, dropout=dropout, config=config) + ])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass for the Transformer stack. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor of shape (batch_size, num_tokens, dim). + Returns: + - (:obj:`torch.Tensor`): Output tensor of the same shape. + """ + for attn, ff in self.layers: + x = attn(x) + x # Apply attention and residual connection + x = ff(x) + x # Apply feed-forward and residual connection + return self.norm(x) + + +class ViT(nn.Module): + """ + Overview: + Vision Transformer (ViT) model. This model applies the Transformer architecture + to sequences of image patches for image classification tasks. + """ + def __init__(self, config: ViTConfig): + """ + Overview: + Initializes the ViT model using a configuration object. + Arguments: + - config (:obj:`ViTConfig`): A configuration object containing all model hyperparameters. + """ + super().__init__() + self.config = config + + image_height, image_width = pair(config.image_size) + patch_height, patch_width = pair(config.patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, \ + 'Image dimensions must be divisible by the patch size.' + + num_patches = (image_height // patch_height) * (image_width // patch_width) + patch_dim = config.channels * patch_height * patch_width + assert config.pool in {'cls', 'mean'}, 'pool type must be either "cls" or "mean"' + + # Patch embedding layer + self.to_patch_embedding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width), + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, config.dim), + nn.LayerNorm(config.dim), + ) + + # Positional embedding and CLS token + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, config.dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, config.dim)) + self.dropout = nn.Dropout(config.emb_dropout) + + # Transformer encoder stack + self.transformer = Transformer( + dim=config.dim, + depth=config.depth, + heads=config.heads, + dim_head=config.dim_head, + mlp_dim=config.mlp_dim, + dropout=config.dropout, + config=config.lora_config + ) + + self.pool = config.pool + self.last_linear = nn.Linear(config.dim, config.num_classes) + + # Final normalization layer + if config.final_norm_option_in_encoder == 'LayerNorm': + self.final_norm = nn.LayerNorm(config.num_classes, eps=1e-5) + elif config.final_norm_option_in_encoder == 'SimNorm': + group_size = 8 # As specified in original code + self.final_norm = SimNorm(simnorm_dim=group_size) + else: + raise ValueError(f"Unsupported final_norm_option_in_encoder: {config.final_norm_option_in_encoder}") + + def forward(self, img: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass for the ViT model. + Arguments: + - img (:obj:`torch.Tensor`): Input image tensor of shape (batch_size, channels, height, width). + Returns: + - (:obj:`torch.Tensor`): Output logits tensor of shape (batch_size, num_classes). + """ + # 1. Patch embedding + x = self.to_patch_embedding(img) + b, n, _ = x.shape + + # 2. Prepend CLS token + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b) + x = torch.cat((cls_tokens, x), dim=1) + + # 3. Add positional embedding + x += self.pos_embedding[:, :(n + 1)] + x = self.dropout(x) + + # 4. Pass through Transformer encoder + x = self.transformer(x) + + # 5. Pooling + x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] + + # 6. Final classification head + x = self.last_linear(x) + x = self.final_norm(x) + + return x + + +# ==================== Test and Benchmark Code ==================== +if __name__ == "__main__": + import random + import time + + # Fix random seeds for reproducibility + torch.manual_seed(42) + random.seed(42) + + # 1. Create a configuration object + # This is now the standard way to configure the model. + vit_config = ViTConfig( + image_size=64, + patch_size=8, + num_classes=768, + dim=768, + depth=12, + heads=12, + mlp_dim=3072, + dropout=0.1, + emb_dropout=0.1, + final_norm_option_in_encoder="LayerNorm" + ) + + # 2. Instantiate the model with the config + model = ViT(config=vit_config) + + # Move model to GPU if available + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() # Set model to evaluation mode for inference + + # Create a dummy input tensor + dummy_input = torch.randn(256, 3, 64, 64).to(device) + + # Perform a single forward pass + with torch.no_grad(): + out = model(dummy_input) + + print(f"Device: {device}") + print(f"Output shape: {out.shape}") + print(f"Output[0] (first 50 values): {out[0][:50]}") + + # 3. Simple Benchmark + print("\nStarting benchmark...") + warmup_reps, bench_reps = 5, 20 + + with torch.no_grad(): + # Warm-up runs + for _ in range(warmup_reps): + _ = model(dummy_input) + + # Synchronize before timing (for CUDA) + if torch.cuda.is_available(): + torch.cuda.synchronize() + + start_time = time.time() + for _ in range(bench_reps): + _ = model(dummy_input) + + # Synchronize after timing + if torch.cuda.is_available(): + torch.cuda.synchronize() + + end_time = time.time() + + total_time = end_time - start_time + avg_latency_ms = (total_time / bench_reps) * 1000 + print(f"Average latency over {bench_reps} runs: {avg_latency_ms:.2f} ms") \ No newline at end of file diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 7bd2e8d2b..da69fbd80 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -15,7 +15,7 @@ from lzero.mcts import MuZeroMCTSCtree as MCTSCtree from lzero.mcts import MuZeroMCTSPtree as MCTSPtree from lzero.model import ImageTransforms -from lzero.model.utils import cal_dormant_ratio +from lzero.model.utils import calculate_dormant_ratio from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ prepare_obs, configure_optimizers @@ -113,7 +113,7 @@ class MuZeroPolicy(Policy): # This is done by setting the parameter learn.learner.hook.save_ckpt_after_iter to the same value as eval_freq in the train_muzero.py automatically. eval_offline=False, # (bool) Whether to calculate the dormant ratio. - cal_dormant_ratio=False, + calculate_dormant_ratio=False, # (bool) Whether to analyze simulation normalization. analysis_sim_norm=False, # (bool) Whether to analyze dormant ratio. @@ -423,8 +423,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # ========= logging for analysis ========= # calculate dormant ratio of encoder - if self._cfg.cal_dormant_ratio: - self.dormant_ratio_encoder = cal_dormant_ratio(self._learn_model.representation_network, obs_batch.detach(), + if self._cfg.calculate_dormant_ratio: + self.dormant_ratio_encoder = calculate_dormant_ratio(self._learn_model.representation_network, obs_batch.detach(), percentage=self._cfg.dormant_threshold) # calculate L2 norm of latent state latent_state_l2_norms = torch.norm(latent_state.view(latent_state.shape[0], -1), p=2, dim=1).mean() @@ -470,7 +470,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) # ========= logging for analysis =============== - if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.cal_dormant_ratio: + if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.calculate_dormant_ratio: # calculate dormant ratio of encoder action_tmp = action_batch[:, step_k] if len(action_tmp.shape) == 1: @@ -486,7 +486,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in latent_state.shape[0], policy_logits.shape[-1], latent_state.shape[2], latent_state.shape[3] ) state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) - self.dormant_ratio_dynamics = cal_dormant_ratio(self._learn_model.dynamics_network, + self.dormant_ratio_dynamics = calculate_dormant_ratio(self._learn_model.dynamics_network, state_action_encoding.detach(), percentage=self._cfg.dormant_threshold) # ========= logging for analysis =============== @@ -941,7 +941,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ return output - def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: + def _reset_collect(self, data_id: Optional[List[int]] = None, task_id: int = None) -> None: """ Overview: Reset the observation and action for the collector environment. @@ -956,7 +956,7 @@ def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: ) self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] - def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + def _reset_eval(self, data_id: Optional[List[int]] = None, task_id: int = None) -> None: """ Overview: Reset the observation and action for the evaluator environment. @@ -970,6 +970,7 @@ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: self._cfg.device ) self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + def _monitor_vars_learn(self) -> List[str]: """ Overview: diff --git a/lzero/policy/muzero_multitask.py b/lzero/policy/muzero_multitask.py new file mode 100644 index 000000000..45addaf59 --- /dev/null +++ b/lzero/policy/muzero_multitask.py @@ -0,0 +1,895 @@ +import copy +from typing import List, Dict, Tuple, Union, Optional + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY + +from lzero.mcts import MuZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.model.utils import cal_dormant_ratio +from lzero.policy import ( + scalar_transform, + InverseScalarTransform, + cross_entropy_loss, + phi_transform, + DiscreteSupport, + to_torch_float_tensor, + mz_network_output_unpack, + select_action, + negative_cosine_similarity, + prepare_obs, +) +from lzero.policy.muzero import MuZeroPolicy + + +def generate_task_loss_dict(multi_task_losses: List[float], task_name_template: str, task_id: int) -> Dict[str, float]: + """ + Overview: + Generates a dictionary for the losses of each task. + Arguments: + - multi_task_losses (:obj:`List[float]`): A list containing the loss for each task. + - task_name_template (:obj:`str`): A template for the task name, e.g., 'loss_task{}'. + - task_id (:obj:`int`): The starting ID for the tasks. + Returns: + - task_loss_dict (:obj:`Dict[str, float]`): A dictionary containing the loss for each task. + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + # Ensure the loss is a scalar value for logging. + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + except Exception: + task_loss_dict[task_name] = task_loss + return task_loss_dict + +class WrappedModelV2: + """ + Overview: + A wrapper class to bundle different parts of a model (tokenizer, transformer, embeddings) + for easier management of parameters and gradients. + """ + def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_table): + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self) -> List[torch.nn.Parameter]: + """ + Overview: + Returns a list of all parameters from the tokenizer, transformer, and all embedding layers. + """ + return ( + list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters()) + ) + + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all parameters in the tokenizer, transformer, and embedding layers to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + +@POLICY_REGISTRY.register('muzero_multitask') +class MuZeroMTPolicy(MuZeroPolicy): + """ + Overview: + The multi-task policy for MuZero, extending MuZeroPolicy. It supports training multiple tasks + simultaneously by separating the loss for each task and optimizing them jointly. + """ + + # Default configuration for MuZeroMTPolicy. + config = dict( + type='muzero_multitask', + model=dict( + model_type='conv', # options={'mlp', 'conv'} + continuous_action_space=False, + observation_shape=(4, 96, 96), # example shape + self_supervised_learning_loss=False, + categorical_distribution=True, + image_channel=1, + frame_stack_num=1, + num_res_blocks=1, + num_channels=64, + support_scale=300, + bias=True, + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + analysis_sim_norm=False, + analysis_dormant_ratio=False, + harmony_balance=False, + ), + # ****** common ****** + use_rnd_model=False, + multi_gpu=False, + sampled_algo=False, + gumbel_algo=False, + mcts_ctree=True, + cuda=True, + collector_env_num=8, + evaluator_env_num=3, + env_type='not_board_games', + action_type='fixed_action_space', + battle_mode='play_with_bot_mode', + monitor_extra_statistics=True, + game_segment_length=200, + eval_offline=False, + cal_dormant_ratio=False, + analysis_sim_norm=False, + analysis_dormant_ratio=False, + + # ****** observation ****** + transform2string=False, + gray_scale=False, + use_augmentation=False, + augmentation=['shift', 'intensity'], + + # ******* learn ****** + use_wandb=False, + ignore_done=False, + update_per_collect=None, + replay_ratio=0.25, + batch_size=256, + optim_type='SGD', + learning_rate=0.2, + target_update_freq=100, + target_update_freq_for_intrinsic_reward=1000, + weight_decay=1e-4, + momentum=0.9, + grad_clip_value=10, + n_episode=8, + num_segments=8, + num_simulations=50, + discount_factor=0.997, + td_steps=5, + num_unroll_steps=5, + reward_loss_weight=1, + value_loss_weight=0.25, + policy_loss_weight=1, + policy_entropy_weight=0, + ssl_loss_weight=0, + lr_piecewise_constant_decay=True, + threshold_training_steps_for_final_lr=int(5e4), + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(1e5), + fixed_temperature_value=0.25, + use_ture_chance_label_in_chance_encoder=False, + + # ****** Priority ****** + use_priority=False, + priority_prob_alpha=0.6, + priority_prob_beta=0.4, + + # ****** UCB ****** + root_dirichlet_alpha=0.3, + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + eps_greedy_exploration_in_collect=False, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + + # ****** Multi-task related ****** + task_num=2, # Number of tasks, adjust as needed. + task_id=0, # The starting ID of the current task. + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Returns the default model configuration for this algorithm. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): A tuple containing the model name and a list of import paths. + """ + return 'MuZeroMTModel', ['lzero.model.muzero_model_multitask'] + + def _init_learn(self) -> None: + """ + Overview: + Initializes the learning mode. This method sets up the learning model, optimizer, and MCTS utilities. + """ + super()._init_learn() + + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + # NOTE: In board games, for a fixed learning rate of 0.003, 'Adam' performs better than 'SGD'. + if self._cfg.optim_type == 'SGD': + self._optimizer = optim.SGD( + self._model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay, + ) + elif self._cfg.optim_type == 'Adam': + self._optimizer = optim.Adam( + self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW': + self._optimizer = configure_optimizers(model=self._model, weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, device_type=self._cfg.device) + + # Learning rate scheduler + if self._cfg.lr_piecewise_constant_decay: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: 1, 0.1, 0.01 are decay rates, not the learning rate itself. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # Use model_wrapper for specialized demands of different modes. + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model + + # Image augmentation + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + + # Support for categorical distribution + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + # ============================================================== + # HarmonyDream (learnable weights for different losses) + # ============================================================== + if self._cfg.model.harmony_balance: + # List of parameter names. + harmony_names = ["harmony_dynamics", "harmony_policy", "harmony_value", "harmony_reward", "harmony_entropy"] + # Initialize and name each parameter. + for name in harmony_names: + param = torch.nn.Parameter(-torch.log(torch.tensor(1.0))) + setattr(self, name, param) + + # RND model for intrinsic reward + if self._cfg.use_rnd_model: + if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq_for_intrinsic_reward} + ) + elif self._cfg.target_model_for_intrinsic_reward_update_type == 'momentum': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward} + ) + + # ========= Logging for analysis ========= + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + self.dormant_ratio_encoder = 0. + self.dormant_ratio_dynamics = 0. + + # Initialize multi-task related parameters. + self.task_num_for_current_rank = self._cfg.task_num + self.task_id = self._cfg.task_id + + def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning, which is the core of the learning process. + Data is sampled from the replay buffer, and the loss is calculated and backpropagated + to update the model. + Arguments: + - data (:obj:`List[Tuple[torch.Tensor, torch.Tensor, int]]`): A list of data tuples for each task, + where each tuple contains (current_batch, target_batch, task_id). + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): A dictionary of information for logging, + including the current learning loss and other learning statistics. + """ + self._learn_model.train() + self._target_model.train() + + # Initialize lists for multi-task losses. + reward_loss_multi_task = [] + policy_loss_multi_task = [] + value_loss_multi_task = [] + consistency_loss_multi_task = [] + policy_entropy_multi_task = [] + lambd_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + + weighted_total_loss = 0.0 # Initialize to zero. + losses_list = [] # To store the loss for each task. + + for task_idx, (current_batch, target_batch, task_id) in enumerate(data): + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # Data augmentation. + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare action batch and convert to tensor. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [mask_batch, target_reward, target_value, target_policy, weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor( + data_list, self._cfg.device + ) + + target_reward = target_reward.view(self._cfg.batch_size[task_idx], -1) + target_value = target_value.view(self._cfg.batch_size[task_idx], -1) + + assert obs_batch.size(0) == self._cfg.batch_size[task_idx] == target_reward.size(0) + + # Transform rewards and values to scaled representation. + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # Convert to categorical distribution. + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # Initial inference. + network_output = self._learn_model.initial_inference(obs_batch, task_id=task_id) + + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # Log Dormant Ratio and L2 Norm. + if self._cfg.cal_dormant_ratio: + self.dormant_ratio_encoder = cal_dormant_ratio( + self._learn_model.representation_network, obs_batch.detach(), + percentage=self._cfg.dormant_threshold + ) + latent_state_l2_norms = torch.norm(latent_state.view(latent_state.shape[0], -1), p=2, dim=1).mean() + + # Inverse transform value. + original_value = self.inverse_scalar_transform_handle(value) + + # Initialize predicted values and policies. + predicted_rewards = [] + if self._cfg.monitor_extra_statistics: + predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( + policy_logits, dim=1 + ).detach().cpu() + + # Calculate priority. + value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + value_priority = value_priority.data.cpu().numpy() + 1e-6 + + # Calculate policy and value loss for the first step. + policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + + prob = torch.softmax(policy_logits, dim=-1) + entropy = -(prob * torch.log(prob + 1e-9)).sum(-1) + policy_entropy_loss = -entropy + + reward_loss = torch.zeros(self._cfg.batch_size[task_idx], device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size[task_idx], device=self._cfg.device) + target_policy_entropy = 0 + + # Unroll loop for multiple steps. + for step_k in range(self._cfg.num_unroll_steps): + # Recurrent inference using the dynamics function. + network_output = self._learn_model.recurrent_inference(latent_state, action_batch[:, step_k]) + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # Log Dormant Ratio for the dynamics network. + if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.cal_dormant_ratio: + action_tmp = action_batch[:, step_k] + if len(action_tmp.shape) == 1: + action_tmp = action_tmp.unsqueeze(-1) + # Convert action to one-hot encoding. + action_one_hot = torch.zeros(action_tmp.shape[0], policy_logits.shape[-1], device=action_tmp.device) + action_tmp = action_tmp.long() + action_one_hot.scatter_(1, action_tmp, 1) + action_encoding_tmp = action_one_hot.unsqueeze(-1).unsqueeze(-1) + action_encoding = action_encoding_tmp.expand( + latent_state.shape[0], policy_logits.shape[-1], latent_state.shape[2], latent_state.shape[3] + ) + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + self.dormant_ratio_dynamics = cal_dormant_ratio( + self._learn_model.dynamics_network, + state_action_encoding.detach(), + percentage=self._cfg.dormant_threshold + ) + + # Inverse transform value. + original_value = self.inverse_scalar_transform_handle(value) + + # Calculate consistency loss (self-supervised learning). + if self._cfg.model.self_supervised_learning_loss and self._cfg.ssl_loss_weight > 0: + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index], task_id=task_id) + + latent_state = to_tensor(latent_state) + representation_state = to_tensor(network_output.latent_state) + + dynamic_proj = self._learn_model.project(latent_state, with_grad=True) + observation_proj = self._learn_model.project(representation_state, with_grad=False) + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] + consistency_loss += temp_loss + + # Calculate policy and value losses. + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1]) + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) + + # Calculate policy entropy loss. + prob = torch.softmax(policy_logits, dim=-1) + entropy = -(prob * torch.log(prob + 1e-9)).sum(-1) + policy_entropy_loss += -entropy + + # Calculate target policy entropy (for debugging purposes only). + target_normalized_visit_count = target_policy[:, step_k + 1] + non_masked_indices = torch.nonzero(mask_batch[:, step_k + 1]).squeeze(-1) + if len(non_masked_indices) > 0: + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count, 0, non_masked_indices + ) + target_policy_entropy += -( + (target_normalized_visit_count_masked + 1e-6) * + torch.log(target_normalized_visit_count_masked + 1e-6) + ).sum(-1).mean() + else: + target_policy_entropy += torch.log( + torch.tensor(target_normalized_visit_count.shape[-1], device=self._cfg.device) + ) + + # Log predicted values and rewards if monitoring extra statistics. + if self._cfg.monitor_extra_statistics: + original_rewards = self.inverse_scalar_transform_handle(reward) + original_rewards_cpu = original_rewards.detach().cpu() + + predicted_values = torch.cat( + (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + ) + predicted_rewards.append(original_rewards_cpu) + predicted_policies = torch.cat( + (predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu()) + ) + + # Core learning model update step. + weighted_loss = self._cfg.policy_loss_weight * policy_loss + \ + self._cfg.value_loss_weight * value_loss + \ + self._cfg.reward_loss_weight * reward_loss + \ + self._cfg.ssl_loss_weight * consistency_loss + \ + self._cfg.policy_entropy_weight * policy_entropy_loss + + # Accumulate losses from multiple tasks. + weighted_total_loss += weighted_loss.mean() + + # Store per-task losses for logging. + reward_loss_multi_task.append(reward_loss.mean().item()) + policy_loss_multi_task.append(policy_loss.mean().item()) + value_loss_multi_task.append(value_loss.mean().item()) + consistency_loss_multi_task.append(consistency_loss.mean().item()) + policy_entropy_multi_task.append(policy_entropy_loss.mean().item()) + # TODO: Adjust if using gradient correction. + lambd_multi_task.append(torch.tensor(0., device=self._cfg.device).item()) + value_priority_multi_task.append(value_priority.mean().item()) + value_priority_mean_multi_task.append(value_priority.mean().item()) + losses_list.append(weighted_loss.mean().item()) + + # Zero the optimizer's gradients. + self._optimizer.zero_grad() + + # Backward pass. + weighted_total_loss.backward() + + # Gradient clipping. + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( + self._learn_model.parameters(), + self._cfg.grad_clip_value + ) + + # Sync gradients for multi-GPU training. + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + + # Update optimizer. + self._optimizer.step() + if self._cfg.lr_piecewise_constant_decay: + self.lr_scheduler.step() + + # Update target model. + self._target_model.update(self._learn_model.state_dict()) + + # Get GPU memory usage. + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0.0 + max_memory_allocated_gb = 0.0 + + # Build the return loss dictionary. + return_loss_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self.collect_epsilon, + 'cur_lr_world_model': self._optimizer.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # Generate task-specific loss dictionaries, prefixing each with "noreduce_". + multi_task_loss_dicts = { + **generate_task_loss_dict(consistency_loss_multi_task, 'noreduce_consistency_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(lambd_multi_task, 'noreduce_lambd_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + } + + # Merge the dictionaries. + return_loss_dict.update(multi_task_loss_dicts) + + # Return the final loss dictionary. + return return_loss_dict + + def _reset_collect(self, data_id: Optional[List[int]] = None, task_id: int = None) -> None: + """ + Overview: + Reset the observation and action for the collector environment. + Arguments: + - data_id (:obj:`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + - task_id (:obj:`int`): The ID of the task. + """ + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + + def _reset_eval(self, data_id: Optional[List[int]] = None, task_id: int = None) -> None: + """ + Overview: + Reset the observation and action for the evaluator environment. + Arguments: + - data_id (:obj:`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + - task_id (:obj:`int`): The ID of the task. + """ + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + + + def _monitor_vars_learn(self, num_tasks: int = None) -> List[str]: + """ + Overview: + Registers variables to be monitored during the learning phase. The registered variables + will be recorded to TensorBoard based on the return value of `_forward_learn`. + If `num_tasks` is provided, it generates monitoring variables for each task. + Arguments: + - num_tasks (:obj:`int`, optional): The number of tasks. + Returns: + - monitored_vars (:obj:`List[str]`): A list of variable names to be monitored. + """ + # Basic monitoring variables. + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + ] + + # Task-specific monitoring variables. + task_specific_vars = [ + 'noreduce_consistency_loss', + 'noreduce_reward_loss', + 'noreduce_policy_loss', + 'noreduce_value_loss', + 'noreduce_policy_entropy', + 'noreduce_lambd', + 'noreduce_value_priority', + 'noreduce_value_priority_mean', + ] + + # Use self.task_num_for_current_rank as the number of tasks for the current rank. + num_tasks = self.task_num_for_current_rank + print(f'self.task_num_for_current_rank: {self.task_num_for_current_rank}') + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + monitored_vars.append(f'{var}_task{self.task_id + task_idx}') + else: + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self.collect_epsilon = 0.0 + if self._cfg.model.model_type == 'conv_context': + self.last_batch_obs = torch.zeros([8, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(8)] + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + task_id: int = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - epsilon (:obj:`float`): The epsilon of the eps greedy exploration. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - task_id (:obj:`int`): The ID of the task. + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self._collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + if self._cfg.model.model_type in ["conv", "mlp"]: + network_output = self._collect_model.initial_inference(data, task_id=task_id) + elif self._cfg.model.model_type == "conv_context": + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, + data, task_id=task_id) + + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + if not self._cfg.collect_with_pure_policy: + # The only difference between collect and eval is the dirichlet noise. + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # C++ MCTS tree. + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # Python MCTS tree. + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, task_id=task_id) + + # List of lists, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + if self._cfg.eps.eps_greedy_exploration_in_collect: + # Epsilon-greedy exploration for collection. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # Normal collection. + # NOTE: Only legal actions possess visit counts, so ``action_index_in_legal_action_set`` represents + # the index within the legal action set, not the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + if self._cfg.model.model_type in ["conv_context"]: + batch_action.append(action) + + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = data + self.last_batch_action = batch_action + else: + # Pure policy collection (without MCTS). + for i, env_id in enumerate(ready_env_id): + policy_values = torch.softmax(torch.tensor([policy_logits[i][a] for a in legal_actions[i]]), + dim=0).tolist() + policy_values = policy_values / np.sum(policy_values) + action_index_in_legal_action_set = np.random.choice(len(legal_actions[i]), p=policy_values) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'searched_value': pred_values[i], + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + + return output + + def _get_target_obs_index_in_step_k(self, step: int) -> Tuple[int, int]: + """ + Overview: + Get the begin and end indices of the target observation at step k. + Arguments: + - step (:obj:`int`): The current step k. + Returns: + - beg_index (:obj:`int`): The beginning index of the target observation. + - end_index (:obj:`int`): The ending index of the target observation. + """ + if self._cfg.model.model_type in ['conv', 'conv_context']: + beg_index = self._cfg.model.image_channel * step + end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) + elif self._cfg.model.model_type in ['mlp', 'mlp_context']: + beg_index = self._cfg.model.observation_shape * step + end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) + return beg_index, end_index + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + if self._cfg.model.model_type == 'conv_context': + self.last_batch_obs = torch.zeros([3, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(3)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, task_id: int = None) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - task_id (:obj:`int`): The ID of the task. + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + if self._cfg.model.model_type in ["conv", "mlp"]: + network_output = self._eval_model.initial_inference(data, task_id=task_id) + elif self._cfg.model.model_type == "conv_context": + network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # If not in training, obtain the scalar values of the value/reward. + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape (B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape (B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # C++ MCTS tree. + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # Python MCTS tree. + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, task_id=task_id) + + # List of lists, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so ``action_index_in_legal_action_set`` represents + # the index within the legal action set, not the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) + # rather than sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + if self._cfg.model.model_type in ["conv_context"]: + batch_action.append(action) + + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = data + self.last_batch_action = batch_action + + return output \ No newline at end of file diff --git a/lzero/policy/sampled_unizero_multitask.py b/lzero/policy/sampled_unizero_multitask.py new file mode 100644 index 000000000..00d929f51 --- /dev/null +++ b/lzero/policy/sampled_unizero_multitask.py @@ -0,0 +1,989 @@ +import copy +import logging +from collections import defaultdict +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import wandb +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY, set_pkg_seed, get_rank, get_world_size + +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import SampledUniZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import ( + scalar_transform, + InverseScalarTransform, + phi_transform, + DiscreteSupport, + to_torch_float_tensor, + mz_network_output_unpack, + select_action, + prepare_obs, + prepare_obs_stack_for_unizero +) +from lzero.policy.unizero import UniZeroPolicy +from .utils import configure_optimizers_nanogpt +import torch.nn.functional as F +import torch.distributed as dist + +# Please add the path to your LibMTL library. +# For example: sys.path.append('/path/to/your/LibMTL/') +import sys +# sys.path.append('/path/to/your/LibMTL/') # Template path +from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect + + +def generate_task_loss_dict(multi_task_losses: List[Union[torch.Tensor, float]], task_name_template: str, task_id: int) -> Dict[str, float]: + """ + Overview: + Generates a dictionary for losses of each task. + Arguments: + - multi_task_losses (:obj:`List[Union[torch.Tensor, float]]`): A list containing the loss for each task. + - task_name_template (:obj:`str`): A template for the task name, e.g., 'obs_loss_task{}'. + - task_id (:obj:`int`): The base task ID. + Returns: + - (:obj:`Dict[str, float]`): A dictionary containing the loss for each task. + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + # Convert tensor to float if it has .item(), otherwise cast to float. + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else float(task_loss) + except Exception as e: + # Fallback for cases where conversion fails. + task_loss_dict[task_name] = task_loss + return task_loss_dict + + +class WrappedModelV2: + """ + Overview: + A wrapper class to conveniently manage different parts of a larger model, + such as the tokenizer, transformer, and various embedding layers. This allows for + easier handling of parameters and gradients for these components. + """ + def __init__(self, tokenizer: torch.nn.Module, transformer: torch.nn.Module, pos_emb: torch.nn.Module, task_emb: torch.nn.Module, act_embedding_table: torch.nn.Module): + """ + Overview: + Initializes the WrappedModelV2 with model components. + Arguments: + - tokenizer (:obj:`torch.nn.Module`): The tokenizer module. + - transformer (:obj:`torch.nn.Module`): The main transformer module. + - pos_emb (:obj:`torch.nn.Module`): The positional embedding layer. + - task_emb (:obj:`torch.nn.Module`): The task embedding layer. + - act_embedding_table (:obj:`torch.nn.Module`): The action embedding table. + """ + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self) -> List[torch.Tensor]: + """ + Overview: + Collects and returns all parameters from the wrapped model components. + Returns: + - (:obj:`List[torch.Tensor]`): A list of all parameters. + """ + return ( + list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + # list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters()) + ) + + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all wrapped model components to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. Defaults to False. + """ + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + # self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + def get_group_parameters(self) -> Dict[str, List[torch.Tensor]]: + """ + Overview: + Returns a dictionary where keys are module names (or finer-grained layers) + and values are the corresponding parameter lists. The order of parameters in the + returned dictionary's values should be consistent with the `parameters()` method. + Returns: + - (:obj:`Dict[str, List[torch.Tensor]]`): A dictionary of grouped parameters. + """ + groups = {} + groups['tokenizer'] = list(self.tokenizer.parameters()) + groups['transformer'] = list(self.transformer.parameters()) + groups['pos_emb'] = list(self.pos_emb.parameters()) + groups['act_embedding_table'] = list(self.act_embedding_table.parameters()) + + # Example of how to add parameters from sub-layers within the transformer. + # This is for demonstration; ensure the order in parameters() is consistent if used. + if hasattr(self.transformer, 'blocks'): + for i, layer in enumerate(self.transformer.blocks): + groups[f'transformer_layer_{i}'] = list(layer.parameters()) + return groups + + +@POLICY_REGISTRY.register('sampled_unizero_multitask') +class SampledUniZeroMTPolicy(UniZeroPolicy): + """ + Overview: + The policy class for Sampled UniZero Multitask, combining multi-task learning with sampled-based MCTS. + This implementation extends the UniZeroPolicy to handle multiple tasks simultaneously while utilizing + sampled MCTS for action selection. It ensures scalability and correctness in multi-task environments. + """ + + # The default_config for Sampled UniZero Multitask policy. + config = dict( + type='sampled_unizero_multitask', + model=dict( + model_type='conv', # options={'mlp', 'conv'} + continuous_action_space=False, + observation_shape=(3, 64, 64), + self_supervised_learning_loss=True, + categorical_distribution=True, + image_channel=3, + frame_stack_num=1, + num_res_blocks=1, + num_channels=64, + support_scale=50, + bias=True, + res_connection_in_dynamics=True, + norm_type='LN', + analysis_sim_norm=False, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), + world_model_cfg=dict( + tokens_per_block=2, + max_blocks=10, + max_tokens=20, + context_length=8, + gru_gating=False, + device='cpu', + analysis_sim_norm=False, + analysis_dormant_ratio=False, + action_space_size=6, + group_size=8, + attention='causal', + num_layers=2, + num_heads=8, + embed_dim=768, + embed_pdrop=0.1, + resid_pdrop=0.1, + attn_pdrop=0.1, + support_size=101, + max_cache_size=5000, + env_num=8, + latent_recon_loss_weight=0., + perceptual_loss_weight=0., + policy_entropy_weight=5e-3, + predict_latent_loss_type='group_kl', + obs_type='image', + gamma=1, + dormant_threshold=0.01, + policy_loss_type='kl', + ), + ), + use_rnd_model=False, + multi_gpu=True, + sampled_algo=True, + gumbel_algo=False, + mcts_ctree=True, + cuda=True, + collector_env_num=8, + evaluator_env_num=3, + env_type='not_board_games', + action_type='fixed_action_space', + battle_mode='play_with_bot_mode', + monitor_extra_statistics=True, + game_segment_length=400, + analysis_sim_norm=False, + collect_with_pure_policy=False, + eval_freq=int(5e3), + sample_type='transition', + + transform2string=False, + gray_scale=False, + use_augmentation=False, + augmentation=['shift', 'intensity'], + + ignore_done=False, + update_per_collect=None, + replay_ratio=0.25, + batch_size=256, + optim_type='AdamW', + learning_rate=0.0001, + init_w=3e-3, + target_update_freq=100, + target_update_theta=0.05, + target_update_freq_for_intrinsic_reward=1000, + weight_decay=1e-4, + momentum=0.9, + grad_clip_value=5, + n_episode=8, + num_simulations=50, + discount_factor=0.997, + td_steps=5, + num_unroll_steps=10, + reward_loss_weight=1, + value_loss_weight=0.25, + policy_loss_weight=1, + ssl_loss_weight=0, + cos_lr_scheduler=False, + piecewise_decay_lr_scheduler=False, + threshold_training_steps_for_final_lr=int(5e4), + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(1e5), + fixed_temperature_value=0.25, + use_ture_chance_label_in_chance_encoder=False, + + use_priority=False, + priority_prob_alpha=0.6, + priority_prob_beta=0.4, + train_start_after_envsteps=0, + + root_dirichlet_alpha=0.3, + root_noise_weight=0.25, + + random_collect_episode_num=0, + + eps=dict( + eps_greedy_exploration_in_collect=False, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm's default model setting for demonstration. + Returns: + - (:obj:`Tuple[str, List[str]]`): A tuple containing the model name and the import paths. + """ + return 'SampledUniZeroMTModel', ['lzero.model.sampled_unizero_model_multitask'] + + def _init_learn(self) -> None: + """ + Overview: + Initializes the learning mode. This method sets up the learn model, optimizer, + target model, and other utilities required for training, such as LR schedulers + and gradient correction methods (e.g., MoCo). + """ + # Configure optimizer for the world model using NanoGPT's configuration utility. + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + # Initialize learning rate schedulers if configured. + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR + + if self._cfg.cos_lr_scheduler: + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, T_max=int(1e5), eta_min=0, last_epoch=-1 + ) + elif self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler = StepLR( + self._optimizer_world_model, step_size=int(5e4), gamma=0.1 + ) + + # Initialize weights for continuous action spaces. + if self._cfg.model.continuous_action_space: + init_w = self._cfg.init_w + self._model.world_model.fc_policy_head.mu.weight.data.uniform_(-init_w, init_w) + self._model.world_model.fc_policy_head.mu.bias.data.uniform_(-init_w, init_w) + try: + self._model.world_model.fc_policy_head.log_sigma_layer.weight.data.uniform_(-init_w, init_w) + self._model.world_model.fc_policy_head.log_sigma_layer.bias.data.uniform_(-init_w, init_w) + except Exception as exception: + logging.warning(exception) + + # Initialize and compile the target model. + self._target_model = copy.deepcopy(self._model) + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "Torch version 2.0 or higher is required." + self._model = torch.compile(self._model) + self._target_model = torch.compile(self._target_model) + + # Wrap the target model for soft updates (momentum-based). + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + self._learn_model = self._model + + # Initialize utilities for loss calculation and transformations. + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + self.intermediate_losses = defaultdict(float) + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + + self.task_id = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + print(f'self._cfg.only_use_moco_stats:{self._cfg.only_use_moco_stats}') + + # Initialize gradient correction method (MoCo) if enabled. + if self._cfg.use_moco or self._cfg.only_use_moco_stats: + # Wrap model components for gradient correction. Note: Heads are not included. + wrapped_model = WrappedModelV2( + self._learn_model.world_model.tokenizer.encoder, # TODO: This might contain one or multiple encoders. + self._learn_model.world_model.transformer, + self._learn_model.world_model.pos_emb, + self._learn_model.world_model.task_emb, + self._learn_model.world_model.act_embedding_table, + ) + + # TODO: The GradCorrect class might need adjustments for multi-GPU training compatibility. + # Initialize the gradient correction mechanism. + self.grad_correct = GradCorrect(wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) + + self.grad_correct.init_param() + self.grad_correct.rep_grad = False + + + def _forward_learn(self, data: Tuple[torch.Tensor], task_weights: Any = None, ignore_grad: bool = False) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward pass for training. This method processes a batch of data for multiple tasks, + computes losses, and updates the model weights. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): A tuple of data batches, one for each task. + - task_weights (:obj:`Any`): Weights for each task's loss. Defaults to None. + - ignore_grad (:obj:`bool`): If True, gradients are zeroed out after computation, effectively skipping the update. Defaults to False. + Returns: + - (:obj:`Dict[str, Union[float, int]]`): A dictionary containing various loss values and training statistics. + """ + self._learn_model.train() + self._target_model.train() + + # Initialize lists to store losses and metrics for each task. + task_weight_multi_task, obs_loss_multi_task, reward_loss_multi_task = [], [], [] + policy_loss_multi_task, orig_policy_loss_multi_task, policy_entropy_multi_task = [], [], [] + value_loss_multi_task, latent_recon_loss_multi_task, perceptual_loss_multi_task = [], [], [] + latent_state_l2_norms_multi_task, average_target_policy_entropy_multi_task = [], [] + value_priority_multi_task, value_priority_mean_multi_task = [], [] + + weighted_total_loss = 0.0 + losses_list = [] # Stores the individual loss tensor for each task. + + for task_id, data_one_task in enumerate(data): + # Unpack data for the current task. + current_batch, target_batch, task_id = data_one_task + obs_batch_ori, action_batch, child_sampled_actions_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch + target_reward, target_value, target_policy = target_batch + + # Prepare observations. + if self._cfg.model.frame_stack_num == 4: + obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) + else: + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg, task_id) + + # Apply data augmentation if enabled. + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare actions and convert data to torch tensors. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1) + if not self._cfg.model.continuous_action_space: + action_batch = action_batch.long() + + data_list = [mask_batch, target_reward.astype('float32'), target_value.astype('float32'), target_policy, weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, self._cfg.device) + + cur_batch_size = target_reward.size(0) + target_reward = target_reward.view(cur_batch_size, -1) + target_value = target_value.view(cur_batch_size, -1) + + # Transform scalar targets to their categorical representation. + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # Prepare the batch for the GPT-based world model. + batch_for_gpt = {} + if isinstance(self._cfg.model.observation_shape_list[task_id], int) or len(self._cfg.model.observation_shape_list[task_id]) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape(cur_batch_size, -1, self._cfg.model.observation_shape_list[task_id]) + else: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape(cur_batch_size, -1, *self._cfg.model.observation_shape_list[task_id]) + + batch_for_gpt['actions'] = action_batch.squeeze(-1) + batch_for_gpt['child_sampled_actions'] = torch.from_numpy(child_sampled_actions_batch).to(self._cfg.device)[:, :-1] + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] + batch_for_gpt['mask_padding'] = (mask_batch == 1.0)[:, :-1] # 0 indicates invalid padding data. + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] + batch_for_gpt['target_policy'] = target_policy[:, :-1] + + # Compute target policy entropy for monitoring. + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean().item() + + # Compute losses using the world model. + losses = self._learn_model.world_model.compute_loss( + batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, task_id=task_id + ) + + # Accumulate weighted total loss. + current_task_weight = task_weights[task_id] if task_weights is not None else 1 + weighted_total_loss += losses.loss_total * current_task_weight + losses_list.append(losses.loss_total * current_task_weight) + task_weight_multi_task.append(current_task_weight) + + # Store intermediate losses for logging. + for loss_name, loss_value in losses.intermediate_losses.items(): + self.intermediate_losses[f"{loss_name}"] = loss_value + + # Collect individual losses for the current task. + obs_loss_multi_task.append(self.intermediate_losses.get('loss_obs', 0.0) or 0.0) + reward_loss_multi_task.append(self.intermediate_losses.get('loss_rewards', 0.0) or 0.0) + policy_loss_multi_task.append(self.intermediate_losses.get('loss_policy', 0.0) or 0.0) + orig_policy_loss_multi_task.append(self.intermediate_losses.get('orig_policy_loss', 0.0) or 0.0) + policy_entropy_multi_task.append(self.intermediate_losses.get('policy_entropy', 0.0) or 0.0) + value_loss_multi_task.append(self.intermediate_losses.get('loss_value', 0.0) or 0.0) + latent_recon_loss_multi_task.append(self.intermediate_losses.get('latent_recon_loss', 0.0) or 0.0) + perceptual_loss_multi_task.append(self.intermediate_losses.get('perceptual_loss', 0.0) or 0.0) + latent_state_l2_norms_multi_task.append(self.intermediate_losses.get('latent_state_l2_norms', 0.0) or 0.0) + average_target_policy_entropy_multi_task.append(average_target_policy_entropy) + value_priority = torch.tensor(0., device=self._cfg.device) # Placeholder + value_priority_multi_task.append(value_priority) + value_priority_mean_multi_task.append(value_priority.mean().item()) + + # --- Model Update Step --- + self._optimizer_world_model.zero_grad() + + # Perform backward pass, either with or without gradient correction. + if self._cfg.use_moco: + # Use MoCo for gradient correction and backpropagation. + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + elif self._cfg.only_use_moco_stats: + # Compute MoCo stats but perform standard backpropagation. + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + weighted_total_loss.backward() + else: + # Standard backpropagation without gradient correction. + lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device) + weighted_total_loss.backward() + + # Clip gradients to prevent exploding gradients. + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), self._cfg.grad_clip_value) + + # NOTE: If ignore_grad is True, zero out gradients. This is useful for DDP synchronization + # when a GPU has finished all its tasks but still needs to participate in the training step. + if ignore_grad: + self._optimizer_world_model.zero_grad() + + # Synchronize gradients across GPUs in multi-GPU setup. + if self._cfg.multi_gpu: + if not self._cfg.use_moco: + # TODO: Investigate if a barrier is needed here for synchronization. + # dist.barrier() + self.sync_gradients(self._learn_model) + + # Update model parameters. + self._optimizer_world_model.step() + + # Step the learning rate scheduler. + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # Update the target model using a soft update rule. + self._target_model.update(self._learn_model.state_dict()) + + # Monitor GPU memory usage. + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated_gb = torch.cuda.memory_allocated() / (1024 ** 3) + max_memory_allocated_gb = torch.cuda.max_memory_allocated() / (1024 ** 3) + else: + current_memory_allocated_gb, max_memory_allocated_gb = 0., 0. + + # --- Logging and Return --- + return_loss_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self._collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # Generate and merge task-specific loss dictionaries. + # The "noreduce_" prefix indicates these are per-rank values before DDP reduction. + multi_task_loss_dicts = { + **generate_task_loss_dict(task_weight_multi_task, 'noreduce_task_weight_task{}', self.task_id), + **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', self.task_id), + **generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', self.task_id), + **generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', self.task_id), + **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', self.task_id), + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', self.task_id), + **generate_task_loss_dict(orig_policy_loss_multi_task, 'noreduce_orig_policy_loss_task{}', self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', self.task_id), + **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'noreduce_target_policy_entropy_task{}', self.task_id), + **generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', self.task_id), + } + return_loss_dict.update(multi_task_loss_dicts) + + # Log to wandb if enabled. + if self._cfg.use_wandb: + wandb.log({'learner_step/' + k: v for k, v in return_loss_dict.items()}, step=self.env_step) + wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) + + return return_loss_dict + + def _monitor_vars_learn(self, num_tasks: int = 2) -> List[str]: + """ + Overview: + Specifies the variables to be monitored during training. These variables will be logged + (e.g., to TensorBoard) based on the dictionary returned by `_forward_learn`. + Arguments: + - num_tasks (:obj:`int`): The number of tasks to generate monitored variables for. This argument is for API consistency and is overridden by `self.task_num_for_current_rank`. + Returns: + - (:obj:`List[str]`): A list of variable names to monitor. + """ + # Basic monitored variables, independent of the number of tasks. + monitored_vars = [ + 'Current_GPU', 'Max_GPU', 'collect_epsilon', 'collect_mcts_temperature', + 'cur_lr_world_model', 'weighted_total_loss', 'total_grad_norm_before_clip_wm', + ] + + # Task-specific variables. + task_specific_vars = [ + 'noreduce_task_weight', 'noreduce_obs_loss', 'noreduce_orig_policy_loss', + 'noreduce_policy_loss', 'noreduce_latent_recon_loss', 'noreduce_policy_entropy', + 'noreduce_target_policy_entropy', 'noreduce_reward_loss', 'noreduce_value_loss', + 'noreduce_perceptual_loss', 'noreduce_latent_state_l2_norms', 'noreduce_lambd', + 'noreduce_value_priority_mean', + ] + + # The number of tasks handled by the current rank. + num_tasks_on_rank = self.task_num_for_current_rank + + # Generate full variable names for each task on the current rank. + if num_tasks_on_rank is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks_on_rank): + # The task ID is offset by the base task ID for this rank. + monitored_vars.append(f'{var}_task{self.task_id + task_idx}') + else: + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + def monitor_weights_and_grads(self, model: torch.nn.Module) -> None: + """ + Overview: + A utility function to monitor and print the statistics (mean, std) of model weights and their gradients. + Arguments: + - model (:obj:`torch.nn.Module`): The model to inspect. + """ + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is not None: + print(f"Layer: {name} | " + f"Weight mean: {param.data.mean():.4f} | " + f"Weight std: {param.data.std():.4f} | " + f"Grad mean: {param.grad.mean():.4f} | " + f"Grad std: {param.grad.std():.4f}") + + def _init_collect(self) -> None: + """ + Overview: + Initializes the collection mode. This method sets up the collect model, MCTS utilities, + and initial states for the collector environments. + """ + self._collect_model = self._model + + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self._task_weight_temperature = 10. + self._collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + + # Initialize placeholders for the last observation and action batches. + if self._cfg.model.model_type == 'conv': + obs_shape = [self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64] + self.last_batch_obs = torch.zeros(obs_shape, device=self._cfg.device) + elif self._cfg.model.model_type == 'mlp': + obs_shape = [self.collector_env_num, self._cfg.model.observation_shape_list[0]] + self.last_batch_obs = torch.zeros(obs_shape, device=self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.collector_env_num)] + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: List = None, + temperature: float = 1.0, + to_play: List[int] = [-1], + epsilon: float = 0.25, + ready_env_id: np.ndarray = None, + timestep: List[int] = [0], + task_id: int = None, + ) -> Dict[int, Dict[str, Any]]: + """ + Overview: + The forward pass for data collection. It uses MCTS to select actions for the current states. + Arguments: + - data (:obj:`torch.Tensor`): The current batch of observations. + - action_mask (:obj:`List`): A list of action masks for each environment. + - temperature (:obj:`float`): The temperature parameter for MCTS action selection. + - to_play (:obj:`List[int]`): A list indicating the current player for each environment. + - epsilon (:obj:`float`): The exploration noise parameter. + - ready_env_id (:obj:`np.ndarray`): An array of environment IDs that are ready for action. + - timestep (:obj:`List[int]`): The current timestep for each environment. + - task_id (:obj:`int`): The ID of the task being executed. + Returns: + - (:obj:`Dict[int, Dict[str, Any]]`): A dictionary mapping environment IDs to action selection results. + """ + self._collect_model.eval() + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + # 1. Initial inference to get root information. + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + # 2. Prepare MCTS roots. + if not self._cfg.model.continuous_action_space: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + else: + legal_actions = [[-1] * self._cfg.model.world_model_cfg.num_of_sampled_actions for _ in range(active_collect_env_num)] + + noises = [np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.world_model_cfg.num_of_sampled_actions).astype(np.float32).tolist() for _ in range(active_collect_env_num)] + + if self._cfg.mcts_ctree: + roots = MCTSCtree.roots(active_collect_env_num, legal_actions, self._cfg.model.world_model_cfg.action_space_size, self._cfg.model.world_model_cfg.num_of_sampled_actions, self._cfg.model.continuous_action_space) + else: + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + + # 3. MCTS search. + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep=timestep, task_id=task_id) + + # 4. Get results from MCTS and select actions. + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + roots_sampled_actions = roots.get_sampled_actions() + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + root_sampled_actions = np.array([getattr(action, 'value', action) for action in roots_sampled_actions[i]]) + + # Select action based on visit counts, with temperature for exploration. + action_idx, visit_count_distribution_entropy = select_action(distributions, temperature=self._collect_mcts_temperature, deterministic=False) + action = root_sampled_actions[action_idx] + if not self._cfg.model.continuous_action_space: + action = int(action.item()) + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'root_sampled_actions': root_sampled_actions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + # 5. Update state for the next step. + self.last_batch_obs = data + self.last_batch_action = batch_action + + # Reset collector if the number of active environments is less than expected. + if active_collect_env_num < self.collector_env_num: + logging.warning(f'Number of active envs ({active_collect_env_num}) is less than collector_env_num ({self.collector_env_num}). Resetting collector.') + self._reset_collect(reset_init_data=True, task_id=task_id) + + return output + + def _init_eval(self) -> None: + """ + Overview: + Initializes the evaluation mode. This method sets up the evaluation model, MCTS utilities, + and initial states for the evaluator environments. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + self.evaluator_env_num = self._cfg.evaluator_env_num + + self.task_id_for_eval = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + + # Initialize placeholders for the last observation and action batches for evaluation. + if self._cfg.model.model_type == 'conv': + obs_shape = [self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64] + self.last_batch_obs_eval = torch.zeros(obs_shape, device=self._cfg.device) + elif self._cfg.model.model_type == 'mlp': + # TODO: Ensure observation_shape_list is correctly indexed for the evaluation task. + obs_shape = [self.evaluator_env_num, self._cfg.model.observation_shape_list[self.task_id_for_eval]] + self.last_batch_obs_eval = torch.zeros(obs_shape, device=self._cfg.device) + print(f'rank {get_rank()} last_batch_obs_eval shape: {self.last_batch_obs_eval.shape}') + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.ndarray = None, timestep: List[int] = [0], task_id: int = None) -> Dict[int, Dict[str, Any]]: + """ + Overview: + The forward pass for evaluation. It uses MCTS to select actions deterministically. + Arguments: + - data (:obj:`torch.Tensor`): The current batch of observations. + - action_mask (:obj:`List`): A list of action masks for each environment. + - to_play (:obj:`int`): The current player. + - ready_env_id (:obj:`np.ndarray`): An array of environment IDs that are ready for action. + - timestep (:obj:`List[int]`): The current timestep for each environment. + - task_id (:obj:`int`): The ID of the task being evaluated. + Returns: + - (:obj:`Dict[int, Dict[str, Any]]`): A dictionary mapping environment IDs to action selection results. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + # 1. Initial inference. + network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + # 2. Prepare MCTS roots without noise for deterministic evaluation. + if not self._cfg.model.continuous_action_space: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + else: + legal_actions = [[-1] * self._cfg.model.world_model_cfg.num_of_sampled_actions for _ in range(active_eval_env_num)] + + if self._cfg.mcts_ctree: + roots = MCTSCtree.roots(active_eval_env_num, legal_actions, self._cfg.model.world_model_cfg.action_space_size, self._cfg.model.world_model_cfg.num_of_sampled_actions, self._cfg.model.continuous_action_space) + else: + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + + # 3. MCTS search. + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep=timestep, task_id=task_id) + + # 4. Get results and select actions deterministically. + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + roots_sampled_actions = roots.get_sampled_actions() + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + root_sampled_actions = np.array([getattr(action, 'value', action) for action in roots_sampled_actions[i]]) + + # Select action deterministically (greedy selection from visit counts). + action_idx, visit_count_distribution_entropy = select_action(distributions, temperature=1, deterministic=True) + action = root_sampled_actions[action_idx] + if not self._cfg.model.continuous_action_space: + action = int(action.item()) + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'root_sampled_actions': root_sampled_actions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + # 5. Update state for the next evaluation step. + self.last_batch_obs_eval = data + self.last_batch_action = batch_action + + return output + + def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Overview: + Resets the collector state. This can be a full reset of initial data or a periodic + clearing of model caches to manage memory. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None, applies to all. + - current_steps (:obj:`int`): The current number of steps, used for periodic cache clearing. + - reset_init_data (:obj:`bool`): Whether to reset the initial observation and action batches. + - task_id (:obj:`int`, optional): The task ID, used to determine observation shape. + """ + if reset_init_data: + obs_shape = self._cfg.model.observation_shape_list[task_id] if task_id is not None else self._cfg.model.observation_shape + self.last_batch_obs = initialize_zeros_batch(obs_shape, self._cfg.collector_env_num, self._cfg.device) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + logging.info(f'Collector: last_batch_obs and last_batch_action have been reset. Shape: {self.last_batch_obs.shape}') + + if env_id is None or isinstance(env_id, list): + return + + # Periodically clear model caches to free up memory. + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + if current_steps > 0 and current_steps % clear_interval == 0: + logging.info(f'Clearing model caches at step {current_steps}.') + world_model = self._collect_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + torch.cuda.empty_cache() + logging.info('Collector: collect_model caches cleared.') + self._reset_target_model() + + def _reset_target_model(self) -> None: + """ + Overview: + Resets the caches of the target model to free up GPU memory. + """ + world_model = self._target_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + torch.cuda.empty_cache() + logging.info('Collector: target_model caches cleared.') + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Returns the state dictionary of the learning components. + Returns: + - (:obj:`Dict[str, Any]`): A dictionary containing the state of the model, target model, and optimizer. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer_world_model': self._optimizer_world_model.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Loads the state dictionary into the learning components. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The state dictionary to load. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + # TODO: The following is a version for pretrain-finetune workflow, which only loads backbone parameters. + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Loads a state_dict into the policy's learn mode, but excludes parameters related to + # multi-task heads and task embeddings. This is useful for fine-tuning a pre-trained model + # on a new set of tasks. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The dict of the policy learn state saved previously. + # """ + # # Define prefixes of parameters to exclude (e.g., multi-task heads, task embeddings). + # exclude_prefixes = [ + # '_orig_mod.world_model.head_policy_multi_task.', + # '_orig_mod.world_model.head_value_multi_task.', + # '_orig_mod.world_model.head_rewards_multi_task.', + # '_orig_mod.world_model.head_observations_multi_task.', + # '_orig_mod.world_model.task_emb.' + # ] + + # # Define specific keys to exclude if they don't fit a prefix pattern. + # exclude_keys = [ + # '_orig_mod.world_model.task_emb.weight', + # '_orig_mod.world_model.task_emb.bias', + # ] + + # def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + # """ + # Filters out parameters that should not be loaded. + # """ + # filtered = {} + # for k, v in state_dict_loader.items(): + # if any(k.startswith(prefix) for prefix in exclude_prefixes) or k in exclude_keys: + # print(f"Excluding parameter from loading: {k}") + # continue + # filtered[k] = v + # return filtered + + # # Filter and load state_dict for the main model. + # if 'model' in state_dict: + # model_state_dict = state_dict['model'] + # filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + # missing, unexpected = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + # if missing: + # print(f"Missing keys when loading _learn_model: {missing}") + # if unexpected: + # print(f"Unexpected keys when loading _learn_model: {unexpected}") + # else: + # print("Warning: 'model' key not found in the state_dict.") + + # # Filter and load state_dict for the target model. + # if 'target_model' in state_dict: + # target_model_state_dict = state_dict['target_model'] + # filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + # missing, unexpected = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + # if missing: + # print(f"Missing keys when loading _target_model: {missing}") + # if unexpected: + # print(f"Unexpected keys when loading _target_model: {unexpected}") + # else: + # print("Warning: 'target_model' key not found in the state_dict.") + + # # Load optimizer state_dict. This is often skipped during fine-tuning, but included here for completeness. + # if 'optimizer_world_model' in state_dict: + # try: + # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + # except Exception as e: + # print(f"Could not load optimizer state_dict: {e}. This may be expected during fine-tuning.") + # else: + # print("Warning: 'optimizer_world_model' key not found in the state_dict.") \ No newline at end of file diff --git a/lzero/policy/scaling_transform.py b/lzero/policy/scaling_transform.py index 19a852f56..17eee4052 100644 --- a/lzero/policy/scaling_transform.py +++ b/lzero/policy/scaling_transform.py @@ -1,6 +1,6 @@ from typing import Union import torch - +import numpy as np class DiscreteSupport(object): @@ -11,7 +11,6 @@ def __init__(self, start: float, stop: float, step: float = 1., device: Union[st assert self.size > 0, "DiscreteSupport size must be greater than 0" self.step = step - def scalar_transform(x: torch.Tensor, epsilon: float = 0.001, delta: float = 1.) -> torch.Tensor: """ Overview: @@ -110,6 +109,7 @@ def visit_count_temperature( def phi_transform( discrete_support: DiscreteSupport, x: torch.Tensor, + label_smoothing_eps: float = 0. # <--- 新增平滑参数 ) -> torch.Tensor: """ Overview: @@ -163,7 +163,15 @@ def phi_transform( dtype=x.dtype, device=x.device) target.scatter_add_(-1, idx, prob) - return target + # return target + + # --- 5. 应用标签平滑 --- + if label_smoothing_eps > 0: + # 将原始的 two-hot 目标与一个均匀分布混合 + smooth_target = (1.0 - label_smoothing_eps) * target + (label_smoothing_eps / size) + return smooth_target + else: + return target def cross_entropy_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 0341c430b..a56474ccb 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -17,7 +17,76 @@ from lzero.policy.muzero import MuZeroPolicy from .utils import configure_optimizers_nanogpt +from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters +import torch.nn.functional as F +def scale_module_weights_vectorized(module: torch.nn.Module, scale_factor: float): + """ + 使用向量化操作高效地缩放一个模块的所有权重。 + """ + if not (0.0 < scale_factor < 1.0): + return # 如果缩放因子无效,则不执行任何操作 + + # 1. 将模块的所有参数展平成一个单一向量 + params_vec = parameters_to_vector(module.parameters()) + + # 2. 在这个向量上执行一次乘法操作 + params_vec.data.mul_(scale_factor) + + # 3. 将缩放后的向量复制回模块的各个参数 + vector_to_parameters(params_vec, module.parameters()) + + +def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): + """ + 为UniZero模型配置带有差异化学习率的优化器。 + """ + # 1. 定义需要特殊处理的参数 + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + + # 2. 将参数分为三组:Transformer主干、Tokenizer、Heads + transformer_params = {pn: p for pn, p in param_dict.items() if 'transformer' in pn} + tokenizer_params = {pn: p for pn, p in param_dict.items() if 'tokenizer' in pn} + + # Heads的参数是那些既不属于transformer也不属于tokenizer的 + head_params = { + pn: p for pn, p in param_dict.items() + if 'transformer' not in pn and 'tokenizer' not in pn + } + + # 3. 为每组设置不同的优化器参数(特别是学习率) + # 这里我们仍然使用AdamW,但学习率设置更合理 + optim_groups = [ + { + 'params': list(tokenizer_params.values()), + 'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4 + # 'lr': learning_rate * 0.1, # 为encoder设置一个较小的学习率,例如 1e-5 + 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + # 'weight_decay': weight_decay # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + }, + { + 'params': list(transformer_params.values()), + 'lr': learning_rate, # 1e-4 + # 'lr': learning_rate * 0.2, # 为Transformer主干设置一个较小的学习率,例如 1e-5 + 'weight_decay': weight_decay + # 'weight_decay': weight_decay * 5.0 + }, + { + 'params': list(head_params.values()), + 'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4 + 'weight_decay': 0.0 # 通常Heads的权重不做衰减 + # 'weight_decay': weight_decay + + } + ] + + print("--- Optimizer Groups ---") + print(f"Transformer LR: {learning_rate}") + print(f"Tokenizer/Heads LR: {learning_rate}") + + optimizer = torch.optim.AdamW(optim_groups, betas=betas) + return optimizer + @POLICY_REGISTRY.register('unizero') class UniZeroPolicy(MuZeroPolicy): """ @@ -81,8 +150,8 @@ class UniZeroPolicy(MuZeroPolicy): device='cpu', # (bool) Whether to analyze simulation normalization. analysis_sim_norm=False, - # (bool) Whether to analyze dormant ratio. - analysis_dormant_ratio=False, + # (bool) Whether to analyze dormant ratio, average_weight_magnitude of net, effective_rank of latent. + analysis_dormant_ratio_weight_rank=False, # (int) The shape of the action space. action_space_size=6, # (int) The size of the group, related to simulation normalization. @@ -139,6 +208,7 @@ class UniZeroPolicy(MuZeroPolicy): rope_theta=10000, # (int) The maximum sequence length for position encoding. max_seq_len=8192, + lora_r= 0, # Controls where to compute reconstruction loss: 'after_backbone', 'before_backbone', or None. # - after_backbone: The reconstruction loss is computed after the encoded representation passes through the backbone. # - before_backbone: The reconstruction loss is computed directly on the encoded representation, without the backbone. @@ -146,6 +216,23 @@ class UniZeroPolicy(MuZeroPolicy): ), ), # ****** common ****** + # (bool) 是否启用自适应策略熵权重 (alpha) + use_adaptive_entropy_weight=True, + # (float) 自适应alpha优化器的学习率 + adaptive_entropy_alpha_lr=1e-4, + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) 是否启用 encoder-clip 值的退火。 + use_encoder_clip_annealing=True, + # (str) 退火类型。可选 'linear' 或 'cosine'。 + encoder_clip_anneal_type='cosine', + # (float) 退火的起始 clip 值 (训练初期,较宽松)。 + encoder_clip_start_value=30.0, + # (float) 退火的结束 clip 值 (训练后期,较严格)。 + encoder_clip_end_value=10.0, + # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 + encoder_clip_anneal_steps=100000, # 例如,在200k次迭代后达到最终值 + # ===================== END: Encoder-Clip Annealing Config ===================== + # (bool) whether to use rnd model. use_rnd_model=False, # (bool) Whether to use multi-gpu training. @@ -178,7 +265,7 @@ class UniZeroPolicy(MuZeroPolicy): # (bool) Whether to use the pure policy to collect data. collect_with_pure_policy=False, # (int) The evaluation frequency. - eval_freq=int(2e3), + eval_freq=int(5e3), # (str) The sample type. Options are ['episode', 'transition']. sample_type='transition', # ****** observation ****** @@ -211,6 +298,10 @@ class UniZeroPolicy(MuZeroPolicy): optim_type='AdamW', # (float) Learning rate for training policy network. Initial lr for manually decay schedule. learning_rate=0.0001, + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + monitor_norm_freq=5000, + # ============================================================ # (int) Frequency of hard target network update. target_update_freq=100, # (int) Frequency of soft target network update. @@ -227,8 +318,12 @@ class UniZeroPolicy(MuZeroPolicy): n_episode=8, # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. num_segments=8, - # (int) the number of simulations in MCTS. + # # (int) the number of simulations in MCTS for renalyze. num_simulations=50, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, # (float) Discount factor (gamma) for returns. discount_factor=0.997, # (int) The number of steps for calculating target q_value. @@ -313,24 +408,142 @@ def default_model(self) -> Tuple[str, List[str]]: """ return 'UniZeroModel', ['lzero.model.unizero_model'] + + # ==================== [新增] 模型范数监控函数 ==================== + def _monitor_model_norms(self) -> Dict[str, float]: + """ + Overview: + 计算并返回模型关键组件(Encoder, Transformer, Heads)的参数矩阵范数。 + 此函数应在 torch.no_grad() 环境下调用,以提高效率。 + Returns: + - norm_metrics (:obj:`Dict[str, float]`): 包含所有范数指标的字典,用于日志记录。 + """ + world_model = self._learn_model.world_model + norm_metrics = {} + + # 定义要监控的模块组 + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_value, + 'head_reward': world_model.head_rewards, + 'head_policy': world_model.head_policy, + } + + for group_name, group_module in module_groups.items(): + total_norm_sq = 0.0 + for param_name, param in group_module.named_parameters(): + if param.requires_grad: + # 计算单层参数的L2范数 + param_norm = param.data.norm(2).item() + # 替换点号,使其在TensorBoard中正确显示为层级 + log_name = f'norm/{group_name}/{param_name.replace(".", "/")}' + norm_metrics[log_name] = param_norm + total_norm_sq += param_norm ** 2 + + # 计算整个模块的总范数 + total_group_norm = np.sqrt(total_norm_sq) + norm_metrics[f'norm/{group_name}/_total_norm'] = total_group_norm + + return norm_metrics + + def _monitor_gradient_norms(self) -> Dict[str, float]: + """ + Overview: + 计算并返回模型关键组件的梯度范数。 + 此函数应在梯度计算完成后、参数更新之前调用。 + Returns: + - grad_metrics (:obj:`Dict[str, float]`): 包含所有梯度范数指标的字典,用于日志记录。 + """ + world_model = self._learn_model.world_model + grad_metrics = {} + + # 定义要监控的模块组 + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_value, + 'head_reward': world_model.head_rewards, + 'head_policy': world_model.head_policy, + } + + for group_name, group_module in module_groups.items(): + total_grad_norm_sq = 0.0 + num_params_with_grad = 0 + + for param_name, param in group_module.named_parameters(): + if param.requires_grad and param.grad is not None: + # 计算单层参数的梯度L2范数 + grad_norm = param.grad.data.norm(2).item() + # 替换点号,使其在TensorBoard中正确显示为层级 + log_name = f'grad/{group_name}/{param_name.replace(".", "/")}' + grad_metrics[log_name] = grad_norm + total_grad_norm_sq += grad_norm ** 2 + num_params_with_grad += 1 + + # 计算整个模块的总梯度范数 + if num_params_with_grad > 0: + total_group_grad_norm = np.sqrt(total_grad_norm_sq) + grad_metrics[f'grad/{group_name}/_total_norm'] = total_group_grad_norm + else: + grad_metrics[f'grad/{group_name}/_total_norm'] = 0.0 + + return grad_metrics + # ================================================================= + def _init_learn(self) -> None: """ Overview: Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ - # NOTE: nanoGPT optimizer - self._optimizer_world_model = configure_optimizers_nanogpt( - model=self._model.world_model, - learning_rate=self._cfg.learning_rate, - weight_decay=self._cfg.weight_decay, - device_type=self._cfg.device, - betas=(0.9, 0.95), - ) + if self._cfg.optim_type == 'SGD': + # --- 改为SGD优化器 --- + self._optimizer_world_model = torch.optim.SGD( + self._model.world_model.parameters(), + lr=self._cfg.learning_rate, # 初始学习率,在配置中设为 0.2 + momentum=self._cfg.momentum, # 在配置中设为 0.9 + weight_decay=self._cfg.weight_decay # 在配置中设为 1e-4 + ) + elif self._cfg.optim_type == 'AdamW': + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + elif self._cfg.optim_type == 'AdamW_mix_lr_wdecay': + self._optimizer_world_model = configure_optimizer_unizero( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, # 使用一个合理的AdamW基础学习率 + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) if self._cfg.cos_lr_scheduler: from torch.optim.lr_scheduler import CosineAnnealingLR # TODO: check the total training steps - self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + # self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + total_iters = self._cfg.get('total_iterations', 500000) # 500k iter + # final_lr = self._cfg.get('final_learning_rate', 0.0) + final_lr = self._cfg.get('final_learning_rate', 1e-6) + + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, + T_max=total_iters, + eta_min=final_lr + ) + print(f"CosineAnnealingLR enabled: T_max={total_iters}, eta_min={final_lr}") + + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer_world_model, lr_lambda=lr_lambda) # use model_wrapper for specialized demands of different modes self._target_model = copy.deepcopy(self._model) @@ -372,6 +585,63 @@ def _init_learn(self) -> None: self.accumulation_steps = self._cfg.accumulation_steps + # ==================== START: 目标熵正则化初始化 ==================== + # 从配置中读取是否启用自适应alpha,并提供一个默认值 + self.use_adaptive_entropy_weight = self._cfg.get('use_adaptive_entropy_weight', True) + + # 在 _init_learn 中增加配置 + self.target_entropy_start_ratio = self._cfg.get('target_entropy_start_ratio', 0.98) + self.target_entropy_end_ratio = self._cfg.get('target_entropy_end_ratio', 0.7) + self.target_entropy_decay_steps = self._cfg.get('target_entropy_decay_steps', 200000) # 例如,在200k步内完成退火 2M envsteps + + if self.use_adaptive_entropy_weight: + # 1. 设置目标熵。对于离散动作空间,一个常见的启发式设置是动作空间维度的负对数乘以一个系数。 + # 这个系数(例如0.98)可以作为一个超参数。 + action_space_size = self._cfg.model.action_space_size + self.target_entropy = -np.log(1.0 / action_space_size) * 0.98 + + # 2. 初始化一个可学习的 log_alpha 参数。 + # 初始化为0,意味着初始的 alpha = exp(0) = 1.0。 + self.log_alpha = torch.nn.Parameter(torch.zeros(1, device=self._cfg.device), requires_grad=True) + + # 3. 为 log_alpha 创建一个专属的优化器。 + # 使用与主优化器不同的、较小的学习率(例如1e-4)通常更稳定。 + alpha_lr = self._cfg.get('adaptive_entropy_alpha_lr', 1e-4) + self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr) + + print("="*20) + print(">>> 目标熵正则化 (自适应Alpha) 已启用 <<<") + print(f" 目标熵 (Target Entropy): {self.target_entropy:.4f}") + print(f" Alpha 优化器学习率: {alpha_lr:.2e}") + print("="*20) + # ===================== END: 目标熵正则化初始化 ===================== + + # ==================== START: 初始化 Encoder-Clip Annealing 参数 ==================== + self.use_encoder_clip_annealing = self._cfg.get('use_encoder_clip_annealing', False) + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 20.0) # TODO + if self.use_encoder_clip_annealing: + self.encoder_clip_anneal_type = self._cfg.get('encoder_clip_anneal_type', 'cosine') + self.encoder_clip_start = self._cfg.get('encoder_clip_start_value', 30.0) + self.encoder_clip_end = self._cfg.get('encoder_clip_end_value', 10.0) + self.encoder_clip_anneal_steps = self._cfg.get('encoder_clip_anneal_steps', 200000) + + print("="*20) + print(">>> Encoder-Clip 退火已启用 <<<") + print(f" 类型: {self.encoder_clip_anneal_type}") + print(f" 范围: {self.encoder_clip_start} -> {self.encoder_clip_end}") + print(f" 步数: {self.encoder_clip_anneal_steps}") + print("="*20) + else: + # 如果不启用退火,则使用固定的 clip 阈值 + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 20.0) + # ===================== END: 初始化 Encoder-Clip Annealing 参数 ===================== + + # --- NEW: Policy Label Smoothing Parameters --- + self.policy_ls_eps_start = self._cfg.get('policy_ls_eps_start', 0.05) # TODO policy_label_smoothing_eps_start 越大的action space需要越大的eps + self.policy_ls_eps_end = self._cfg.get('policy_label_smoothing_eps_end ', 0.01) # TODO policy_label_smoothing_eps_start + self.policy_ls_eps_decay_steps = self._cfg.get('policy_ls_eps_decay_steps ', 50000) # TODO 50k + print(f"self.policy_ls_eps_start:{self.policy_ls_eps_start}") + # @profile def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: """ @@ -393,6 +663,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch target_reward, target_value, target_policy = target_batch + # --- NEW: Calculate current epsilon for policy --- + if self.policy_ls_eps_start > 0: + progress = min(1.0, train_iter / self.policy_ls_eps_decay_steps) + current_policy_label_eps = self.policy_ls_eps_start * (1 - progress) + self.policy_ls_eps_end * progress + else: + current_policy_label_eps = 0.0 + # Prepare observations based on frame stack number if self._cfg.model.frame_stack_num > 1: obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) @@ -421,8 +698,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in transformed_target_value = scalar_transform(target_value) # Convert to categorical distributions - target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) - target_value_categorical = phi_transform(self.value_support, transformed_target_value) + # target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + # target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward, label_smoothing_eps= self._cfg.label_smoothing_eps) + target_value_categorical = phi_transform(self.value_support, transformed_target_value, label_smoothing_eps=self._cfg.label_smoothing_eps) # Prepare batch for GPT model batch_for_gpt = {} @@ -445,6 +725,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in batch_for_gpt['target_value'] = target_value_categorical[:, :-1] batch_for_gpt['target_policy'] = target_policy[:, :-1] + batch_for_gpt['scalar_target_value'] = target_value + # Extract valid target policy data and compute entropy valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) @@ -452,13 +734,83 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Update world model losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle + batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, global_step=train_iter, current_policy_label_eps=current_policy_label_eps, ) # NOTE : compute_loss third argument is now a dead argument. If this changes, it could need adaptation between value_inverse and reward_inverse. - weighted_total_loss = losses.loss_total + # ==================== [修改] 集成范数监控逻辑 ==================== + norm_log_dict = {} + # 检查是否达到监控频率 + if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): + with torch.no_grad(): + # 1. 监控模型参数范数 + param_norm_metrics = self._monitor_model_norms() + norm_log_dict.update(param_norm_metrics) + + # 2. 监控中间张量 x (Transformer的输出) + intermediate_x = losses.intermediate_losses.get('intermediate_tensor_x') + if intermediate_x is not None: + # x 的形状为 (B, T, E) + # 计算每个 token 的 L2 范数 + token_norms = intermediate_x.norm(p=2, dim=-1) + + # 记录这些范数的统计数据 + norm_log_dict['norm/x_token/mean'] = token_norms.mean().item() + norm_log_dict['norm/x_token/std'] = token_norms.std().item() + norm_log_dict['norm/x_token/max'] = token_norms.max().item() + norm_log_dict['norm/x_token/min'] = token_norms.min().item() + + # 3. 监控 logits 的详细统计 (Value, Policy, Reward) + logits_value = losses.intermediate_losses.get('logits_value') + if logits_value is not None: + norm_log_dict['logits/value/mean'] = logits_value.mean().item() + norm_log_dict['logits/value/std'] = logits_value.std().item() + norm_log_dict['logits/value/max'] = logits_value.max().item() + norm_log_dict['logits/value/min'] = logits_value.min().item() + norm_log_dict['logits/value/abs_max'] = logits_value.abs().max().item() + + logits_policy = losses.intermediate_losses.get('logits_policy') + if logits_policy is not None: + norm_log_dict['logits/policy/mean'] = logits_policy.mean().item() + norm_log_dict['logits/policy/std'] = logits_policy.std().item() + norm_log_dict['logits/policy/max'] = logits_policy.max().item() + norm_log_dict['logits/policy/min'] = logits_policy.min().item() + norm_log_dict['logits/policy/abs_max'] = logits_policy.abs().max().item() + + logits_reward = losses.intermediate_losses.get('logits_reward') + if logits_reward is not None: + norm_log_dict['logits/reward/mean'] = logits_reward.mean().item() + norm_log_dict['logits/reward/std'] = logits_reward.std().item() + norm_log_dict['logits/reward/max'] = logits_reward.max().item() + norm_log_dict['logits/reward/min'] = logits_reward.min().item() + norm_log_dict['logits/reward/abs_max'] = logits_reward.abs().max().item() + + # 4. 监控 obs_embeddings (Encoder输出) 的统计 + obs_embeddings = losses.intermediate_losses.get('obs_embeddings') + if obs_embeddings is not None: + # 计算每个 embedding 的 L2 范数 + emb_norms = obs_embeddings.norm(p=2, dim=-1) + norm_log_dict['embeddings/obs/norm_mean'] = emb_norms.mean().item() + norm_log_dict['embeddings/obs/norm_std'] = emb_norms.std().item() + norm_log_dict['embeddings/obs/norm_max'] = emb_norms.max().item() + norm_log_dict['embeddings/obs/norm_min'] = emb_norms.min().item() + # ================================================================= + + # ==================== START MODIFICATION 2 ==================== + # Extract the calculated value_priority from the returned losses. + value_priority_tensor = losses.intermediate_losses['value_priority'] + # Convert to numpy array for the replay buffer, adding a small epsilon. + value_priority_np = value_priority_tensor.detach().cpu().numpy() + 1e-6 + # ===================== END MODIFICATION 2 ===================== + + # weighted_total_loss = losses.loss_total + # TODO: + weighted_total_loss = (weights * losses.loss_total).mean() + for loss_name, loss_value in losses.intermediate_losses.items(): self.intermediate_losses[f"{loss_name}"] = loss_value + # 从 losses 对象中提取策略熵 + obs_loss = self.intermediate_losses['loss_obs'] reward_loss = self.intermediate_losses['loss_rewards'] policy_loss = self.intermediate_losses['loss_policy'] @@ -471,9 +823,26 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in middle_step_losses = self.intermediate_losses['middle_step_losses'] last_step_losses = self.intermediate_losses['last_step_losses'] dormant_ratio_encoder = self.intermediate_losses['dormant_ratio_encoder'] - dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model'] + dormant_ratio_transformer = self.intermediate_losses['dormant_ratio_transformer'] + dormant_ratio_head = self.intermediate_losses['dormant_ratio_head'] + avg_weight_mag_encoder = self.intermediate_losses['avg_weight_mag_encoder'] + avg_weight_mag_transformer = self.intermediate_losses['avg_weight_mag_transformer'] + avg_weight_mag_head = self.intermediate_losses['avg_weight_mag_head'] + e_rank_last_linear = self.intermediate_losses['e_rank_last_linear'] + e_rank_sim_norm = self.intermediate_losses['e_rank_sim_norm'] latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms'] + latent_action_l2_norms = self.intermediate_losses['latent_action_l2_norms'] + logits_value_mean=self.intermediate_losses['logits_value_mean'] + logits_value_max=self.intermediate_losses['logits_value_max'] + logits_value_min=self.intermediate_losses['logits_value_min'] + logits_policy_mean=self.intermediate_losses['logits_policy_mean'] + logits_policy_max=self.intermediate_losses['logits_policy_max'] + logits_policy_min=self.intermediate_losses['logits_policy_min'] + temperature_value=self.intermediate_losses['temperature_value'] + temperature_reward=self.intermediate_losses['temperature_reward'] + temperature_policy=self.intermediate_losses['temperature_policy'] + assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" @@ -482,19 +851,107 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in if (train_iter % self.accumulation_steps) == 0: self._optimizer_world_model.zero_grad() + + # ==================== START: 目标熵正则化更新逻辑 ==================== + alpha_loss = None + current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight # 默认使用固定值 + if self.use_adaptive_entropy_weight: + # --- 动态计算目标熵 (这部分逻辑是正确的,予以保留) --- + progress = min(1.0, train_iter / self.target_entropy_decay_steps) + current_ratio = self.target_entropy_start_ratio * (1 - progress) + self.target_entropy_end_ratio * progress + action_space_size = self._cfg.model.action_space_size + # 注意:我们将 target_entropy 定义为正数,更符合直觉 + current_target_entropy = -np.log(1.0 / action_space_size) * current_ratio + + # --- 计算 alpha_loss (已修正符号) --- + # 这是核心修正点:去掉了最前面的负号 + # detach() 仍然是关键,确保 alpha_loss 的梯度只流向 log_alpha + alpha_loss = (self.log_alpha * (policy_entropy.detach() - current_target_entropy)).mean() + + # # --- 更新 log_alpha --- + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + # --- [优化建议] 增加 log_alpha 裁剪作为安全措施 --- + with torch.no_grad(): + # 将 alpha 限制在例如 [1e-4, 10.0] 的范围内 + self.log_alpha.clamp_(np.log(1e-4), np.log(10.0)) + + # --- 使用当前更新后的 alpha (截断梯度流) --- + current_alpha = self.log_alpha.exp().detach() + + # 重新计算加权的策略损失和总损失 + # 注意:这里的 policy_entropy 已经是一个batch的平均值 + weighted_policy_loss = orig_policy_loss - current_alpha * policy_entropy + # 重新构建总损失 (不使用 losses.loss_total) + # 确保这里的权重与 LossWithIntermediateLosses 类中的计算方式一致 + self.obs_loss_weight = 10 + self.value_loss_weight = 0.5 + self.reward_loss_weight = 1. + self.policy_loss_weight = 1. + self.ends_loss_weight = 0. + total_loss = ( + self.reward_loss_weight * reward_loss + + self.value_loss_weight * value_loss + + self.policy_loss_weight * weighted_policy_loss + + self.obs_loss_weight * obs_loss # 假设 ssl_loss_weight 是 obs_loss 的权重 + # ... 如果还有其他损失项,也加进来 ... + ) + weighted_total_loss = (weights * total_loss).mean() + # ===================== END: 目标熵正则化更新逻辑 ===================== + # Scale the loss by the number of accumulation steps weighted_total_loss = weighted_total_loss / self.accumulation_steps weighted_total_loss.backward() + # ----------------------------------------------------------------- + # 仍然在 torch.no_grad() 环境下执行 + # ================================================================= + with torch.no_grad(): + # 1. Encoder-Clip + # ==================== START: 动态计算当前 Clip 阈值 ==================== + current_clip_value = self.latent_norm_clip_threshold # 默认使用固定值 + if self.use_encoder_clip_annealing: + progress = min(1.0, train_iter / self.encoder_clip_anneal_steps) + + if self.encoder_clip_anneal_type == 'cosine': + # 余弦调度: 从1平滑过渡到0 + cosine_progress = 0.5 * (1.0 + np.cos(np.pi * progress)) + current_clip_value = self.encoder_clip_end + \ + (self.encoder_clip_start - self.encoder_clip_end) * cosine_progress + else: # 默认为线性调度 + current_clip_value = self.encoder_clip_start * (1 - progress) + \ + self.encoder_clip_end * progress + # ===================== END: 动态计算当前 Clip 阈值 ===================== + + # 1. Encoder-Clip (使用动态计算出的 current_clip_value) + if current_clip_value > 0 and 'obs_embeddings' in losses.intermediate_losses: + obs_embeddings = losses.intermediate_losses['obs_embeddings'] + if obs_embeddings is not None: + max_latent_norm = obs_embeddings.norm(p=2, dim=-1).max() + if max_latent_norm > current_clip_value: + scale_factor = current_clip_value / max_latent_norm.item() + # 不再频繁打印,或者可以改为每隔N步打印一次 + if train_iter % 1000 == 0: + print(f"[Encoder-Clip Annealing] Iter {train_iter}: Max latent norm {max_latent_norm.item():.2f} > {current_clip_value:.2f}. Scaling by {scale_factor:.4f}.") + scale_module_weights_vectorized(self._model.world_model.tokenizer.encoder, scale_factor) + # Check if the current iteration completes an accumulation cycle if (train_iter + 1) % self.accumulation_steps == 0: + # ==================== [新增] 监控梯度范数 ==================== + # 在梯度裁剪之前监控梯度范数,用于诊断梯度爆炸/消失问题 + if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): + grad_norm_metrics = self._monitor_gradient_norms() + norm_log_dict.update(grad_norm_metrics) + # ================================================================= + # Analyze gradient norms if simulation normalization analysis is enabled if self._cfg.analysis_sim_norm: # Clear previous analysis results to prevent memory overflow del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() self._target_model.encoder_hook.clear_data() - + # Clip gradients to prevent exploding gradients total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( self._learn_model.world_model.parameters(), self._cfg.grad_clip_value @@ -561,21 +1018,61 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'target_policy_entropy': average_target_policy_entropy.item(), 'reward_loss': reward_loss.item(), 'value_loss': value_loss.item(), - # 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO + # Add value_priority to the log dictionary. + 'value_priority': value_priority_np.mean().item(), + 'value_priority_orig': value_priority_np, 'target_reward': target_reward.mean().item(), 'target_value': target_value.mean().item(), 'transformed_target_reward': transformed_target_reward.mean().item(), 'transformed_target_value': transformed_target_value.mean().item(), 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), - 'analysis/dormant_ratio_encoder': dormant_ratio_encoder.item(), - 'analysis/dormant_ratio_world_model': dormant_ratio_world_model.item(), + 'analysis/dormant_ratio_encoder': dormant_ratio_encoder, + 'analysis/dormant_ratio_transformer': dormant_ratio_transformer, + 'analysis/dormant_ratio_head': dormant_ratio_head, + + 'analysis/avg_weight_mag_encoder': avg_weight_mag_encoder, + 'analysis/avg_weight_mag_transformer': avg_weight_mag_transformer, + 'analysis/avg_weight_mag_head': avg_weight_mag_head, + 'analysis/e_rank_last_linear': e_rank_last_linear, + 'analysis/e_rank_sim_norm': e_rank_sim_norm, + 'analysis/latent_state_l2_norms': latent_state_l2_norms.item(), + 'analysis/latent_action_l2_norms': latent_action_l2_norms, 'analysis/l2_norm_before': self.l2_norm_before, 'analysis/l2_norm_after': self.l2_norm_after, 'analysis/grad_norm_before': self.grad_norm_before, 'analysis/grad_norm_after': self.grad_norm_after, + "logits_value_mean":logits_value_mean, + "logits_value_max":logits_value_max, + "logits_value_min":logits_value_min, + "logits_policy_mean":logits_policy_mean, + "logits_policy_max":logits_policy_max, + "logits_policy_min":logits_policy_min, + + "temperature_value":temperature_value, + "temperature_reward":temperature_reward, + "temperature_policy":temperature_policy, + + "current_policy_label_eps":current_policy_label_eps, } - + + # ==================== [修改] 将范数监控结果合并到日志中 ==================== + if norm_log_dict: + return_log_dict.update(norm_log_dict) + # ======================================================================= + + # ==================== START: 添加新日志项 ==================== + if self.use_adaptive_entropy_weight: + return_log_dict['adaptive_alpha'] = current_alpha.item() + return_log_dict['adaptive_target_entropy_ratio'] = current_ratio + return_log_dict['alpha_loss'] = alpha_loss.item() + # ==================== START: 添加新日志项 ==================== + + # ==================== START: 添加新日志项 ==================== + if self.use_encoder_clip_annealing: + return_log_dict['current_encoder_clip_value'] = current_clip_value + # ===================== END: 添加新日志项 ===================== + if self._cfg.use_wandb: wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step) wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) @@ -597,11 +1094,13 @@ def _init_collect(self) -> None: Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. """ self._collect_model = self._model - + # 为 collect MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_collect_cfg = copy.deepcopy(self._cfg) + mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) + self._mcts_collect = MCTSCtree(mcts_collect_cfg) else: - self._mcts_collect = MCTSPtree(self._cfg) + self._mcts_collect = MCTSPtree(mcts_collect_cfg) self._collect_mcts_temperature = 1. self._collect_epsilon = 0.0 self.collector_env_num = self._cfg.collector_env_num @@ -622,8 +1121,9 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id: np.ndarray = None, - timestep: List = [0] + ready_env_id: np.array = None, + timestep: List = [0], + task_id: int = None, ) -> Dict: """ Overview: @@ -636,6 +1136,7 @@ def _forward_collect( - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - timestep (:obj:`list`): The step index of the env in one episode. + - task_id (:obj:`int`): The task id. Default is None, which means UniZero is in the single-task mode. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ @@ -741,13 +1242,25 @@ def _forward_collect( self.last_batch_obs = data self.last_batch_action = batch_action - # ========= TODO: for muzero_segment_collector now ========= + # ========= TODO: This logic is a temporary workaround specific to the muzero_segment_collector. ========= if active_collect_env_num < self.collector_env_num: - print('==========collect_forward============') - print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + # When an environment finishes an episode ('done'), the length of `self.last_batch_obs` passed back + # becomes smaller than the total number of collector environments. + # Handling this dynamic batch size is complex, as the transformer's KV cache retrieval + # requires a stable environment ID for correct indexing. A mismatch would cause retrieval errors. + # + # Therefore, as a simpler solution, we reset the collection state for ALL environments. + # By resetting `self.last_batch_action` to -1 for all `self.collector_env_num` environments, + # we force the transformer to start its context from scratch, avoiding incorrect cache lookups. + print('========== collect_forward ============') + print(f'An environment has finished. Active envs: {active_collect_env_num} < Total envs: {self.collector_env_num}. Resetting all.') + self._reset_collect(reset_init_data=True) + + # If the sampling type is 'episode', it's unexpected for the number of active environments to drop, + # as this suggests an inconsistent state or a potential issue in the collection logic. if getattr(self._cfg, 'sample_type', '') == 'episode': - print('BUG: sample_type is episode, but len(self.last_batch_obs) < self.collector_env_num') + print('WARNING: Inconsistent state detected. `sample_type` is "episode", but the number of active environments has changed.') return output @@ -757,10 +1270,16 @@ def _init_eval(self) -> None: Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. """ self._eval_model = self._model + + # 为 eval MCTS 创建一个配置副本,并设置特定的模拟次数 + mcts_eval_cfg = copy.deepcopy(self._cfg) + mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations + if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) + self._mcts_eval = MCTSCtree(mcts_eval_cfg) else: - self._mcts_eval = MCTSPtree(self._cfg) + self._mcts_eval = MCTSPtree(mcts_eval_cfg) + self.evaluator_env_num = self._cfg.evaluator_env_num if self._cfg.model.model_type == 'conv': @@ -772,8 +1291,8 @@ def _init_eval(self) -> None: ).to(self._cfg.device) self.last_batch_action = [-1 for i in range(self.collector_env_num)] - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [-1], - ready_env_id: np.array = None, timestep: List = [0]) -> Dict: + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, timestep: List = [0], task_id: int = None,) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -784,6 +1303,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to eval. - timestep (:obj:`list`): The step index of the env in one episode. + - task_id (:obj:`int`): The task id. Default is None, which means UniZero is in the single-task mode. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of eval_env, C is the number of channels, \ @@ -804,7 +1324,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ ready_env_id = np.arange(active_eval_env_num) output = {i: None for i in ready_env_id} with torch.no_grad(): - network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) + network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data, timestep) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) # if not in training, obtain the scalars of the value/reward @@ -864,12 +1384,12 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ } batch_action.append(action) - self.last_batch_obs = data + self.last_batch_obs_eval = data self.last_batch_action = batch_action return output - def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True, task_id: int = None) -> None: """ Overview: This method resets the collection process for a specific environment. It clears caches and memory @@ -890,15 +1410,31 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in ) self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] - # Return immediately if env_id is None or a list - if env_id is None or isinstance(env_id, list): - return + # We must handle both single int and list of ints for env_id. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the collector. + if current_steps is None: + world_model = self._collect_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end.') + + # ======== TODO: 20251015 ======== # Determine the clear interval based on the environment's sample type - clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length # Clear caches if the current steps are a multiple of the clear interval - if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: print(f'clear_interval: {clear_interval}') # Clear various caches in the collect model's world model @@ -911,10 +1447,9 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in # Free up GPU memory torch.cuda.empty_cache() - print('collector: collect_model clear()') - print(f'eps_steps_lst[{env_id}]: {current_steps}') + print(f'eps_steps_lst[{env_id}]: {current_steps}, collector: collect_model clear()') - def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True, task_id: int = None) -> None: """ Overview: This method resets the evaluation process for a specific environment. It clears caches and memory @@ -927,23 +1462,61 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. """ if reset_init_data: - self.last_batch_obs = initialize_pad_batch( - self._cfg.model.observation_shape, - self._cfg.evaluator_env_num, - self._cfg.device, - pad_token_id=self.pad_token_id - ) + if task_id is not None: + self.last_batch_obs_eval = initialize_zeros_batch( + self._cfg.model.observation_shape_list[task_id], + self._cfg.evaluator_env_num, + self._cfg.device, + pad_token_id=self.pad_token_id + ) + print(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + + else: + self.last_batch_obs_eval = initialize_pad_batch( # TODO + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device, + pad_token_id=self.pad_token_id + ) + print(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] - # Return immediately if env_id is None or a list - if env_id is None or isinstance(env_id, list): - return + # --- BEGIN ROBUST FIX --- + # This logic handles the crucial end-of-episode cache clearing for evaluation. + # The evaluator calls `_policy.reset([env_id])` when an episode is done. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id - # Determine the clear interval based on the environment's sample type - clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + # The key condition: `current_steps` is None only on the end-of-episode reset call from the evaluator. + if current_steps is None: + world_model = self._eval_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end.') + # The recurrent cache is global. + world_model.past_kv_cache_recurrent_infer.clear() + + if hasattr(world_model, 'keys_values_wm_list'): + world_model.keys_values_wm_list.clear() + + torch.cuda.empty_cache() + return + # --- END ROBUST FIX --- + + # ======== TODO: 20251015 ======== + # Determine the clear interval based on the environment's sample type + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length # Clear caches if the current steps are a multiple of the clear interval - if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: print(f'clear_interval: {clear_interval}') # Clear various caches in the eval model's world model @@ -965,56 +1538,142 @@ def _monitor_vars_learn(self) -> List[str]: Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value ``_forward_learn``. """ - return [ + base_vars = [ + # ==================== Analysis Metrics ==================== 'analysis/dormant_ratio_encoder', - 'analysis/dormant_ratio_world_model', + 'analysis/dormant_ratio_transformer', + 'analysis/dormant_ratio_head', + 'analysis/avg_weight_mag_encoder', + 'analysis/avg_weight_mag_transformer', + 'analysis/avg_weight_mag_head', + 'analysis/e_rank_last_linear', + 'analysis/e_rank_sim_norm', 'analysis/latent_state_l2_norms', + 'analysis/latent_action_l2_norms', 'analysis/l2_norm_before', 'analysis/l2_norm_after', 'analysis/grad_norm_before', 'analysis/grad_norm_after', + # ==================== Step-wise Loss Analysis ==================== 'analysis/first_step_loss_value', 'analysis/first_step_loss_policy', 'analysis/first_step_loss_rewards', 'analysis/first_step_loss_obs', - 'analysis/middle_step_loss_value', 'analysis/middle_step_loss_policy', 'analysis/middle_step_loss_rewards', 'analysis/middle_step_loss_obs', - 'analysis/last_step_loss_value', 'analysis/last_step_loss_policy', 'analysis/last_step_loss_rewards', 'analysis/last_step_loss_obs', + # ==================== System Metrics ==================== 'Current_GPU', 'Max_GPU', 'collect_epsilon', 'collect_mcts_temperature', 'cur_lr_world_model', - 'cur_lr_tokenizer', + # ==================== Core Losses ==================== 'weighted_total_loss', 'obs_loss', 'policy_loss', 'orig_policy_loss', 'policy_entropy', 'latent_recon_loss', + 'perceptual_loss', 'target_policy_entropy', 'reward_loss', 'value_loss', - 'consistency_loss', 'value_priority', 'target_reward', 'target_value', + 'transformed_target_reward', + 'transformed_target_value', + + # ==================== Gradient Norms ==================== 'total_grad_norm_before_clip_wm', - # tokenizer - 'commitment_loss', - 'reconstruction_loss', - 'perceptual_loss', + + # ==================== Logits Statistics ==================== + 'logits_value_mean', + 'logits_value_max', + 'logits_value_min', + 'logits_policy_mean', + 'logits_policy_max', + 'logits_policy_min', + + # ==================== Temperature Parameters ==================== + 'temperature_value', + 'temperature_reward', + 'temperature_policy', + + # ==================== Training Configuration ==================== + 'current_policy_label_eps', + 'adaptive_alpha', + 'adaptive_target_entropy_ratio', + 'alpha_loss', + 'current_encoder_clip_value', + ] + + # ==================== [新增] 范数和中间张量监控变量 ==================== + norm_vars = [ + # 模块总范数 (参数范数) + 'norm/encoder/_total_norm', + 'norm/transformer/_total_norm', + 'norm/head_value/_total_norm', + 'norm/head_reward/_total_norm', + 'norm/head_policy/_total_norm', + + # 模块总范数 (梯度范数) + 'grad/encoder/_total_norm', + 'grad/transformer/_total_norm', + 'grad/head_value/_total_norm', + 'grad/head_reward/_total_norm', + 'grad/head_policy/_total_norm', + + # 中间张量 x (Transformer输出) 的统计信息 + 'norm/x_token/mean', + 'norm/x_token/std', + 'norm/x_token/max', + 'norm/x_token/min', + + # Logits 的详细统计 (Value) + 'logits/value/mean', + 'logits/value/std', + 'logits/value/max', + 'logits/value/min', + 'logits/value/abs_max', + + # Logits 的详细统计 (Policy) + 'logits/policy/mean', + 'logits/policy/std', + 'logits/policy/max', + 'logits/policy/min', + 'logits/policy/abs_max', + + # Logits 的详细统计 (Reward) + 'logits/reward/mean', + 'logits/reward/std', + 'logits/reward/max', + 'logits/reward/min', + 'logits/reward/abs_max', + + # Embeddings 的统计信息 + 'embeddings/obs/norm_mean', + 'embeddings/obs/norm_std', + 'embeddings/obs/norm_max', + 'embeddings/obs/norm_min', ] + # 注意:我们不把每一层的范数都加到这里,因为数量太多会导致日志混乱。 + # 在实践中,如果通过总范数发现问题,可以临时在TensorBoard中搜索特定层的范数, + # 或者在本地打印 `norm_log_dict` 来进行详细分析。 + # wandb等工具可以更好地处理大量的动态指标。 + # ======================================================================== + + return base_vars + norm_vars + def _state_dict_learn(self) -> Dict[str, Any]: """ @@ -1023,11 +1682,16 @@ def _state_dict_learn(self) -> Dict[str, Any]: Returns: - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. """ - return { + state_dict = { 'model': self._learn_model.state_dict(), 'target_model': self._target_model.state_dict(), 'optimizer_world_model': self._optimizer_world_model.state_dict(), } + # ==================== START: 保存Alpha优化器状态 ==================== + if self.use_adaptive_entropy_weight: + state_dict['alpha_optimizer'] = self.alpha_optimizer.state_dict() + # ===================== END: 保存Alpha优化器状态 ===================== + return state_dict def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ @@ -1038,7 +1702,12 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ self._learn_model.load_state_dict(state_dict['model']) self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + # ==================== START: 加载Alpha优化器状态 ==================== + # if self.use_adaptive_entropy_weight and 'alpha_optimizer' in state_dict: + # self.alpha_optimizer.load_state_dict(state_dict['alpha_optimizer']) + # ===================== END: 加载Alpha优化器状态 ===================== def recompute_pos_emb_diff_and_clear_cache(self) -> None: """ diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py new file mode 100644 index 000000000..cbf605a1e --- /dev/null +++ b/lzero/policy/unizero_multitask.py @@ -0,0 +1,2288 @@ +import copy +from collections import defaultdict +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY + +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import UniZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import prepare_obs_stack_for_unizero +from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs +from lzero.policy.unizero import UniZeroPolicy, scale_module_weights_vectorized +from .utils import configure_optimizers_nanogpt +import sys + +# Please replace the path with the actual location of your LibMTL library. +sys.path.append('/path/to/your/LibMTL') + +from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect +from LibMTL.weighting.moco_fast_mem_eff import FastMoCoMemEff as FastMoCo +from LibMTL.weighting.moco_fast_mem_eff import MoCoCfg + +import torch.distributed as dist + +# ------------------------------------------------------------ +# 1. Add a dedicated process-group for the learner. +# (This function should be called once during the initialization of the main process or the learner.) +# ------------------------------------------------------------ +def build_learner_group(learner_ranks: list[int]) -> dist.ProcessGroup: + """ + Overview: + Builds and returns a new process group containing only the learner ranks. + This is used for methods like GenericMoCo that require collective communication + only among the ranks performing training. + Arguments: + - learner_ranks (:obj:`list[int]`): A list of world ranks that are designated as learners. + These are the ranks that will perform the backward pass. + e.g., if CUDA_VISIBLE_DEVICES=0,1, then learner_ranks=[0,1]. + Returns: + - pg (:obj:`dist.ProcessGroup`): A new process group containing only the learner ranks. + """ + world_pg = dist.group.WORLD + pg = dist.new_group(ranks=learner_ranks, backend='nccl') + if dist.get_rank() in learner_ranks: + torch.cuda.set_device(learner_ranks.index(dist.get_rank())) + return pg + + +def generate_task_loss_dict(multi_task_losses: List[Union[torch.Tensor, float]], task_name_template: str, task_id: int) -> Dict[str, float]: + """ + Overview: + Generates a dictionary for the losses of each task. + Arguments: + - multi_task_losses (:obj:`List[Union[torch.Tensor, float]]`): A list containing the loss for each task. + - task_name_template (:obj:`str`): The template for the task name, e.g., 'obs_loss_task{}'. + - task_id (:obj:`int`): The starting ID of the tasks. + Returns: + - task_loss_dict (:obj:`Dict[str, float]`): A dictionary where keys are formatted task names and values are the corresponding losses. + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + # Get the scalar value of the loss if it's a tensor. + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + except Exception as e: + task_loss_dict[task_name] = task_loss + return task_loss_dict + +# # 修改后的函数: +# def generate_task_loss_dict( +# multi_task_losses: List[Union[torch.Tensor, float]], +# task_name_template: str, +# global_task_ids: List[int] +# ) -> Dict[str, float]: +# """ +# Overview: +# Generates a dictionary for the losses of each task using their explicit global IDs. +# Arguments: +# - multi_task_losses (:obj:`List[Union[torch.Tensor, float]]`): A list containing the loss for each task. +# - task_name_template (:obj:`str`): The template for the task name, e.g., 'obs_loss_task{}'. +# - global_task_ids (:obj:`List[int]`): A list of global task IDs corresponding to each loss in multi_task_losses. +# Returns: +# - task_loss_dict (:obj:`Dict[str, float]`): A dictionary where keys are formatted task names and values are the corresponding losses. +# """ +# task_loss_dict = {} +# # 使用 zip 将每个损失与其正确的全局ID配对 +# for task_loss, global_id in zip(multi_task_losses, global_task_ids): +# task_name = task_name_template.format(global_id) +# try: +# task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss +# except Exception as e: +# task_loss_dict[task_name] = task_loss +# return task_loss_dict + + +class WrappedModel: + """ + Overview: + A wrapper class for the world model to conveniently access its parameters and zero its gradients. + This version wraps the entire world model. + """ + def __init__(self, world_model: torch.nn.Module): + """ + Arguments: + - world_model (:obj:`torch.nn.Module`): The world model instance. + """ + self.world_model = world_model + + def parameters(self) -> iter: + """ + Overview: + Returns an iterator over the parameters of the entire world model. + """ + return self.world_model.parameters() + + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all world model parameters to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ + self.world_model.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV2: + """ + Overview: + A wrapper for specific components of the world model. + This version is designed to group parameters that are considered "shared" + across tasks for gradient correction methods like MoCo, excluding the prediction heads. + """ + def __init__(self, tokenizer: torch.nn.Module, transformer: torch.nn.Module, pos_emb: torch.nn.Module, task_emb: torch.nn.Module, act_embedding_table: torch.nn.Module): + """ + Arguments: + - tokenizer (:obj:`torch.nn.Module`): The tokenizer module. + - transformer (:obj:`torch.nn.Module`): The transformer backbone. + - pos_emb (:obj:`torch.nn.Module`): The positional embedding module. + - task_emb (:obj:`torch.nn.Module`): The task embedding module. + - act_embedding_table (:obj:`torch.nn.Module`): The action embedding table. + """ + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self) -> iter: + """ + Overview: + Returns an iterator over the parameters of the wrapped components (tokenizer, transformer, embeddings). + These are typically the shared parts of the model whose gradients need to be managed for multi-task learning. + """ + return (list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + # list(self.task_emb.parameters()) + # TODO: Decide whether to include task embeddings in shared parameters. + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all wrapped components to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + # self.task_emb.zero_grad(set_to_none=set_to_none) # TODO: Match the decision made in the parameters() method. + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV3: + """ + Overview: + An alternative wrapper for world model components. + This version excludes the tokenizer from the shared parameters, focusing gradient correction + on the transformer and embedding layers. + """ + def __init__(self, transformer: torch.nn.Module, pos_emb: torch.nn.Module, task_emb: torch.nn.Module, act_embedding_table: torch.nn.Module): + """ + Arguments: + - transformer (:obj:`torch.nn.Module`): The transformer backbone. + - pos_emb (:obj:`torch.nn.Module`): The positional embedding module. + - task_emb (:obj:`torch.nn.Module`): The task embedding module. + - act_embedding_table (:obj:`torch.nn.Module`): The action embedding table. + """ + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self) -> iter: + """ + Overview: + Returns an iterator over the parameters of the transformer and various embedding layers. + """ + return (list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of the wrapped components to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + +# def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): +# """ +# 为UniZero模型配置带有差异化学习率的优化器。 +# """ +# # 1. 定义需要特殊处理的参数 +# param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + +# # 2. 将参数分为三组:Transformer主干、Tokenizer、Heads +# transformer_params = {pn: p for pn, p in param_dict.items() if 'transformer' in pn} +# tokenizer_params = {pn: p for pn, p in param_dict.items() if 'tokenizer' in pn} + +# # Heads的参数是那些既不属于transformer也不属于tokenizer的 +# head_params = { +# pn: p for pn, p in param_dict.items() +# if 'transformer' not in pn and 'tokenizer' not in pn +# } + +# # 3. 为每组设置不同的优化器参数(特别是学习率) +# # 这里我们仍然使用AdamW,但学习率设置更合理 +# optim_groups = [ +# { +# 'params': list(transformer_params.values()), +# 'lr': learning_rate, # 1e-4 +# # 'lr': learning_rate * 0.2, # 为Transformer主干设置一个较小的学习率,例如 1e-5 +# 'weight_decay': weight_decay +# # 'weight_decay': weight_decay * 5.0 +# }, +# { +# 'params': list(tokenizer_params.values()), +# 'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4 +# # 'lr': learning_rate * 0.1, # 为encoder设置一个较小的学习率,例如 1e-5 +# 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + +# }, +# { +# 'params': list(head_params.values()), +# 'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4 +# 'weight_decay': 0.0 # 通常Heads的权重不做衰减 +# # 'weight_decay': weight_decay + +# } +# ] + +# print("--- Optimizer Groups ---") +# print(f"Transformer LR: {learning_rate}") +# print(f"Tokenizer/Heads LR: {learning_rate}") + +# optimizer = torch.optim.AdamW(optim_groups, betas=betas) +# return optimizer + +def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): + """ + 为UniZero模型配置带有差异化学习率的优化器。 + (修正版,确保参数组互斥) + """ + # 1. 创建空的参数列表用于分组 + transformer_params = [] + tokenizer_params = [] + head_params = [] + + # 2. 遍历所有可训练参数,并使用 if/elif/else 结构确保每个参数只被分配到一个组 + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if 'transformer' in name: + transformer_params.append(param) + elif 'tokenizer' in name: + tokenizer_params.append(param) + else: + head_params.append(param) + + # 3. 为每组设置不同的优化器参数 + # 这里我们仍然使用AdamW,但学习率设置更合理 + optim_groups = [ + { + 'params': transformer_params, + 'lr': learning_rate, # 1e-4 + 'weight_decay': weight_decay + }, + { + 'params': tokenizer_params, + 'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4 + # 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + 'weight_decay': weight_decay # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + }, + { + 'params': head_params, + 'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4 + # 'weight_decay': 0.0 # 通常Heads的权重不做衰减 + 'weight_decay': weight_decay + + } + ] + + print("--- Optimizer Groups ---") + # 打印每个组的参数数量以供调试 + print(f"Transformer params: {len(transformer_params)}") + print(f"Tokenizer params: {len(tokenizer_params)}") + print(f"Head params: {len(head_params)}") + print(f"Transformer LR: {learning_rate}") + print(f"Tokenizer/Heads LR: {learning_rate}") + + optimizer = torch.optim.AdamW(optim_groups, betas=betas) + return optimizer + +@POLICY_REGISTRY.register('unizero_multitask') +class UniZeroMTPolicy(UniZeroPolicy): + """ + Overview: + The policy class for multi-task UniZero, an official implementation for the paper "UniZero: Generalized and Efficient Planning + with Scalable Latent World Models". UniZero aims to enhance the planning capabilities of reinforcement learning agents + by addressing the limitations of MuZero-style algorithms, particularly in environments requiring the + capture of long-term dependencies. More details can be found at: https://arxiv.org/abs/2406.10667. + """ + + # The default_config for UniZero multi-task policy. + config = dict( + type='unizero_multitask', + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The obs shape. + observation_shape=(3, 64, 64), + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=3, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=50, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'BN'. + norm_type='LN', # NOTE: LayerNorm is used in the transformer-based world model. + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (int) The save interval of the model. + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), + world_model_cfg=dict( + # (int) The number of tokens per block. + tokens_per_block=2, + # (int) The maximum number of blocks. + max_blocks=10, + # (int) The maximum number of tokens, calculated as tokens per block multiplied by max blocks. + max_tokens=2 * 10, + # (int) The context length, usually calculated as twice the number of some base unit. + context_length=2 * 4, + # (bool) Whether to use GRU gating mechanism. + gru_gating=False, + # (str) The device to be used for computation, e.g., 'cpu' or 'cuda'. + device='cpu', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, + # (int) The shape of the action space. + action_space_size=6, + # (int) The size of the group, related to simulation normalization. + group_size=8, # NOTE: for sim_norm + # (str) The type of attention mechanism used. Options could be ['causal']. + attention='causal', + # (int) The number of layers in the model. + num_layers=2, + # (int) The number of attention heads. + num_heads=8, + # (int) The dimension of the embedding. + embed_dim=768, + # (float) The dropout probability for the embedding layer. + embed_pdrop=0.1, + # (float) The dropout probability for the residual connections. + resid_pdrop=0.1, + # (float) The dropout probability for the attention mechanism. + attn_pdrop=0.1, + # (int) The size of the support set for value and reward heads. + support_size=101, + # (int) The maximum size of the cache. + max_cache_size=5000, + # (int) The number of environments. + env_num=8, + # (float) The weight of the latent reconstruction loss. + latent_recon_loss_weight=0., + # (float) The weight of the perceptual loss. + perceptual_loss_weight=0., + # (float) The weight of the policy entropy. + policy_entropy_weight=1e-4, + # (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse']. + predict_latent_loss_type='group_kl', + # (str) The type of observation. Options are ['image', 'vector']. + obs_type='image', + # (float) The discount factor for future rewards. + gamma=1, + # (bool) Whether to analyze dormant ratio, average_weight_magnitude of net, effective_rank of latent. + analysis_dormant_ratio_weight_rank=False, + # (float) The threshold for a dormant neuron. + dormant_threshold=0.01, + + ), + ), + # ****** common ****** + # (bool) whether to use rnd model. + use_rnd_model=False, + # (bool) Whether to use multi-gpu training. + multi_gpu=True, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space']. + action_type='fixed_action_space', + # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=400, + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to use the pure policy to collect data. + collect_with_pure_policy=False, + # (int) The evaluation frequency. + eval_freq=int(5e3), + # (str) The sample type. Options are ['episode', 'transition']. + sample_type='transition', + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + replay_ratio=0.25, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. + optim_type='AdamW', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.0001, + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + monitor_norm_freq=5000, + # ============================================================ + # (int) Frequency of hard target network update. + target_update_freq=100, + # (int) Frequency of soft target network update. + target_update_theta=0.05, + # (int) Frequency of target network update. + target_update_freq_for_intrinsic_reward=1000, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=5, + # (int) The number of episodes in each collecting stage when use muzero_collector. + n_episode=8, + # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. + num_segments=8, + # # (int) the number of simulations in MCTS for renalyze. + num_simulations=50, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=10, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + cos_lr_scheduler=False, + piecewise_decay_lr_scheduler=False, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=False, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + # (int) The initial Env Steps for training. + train_start_after_envsteps=int(0), + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm's default model setting for demonstration. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): A tuple containing the model name and a list of import paths. + - model_type (:obj:`str`): The model type used in this algorithm, registered in ModelRegistry. + - import_names (:obj:`List[str]`): The list of model class paths used in this algorithm. + .. note:: + Users can define and use customized network models, but they must adhere to the same interface definition + as indicated by the import_names path. For multi-task UniZero, this is ``lzero.model.unizero_model_multitask.UniZeroMTModel``. + """ + # NOTE: This specifies the default multi-task model. + return 'UniZeroMTModel', ['lzero.model.unizero_model_multitask'] + + # ==================== [新增] 模型范数监控函数 ==================== + def _monitor_model_norms(self) -> Dict[str, float]: + """ + Overview: + 计算并返回模型关键组件(Encoder, Transformer, Heads)的参数矩阵范数。 + 此函数应在 torch.no_grad() 环境下调用,以提高效率。 + Returns: + - norm_metrics (:obj:`Dict[str, float]`): 包含所有范数指标的字典,用于日志记录。 + """ + world_model = self._learn_model.world_model + norm_metrics = {} + + # 定义要监控的模块组 + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_values, # Note: multi-task uses head_values (plural) + 'head_reward': world_model.head_rewards, + 'head_policy': world_model.head_policies, # Note: multi-task uses head_policies (plural) + } + + for group_name, group_module in module_groups.items(): + # Handle ModuleList (for multi-task heads) + if isinstance(group_module, torch.nn.ModuleList): + for task_idx, task_module in enumerate(group_module): + total_norm_sq = 0.0 + for param_name, param in task_module.named_parameters(): + if param.requires_grad: + param_norm = param.data.norm(2).item() + log_name = f'norm/{group_name}_task{task_idx}/{param_name.replace(".", "/")}' + norm_metrics[log_name] = param_norm + total_norm_sq += param_norm ** 2 + total_group_norm = np.sqrt(total_norm_sq) + norm_metrics[f'norm/{group_name}_task{task_idx}/_total_norm'] = total_group_norm + else: + # Handle single module + total_norm_sq = 0.0 + for param_name, param in group_module.named_parameters(): + if param.requires_grad: + param_norm = param.data.norm(2).item() + log_name = f'norm/{group_name}/{param_name.replace(".", "/")}' + norm_metrics[log_name] = param_norm + total_norm_sq += param_norm ** 2 + total_group_norm = np.sqrt(total_norm_sq) + norm_metrics[f'norm/{group_name}/_total_norm'] = total_group_norm + + return norm_metrics + + def _monitor_gradient_norms(self) -> Dict[str, float]: + """ + Overview: + 计算并返回模型关键组件的梯度范数。 + 此函数应在梯度计算完成后、参数更新之前调用。 + Returns: + - grad_metrics (:obj:`Dict[str, float]`): 包含所有梯度范数指标的字典,用于日志记录。 + """ + world_model = self._learn_model.world_model + grad_metrics = {} + + # 定义要监控的模块组 + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_values, + 'head_reward': world_model.head_rewards, + 'head_policy': world_model.head_policies, + } + + for group_name, group_module in module_groups.items(): + # Handle ModuleList (for multi-task heads) + if isinstance(group_module, torch.nn.ModuleList): + for task_idx, task_module in enumerate(group_module): + total_grad_norm_sq = 0.0 + num_params_with_grad = 0 + for param_name, param in task_module.named_parameters(): + if param.requires_grad and param.grad is not None: + grad_norm = param.grad.data.norm(2).item() + log_name = f'grad/{group_name}_task{task_idx}/{param_name.replace(".", "/")}' + grad_metrics[log_name] = grad_norm + total_grad_norm_sq += grad_norm ** 2 + num_params_with_grad += 1 + if num_params_with_grad > 0: + total_group_grad_norm = np.sqrt(total_grad_norm_sq) + grad_metrics[f'grad/{group_name}_task{task_idx}/_total_norm'] = total_group_grad_norm + else: + grad_metrics[f'grad/{group_name}_task{task_idx}/_total_norm'] = 0.0 + else: + # Handle single module + total_grad_norm_sq = 0.0 + num_params_with_grad = 0 + for param_name, param in group_module.named_parameters(): + if param.requires_grad and param.grad is not None: + grad_norm = param.grad.data.norm(2).item() + log_name = f'grad/{group_name}/{param_name.replace(".", "/")}' + grad_metrics[log_name] = grad_norm + total_grad_norm_sq += grad_norm ** 2 + num_params_with_grad += 1 + if num_params_with_grad > 0: + total_group_grad_norm = np.sqrt(total_grad_norm_sq) + grad_metrics[f'grad/{group_name}/_total_norm'] = total_group_grad_norm + else: + grad_metrics[f'grad/{group_name}/_total_norm'] = 0.0 + + return grad_metrics + # ================================================================= + + def _init_learn(self) -> None: + """ + Overview: + Initializes the learn mode. This method is called by ``self.__init__``. + It sets up the learn model, optimizer, target model, and other utilities required for training. + """ + if self._cfg.optim_type == 'SGD': + # --- 改为SGD优化器 --- + self._optimizer_world_model = torch.optim.SGD( + self._model.world_model.parameters(), + lr=self._cfg.learning_rate, # 初始学习率,在配置中设为 0.2 + momentum=self._cfg.momentum, # 在配置中设为 0.9 + weight_decay=self._cfg.weight_decay # 在配置中设为 1e-4 + ) + elif self._cfg.optim_type == 'AdamW': + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + elif self._cfg.optim_type == 'AdamW_mix_lr_wdecay': + self._optimizer_world_model = configure_optimizer_unizero( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, # 使用一个合理的AdamW基础学习率 + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + if self._cfg.cos_lr_scheduler: + from torch.optim.lr_scheduler import CosineAnnealingLR + # TODO: check the total training steps + # self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + total_iters = self._cfg.get('total_iterations', 500000) # 500k iter + # final_lr = self._cfg.get('final_learning_rate', 0.0) + final_lr = self._cfg.get('final_learning_rate', 1e-6) + + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, + T_max=total_iters, + eta_min=final_lr + ) + print(f"CosineAnnealingLR enabled: T_max={total_iters}, eta_min={final_lr}") + + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer_world_model, lr_lambda=lr_lambda) + + + # Use a deep copy for the target model. + self._target_model = copy.deepcopy(self._model) + # Ensure that the installed torch version is >= 2.0 for torch.compile. + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" + self._model = torch.compile(self._model) + self._target_model = torch.compile(self._target_model) + + # Wrap the target model for soft updates (momentum-based). + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + + self.intermediate_losses = defaultdict(float) + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + + # Create a WrappedModel instance. + # This is used for gradient correction methods where gradients of shared parameters are managed. + # In this setup, all parameters are considered shared and subject to correction. + # wrapped_model = WrappedModel( + # self._learn_model.world_model, + # ) + + self.task_id = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + + print(f'self._cfg.only_use_moco_stats:{self._cfg.only_use_moco_stats}') + if self._cfg.use_moco or self._cfg.only_use_moco_stats: + # The prediction heads' gradients are not corrected. + self.wrapped_model = WrappedModelV2( + # TODO: This assumes the tokenizer has an encoder attribute which is a list. This might need to be more robust. + self._learn_model.world_model.tokenizer.encoder[0], + self._learn_model.world_model.transformer, + self._learn_model.world_model.pos_emb, + self._learn_model.world_model.task_emb, + self._learn_model.world_model.act_embedding_table, + ) + + # Alternative setup: The head and tokenizer.encoder gradients are not corrected. + # wrapped_model = WrappedModelV3( + # self._learn_model.world_model.transformer, + # self._learn_model.world_model.pos_emb, + # self._learn_model.world_model.task_emb, + # self._learn_model.world_model.act_embedding_table, + # ) + + # Pass the wrapped_model as `shared_module` to the gradient correction method. + # ========= Initialize MoCo/CAGrad parameters ========= + if self._cfg.moco_version=="v0": + # This version is only compatible with single-GPU training. + self.grad_correct = GradCorrect(self.wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) + self.grad_correct.init_param() + self.grad_correct.rep_grad = False + elif self._cfg.moco_version=="v1": + cfg_moco = MoCoCfg( + beta0=0.9, beta_sigma=0.95, + gamma0=0.1, gamma_sigma=0.95, + rho=0.01, stat_interval=10000) + self.grad_correct = FastMoCo( + shared_module=self.wrapped_model, + world_task_num=self._cfg.total_task_num, # Total number of tasks globally + device=self._cfg.device, + multi_gpu=self._cfg.multi_gpu, + cfg=cfg_moco, + ) + + # Cache for plasticity-related metrics from the previous frame. + self._prev_plasticity_metrics = dict( + dormant_ratio_encoder = 0.0, + dormant_ratio_transformer = 0.0, + dormant_ratio_head = 0.0, + avg_weight_mag_encoder = 0.0, + avg_weight_mag_transformer = 0.0, + avg_weight_mag_head = 0.0, + e_rank_last_linear = 0.0, + e_rank_sim_norm = 0.0, + ) + + # ==================== START: 目标熵正则化初始化 ==================== + # 从配置中读取是否启用自适应alpha,并提供一个默认值 + self.use_adaptive_entropy_weight = self._cfg.get('use_adaptive_entropy_weight', True) + + # 在 _init_learn 中增加配置 + self.target_entropy_start_ratio = self._cfg.get('target_entropy_start_ratio', 0.98) + self.target_entropy_end_ratio = self._cfg.get('target_entropy_end_ratio', 0.7) + self.target_entropy_decay_steps = self._cfg.get('target_entropy_decay_steps', 200000) # 例如,在200k步内完成退火 2M envsteps + + if self.use_adaptive_entropy_weight: + # 1. 设置目标熵。对于离散动作空间,一个常见的启发式设置是动作空间维度的负对数乘以一个系数。 + # 这个系数(例如0.98)可以作为一个超参数。 + action_space_size = self._cfg.model.action_space_size + self.target_entropy = -np.log(1.0 / action_space_size) * 0.98 + + # 2. 初始化一个可学习的 log_alpha 参数。 + # 初始化为0,意味着初始的 alpha = exp(0) = 1.0。 + self.log_alpha = torch.nn.Parameter(torch.zeros(1, device=self._cfg.device), requires_grad=True) + + # 3. 为 log_alpha 创建一个专属的优化器。 + # 使用与主优化器不同的、较小的学习率(例如1e-4)通常更稳定。 + alpha_lr = self._cfg.get('adaptive_entropy_alpha_lr', 1e-4) + self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr) + + print("="*20) + print(">>> 目标熵正则化 (自适应Alpha) 已启用 <<<") + print(f" 目标熵 (Target Entropy): {self.target_entropy:.4f}") + print(f" Alpha 优化器学习率: {alpha_lr:.2e}") + print("="*20) + # ===================== END: 目标熵正则化初始化 ===================== + + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 30.0) + # ==================== START: 初始化 Encoder-Clip Annealing 参数 ==================== + self.use_encoder_clip_annealing = self._cfg.get('use_encoder_clip_annealing', False) + if self.use_encoder_clip_annealing: + self.encoder_clip_anneal_type = self._cfg.get('encoder_clip_anneal_type', 'cosine') + self.encoder_clip_start = self._cfg.get('encoder_clip_start_value', 30.0) + self.encoder_clip_end = self._cfg.get('encoder_clip_end_value', 10.0) + self.encoder_clip_anneal_steps = self._cfg.get('encoder_clip_anneal_steps', 200000) + + print("="*20) + print(">>> Encoder-Clip 退火已启用 <<<") + print(f" 类型: {self.encoder_clip_anneal_type}") + print(f" 范围: {self.encoder_clip_start} -> {self.encoder_clip_end}") + print(f" 步数: {self.encoder_clip_anneal_steps}") + print("="*20) + else: + # 如果不启用退火,则使用固定的 clip 阈值 + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 30.0) + # ===================== END: 初始化 Encoder-Clip Annealing 参数 ===================== + + # --- NEW: Policy Label Smoothing Parameters --- + self.policy_ls_eps_start = self._cfg.get('policy_ls_eps_start', 0.05) # TODO policy_label_smoothing_eps_start 越大的action space需要越大的eps + self.policy_ls_eps_end = self._cfg.get('policy_label_smoothing_eps_end ', 0.01) # TODO policy_label_smoothing_eps_start + self.policy_ls_eps_decay_steps = self._cfg.get('policy_ls_eps_decay_steps ', 50000) # TODO 50k + print(f"self.policy_ls_eps_start:{self.policy_ls_eps_start}") + + @staticmethod + def _is_zero(x: Union[float, torch.Tensor], eps: float = 1e-8) -> bool: + """ + Overview: + Checks if a scalar or a 0-D tensor can be considered zero within a small tolerance. + Arguments: + - x (:obj:`Union[float, torch.Tensor]`): The input value to check. + - eps (:obj:`float`): The tolerance for checking against zero. + Returns: + - (:obj:`bool`): True if the value is close to zero, False otherwise. + """ + if isinstance(x, torch.Tensor): + return torch.all(torch.abs(x) < eps).item() + return abs(x) < eps + + def _retain_prev_if_zero(self, name: str, + value: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + """ + Overview: + If the current `value` is close to zero, returns the cached value from the previous frame. + Otherwise, it updates the cache with the current value and returns it. This is useful for + metrics that are computed intermittently. + Arguments: + - name (:obj:`str`): The name of the metric to cache. + - value (:obj:`Union[float, torch.Tensor]`): The current value of the metric. + Returns: + - (:obj:`Union[float, torch.Tensor]`): The retained or current value. + """ + if self._is_zero(value): + # Directly return the previous value (can be float or tensor). + return self._prev_plasticity_metrics[name] + else: + # Update the cache and return the current value. + self._prev_plasticity_metrics[name] = value + return value + + + #@profile + def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_iter=None, ignore_grad=False) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning in the policy. This is the core of the training process. + Data is sampled from the replay buffer, losses are calculated, and the model is updated via backpropagation. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): A tuple of data batches, where each element corresponds to a different task. + - task_weights (:obj:`Any`, optional): Optional weights for each task's loss. Not currently used. + - ignore_grad (:obj:`bool`): If True, gradients are zeroed out after computation, effectively skipping the update. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): A dictionary containing current learning losses and statistics for logging. + """ + self._learn_model.train() + self._target_model.train() + + # Lists to store metrics for each task within the batch. + obs_loss_multi_task = [] + reward_loss_multi_task = [] + policy_loss_multi_task = [] + value_loss_multi_task = [] + latent_recon_loss_multi_task = [] + perceptual_loss_multi_task = [] + orig_policy_loss_multi_task = [] + policy_entropy_multi_task = [] + weighted_total_loss = 0.0 # Initialize to 0.0 to avoid in-place operations. + total_alpha_loss = 0.0 + + latent_state_l2_norms_multi_task = [] + average_target_policy_entropy_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + + # Metrics for network plasticity analysis. + dormant_ratio_encoder_multi_task = [] + dormant_ratio_transformer_multi_task = [] + dormant_ratio_head_multi_task = [] + avg_weight_mag_encoder_multi_task = [] + avg_weight_mag_transformer_multi_task = [] + avg_weight_mag_head_multi_task = [] + e_rank_last_linear_multi_task = [] + e_rank_sim_norm_multi_task = [] + + # --- NEW: Calculate current epsilon for policy --- + # if self.policy_ls_eps_start > 0: + # progress = min(1.0, train_iter / self.policy_ls_eps_decay_steps) + # current_policy_label_eps = self.policy_ls_eps_start * (1 - progress) + self.policy_ls_eps_end * progress + # else: + # current_policy_label_eps = 0.0 + current_policy_label_eps = 0.01 + + # 新增一个列表来收集当前批次中所有任务的真实全局ID + global_task_ids_in_batch = [] + alpha_loss = None + + + # 用于Alpha日志记录的新列表 + alpha_loss_multi_task = [] + target_entropy_multi_task = [] + + # 仅在自适应alpha启用时,预先获取当前alpha值,确保在单次迭代中对所有任务一致 + current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight + if self.use_adaptive_entropy_weight: + current_alpha = self.log_alpha.exp().detach() + + losses_list = [] # Used to store the loss tensor for each task, required by gradient correction methods. + for task_id, data_one_task in enumerate(data): + current_batch, target_batch, task_id = data_one_task # task_id 是真实的全局ID + + # 将真实的全局ID添加到列表中 + global_task_ids_in_batch.append(task_id) + + # TODO: Adapt RoPE for multitask settings (using timestep_batch). + obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch + target_reward, target_value, target_policy = target_batch + + # Prepare observations based on frame stack number. + if self._cfg.model.frame_stack_num == 4: + obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) + else: + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # Apply augmentations if needed. + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare action batch and convert to a torch tensor. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( + -1).long() # For discrete action space. + data_list = [mask_batch, target_reward.astype('float32'), target_value.astype('float32'), target_policy, + weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, + self._cfg.device) + + cur_batch_size = target_reward.size(0) # Run-time batch size. + + target_reward = target_reward.view(cur_batch_size, -1) + target_value = target_value.view(cur_batch_size, -1) + + # Transform scalar rewards and values to their scaled representations. + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # Convert scaled representations to categorical distributions. + # target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + # target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward, label_smoothing_eps= self._cfg.label_smoothing_eps) + target_value_categorical = phi_transform(self.value_support, transformed_target_value, label_smoothing_eps=self._cfg.label_smoothing_eps) + + + # Prepare the batch for the transformer-based world model. + batch_for_gpt = {} + if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + cur_batch_size, -1, self._cfg.model.observation_shape) + elif len(self._cfg.model.observation_shape) == 3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + cur_batch_size, -1, *self._cfg.model.observation_shape) + + batch_for_gpt['actions'] = action_batch.squeeze(-1) + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] + batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data. + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, + device=self._cfg.device) + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] + batch_for_gpt['target_policy'] = target_policy[:, :-1] + batch_for_gpt['scalar_target_value'] = target_value + + # Extract valid target policy data and compute its entropy. + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean().item() + + # Update world model and compute losses. + intermediate_losses = defaultdict(float) + # losses = self._learn_model.world_model.compute_loss( + # batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, task_id=task_id + # ) + + losses = self._learn_model.world_model.compute_loss( + batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, current_policy_label_eps=current_policy_label_eps, task_id=task_id + ) + + # ==================== START MODIFICATION 2 ==================== + # Extract the calculated value_priority from the returned losses. + value_priority_tensor = losses.intermediate_losses['value_priority'] + # Convert to numpy array for the replay buffer, adding a small epsilon. + value_priority_np = value_priority_tensor.detach().cpu().numpy() + 1e-6 + # ===================== END MODIFICATION 2 ===================== + + + # TODO: Accumulate the weighted total loss. This assumes the loss from `compute_loss` is already weighted. + weighted_total_loss += losses.loss_total # NOTE:+= + + # TODO: Add assertions to check for NaN or Inf values in the loss if needed for debugging. + # assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" + # assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" + + # TODO: Append the total loss for this task, used by MoCo. + losses_list.append(losses.loss_total) + + for loss_name, loss_value in losses.intermediate_losses.items(): + intermediate_losses[f"{loss_name}"] = loss_value + + + + obs_loss = intermediate_losses['loss_obs'] + reward_loss = intermediate_losses['loss_rewards'] + policy_loss = intermediate_losses['loss_policy'] + orig_policy_loss = intermediate_losses['orig_policy_loss'] + policy_entropy = intermediate_losses['policy_entropy'] + value_loss = intermediate_losses['loss_value'] + latent_recon_loss = intermediate_losses['latent_recon_loss'] + perceptual_loss = intermediate_losses['perceptual_loss'] + latent_state_l2_norms = intermediate_losses['latent_state_l2_norms'] + + # 从 losses 对象中提取策略熵 + # ==================== START: 目标熵正则化更新逻辑 ==================== + current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight # 默认使用固定值 + if self.use_adaptive_entropy_weight: + + # --- 动态计算目标熵 (这部分逻辑是正确的,予以保留) --- + progress = min(1.0, train_iter / self.target_entropy_decay_steps) + current_ratio = self.target_entropy_start_ratio * (1 - progress) + self.target_entropy_end_ratio * progress + action_space_size = self._cfg.model.action_space_size + # 注意:我们将 target_entropy 定义为正数,更符合直觉 + current_target_entropy = -np.log(1.0 / action_space_size) * current_ratio + + # --- 计算 alpha_loss (已修正符号) --- + # 这是核心修正点:去掉了最前面的负号 + # detach() 仍然是关键,确保 alpha_loss 的梯度只流向 log_alpha + alpha_loss_task = (self.log_alpha * (policy_entropy.detach() - current_target_entropy)).mean() # NOTE:= + + # # --- 更新 log_alpha --- + # self.alpha_optimizer.zero_grad() + # alpha_loss.backward() + # self.alpha_optimizer.step() + + # 累加alpha_loss + total_alpha_loss += alpha_loss_task + # 为日志记录收集每个任务的alpha_loss和目标熵 + alpha_loss_multi_task.append(alpha_loss_task) + target_entropy_multi_task.append(current_target_entropy) + + # --- [优化建议] 增加 log_alpha 裁剪作为安全措施 --- + with torch.no_grad(): + # 将 alpha 限制在例如 [1e-4, 10.0] 的范围内 + self.log_alpha.clamp_(np.log(1e-4), np.log(10.0)) + + # --- 使用当前更新后的 alpha (截断梯度流) --- + current_alpha = self.log_alpha.exp().detach() + + # 重新计算加权的策略损失和总损失 + # 注意:这里的 policy_entropy 已经是一个batch的平均值 + weighted_policy_loss = orig_policy_loss - current_alpha * policy_entropy + # 重新构建总损失 (不使用 losses.loss_total) + # 确保这里的权重与 LossWithIntermediateLosses 类中的计算方式一致 + self.obs_loss_weight = 10 + self.value_loss_weight = 0.5 + self.reward_loss_weight = 1. + self.policy_loss_weight = 1. + self.ends_loss_weight = 0. + total_loss = ( + self.reward_loss_weight * reward_loss + + self.value_loss_weight * value_loss + + self.policy_loss_weight * weighted_policy_loss + + self.obs_loss_weight * obs_loss # 假设 ssl_loss_weight 是 obs_loss 的权重 + # ... 如果还有其他损失项,也加进来 ... + ) + weighted_total_loss += (weights * total_loss).mean() # NOTE:+= + # ===================== END: 目标熵正则化更新逻辑 ===================== + + # ============ For value-based priority calculation ============ + # TODO: The following section for calculating value_priority is commented out. + # If re-enabled, ensure it correctly computes L1 loss between predicted and target values + # and handles CPU/Numpy conversion properly. + # original_value = self.value_inverse_scalar_transform_handle(logits_value.reshape(-1, 101)).reshape( + # batch_for_gpt['observations'].shape[0], batch_for_gpt['observations'].shape[1], 1) + # value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1)[:,0], target_value[:, 0]) + # value_priority = value_priority.data.cpu().numpy() + 1e-6 + # value_priority = torch.tensor(0., device=self._cfg.device) + # ============ End of value priority section ============ + + # Metrics related to network plasticity. + # Use the helper function to retain the previous value if the current one is zero. + dormant_ratio_encoder = self._retain_prev_if_zero( + 'dormant_ratio_encoder', + intermediate_losses['dormant_ratio_encoder']) + dormant_ratio_transformer = self._retain_prev_if_zero( + 'dormant_ratio_transformer', + intermediate_losses['dormant_ratio_transformer']) + dormant_ratio_head = self._retain_prev_if_zero( + 'dormant_ratio_head', + intermediate_losses['dormant_ratio_head']) + avg_weight_mag_encoder = self._retain_prev_if_zero( + 'avg_weight_mag_encoder', + intermediate_losses['avg_weight_mag_encoder']) + avg_weight_mag_transformer = self._retain_prev_if_zero( + 'avg_weight_mag_transformer', + intermediate_losses['avg_weight_mag_transformer']) + avg_weight_mag_head = self._retain_prev_if_zero( + 'avg_weight_mag_head', + intermediate_losses['avg_weight_mag_head']) + e_rank_last_linear = self._retain_prev_if_zero( + 'e_rank_last_linear', + intermediate_losses['e_rank_last_linear']) + e_rank_sim_norm = self._retain_prev_if_zero( + 'e_rank_sim_norm', + intermediate_losses['e_rank_sim_norm']) + + # Append all metrics for this task to their respective lists. + obs_loss_multi_task.append(obs_loss) + reward_loss_multi_task.append(reward_loss) + policy_loss_multi_task.append(policy_loss) + orig_policy_loss_multi_task.append(orig_policy_loss) + policy_entropy_multi_task.append(policy_entropy) + value_loss_multi_task.append(value_loss) + latent_recon_loss_multi_task.append(latent_recon_loss) + perceptual_loss_multi_task.append(perceptual_loss) + latent_state_l2_norms_multi_task.append(latent_state_l2_norms) + value_priority_multi_task.append(value_priority_tensor) + value_priority_mean_multi_task.append(value_priority_tensor.mean().item()) + + # Append plasticity metrics. + dormant_ratio_encoder_multi_task.append(dormant_ratio_encoder) + dormant_ratio_transformer_multi_task.append(dormant_ratio_transformer) + dormant_ratio_head_multi_task.append(dormant_ratio_head) + avg_weight_mag_encoder_multi_task.append(avg_weight_mag_encoder) + avg_weight_mag_transformer_multi_task.append(avg_weight_mag_transformer) + avg_weight_mag_head_multi_task.append(avg_weight_mag_head) + e_rank_last_linear_multi_task.append(e_rank_last_linear) + e_rank_sim_norm_multi_task.append(e_rank_sim_norm) + + + # ==================== [新增] 集成范数监控逻辑 ==================== + norm_log_dict = {} + # 检查是否达到监控频率 + if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): + with torch.no_grad(): + # 1. 监控模型参数范数 + param_norm_metrics = self._monitor_model_norms() + norm_log_dict.update(param_norm_metrics) + + # 2. 监控中间张量 x (Transformer的输出) + intermediate_x = losses.intermediate_losses.get('intermediate_tensor_x') + if intermediate_x is not None: + # x 的形状为 (B, T, E) + # 计算每个 token 的 L2 范数 + token_norms = intermediate_x.norm(p=2, dim=-1) + + # 记录这些范数的统计数据 + norm_log_dict['norm/x_token/mean'] = token_norms.mean().item() + norm_log_dict['norm/x_token/std'] = token_norms.std().item() + norm_log_dict['norm/x_token/max'] = token_norms.max().item() + norm_log_dict['norm/x_token/min'] = token_norms.min().item() + + # 3. 监控 logits 的详细统计 (Value, Policy, Reward) + logits_value = losses.intermediate_losses.get('logits_value') + if logits_value is not None: + norm_log_dict['logits/value/mean'] = logits_value.mean().item() + norm_log_dict['logits/value/std'] = logits_value.std().item() + norm_log_dict['logits/value/max'] = logits_value.max().item() + norm_log_dict['logits/value/min'] = logits_value.min().item() + norm_log_dict['logits/value/abs_max'] = logits_value.abs().max().item() + + logits_policy = losses.intermediate_losses.get('logits_policy') + if logits_policy is not None: + norm_log_dict['logits/policy/mean'] = logits_policy.mean().item() + norm_log_dict['logits/policy/std'] = logits_policy.std().item() + norm_log_dict['logits/policy/max'] = logits_policy.max().item() + norm_log_dict['logits/policy/min'] = logits_policy.min().item() + norm_log_dict['logits/policy/abs_max'] = logits_policy.abs().max().item() + + logits_reward = losses.intermediate_losses.get('logits_reward') + if logits_reward is not None: + norm_log_dict['logits/reward/mean'] = logits_reward.mean().item() + norm_log_dict['logits/reward/std'] = logits_reward.std().item() + norm_log_dict['logits/reward/max'] = logits_reward.max().item() + norm_log_dict['logits/reward/min'] = logits_reward.min().item() + norm_log_dict['logits/reward/abs_max'] = logits_reward.abs().max().item() + + # 4. 监控 obs_embeddings (Encoder输出) 的统计 + obs_embeddings = losses.intermediate_losses.get('obs_embeddings') + if obs_embeddings is not None: + # 计算每个 embedding 的 L2 范数 + emb_norms = obs_embeddings.norm(p=2, dim=-1) + norm_log_dict['embeddings/obs/norm_mean'] = emb_norms.mean().item() + norm_log_dict['embeddings/obs/norm_std'] = emb_norms.std().item() + norm_log_dict['embeddings/obs/norm_max'] = emb_norms.max().item() + norm_log_dict['embeddings/obs/norm_min'] = emb_norms.min().item() + # ================================================================= + + # Core learn model update step. + self._optimizer_world_model.zero_grad() + + if self.use_adaptive_entropy_weight: + self.alpha_optimizer.zero_grad() + # 2. 计算最终的alpha loss (在累加后取平均) + final_alpha_loss = None + if self.use_adaptive_entropy_weight: + if len(data) > 0: + final_alpha_loss = total_alpha_loss / len(data) + else: # 防御性编程,避免除以0 + final_alpha_loss = torch.tensor(0.0, device=self._cfg.device) + + # Assuming losses_list is a list of tensors with gradients, e.g., [loss1, loss2, ...]. + if self._cfg.use_moco: + # Call MoCo's backward method, which handles gradient correction internally. + if self._cfg.moco_version=="v0": + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + elif self._cfg.moco_version=="v1": + lambd, stats = self.grad_correct.backward(losses_list) + + # 单独为alpha loss进行反向传播 + if self.use_adaptive_entropy_weight: + final_alpha_loss.backward() + + elif self._cfg.only_use_moco_stats: + # Only compute MoCo stats without applying gradient correction. + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + + # Each rank performs its own backpropagation. + # weighted_total_loss.backward() + + # 如果启用自适应alpha,将alpha loss加到主损失上一起反向传播 + if self.use_adaptive_entropy_weight: + (weighted_total_loss + final_alpha_loss).backward() + elif weighted_total_loss != 0.0: # 确保有损失可以反向传播 + weighted_total_loss.backward() + + else: + # If not using gradient correction, each rank performs standard backpropagation. + lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device) + + # weighted_total_loss.backward() + + # 如果启用自适应alpha,将alpha loss加到主损失上一起反向传播 + if self.use_adaptive_entropy_weight: + (weighted_total_loss + final_alpha_loss).backward() + elif weighted_total_loss != 0.0: # 确保有损失可以反向传播 + weighted_total_loss.backward() + + # ----------------------------------------------------------------- + # 仍然在 torch.no_grad() 环境下执行 + # ================================================================= + with torch.no_grad(): + # 1. Encoder-Clip + # ==================== START: 动态计算当前 Clip 阈值 ==================== + current_clip_value = self.latent_norm_clip_threshold # 默认使用固定值 + if self.use_encoder_clip_annealing: + progress = min(1.0, train_iter / self.encoder_clip_anneal_steps) + + if self.encoder_clip_anneal_type == 'cosine': + # 余弦调度: 从1平滑过渡到0 + cosine_progress = 0.5 * (1.0 + np.cos(np.pi * progress)) + current_clip_value = self.encoder_clip_end + \ + (self.encoder_clip_start - self.encoder_clip_end) * cosine_progress + else: # 默认为线性调度 + current_clip_value = self.encoder_clip_start * (1 - progress) + \ + self.encoder_clip_end * progress + # ===================== END: 动态计算当前 Clip 阈值 ===================== + + # 1. Encoder-Clip (使用动态计算出的 current_clip_value) + if current_clip_value > 0 and 'obs_embeddings' in losses.intermediate_losses: + obs_embeddings = losses.intermediate_losses['obs_embeddings'] + if obs_embeddings is not None: + max_latent_norm = obs_embeddings.norm(p=2, dim=-1).max() + if max_latent_norm > current_clip_value: + scale_factor = current_clip_value / max_latent_norm.item() + # 不再频繁打印,或者可以改为每隔N步打印一次 + if train_iter % 1000 == 0: + print(f"[Encoder-Clip Annealing] Iter {train_iter}: Max latent norm {max_latent_norm.item():.2f} > {current_clip_value:.2f}. Scaling by {scale_factor:.4f}.") + scale_module_weights_vectorized(self._model.world_model.tokenizer.encoder, scale_factor) + + + # For debugging purposes. + # for name, param in self._learn_model.world_model.tokenizer.encoder.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + # if param.requires_grad: + # print(name, param.grad.norm()) + + # ==================== [新增] 监控梯度范数 ==================== + # 在梯度裁剪之前监控梯度范数,用于诊断梯度爆炸/消失问题 + if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): + grad_norm_metrics = self._monitor_gradient_norms() + norm_log_dict.update(grad_norm_metrics) + # ================================================================= + + if self._cfg.analysis_sim_norm: + del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after + self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() + self._target_model.encoder_hook.clear_data() + + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), + self._cfg.grad_clip_value) + + if ignore_grad: + # NOTE: For cases where all tasks on a GPU are solved, `train` is still called for DDP synchronization, + # but gradients should be zeroed out to prevent updates. + self._optimizer_world_model.zero_grad() + + if self._cfg.multi_gpu: + # If not using a gradient correction method that handles it, sync gradients manually. + if not self._cfg.use_moco: + self.sync_gradients(self._learn_model) + + self._optimizer_world_model.step() + + # 4. 更新Alpha优化器 + if self.use_adaptive_entropy_weight: + self.alpha_optimizer.step() + # 裁剪log_alpha以保证稳定性 + with torch.no_grad(): + self.log_alpha.clamp_(np.log(1e-4), np.log(10.0)) + + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # Core target model update step. + self._target_model.update(self._learn_model.state_dict()) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0. + max_memory_allocated_gb = 0. + + # Build the dictionary of return values for logging. + return_log_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self._collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # ==================== START: 添加新日志项 ==================== + if self.use_adaptive_entropy_weight: + return_log_dict['adaptive_alpha'] = current_alpha.item() + return_log_dict['adaptive_target_entropy_ratio'] = current_ratio + return_log_dict['final_alpha_loss'] = final_alpha_loss.item() + # ==================== START: 添加新日志项 ==================== + + # Generate task-related loss dictionaries and prefix each task-related loss with "noreduce_". + multi_task_loss_dicts = { + **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', task_id=self.task_id), #global_task_ids=global_task_ids_in_batch), # task_id=self.task_id), + **generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_head_multi_task, 'noreduce_dormant_ratio_head_task{}', task_id=self.task_id), + + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(orig_policy_loss_multi_task, 'noreduce_orig_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'noreduce_target_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + + # 新增alpha相关日志 + **generate_task_loss_dict(alpha_loss_multi_task, 'noreduce_alpha_loss_task{}', self.task_id), + **generate_task_loss_dict(target_entropy_multi_task, 'noreduce_target_entropy_task{}', self.task_id), + } + return_log_dict.update(multi_task_loss_dicts) + + + if self._learn_model.world_model.do_analysis: + # Include plasticity metrics if analysis is enabled. + plasticity_loss_dicts = { + **generate_task_loss_dict(dormant_ratio_encoder_multi_task, 'noreduce_dormant_ratio_encoder_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_transformer_multi_task, 'noreduce_dormant_ratio_transformer_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_head_multi_task, 'noreduce_dormant_ratio_head_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_encoder_multi_task, 'noreduce_avg_weight_mag_encoder_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_transformer_multi_task, 'noreduce_avg_weight_mag_transformer_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_head_multi_task, 'noreduce_avg_weight_mag_head_task{}', task_id=self.task_id), + **generate_task_loss_dict(e_rank_last_linear_multi_task, 'noreduce_e_rank_last_linear_task{}', task_id=self.task_id), + **generate_task_loss_dict(e_rank_sim_norm_multi_task, 'noreduce_e_rank_sim_norm_task{}', task_id=self.task_id), + } + # Merge the dictionaries. + return_log_dict.update(plasticity_loss_dicts) + + # ==================== [修改] 将范数监控结果合并到日志中 ==================== + if norm_log_dict: + return_log_dict.update(norm_log_dict) + # ======================================================================= + + # Return the final loss dictionary. + return return_log_dict + + def monitor_weights_and_grads(self, model: torch.nn.Module) -> None: + """ + Overview: + A utility function to print the mean and standard deviation of weights and their gradients for each layer in a model. + Useful for debugging training issues like exploding or vanishing gradients. + Arguments: + - model (:obj:`torch.nn.Module`): The model to monitor. + """ + for name, param in model.named_parameters(): + if param.requires_grad: + print(f"Layer: {name} | " + f"Weight mean: {param.data.mean():.4f} | " + f"Weight std: {param.data.std():.4f} | " + f"Grad mean: {param.grad.mean():.4f} | " + f"Grad std: {param.grad.std():.4f}") + + def _init_collect(self) -> None: + """ + Overview: + Initializes the collect mode. This method is called by ``self.__init__``. + It sets up the collect model and MCTS utilities for data collection. + """ + self._collect_model = self._model + + # Create a copy of the configuration for collect MCTS and set a specific number of simulations. + mcts_collect_cfg = copy.deepcopy(self._cfg) + mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations + + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(mcts_collect_cfg) + else: + self._mcts_collect = MCTSPtree(mcts_collect_cfg) + + self._collect_mcts_temperature = 1. + self._collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + + # TODO: The num_tasks parameter is hardcoded. It should ideally be derived from the config. + def _monitor_vars_learn(self, num_tasks: int = 2) -> List[str]: + """ + Overview: + Registers variables to be monitored during training. These variables will be logged in TensorBoard. + It dynamically creates variable names for each task if `num_tasks` is provided. + Arguments: + - num_tasks (:obj:`int`): The number of tasks being trained on the current rank. + Returns: + - monitored_vars (:obj:`List[str]`): A list of strings, where each string is the name of a variable to be logged. + """ + # Basic monitored variables that do not depend on the number of tasks. + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + + # 'value_priority', + 'adaptive_alpha', + "adaptive_target_entropy_ratio", + 'final_alpha_loss', + ] + + # ==================== [新增] 范数和中间张量监控变量 ==================== + # 这些变量对所有任务是共享的(不是per-task的) + norm_vars = [ + # 模块总范数 (参数范数) - 共享模块 + 'norm/encoder/_total_norm', + 'norm/transformer/_total_norm', + + # 模块总范数 (梯度范数) - 共享模块 + 'grad/encoder/_total_norm', + 'grad/transformer/_total_norm', + + # 中间张量 x (Transformer输出) 的统计信息 + 'norm/x_token/mean', + 'norm/x_token/std', + 'norm/x_token/max', + 'norm/x_token/min', + + # Logits 的详细统计 (Value) + 'logits/value/mean', + 'logits/value/std', + 'logits/value/max', + 'logits/value/min', + 'logits/value/abs_max', + + # Logits 的详细统计 (Policy) + 'logits/policy/mean', + 'logits/policy/std', + 'logits/policy/max', + 'logits/policy/min', + 'logits/policy/abs_max', + + # Logits 的详细统计 (Reward) + 'logits/reward/mean', + 'logits/reward/std', + 'logits/reward/max', + 'logits/reward/min', + 'logits/reward/abs_max', + + # Embeddings 的统计信息 + 'embeddings/obs/norm_mean', + 'embeddings/obs/norm_std', + 'embeddings/obs/norm_max', + 'embeddings/obs/norm_min', + ] + monitored_vars.extend(norm_vars) + # ======================================================================== + + + + # Task-specific variables to be monitored. + task_specific_vars = [ + 'noreduce_obs_loss', + 'noreduce_orig_policy_loss', + 'noreduce_policy_loss', + 'noreduce_latent_recon_loss', + 'noreduce_policy_entropy', + 'noreduce_target_policy_entropy', + 'noreduce_reward_loss', + 'noreduce_value_loss', + 'noreduce_perceptual_loss', + 'noreduce_latent_state_l2_norms', + 'noreduce_lambd', + 'noreduce_value_priority_mean', + # Metrics related to network plasticity. + 'noreduce_dormant_ratio_encoder', + 'noreduce_dormant_ratio_transformer', + 'noreduce_dormant_ratio_head', + 'noreduce_avg_weight_mag_encoder', + 'noreduce_avg_weight_mag_transformer', + 'noreduce_avg_weight_mag_head', + 'noreduce_e_rank_last_linear', + 'noreduce_e_rank_sim_norm', + "noreduce_alpha_loss", + "noreduce_target_entropy", + + ] + + # Use self.task_num_for_current_rank as the number of tasks for the current rank. + num_tasks = self.task_num_for_current_rank + # If the number of tasks is provided, extend the monitored variables list with task-specific variable names. + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + monitored_vars.append(f'{var}_task{self.task_id+task_idx}') + else: + # If num_tasks is not provided, assume a single task and use the original variable names. + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + #@profile + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + timestep: List = [0], + task_id: int = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data. It uses the model to perform MCTS search and + selects actions via sampling to encourage exploration. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e., the current observation. + - action_mask (:obj:`list`, optional): A list of action masks for each environment. + - temperature (:obj:`float`, optional): The temperature for MCTS action selection. + - to_play (:obj:`List`, optional): A list of player IDs for each environment. + - epsilon (:obj:`float`, optional): The probability for epsilon-greedy exploration. + - ready_env_id (:obj:`np.array`, optional): An array of IDs for environments that are ready for a new action. + - timestep (:obj:`List`, optional): The current timestep in each environment. + - task_id (:obj:`int`, optional): The ID of the task for the current environments. + Returns: + - output (:obj:`Dict`): A dictionary where keys are environment IDs and values are dictionaries + containing the selected action and other MCTS statistics. + """ + self._collect_model.eval() + + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + + # ========================== 核心修复 ========================== + # C++ 绑定需要一个 list,即使它在 MuZero 中代表奖励。 + reward_roots = reward_roots.detach().cpu().numpy().tolist() + # =============================================================== + + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + # The main difference between collect and eval is the addition of Dirichlet noise at the root. + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # C++ MCTS tree implementation. + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # Python MCTS tree implementation. + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + + # # 在本文件开始,通过全局变量来控制是否处于调试状态 + # global DEBUG_ENABLED;DEBUG_ENABLED = True + # import torch.distributed as dist + # if dist.get_rank() == 0 and DEBUG_ENABLED: + # print(f"rank {dist.get_rank()} 进入调试模式,输入interact,可以键入整段的python代码调试。通过设置 DEBUG_ENABLED = False, 可以跳过调试状态") + # import ipdb; ipdb.set_trace() + # # 同步点,防止其它进程早跑 + # dist.barrier() + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep= timestep, task_id=task_id) + + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + if self._cfg.eps.eps_greedy_exploration_in_collect: + # Epsilon-greedy collection strategy. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self._collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # Standard collection strategy (sampling from MCTS policy). + # NOTE: `action_index_in_legal_action_set` is the index within the set of legal actions. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # Convert the index back to the action in the full action space. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + # ============== TODO: This section is for visualization purposes only and should be removed for training. ============== + # It forces deterministic action selection during collection. + # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + # distributions, temperature=self._collect_mcts_temperature, deterministic=True + # ) + # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # ============== End of visualization section. ============== + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + # ========= TODO: This logic is currently for the `muzero_segment_collector`. ========= + if active_collect_env_num < self.collector_env_num: + # When one environment in `collect_env` finishes early, the length of `self.last_batch_obs` is reduced. + # The transformer needs the `env_id` to retrieve from the KV cache, which is complex to manage with a dynamic batch size. + # Therefore, we reset `self.last_batch_action` for all environments to -1, forcing the transformer + # to start from scratch and avoid retrieval errors. + print('==========collect_forward============') + print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + self._reset_collect(reset_init_data=True, task_id=task_id) + if getattr(self._cfg, 'sample_type', '') == 'episode': + print('BUG: sample_type is episode, but len(self.last_batch_obs) < self.collector_env_num') + + return output + + def _init_eval(self) -> None: + """ + Overview: + Initializes the eval mode. This method is called by ``self.__init__``. + It sets up the eval model and MCTS utilities for evaluation. + """ + self._eval_model = self._model + + # Create a copy of the configuration for eval MCTS and set a specific number of simulations. + mcts_eval_cfg = copy.deepcopy(self._cfg) + mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations + + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(mcts_eval_cfg) + else: + self._mcts_eval = MCTSPtree(mcts_eval_cfg) + + self.evaluator_env_num = self._cfg.evaluator_env_num + + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + #@profile + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, timestep: List = [0], task_id: int = None) -> Dict: + """ + Overview: + The forward function for evaluating the policy. It uses the model to perform MCTS search and + selects actions deterministically (choosing the one with the highest visit count). + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e., the current observation. + - action_mask (:obj:`list`): A list of action masks for each environment. + - to_play (:obj:`int`, optional): The player ID for the current turn. + - ready_env_id (:obj:`np.array`, optional): An array of IDs for environments that are ready for a new action. + - timestep (:obj:`List`, optional): The current timestep in each environment. + - task_id (:obj:`int`, optional): The ID of the task for the current environments. + Returns: + - output (:obj:`Dict`): A dictionary where keys are environment IDs and values are dictionaries + containing the selected action and other MCTS statistics. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + # ========================== 核心修复 ========================== + # C++ 绑定需要一个 list,即使它在 MuZero 中代表奖励。 + reward_roots = reward_roots.detach().cpu().numpy().tolist() # TODO============================= + # =============================================================== + + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # C++ MCTS tree implementation. + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # Python MCTS tree implementation. + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + + # During evaluation, no noise is added to the root policy. + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep= timestep, task_id=task_id) + + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + + batch_action = [] + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + # NOTE: `deterministic=True` means we select the action with the highest visit count (argmax) + # rather than sampling, which is standard for evaluation. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # Convert the index back to the action in the full action space. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs_eval = data + self.last_batch_action = batch_action + + return output + + #@profile + def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Overview: + Resets the collection process for a specific environment or all environments. + It can clear caches and reset initial data to ensure optimal performance and prevent state leakage. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None, the reset applies more broadly. Defaults to None. + - current_steps (:obj:`int`, optional): The current step count in the environment, used to trigger periodic cache clearing. Defaults to 0. + - reset_init_data (:obj:`bool`, optional): If True, resets the initial observation and action buffers. Defaults to True. + - task_id (:obj:`int`, optional): The task ID, currently unused in this method. Defaults to None. + """ + if reset_init_data: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + # print('Collector: last_batch_obs and last_batch_action have been reset.') + + # Return immediately if env_id is not a single integer (e.g., None or a list). + # if env_id is None or isinstance(env_id, list): + # return + + # We must handle both single int and list of ints for env_id. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the collector. + if current_steps is None: + world_model = self._collect_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end.') + + + # Determine the clear interval based on the environment's sample type. + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length + + # Clear caches periodically to manage memory. + # if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: + + print(f'clear_interval: {clear_interval}') + + # Clear various KV caches in the collect model's world model. + world_model = self._collect_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up unused GPU memory. + torch.cuda.empty_cache() + + print(f'Collector: Caches cleared for collect_model at step {current_steps} for env {env_id}.') + + # TODO: Check if resetting the target model here is correct and necessary. + self._reset_target_model() + + #@profile + def _reset_target_model(self) -> None: + """ + Overview: + Resets the target model by clearing its internal caches. This is crucial for managing memory, + especially when using transformer-based models with KV caching. + """ + # Clear various KV caches in the target model's world model. + world_model = self._target_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up unused GPU memory. + torch.cuda.empty_cache() + print('Collector: Target model past_kv_cache cleared.') + + #@profile + def _reset_eval(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Overview: + Resets the evaluation process for a specific environment or all environments. + Clears caches and resets initial data to ensure clean evaluation runs. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. Defaults to None. + - current_steps (:obj:`int`, optional): The current step count, used for periodic cache clearing. Defaults to 0. + - reset_init_data (:obj:`bool`, optional): If True, resets the initial observation and action buffers. Defaults to True. + - task_id (:obj:`int`, optional): The task ID. Can be used to handle different observation shapes per task. Defaults to None. + """ + if reset_init_data: + self.last_batch_obs_eval = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + # print(f'Evaluator reset: last_batch_obs_eval shape: {self.last_batch_obs_eval.shape}') + + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + + + # --- BEGIN ROBUST FIX --- + # This logic handles the crucial end-of-episode cache clearing for evaluation. + # The evaluator calls `_policy.reset([env_id])` when an episode is done. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the evaluator. + if current_steps is None: + world_model = self._eval_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end.') + + # The recurrent cache is global. + world_model.past_kv_cache_recurrent_infer.clear() + + if hasattr(world_model, 'keys_values_wm_list'): + world_model.keys_values_wm_list.clear() + + torch.cuda.empty_cache() + return + # --- END ROBUST FIX --- + + # Determine the clear interval. + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length + + # Clear caches periodically. + # if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: + + print(f'clear_interval: {clear_interval}') + + # Clear various KV caches in the eval model's world model. + world_model = self._eval_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up unused GPU memory. + torch.cuda.empty_cache() + + print(f'Evaluator: Caches cleared for eval_model at step {current_steps} for env {env_id}.') + + + def recompute_pos_emb_diff_and_clear_cache(self) -> None: + """ + Overview: + Clears all KV caches and precomputes positional embedding matrices in the model. + This is typically called when the maximum sequence length changes. + """ + # NOTE: This must be done for both the collect and target models. + for model in [self._collect_model, self._target_model]: + model.world_model.precompute_pos_emb_diff_kv() + model.world_model.clear_caches() + torch.cuda.empty_cache() + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Returns the state dictionary of the learn mode. + This typically includes the model, target model, and optimizer states, + which are necessary for saving and resuming training. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The state dictionary for the current learning progress. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer_world_model': self._optimizer_world_model.state_dict(), + } + + # ========== NOTE: This is the original version which loads all parameters from the state_dict. ========== + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Loads the state_dict into the policy's learn mode. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The state dictionary saved from a previous training session. + # """ + # self._learn_model.load_state_dict(state_dict['model']) + # self._target_model.load_state_dict(state_dict['target_model']) + # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + # ========== NOTE: This is a pretrain-finetune version that selectively loads parameters and freezes layers. ========== + def _load_state_dict_learn(self, state_dict: Dict[str, Any], finetune_components: List[str] = []) -> None: + """ + Overview: + Loads a state_dict for fine-tuning. It excludes multi-task specific parameters + and can freeze parts of the model (e.g., encoder, transformer) based on `finetune_components`. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The state dictionary from a pre-trained model. + - finetune_components (:obj:`List[str]`, optional): A list of component names (e.g., "encoder", "transformer") + that will remain trainable. Components not in this list will have their parameters frozen. + """ + # Example configurations for fine-tuning: + # finetune_components = [] # Loads encoder & transformer, fine-tunes only heads. + # finetune_components = ['transformer'] # Loads encoder & transformer, fine-tunes transformer & heads. + finetune_components = ["representation_network", "encoder"] # Loads encoder & transformer, fine-tunes encoder & heads. + + # Define prefixes of parameters to be excluded from loading (typically multi-task heads). + exclude_prefixes = [ + '_orig_mod.world_model.head_policy_multi_task.', + '_orig_mod.world_model.head_value_multi_task.', + '_orig_mod.world_model.head_rewards_multi_task.', + '_orig_mod.world_model.head_observations_multi_task.', + '_orig_mod.world_model.task_emb.' + ] + + # Define specific parameter keys to be excluded (for special cases like task embeddings). + exclude_keys = [ + '_orig_mod.world_model.task_emb.weight', + '_orig_mod.world_model.task_emb.bias', + ] + + def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + """ + Filters out parameters from a state_dict based on prefixes and specific keys. + """ + filtered = {} + for k, v in state_dict_loader.items(): + if any(k.startswith(prefix) for prefix in exclude_prefixes): + print(f"Excluding parameter: {k}") # For debugging + continue + if k in exclude_keys: + print(f"Excluding specific parameter: {k}") # For debugging + continue + filtered[k] = v + return filtered + + # Filter and load the 'model' state_dict. + if 'model' in state_dict: + model_state_dict = state_dict['model'] + filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + if missing_keys: + print(f"Missing keys when loading _learn_model: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") + else: + print("No 'model' key found in the state_dict.") + + # Filter and load the 'target_model' state_dict. + if 'target_model' in state_dict: + target_model_state_dict = state_dict['target_model'] + filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + if missing_keys: + print(f"Missing keys when loading _target_model: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys when loading _target_model: {unexpected_keys}") + else: + print("No 'target_model' key found in the state_dict.") + + # Handle freezing/unfreezing of parameters in _learn_model based on finetune_components. + # This assumes a naming convention where component names are present in parameter names. + for name, param in self._learn_model.named_parameters(): + # Freeze the encoder if "encoder" is not in finetune_components. + if "encoder" in name and "encoder" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + # Freeze the representation network if "representation_network" is not in finetune_components. + elif "representation_network" in name and "representation_network" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + # Freeze the transformer if "transformer" is not in finetune_components. + elif "transformer" in name and "transformer" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + else: + # Other parameters remain trainable by default. + print(f"Parameter remains trainable: {name}") + + # NOTE: For more complex model structures, it might be better to identify modules by their class + # rather than relying on parameter names. For example: + # for module in self._learn_model.modules(): + # if isinstance(module, EncoderModule) and "encoder" not in finetune_components: + # for param in module.parameters(): + # param.requires_grad = False + + # ========== NOTE: Another pretrain-finetune version. The main difference from the above is the freezing logic and comments. ========== + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Loads a state_dict into the policy's learn mode, excluding multi-task related parameters. + # This is intended for fine-tuning a pre-trained model on new tasks. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The state dictionary from a pre-trained model. + # """ + # # Define prefixes of parameters to be excluded. + # exclude_prefixes = [ + # '_orig_mod.world_model.head_policy_multi_task.', + # '_orig_mod.world_model.head_value_multi_task.', + # '_orig_mod.world_model.head_rewards_multi_task.', + # '_orig_mod.world_model.head_observations_multi_task.', + # '_orig_mod.world_model.task_emb.' + # ] + + # # Define specific parameter keys to be excluded. + # exclude_keys = [ + # '_orig_mod.world_model.task_emb.weight', + # '_orig_mod.world_model.task_emb.bias', + # ] + + # def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + # """ + # Filters out parameters that should not be loaded. + # """ + # filtered = {} + # for k, v in state_dict_loader.items(): + # if any(k.startswith(prefix) for prefix in exclude_prefixes): + # print(f"Excluding parameter: {k}") + # continue + # if k in exclude_keys: + # print(f"Excluding specific parameter: {k}") + # continue + # filtered[k] = v + # return filtered + + # # Filter and load the 'model' part. + # if 'model' in state_dict: + # model_state_dict = state_dict['model'] + # filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + # missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + # if missing_keys: + # print(f"Missing keys when loading _learn_model: {missing_keys}") + # if unexpected_keys: + # print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") + # else: + # print("No 'model' key found in the state_dict.") + + # # Filter and load the 'target_model' part. + # if 'target_model' in state_dict: + # target_model_state_dict = state_dict['target_model'] + # filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + # missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + # if missing_keys: + # print(f"Missing keys when loading _target_model: {missing_keys}") + # if unexpected_keys: + # print(f"Unexpected keys when loading _target_model: {unexpected_keys}") + # else: + # print("No 'target_model' key found in the state_dict.") + + # # Do not load the optimizer's state_dict when fine-tuning, as it contains state (like momentum) + # # specific to the pre-training task, which can hinder adaptation to new tasks. + # # A fresh optimizer is usually preferred. + # # if 'optimizer_world_model' in state_dict: + # # ... \ No newline at end of file diff --git a/lzero/policy/unizero_multitask_alpha_indep.py b/lzero/policy/unizero_multitask_alpha_indep.py new file mode 100644 index 000000000..db2b4c513 --- /dev/null +++ b/lzero/policy/unizero_multitask_alpha_indep.py @@ -0,0 +1,2000 @@ +import copy +from collections import defaultdict +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY + +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import UniZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import prepare_obs_stack_for_unizero +from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs +from lzero.policy.unizero import UniZeroPolicy, scale_module_weights_vectorized +from .utils import configure_optimizers_nanogpt +import sys + +# Please replace the path with the actual location of your LibMTL library. +sys.path.append('/path/to/your/LibMTL') + +from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect +from LibMTL.weighting.moco_fast_mem_eff import FastMoCoMemEff as FastMoCo +from LibMTL.weighting.moco_fast_mem_eff import MoCoCfg + +import torch.distributed as dist + +# ------------------------------------------------------------ +# 1. Add a dedicated process-group for the learner. +# (This function should be called once during the initialization of the main process or the learner.) +# ------------------------------------------------------------ +def build_learner_group(learner_ranks: list[int]) -> dist.ProcessGroup: + """ + Overview: + Builds and returns a new process group containing only the learner ranks. + This is used for methods like GenericMoCo that require collective communication + only among the ranks performing training. + Arguments: + - learner_ranks (:obj:`list[int]`): A list of world ranks that are designated as learners. + These are the ranks that will perform the backward pass. + e.g., if CUDA_VISIBLE_DEVICES=0,1, then learner_ranks=[0,1]. + Returns: + - pg (:obj:`dist.ProcessGroup`): A new process group containing only the learner ranks. + """ + world_pg = dist.group.WORLD + pg = dist.new_group(ranks=learner_ranks, backend='nccl') + if dist.get_rank() in learner_ranks: + torch.cuda.set_device(learner_ranks.index(dist.get_rank())) + return pg + + +def generate_task_loss_dict(multi_task_losses: List[Union[torch.Tensor, float]], task_name_template: str, task_id: int) -> Dict[str, float]: + """ + Overview: + Generates a dictionary for the losses of each task. + Arguments: + - multi_task_losses (:obj:`List[Union[torch.Tensor, float]]`): A list containing the loss for each task. + - task_name_template (:obj:`str`): The template for the task name, e.g., 'obs_loss_task{}'. + - task_id (:obj:`int`): The starting ID of the tasks. + Returns: + - task_loss_dict (:obj:`Dict[str, float]`): A dictionary where keys are formatted task names and values are the corresponding losses. + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + # Get the scalar value of the loss if it's a tensor. + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + except Exception as e: + task_loss_dict[task_name] = task_loss + return task_loss_dict + +# # 修改后的函数: +# def generate_task_loss_dict( +# multi_task_losses: List[Union[torch.Tensor, float]], +# task_name_template: str, +# global_task_ids: List[int] +# ) -> Dict[str, float]: +# """ +# Overview: +# Generates a dictionary for the losses of each task using their explicit global IDs. +# Arguments: +# - multi_task_losses (:obj:`List[Union[torch.Tensor, float]]`): A list containing the loss for each task. +# - task_name_template (:obj:`str`): The template for the task name, e.g., 'obs_loss_task{}'. +# - global_task_ids (:obj:`List[int]`): A list of global task IDs corresponding to each loss in multi_task_losses. +# Returns: +# - task_loss_dict (:obj:`Dict[str, float]`): A dictionary where keys are formatted task names and values are the corresponding losses. +# """ +# task_loss_dict = {} +# # 使用 zip 将每个损失与其正确的全局ID配对 +# for task_loss, global_id in zip(multi_task_losses, global_task_ids): +# task_name = task_name_template.format(global_id) +# try: +# task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss +# except Exception as e: +# task_loss_dict[task_name] = task_loss +# return task_loss_dict + + +class WrappedModel: + """ + Overview: + A wrapper class for the world model to conveniently access its parameters and zero its gradients. + This version wraps the entire world model. + """ + def __init__(self, world_model: torch.nn.Module): + """ + Arguments: + - world_model (:obj:`torch.nn.Module`): The world model instance. + """ + self.world_model = world_model + + def parameters(self) -> iter: + """ + Overview: + Returns an iterator over the parameters of the entire world model. + """ + return self.world_model.parameters() + + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all world model parameters to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ + self.world_model.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV2: + """ + Overview: + A wrapper for specific components of the world model. + This version is designed to group parameters that are considered "shared" + across tasks for gradient correction methods like MoCo, excluding the prediction heads. + """ + def __init__(self, tokenizer: torch.nn.Module, transformer: torch.nn.Module, pos_emb: torch.nn.Module, task_emb: torch.nn.Module, act_embedding_table: torch.nn.Module): + """ + Arguments: + - tokenizer (:obj:`torch.nn.Module`): The tokenizer module. + - transformer (:obj:`torch.nn.Module`): The transformer backbone. + - pos_emb (:obj:`torch.nn.Module`): The positional embedding module. + - task_emb (:obj:`torch.nn.Module`): The task embedding module. + - act_embedding_table (:obj:`torch.nn.Module`): The action embedding table. + """ + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self) -> iter: + """ + Overview: + Returns an iterator over the parameters of the wrapped components (tokenizer, transformer, embeddings). + These are typically the shared parts of the model whose gradients need to be managed for multi-task learning. + """ + return (list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + # list(self.task_emb.parameters()) + # TODO: Decide whether to include task embeddings in shared parameters. + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all wrapped components to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + # self.task_emb.zero_grad(set_to_none=set_to_none) # TODO: Match the decision made in the parameters() method. + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV3: + """ + Overview: + An alternative wrapper for world model components. + This version excludes the tokenizer from the shared parameters, focusing gradient correction + on the transformer and embedding layers. + """ + def __init__(self, transformer: torch.nn.Module, pos_emb: torch.nn.Module, task_emb: torch.nn.Module, act_embedding_table: torch.nn.Module): + """ + Arguments: + - transformer (:obj:`torch.nn.Module`): The transformer backbone. + - pos_emb (:obj:`torch.nn.Module`): The positional embedding module. + - task_emb (:obj:`torch.nn.Module`): The task embedding module. + - act_embedding_table (:obj:`torch.nn.Module`): The action embedding table. + """ + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self) -> iter: + """ + Overview: + Returns an iterator over the parameters of the transformer and various embedding layers. + """ + return (list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of the wrapped components to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + +# def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): +# """ +# 为UniZero模型配置带有差异化学习率的优化器。 +# """ +# # 1. 定义需要特殊处理的参数 +# param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + +# # 2. 将参数分为三组:Transformer主干、Tokenizer、Heads +# transformer_params = {pn: p for pn, p in param_dict.items() if 'transformer' in pn} +# tokenizer_params = {pn: p for pn, p in param_dict.items() if 'tokenizer' in pn} + +# # Heads的参数是那些既不属于transformer也不属于tokenizer的 +# head_params = { +# pn: p for pn, p in param_dict.items() +# if 'transformer' not in pn and 'tokenizer' not in pn +# } + +# # 3. 为每组设置不同的优化器参数(特别是学习率) +# # 这里我们仍然使用AdamW,但学习率设置更合理 +# optim_groups = [ +# { +# 'params': list(transformer_params.values()), +# 'lr': learning_rate, # 1e-4 +# # 'lr': learning_rate * 0.2, # 为Transformer主干设置一个较小的学习率,例如 1e-5 +# 'weight_decay': weight_decay +# # 'weight_decay': weight_decay * 5.0 +# }, +# { +# 'params': list(tokenizer_params.values()), +# 'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4 +# # 'lr': learning_rate * 0.1, # 为encoder设置一个较小的学习率,例如 1e-5 +# 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + +# }, +# { +# 'params': list(head_params.values()), +# 'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4 +# 'weight_decay': 0.0 # 通常Heads的权重不做衰减 +# # 'weight_decay': weight_decay + +# } +# ] + +# print("--- Optimizer Groups ---") +# print(f"Transformer LR: {learning_rate}") +# print(f"Tokenizer/Heads LR: {learning_rate}") + +# optimizer = torch.optim.AdamW(optim_groups, betas=betas) +# return optimizer + +def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): + """ + 为UniZero模型配置带有差异化学习率的优化器。 + (修正版,确保参数组互斥) + """ + # 1. 创建空的参数列表用于分组 + transformer_params = [] + tokenizer_params = [] + head_params = [] + + # 2. 遍历所有可训练参数,并使用 if/elif/else 结构确保每个参数只被分配到一个组 + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if 'transformer' in name: + transformer_params.append(param) + elif 'tokenizer' in name: + tokenizer_params.append(param) + else: + head_params.append(param) + + # 3. 为每组设置不同的优化器参数 + # 这里我们仍然使用AdamW,但学习率设置更合理 + optim_groups = [ + { + 'params': transformer_params, + 'lr': learning_rate, # 1e-4 + 'weight_decay': weight_decay + }, + { + 'params': tokenizer_params, + 'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4 + # 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + 'weight_decay': weight_decay # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化 + }, + { + 'params': head_params, + 'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4 + # 'weight_decay': 0.0 # 通常Heads的权重不做衰减 + 'weight_decay': weight_decay + + } + ] + + print("--- Optimizer Groups ---") + # 打印每个组的参数数量以供调试 + print(f"Transformer params: {len(transformer_params)}") + print(f"Tokenizer params: {len(tokenizer_params)}") + print(f"Head params: {len(head_params)}") + print(f"Transformer LR: {learning_rate}") + print(f"Tokenizer/Heads LR: {learning_rate}") + + optimizer = torch.optim.AdamW(optim_groups, betas=betas) + return optimizer + +@POLICY_REGISTRY.register('unizero_multitask') +class UniZeroMTPolicy(UniZeroPolicy): + """ + Overview: + The policy class for multi-task UniZero, an official implementation for the paper "UniZero: Generalized and Efficient Planning + with Scalable Latent World Models". UniZero aims to enhance the planning capabilities of reinforcement learning agents + by addressing the limitations of MuZero-style algorithms, particularly in environments requiring the + capture of long-term dependencies. More details can be found at: https://arxiv.org/abs/2406.10667. + """ + + # The default_config for UniZero multi-task policy. + config = dict( + type='unizero_multitask', + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The obs shape. + observation_shape=(3, 64, 64), + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=3, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=50, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'BN'. + norm_type='LN', # NOTE: LayerNorm is used in the transformer-based world model. + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (int) The save interval of the model. + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), + world_model_cfg=dict( + # (int) The number of tokens per block. + tokens_per_block=2, + # (int) The maximum number of blocks. + max_blocks=10, + # (int) The maximum number of tokens, calculated as tokens per block multiplied by max blocks. + max_tokens=2 * 10, + # (int) The context length, usually calculated as twice the number of some base unit. + context_length=2 * 4, + # (bool) Whether to use GRU gating mechanism. + gru_gating=False, + # (str) The device to be used for computation, e.g., 'cpu' or 'cuda'. + device='cpu', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, + # (int) The shape of the action space. + action_space_size=6, + # (int) The size of the group, related to simulation normalization. + group_size=8, # NOTE: for sim_norm + # (str) The type of attention mechanism used. Options could be ['causal']. + attention='causal', + # (int) The number of layers in the model. + num_layers=2, + # (int) The number of attention heads. + num_heads=8, + # (int) The dimension of the embedding. + embed_dim=768, + # (float) The dropout probability for the embedding layer. + embed_pdrop=0.1, + # (float) The dropout probability for the residual connections. + resid_pdrop=0.1, + # (float) The dropout probability for the attention mechanism. + attn_pdrop=0.1, + # (int) The size of the support set for value and reward heads. + support_size=101, + # (int) The maximum size of the cache. + max_cache_size=5000, + # (int) The number of environments. + env_num=8, + # (float) The weight of the latent reconstruction loss. + latent_recon_loss_weight=0., + # (float) The weight of the perceptual loss. + perceptual_loss_weight=0., + # (float) The weight of the policy entropy. + policy_entropy_weight=1e-4, + # (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse']. + predict_latent_loss_type='group_kl', + # (str) The type of observation. Options are ['image', 'vector']. + obs_type='image', + # (float) The discount factor for future rewards. + gamma=1, + # (bool) Whether to analyze dormant ratio, average_weight_magnitude of net, effective_rank of latent. + analysis_dormant_ratio_weight_rank=False, + # (float) The threshold for a dormant neuron. + dormant_threshold=0.01, + + ), + ), + # ****** common ****** + # (bool) whether to use rnd model. + use_rnd_model=False, + # (bool) Whether to use multi-gpu training. + multi_gpu=True, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space']. + action_type='fixed_action_space', + # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=400, + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to use the pure policy to collect data. + collect_with_pure_policy=False, + # (int) The evaluation frequency. + eval_freq=int(5e3), + # (str) The sample type. Options are ['episode', 'transition']. + sample_type='transition', + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + replay_ratio=0.25, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. + optim_type='AdamW', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.0001, + # (int) Frequency of hard target network update. + target_update_freq=100, + # (int) Frequency of soft target network update. + target_update_theta=0.05, + # (int) Frequency of target network update. + target_update_freq_for_intrinsic_reward=1000, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=5, + # (int) The number of episodes in each collecting stage when use muzero_collector. + n_episode=8, + # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. + num_segments=8, + # # (int) the number of simulations in MCTS for renalyze. + num_simulations=50, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=10, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + cos_lr_scheduler=False, + piecewise_decay_lr_scheduler=False, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=False, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + # (int) The initial Env Steps for training. + train_start_after_envsteps=int(0), + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm's default model setting for demonstration. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): A tuple containing the model name and a list of import paths. + - model_type (:obj:`str`): The model type used in this algorithm, registered in ModelRegistry. + - import_names (:obj:`List[str]`): The list of model class paths used in this algorithm. + .. note:: + Users can define and use customized network models, but they must adhere to the same interface definition + as indicated by the import_names path. For multi-task UniZero, this is ``lzero.model.unizero_model_multitask.UniZeroMTModel``. + """ + # NOTE: This specifies the default multi-task model. + return 'UniZeroMTModel', ['lzero.model.unizero_model_multitask'] + + def _init_learn(self) -> None: + """ + Overview: + Initializes the learn mode. This method is called by ``self.__init__``. + It sets up the learn model, optimizer, target model, and other utilities required for training. + """ + if self._cfg.optim_type == 'SGD': + # --- 改为SGD优化器 --- + self._optimizer_world_model = torch.optim.SGD( + self._model.world_model.parameters(), + lr=self._cfg.learning_rate, # 初始学习率,在配置中设为 0.2 + momentum=self._cfg.momentum, # 在配置中设为 0.9 + weight_decay=self._cfg.weight_decay # 在配置中设为 1e-4 + ) + elif self._cfg.optim_type == 'AdamW': + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + elif self._cfg.optim_type == 'AdamW_mix_lr_wdecay': + self._optimizer_world_model = configure_optimizer_unizero( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, # 使用一个合理的AdamW基础学习率 + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + if self._cfg.cos_lr_scheduler: + from torch.optim.lr_scheduler import CosineAnnealingLR + # TODO: check the total training steps + # self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + total_iters = self._cfg.get('total_iterations', 500000) # 500k iter + # final_lr = self._cfg.get('final_learning_rate', 0.0) + final_lr = self._cfg.get('final_learning_rate', 1e-6) + + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, + T_max=total_iters, + eta_min=final_lr + ) + print(f"CosineAnnealingLR enabled: T_max={total_iters}, eta_min={final_lr}") + + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer_world_model, lr_lambda=lr_lambda) + + + # Use a deep copy for the target model. + self._target_model = copy.deepcopy(self._model) + # Ensure that the installed torch version is >= 2.0 for torch.compile. + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" + self._model = torch.compile(self._model) + self._target_model = torch.compile(self._target_model) + + # Wrap the target model for soft updates (momentum-based). + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + + self.intermediate_losses = defaultdict(float) + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + + # Create a WrappedModel instance. + # This is used for gradient correction methods where gradients of shared parameters are managed. + # In this setup, all parameters are considered shared and subject to correction. + # wrapped_model = WrappedModel( + # self._learn_model.world_model, + # ) + + self.task_id = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + + print(f'self._cfg.only_use_moco_stats:{self._cfg.only_use_moco_stats}') + if self._cfg.use_moco or self._cfg.only_use_moco_stats: + # The prediction heads' gradients are not corrected. + self.wrapped_model = WrappedModelV2( + # TODO: This assumes the tokenizer has an encoder attribute which is a list. This might need to be more robust. + self._learn_model.world_model.tokenizer.encoder[0], + self._learn_model.world_model.transformer, + self._learn_model.world_model.pos_emb, + self._learn_model.world_model.task_emb, + self._learn_model.world_model.act_embedding_table, + ) + + # Alternative setup: The head and tokenizer.encoder gradients are not corrected. + # wrapped_model = WrappedModelV3( + # self._learn_model.world_model.transformer, + # self._learn_model.world_model.pos_emb, + # self._learn_model.world_model.task_emb, + # self._learn_model.world_model.act_embedding_table, + # ) + + # Pass the wrapped_model as `shared_module` to the gradient correction method. + # ========= Initialize MoCo/CAGrad parameters ========= + if self._cfg.moco_version=="v0": + # This version is only compatible with single-GPU training. + self.grad_correct = GradCorrect(self.wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) + self.grad_correct.init_param() + self.grad_correct.rep_grad = False + elif self._cfg.moco_version=="v1": + cfg_moco = MoCoCfg( + beta0=0.9, beta_sigma=0.95, + gamma0=0.1, gamma_sigma=0.95, + rho=0.01, stat_interval=10000) + self.grad_correct = FastMoCo( + shared_module=self.wrapped_model, + world_task_num=self._cfg.total_task_num, # Total number of tasks globally + device=self._cfg.device, + multi_gpu=self._cfg.multi_gpu, + cfg=cfg_moco, + ) + + # Cache for plasticity-related metrics from the previous frame. + self._prev_plasticity_metrics = dict( + dormant_ratio_encoder = 0.0, + dormant_ratio_transformer = 0.0, + dormant_ratio_head = 0.0, + avg_weight_mag_encoder = 0.0, + avg_weight_mag_transformer = 0.0, + avg_weight_mag_head = 0.0, + e_rank_last_linear = 0.0, + e_rank_sim_norm = 0.0, + ) + + # ==================== START: 目标熵正则化初始化 ==================== + # 从配置中读取是否启用自适应alpha,并提供一个默认值 + self.use_adaptive_entropy_weight = self._cfg.get('use_adaptive_entropy_weight', True) + + # 在 _init_learn 中增加配置 + self.target_entropy_start_ratio = self._cfg.get('target_entropy_start_ratio', 0.98) + self.target_entropy_end_ratio = self._cfg.get('target_entropy_end_ratio', 0.7) + self.target_entropy_decay_steps = self._cfg.get('target_entropy_decay_steps', 200000) # 例如,在200k步内完成退火 2M envsteps + + if self.use_adaptive_entropy_weight: + # 1. 设置目标熵。对于离散动作空间,一个常见的启发式设置是动作空间维度的负对数乘以一个系数。 + # 这个系数(例如0.98)可以作为一个超参数。 + action_space_size = self._cfg.model.action_space_size + self.target_entropy = -np.log(1.0 / action_space_size) * 0.98 + + # 2. 初始化一个可学习的 log_alpha 参数。 + # 初始化为0,意味着初始的 alpha = exp(0) = 1.0。 + self.log_alpha = torch.nn.Parameter(torch.zeros(1, device=self._cfg.device), requires_grad=True) + + # 3. 为 log_alpha 创建一个专属的优化器。 + # 使用与主优化器不同的、较小的学习率(例如1e-4)通常更稳定。 + alpha_lr = self._cfg.get('adaptive_entropy_alpha_lr', 1e-4) + self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr) + + print("="*20) + print(">>> 目标熵正则化 (自适应Alpha) 已启用 <<<") + print(f" 目标熵 (Target Entropy): {self.target_entropy:.4f}") + print(f" Alpha 优化器学习率: {alpha_lr:.2e}") + print("="*20) + # ===================== END: 目标熵正则化初始化 ===================== + + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 30.0) + # ==================== START: 初始化 Encoder-Clip Annealing 参数 ==================== + self.use_encoder_clip_annealing = self._cfg.get('use_encoder_clip_annealing', False) + if self.use_encoder_clip_annealing: + self.encoder_clip_anneal_type = self._cfg.get('encoder_clip_anneal_type', 'cosine') + self.encoder_clip_start = self._cfg.get('encoder_clip_start_value', 30.0) + self.encoder_clip_end = self._cfg.get('encoder_clip_end_value', 10.0) + self.encoder_clip_anneal_steps = self._cfg.get('encoder_clip_anneal_steps', 200000) + + print("="*20) + print(">>> Encoder-Clip 退火已启用 <<<") + print(f" 类型: {self.encoder_clip_anneal_type}") + print(f" 范围: {self.encoder_clip_start} -> {self.encoder_clip_end}") + print(f" 步数: {self.encoder_clip_anneal_steps}") + print("="*20) + else: + # 如果不启用退火,则使用固定的 clip 阈值 + self.latent_norm_clip_threshold = self._cfg.get('latent_norm_clip_threshold', 30.0) + # ===================== END: 初始化 Encoder-Clip Annealing 参数 ===================== + + # --- NEW: Policy Label Smoothing Parameters --- + self.policy_ls_eps_start = self._cfg.get('policy_ls_eps_start', 0.05) # TODO policy_label_smoothing_eps_start 越大的action space需要越大的eps + self.policy_ls_eps_end = self._cfg.get('policy_label_smoothing_eps_end ', 0.01) # TODO policy_label_smoothing_eps_start + self.policy_ls_eps_decay_steps = self._cfg.get('policy_ls_eps_decay_steps ', 50000) # TODO 50k + print(f"self.policy_ls_eps_start:{self.policy_ls_eps_start}") + + @staticmethod + def _is_zero(x: Union[float, torch.Tensor], eps: float = 1e-8) -> bool: + """ + Overview: + Checks if a scalar or a 0-D tensor can be considered zero within a small tolerance. + Arguments: + - x (:obj:`Union[float, torch.Tensor]`): The input value to check. + - eps (:obj:`float`): The tolerance for checking against zero. + Returns: + - (:obj:`bool`): True if the value is close to zero, False otherwise. + """ + if isinstance(x, torch.Tensor): + return torch.all(torch.abs(x) < eps).item() + return abs(x) < eps + + def _retain_prev_if_zero(self, name: str, + value: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + """ + Overview: + If the current `value` is close to zero, returns the cached value from the previous frame. + Otherwise, it updates the cache with the current value and returns it. This is useful for + metrics that are computed intermittently. + Arguments: + - name (:obj:`str`): The name of the metric to cache. + - value (:obj:`Union[float, torch.Tensor]`): The current value of the metric. + Returns: + - (:obj:`Union[float, torch.Tensor]`): The retained or current value. + """ + if self._is_zero(value): + # Directly return the previous value (can be float or tensor). + return self._prev_plasticity_metrics[name] + else: + # Update the cache and return the current value. + self._prev_plasticity_metrics[name] = value + return value + + + #@profile + def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_iter=None, ignore_grad=False) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning in the policy. This is the core of the training process. + Data is sampled from the replay buffer, losses are calculated, and the model is updated via backpropagation. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): A tuple of data batches, where each element corresponds to a different task. + - task_weights (:obj:`Any`, optional): Optional weights for each task's loss. Not currently used. + - ignore_grad (:obj:`bool`): If True, gradients are zeroed out after computation, effectively skipping the update. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): A dictionary containing current learning losses and statistics for logging. + """ + self._learn_model.train() + self._target_model.train() + + # Lists to store metrics for each task within the batch. + obs_loss_multi_task = [] + reward_loss_multi_task = [] + policy_loss_multi_task = [] + value_loss_multi_task = [] + latent_recon_loss_multi_task = [] + perceptual_loss_multi_task = [] + orig_policy_loss_multi_task = [] + policy_entropy_multi_task = [] + weighted_total_loss = 0.0 # Initialize to 0.0 to avoid in-place operations. + + latent_state_l2_norms_multi_task = [] + average_target_policy_entropy_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + + # Metrics for network plasticity analysis. + dormant_ratio_encoder_multi_task = [] + dormant_ratio_transformer_multi_task = [] + dormant_ratio_head_multi_task = [] + avg_weight_mag_encoder_multi_task = [] + avg_weight_mag_transformer_multi_task = [] + avg_weight_mag_head_multi_task = [] + e_rank_last_linear_multi_task = [] + e_rank_sim_norm_multi_task = [] + + # --- NEW: Calculate current epsilon for policy --- + # if self.policy_ls_eps_start > 0: + # progress = min(1.0, train_iter / self.policy_ls_eps_decay_steps) + # current_policy_label_eps = self.policy_ls_eps_start * (1 - progress) + self.policy_ls_eps_end * progress + # else: + # current_policy_label_eps = 0.0 + current_policy_label_eps = 0.01 + + # 新增一个列表来收集当前批次中所有任务的真实全局ID + global_task_ids_in_batch = [] + alpha_loss = None + + losses_list = [] # Used to store the loss tensor for each task, required by gradient correction methods. + for task_id, data_one_task in enumerate(data): + current_batch, target_batch, task_id = data_one_task # task_id 是真实的全局ID + + # 将真实的全局ID添加到列表中 + global_task_ids_in_batch.append(task_id) + + # TODO: Adapt RoPE for multitask settings (using timestep_batch). + obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch + target_reward, target_value, target_policy = target_batch + + # Prepare observations based on frame stack number. + if self._cfg.model.frame_stack_num == 4: + obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) + else: + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # Apply augmentations if needed. + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare action batch and convert to a torch tensor. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( + -1).long() # For discrete action space. + data_list = [mask_batch, target_reward.astype('float32'), target_value.astype('float32'), target_policy, + weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, + self._cfg.device) + + cur_batch_size = target_reward.size(0) # Run-time batch size. + + target_reward = target_reward.view(cur_batch_size, -1) + target_value = target_value.view(cur_batch_size, -1) + + # Transform scalar rewards and values to their scaled representations. + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # Convert scaled representations to categorical distributions. + # target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + # target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward, label_smoothing_eps= self._cfg.label_smoothing_eps) + target_value_categorical = phi_transform(self.value_support, transformed_target_value, label_smoothing_eps=self._cfg.label_smoothing_eps) + + + # Prepare the batch for the transformer-based world model. + batch_for_gpt = {} + if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + cur_batch_size, -1, self._cfg.model.observation_shape) + elif len(self._cfg.model.observation_shape) == 3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + cur_batch_size, -1, *self._cfg.model.observation_shape) + + batch_for_gpt['actions'] = action_batch.squeeze(-1) + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] + batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data. + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, + device=self._cfg.device) + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] + batch_for_gpt['target_policy'] = target_policy[:, :-1] + batch_for_gpt['scalar_target_value'] = target_value + + # Extract valid target policy data and compute its entropy. + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean().item() + + # Update world model and compute losses. + intermediate_losses = defaultdict(float) + # losses = self._learn_model.world_model.compute_loss( + # batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, task_id=task_id + # ) + + losses = self._learn_model.world_model.compute_loss( + batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, current_policy_label_eps=current_policy_label_eps, task_id=task_id + ) + + # ==================== START MODIFICATION 2 ==================== + # Extract the calculated value_priority from the returned losses. + value_priority_tensor = losses.intermediate_losses['value_priority'] + # Convert to numpy array for the replay buffer, adding a small epsilon. + value_priority_np = value_priority_tensor.detach().cpu().numpy() + 1e-6 + # ===================== END MODIFICATION 2 ===================== + + + # TODO: Accumulate the weighted total loss. This assumes the loss from `compute_loss` is already weighted. + weighted_total_loss += losses.loss_total # NOTE:+= + + # TODO: Add assertions to check for NaN or Inf values in the loss if needed for debugging. + # assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" + # assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" + + # TODO: Append the total loss for this task, used by MoCo. + losses_list.append(losses.loss_total) + + for loss_name, loss_value in losses.intermediate_losses.items(): + intermediate_losses[f"{loss_name}"] = loss_value + + + + obs_loss = intermediate_losses['loss_obs'] + reward_loss = intermediate_losses['loss_rewards'] + policy_loss = intermediate_losses['loss_policy'] + orig_policy_loss = intermediate_losses['orig_policy_loss'] + policy_entropy = intermediate_losses['policy_entropy'] + value_loss = intermediate_losses['loss_value'] + latent_recon_loss = intermediate_losses['latent_recon_loss'] + perceptual_loss = intermediate_losses['perceptual_loss'] + latent_state_l2_norms = intermediate_losses['latent_state_l2_norms'] + + # 从 losses 对象中提取策略熵 + # ==================== START: 目标熵正则化更新逻辑 ==================== + current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight # 默认使用固定值 + if self.use_adaptive_entropy_weight: + # --- 动态计算目标熵 (这部分逻辑是正确的,予以保留) --- + progress = min(1.0, train_iter / self.target_entropy_decay_steps) + current_ratio = self.target_entropy_start_ratio * (1 - progress) + self.target_entropy_end_ratio * progress + action_space_size = self._cfg.model.action_space_size + # 注意:我们将 target_entropy 定义为正数,更符合直觉 + current_target_entropy = -np.log(1.0 / action_space_size) * current_ratio + + # --- 计算 alpha_loss (已修正符号) --- + # 这是核心修正点:去掉了最前面的负号 + # detach() 仍然是关键,确保 alpha_loss 的梯度只流向 log_alpha + alpha_loss = (self.log_alpha * (policy_entropy.detach() - current_target_entropy)).mean() # NOTE:= + + # # --- 更新 log_alpha --- + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + # --- [优化建议] 增加 log_alpha 裁剪作为安全措施 --- + with torch.no_grad(): + # 将 alpha 限制在例如 [1e-4, 10.0] 的范围内 + self.log_alpha.clamp_(np.log(1e-4), np.log(10.0)) + + # --- 使用当前更新后的 alpha (截断梯度流) --- + current_alpha = self.log_alpha.exp().detach() + + # 重新计算加权的策略损失和总损失 + # 注意:这里的 policy_entropy 已经是一个batch的平均值 + weighted_policy_loss = orig_policy_loss - current_alpha * policy_entropy + # 重新构建总损失 (不使用 losses.loss_total) + # 确保这里的权重与 LossWithIntermediateLosses 类中的计算方式一致 + self.obs_loss_weight = 10 + self.value_loss_weight = 0.5 + self.reward_loss_weight = 1. + self.policy_loss_weight = 1. + self.ends_loss_weight = 0. + total_loss = ( + self.reward_loss_weight * reward_loss + + self.value_loss_weight * value_loss + + self.policy_loss_weight * weighted_policy_loss + + self.obs_loss_weight * obs_loss # 假设 ssl_loss_weight 是 obs_loss 的权重 + # ... 如果还有其他损失项,也加进来 ... + ) + weighted_total_loss += (weights * total_loss).mean() # NOTE:+= + # ===================== END: 目标熵正则化更新逻辑 ===================== + + # ============ For value-based priority calculation ============ + # TODO: The following section for calculating value_priority is commented out. + # If re-enabled, ensure it correctly computes L1 loss between predicted and target values + # and handles CPU/Numpy conversion properly. + # original_value = self.value_inverse_scalar_transform_handle(logits_value.reshape(-1, 101)).reshape( + # batch_for_gpt['observations'].shape[0], batch_for_gpt['observations'].shape[1], 1) + # value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1)[:,0], target_value[:, 0]) + # value_priority = value_priority.data.cpu().numpy() + 1e-6 + # value_priority = torch.tensor(0., device=self._cfg.device) + # ============ End of value priority section ============ + + # Metrics related to network plasticity. + # Use the helper function to retain the previous value if the current one is zero. + dormant_ratio_encoder = self._retain_prev_if_zero( + 'dormant_ratio_encoder', + intermediate_losses['dormant_ratio_encoder']) + dormant_ratio_transformer = self._retain_prev_if_zero( + 'dormant_ratio_transformer', + intermediate_losses['dormant_ratio_transformer']) + dormant_ratio_head = self._retain_prev_if_zero( + 'dormant_ratio_head', + intermediate_losses['dormant_ratio_head']) + avg_weight_mag_encoder = self._retain_prev_if_zero( + 'avg_weight_mag_encoder', + intermediate_losses['avg_weight_mag_encoder']) + avg_weight_mag_transformer = self._retain_prev_if_zero( + 'avg_weight_mag_transformer', + intermediate_losses['avg_weight_mag_transformer']) + avg_weight_mag_head = self._retain_prev_if_zero( + 'avg_weight_mag_head', + intermediate_losses['avg_weight_mag_head']) + e_rank_last_linear = self._retain_prev_if_zero( + 'e_rank_last_linear', + intermediate_losses['e_rank_last_linear']) + e_rank_sim_norm = self._retain_prev_if_zero( + 'e_rank_sim_norm', + intermediate_losses['e_rank_sim_norm']) + + # Append all metrics for this task to their respective lists. + obs_loss_multi_task.append(obs_loss) + reward_loss_multi_task.append(reward_loss) + policy_loss_multi_task.append(policy_loss) + orig_policy_loss_multi_task.append(orig_policy_loss) + policy_entropy_multi_task.append(policy_entropy) + value_loss_multi_task.append(value_loss) + latent_recon_loss_multi_task.append(latent_recon_loss) + perceptual_loss_multi_task.append(perceptual_loss) + latent_state_l2_norms_multi_task.append(latent_state_l2_norms) + value_priority_multi_task.append(value_priority_tensor) + value_priority_mean_multi_task.append(value_priority_tensor.mean().item()) + + # Append plasticity metrics. + dormant_ratio_encoder_multi_task.append(dormant_ratio_encoder) + dormant_ratio_transformer_multi_task.append(dormant_ratio_transformer) + dormant_ratio_head_multi_task.append(dormant_ratio_head) + avg_weight_mag_encoder_multi_task.append(avg_weight_mag_encoder) + avg_weight_mag_transformer_multi_task.append(avg_weight_mag_transformer) + avg_weight_mag_head_multi_task.append(avg_weight_mag_head) + e_rank_last_linear_multi_task.append(e_rank_last_linear) + e_rank_sim_norm_multi_task.append(e_rank_sim_norm) + + + # Core learn model update step. + self._optimizer_world_model.zero_grad() + + # Assuming losses_list is a list of tensors with gradients, e.g., [loss1, loss2, ...]. + if self._cfg.use_moco: + # Call MoCo's backward method, which handles gradient correction internally. + if self._cfg.moco_version=="v0": + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + elif self._cfg.moco_version=="v1": + lambd, stats = self.grad_correct.backward(losses_list) + + elif self._cfg.only_use_moco_stats: + # Only compute MoCo stats without applying gradient correction. + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + # Each rank performs its own backpropagation. + weighted_total_loss.backward() + else: + # If not using gradient correction, each rank performs standard backpropagation. + lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device) + weighted_total_loss.backward() + + + # ----------------------------------------------------------------- + # 仍然在 torch.no_grad() 环境下执行 + # ================================================================= + with torch.no_grad(): + # 1. Encoder-Clip + # ==================== START: 动态计算当前 Clip 阈值 ==================== + current_clip_value = self.latent_norm_clip_threshold # 默认使用固定值 + if self.use_encoder_clip_annealing: + progress = min(1.0, train_iter / self.encoder_clip_anneal_steps) + + if self.encoder_clip_anneal_type == 'cosine': + # 余弦调度: 从1平滑过渡到0 + cosine_progress = 0.5 * (1.0 + np.cos(np.pi * progress)) + current_clip_value = self.encoder_clip_end + \ + (self.encoder_clip_start - self.encoder_clip_end) * cosine_progress + else: # 默认为线性调度 + current_clip_value = self.encoder_clip_start * (1 - progress) + \ + self.encoder_clip_end * progress + # ===================== END: 动态计算当前 Clip 阈值 ===================== + + # 1. Encoder-Clip (使用动态计算出的 current_clip_value) + if current_clip_value > 0 and 'obs_embeddings' in losses.intermediate_losses: + obs_embeddings = losses.intermediate_losses['obs_embeddings'] + if obs_embeddings is not None: + max_latent_norm = obs_embeddings.norm(p=2, dim=-1).max() + if max_latent_norm > current_clip_value: + scale_factor = current_clip_value / max_latent_norm.item() + # 不再频繁打印,或者可以改为每隔N步打印一次 + if train_iter % 1000 == 0: + print(f"[Encoder-Clip Annealing] Iter {train_iter}: Max latent norm {max_latent_norm.item():.2f} > {current_clip_value:.2f}. Scaling by {scale_factor:.4f}.") + scale_module_weights_vectorized(self._model.world_model.tokenizer.encoder, scale_factor) + + + # For debugging purposes. + # for name, param in self._learn_model.world_model.tokenizer.encoder.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + # if param.requires_grad: + # print(name, param.grad.norm()) + + if self._cfg.analysis_sim_norm: + del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after + self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() + self._target_model.encoder_hook.clear_data() + + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), + self._cfg.grad_clip_value) + + if ignore_grad: + # NOTE: For cases where all tasks on a GPU are solved, `train` is still called for DDP synchronization, + # but gradients should be zeroed out to prevent updates. + self._optimizer_world_model.zero_grad() + + if self._cfg.multi_gpu: + # If not using a gradient correction method that handles it, sync gradients manually. + if not self._cfg.use_moco: + self.sync_gradients(self._learn_model) + + self._optimizer_world_model.step() + + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # Core target model update step. + self._target_model.update(self._learn_model.state_dict()) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0. + max_memory_allocated_gb = 0. + + # Build the dictionary of return values for logging. + return_log_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self._collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # ==================== START: 添加新日志项 ==================== + if self.use_adaptive_entropy_weight: + return_log_dict['adaptive_alpha'] = current_alpha.item() + return_log_dict['adaptive_target_entropy_ratio'] = current_ratio + return_log_dict['alpha_loss'] = alpha_loss.item() + # ==================== START: 添加新日志项 ==================== + + # Generate task-related loss dictionaries and prefix each task-related loss with "noreduce_". + multi_task_loss_dicts = { + **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', task_id=self.task_id), #global_task_ids=global_task_ids_in_batch), # task_id=self.task_id), + **generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_head_multi_task, 'noreduce_dormant_ratio_head_task{}', task_id=self.task_id), + + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(orig_policy_loss_multi_task, 'noreduce_orig_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'noreduce_target_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + } + return_log_dict.update(multi_task_loss_dicts) + + + if self._learn_model.world_model.do_analysis: + # Include plasticity metrics if analysis is enabled. + plasticity_loss_dicts = { + **generate_task_loss_dict(dormant_ratio_encoder_multi_task, 'noreduce_dormant_ratio_encoder_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_transformer_multi_task, 'noreduce_dormant_ratio_transformer_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_head_multi_task, 'noreduce_dormant_ratio_head_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_encoder_multi_task, 'noreduce_avg_weight_mag_encoder_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_transformer_multi_task, 'noreduce_avg_weight_mag_transformer_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_head_multi_task, 'noreduce_avg_weight_mag_head_task{}', task_id=self.task_id), + **generate_task_loss_dict(e_rank_last_linear_multi_task, 'noreduce_e_rank_last_linear_task{}', task_id=self.task_id), + **generate_task_loss_dict(e_rank_sim_norm_multi_task, 'noreduce_e_rank_sim_norm_task{}', task_id=self.task_id), + } + # Merge the dictionaries. + return_log_dict.update(plasticity_loss_dicts) + + # Return the final loss dictionary. + return return_log_dict + + def monitor_weights_and_grads(self, model: torch.nn.Module) -> None: + """ + Overview: + A utility function to print the mean and standard deviation of weights and their gradients for each layer in a model. + Useful for debugging training issues like exploding or vanishing gradients. + Arguments: + - model (:obj:`torch.nn.Module`): The model to monitor. + """ + for name, param in model.named_parameters(): + if param.requires_grad: + print(f"Layer: {name} | " + f"Weight mean: {param.data.mean():.4f} | " + f"Weight std: {param.data.std():.4f} | " + f"Grad mean: {param.grad.mean():.4f} | " + f"Grad std: {param.grad.std():.4f}") + + def _init_collect(self) -> None: + """ + Overview: + Initializes the collect mode. This method is called by ``self.__init__``. + It sets up the collect model and MCTS utilities for data collection. + """ + self._collect_model = self._model + + # Create a copy of the configuration for collect MCTS and set a specific number of simulations. + mcts_collect_cfg = copy.deepcopy(self._cfg) + mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations + + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(mcts_collect_cfg) + else: + self._mcts_collect = MCTSPtree(mcts_collect_cfg) + + self._collect_mcts_temperature = 1. + self._collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + + # TODO: The num_tasks parameter is hardcoded. It should ideally be derived from the config. + def _monitor_vars_learn(self, num_tasks: int = 2) -> List[str]: + """ + Overview: + Registers variables to be monitored during training. These variables will be logged in TensorBoard. + It dynamically creates variable names for each task if `num_tasks` is provided. + Arguments: + - num_tasks (:obj:`int`): The number of tasks being trained on the current rank. + Returns: + - monitored_vars (:obj:`List[str]`): A list of strings, where each string is the name of a variable to be logged. + """ + # Basic monitored variables that do not depend on the number of tasks. + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + + # 'value_priority', + 'adaptive_alpha', + "adaptive_target_entropy_ratio", + 'alpha_loss', + ] + + + + # Task-specific variables to be monitored. + task_specific_vars = [ + 'noreduce_obs_loss', + 'noreduce_orig_policy_loss', + 'noreduce_policy_loss', + 'noreduce_latent_recon_loss', + 'noreduce_policy_entropy', + 'noreduce_target_policy_entropy', + 'noreduce_reward_loss', + 'noreduce_value_loss', + 'noreduce_perceptual_loss', + 'noreduce_latent_state_l2_norms', + 'noreduce_lambd', + 'noreduce_value_priority_mean', + # Metrics related to network plasticity. + 'noreduce_dormant_ratio_encoder', + 'noreduce_dormant_ratio_transformer', + 'noreduce_dormant_ratio_head', + 'noreduce_avg_weight_mag_encoder', + 'noreduce_avg_weight_mag_transformer', + 'noreduce_avg_weight_mag_head', + 'noreduce_e_rank_last_linear', + 'noreduce_e_rank_sim_norm' + ] + + # Use self.task_num_for_current_rank as the number of tasks for the current rank. + num_tasks = self.task_num_for_current_rank + # If the number of tasks is provided, extend the monitored variables list with task-specific variable names. + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + monitored_vars.append(f'{var}_task{self.task_id+task_idx}') + else: + # If num_tasks is not provided, assume a single task and use the original variable names. + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + #@profile + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + timestep: List = [0], + task_id: int = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data. It uses the model to perform MCTS search and + selects actions via sampling to encourage exploration. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e., the current observation. + - action_mask (:obj:`list`, optional): A list of action masks for each environment. + - temperature (:obj:`float`, optional): The temperature for MCTS action selection. + - to_play (:obj:`List`, optional): A list of player IDs for each environment. + - epsilon (:obj:`float`, optional): The probability for epsilon-greedy exploration. + - ready_env_id (:obj:`np.array`, optional): An array of IDs for environments that are ready for a new action. + - timestep (:obj:`List`, optional): The current timestep in each environment. + - task_id (:obj:`int`, optional): The ID of the task for the current environments. + Returns: + - output (:obj:`Dict`): A dictionary where keys are environment IDs and values are dictionaries + containing the selected action and other MCTS statistics. + """ + self._collect_model.eval() + + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + + # ========================== 核心修复 ========================== + # C++ 绑定需要一个 list,即使它在 MuZero 中代表奖励。 + reward_roots = reward_roots.detach().cpu().numpy().tolist() + # =============================================================== + + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + # The main difference between collect and eval is the addition of Dirichlet noise at the root. + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # C++ MCTS tree implementation. + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # Python MCTS tree implementation. + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + + # # 在本文件开始,通过全局变量来控制是否处于调试状态 + # global DEBUG_ENABLED;DEBUG_ENABLED = True + # import torch.distributed as dist + # if dist.get_rank() == 0 and DEBUG_ENABLED: + # print(f"rank {dist.get_rank()} 进入调试模式,输入interact,可以键入整段的python代码调试。通过设置 DEBUG_ENABLED = False, 可以跳过调试状态") + # import ipdb; ipdb.set_trace() + # # 同步点,防止其它进程早跑 + # dist.barrier() + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep= timestep, task_id=task_id) + + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + if self._cfg.eps.eps_greedy_exploration_in_collect: + # Epsilon-greedy collection strategy. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self._collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # Standard collection strategy (sampling from MCTS policy). + # NOTE: `action_index_in_legal_action_set` is the index within the set of legal actions. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # Convert the index back to the action in the full action space. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + # ============== TODO: This section is for visualization purposes only and should be removed for training. ============== + # It forces deterministic action selection during collection. + # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + # distributions, temperature=self._collect_mcts_temperature, deterministic=True + # ) + # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # ============== End of visualization section. ============== + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + # ========= TODO: This logic is currently for the `muzero_segment_collector`. ========= + if active_collect_env_num < self.collector_env_num: + # When one environment in `collect_env` finishes early, the length of `self.last_batch_obs` is reduced. + # The transformer needs the `env_id` to retrieve from the KV cache, which is complex to manage with a dynamic batch size. + # Therefore, we reset `self.last_batch_action` for all environments to -1, forcing the transformer + # to start from scratch and avoid retrieval errors. + print('==========collect_forward============') + print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + self._reset_collect(reset_init_data=True, task_id=task_id) + if getattr(self._cfg, 'sample_type', '') == 'episode': + print('BUG: sample_type is episode, but len(self.last_batch_obs) < self.collector_env_num') + + return output + + def _init_eval(self) -> None: + """ + Overview: + Initializes the eval mode. This method is called by ``self.__init__``. + It sets up the eval model and MCTS utilities for evaluation. + """ + self._eval_model = self._model + + # Create a copy of the configuration for eval MCTS and set a specific number of simulations. + mcts_eval_cfg = copy.deepcopy(self._cfg) + mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations + + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(mcts_eval_cfg) + else: + self._mcts_eval = MCTSPtree(mcts_eval_cfg) + + self.evaluator_env_num = self._cfg.evaluator_env_num + + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + #@profile + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, timestep: List = [0], task_id: int = None) -> Dict: + """ + Overview: + The forward function for evaluating the policy. It uses the model to perform MCTS search and + selects actions deterministically (choosing the one with the highest visit count). + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e., the current observation. + - action_mask (:obj:`list`): A list of action masks for each environment. + - to_play (:obj:`int`, optional): The player ID for the current turn. + - ready_env_id (:obj:`np.array`, optional): An array of IDs for environments that are ready for a new action. + - timestep (:obj:`List`, optional): The current timestep in each environment. + - task_id (:obj:`int`, optional): The ID of the task for the current environments. + Returns: + - output (:obj:`Dict`): A dictionary where keys are environment IDs and values are dictionaries + containing the selected action and other MCTS statistics. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + # ========================== 核心修复 ========================== + # C++ 绑定需要一个 list,即使它在 MuZero 中代表奖励。 + reward_roots = reward_roots.detach().cpu().numpy().tolist() # TODO============================= + # =============================================================== + + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # C++ MCTS tree implementation. + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # Python MCTS tree implementation. + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + + # During evaluation, no noise is added to the root policy. + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep= timestep, task_id=task_id) + + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + + batch_action = [] + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + # NOTE: `deterministic=True` means we select the action with the highest visit count (argmax) + # rather than sampling, which is standard for evaluation. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # Convert the index back to the action in the full action space. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs_eval = data + self.last_batch_action = batch_action + + return output + + #@profile + def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Overview: + Resets the collection process for a specific environment or all environments. + It can clear caches and reset initial data to ensure optimal performance and prevent state leakage. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None, the reset applies more broadly. Defaults to None. + - current_steps (:obj:`int`, optional): The current step count in the environment, used to trigger periodic cache clearing. Defaults to 0. + - reset_init_data (:obj:`bool`, optional): If True, resets the initial observation and action buffers. Defaults to True. + - task_id (:obj:`int`, optional): The task ID, currently unused in this method. Defaults to None. + """ + if reset_init_data: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + # print('Collector: last_batch_obs and last_batch_action have been reset.') + + # Return immediately if env_id is not a single integer (e.g., None or a list). + # if env_id is None or isinstance(env_id, list): + # return + + # We must handle both single int and list of ints for env_id. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the collector. + if current_steps is None: + world_model = self._collect_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end.') + + + # Determine the clear interval based on the environment's sample type. + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length + + # Clear caches periodically to manage memory. + # if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: + + print(f'clear_interval: {clear_interval}') + + # Clear various KV caches in the collect model's world model. + world_model = self._collect_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up unused GPU memory. + torch.cuda.empty_cache() + + print(f'Collector: Caches cleared for collect_model at step {current_steps} for env {env_id}.') + + # TODO: Check if resetting the target model here is correct and necessary. + self._reset_target_model() + + #@profile + def _reset_target_model(self) -> None: + """ + Overview: + Resets the target model by clearing its internal caches. This is crucial for managing memory, + especially when using transformer-based models with KV caching. + """ + # Clear various KV caches in the target model's world model. + world_model = self._target_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up unused GPU memory. + torch.cuda.empty_cache() + print('Collector: Target model past_kv_cache cleared.') + + #@profile + def _reset_eval(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Overview: + Resets the evaluation process for a specific environment or all environments. + Clears caches and resets initial data to ensure clean evaluation runs. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. Defaults to None. + - current_steps (:obj:`int`, optional): The current step count, used for periodic cache clearing. Defaults to 0. + - reset_init_data (:obj:`bool`, optional): If True, resets the initial observation and action buffers. Defaults to True. + - task_id (:obj:`int`, optional): The task ID. Can be used to handle different observation shapes per task. Defaults to None. + """ + if reset_init_data: + self.last_batch_obs_eval = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + # print(f'Evaluator reset: last_batch_obs_eval shape: {self.last_batch_obs_eval.shape}') + + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + + + # --- BEGIN ROBUST FIX --- + # This logic handles the crucial end-of-episode cache clearing for evaluation. + # The evaluator calls `_policy.reset([env_id])` when an episode is done. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the evaluator. + if current_steps is None: + world_model = self._eval_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end.') + + # The recurrent cache is global. + world_model.past_kv_cache_recurrent_infer.clear() + + if hasattr(world_model, 'keys_values_wm_list'): + world_model.keys_values_wm_list.clear() + + torch.cuda.empty_cache() + return + # --- END ROBUST FIX --- + + # Determine the clear interval. + # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length + + # Clear caches periodically. + # if current_steps % clear_interval == 0: + if current_steps is not None and current_steps % clear_interval == 0: + + print(f'clear_interval: {clear_interval}') + + # Clear various KV caches in the eval model's world model. + world_model = self._eval_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up unused GPU memory. + torch.cuda.empty_cache() + + print(f'Evaluator: Caches cleared for eval_model at step {current_steps} for env {env_id}.') + + + def recompute_pos_emb_diff_and_clear_cache(self) -> None: + """ + Overview: + Clears all KV caches and precomputes positional embedding matrices in the model. + This is typically called when the maximum sequence length changes. + """ + # NOTE: This must be done for both the collect and target models. + for model in [self._collect_model, self._target_model]: + model.world_model.precompute_pos_emb_diff_kv() + model.world_model.clear_caches() + torch.cuda.empty_cache() + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Returns the state dictionary of the learn mode. + This typically includes the model, target model, and optimizer states, + which are necessary for saving and resuming training. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The state dictionary for the current learning progress. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer_world_model': self._optimizer_world_model.state_dict(), + } + + # ========== NOTE: This is the original version which loads all parameters from the state_dict. ========== + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Loads the state_dict into the policy's learn mode. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The state dictionary saved from a previous training session. + # """ + # self._learn_model.load_state_dict(state_dict['model']) + # self._target_model.load_state_dict(state_dict['target_model']) + # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + # ========== NOTE: This is a pretrain-finetune version that selectively loads parameters and freezes layers. ========== + def _load_state_dict_learn(self, state_dict: Dict[str, Any], finetune_components: List[str] = []) -> None: + """ + Overview: + Loads a state_dict for fine-tuning. It excludes multi-task specific parameters + and can freeze parts of the model (e.g., encoder, transformer) based on `finetune_components`. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The state dictionary from a pre-trained model. + - finetune_components (:obj:`List[str]`, optional): A list of component names (e.g., "encoder", "transformer") + that will remain trainable. Components not in this list will have their parameters frozen. + """ + # Example configurations for fine-tuning: + # finetune_components = [] # Loads encoder & transformer, fine-tunes only heads. + # finetune_components = ['transformer'] # Loads encoder & transformer, fine-tunes transformer & heads. + finetune_components = ["representation_network", "encoder"] # Loads encoder & transformer, fine-tunes encoder & heads. + + # Define prefixes of parameters to be excluded from loading (typically multi-task heads). + exclude_prefixes = [ + '_orig_mod.world_model.head_policy_multi_task.', + '_orig_mod.world_model.head_value_multi_task.', + '_orig_mod.world_model.head_rewards_multi_task.', + '_orig_mod.world_model.head_observations_multi_task.', + '_orig_mod.world_model.task_emb.' + ] + + # Define specific parameter keys to be excluded (for special cases like task embeddings). + exclude_keys = [ + '_orig_mod.world_model.task_emb.weight', + '_orig_mod.world_model.task_emb.bias', + ] + + def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + """ + Filters out parameters from a state_dict based on prefixes and specific keys. + """ + filtered = {} + for k, v in state_dict_loader.items(): + if any(k.startswith(prefix) for prefix in exclude_prefixes): + print(f"Excluding parameter: {k}") # For debugging + continue + if k in exclude_keys: + print(f"Excluding specific parameter: {k}") # For debugging + continue + filtered[k] = v + return filtered + + # Filter and load the 'model' state_dict. + if 'model' in state_dict: + model_state_dict = state_dict['model'] + filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + if missing_keys: + print(f"Missing keys when loading _learn_model: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") + else: + print("No 'model' key found in the state_dict.") + + # Filter and load the 'target_model' state_dict. + if 'target_model' in state_dict: + target_model_state_dict = state_dict['target_model'] + filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + if missing_keys: + print(f"Missing keys when loading _target_model: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys when loading _target_model: {unexpected_keys}") + else: + print("No 'target_model' key found in the state_dict.") + + # Handle freezing/unfreezing of parameters in _learn_model based on finetune_components. + # This assumes a naming convention where component names are present in parameter names. + for name, param in self._learn_model.named_parameters(): + # Freeze the encoder if "encoder" is not in finetune_components. + if "encoder" in name and "encoder" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + # Freeze the representation network if "representation_network" is not in finetune_components. + elif "representation_network" in name and "representation_network" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + # Freeze the transformer if "transformer" is not in finetune_components. + elif "transformer" in name and "transformer" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + else: + # Other parameters remain trainable by default. + print(f"Parameter remains trainable: {name}") + + # NOTE: For more complex model structures, it might be better to identify modules by their class + # rather than relying on parameter names. For example: + # for module in self._learn_model.modules(): + # if isinstance(module, EncoderModule) and "encoder" not in finetune_components: + # for param in module.parameters(): + # param.requires_grad = False + + # ========== NOTE: Another pretrain-finetune version. The main difference from the above is the freezing logic and comments. ========== + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Loads a state_dict into the policy's learn mode, excluding multi-task related parameters. + # This is intended for fine-tuning a pre-trained model on new tasks. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The state dictionary from a pre-trained model. + # """ + # # Define prefixes of parameters to be excluded. + # exclude_prefixes = [ + # '_orig_mod.world_model.head_policy_multi_task.', + # '_orig_mod.world_model.head_value_multi_task.', + # '_orig_mod.world_model.head_rewards_multi_task.', + # '_orig_mod.world_model.head_observations_multi_task.', + # '_orig_mod.world_model.task_emb.' + # ] + + # # Define specific parameter keys to be excluded. + # exclude_keys = [ + # '_orig_mod.world_model.task_emb.weight', + # '_orig_mod.world_model.task_emb.bias', + # ] + + # def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + # """ + # Filters out parameters that should not be loaded. + # """ + # filtered = {} + # for k, v in state_dict_loader.items(): + # if any(k.startswith(prefix) for prefix in exclude_prefixes): + # print(f"Excluding parameter: {k}") + # continue + # if k in exclude_keys: + # print(f"Excluding specific parameter: {k}") + # continue + # filtered[k] = v + # return filtered + + # # Filter and load the 'model' part. + # if 'model' in state_dict: + # model_state_dict = state_dict['model'] + # filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + # missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + # if missing_keys: + # print(f"Missing keys when loading _learn_model: {missing_keys}") + # if unexpected_keys: + # print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") + # else: + # print("No 'model' key found in the state_dict.") + + # # Filter and load the 'target_model' part. + # if 'target_model' in state_dict: + # target_model_state_dict = state_dict['target_model'] + # filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + # missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + # if missing_keys: + # print(f"Missing keys when loading _target_model: {missing_keys}") + # if unexpected_keys: + # print(f"Unexpected keys when loading _target_model: {unexpected_keys}") + # else: + # print("No 'target_model' key found in the state_dict.") + + # # Do not load the optimizer's state_dict when fine-tuning, as it contains state (like momentum) + # # specific to the pre-training task, which can hinder adaptation to new tasks. + # # A fresh optimizer is usually preferred. + # # if 'optimizer_world_model' in state_dict: + # # ... \ No newline at end of file diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index 8b25c98b7..1dd85d259 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -211,29 +211,69 @@ def forward(self, input): return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) -# modified from https://github.com/karpathy/nanoGPT/blob/master/model.py#L263 -def configure_optimizers_nanogpt(model, weight_decay, learning_rate, betas, device_type): - # start with all of the candidate parameters +# The following code is modified from the original implementation at: +# https://github.com/karpathy/nanoGPT/blob/master/model.py#L263 + +def configure_optimizers_nanogpt( + model: nn.Module, + weight_decay: float, + learning_rate: float, + betas: Tuple[float, float], + device_type: str +) -> torch.optim.AdamW: + """ + Overview: + Configures the AdamW optimizer for the nanoGPT model. This function separates model + parameters into two groups: one that will be subject to weight decay and one that will not. + Typically, 2D and higher-dimensional tensors (e.g., weights of linear layers) are decayed, + while 1D tensors (e.g., biases and LayerNorm weights) are not. + + Arguments: + - model (:obj:`nn.Module`): The model for which to configure optimizers. + - weight_decay (:obj:`float`): The weight decay coefficient to apply. + - learning_rate (:obj:`float`): The learning rate for the optimizer. + - betas (:obj:`Tuple[float, float]`): The beta coefficients for the AdamW optimizer (e.g., (0.9, 0.95)). + - device_type (:obj:`str`): The type of device being used, e.g., 'cuda' or 'cpu'. + + Returns: + (:obj:`torch.optim.AdamW`): The configured AdamW optimizer instance. + """ + # Start with all of the candidate parameters from the model. param_dict = {pn: p for pn, p in model.named_parameters()} - # filter out those that do not require grad - param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} - # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. - # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + + # TODO: The following code is commented out, which is crucial for a balanced pipeline. + # We do not filter out parameters with `requires_grad=False` because their `requires_grad` + # attribute might be set to `True` at a later stage during training. + # param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + + # Create optimizer parameter groups. Any parameter that is 2D or higher will be weight decayed, + # otherwise no. i.e. all weight tensors in matrix multiplications and embeddings will be decayed, + # while all biases and layernorm weights will not. decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] optim_groups = [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': nodecay_params, 'weight_decay': 0.0} ] + num_decay_params = sum(p.numel() for p in decay_params) num_nodecay_params = sum(p.numel() for p in nodecay_params) print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") - # Create AdamW optimizer and use the fused version if it is available + + # Create the AdamW optimizer. + # Check if a fused version of AdamW is available in the current PyTorch installation. fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters + + # Note: The current logic creates a standard AdamW optimizer on CUDA-enabled systems. + # The 'fused' version is only considered on non-CUDA systems, where it will ultimately not be used + # because `device_type` would not be 'cuda'. if torch.cuda.is_available(): + # On a CUDA-enabled system, create a standard AdamW optimizer. optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) else: + # On a non-CUDA system, check if the fused optimizer can be used. + # This will be False if device_type is not 'cuda'. use_fused = fused_available and device_type == 'cuda' extra_args = dict(fused=True) if use_fused else dict() optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) @@ -372,7 +412,7 @@ def prepare_obs_stack_for_unizero(obs_batch_ori: np.ndarray, cfg: EasyDict) -> T return obs_batch, obs_target_batch -def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]: +def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict, task_id = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: Prepare the observations for the model by converting the original batch of observations @@ -395,9 +435,12 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # Calculate the dimension size to slice based on the model configuration. # For convolutional models ('conv'), use the number of frames to stack times the number of channels. # For multi-layer perceptron models ('mlp'), use the number of frames to stack times the size of the observation space. - stack_dim = cfg.model.frame_stack_num * ( + if task_id is None: + stack_dim = cfg.model.frame_stack_num * ( cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape) - + else: + stack_dim = cfg.model.frame_stack_num * ( + cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape_list[task_id]) # Slice the original observation tensor to obtain the batch for the initial inference. obs_batch = obs_batch_ori[:, :stack_dim] @@ -408,7 +451,10 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # Determine the starting dimension to exclude based on the model type. # For 'conv', exclude the first 'image_channel' dimensions. # For 'mlp', exclude the first 'observation_shape' dimensions. - exclude_dim = cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape + if task_id is None: + exclude_dim = cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape + else: + exclude_dim = cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape_list[task_id] # Slice the original observation tensor to obtain the batch for consistency loss calculation. obs_target_batch = obs_batch_ori[:, exclude_dim:] @@ -565,6 +611,10 @@ def concat_output_value(output_lst: List) -> np.ndarray: for output in output_lst: value_lst.append(output.value) + # print(f'value_lst:{value_lst}') + # print(f'value_lst[0]:{value_lst[0]}') + # print(f'value_lst[0].shape:{value_lst[0].shape}') + value_lst = np.concatenate(value_lst) return value_lst diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 4d3b1b740..06fa3b580 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -1,7 +1,6 @@ -import os import time from collections import deque, namedtuple -from typing import Optional, Any, List +from typing import Optional, Any, List, Dict, Set import numpy as np import torch @@ -16,70 +15,77 @@ from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation -from lzero.policy.utils import compute_bleu @SERIAL_COLLECTOR_REGISTRY.register('episode_muzero') class MuZeroCollector(ISerialCollector): """ Overview: - The Episode Collector for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel MuZero. - It manages the data collection process for training these algorithms using a serial mechanism. + The episode-based collector for MCTS-based reinforcement learning algorithms, + including MuZero, EfficientZero, Sampled EfficientZero, and Gumbel MuZero. + It orchestrates the data collection process in a serial manner, managing interactions + between the policy and the environment to generate game segments for training. Interfaces: - ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``envstep``, ``__del__``, ``_compute_priorities``, - ``pad_and_save_last_trajectory``, ``collect``, ``_output_log``, ``close`` + ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``collect``, + ``_compute_priorities``, ``pad_and_save_last_trajectory``, ``_output_log``, ``close``, ``__del__``. Properties: - ``envstep`` + ``envstep``. """ - # TO be compatible with ISerialCollector + # Default configuration for the collector. To be compatible with ISerialCollector. config = dict() def __init__( self, collect_print_freq: int = 100, - env: BaseEnvManager = None, - policy: namedtuple = None, + env: Optional[BaseEnvManager] = None, + policy: Optional[namedtuple] = None, tb_logger: 'SummaryWriter' = None, # noqa - exp_name: Optional[str] = 'default_experiment', - instance_name: Optional[str] = 'collector', + exp_name: str = 'default_experiment', + instance_name: str = 'collector', policy_config: 'policy_config' = None, # noqa + task_id: Optional[int] = None, ) -> None: """ Overview: - Initialize the MuZeroCollector with the given parameters. + Initializes the MuZeroCollector with the given configuration. Arguments: - - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information. - - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager. - - policy (:obj:`Optional[namedtuple]`): namedtuple of the collection mode policy API. - - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance. - - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. - - instance_name (:obj:`str`): Unique identifier for this collector instance. - - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. + - collect_print_freq (:obj:`int`): The frequency (in training iterations) at which to print collection statistics. + - env (:obj:`Optional[BaseEnvManager]`): An instance of a vectorized environment manager. + - policy (:obj:`Optional[namedtuple]`): A namedtuple containing the policy's forward pass and other methods. + - tb_logger (:obj:`Optional[SummaryWriter]`): A TensorBoard logger instance for logging metrics. + - exp_name (:obj:`str`): The name of the experiment, used for organizing logs. + - instance_name (:obj:`str`): A unique name for this collector instance. + - policy_config (:obj:`'policy_config'`): The configuration object for the policy. + - task_id (:obj:`Optional[int]`): The identifier for the current task in a multi-task setting. If None, operates in single-task mode. """ + self.task_id = task_id self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq self._timer = EasyTimer() self._end_flag = False + # Get distributed training info self._rank = get_rank() self._world_size = get_world_size() + + # Logger setup: only rank 0 creates the main logger and TensorBoard logger. if self._rank == 0: if tb_logger is not None: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False ) self._tb_logger = tb_logger else: self._logger, self._tb_logger = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name ) else: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False ) self._tb_logger = None @@ -91,12 +97,11 @@ def __init__( def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset or replace the environment managed by this collector. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the collector with the new passed \ - in environment and launch. + Resets or replaces the environment managed by the collector. + If `_env` is None, it resets the existing environment. Otherwise, it replaces the old + environment with the new one and launches it. Arguments: - - env (:obj:`Optional[BaseEnvManager]`): New environment to manage, if provided. + - _env (:obj:`Optional[BaseEnvManager]`): The new environment to be used. If None, resets the current environment. """ if _env is not None: self._env = _env @@ -108,42 +113,39 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Reset or replace the policy used by this collector. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the collector with the new passed in policy. + Resets or replaces the policy used by the collector. + If `_policy` is None, it resets the existing policy. Otherwise, it replaces the old + policy with the new one. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + - _policy (:obj:`Optional[namedtuple]`): The new policy to be used. """ - assert hasattr(self, '_env'), "please set env first" + assert hasattr(self, '_env'), "Please set env first before resetting policy." if _policy is not None: self._policy = _policy self._default_n_episode = _policy.get_attribute('cfg').get('n_episode', None) self._logger.debug( - 'Set default n_episode mode(n_episode({}), env_num({}))'.format(self._default_n_episode, self._env_num) + f"Set default n_episode mode(n_episode({self._default_n_episode}), env_num({self._env_num}))" ) self._policy.reset() def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset the collector with the given policy and/or environment. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the collector with the new passed \ - in environment and launch. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the collector with the new passed in policy. + Resets the collector, including the environment and policy. Also re-initializes + internal state variables for tracking collection progress. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) + - _policy (:obj:`Optional[namedtuple]`): The new policy to use. + - _env (:obj:`Optional[BaseEnvManager]`): The new environment to use. """ if _env is not None: self.reset_env(_env) if _policy is not None: self.reset_policy(_policy) - self._env_info = {env_id: {'time': 0., 'step': 0, 'text_bleu': 0.} for env_id in range(self._env_num)} + # Initialize per-environment tracking info + self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} + # Reset overall statistics self._episode_info = [] self._total_envstep_count = 0 self._total_episode_count = 0 @@ -151,36 +153,35 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self._last_train_iter = 0 self._end_flag = False - # A game_segment_pool implementation based on the deque structure. + # A pool to store completed game segments, implemented using a deque. self.game_segment_pool = deque(maxlen=int(1e6)) self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps def _reset_stat(self, env_id: int) -> None: """ Overview: - Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool \ - and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ - to get more messages. + Resets the statistics for a specific environment, identified by `env_id`. + This is typically called when an episode in that environment ends. Arguments: - - env_id (:obj:`int`): the id where we need to reset the collector's state + - env_id (:obj:`int`): The ID of the environment to reset statistics for. """ - self._env_info[env_id] = {'time': 0., 'step': 0, 'text_bleu': 0.} + self._env_info[env_id] = {'time': 0., 'step': 0} @property def envstep(self) -> int: """ Overview: - Get the total number of environment steps collected. + Returns the total number of environment steps collected since the last reset. Returns: - - envstep (:obj:`int`): Total number of environment steps collected. + - envstep (:obj:`int`): The total environment step count. """ return self._total_envstep_count def close(self) -> None: """ Overview: - Close the collector. If end_flag is False, close the environment, flush the tb_logger \ - and close the tb_logger. + Closes the collector, including the environment and any loggers. + Ensures that all resources are properly released. """ if self._end_flag: return @@ -193,627 +194,454 @@ def close(self) -> None: def __del__(self) -> None: """ Overview: - Execute the close command and close the collector. __del__ is automatically called to \ - destroy the collector instance when the collector finishes its work + Destructor for the collector instance, ensuring that `close` is called + to clean up resources. """ self.close() # ============================================================== - # MCTS+RL related core code + # MCTS+RL Core Collection Logic # ============================================================== - def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> np.ndarray: + def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> Optional[np.ndarray]: """ Overview: - Compute the priorities for transitions based on prediction and search value discrepancies. + Computes priorities for experience replay based on the discrepancy between + predicted values and MCTS search values. Arguments: - - i (:obj:`int`): Index of the values in the list to compute the priority for. - - pred_values_lst (:obj:`List[float]`): List of predicted values. - - search_values_lst (:obj:`List[float]`): List of search values obtained from MCTS. + - i (:obj:`int`): The index of the environment's data in the lists. + - pred_values_lst (:obj:`List[float]`): A list containing lists of predicted values for each environment. + - search_values_lst (:obj:`List[float]`): A list containing lists of search values from MCTS for each environment. Returns: - - priorities (:obj:`np.ndarray`): Array of computed priorities. + - priorities (:obj:`Optional[np.ndarray]`): An array of priorities for the transitions. Returns None if priority is not used. """ if self.policy_config.use_priority: - # Calculate priorities. The priorities are the L1 losses between the predicted - # values and the search values. We use 'none' as the reduction parameter, which - # means the loss is calculated for each element individually, instead of being summed or averaged. - # A small constant (1e-6) is added to the results to avoid zero priorities. This - # is done because zero priorities could potentially cause issues in some scenarios. + # Calculate priorities as the L1 loss between predicted values and search values. + # 'reduction=none' ensures the loss is calculated for each element individually. pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1) - search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device - ).float().view(-1) - priorities = L1Loss(reduction='none' - )(pred_values, - search_values).detach().cpu().numpy() + 1e-6 + search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device).float().view(-1) + + # A small epsilon is added to avoid zero priorities. + priorities = L1Loss(reduction='none')(pred_values, search_values).detach().cpu().numpy() + 1e-6 else: - # priorities is None -> use the max priority for all newly collected data + # If priority is not used, return None. The replay buffer will use max priority for new data. priorities = None return priorities - def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegment], - last_game_priorities: List[np.ndarray], - game_segments: List[GameSegment], done: np.ndarray) -> None: + def pad_and_save_last_trajectory( + self, i: int, last_game_segments: List[Optional[GameSegment]], + last_game_priorities: List[Optional[np.ndarray]], + game_segments: List[GameSegment], done: np.ndarray + ) -> None: """ Overview: - Save the game segment to the pool if the current game is finished, padding it if necessary. + Pads the end of the `last_game_segment` with data from the start of the current `game_segment`. + This is necessary to compute target values for the final transitions of a segment. After padding, + the completed segment is stored in the `game_segment_pool`. Arguments: - - i (:obj:`int`): Index of the current game segment. - - last_game_segments (:obj:`List[GameSegment]`): List of the last game segments to be padded and saved. - - last_game_priorities (:obj:`List[np.ndarray]`): List of priorities of the last game segments. - - game_segments (:obj:`List[GameSegment]`): List of the current game segments. - - done (:obj:`np.ndarray`): Array indicating whether each game is done. + - i (:obj:`int`): The index of the environment being processed. + - last_game_segments (:obj:`List[Optional[GameSegment]]`): List of game segments from the previous collection chunk. + - last_game_priorities (:obj:`List[Optional[np.ndarray]]`): List of priorities corresponding to the last game segments. + - game_segments (:obj:`List[GameSegment]`): List of game segments from the current collection chunk. + - done (:obj:`np.ndarray`): Array indicating if the episode has terminated for each environment. Note: - (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True + An implicit assumption is that the start of the new segment's observation history overlaps with the + end of the last segment's, e.g., `(last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all()` is True. """ - # pad over last segment trajectory - beg_index = self.policy_config.model.frame_stack_num - end_index = beg_index + self.policy_config.num_unroll_steps + self.policy_config.td_steps - - # the start obs is init zero obs, so we take the - # [ : +] obs as the pad obs - # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs - pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] - - # NOTE: for unizero - beg_index = 0 - end_index = beg_index + self.policy_config.num_unroll_steps + self.policy_config.td_steps - pad_action_lst = game_segments[i].action_segment[beg_index:end_index] - - # NOTE: for unizero - pad_child_visits_lst = game_segments[i].child_visit_segment[ - :self.policy_config.num_unroll_steps + self.policy_config.td_steps] - - # EfficientZero original repo bug: - # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] - - beg_index = 0 - end_index = beg_index + self.unroll_plus_td_steps - 1 + # --- Prepare padding data from the current game segment --- + # Observations for padding are taken from the start of the new segment. + beg_index_obs = self.policy_config.model.frame_stack_num + end_index_obs = beg_index_obs + self.policy_config.num_unroll_steps + self.policy_config.td_steps + pad_obs_lst = game_segments[i].obs_segment[beg_index_obs:end_index_obs] + + # Actions for padding. + beg_index_ac = 0 + end_index_ac = beg_index_ac + self.policy_config.num_unroll_steps + self.policy_config.td_steps + pad_action_lst = game_segments[i].action_segment[beg_index_ac:end_index_ac] + + # Child visits for padding. + pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] + + # Rewards for padding. + beg_index_rew = 0 + end_index_rew = beg_index_rew + self.unroll_plus_td_steps - 1 + pad_reward_lst = game_segments[i].reward_segment[beg_index_rew:end_index_rew] + + # Root values for padding. + beg_index_val = 0 + end_index_val = beg_index_val + self.unroll_plus_td_steps + pad_root_values_lst = game_segments[i].root_value_segment[beg_index_val:end_index_val] - pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_lst = game_segments[i].chance_segment[beg_index:end_index] - - beg_index = 0 - end_index = beg_index + self.unroll_plus_td_steps - - pad_root_values_lst = game_segments[i].root_value_segment[beg_index:end_index] - + chance_lst = game_segments[i].chance_segment[beg_index_rew:end_index_rew] + if self.policy_config.gumbel_algo: - pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index:end_index] + pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index_val:end_index_val] - # pad over and save + # --- Pad the last game segment and save it --- if self.policy_config.gumbel_algo: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, - next_segment_improved_policy=pad_improved_policy_prob) + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, + pad_child_visits_lst, next_segment_improved_policy=pad_improved_policy_prob + ) else: if self.policy_config.use_ture_chance_label_in_chance_encoder: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, - next_chances=chance_lst) + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, + pad_child_visits_lst, next_chances=chance_lst + ) else: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst) - """ - Note: - game_segment element shape: - obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 - rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 - action: game_segment_length -> 20 - root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 - child_visits: game_segment_length + num_unroll_steps -> 20 +5 - to_play: game_segment_length -> 20 - action_mask: game_segment_length -> 20 - """ - + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst + ) + + # Convert the segment's lists to NumPy arrays for efficient storage. last_game_segments[i].game_segment_to_array() - # put the game segment into the pool + # Add the completed game segment and its associated data to the pool. self.game_segment_pool.append((last_game_segments[i], last_game_priorities[i], done[i])) - # reset last game_segments + # Reset the placeholder for the last game segment. last_game_segments[i] = None last_game_priorities[i] = None - return None - - def collect(self, - n_episode: Optional[int] = None, - train_iter: int = 0, - policy_kwargs: Optional[dict] = None, - collect_with_pure_policy: bool = False) -> List[Any]: + def collect( + self, + n_episode: Optional[int] = None, + train_iter: int = 0, + policy_kwargs: Optional[Dict] = None, + collect_with_pure_policy: bool = False + ) -> List[Any]: """ Overview: - Collect `n_episode` episodes of data with policy_kwargs, trained for `train_iter` iterations. + Collects `n_episode` episodes of data. It manages the entire lifecycle of an episode, + from getting actions from the policy, stepping the environment, storing transitions, + and saving completed game segments. Arguments: - - n_episode (:obj:`Optional[int]`): Number of episodes to collect. - - train_iter (:obj:`int`): Number of training iterations completed so far. - - policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy. - - collect_with_pure_policy (:obj:`bool`): Whether to collect data using pure policy without MCTS. + - n_episode (:obj:`Optional[int]`): The number of episodes to collect. If None, uses the default from the policy config. + - train_iter (:obj:`int`): The current training iteration, used for logging. + - policy_kwargs (:obj:`Optional[Dict]`): Additional keyword arguments to pass to the policy's forward method, like temperature for exploration. + - collect_with_pure_policy (:obj:`bool`): If True, collects data using a pure policy (e.g., greedy action) without MCTS. Returns: - - return_data (:obj:`List[Any]`): Collected data in the form of a list. + - return_data (:obj:`List[Any]`): A list containing the collected game segments and metadata. """ - # TODO: collect_with_pure_policy as a separate collector + # TODO(author): Consider implementing `collect_with_pure_policy` as a separate, more streamlined collector for clarity and modularity. if n_episode is None: if self._default_n_episode is None: - raise RuntimeError("Please specify collect n_episode") + raise RuntimeError("Please specify `n_episode` for collection.") else: n_episode = self._default_n_episode - assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num) + assert n_episode >= self._env_num, f"Please ensure n_episode ({n_episode}) >= env_num ({self._env_num})." + if policy_kwargs is None: policy_kwargs = {} - temperature = policy_kwargs['temperature'] - epsilon = policy_kwargs['epsilon'] + temperature = policy_kwargs.get('temperature', 1.0) + epsilon = policy_kwargs.get('epsilon', 0.0) + # --- Initializations --- collected_episode = 0 - collected_step = 0 env_nums = self._env_num retry_waiting_time = 0.05 - # initializations + # Wait for all environments to be ready and get initial observations. init_obs = self._env.ready_obs while len(init_obs.keys()) != self._env_num: - # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + self._logger.warning(f"Waiting for all environments to reset. Ready envs: {list(init_obs.keys())}") time.sleep(retry_waiting_time) - self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) - ) init_obs = self._env.ready_obs + # Prepare initial state dictionaries from observations. action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} - - timestep_dict = {} - for i in range(env_nums): - if 'timestep' not in init_obs[i]: - if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - print(f"Warning: 'timestep' key is missing in init_obs[{i}]. Assigning value -1. Please note that the unizero algorithm may require the 'timestep' key in init_obs.") - timestep_dict[i] = to_ndarray(init_obs[i].get('timestep', -1)) - + timestep_dict = {i: to_ndarray(init_obs[i].get('timestep', -1)) for i in range(env_nums)} if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} - game_segments = [ - GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) for _ in range(env_nums) - ] - # stacked observation windows in reset stage for init game_segments - observation_window_stack = [[] for _ in range(env_nums)] + # Initialize game segments and observation stacks for each environment. + game_segments = [GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) for _ in range(env_nums)] + observation_window_stack = [deque(maxlen=self.policy_config.model.frame_stack_num) for _ in range(env_nums)] for env_id in range(env_nums): - observation_window_stack[env_id] = deque( - [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) + for _ in range(self.policy_config.model.frame_stack_num): + observation_window_stack[env_id].append(to_ndarray(init_obs[env_id]['observation'])) game_segments[env_id].reset(observation_window_stack[env_id]) + # State tracking variables for the collection loop. dones = np.array([False for _ in range(env_nums)]) - last_game_segments = [None for _ in range(env_nums)] - last_game_priorities = [None for _ in range(env_nums)] - # for priorities in self-play + last_game_segments: List[Optional[GameSegment]] = [None for _ in range(env_nums)] + last_game_priorities: List[Optional[np.ndarray]] = [None for _ in range(env_nums)] + + # Buffers for priority calculation. search_values_lst = [[] for _ in range(env_nums)] pred_values_lst = [[] for _ in range(env_nums)] if self.policy_config.gumbel_algo: improved_policy_lst = [[] for _ in range(env_nums)] - # some logs - eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) + # Logging variables. + eps_steps_lst = np.zeros(env_nums) + visit_entropies_lst = np.zeros(env_nums) if self.policy_config.gumbel_algo: completed_value_lst = np.zeros(env_nums) - self_play_moves = 0. - self_play_episodes = 0. - self_play_moves_max = 0 - self_play_visit_entropy = [] - total_transitions = 0 - ready_env_id = set() + ready_env_id: Set[int] = set() remain_episode = n_episode if collect_with_pure_policy: - temp_visit_list = [0.0 for i in range(self._env.action_space.n)] + # Dummy visit counts for pure policy collection. + temp_visit_list = [0.0 for _ in range(self._env.action_space.n)] + # --- Main Collection Loop --- while True: with self._timer: - # Get current ready env obs. + # Get observations from ready environments. obs = self._env.ready_obs - new_available_env_id = set(obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + ready_env_id.update(list(new_available_env_id)[:remain_episode]) remain_episode -= min(len(new_available_env_id), remain_episode) - - # NOTE: If waiting for N environments to synchronize, it may result in some environments not being completed (done) by the time of return. - # However, the current muzero_collector does not properly maintain the global self.last_game_segments, leading to some data not being collected. - - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} - stack_obs = list(stack_obs.values()) - - action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} - to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} - timestep_dict = {env_id: timestep_dict[env_id] for env_id in ready_env_id} + # Prepare policy inputs. + stack_obs_list = [game_segments[env_id].get_obs() for env_id in ready_env_id] action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] timestep = [timestep_dict[env_id] for env_id in ready_env_id] - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} - - stack_obs = to_ndarray(stack_obs) - # return stack_obs shape: [B, S*C, W, H] e.g. [8, 4*1, 96, 96] - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device) + stack_obs_array = to_ndarray(stack_obs_list) + stack_obs_tensor = prepare_observation(stack_obs_array, self.policy_config.model.model_type) + stack_obs_tensor = torch.from_numpy(stack_obs_tensor).to(self.policy_config.device) # ============================================================== - # Key policy forward step + # Policy Forward Pass # ============================================================== - # print(f'ready_env_id:{ready_env_id}') - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep) - - pred_next_text_with_env_id = {k: v['predicted_next_text'] if 'predicted_next_text' in v else -1 for k, v in policy_output.items()} - - # Extract relevant policy outputs - actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} - value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} - pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} - timestep_dict_with_env_id = { - k: v['timestep'] if 'timestep' in v else -1 for k, v in policy_output.items() + policy_input = { + 'x': stack_obs_tensor, + 'action_mask': action_mask, + 'temperature': temperature, + 'to_play': to_play, + 'epsilon': epsilon, + 'ready_env_id': ready_env_id, + 'timestep': timestep } - + if self.task_id is not None: + policy_input['task_id'] = self.task_id + + policy_output = self._policy.forward(**policy_input) + + # --- Unpack policy outputs --- + actions, value_dict, pred_value_dict = {}, {}, {} + distributions_dict, visit_entropy_dict = {}, {} if self.policy_config.sampled_algo: - root_sampled_actions_dict_with_env_id = { - k: v['root_sampled_actions'] for k, v in policy_output.items() - } - - if not collect_with_pure_policy: - distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in - policy_output.items()} - visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in - policy_output.items()} + root_sampled_actions_dict = {} + if self.policy_config.gumbel_algo: + improved_policy_dict, completed_value_dict = {}, {} - if self.policy_config.gumbel_algo: - improved_policy_dict_with_env_id = {k: v['improved_policy_probs'] for k, v in - policy_output.items()} - completed_value_with_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()} - - # Initialize dictionaries to store results - actions = {} - value_dict = {} - pred_value_dict = {} - timestep_dict = {} - pred_next_text = {} - - if not collect_with_pure_policy: - distributions_dict = {} - visit_entropy_dict = {} - - if self.policy_config.sampled_algo: - root_sampled_actions_dict = {} - - if self.policy_config.gumbel_algo: - improved_policy_dict = {} - completed_value_dict = {} - - # Populate the result dictionaries for env_id in ready_env_id: - actions[env_id] = actions_with_env_id.pop(env_id) - value_dict[env_id] = value_dict_with_env_id.pop(env_id) - pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id) - timestep_dict[env_id] = timestep_dict_with_env_id.pop(env_id) - pred_next_text[env_id] = pred_next_text_with_env_id.pop(env_id) - + output = policy_output[env_id] + actions[env_id] = output['action'] + value_dict[env_id] = output['searched_value'] + pred_value_dict[env_id] = output['predicted_value'] + if not collect_with_pure_policy: - distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) - + distributions_dict[env_id] = output['visit_count_distributions'] + visit_entropy_dict[env_id] = output['visit_count_distribution_entropy'] if self.policy_config.sampled_algo: - root_sampled_actions_dict[env_id] = root_sampled_actions_dict_with_env_id.pop(env_id) - - visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) - + root_sampled_actions_dict[env_id] = output['root_sampled_actions'] if self.policy_config.gumbel_algo: - improved_policy_dict[env_id] = improved_policy_dict_with_env_id.pop(env_id) - completed_value_dict[env_id] = completed_value_with_env_id.pop(env_id) - + improved_policy_dict[env_id] = output['improved_policy_probs'] + completed_value_dict[env_id] = output['roots_completed_value'] + # ============================================================== - # Interact with the environment + # Environment Interaction # ============================================================== timesteps = self._env.step(actions) - interaction_duration = self._timer.value / len(timesteps) - - groundtrut_next_text = {} + interaction_duration = self._timer.value / len(timesteps) if timesteps else 0 + for env_id, episode_timestep in timesteps.items(): with self._timer: + # Handle abnormal timesteps by resetting the environment and policy state. if episode_timestep.info.get('abnormal', False): - # If there is an abnormal episode_timestep, reset all the related variables(including this env). - # suppose there is no reset param, reset this env self._env.reset({env_id: None}) self._policy.reset([env_id]) self._reset_stat(env_id) - self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, episode_timestep.info)) + self._logger.info(f"Environment {env_id} returned an abnormal step, info: {episode_timestep.info}") continue + obs, reward, done, info = episode_timestep.obs, episode_timestep.reward, episode_timestep.done, episode_timestep.info - - if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': - obs_input_ids = torch.tensor(obs['observation'], dtype=torch.long) # shape: [L] - obs_attn_mask = torch.tensor(obs['obs_attn_mask'][0], dtype=torch.long) - valid_input_ids = obs_input_ids[obs_attn_mask == 1].tolist() - - groundtrut_next_text[env_id] = self._env._envs[env_id].tokenizer.decode(valid_input_ids, skip_special_tokens=True) - text_bleu = compute_bleu(reference=groundtrut_next_text[env_id], prediction=pred_next_text[env_id]) - # Whether to output text comparisons with high BLEU scores to evaluate the effectiveness of decoding the next latent. - if text_bleu > 0.85: - os.makedirs("./log", exist_ok=True) - with open("./log/bleu_match.txt", "a", encoding="utf-8") as f: - f.write(f"pred_text={pred_next_text[env_id]}\ngroundtruth_text={groundtrut_next_text[env_id]}\ntext_bleu={text_bleu:.4f}\n\n") - + # Store MCTS search statistics. if collect_with_pure_policy: game_segments[env_id].store_search_stats(temp_visit_list, 0) else: if self.policy_config.sampled_algo: - game_segments[env_id].store_search_stats( - distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] - ) + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id]) elif self.policy_config.gumbel_algo: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], - improved_policy=improved_policy_dict[env_id]) + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], improved_policy=improved_policy_dict[env_id]) else: game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) - # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} - # in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment`` + # Append the current transition to the game segment. + append_args = (actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], to_play_dict[env_id]) if self.policy_config.use_ture_chance_label_in_chance_encoder: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], timestep_dict[env_id], chance_dict[env_id] - ) - else: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], timestep_dict[env_id] - ) + append_args += (chance_dict[env_id],) + append_args += (timestep_dict[env_id],) + game_segments[env_id].append(*append_args) - # NOTE: the position of code snippet is very important. - # the obs['action_mask'] and obs['to_play'] are corresponding to the next action + # Update state dictionaries for the next step. action_mask_dict[env_id] = to_ndarray(obs['action_mask']) to_play_dict[env_id] = to_ndarray(obs['to_play']) timestep_dict[env_id] = to_ndarray(obs.get('timestep', -1)) if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict[env_id] = to_ndarray(obs['chance']) - if self.policy_config.ignore_done: - dones[env_id] = False - else: - dones[env_id] = done - + dones[env_id] = done if not self.policy_config.ignore_done else False + + # Update logging and priority data. if not collect_with_pure_policy: visit_entropies_lst[env_id] += visit_entropy_dict[env_id] if self.policy_config.gumbel_algo: completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) - + eps_steps_lst[env_id] += 1 - if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - # only for UniZero now - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) - - total_transitions += 1 - if self.policy_config.use_priority: pred_values_lst[env_id].append(pred_value_dict[env_id]) search_values_lst[env_id].append(value_dict[env_id]) - if self.policy_config.gumbel_algo and not collect_with_pure_policy: - improved_policy_lst[env_id].append(improved_policy_dict[env_id]) - # append the newest obs + # Update the observation window with the new observation. observation_window_stack[env_id].append(to_ndarray(obs['observation'])) # ============================================================== - # we will save a game segment if it is the end of the game or the next game segment is finished. + # Game Segment Saving Logic # ============================================================== - - # if game segment is full, we will save the last game segment + # If a segment is full, pad and save the previous segment. if game_segments[env_id].is_full(): - # pad over last segment trajectory if last_game_segments[env_id] is not None: - # TODO(pu): return the one game segment - self.pad_and_save_last_trajectory( - env_id, last_game_segments, last_game_priorities, game_segments, dones - ) + self.pad_and_save_last_trajectory(env_id, last_game_segments, last_game_priorities, game_segments, dones) - # calculate priority + # Calculate priorities for the now-completed `last_game_segment`. priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] - if self.policy_config.gumbel_algo and not collect_with_pure_policy: - improved_policy_lst[env_id] = [] + pred_values_lst[env_id], search_values_lst[env_id] = [], [] - # the current game_segments become last_game_segment + # The current segment becomes the `last_game_segment`. last_game_segments[env_id] = game_segments[env_id] last_game_priorities[env_id] = priorities - # create new GameSegment - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) + # Start a new game segment. + game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) game_segments[env_id].reset(observation_window_stack[env_id]) self._env_info[env_id]['step'] += 1 - if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': - self._env_info[env_id]['text_bleu'] += text_bleu - collected_step += 1 self._env_info[env_id]['time'] += self._timer.value + interaction_duration - if episode_timestep.done: - reward = episode_timestep.info['eval_episode_return'] - info = { - 'reward': reward, - 'time': self._env_info[env_id]['time'], - 'step': self._env_info[env_id]['step'], - } - if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': - info.update({'text_bleu':self._env_info[env_id]['text_bleu'] / self._env_info[env_id]['step']}) - + + # --- Episode Termination Handling --- + if done: + collected_episode += 1 + reward = info['eval_episode_return'] + log_info = {'reward': reward, 'time': self._env_info[env_id]['time'], 'step': self._env_info[env_id]['step']} if not collect_with_pure_policy: - info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] + log_info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 if self.policy_config.gumbel_algo: - info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] + log_info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 + self._episode_info.append(log_info) - collected_episode += 1 - self._episode_info.append(info) - - # ============================================================== - # if it is the end of the game, we will save the game segment - # ============================================================== - - # NOTE: put the penultimate game segment in one episode into the trajectory_pool - # pad over 2th last game_segment using the last game_segment + # Pad and save the segment before the final one. if last_game_segments[env_id] is not None: - self.pad_and_save_last_trajectory( - env_id, last_game_segments, last_game_priorities, game_segments, dones - ) - - # store current segment trajectory + self.pad_and_save_last_trajectory(env_id, last_game_segments, last_game_priorities, game_segments, dones) + + # Process and save the final segment of the episode. priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - - # NOTE: put the last game segment in one episode into the trajectory_pool game_segments[env_id].game_segment_to_array() - - # assert len(game_segments[env_id]) == len(priorities) - # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null - if len(game_segments[env_id].reward_segment) != 0: + if len(game_segments[env_id].reward_segment) > 0: self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id])) - # print(game_segments[env_id].reward_segment) - # reset the finished env and init game_segments + # Reset environment-specific states for a new episode. if n_episode > self._env_num: - # Get current ready env obs. + # Re-initialize the state for this env_id. init_obs = self._env.ready_obs - retry_waiting_time = 0.001 - while len(init_obs.keys()) != self._env_num: - # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + while env_id not in init_obs: + self._logger.warning(f"Waiting for env {env_id} to reset...") time.sleep(retry_waiting_time) - self._logger.info( - '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 - ) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format( - retry_waiting_time, self._env._env_states - ) - ) init_obs = self._env.ready_obs - - new_available_env_id = set(init_obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) - remain_episode -= min(len(new_available_env_id), remain_episode) - + action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) timestep_dict[env_id] = to_ndarray(init_obs[env_id].get('timestep', -1)) - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) - - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) - observation_window_stack[env_id] = deque( - [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) + chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) + + # Reset game segment and observation stack. + game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) + observation_window_stack[env_id].clear() + for _ in range(self.policy_config.model.frame_stack_num): + observation_window_stack[env_id].append(init_obs[env_id]['observation']) game_segments[env_id].reset(observation_window_stack[env_id]) last_game_segments[env_id] = None last_game_priorities[env_id] = None - # log - self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) - if not collect_with_pure_policy: - self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) - self_play_moves += eps_steps_lst[env_id] - self_play_episodes += 1 - - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] - eps_steps_lst[env_id] = 0 - visit_entropies_lst[env_id] = 0 + # Reset tracking and logging variables. + pred_values_lst[env_id], search_values_lst[env_id] = [], [] + eps_steps_lst[env_id], visit_entropies_lst[env_id] = 0, 0 + if self.policy_config.gumbel_algo: + completed_value_lst[env_id] = 0 - # Env reset is done by env_manager automatically - self._policy.reset([env_id]) # NOTE: reset the policy for the env_id. Default reset_init_data=True. + # Reset policy and collector stats for the finished environment. + self._policy.reset([env_id]) self._reset_stat(env_id) ready_env_id.remove(env_id) + # --- Check for Collection Completion --- if collected_episode >= n_episode: - # [data, meta_data] - return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [ - { - 'priorities': self.game_segment_pool[i][1], - 'done': self.game_segment_pool[i][2], + # Prepare data for returning. + return_data = [ + [item[0] for item in self.game_segment_pool], + [{ + 'priorities': item[1], + 'done': item[2], 'unroll_plus_td_steps': self.unroll_plus_td_steps - } for i in range(len(self.game_segment_pool)) + } for item in self.game_segment_pool] ] self.game_segment_pool.clear() break - + + # --- Finalize and Log --- collected_duration = sum([d['time'] for d in self._episode_info]) - # reduce data when enables DDP + # In DDP, aggregate statistics across all processes. if self._world_size > 1: - # Before allreduce - self._logger.info(f"Rank {self._rank} before allreduce: collected_step={collected_step}, collected_episode={collected_episode}") collected_step = allreduce_data(collected_step, 'sum') collected_episode = allreduce_data(collected_episode, 'sum') collected_duration = allreduce_data(collected_duration, 'sum') - # After allreduce - self._logger.info(f"Rank {self._rank} after allreduce: collected_step={collected_step}, collected_episode={collected_episode}") self._total_envstep_count += collected_step self._total_episode_count += collected_episode self._total_duration += collected_duration - # log self._output_log(train_iter) return return_data def _output_log(self, train_iter: int) -> None: """ Overview: - Log the collector's data and output the log information. + Aggregates and logs collection statistics to the console, TensorBoard, and WandB. + This method is only executed by the rank 0 process in a distributed setup. Arguments: - - train_iter (:obj:`int`): Current training iteration number for logging context. + - train_iter (:obj:`int`): The current training iteration number, used as the logging step. """ if self._rank != 0: return + if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: self._last_train_iter = train_iter episode_count = len(self._episode_info) envstep_count = sum([d['step'] for d in self._episode_info]) duration = sum([d['time'] for d in self._episode_info]) episode_reward = [d['reward'] for d in self._episode_info] - if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': - episode_bleu = [d['text_bleu'] for d in self._episode_info] - - if not self.collect_with_pure_policy: - visit_entropy = [d['visit_entropy'] for d in self._episode_info] - else: - visit_entropy = [0.0] - if self.policy_config.gumbel_algo: - completed_value = [d['completed_value'] for d in self._episode_info] - self._total_duration += duration + info = { 'episode_count': episode_count, 'envstep_count': envstep_count, 'avg_envstep_per_episode': envstep_count / episode_count, - 'avg_envstep_per_sec': envstep_count / duration, - 'avg_episode_per_sec': episode_count / duration, + 'avg_envstep_per_sec': envstep_count / duration if duration > 0 else 0, + 'avg_episode_per_sec': episode_count / duration if duration > 0 else 0, 'collect_time': duration, 'reward_mean': np.mean(episode_reward), 'reward_std': np.std(episode_reward), @@ -822,22 +650,32 @@ def _output_log(self, train_iter: int) -> None: 'total_envstep_count': self._total_envstep_count, 'total_episode_count': self._total_episode_count, 'total_duration': self._total_duration, - 'visit_entropy': np.mean(visit_entropy), } - if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': - info.update({'text_avg_bleu':np.mean(episode_bleu)}) + + if not self.collect_with_pure_policy: + visit_entropy = [d['visit_entropy'] for d in self._episode_info] + info['visit_entropy_mean'] = np.mean(visit_entropy) if self.policy_config.gumbel_algo: - info['completed_value'] = np.mean(completed_value) + completed_value = [d['completed_value'] for d in self._episode_info] + info['completed_value_mean'] = np.mean(completed_value) + self._episode_info.clear() - self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) + # Log to console + self._logger.info("Collector Training Summary:\n{}".format('\n'.join([f' {k}: {v}' for k, v in info.items()]))) + + # Log to TensorBoard and WandB for k, v in info.items(): - if k in ['each_reward']: - continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - if k in ['total_envstep_count']: - continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) - + if self.task_id is None: + tb_prefix_iter = f'{self._instance_name}_iter/' + tb_prefix_step = f'{self._instance_name}_step/' + else: + tb_prefix_iter = f'{self._instance_name}_iter_task{self.task_id}/' + tb_prefix_step = f'{self._instance_name}_step_task{self.task_id}/' + + self._tb_logger.add_scalar(tb_prefix_iter + k, v, train_iter) + self._tb_logger.add_scalar(tb_prefix_step + k, v, self._total_envstep_count) + if self.policy_config.use_wandb: - wandb.log({'{}_step/'.format(self._instance_name) + k: v for k, v in info.items()}, step=self._total_envstep_count) + wandb_log_data = {tb_prefix_step + k: v for k, v in info.items()} + wandb.log(wandb_log_data, step=self._total_envstep_count) \ No newline at end of file diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 2a70feea5..01fabd38c 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -15,80 +15,92 @@ from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation +import threading class MuZeroEvaluator(ISerialEvaluator): """ Overview: - The Evaluator class for MCTS+RL algorithms, such as MuZero, EfficientZero, and Sampled EfficientZero. + The Evaluator for MCTS-based reinforcement learning algorithms, such as MuZero, EfficientZero, and Sampled EfficientZero. Interfaces: __init__, reset, reset_policy, reset_env, close, should_eval, eval Properties: env, policy """ + # Default configuration for the MuZeroEvaluator. + config = dict( + # The frequency of evaluation, measured in training iterations. + eval_freq=50, + ) + @classmethod def default_config(cls: type) -> EasyDict: """ Overview: - Retrieve the default configuration for the evaluator by merging evaluator-specific defaults with other - defaults and any user-provided configuration. + Get the default configuration of the MuZeroEvaluator. Returns: - - cfg (:obj:`EasyDict`): The default configuration for the evaluator. + - cfg (:obj:`EasyDict`): An EasyDict object representing the default configuration. """ cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' return cfg - config = dict( - # Evaluate every "eval_freq" training iterations. - eval_freq=50, - ) - def __init__( self, eval_freq: int = 1000, n_evaluator_episode: int = 3, - stop_value: int = 1e6, - env: BaseEnvManager = None, - policy: namedtuple = None, - tb_logger: 'SummaryWriter' = None, # noqa - exp_name: Optional[str] = 'default_experiment', - instance_name: Optional[str] = 'evaluator', - policy_config: 'policy_config' = None, # noqa + stop_value: float = 1e6, + env: Optional[BaseEnvManager] = None, + policy: Optional[namedtuple] = None, + tb_logger: Optional['SummaryWriter'] = None, + exp_name: str = 'default_experiment', + instance_name: str = 'evaluator', + policy_config: Optional[EasyDict] = None, + task_id: Optional[int] = None, ) -> None: """ Overview: - Initialize the evaluator with configuration settings for various components such as logger helper and timer. + Initialize the MuZeroEvaluator. Arguments: - - eval_freq (:obj:`int`): Evaluation frequency in terms of training steps. - - n_evaluator_episode (:obj:`int`): Number of episodes to evaluate in total. - - stop_value (:obj:`float`): A reward threshold above which the training is considered converged. - - env (:obj:`Optional[BaseEnvManager]`): An optional instance of a subclass of BaseEnvManager. - - policy (:obj:`Optional[namedtuple]`): An optional API namedtuple defining the policy for evaluation. - - tb_logger (:obj:`Optional[SummaryWriter]`): Optional TensorBoard logger instance. - - exp_name (:obj:`str`): Name of the experiment, used to determine output directory. - - instance_name (:obj:`str`): Name of this evaluator instance. - - policy_config (:obj:`Optional[dict]`): Optional configuration for the game policy. + - eval_freq (:obj:`int`): The frequency, in training iterations, at which to run evaluation. + - n_evaluator_episode (:obj:`int`): The total number of episodes to run during each evaluation. + - stop_value (:obj:`float`): The reward threshold at which training is considered converged and will stop. + - env (:obj:`Optional[BaseEnvManager]`): An optional environment manager for evaluation. + - policy (:obj:`Optional[namedtuple]`): An optional policy for evaluation. + - tb_logger (:obj:`Optional['SummaryWriter']`): An optional TensorBoard logger. + - exp_name (:obj:`str`): The name of the experiment, used for logging. + - instance_name (:obj:`str`): The name of this evaluator instance. + - policy_config (:obj:`Optional[EasyDict]`): Configuration for the policy. + - task_id (:obj:`Optional[int]`): The unique identifier for the task. If None, it operates in single-task mode. """ + self.stop_event = threading.Event() # Event to signal a stop, e.g., due to a timeout. + self.task_id = task_id self._eval_freq = eval_freq self._exp_name = exp_name self._instance_name = instance_name - # Logger (Monitor will be initialized in policy setter) - # Only rank == 0 learner needs monitor and tb_logger, others only need text_logger to display terminal output. + # Initialize logger. Only rank 0 needs a full logger with TensorBoard. if get_rank() == 0: if tb_logger is not None: self._logger, _ = build_logger( - './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False + f'./{self._exp_name}/log/{self._instance_name}', self._instance_name, need_tb=False ) self._tb_logger = tb_logger else: self._logger, self._tb_logger = build_logger( - './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name + f'./{self._exp_name}/log/{self._instance_name}', self._instance_name ) else: - self._logger, self._tb_logger = None, None # for close elegantly + # TODO(username): Refine logger setup for UniZero multitask with DDP v2. + if tb_logger is not None: + self._logger, _ = build_logger( + f'./{self._exp_name}/log/{self._instance_name}', self._instance_name, need_tb=False + ) + self._tb_logger = tb_logger + + self._rank = get_rank() + print(f'rank {self._rank}, self.task_id: {self.task_id}') self.reset(policy, env) @@ -97,18 +109,16 @@ def __init__( self._stop_value = stop_value # ============================================================== - # MCTS+RL related core code + # MCTS+RL related core properties # ============================================================== self.policy_config = policy_config def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset the environment for the evaluator, optionally replacing it with a new environment. - If _env is None, reset the old environment. If _env is not None, replace the old environment - in the evaluator with the new passed in environment and launch. + Reset the environment. If a new environment is provided, it replaces the old one. Arguments: - - _env (:obj:`Optional[BaseEnvManager]`): An optional new environment instance to replace the existing one. + - _env (:obj:`Optional[BaseEnvManager]`): New environment manager to use. If None, resets the existing environment. """ if _env is not None: self._env = _env @@ -120,29 +130,22 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Reset the policy for the evaluator, optionally replacing it with a new policy. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the evaluator with the new passed in policy. + Reset the policy. If a new policy is provided, it replaces the old one. Arguments: - - _policy (:obj:`Optional[namedtuple]`): An optional new policy namedtuple to replace the existing one. + - _policy (:obj:`Optional[namedtuple]`): New policy to use. If None, resets the existing policy. """ - assert hasattr(self, '_env'), "please set env first" + assert hasattr(self, '_env'), "Please set environment first." if _policy is not None: self._policy = _policy - self._policy.reset() + self._policy.reset(task_id=self.task_id) def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset both the policy and environment for the evaluator, optionally replacing them. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the evaluator with the new passed in \ - environment and launch. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the evaluator with the new passed in policy. + Reset both the policy and the environment. Arguments: - - _policy (:obj:`Optional[namedtuple]`): An optional new policy namedtuple to replace the existing one. - - _env (:obj:`Optional[BaseEnvManager]`): An optional new environment instance to replace the existing one. + - _policy (:obj:`Optional[namedtuple]`): New policy to use. + - _env (:obj:`Optional[BaseEnvManager]`): New environment manager to use. """ if _env is not None: self.reset_env(_env) @@ -151,37 +154,36 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self._max_episode_return = float("-inf") self._last_eval_iter = 0 self._end_flag = False - def close(self) -> None: """ Overview: - Close the evaluator, the environment, flush and close the TensorBoard logger if applicable. + Close the evaluator, including the environment and the TensorBoard logger. """ if self._end_flag: return self._end_flag = True - self._env.close() + if hasattr(self, '_env'): + self._env.close() if self._tb_logger: self._tb_logger.flush() self._tb_logger.close() - def __del__(self): + def __del__(self) -> None: """ Overview: - Execute the close command and close the evaluator. __del__ is automatically called \ - to destroy the evaluator instance when the evaluator finishes its work + Destructor that ensures `close` is called to clean up resources. """ self.close() def should_eval(self, train_iter: int) -> bool: """ Overview: - Determine whether to initiate evaluation based on the training iteration count and evaluation frequency. + Determine whether it's time to run an evaluation based on the training iteration. Arguments: - - train_iter (:obj:`int`): The current count of training iterations. + - train_iter (:obj:`int`): The current training iteration. Returns: - - (:obj:`bool`): `True` if evaluation should be initiated, otherwise `False`. + - (:obj:`bool`): True if evaluation should be run, otherwise False. """ if train_iter == self._last_eval_iter: return False @@ -192,54 +194,56 @@ def should_eval(self, train_iter: int) -> bool: def eval( self, - save_ckpt_fn: Callable = None, + save_ckpt_fn: Optional[Callable] = None, train_iter: int = -1, envstep: int = -1, n_episode: Optional[int] = None, return_trajectory: bool = False, - ) -> Tuple[bool, float]: + ) -> Tuple[bool, Dict[str, Any]]: """ Overview: - Evaluate the current policy, storing the best policy if it achieves the highest historical reward. + Run a full evaluation process. It will evaluate the current policy, log the results, + and save a checkpoint if a new best performance is achieved. Arguments: - - save_ckpt_fn (:obj:`Optional[Callable]`): Optional function to save a checkpoint when a new best reward is achieved. - - train_iter (:obj:`int`): The current training iteration count. - - envstep (:obj:`int`): The current environment step count. - - n_episode (:obj:`Optional[int]`): Optional number of evaluation episodes; defaults to the evaluator's setting. - - return_trajectory (:obj:`bool`): Return the evaluated trajectory `game_segments` in `episode_info` if True. + - save_ckpt_fn (:obj:`Optional[Callable]`): A function to save a checkpoint. Called when a new best reward is achieved. + - train_iter (:obj:`int`): The current training iteration. + - envstep (:obj:`int`): The current total environment steps. + - n_episode (:obj:`Optional[int]`): The number of episodes to evaluate. Defaults to the value set in `__init__`. + - return_trajectory (:obj:`bool`): Whether to return the collected `game_segments` in the result dictionary. Returns: - - stop_flag (:obj:`bool`): Indicates whether the training can be stopped based on the stop value. - - episode_info (:obj:`Dict[str, Any]`): A dictionary containing information about the evaluation episodes. + - stop_flag (:obj:`bool`): A flag indicating whether the training should stop (e.g., if the stop value is reached). + - episode_info (:obj:`Dict[str, Any]`): A dictionary containing evaluation results, such as rewards and episode lengths. """ - # the evaluator only works on rank0 + if torch.cuda.is_available(): + print(f"=========in eval() Rank {get_rank()} ===========") + device = torch.cuda.current_device() + print(f"当前默认的 GPU 设备编号: {device}") + torch.cuda.set_device(get_rank()) + print(f"set device后的 GPU 设备编号: {get_rank()}") + + # The evaluator is designed to work on rank 0, but DDP support is being developed. episode_info = None stop_flag = False - if get_rank() == 0: + # TODO(username): Refine evaluation logic for UniZero multitask with DDP v2. + if get_rank() >= 0: if n_episode is None: n_episode = self._default_n_episode - assert n_episode is not None, "please indicate eval n_episode" + assert n_episode is not None, "Please specify the number of evaluation episodes (n_episode)." envstep_count = 0 eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode) env_nums = self._env.env_num self._env.reset() - self._policy.reset() + self._policy.reset(task_id=self.task_id) - # initializations + # Initializations init_obs = self._env.ready_obs + # Wait for all environments to be ready, especially in subprocess-based environment managers. retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: - # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + self._logger.info(f"Waiting for all environments to reset. Current ready envs: {list(init_obs.keys())}") time.sleep(retry_waiting_time) - self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, - self._env._env_states) - ) init_obs = self._env.ready_obs action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} @@ -248,20 +252,17 @@ def eval( timestep_dict = {} for i in range(env_nums): if 'timestep' not in init_obs[i]: - if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - print(f"Warning: 'timestep' key is missing in init_obs[{i}]. Assigning value -1. Please note that the unizero algorithm may require the 'timestep' key in init_obs.") + print(f"Warning: 'timestep' key is missing in init_obs[{i}], assigning value -1") timestep_dict[i] = to_ndarray(init_obs[i].get('timestep', -1)) - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} - dones = np.array([False for _ in range(env_nums)]) game_segments = [ GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) for _ in range(env_nums) ] for i in range(env_nums): @@ -272,73 +273,54 @@ def eval( ready_env_id = set() remain_episode = n_episode eps_steps_lst = np.zeros(env_nums) - with self._timer: while not eval_monitor.is_finished(): - # Get current ready env obs. + # Check if a timeout has occurred. + if self.stop_event.is_set(): + self._logger.info("[EVALUATOR]: Evaluation aborted due to timeout.") + break + + # Get observations from ready environments. obs = self._env.ready_obs new_available_env_id = set(obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - # In a parallel evaluation setting, it's possible for all active environments to finish their - # episodes simultaneously. This can leave `ready_env_id` temporarily empty while the environments - # are being reset by the manager. - # To prevent processing an empty batch, which would cause an IndexError or other errors downstream, - # we check if `ready_env_id` is empty. If so, we sleep briefly to prevent a busy-wait, - # and `continue` to the next loop iteration to wait for newly reset environments to become available. - if not ready_env_id: - time.sleep(0.01) - continue - + # Prepare stacked observations and other inputs for the policy. stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} stack_obs = list(stack_obs.values()) - - action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} - to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} - timestep_dict = {env_id: timestep_dict[env_id] for env_id in ready_env_id} action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] timestep = [timestep_dict[env_id] for env_id in ready_env_id] - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} - stack_obs = to_ndarray(stack_obs) stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() # ============================================================== - # policy forward + # Policy Forward Pass # ============================================================== - policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, timestep=timestep) - + if self.task_id is None: + # Single-task setting + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, timestep=timestep) + else: + # Multi-task setting + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, timestep=timestep, task_id=self.task_id) + + # Unpack policy outputs. actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: - root_sampled_actions_dict_with_env_id = { - k: v['root_sampled_actions'] - for k, v in policy_output.items() - } - + root_sampled_actions_dict_with_env_id = {k: v['root_sampled_actions'] for k, v in policy_output.items()} value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} - timestep_dict_with_env_id = { - k: v['timestep'] if 'timestep' in v else -1 for k, v in policy_output.items() - } - visit_entropy_dict_with_env_id = { - k: v['visit_count_distribution_entropy'] - for k, v in policy_output.items() - } - - actions = {} - distributions_dict = {} + timestep_dict_with_env_id = {k: v.get('timestep', -1) for k, v in policy_output.items()} + visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in policy_output.items()} + + # Remap outputs from policy's internal IDs to environment IDs. + actions, distributions_dict, value_dict, pred_value_dict, timestep_dict, visit_entropy_dict = {}, {}, {}, {}, {}, {} if self.policy_config.sampled_algo: root_sampled_actions_dict = {} - value_dict = {} - pred_value_dict = {} - timestep_dict = {} - visit_entropy_dict = {} for index, env_id in enumerate(ready_env_id): actions[env_id] = actions_with_env_id.pop(env_id) @@ -351,45 +333,30 @@ def eval( visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) # ============================================================== - # Interact with env. + # Environment Interaction # ============================================================== timesteps = self._env.step(actions) timesteps = to_tensor(timesteps, dtype=torch.float32) - for env_id, episode_timestep in timesteps.items(): obs, reward, done, info = episode_timestep.obs, episode_timestep.reward, episode_timestep.done, episode_timestep.info - # obs_input_ids = obs['observation'].long() - # obs_attn_mask = obs['obs_attn_mask'][0].long() - # valid_input_ids = obs_input_ids[obs_attn_mask == 1].tolist() - eps_steps_lst[env_id] += 1 + # This reset logic is specific to UniZero-like models. if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - # only for UniZero now - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False, task_id=self.task_id) - if self.policy_config.use_ture_chance_label_in_chance_encoder: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], timestep_dict[env_id], chance_dict[env_id] - ) - else: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], timestep_dict[env_id] - ) + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id], timestep_dict[env_id] + ) - # NOTE: the position of code snippet is very important. - # the obs['action_mask'] and obs['to_play'] are corresponding to next action + # IMPORTANT: The action_mask and to_play from the new observation correspond to the *next* state. action_mask_dict[env_id] = to_ndarray(obs['action_mask']) to_play_dict[env_id] = to_ndarray(obs['to_play']) timestep_dict[env_id] = to_ndarray(obs.get('timestep', -1)) - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict[env_id] = to_ndarray(obs['chance']) dones[env_id] = done if episode_timestep.done: - # Env reset is done by env_manager automatically. self._policy.reset([env_id]) reward = episode_timestep.info['eval_episode_return'] saved_info = {'eval_episode_return': episode_timestep.info['eval_episode_return']} @@ -398,117 +365,106 @@ def eval( eval_monitor.update_info(env_id, saved_info) eval_monitor.update_reward(env_id, reward) self._logger.info( - "[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format( - env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() - ) + f"[EVALUATOR] env {env_id} finished episode, final reward: {eval_monitor.get_latest_reward(env_id)}, " + f"current episode count: {eval_monitor.get_current_episode()}" ) - # reset the finished env and init game_segments + # If there are more episodes to run than available environments, reset and reuse this one. if n_episode > self._env_num: - # Get current ready env obs. init_obs = self._env.ready_obs - retry_waiting_time = 0.001 + # Wait for the environment to be ready again. while len(init_obs.keys()) != self._env_num: - # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info( - 'Before sleeping, the _env_states is {}'.format(self._env._env_states) - ) + self._logger.info(f"Waiting for env {env_id} to reset. Current ready envs: {list(init_obs.keys())}") time.sleep(retry_waiting_time) - self._logger.info( - '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 - ) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format( - retry_waiting_time, self._env._env_states - ) - ) init_obs = self._env.ready_obs new_available_env_id = set(init_obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) + # Re-initialize state for the new episode. action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) timestep_dict[env_id] = to_ndarray(init_obs[env_id].get('timestep', -1)) - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) - game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) - game_segments[env_id].reset( - [ - init_obs[env_id]['observation'] - for _ in range(self.policy_config.model.frame_stack_num) - ] + [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)] ) eps_steps_lst[env_id] = 0 - - # Env reset is done by env_manager automatically. - self._policy.reset([env_id]) # NOTE: reset the policy for the env_id. Default reset_init_data=True. + # NOTE: Reset the policy state for this env_id. `reset_init_data` defaults to True. + self._policy.reset([env_id]) ready_env_id.remove(env_id) envstep_count += 1 - + duration = self._timer.value episode_return = eval_monitor.get_episode_return() info = { 'train_iter': train_iter, - 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), + 'ckpt_name': f'iteration_{train_iter}.pth.tar', 'episode_count': n_episode, 'envstep_count': envstep_count, - 'avg_envstep_per_episode': envstep_count / n_episode, + 'avg_envstep_per_episode': envstep_count / n_episode if n_episode > 0 else 0, 'evaluate_time': duration, - 'avg_envstep_per_sec': envstep_count / duration, - 'avg_time_per_episode': n_episode / duration, + 'avg_envstep_per_sec': envstep_count / duration if duration > 0 else 0, + 'avg_time_per_episode': n_episode / duration if duration > 0 else 0, 'reward_mean': np.mean(episode_return), 'reward_std': np.std(episode_return), 'reward_max': np.max(episode_return), - 'reward_min': np.min(episode_return) - # 'each_reward': episode_return, + 'reward_min': np.min(episode_return), } episode_info = eval_monitor.get_episode_info() if episode_info is not None: info.update(episode_info) + + print(f'rank {self._rank}, self.task_id: {self.task_id}') self._logger.info(self._logger.get_tabulate_vars_hor(info)) + + # Log to TensorBoard and WandB. for k, v in info.items(): - if k in ['train_iter', 'ckpt_name', 'each_reward']: + if k in ['train_iter', 'ckpt_name', 'each_reward'] or not np.isscalar(v): continue - if not np.isscalar(v): - continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + if self.task_id is None: + self._tb_logger.add_scalar(f'{self._instance_name}_iter/{k}', v, train_iter) + self._tb_logger.add_scalar(f'{self._instance_name}_step/{k}', v, envstep) + else: + self._tb_logger.add_scalar(f'{self._instance_name}_iter_task{self.task_id}/{k}', v, train_iter) + self._tb_logger.add_scalar(f'{self._instance_name}_step_task{self.task_id}/{k}', v, envstep) if self.policy_config.use_wandb: - wandb.log({'{}_step/'.format(self._instance_name) + k: v}, step=envstep) + wandb.log({f'{self._instance_name}_step/{k}': v}, step=envstep) - episode_return = np.mean(episode_return) - if episode_return > self._max_episode_return: + # Check for new best performance and save checkpoint. + mean_episode_return = np.mean(episode_return) + if mean_episode_return > self._max_episode_return: if save_ckpt_fn: save_ckpt_fn('ckpt_best.pth.tar') - self._max_episode_return = episode_return - stop_flag = episode_return >= self._stop_value and train_iter > 0 + self._max_episode_return = mean_episode_return + + # Check if the stop condition is met. + stop_flag = mean_episode_return >= self._stop_value and train_iter > 0 if stop_flag: self._logger.info( - "[LightZero serial pipeline] " + - "Current episode_return: {} is greater than stop_value: {}".format(episode_return, - self._stop_value) + - ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." + f"[LightZero serial pipeline] Current episode_return: {mean_episode_return} is greater than " + f"stop_value: {self._stop_value}. The agent is considered converged." ) - if get_world_size() > 1: - objects = [stop_flag, episode_info] - broadcast_object_list(objects, src=0) - stop_flag, episode_info = objects + # TODO(username): Finalize DDP synchronization for evaluation results. + # if get_world_size() > 1: + # objects = [stop_flag, episode_info] + # print(f'rank {self._rank}, self.task_id: {self.task_id}') + # print('before broadcast_object_list') + # broadcast_object_list(objects, src=0) + # print('evaluator after broadcast_object_list') + # stop_flag, episode_info = objects episode_info = to_item(episode_info) if return_trajectory: episode_info['trajectory'] = game_segments - return stop_flag, episode_info + return stop_flag, episode_info \ No newline at end of file diff --git a/lzero/worker/muzero_segment_collector.py b/lzero/worker/muzero_segment_collector.py index 46cc016bc..ad7f91bf9 100644 --- a/lzero/worker/muzero_segment_collector.py +++ b/lzero/worker/muzero_segment_collector.py @@ -1,7 +1,7 @@ import logging import time from collections import deque, namedtuple -from typing import Optional, Any, List +from typing import Optional, Any, List, Dict import numpy as np import torch @@ -20,21 +20,20 @@ class MuZeroSegmentCollector(ISerialCollector): """ Overview: - MuZeroSegmentCollector is a data collector for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, and Gumbel MuZero. - It manages the data collection process for training these algorithms using a serial mechanism. - - The main difference from MuZeroCollector is that MuZeroSegmentCollector returns after collecting a specified number of segments, - whereas MuZeroCollector returns after collecting a complete game. This provides more extensibility and flexibility in data collection. + MuZeroSegmentCollector is a data collector for MCTS+RL algorithms, including MuZero, EfficientZero, + Sampled EfficientZero, and Gumbel MuZero. It manages the data collection process for training these + algorithms using a serial mechanism. + The main difference from MuZeroCollector is that MuZeroSegmentCollector returns after collecting a + specified number of segments, whereas MuZeroCollector returns after collecting a complete game. + This provides more extensibility and flexibility in data collection. Interfaces: - ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``envstep``, ``__del__``, ``_compute_priorities``, - ``pad_and_save_last_trajectory``, ``collect``, ``_output_log``, ``close`` - + ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``collect``, ``close``, ``__del__`` Properties: - ``envstep``: Counter for the current number of environment steps. + - envstep (:obj:`int`): The total number of environment steps collected. """ - # To be compatible with ISerialCollector + # To be compatible with ISerialCollector. config = dict() def __init__( @@ -46,19 +45,22 @@ def __init__( exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'collector', policy_config: 'policy_config' = None, # noqa + task_id: int = None, ) -> None: """ Overview: - Initialize the MuZeroSegmentCollector with the given parameters. + Initializes the MuZeroSegmentCollector. Arguments: - - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information. - - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager. - - policy (:obj:`Optional[namedtuple]`): Namedtuple of the collection mode policy API. - - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance. - - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. - - instance_name (:obj:`str`): Unique identifier for this collector instance. - - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. + - collect_print_freq (:obj:`int`): The frequency (in training steps) at which to print collection information. + - env (:obj:`Optional[BaseEnvManager]`): An instance of a vectorized environment manager. + - policy (:obj:`Optional[namedtuple]`): A namedtuple containing the collect mode policy API. + - tb_logger (:obj:`Optional[SummaryWriter]`): A TensorBoard logger instance. + - exp_name (:obj:`str`): The name of the experiment, used for logging and saving. + - instance_name (:obj:`str`): A unique identifier for this collector instance. + - policy_config (:obj:`Optional[policy_config]`): The configuration object for the policy. + - task_id (:obj:`int`): The ID of the task, used in multi-task learning settings. """ + self.task_id = task_id self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq @@ -67,23 +69,23 @@ def __init__( self._rank = get_rank() self._world_size = get_world_size() + if self._rank == 0: if tb_logger is not None: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), - name=self._instance_name, - need_tb=False + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False ) self._tb_logger = tb_logger else: self._logger, self._tb_logger = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name ) else: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False ) - self._tb_logger = None + # TODO(author): This part is for UniZero multi-task DDP v2 compatibility. + self._tb_logger = tb_logger self.policy_config = policy_config self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy @@ -93,12 +95,11 @@ def __init__( def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset or replace the environment managed by this collector. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the collector with the new passed \ - in environment and launch. + Resets or replaces the environment managed by the collector. + If `_env` is None, it resets the existing environment. Otherwise, it replaces the old + environment with the new one and launches it. Arguments: - - env (:obj:`Optional[BaseEnvManager]`): New environment to manage, if provided. + - _env (:obj:`Optional[BaseEnvManager]`): The new environment to be used. Defaults to None. """ if _env is not None: self._env = _env @@ -110,35 +111,28 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Reset or replace the policy used by this collector. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the collector with the new passed in policy. + Resets or replaces the policy used by the collector. + If `_policy` is None, it resets the existing policy. Otherwise, it replaces the old + policy with the new one. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + - _policy (:obj:`Optional[namedtuple]`): The new policy's API in a namedtuple format. Defaults to None. """ - assert hasattr(self, '_env'), "please set env first" + assert hasattr(self, '_env'), "Please set env before resetting policy." if _policy is not None: self._policy = _policy - - self._default_num_segments = _policy.get_attribute('cfg').get('num_segments', None) + self._default_num_segments = self._policy.get_attribute('cfg').get('num_segments', None) self._logger.debug( - 'Set default num_segments mode(num_segments({}), env_num({}))'.format(self._default_num_segments, self._env_num) + f'Set default num_segments mode(num_segments({self._default_num_segments}), env_num({self._env_num}))' ) - self._policy.reset() + self._policy.reset(task_id=self.task_id) def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset the collector with the given policy and/or environment. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the collector with the new passed \ - in environment and launch. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the collector with the new passed in policy. + Resets the collector state, including the environment and policy. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) + - _policy (:obj:`Optional[namedtuple]`): The new policy to use. Defaults to None. + - _env (:obj:`Optional[BaseEnvManager]`): The new environment to use. Defaults to None. """ if _env is not None: self.reset_env(_env) @@ -147,13 +141,12 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} - # Initialize action_mask_dict, to_play_dict, and chance_dict here to ensure they contain values for all env_id + # Initialize dictionaries to store environment-specific states. self.action_mask_dict = {i: None for i in range(self._env_num)} self.to_play_dict = {i: None for i in range(self._env_num)} + self.timestep_dict = {i: None for i in range(self._env_num)} if self.policy_config.use_ture_chance_label_in_chance_encoder: self.chance_dict = {i: None for i in range(self._env_num)} - - self.timestep_dict = {i: None for i in range(self._env_num)} self.dones = np.array([False for _ in range(self._env_num)]) self.last_game_segments = [None for _ in range(self._env_num)] @@ -166,18 +159,16 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self._last_train_iter = 0 self._end_flag = False - # A game_segment_pool implementation based on the deque structure. + # A deque-based pool for storing game segments. self.game_segment_pool = deque(maxlen=int(1e6)) self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps def _reset_stat(self, env_id: int) -> None: """ Overview: - Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool \ - and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ - to get more messages. + Resets the statistics for a specific environment. Arguments: - - env_id (:obj:`int`): the id where we need to reset the collector's state + - env_id (:obj:`int`): The ID of the environment to reset. """ self._env_info[env_id] = {'time': 0., 'step': 0} @@ -185,17 +176,16 @@ def _reset_stat(self, env_id: int) -> None: def envstep(self) -> int: """ Overview: - Get the total number of environment steps collected. + Returns the total number of environment steps collected. Returns: - - envstep (:obj:`int`): Total number of environment steps collected. + - envstep (:obj:`int`): The total count of environment steps. """ return self._total_envstep_count def close(self) -> None: """ Overview: - Close the collector. If end_flag is False, close the environment, flush the tb_logger \ - and close the tb_logger. + Closes the collector, including the environment and the TensorBoard logger. """ if self._end_flag: return @@ -208,79 +198,63 @@ def close(self) -> None: def __del__(self) -> None: """ Overview: - Execute the close command and close the collector. __del__ is automatically called to \ - destroy the collector instance when the collector finishes its work + Ensures that the `close` method is called when the collector instance is deleted. """ self.close() - # ============================================================== - # MCTS+RL related core code - # ============================================================== - def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> np.ndarray: + def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> Optional[np.ndarray]: """ Overview: - Compute the priorities for transitions based on prediction and search value discrepancies. + Computes priorities for transitions based on the discrepancy between predicted and search values. Arguments: - - i (:obj:`int`): Index of the values in the list to compute the priority for. - - pred_values_lst (:obj:`List[float]`): List of predicted values. - - search_values_lst (:obj:`List[float]`): List of search values obtained from MCTS. + - i (:obj:`int`): The index of the values list to process. + - pred_values_lst (:obj:`List[float]`): A list containing lists of predicted values. + - search_values_lst (:obj:`List[float]`): A list containing lists of search values from MCTS. Returns: - - priorities (:obj:`np.ndarray`): Array of computed priorities. + - priorities (:obj:`Optional[np.ndarray]`): An array of computed priorities, or None if priority is disabled. """ if self.policy_config.use_priority: - # Calculate priorities. The priorities are the L1 losses between the predicted - # values and the search values. We use 'none' as the reduction parameter, which - # means the loss is calculated for each element individually, instead of being summed or averaged. - # A small constant (1e-6) is added to the results to avoid zero priorities. This - # is done because zero priorities could potentially cause issues in some scenarios. + # Calculate priorities as the L1 loss between predicted and search values. + # The reduction is 'none' to get per-element losses. + # A small epsilon (1e-6) is added to prevent zero priorities. pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1) - search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device - ).float().view(-1) - priorities = L1Loss(reduction='none' - )(pred_values, - search_values).detach().cpu().numpy() + 1e-6 + search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device).float().view(-1) + priorities = L1Loss(reduction='none')(pred_values, search_values).detach().cpu().numpy() + 1e-6 else: - # priorities is None -> use the max priority for all newly collected data + # If not using priority, all new data will use the maximum priority in the replay buffer. priorities = None return priorities - def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegment], - last_game_priorities: List[np.ndarray], - game_segments: List[GameSegment], done: np.ndarray) -> None: + def pad_and_save_last_trajectory( + self, i: int, last_game_segments: List[GameSegment], last_game_priorities: List[np.ndarray], + game_segments: List[GameSegment], done: np.ndarray + ) -> None: """ Overview: - Save the game segment to the pool if the current game is finished, padding it if necessary. + Pads the last game segment with data from the current segment and saves it to the pool. + This is done when a game ends or a segment becomes full. Arguments: - - i (:obj:`int`): Index of the current game segment. - - last_game_segments (:obj:`List[GameSegment]`): List of the last game segments to be padded and saved. - - last_game_priorities (:obj:`List[np.ndarray]`): List of priorities of the last game segments. - - game_segments (:obj:`List[GameSegment]`): List of the current game segments. - - done (:obj:`np.ndarray`): Array indicating whether each game is done. - Note: - (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True + - i (:obj:`int`): The index of the current game segment (and environment). + - last_game_segments (:obj:`List[GameSegment]`): The list of previous game segments to be padded. + - last_game_priorities (:obj:`List[np.ndarray]`): The list of priorities for the previous game segments. + - game_segments (:obj:`List[GameSegment]`): The list of current game segments, used for padding data. + - done (:obj:`np.ndarray`): An array indicating whether each game has terminated. """ - # pad over last segment trajectory + # Pad the trajectory of the last segment. beg_index = self.policy_config.model.frame_stack_num end_index = beg_index + self.policy_config.num_unroll_steps + self.policy_config.td_steps - # the start obs is init zero obs, so we take the - # [ : +] obs as the pad obs - # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs + # The initial observations are zero-padded, so we take observations from + # [ : + ] for padding. pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] - # NOTE: for unizero + # NOTE: Specific padding logic for UniZero. pad_action_lst = game_segments[i].action_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] - - # NOTE: for unizero pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] - # EfficientZero original repo bug: - # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] - beg_index = 0 end_index = beg_index + self.unroll_plus_td_steps - 1 - pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] if self.policy_config.use_ture_chance_label_in_chance_encoder: @@ -288,101 +262,87 @@ def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegm beg_index = 0 end_index = beg_index + self.unroll_plus_td_steps - pad_root_values_lst = game_segments[i].root_value_segment[beg_index:end_index] if self.policy_config.gumbel_algo: pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index:end_index] - # pad over and save + # Pad and finalize the last game segment. if self.policy_config.gumbel_algo: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, - next_segment_improved_policy=pad_improved_policy_prob) + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, + next_segment_improved_policy=pad_improved_policy_prob + ) else: if self.policy_config.use_ture_chance_label_in_chance_encoder: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, - next_chances=chance_lst) + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, + next_chances=chance_lst + ) else: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst) - """ - Note: - game_segment element shape: - obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 - rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 - action: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 - root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 - child_visits: game_segment_length + num_unroll_steps -> 20 +5 - to_play: game_segment_length -> 20 - action_mask: game_segment_length -> 20 - """ + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst + ) last_game_segments[i].game_segment_to_array() - # put the game segment into the pool + # Add the completed game segment to the pool. self.game_segment_pool.append((last_game_segments[i], last_game_priorities[i], done[i])) - # reset last game_segments and last game_priorities for the next collection + # Reset placeholders for the next collection cycle. last_game_segments[i] = None last_game_priorities[i] = None - return None - - def collect(self, - num_segments: Optional[int] = None, - train_iter: int = 0, - policy_kwargs: Optional[dict] = None, - collect_with_pure_policy: bool = False) -> List[Any]: + def collect( + self, + num_segments: Optional[int] = None, + train_iter: int = 0, + policy_kwargs: Optional[dict] = None, + collect_with_pure_policy: bool = False + ) -> List[Any]: """ Overview: - Collect `num_segments` segments of data with policy_kwargs, trained for `train_iter` iterations. + Collects a specified number of game segments using the policy. Arguments: - - num_segments (:obj:`Optional[int]`): Number of segments to collect. - - train_iter (:obj:`int`): Number of training iterations completed so far. - - policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy. - - collect_with_pure_policy (:obj:`bool`): Whether to collect data using pure policy without MCTS. + - num_segments (:obj:`Optional[int]`): The number of segments to collect. If None, uses the default. + - train_iter (:obj:`int`): The current training iteration, used for logging. + - policy_kwargs (:obj:`Optional[dict]`): Additional arguments for the policy forward pass. + - collect_with_pure_policy (:obj:`bool`): If True, collects data using a pure policy without MCTS. Returns: - - return_data (:obj:`List[Any]`): Collected data in the form of a list. + - return_data (:obj:`List[Any]`): A list containing the collected game segments and their metadata. """ if num_segments is None: if self._default_num_segments is None: - raise RuntimeError("Please specify collect num_segments") + raise RuntimeError("Please specify num_segments for collection.") else: num_segments = self._default_num_segments - assert num_segments == self._env_num, "Please make sure num_segments == env_num{}/{}".format(num_segments, self._env_num) + assert num_segments == self._env_num, f"num_segments({num_segments}) must be equal to env_num({self._env_num})." if policy_kwargs is None: policy_kwargs = {} - temperature = policy_kwargs['temperature'] - epsilon = policy_kwargs['epsilon'] + temperature = policy_kwargs.get('temperature', 1.0) + epsilon = policy_kwargs.get('epsilon', 0.0) + # Initializations collected_episode = 0 collected_step = 0 env_nums = self._env_num - - # initializations init_obs = self._env.ready_obs + # Wait for all environments to be ready, especially in a subprocess setup. retry_waiting_time = 0.05 - while len(init_obs.keys()) != self._env_num: - # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + while len(init_obs.keys()) != env_nums: + self._logger.info(f'Waiting for all environments to reset. Ready envs: {list(init_obs.keys())}') time.sleep(retry_waiting_time) - self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) - ) init_obs = self._env.ready_obs for env_id in range(env_nums): - if env_id in init_obs.keys(): + if env_id in init_obs: self.action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) self.to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) - if 'timestep' not in init_obs[env_id]: - print(f"Warning: 'timestep' key is missing in init_obs[{env_id}], assigning value -1") self.timestep_dict[env_id] = to_ndarray(init_obs[env_id].get('timestep', -1)) - + if 'timestep' not in init_obs[env_id]: + self._logger.warning(f"'timestep' key missing in init_obs[{env_id}], assigning default -1.") if self.policy_config.use_ture_chance_label_in_chance_encoder: self.chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) @@ -390,151 +350,95 @@ def collect(self, GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) for _ in range(env_nums) ] - # stacked observation windows in reset stage for init game_segments - observation_window_stack = [[] for _ in range(env_nums)] - for env_id in range(env_nums): - observation_window_stack[env_id] = deque( - [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) + # Stacked observation windows for initializing game segments. + observation_window_stack = [deque(maxlen=self.policy_config.model.frame_stack_num) for _ in range(env_nums)] + for env_id in range(env_nums): + initial_frames = [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)] + observation_window_stack[env_id].extend(initial_frames) game_segments[env_id].reset(observation_window_stack[env_id]) - # for priorities in self-play + # Lists for storing values for priority calculation. search_values_lst = [[] for _ in range(env_nums)] pred_values_lst = [[] for _ in range(env_nums)] if self.policy_config.gumbel_algo: improved_policy_lst = [[] for _ in range(env_nums)] - # some logs + # Logging variables. eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) if self.policy_config.gumbel_algo: completed_value_lst = np.zeros(env_nums) - self_play_moves = 0. - self_play_episodes = 0. - self_play_moves_max = 0 - self_play_visit_entropy = [] - total_transitions = 0 if collect_with_pure_policy: - temp_visit_list = [0.0 for i in range(self._env.action_space.n)] + temp_visit_list = [0.0 for _ in range(self._env.action_space.n)] while True: with self._timer: - # Get current ready env obs. + # Get observations from ready environments. obs = self._env.ready_obs ready_env_id = set(obs.keys()) if len(ready_env_id) < self._env_num: - logging.info(f'muzero_segment_collector: len(ready_env_id) < self._env_num, ready_env_id: {ready_env_id}, self._env_num: {self._env_num}') - - # TODO: For UniZero, during the init-infer process, it is necessary to retrieve the current kv_cache from the kv_cache_dict corresponding to each env_id. - # In theory, this requires waiting for all environments to be ready. However, in practice, - # waiting for all environments to be ready can have a significant negative impact on UniZero's performance, - # whereas the impact on MuZero is relatively small. + self._logger.debug(f'Only {len(ready_env_id)}/{self._env_num} envs are ready.') + + # TODO(author): For UniZero, waiting for all environments to be ready can negatively impact performance. + # This wait loop is currently commented out, but its impact should be considered. # while len(obs.keys()) != self._env_num: - # # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # # len(self._env.ready_obs), especially in tictactoe env. - # self._logger.info('The current init_obs.keys() is {}'.format(obs.keys())) - # self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) # time.sleep(retry_waiting_time) - # self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - # self._logger.info( - # 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) - # ) # obs = self._env.ready_obs # ready_env_id = set(obs.keys()) - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} - stack_obs = list(stack_obs.values()) + # Prepare stacked observations for the policy network. + stack_obs_dict = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs_list = [stack_obs_dict[env_id] for env_id in sorted(list(ready_env_id))] self.action_mask_dict_tmp = {env_id: self.action_mask_dict[env_id] for env_id in ready_env_id} self.to_play_dict_tmp = {env_id: self.to_play_dict[env_id] for env_id in ready_env_id} self.timestep_dict_tmp = {env_id: self.timestep_dict[env_id] for env_id in ready_env_id} - - action_mask = [self.action_mask_dict_tmp[env_id] for env_id in ready_env_id] - to_play = [self.to_play_dict_tmp[env_id] for env_id in ready_env_id] - timestep = [self.timestep_dict_tmp[env_id] for env_id in ready_env_id] + + action_mask = [self.action_mask_dict_tmp[env_id] for env_id in sorted(list(ready_env_id))] + to_play = [self.to_play_dict_tmp[env_id] for env_id in sorted(list(ready_env_id))] + timestep = [self.timestep_dict_tmp[env_id] for env_id in sorted(list(ready_env_id))] if self.policy_config.use_ture_chance_label_in_chance_encoder: self.chance_dict_tmp = {env_id: self.chance_dict[env_id] for env_id in ready_env_id} - stack_obs = to_ndarray(stack_obs) - # return stack_obs shape: [B, S*C, W, H] e.g. [8, 4*1, 96, 96] - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device) + stack_obs_array = to_ndarray(stack_obs_list) + stack_obs_tensor = prepare_observation(stack_obs_array, self.policy_config.model.model_type) + stack_obs_tensor = torch.from_numpy(stack_obs_tensor).to(self.policy_config.device) # ============================================================== - # Key policy forward step + # Perform a forward pass with the policy. # ============================================================== - # logging.info(f'ready_env_id:{ready_env_id}') - # logging.info(f'timestep:{timestep}') - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep) + policy_args = (stack_obs_tensor, action_mask, temperature, to_play, epsilon) + policy_kwargs_forward = {'ready_env_id': sorted(list(ready_env_id)), 'timestep': timestep} + if self.task_id is not None: + policy_kwargs_forward['task_id'] = self.task_id + + policy_output = self._policy.forward(*policy_args, **policy_kwargs_forward) - # Extract relevant policy outputs + # Extract policy outputs. actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} - timestep_dict_with_env_id = { - k: v['timestep'] if 'timestep' in v else -1 for k, v in policy_output.items() - } - - if self.policy_config.sampled_algo: - root_sampled_actions_dict_with_env_id = { - k: v['root_sampled_actions'] for k, v in policy_output.items() - } if not collect_with_pure_policy: - distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in - policy_output.items()} - visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in - policy_output.items()} - - if self.policy_config.gumbel_algo: - improved_policy_dict_with_env_id = {k: v['improved_policy_probs'] for k, v in - policy_output.items()} - completed_value_with_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()} - - # Initialize dictionaries to store results - actions = {} - value_dict = {} - pred_value_dict = {} - timestep_dict = {} - - if not collect_with_pure_policy: - distributions_dict = {} - visit_entropy_dict = {} - + distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} + visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: - root_sampled_actions_dict = {} - + root_sampled_actions_dict_with_env_id = {k: v['root_sampled_actions'] for k, v in policy_output.items()} if self.policy_config.gumbel_algo: - improved_policy_dict = {} - completed_value_dict = {} - - # Populate the result dictionaries - for env_id in ready_env_id: - actions[env_id] = actions_with_env_id.pop(env_id) - value_dict[env_id] = value_dict_with_env_id.pop(env_id) - pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id) - timestep_dict[env_id] = timestep_dict_with_env_id.pop(env_id) - - if not collect_with_pure_policy: - distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) - - if self.policy_config.sampled_algo: - root_sampled_actions_dict[env_id] = root_sampled_actions_dict_with_env_id.pop(env_id) + improved_policy_dict_with_env_id = {k: v['improved_policy_probs'] for k, v in policy_output.items()} + completed_value_with_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()} - visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) - - if self.policy_config.gumbel_algo: - improved_policy_dict[env_id] = improved_policy_dict_with_env_id.pop(env_id) - completed_value_dict[env_id] = completed_value_with_env_id.pop(env_id) + # Populate the result dictionaries, mapping outputs to original env_ids. + actions: Dict[int, Any] = {env_id: actions_with_env_id.pop(env_id) for env_id in ready_env_id} # ============================================================== - # Interact with the environment + # Step the environments with the chosen actions. # ============================================================== timesteps = self._env.step(actions) @@ -542,108 +446,98 @@ def collect(self, for env_id, episode_timestep in timesteps.items(): with self._timer: + # Handle abnormal timesteps by resetting the environment and policy state. if episode_timestep.info.get('abnormal', False): - # If there is an abnormal episode_timestep, reset all the related variables(including this env). - # suppose there is no reset param, reset this env self._env.reset({env_id: None}) self._policy.reset([env_id]) self._reset_stat(env_id) - self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, episode_timestep.info)) + self._logger.info(f'Env {env_id} had an abnormal step, info: {episode_timestep.info}') continue + obs, reward, done, info = episode_timestep.obs, episode_timestep.reward, episode_timestep.done, episode_timestep.info + # Store search statistics in the game segment. if collect_with_pure_policy: game_segments[env_id].store_search_stats(temp_visit_list, 0) else: if self.policy_config.sampled_algo: game_segments[env_id].store_search_stats( - distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] + distributions_dict_with_env_id[env_id], value_dict_with_env_id[env_id], root_sampled_actions_dict_with_env_id[env_id] ) elif self.policy_config.gumbel_algo: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], - improved_policy=improved_policy_dict[env_id]) + game_segments[env_id].store_search_stats( + distributions_dict_with_env_id[env_id], value_dict_with_env_id[env_id], + improved_policy=improved_policy_dict_with_env_id[env_id] + ) else: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) + game_segments[env_id].store_search_stats(distributions_dict_with_env_id[env_id], value_dict_with_env_id[env_id]) - # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} - # in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment`` + # Append the new transition to the game segment. + append_kwargs = {'timestep': to_ndarray(obs.get('timestep', -1))} if self.policy_config.use_ture_chance_label_in_chance_encoder: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, self.action_mask_dict_tmp[env_id], - self.to_play_dict_tmp[env_id], timestep=to_ndarray(obs['timestep']), chance=self.chance_dict_tmp[env_id] - ) - else: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, self.action_mask_dict_tmp[env_id], - self.to_play_dict_tmp[env_id], timestep=to_ndarray(obs['timestep']) - ) - - # NOTE: the position of code snippet is very important. - # the obs['action_mask'] and obs['to_play'] are corresponding to the next action - self.action_mask_dict_tmp[env_id] = to_ndarray(obs['action_mask']) - self.to_play_dict_tmp[env_id] = to_ndarray(obs['to_play']) - # self.timestep_dict_tmp[env_id] = to_ndarray(obs['timestep']) - self.timestep_dict_tmp[env_id] = to_ndarray(obs.get('timestep', -1)) - - + append_kwargs['chance'] = self.chance_dict_tmp[env_id] + + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, + self.action_mask_dict_tmp[env_id], self.to_play_dict_tmp[env_id], **append_kwargs + ) + + # NOTE: This position is crucial. The action_mask and to_play from the new observation correspond to the *next* state. + self.action_mask_dict[env_id] = to_ndarray(obs['action_mask']) + self.to_play_dict[env_id] = to_ndarray(obs['to_play']) + self.timestep_dict[env_id] = to_ndarray(obs.get('timestep', -1)) if self.policy_config.use_ture_chance_label_in_chance_encoder: - self.chance_dict_tmp[env_id] = to_ndarray(obs['chance']) + self.chance_dict[env_id] = to_ndarray(obs['chance']) - if self.policy_config.ignore_done: - self.dones[env_id] = False - else: - self.dones[env_id] = done + self.dones[env_id] = False if self.policy_config.ignore_done else done if not collect_with_pure_policy: - visit_entropies_lst[env_id] += visit_entropy_dict[env_id] + visit_entropies_lst[env_id] += visit_entropy_dict_with_env_id[env_id] if self.policy_config.gumbel_algo: - completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) + completed_value_lst[env_id] += np.mean(np.array(completed_value_with_env_id[env_id])) eps_steps_lst[env_id] += 1 + + # NOTE: Specific reset logic for UniZero. if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - # ============ only for UniZero now ============ self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) - total_transitions += 1 - if self.policy_config.use_priority: - pred_values_lst[env_id].append(pred_value_dict[env_id]) - search_values_lst[env_id].append(value_dict[env_id]) + pred_values_lst[env_id].append(pred_value_dict_with_env_id[env_id]) + search_values_lst[env_id].append(value_dict_with_env_id[env_id]) if self.policy_config.gumbel_algo and not collect_with_pure_policy: - improved_policy_lst[env_id].append(improved_policy_dict[env_id]) + improved_policy_lst[env_id].append(improved_policy_dict_with_env_id[env_id]) - # append the newest obs + # Append the newest observation to the observation window. observation_window_stack[env_id].append(to_ndarray(obs['observation'])) # ============================================================== - # we will save a game segment if it is the end of the game or the next game segment is finished. + # Save a game segment if it is full or the game has ended. # ============================================================== - - # if game segment is full, we will save the last game segment if game_segments[env_id].is_full(): - # pad over last segment trajectory + # If there's a previous segment, pad and save it. if self.last_game_segments[env_id] is not None: - # TODO(pu): return the one game segment + # TODO(pu): This logic pads and saves one game segment at a time. self.pad_and_save_last_trajectory( env_id, self.last_game_segments, self.last_game_priorities, game_segments, self.dones ) - # calculate priority + # Calculate priorities for the collected transitions. priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] + pred_values_lst[env_id], search_values_lst[env_id] = [], [] if self.policy_config.gumbel_algo and not collect_with_pure_policy: improved_policy_lst[env_id] = [] - # the current game_segments become last_game_segment + # The current segment now becomes the 'last' segment for the next padding operation. self.last_game_segments[env_id] = game_segments[env_id] self.last_game_priorities[env_id] = priorities - # create new GameSegment + # Create a new game segment to continue collection. game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) game_segments[env_id].reset(observation_window_stack[env_id]) @@ -652,94 +546,83 @@ def collect(self, self._env_info[env_id]['time'] += self._timer.value + interaction_duration if episode_timestep.done: - logging.info(f'========env {env_id} done!========') + self._logger.info(f'======== Environment {env_id} episode finished! ========') self._total_episode_count += 1 - reward = episode_timestep.info['eval_episode_return'] info = { - 'reward': reward, + 'reward': episode_timestep.info['eval_episode_return'], 'time': self._env_info[env_id]['time'], 'step': self._env_info[env_id]['step'], } if not collect_with_pure_policy: - info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] + info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 if self.policy_config.gumbel_algo: - info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] - + info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 collected_episode += 1 self._episode_info.append(info) # ============================================================== - # if it is the end of the game, we will save the game segment + # At the end of a game, save all remaining game segments. # ============================================================== - - # NOTE: put the penultimate game segment in one episode into the trajectory_pool - # pad over 2th last game_segment using the last game_segment + # NOTE: Store the second-to-last game segment of the episode. if self.last_game_segments[env_id] is not None: self.pad_and_save_last_trajectory( env_id, self.last_game_segments, self.last_game_priorities, game_segments, self.dones ) - # store current segment trajectory + # Calculate priorities for the final segment. priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - # NOTE: put the last game segment in one episode into the trajectory_pool + # NOTE: Store the final game segment of the episode. game_segments[env_id].game_segment_to_array() - - # assert len(game_segments[env_id]) == len(priorities) - # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null - if len(game_segments[env_id].reward_segment) != 0: + if len(game_segments[env_id].reward_segment) > 0: self.game_segment_pool.append((game_segments[env_id], priorities, self.dones[env_id])) - # log - self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) - if not collect_with_pure_policy: - self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) - self_play_moves += eps_steps_lst[env_id] - self_play_episodes += 1 - - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] - eps_steps_lst[env_id] = 0 - visit_entropies_lst[env_id] = 0 - - # Env reset is done by env_manager automatically - # NOTE: ============ reset the policy for the env_id. Default reset_init_data=True. ================ - self._policy.reset([env_id]) + # Reset lists and stats for the new episode. + pred_values_lst[env_id], search_values_lst[env_id] = [], [] + eps_steps_lst[env_id], visit_entropies_lst[env_id] = 0, 0 + + # Environment reset is handled by the env_manager automatically. + # NOTE: Reset the policy state for the completed environment. + self._policy.reset([env_id], task_id=self.task_id) self._reset_stat(env_id) - ready_env_id.remove(env_id) - # ===== NOTE: if one episode done and not return, we should init its game_segments[env_id] ======= - # create new GameSegment - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) + # NOTE: If an episode finishes but collection continues, re-initialize its game segment. + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config, + task_id=self.task_id + ) game_segments[env_id].reset(observation_window_stack[env_id]) - - # NOTE: must after the for loop to make sure all env_id's data are collected + # Check if the required number of segments has been collected. if len(self.game_segment_pool) >= self._default_num_segments: - logging.info(f'env {env_id} collected {len(self.game_segment_pool)} segments now!') - - # [data, meta_data] - return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [ - { - 'priorities': self.game_segment_pool[i][1], - 'done': self.game_segment_pool[i][2], - 'unroll_plus_td_steps': self.unroll_plus_td_steps - } for i in range(len(self.game_segment_pool)) + self._logger.info(f'Collected {len(self.game_segment_pool)} segments, reaching the target of {self._default_num_segments}.') + + # Format data for returning: [game_segments, metadata_list] + return_data = [ + [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], + [ + { + 'priorities': self.game_segment_pool[i][1], + 'done': self.game_segment_pool[i][2], + 'unroll_plus_td_steps': self.unroll_plus_td_steps + } for i in range(len(self.game_segment_pool)) + ] ] self.game_segment_pool.clear() break + collected_duration = sum([d['time'] for d in self._episode_info]) + # TODO: for atari multitask new ddp pipeline # reduce data when enables DDP - if self._world_size > 1: - collected_step = allreduce_data(collected_step, 'sum') - collected_episode = allreduce_data(collected_episode, 'sum') - collected_duration = allreduce_data(collected_duration, 'sum') + # if self._world_size > 1: + # collected_step = allreduce_data(collected_step, 'sum') + # collected_episode = allreduce_data(collected_episode, 'sum') + # collected_duration = allreduce_data(collected_duration, 'sum') + self._total_envstep_count += collected_step self._total_episode_count += collected_episode self._total_duration += collected_duration @@ -751,31 +634,31 @@ def collect(self, def _output_log(self, train_iter: int) -> None: """ Overview: - Log the collector's data and output the log information. + Logs collection statistics to the console and TensorBoard. Arguments: - - train_iter (:obj:`int`): Current training iteration number for logging context. + - train_iter (:obj:`int`): The current training iteration for logging context. """ - if self._rank != 0: - return + # TODO(author): For multi-task DDP, logging might be restricted to rank 0. + # if self._rank != 0: + # return if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: self._last_train_iter = train_iter episode_count = len(self._episode_info) envstep_count = sum([d['step'] for d in self._episode_info]) duration = sum([d['time'] for d in self._episode_info]) episode_reward = [d['reward'] for d in self._episode_info] + if not self.collect_with_pure_policy: - visit_entropy = [d['visit_entropy'] for d in self._episode_info] + visit_entropy = [d.get('visit_entropy', 0.0) for d in self._episode_info] else: visit_entropy = [0.0] - if self.policy_config.gumbel_algo: - completed_value = [d['completed_value'] for d in self._episode_info] - self._total_duration += duration + info = { 'episode_count': episode_count, 'envstep_count': envstep_count, 'avg_envstep_per_episode': envstep_count / episode_count, - 'avg_envstep_per_sec': envstep_count / duration, - 'avg_episode_per_sec': episode_count / duration, + 'avg_envstep_per_sec': envstep_count / duration if duration > 0 else 0, + 'avg_episode_per_sec': episode_count / duration if duration > 0 else 0, 'collect_time': duration, 'reward_mean': np.mean(episode_reward), 'reward_std': np.std(episode_reward), @@ -784,16 +667,25 @@ def _output_log(self, train_iter: int) -> None: 'total_envstep_count': self._total_envstep_count, 'total_episode_count': self._total_episode_count, 'total_duration': self._total_duration, - 'visit_entropy': np.mean(visit_entropy), + 'visit_entropy_mean': np.mean(visit_entropy), } if self.policy_config.gumbel_algo: - info['completed_value'] = np.mean(completed_value) + completed_value = [d.get('completed_value', 0.0) for d in self._episode_info] + info['completed_value_mean'] = np.mean(completed_value) + self._episode_info.clear() - self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) + + self._logger.info(f"Collector log (rank {self._rank}, task_id {self.task_id}):\n" + '\n'.join([f'{k}: {v}' for k, v in info.items()])) for k, v in info.items(): if k in ['each_reward']: continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - if k in ['total_envstep_count']: - continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) \ No newline at end of file + if self.task_id is None: + # Log for single-task setting + self._tb_logger.add_scalar(f'{self._instance_name}_iter/{k}', v, train_iter) + if k not in ['total_envstep_count', 'total_episode_count', 'total_duration']: + self._tb_logger.add_scalar(f'{self._instance_name}_step/{k}', v, self._total_envstep_count) + else: + # Log for multi-task setting + self._tb_logger.add_scalar(f'{self._instance_name}_iter_task{self.task_id}/{k}', v, train_iter) + if k not in ['total_envstep_count', 'total_episode_count', 'total_duration']: + self._tb_logger.add_scalar(f'{self._instance_name}_step_task{self.task_id}/{k}', v, self._total_envstep_count) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f53f1dd5c..2b56f5a3f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ moviepy pytest line_profiler xxhash +simple_parsing einops openai nltk \ No newline at end of file diff --git a/zoo/atari/config/atari_env_action_space_map.py b/zoo/atari/config/atari_env_action_space_map.py index e2090586d..d40d12f41 100644 --- a/zoo/atari/config/atari_env_action_space_map.py +++ b/zoo/atari/config/atari_env_action_space_map.py @@ -27,4 +27,7 @@ 'SeaquestNoFrameskip-v4': 18, 'BoxingNoFrameskip-v4': 18, 'BreakoutNoFrameskip-v4': 4, + 'SpaceInvadersNoFrameskip-v4': 6, + 'BeamRiderNoFrameskip-v4': 9, + 'GravitarNoFrameskip-v4': 18, }) \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_multitask_segment_ddp_config.py b/zoo/atari/config/atari_muzero_multitask_segment_ddp_config.py new file mode 100644 index 000000000..7d640e1d7 --- /dev/null +++ b/zoo/atari/config/atari_muzero_multitask_segment_ddp_config.py @@ -0,0 +1,330 @@ +""" +Overview: + Configuration generation script for multi-task MuZero training on Atari environments. + This script defines and generates the necessary configuration files for a distributed training setup. +""" +from easydict import EasyDict +from copy import deepcopy +from typing import List, Union, Dict, Any + +# The 'atari_env_action_space_map' was not used in the original code, so it has been removed. + +class AtariMuZeroMultitaskConfig: + """ + Overview: + A class to generate and manage configurations for multi-task MuZero experiments on Atari. + It encapsulates the entire configuration logic, providing a clean and extensible interface. + """ + + def __init__( + self, + env_id_list: List[str], + seed: int, + num_unroll_steps: int, + num_simulations: int, + collector_env_num: int, + evaluator_env_num: int, + max_env_step: int, + batch_size: Union[List[int], int], + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + exp_path_prefix: str = 'YOUR_EXPERIMENT_PATH_PREFIX/data_muzero_mt_atari', + ) -> None: + """ + Overview: + Initializes the multi-task configuration generator. + Arguments: + - env_id_list (:obj:`List[str]`): A list of Atari environment IDs to be trained on. + - seed (:obj:`int`): The random seed for the experiment. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model during training. + - num_simulations (:obj:`int`): The number of simulations to run in the MCTS search. + - collector_env_num (:obj:`int`): The number of environments for data collection. + - evaluator_env_num (:obj:`int`): The number of environments for evaluation. + - max_env_step (:obj:`int`): The total number of environment steps to train for. + - batch_size (:obj:`Union[List[int], int]`): The batch size for training. Can be a list for per-task sizes or a single int. + - norm_type (:obj:`str`): The type of normalization to use in the model (e.g., 'BN', 'LN'). + - buffer_reanalyze_freq (:obj:`float`): The frequency at which to reanalyze the replay buffer. + - reanalyze_batch_size (:obj:`int`): The batch size for reanalysis. + - reanalyze_partition (:obj:`float`): The partition ratio for reanalysis. + - num_segments (:obj:`int`): The number of segments for the replay buffer. + - exp_path_prefix (:obj:`str`): A template for the experiment's output path. + """ + self.env_id_list = env_id_list + self.seed = seed + self.num_unroll_steps = num_unroll_steps + self.num_simulations = num_simulations + self.collector_env_num = collector_env_num + self.evaluator_env_num = evaluator_env_num + self.max_env_step = max_env_step + self.batch_size = batch_size + self.norm_type = norm_type + self.buffer_reanalyze_freq = buffer_reanalyze_freq + self.reanalyze_batch_size = reanalyze_batch_size + self.reanalyze_partition = reanalyze_partition + self.num_segments = num_segments + self.exp_path_prefix = exp_path_prefix + + # --- Derived attributes --- + self.num_tasks = len(self.env_id_list) + self.action_space_size = 18 # Default full action space for Atari + + def _create_base_config(self) -> EasyDict: + """ + Overview: + Creates the base configuration dictionary with shared settings for all tasks. + Returns: + - (:obj:`EasyDict`): A dictionary containing the base configuration. + """ + return EasyDict(dict( + env=dict( + stop_value=int(self.max_env_step), + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + collector_env_num=self.collector_env_num, + evaluator_env_num=self.evaluator_env_num, + n_evaluator_episode=self.evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + ), + policy=dict( + multi_gpu=True, # Very important for DDP + learn=dict( + learner=dict( + hook=dict(save_ckpt_after_iter=200000), + ), + ), + grad_correct_params=dict(), + task_num=self.num_tasks, + model=dict( + device='cuda', + num_res_blocks=2, + num_channels=256, + reward_head_channels=16, + value_head_channels=16, + policy_head_channels=16, + fc_reward_layers=[32], + fc_value_layers=[32], + fc_policy_layers=[32], + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + action_space_size=self.action_space_size, + norm_type=self.norm_type, + model_type='conv', + image_channel=1, + downsample=True, + self_supervised_learning_loss=True, + discrete_action_encoding_type='one_hot', + use_sim_norm=True, + use_sim_norm_kl_loss=False, + task_num=self.num_tasks, + ), + allocated_batch_sizes=False, + cuda=True, + env_type='not_board_games', + train_start_after_envsteps=2000, + # train_start_after_envsteps=0, # TODO: debug + game_segment_length=20, + random_collect_episode_num=0, + use_augmentation=True, + use_priority=False, + replay_ratio=0.25, + num_unroll_steps=self.num_unroll_steps, + update_per_collect=80, + optim_type='SGD', + td_steps=5, + lr_piecewise_constant_decay=True, + manual_temperature_decay=False, + learning_rate=0.2, + target_update_freq=100, + num_segments=self.num_segments, + num_simulations=self.num_simulations, + policy_entropy_weight=5e-3, # TODO: Fine-tune this weight. + ssl_loss_weight=2, + eval_freq=int(5e3), + replay_buffer_size=int(5e5), + collector_env_num=self.collector_env_num, + evaluator_env_num=self.evaluator_env_num, + # ============= Reanalyze Parameters ============= + buffer_reanalyze_freq=self.buffer_reanalyze_freq, + reanalyze_batch_size=self.reanalyze_batch_size, + reanalyze_partition=self.reanalyze_partition, + ), + )) + + def _get_exp_name(self, env_id: str) -> str: + """ + Overview: + Generates a formatted experiment name for a given task. + Arguments: + - env_id (:obj:`str`): The environment ID for the specific task. + Returns: + - (:obj:`str`): The formatted experiment name. + """ + # TODO: debug name + prefix = ( + f'{self.exp_path_prefix}/{self.num_tasks}games_brf{self.buffer_reanalyze_freq}/' + f'{self.num_tasks}games_brf{self.buffer_reanalyze_freq}_1-encoder-{self.norm_type}-res2-channel256_gsl20_' + f'{self.num_tasks}-pred-head_mbs-512_upc80_H{self.num_unroll_steps}_seed{self.seed}/' + ) + env_name = env_id.split('NoFrameskip')[0] + return f"{prefix}{env_name}_muzero-mt_seed{self.seed}" + + def generate_configs(self) -> List[List[Union[int, List[Any]]]]: + """ + Overview: + Generates the final list of configurations for all specified tasks, + ready to be used by the training entry point. + Returns: + - (:obj:`List[List[Union[int, List[Any]]]]`): A list where each element corresponds to a task, + containing the task_id and a list with the task's config and env_manager config. + """ + base_config = self._create_base_config() + env_manager_config = self._create_env_manager_config() + + configs = [] + for task_id, env_id in enumerate(self.env_id_list): + task_config = deepcopy(base_config) + + # --- Apply task-specific settings --- + task_config.env.env_id = env_id + task_config.policy.task_id = task_id + + # Handle per-task batch size if provided as a list + if isinstance(self.batch_size, list): + task_config.policy.batch_size = self.batch_size[task_id] + else: + task_config.policy.batch_size = self.batch_size + + task_config.exp_name = self._get_exp_name(env_id) + + configs.append([task_id, [task_config, env_manager_config]]) + + return configs + + @staticmethod + def _create_env_manager_config() -> EasyDict: + """ + Overview: + Creates a static configuration for the environment and policy managers. + Returns: + - (:obj:`EasyDict`): A dictionary containing manager configurations. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero_multitask', + import_names=['lzero.policy.muzero_multitask'], + ), + )) + + +if __name__ == "__main__": + # ============================================================== + # Hyperparameters for Multi-Task Training + # ============================================================== + + # --- List of Atari environments for multi-task learning --- + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', 'AmidarNoFrameskip-v4', + 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', + 'FreewayNoFrameskip-v4', 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', + 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', 'KrullNoFrameskip-v4', + 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + + # --- Core Experiment Settings --- + seed = 0 + max_env_step = int(5e5) + + # --- Training & Model Parameters --- + num_unroll_steps = 5 + num_simulations = 50 + norm_type = 'BN' # 'BN' (Batch Normalization) or 'LN' (Layer Normalization) + + # --- Environment & Collector Settings --- + collector_env_num = 8 + evaluator_env_num = 3 + num_segments = 8 + + # --- Batch Size Configuration --- + # The batch size is dynamically calculated per task to not exceed a maximum total batch size. + max_batch_size = 512 + per_task_batch_size = int(min(64, max_batch_size / len(env_id_list))) + batch_size = [per_task_batch_size] * len(env_id_list) + + # --- Reanalyze Buffer Settings --- + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # --- (Optional) Debug Settings --- + # To use debug settings, uncomment the following lines. + # collector_env_num = 2 + # evaluator_env_num = 2 + # num_segments = 2 + # num_simulations = 3 + # debug_batch_size = int(min(2, max_batch_size / len(env_id_list))) + # batch_size = [debug_batch_size] * len(env_id_list) + # print("--- RUNNING IN DEBUG MODE ---") + + print(f'=========== Batch size per task: {batch_size[0]} ===========') + + # ============================================================== + # Configuration Generation and Training Launch + # ============================================================== + + # --- Instantiate and generate configurations --- + experiment_config = AtariMuZeroMultitaskConfig( + env_id_list=env_id_list, + seed=seed, + max_env_step=max_env_step, + num_unroll_steps=num_unroll_steps, + num_simulations=num_simulations, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + batch_size=batch_size, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + # Note: Update this path to your desired location. + exp_path_prefix='YOUR_EXPERIMENT_PATH_PREFIX/data_muzero_mt_atari_20250228' + ) + + configs_to_run = experiment_config.generate_configs() + + # --- Launch Distributed Training --- + """ + Overview: + This script should be executed with GPUs. + Set the NCCL timeout and launch the script using one of the following commands. + + Command using torch.distributed.launch: + export NCCL_TIMEOUT=3600000 + python -m torch.distributed.launch --nproc_per_node=4 --master_port=29501 ./path/to/this/script.py + + Command using torchrun: + export NCCL_TIMEOUT=3600000 + torchrun --nproc_per_node=4 --master_port=29501 ./path/to/this/script.py + """ + from lzero.entry import train_muzero_multitask_segment_ddp + from ding.utils import DDPContext + + with DDPContext(): + train_muzero_multitask_segment_ddp(configs_to_run, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_segment_config.py b/zoo/atari/config/atari_muzero_segment_config.py index 4289fb957..03fffa39e 100644 --- a/zoo/atari/config/atari_muzero_segment_config.py +++ b/zoo/atari/config/atari_muzero_segment_config.py @@ -18,11 +18,14 @@ def main(env_id, seed): num_unroll_steps = 5 batch_size = 256 - max_env_step = int(5e5) + # max_env_step = int(5e5) + max_env_step = int(100e6) + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. # buffer_reanalyze_freq = 1/10 - buffer_reanalyze_freq = 1/10000 + buffer_reanalyze_freq = 1/50 + # buffer_reanalyze_freq = 1/10000 # Each reanalyze process will reanalyze sequences ( transitions per sequence) reanalyze_batch_size = 160 # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. @@ -43,7 +46,7 @@ def main(env_id, seed): env=dict( stop_value=int(1e6), env_id=env_id, - observation_shape=(4, 96, 96), + observation_shape=(4, 64, 64), frame_stack_num=4, gray_scale=True, collector_env_num=collector_env_num, @@ -59,7 +62,7 @@ def main(env_id, seed): analysis_sim_norm=False, cal_dormant_ratio=False, model=dict( - observation_shape=(4, 96, 96), + observation_shape=(4, 64, 64), image_channel=1, frame_stack_num=4, gray_scale=True, @@ -123,7 +126,7 @@ def main(env_id, seed): # ============ use muzero_segment_collector instead of muzero_collector ============= from lzero.entry import train_muzero_segment - main_config.exp_name = f'data_muzero/{env_id[:-14]}/{env_id[:-14]}_mz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}_bs{batch_size}_seed{seed}' + main_config.exp_name = f'data_lz/data_muzero/{env_id[:-14]}/{env_id[:-14]}_mz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}_bs{batch_size}_seed{seed}' train_muzero_segment([main_config, create_config], seed=seed, max_env_step=max_env_step) if __name__ == "__main__": diff --git a/zoo/atari/config/atari_rezero_mz_config.py b/zoo/atari/config/atari_rezero_mz_config.py index c7787831b..91517afd5 100644 --- a/zoo/atari/config/atari_rezero_mz_config.py +++ b/zoo/atari/config/atari_rezero_mz_config.py @@ -18,6 +18,17 @@ reuse_search = True collect_with_pure_policy = True buffer_reanalyze_freq = 1 + +# ====== only for debug ===== +# collector_env_num = 8 +# num_segments = 8 +# evaluator_env_num = 2 +# num_simulations = 5 +# max_env_step = int(2e5) +# reanalyze_ratio = 0.1 +# batch_size = 64 +# num_unroll_steps = 10 +# replay_ratio = 0.01 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -32,6 +43,9 @@ evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), + # # TODO: only for debug + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), ), policy=dict( model=dict( diff --git a/zoo/atari/config/atari_unizero_ddp_config.py b/zoo/atari/config/atari_unizero_ddp_config.py index d64332d58..887e5f7cb 100644 --- a/zoo/atari/config/atari_unizero_ddp_config.py +++ b/zoo/atari/config/atari_unizero_ddp_config.py @@ -55,13 +55,20 @@ max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action context_length=2 * infer_context_length, device='cuda', - # device='cpu', action_space_size=action_space_size, num_layers=2, num_heads=8, embed_dim=768, obs_type='image', env_num=max(collector_env_num, evaluator_env_num), + task_num=1, + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, ), ), # (str) The path of the pretrained model. If None, the model will be initialized by the default model. diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py new file mode 100644 index 000000000..9c4725f9f --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py @@ -0,0 +1,550 @@ +# -*- coding: utf-8 -*- +""" +Overview: + This script contains the configuration generation logic for a multi-task UniZero agent + designed for Atari environments. It sets up experiment parameters, computes batch sizes + for distributed training, and generates the final configuration objects required to + launch the training process. + +Execution Command Example: + To run this script using distributed training with GPUs, use the following command. + Replace with the number of GPUs per node (e.g., 8) and adjust paths and log files as needed. + + cd /path/to/your/project/LightZero + python -m torch.distributed.launch --nproc_per_node= --master_port= \ + /path/to/this/script.py 2>&1 | tee /path/to/your/logs/training.log +""" +import math +from typing import List, Tuple, Dict, Any + +from easydict import EasyDict +from ding.utils import DDPContext +# It is recommended to place entry point imports within the main execution block +# to avoid circular dependencies or premature initializations. +# from lzero.entry import train_unizero_multitask_balance_segment_ddp + + +# ============================================================== +# Configuration Computation and Generation +# ============================================================== + +def compute_batch_config( + env_id_list: List[str], + effective_batch_size: int, + gpus_per_node: int = 8, + max_micro_batch_per_gpu: int = 400 +) -> Tuple[List[int], int]: + """ + Overview: + Computes the micro-batch size for each environment and the number of gradient accumulation steps. + This is designed to balance the load across multiple environments and GPUs while respecting + memory constraints (max_micro_batch_per_gpu). + + Arguments: + - env_id_list (:obj:`List[str]`): A list of environment IDs. + - effective_batch_size (:obj:`int`): The target total batch size after gradient accumulation. + - gpus_per_node (:obj:`int`): The number of GPUs available for training. Defaults to 8. + - max_micro_batch_per_gpu (:obj:`int`): The maximum micro-batch size that can fit on a single GPU. Defaults to 400. + + Returns: + - (:obj:`Tuple[List[int], int]`): A tuple containing: + - A list of micro-batch sizes, one for each environment. + - The number of gradient accumulation steps required. + """ + num_envs = len(env_id_list) + if num_envs == 0: + return [], 1 + + # To avoid division by zero, assume at least one environment is processed per GPU group. + envs_per_gpu_group = max(1, num_envs // gpus_per_node) + + # Calculate the maximum micro-batch size per environment based on GPU memory limits. + max_micro_batch_per_env = int(max_micro_batch_per_gpu / envs_per_gpu_group) + + # Calculate the theoretical batch size per environment if distributed evenly. + theoretical_env_batch = effective_batch_size / num_envs + + if theoretical_env_batch > max_micro_batch_per_env: + # If the theoretical batch size exceeds the per-environment limit, + # cap the micro-batch size at the maximum allowed value. + micro_batch_size = max_micro_batch_per_env + # Calculate gradient accumulation steps needed to reach the effective batch size. + grad_accumulate_steps = math.ceil(theoretical_env_batch / max_micro_batch_per_env) + else: + # If the theoretical batch size is within limits, use it directly. + micro_batch_size = int(theoretical_env_batch) + grad_accumulate_steps = 1 + + # Assign the same computed micro-batch size to all environments. + batch_sizes = [micro_batch_size] * num_envs + + # Logging for debugging purposes. + print(f"Number of environments: {num_envs}") + print(f"Effective total batch size: {effective_batch_size}") + print(f"Theoretical batch size per environment: {theoretical_env_batch:.2f}") + print(f"Micro-batch size per environment: {micro_batch_size}") + print(f"Gradient accumulation steps: {grad_accumulate_steps}") + + return batch_sizes, grad_accumulate_steps + + +def create_config( + env_id: str, + action_space_size: int, + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: int, + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + target_return: int, + curriculum_stage_num: int, + num_envs: int, +) -> EasyDict: + """ + Overview: + Creates the main configuration dictionary for a single UniZero task. + + Arguments: + - env_id (:obj:`str`): The ID of the environment (e.g., 'PongNoFrameskip-v4'). + - action_space_size (:obj:`int`): The size of the action space. + - collector_env_num (:obj:`int`): Number of environments for data collection. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - n_episode (:obj:`int`): Number of episodes to run for collection. + - num_simulations (:obj:`int`): Number of simulations for MCTS. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in a batch. + - batch_size (:obj:`int`): The micro-batch size for training. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model dynamics. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization layer to use (e.g., 'LN'). + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the replay buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalysis. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalysis. + - num_segments (:obj:`int`): Number of segments for game episodes. + - total_batch_size (:obj:`int`): The effective total batch size. + - target_return (:obj:`int`): The target return for the environment. + - curriculum_stage_num (:obj:`int`): The number of stages in curriculum learning. + - num_envs (:obj:`int`): The total number of environments in the multi-task setup. + + Returns: + - (:obj:`EasyDict`): A configuration object for the agent. + """ + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + ), + policy=dict( + multi_gpu=True, # Crucial for DDP + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + continuous_action_space=False, + world_model_cfg=dict( + use_global_pooling=False, + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + share_head=False, + analysis_dormant_ratio_weight_rank=False, + dormant_threshold=0.025, + continuous_action_space=False, + task_embed_option=None, + use_task_embed=False, + use_shared_projection=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=4, + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=num_envs, + task_num=num_envs, + encoder_type='vit', + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=True, + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + moe_use_lora=True, + curriculum_stage_num=curriculum_stage_num, + lora_target_modules=["attn", "feed_forward"], + lora_r=64, + lora_alpha=32, + lora_dropout=0.1, + lora_scale_init=1, + min_stage0_iters=50000, + max_stage_iters=20000, + apply_curriculum_to_encoder=False, + ), + ), + # --- Task and Learning Settings --- + total_task_num=num_envs, + task_num=num_envs, + task_id=0, # This will be overridden for each task. + target_return=target_return, + use_task_exploitation_weight=False, + task_complexity_weight=True, + balance_pipeline=True, + # --- Training Settings --- + cuda=True, + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + update_per_collect=80, + replay_ratio=0.25, + optim_type='AdamW', + cos_lr_scheduler=False, + train_start_after_envsteps=int(0), + # --- Replay Buffer and Reanalysis --- + replay_buffer_size=int(5e5), + num_segments=num_segments, + use_priority=False, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + reanalyze_ratio=reanalyze_ratio, + # --- MCTS Settings --- + num_simulations=num_simulations, + collect_num_simulations=num_simulations, + eval_num_simulations=50, + # --- Collector and Evaluator Settings --- + n_episode=n_episode, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + eval_freq=int(1e4), + # --- Miscellaneous --- + print_task_priority_logs=False, + model_path=None, + game_segment_length=20, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + ), + )) + + +def _generate_experiment_name( + base_path_prefix: str, + num_envs: int, + curriculum_stage_num: int, + buffer_reanalyze_freq: float, + seed: int, + env_id: str +) -> str: + """ + Overview: + Helper function to generate a standardized experiment name. + + Arguments: + - base_path_prefix (:obj:`str`): The prefix for the experiment path, e.g., 'data_unizero_atari_mt_balance_YYYYMMDD'. + - num_envs (:obj:`int`): The total number of environments. + - curriculum_stage_num (:obj:`int`): The number of curriculum stages. + - buffer_reanalyze_freq (:obj:`float`): The buffer reanalyze frequency. + - seed (:obj:`int`): The random seed for the experiment. + - env_id (:obj:`str`): The environment ID for this specific task. + + Returns: + - (:obj:`str`): The generated experiment name. + """ + # Template for the experiment's parent directory. + brf_str = str(buffer_reanalyze_freq).replace('.', '') + parent_dir = ( + f"{base_path_prefix}/atari_{num_envs}games_balance-total-stage{curriculum_stage_num}_" + f"stage-50k-20k_vit-small-ln_trans-nlayer4-moe8_backbone-attn-mlp-lora_no-lora-scale_" + f"brf{brf_str}_not-share-head_seed{seed}/" + ) + + # Clean the environment ID for the final part of the name. + env_name_part = env_id.split('NoFrameskip')[0] + + return f"{parent_dir}{env_name_part}_seed{seed}" + + +def generate_configs( + env_id_list: List[str], + action_space_size: int, + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_sizes: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + target_return_dict: Dict[str, int], + curriculum_stage_num: int, +) -> List[Tuple[int, List[Any]]]: + """ + Overview: + Generates a list of configuration tuples, one for each task/environment. + + Returns: + - (:obj:`List[Tuple[int, List[Any]]]`): A list where each element is a tuple containing + the task_id and a list with the main config and the environment manager config. + """ + configs = [] + exp_name_base_prefix = 'data_unizero_atari_mt_balance_20250730' # YYYYMMDD format + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id=env_id, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_sizes[task_id], + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + target_return=target_return_dict[env_id], + curriculum_stage_num=curriculum_stage_num, + num_envs=len(env_id_list), + ) + config.policy.task_id = task_id + config.exp_name = _generate_experiment_name( + base_path_prefix=exp_name_base_prefix, + num_envs=len(env_id_list), + curriculum_stage_num=curriculum_stage_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + seed=seed, + env_id=env_id + ) + configs.append([task_id, [config, create_env_manager()]]) + return configs + + +def create_env_manager() -> EasyDict: + """ + Overview: + Creates the environment manager configuration, specifying the types of environment, + policy, and manager to be used. + + Returns: + - (:obj:`EasyDict`): A configuration object for the environment manager. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + + +def get_atari_target_return_dict(ratio: float = 1.0) -> Dict[str, int]: + """ + Overview: + Calculates the target return for each Atari game based on a predefined score + and a scaling ratio. + + Arguments: + - ratio (:obj:`float`): A scaling factor for the target returns. Defaults to 1.0. + + Returns: + - (:obj:`Dict[str, int]`): A dictionary mapping environment IDs to their calculated target returns. + """ + # Pre-defined target scores for various Atari games. + target_scores = { + 'PongNoFrameskip-v4': 20, + 'MsPacmanNoFrameskip-v4': 6951.6, + 'SeaquestNoFrameskip-v4': 42054.7, + 'BoxingNoFrameskip-v4': 12.1, + 'AlienNoFrameskip-v4': 7127.7, + 'ChopperCommandNoFrameskip-v4': 7387.8, + 'HeroNoFrameskip-v4': 30826.4, + 'RoadRunnerNoFrameskip-v4': 7845.0, + 'AmidarNoFrameskip-v4': 100.5, + 'AssaultNoFrameskip-v4': 742.0, + 'AsterixNoFrameskip-v4': 1503.3, + 'BankHeistNoFrameskip-v4': 753.1, + 'BattleZoneNoFrameskip-v4': 12187.5, + 'CrazyClimberNoFrameskip-v4': 15829.4, + 'DemonAttackNoFrameskip-v4': 1971.0, + 'FreewayNoFrameskip-v4': 29.6, + 'FrostbiteNoFrameskip-v4': 334.7, + 'GopherNoFrameskip-v4': 2412.5, + 'JamesbondNoFrameskip-v4': 302.8, + 'KangarooNoFrameskip-v4': 3035.0, + 'KrullNoFrameskip-v4': 2665.5, + 'KungFuMasterNoFrameskip-v4': 12736.3, + 'PrivateEyeNoFrameskip-v4': 1001.3, + 'UpNDownNoFrameskip-v4': 11693.2, + 'QbertNoFrameskip-v4': 13455.0, + 'BreakoutNoFrameskip-v4': 30.5, + } + return {env: int(round(score * ratio)) for env, score in target_scores.items()} + + +def get_env_id_list(num_games: int) -> List[str]: + """ + Overview: + Returns a list of Atari environment IDs based on the specified number of games. + + Arguments: + - num_games (:obj:`int`): The number of games to include (e.g., 8 or 26). + + Returns: + - (:obj:`List[str]`): A list of environment ID strings. + """ + games_8 = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + games_26 = games_8 + [ + 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', + 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + if num_games == 3: + return ['PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4'] + elif num_games == 8: + return games_8 + elif num_games == 26: + return games_26 + else: + raise ValueError(f"Unsupported number of games: {num_games}. Supported values are 3, 8, 26.") + + +def main(): + """ + Overview: + Main function to configure and launch the multi-task training process. + """ + # ============================================================== + # Primary Hyperparameters + # ============================================================== + # --- Experiment --- + num_games = 8 # Options: 3, 8, 26 + seeds = [0] + max_env_step = int(4e5) + benchmark_name = "atari" + + # --- Curriculum --- + curriculum_stage_num = 5 + + # --- Environment and Agent --- + action_space_size = 18 + num_simulations = 50 + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + + # --- Collector and Evaluator --- + collector_env_num = 8 + evaluator_env_num = 3 + n_episode = 8 + num_segments = 8 + + # --- Reanalysis --- + reanalyze_ratio = 0.0 + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ============================================================== + # Derived Configurations + # ============================================================== + env_id_list = get_env_id_list(num_games) + target_return_dict = get_atari_target_return_dict(ratio=1.0) + + # --- Batch Size Calculation --- + if num_games == 8: + effective_batch_size = 512 + elif num_games == 26: + effective_batch_size = 512 # For ViT-Base encoder + else: + # Default or other cases + effective_batch_size = 512 + + batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size) + # Note: `total_batch_size` is passed to the config but `effective_batch_size` is used for calculation. + # This maintains consistency with the original script's logic. + total_batch_size = effective_batch_size + + # ============================================================== + # Launch Training + # ============================================================== + from lzero.entry import train_unizero_multitask_balance_segment_ddp + + for seed in seeds: + configs = generate_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_sizes=batch_sizes, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + target_return_dict=target_return_dict, + curriculum_stage_num=curriculum_stage_num + ) + + with DDPContext(): + train_unizero_multitask_balance_segment_ddp( + configs, + seed=seed, + max_env_step=max_env_step, + benchmark_name=benchmark_name + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py new file mode 100644 index 000000000..33de7eea0 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py @@ -0,0 +1,446 @@ +from easydict import EasyDict +import math +from typing import List, Tuple, Any, Dict, Union + +# ------------------------------------------------- +# 1. Refactored compute_batch_config +# ------------------------------------------------- +def compute_batch_config( + env_id_list: List[str], + effective_batch_size: int, + gpu_num: int = 8, + max_micro_batch_one_gpu: int = 400, +) -> Tuple[List[int], int]: + """ + Overview: + Calculate the micro-batch size for each environment and the number of gradient accumulation steps + to approach a target effective batch size across multiple GPUs and environments. + + Arguments: + - env_id_list (:obj:`List[str]`): A list of environment IDs for all tasks. + - effective_batch_size (:obj:`int`): The target global batch size for one backward pass. + - gpu_num (:obj:`int`): The number of GPUs actually used. Defaults to 8. + - max_micro_batch_one_gpu (:obj:`int`): The maximum micro-batch size a single GPU can handle. Defaults to 400. + + Returns: + - batch_sizes (:obj:`List[int]`): A list of micro-batch sizes for each environment. + - grad_acc_steps (:obj:`int`): The number of gradient accumulation steps. + """ + n_env = len(env_id_list) + # Number of environments that each GPU needs to handle simultaneously. + envs_per_gpu = max(1, math.ceil(n_env / gpu_num)) + # Reduce the micro-batch limit if multiple environments share one GPU. + max_micro_batch = max(1, max_micro_batch_one_gpu // envs_per_gpu) + + # First, calculate a candidate micro-batch by distributing the effective batch size evenly. + candidate = max(1, effective_batch_size // n_env) + micro_batch = min(candidate, max_micro_batch) + + # Gradient accumulation steps = ceil(global_batch / (micro_batch * n_env)). + grad_acc_steps = max(1, math.ceil(effective_batch_size / (micro_batch * n_env))) + + # Fine-tune the micro-batch downwards to ensure: + # micro_batch * n_env * grad_acc_steps <= effective_batch_size + # This aims to get as close as possible to the target without exceeding it. + while micro_batch * n_env * grad_acc_steps > effective_batch_size: + micro_batch -= 1 + if micro_batch == 0: # Defensive check, should not happen in theory. + micro_batch = 1 + break + + batch_sizes = [micro_batch] * n_env + + # --- Debug Information --- # + real_total_batch_size = micro_batch * n_env * grad_acc_steps + print( + f"[BatchConfig] Envs={n_env}, TargetTotalBS={effective_batch_size}, " + f"MicroBS={micro_batch}, GradAccSteps={grad_acc_steps}, RealTotalBS={real_total_batch_size}" + ) + + return batch_sizes, grad_acc_steps + +def create_config( + env_id: str, action_space_size: int, collector_env_num: int, evaluator_env_num: int, n_episode: int, + num_simulations: int, reanalyze_ratio: float, batch_size: int, num_unroll_steps: int, + infer_context_length: int, norm_type: str, buffer_reanalyze_freq: float, reanalyze_batch_size: int, + reanalyze_partition: float, num_segments: int, total_batch_size: int, num_layers: int +) -> EasyDict: + """ + Overview: + Creates the main configuration structure for a single training task. + + Arguments: + - env_id (:obj:`str`): The environment ID. + - action_space_size (:obj:`int`): The size of the action space. + - collector_env_num (:obj:`int`): Number of environments for data collection. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - n_episode (:obj:`int`): Number of episodes to run for evaluation. + - num_simulations (:obj:`int`): Number of simulations in MCTS. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed samples in a batch. + - batch_size (:obj:`int`): The batch size for training. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model dynamics. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization layer to use (e.g., 'LN'). + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the replay buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalysis. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalysis. + - num_segments (:obj:`int`): Number of segments for data collection. + - total_batch_size (:obj:`int`): The total effective batch size. + - num_layers (:obj:`int`): Number of layers in the transformer model. + + Returns: + - (:obj:`EasyDict`): A configuration object. + """ + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + + # collect_max_episode_steps=int(50), # debug + # eval_max_episode_steps=int(50), + ), + policy=dict( + multi_gpu=True, # Essential for DDP (Distributed Data Parallel) + only_use_moco_stats=False, + use_moco=False, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + moco_version="v1", + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, # This will be overridden for each task + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + continuous_action_space=False, + world_model_cfg=dict( + num_res_blocks=2, + num_channels=256, + norm_type=norm_type, + use_global_pooling=False, + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + share_head=False, + analysis_dormant_ratio_weight_rank=False, + # analysis_dormant_ratio_weight_rank=True, + # analysis_dormant_ratio_interval=5000, + continuous_action_space=False, + task_embed_option=None, + use_task_embed=False, + use_shared_projection=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + # num_heads=24, + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=len(env_id_list), + task_num=len(env_id_list), + # game_segment_length=game_segment_length, + game_segment_length=20, # TODO + use_priority=True, + # use_priority=False, # TODO===== + priority_prob_alpha=1, + priority_prob_beta=1, + # encoder_type='vit', + encoder_type='resnet', + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + + multiplication_moe_in_transformer=True, + # multiplication_moe_in_transformer=False, # TODO===== + + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + # LoRA parameters + moe_use_lora=False, + lora_r=0, + lora_alpha=1, + lora_dropout=0.0, + + + optim_type='AdamW_mix_lr_wdecay', # only for tsne plot + ), + ), + optim_type='AdamW_mix_lr_wdecay', + weight_decay=1e-2, # TODO: encoder 5*wd, transformer wd, head 0 + learning_rate=0.0001, + + # (bool) 是否启用自适应策略熵权重 (alpha) + use_adaptive_entropy_weight=True, + # use_adaptive_entropy_weight=False, + + # (float) 自适应alpha优化器的学习率 + adaptive_entropy_alpha_lr=1e-4, + target_entropy_start_ratio =0.98, + # target_entropy_end_ratio =0.9, # TODO===== + # target_entropy_end_ratio =0.7, + # target_entropy_decay_steps = 100000, # 例如,在100k次迭代后达到最终值 + + target_entropy_end_ratio =0.5, # for action_space=18 + target_entropy_decay_steps = 100000, # 例如,在150k次迭代 300k envsteps后达到最终值 + # target_entropy_decay_steps = 150000, # 例如,在150k次迭代 300k envsteps后达到最终值 + + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) 是否启用 encoder-clip 值的退火。 + use_encoder_clip_annealing=True, + # (str) 退火类型。可选 'linear' 或 'cosine'。 + encoder_clip_anneal_type='cosine', + # (float) 退火的起始 clip 值 (训练初期,较宽松)。 + encoder_clip_start_value=30.0, + # (float) 退火的结束 clip 值 (训练后期,较严格)。 + encoder_clip_end_value=10.0, + # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 + encoder_clip_anneal_steps=100000, # 例如,在100k次迭代后达到最终值 + # encoder_clip_anneal_steps=50000, # 例如,在30k次迭代后达到最终值 + + + # ==================== START: label smooth ==================== + policy_ls_eps_start=0.05, #TODO============= good start in Pong and MsPacman + policy_ls_eps_end=0.01, + policy_ls_eps_decay_steps=50000, # 50k + label_smoothing_eps=0.1, #TODO============= for value + + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + monitor_norm_freq=10000, + # monitor_norm_freq=2, # only for debug + + use_task_exploitation_weight=False, + task_complexity_weight=False, + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + # use_priority=False, # TODO===== + use_priority=True, + priority_prob_alpha=1, + priority_prob_beta=1, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, # Corresponds to replay_ratio=0.5 for 8 games (20*8*0.5=80) + replay_ratio=0.25, + batch_size=batch_size, + # optim_type='AdamW', + cos_lr_scheduler=False, + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + # eval_freq=int(2e4), # Evaluation frequency for 26 games + eval_freq=int(1e4), # Evaluation frequency for 8 games + # eval_freq=int(1e4), # Evaluation frequency for 8 games + # eval_freq=int(2), # ======== TODO: only for debug======== + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs( + env_id_list: List[str], action_space_size: int, collector_env_num: int, n_episode: int, + evaluator_env_num: int, num_simulations: int, reanalyze_ratio: float, batch_size: List[int], + num_unroll_steps: int, infer_context_length: int, norm_type: str, seed: int, + buffer_reanalyze_freq: float, reanalyze_batch_size: int, reanalyze_partition: float, + num_segments: int, total_batch_size: int, num_layers: int +) -> List[List[Union[int, List[EasyDict]]]]: + """ + Overview: + Generates a list of configurations for all specified tasks. + + Arguments: + (See arguments for `create_config` function) + - seed (:obj:`int`): The random seed for the experiment. + + Returns: + - (:obj:`List[List[Union[int, List[EasyDict]]]]`): A list where each element contains a task_id + and its corresponding configuration objects. + """ + configs = [] + # --- Experiment Name Template --- + # Replace placeholders like [BENCHMARK_TAG] and [MODEL_TAG] to define the experiment name. + # benchmark_tag = "data_unizero_mt_refactor1010_debug" # e.g., unizero_atari_mt_20250612 + benchmark_tag = "data_unizero_mt_refactor1012" # e.g., unizero_atari_mt_20250612 + + # model_tag = f"vit-small_moe8_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}_not-share-head" + # model_tag = f"resnet_noprior_noalpha_nomoe_head-inner-ln_adamw-wd1e-2_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}" + + # model_tag = f"vit_prior_alpha-100k-098-07_encoder-100k-30-10_moe8_head-inner-ln_adamw-wd1e-2_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}" + + # model_tag = f"resnet_encoder-100k-30-10-true_label-smooth_prior_alpha-100k-098-07_moe8_head-inner-ln_adamw-wd1e-2-all_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}" + model_tag = f"resnet_tran-nlayer{num_layers}_moe8_encoder-100k-30-10-true_alpha-100k-098-05_prior_adamw-wd1e-2-all_tbs512_brf{buffer_reanalyze_freq}_label-smooth_head-inner-ln" + # model_tag = f"resnet_encoder-100k-30-10-true_label-smooth_prior_alpha-150k-098-05_moe8_head-inner-ln_adamw-wd1e-2-all_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}" + + exp_name_prefix = f'{benchmark_tag}/atari_{len(env_id_list)}games_{model_tag}_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size, num_layers + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager() -> EasyDict: + """ + Overview: + Creates the environment manager configuration, specifying the types of environment, + policy, and their import paths. + + Returns: + - (:obj:`EasyDict`): A configuration object for the environment manager. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs for distributed training. + Run the following command to launch the script: + + Example launch command: + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + + export CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 + + export CUDA_VISIBLE_DEVICES=4,5,6,7 + + cd /path/to/your/project/ + python -m torch.distributed.launch --nproc_per_node=6 --master_port=29502 /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee /mnt/nfs/zhangjinouwen/puyuan/LightZero/log/20251012_resnet_nlayer4_alpha-100k-098-05.log + /path/to/this/script.py 2>&1 | tee /path/to/your/log/file.log + """ + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import torch.distributed as dist + import os + + # --- Main Experiment Settings --- + num_games = 8 # Options: 3, 8, 26 + num_layers = 4 + # num_layers = 2 # debug + action_space_size = 18 + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + # max_env_step = int(4e5) + max_env_step = int(5e6) # TODO + reanalyze_ratio = 0.0 + + if num_games == 3: + env_id_list = ['PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4'] + elif num_games == 8: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + elif num_games == 26: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + else: + raise ValueError(f"Unsupported number of environments: {num_games}") + + # --- Batch Size Calculation --- + # The effective batch size is adjusted based on the number of games and model size (layers) + # to fit within GPU memory constraints. + if len(env_id_list) == 8: + if num_layers in [2, 4]: + effective_batch_size = 512 + elif num_layers == 8: + effective_batch_size = 512 + elif len(env_id_list) == 26: + effective_batch_size = 512 + elif len(env_id_list) == 18: + effective_batch_size = 1536 + elif len(env_id_list) == 3: + effective_batch_size = 10 # For debugging + else: + raise ValueError(f"Batch size not configured for {len(env_id_list)} environments.") + + batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size, gpu_num=6) # TODO + total_batch_size = effective_batch_size # Currently for logging purposes + + # --- Model and Training Settings --- + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 100000000 # Effectively disable buffer reanalyze + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ====== only for debug ===== + # num_games = 8 # Options: 3, 8, 26 + # num_layers = 2 # debug + # collector_env_num = 2 + # num_segments = 2 + # evaluator_env_num = 2 + # num_simulations = 5 + # batch_sizes = [num_games] * len(env_id_list) + # buffer_reanalyze_freq = 1/100000000 + # total_batch_size = num_games * len(env_id_list) + + + # --- Training Loop --- + for seed in [0]: + configs = generate_configs( + env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_sizes, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size, num_layers + ) + + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step, benchmark_name="atari") + print(f"Seed: {seed} training finished!") + if dist.is_initialized(): + dist.destroy_process_group() \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py b/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py new file mode 100644 index 000000000..b7973ff87 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py @@ -0,0 +1,333 @@ +from easydict import EasyDict +from typing import List, Any, Dict + +# ============================================================== +# Environment and Policy Manager Configuration +# ============================================================== + +def create_env_manager() -> EasyDict: + """ + Overview: + Creates the configuration for the environment and policy managers. + This config specifies the types and import paths for core components + like the environment wrapper and the policy definition. + Returns: + - manager_config (:obj:`EasyDict`): A dictionary containing the types and import names + for the environment and policy managers. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +# ============================================================== +# Main Configuration Generation +# ============================================================== + +def create_config( + env_id: str, + action_space_size: int, + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + env_id_list: List[str], +) -> EasyDict: + """ + Overview: + Creates the main configuration dictionary for a single task in a multi-task setup. + Arguments: + - env_id (:obj:`str`): The ID of the environment for this specific task. + - action_space_size (:obj:`int`): The size of the action space for the model. + - collector_env_num (:obj:`int`): The number of environments for the data collector. + - evaluator_env_num (:obj:`int`): The number of environments for the evaluator. + - n_episode (:obj:`int`): The number of episodes to run for collection. + - num_simulations (:obj:`int`): The number of simulations for the MCTS algorithm. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in the replay buffer. + - batch_size (:obj:`List[int]`): The batch size for training, specified per task. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model during training. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization to use (e.g., 'LN' for LayerNorm). + - buffer_reanalyze_freq (:obj:`float`): The frequency at which to reanalyze the buffer. + - reanalyze_batch_size (:obj:`int`): The batch size for reanalyzing data. + - reanalyze_partition (:obj:`float`): The partition ratio for reanalyzing data. + - num_segments (:obj:`int`): The number of segments for game data. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + - env_id_list (:obj:`List[str]`): The list of all environment IDs in the multi-task setup. + Returns: + - config (:obj:`EasyDict`): The complete configuration for a single training task. + """ + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + ), + policy=dict( + multi_gpu=True, # Enable multi-GPU for DDP + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, + MoCo_rho=0, calpha=0.5, rescale=1, + ), + task_num=len(env_id_list), + task_id=0, # Placeholder, will be set in generate_configs + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + env_id_list=env_id_list, + # TODO: Implement and verify the t-SNE analysis functionality. + analysis_tsne=True, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=8, # Transformer layers + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=len(env_id_list), + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(2e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + + +def _generate_exp_name_prefix( + exp_base_path: str, + num_games: int, + buffer_reanalyze_freq: float, + norm_type: str, + seed: int +) -> str: + """ + Overview: + Generates a standardized prefix for the experiment name based on key hyperparameters. + Arguments: + - exp_base_path (:obj:`str`): The base directory for the experiment logs. + - num_games (:obj:`int`): The number of games in the multi-task setup. + - buffer_reanalyze_freq (:obj:`float`): The frequency of buffer reanalysis. + - norm_type (:obj:`str`): The normalization type used in the model. + - seed (:obj:`int`): The random seed for the experiment. + Returns: + - (:obj:`str`): The generated experiment name prefix. + """ + # NOTE: This name is constructed based on a specific convention to encode hyperparameters. + # It includes details about the model architecture, training parameters, and environment setup. + return ( + f'{exp_base_path}/{num_games}games_brf{buffer_reanalyze_freq}_' + f'1-encoder-{norm_type}-res2-channel256_gsl20_{num_games}-pred-head_' + f'nlayer8-nh24-lsd768_seed{seed}/' + ) + + +def generate_configs( + env_id_list: List[str], + action_space_size: int, + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + exp_base_path: str, +) -> List[List[Any]]: + """ + Overview: + Generates a list of configurations for each task in a multi-task training setup. + Each configuration is paired with an environment manager config. + Arguments: + - (All arguments from create_config, plus): + - seed (:obj:`int`): The random seed for the experiment, used for naming. + - exp_base_path (:obj:`str`): The base path for saving experiment results. + Returns: + - configs (:obj:`List[List[Any]]`): A list where each item contains + [task_id, [task_specific_config, env_manager_config]]. + """ + configs = [] + exp_name_prefix = _generate_exp_name_prefix( + exp_base_path, len(env_id_list), buffer_reanalyze_freq, norm_type, seed + ) + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, + n_episode, num_simulations, reanalyze_ratio, batch_size, + num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size, env_id_list + ) + # Assign the specific task ID for this configuration + config.policy.task_id = task_id + # Set the full experiment name for logging and checkpointing + env_name = env_id.split('NoFrameskip')[0] + config.exp_name = exp_name_prefix + f"{env_name}_unizero-mt_seed{seed}" + + configs.append([task_id, [config, create_env_manager()]]) + + return configs + +# ============================================================== +# Main execution block +# ============================================================== + +if __name__ == "__main__": + """ + Overview: + This program is designed to obtain the t-SNE of the latent states in multi-task learning + across a set of Atari games (e.g., 8 games). + + This script should be executed with GPUs for Distributed Data Parallel (DDP) training. + Run one of the following commands to launch the script: + + Using `torch.distributed.launch` (deprecated): + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./path/to/this/script.py + + Using `torchrun` (recommended): + torchrun --nproc_per_node=8 ./path/to/this/script.py + """ + from lzero.entry import train_unizero_multitask_segment_eval + from ding.utils import DDPContext + + # --- Basic Environment and Model Setup --- + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + action_space_size = 18 # Standard action space size for Atari games + + # --- Hyperparameter Configuration --- + # Grouping hyperparameters for better readability and management. + main_hyperparams = { + 'seed': 0, + 'collector_env_num': 2, + 'evaluator_env_num': 2, + 'n_episode': 2, + 'num_simulations': 50, + 'max_env_step': int(4e5), + 'reanalyze_ratio': 0.0, + 'num_segments': 2, + 'num_unroll_steps': 10, + 'infer_context_length': 4, + 'norm_type': 'LN', + 'buffer_reanalyze_freq': 1/50, + 'reanalyze_batch_size': 160, + 'reanalyze_partition': 0.75, + 'total_batch_size': int(4 * len(env_id_list)), + 'batch_size_per_task': 4, + # --- Path for experiment logs and pretrained model --- + # NOTE: Please update these paths to your local directory structure. + 'exp_base_path': 'data/unizero_mt_ddp-8gpu_eval-latent_state_tsne', + # Example for an 8-game pretrained model + 'pretrained_model_path': '/path/to/your/pretrained_model.pth.tar', + # Example for a 26-game pretrained model + # 'pretrained_model_path': '/path/to/your/26_game_model.pth.tar', + } + + # --- Generate Configurations for each seed --- + # This loop allows running experiments with multiple seeds easily. + for seed in [main_hyperparams['seed']]: + # The batch size is a list, with one entry per task. + batch_size_list = [main_hyperparams['batch_size_per_task']] * len(env_id_list) + + # Generate the list of configurations for the trainer + configs = generate_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=main_hyperparams['collector_env_num'], + n_episode=main_hyperparams['n_episode'], + evaluator_env_num=main_hyperparams['evaluator_env_num'], + num_simulations=main_hyperparams['num_simulations'], + reanalyze_ratio=main_hyperparams['reanalyze_ratio'], + batch_size=batch_size_list, + num_unroll_steps=main_hyperparams['num_unroll_steps'], + infer_context_length=main_hyperparams['infer_context_length'], + norm_type=main_hyperparams['norm_type'], + seed=seed, + buffer_reanalyze_freq=main_hyperparams['buffer_reanalyze_freq'], + reanalyze_batch_size=main_hyperparams['reanalyze_batch_size'], + reanalyze_partition=main_hyperparams['reanalyze_partition'], + num_segments=main_hyperparams['num_segments'], + total_batch_size=main_hyperparams['total_batch_size'], + exp_base_path=main_hyperparams['exp_base_path'], + ) + + # --- Launch Training --- + # Use DDPContext to manage the distributed training environment. + with DDPContext(): + train_unizero_multitask_segment_eval( + configs, + seed=seed, + model_path=main_hyperparams['pretrained_model_path'], + max_env_step=main_hyperparams['max_env_step'] + ) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py b/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py new file mode 100644 index 000000000..3581839b2 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py @@ -0,0 +1,409 @@ +from easydict import EasyDict +from typing import List, Tuple, Union, Any, Dict + +class UniZeroAtariConfig: + """ + Overview: + Default configuration class for UniZero Atari experiments. + This class centralizes all default parameters, making it easier to manage and extend. + """ + def __init__(self) -> None: + self.exp_name: str = '' + self.env: EasyDict = self._get_default_env_config() + self.policy: EasyDict = self._get_default_policy_config() + + @staticmethod + def _get_default_env_config() -> EasyDict: + """ + Overview: + Returns the default environment configuration. + """ + return EasyDict(dict( + stop_value=int(1e6), + env_id='PongNoFrameskip-v4', + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=8, + evaluator_env_num=3, + n_evaluator_episode=3, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + )) + + @staticmethod + def _get_default_policy_config() -> EasyDict: + """ + Overview: + Returns the default policy configuration. + """ + return EasyDict(dict( + multi_gpu=True, + # ==============TODO============== + use_moco=False, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000))), + grad_correct_params=dict( + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + calpha=0.5, + rescale=1, + ), + task_num=1, + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=18, + norm_type='LN', + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + # TODO: for latent state layer_norm + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + # TODO: only for latent state sim_norm + # final_norm_option_in_obs_head='SimNorm', + # final_norm_option_in_encoder='SimNorm', + # predict_latent_loss_type='group_kl', + share_head=False, # TODO + analysis_dormant_ratio_weight_rank=False, # TODO + dormant_threshold=0.025, + continuous_action_space=False, + # ==============TODO: none ============== + task_embed_option=None, + use_task_embed=False, + # ==============TODO============== + # task_embed_option='concat_task_embed', + # use_task_embed=True, + # task_embed_dim=96, + # task_embed_dim=128, + use_shared_projection=False, + max_blocks=10, # num_unroll_steps + max_tokens=20, # 2 * num_unroll_steps + context_length=8, # 2 * infer_context_length + device='cuda', + action_space_size=18, + num_layers=8, + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=1, + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + # LoRA parameters (enable LoRA by setting lora_r > 0) + lora_r=0, + # lora_r=8, + lora_alpha=32, + lora_dropout=0.1, + # Default target modules: attn and feed_forward + lora_target_modules=["attn", "feed_forward"], + ), + ), + # TODO + use_task_exploitation_weight=False, + task_complexity_weight=False, + total_batch_size=512, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=10, + game_segment_length=20, + update_per_collect=80, + replay_ratio=0.25, + batch_size=64, + optim_type='AdamW', + cos_lr_scheduler=True, + num_segments=8, + num_simulations=50, + reanalyze_ratio=0.0, + n_episode=8, + replay_buffer_size=int(5e5), + eval_freq=int(2e4), + collector_env_num=8, + evaluator_env_num=3, + buffer_reanalyze_freq=1 / 10000000, + reanalyze_batch_size=160, + reanalyze_partition=0.75, + )) + +def create_config( + env_id: str, + action_space_size: int, + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: Union[int, List[int]], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + task_num: int +) -> EasyDict: + """ + Overview: + Creates and customizes a configuration for a specific Atari environment task. + + Arguments: + - env_id (:obj:`str`): The ID of the Atari environment. + - action_space_size (:obj:`int`): The size of the action space. + - collector_env_num (:obj:`int`): Number of environments for collecting data. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - n_episode (:obj:`int`): Number of episodes to run for each collection. + - num_simulations (:obj:`int`): Number of simulations in the MCTS. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed samples in the replay buffer. + - batch_size (:obj:`Union[int, List[int]]`): The batch size for training. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization to use. + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalyzing. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalyzing. + - num_segments (:obj:`int`): Number of segments for each game. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + - task_num (:obj:`int`): The total number of tasks. + + Returns: + - (:obj:`EasyDict`): A fully configured EasyDict object for the experiment. + """ + cfg = UniZeroAtariConfig() + + # == Update Environment Config == + cfg.env.env_id = env_id + cfg.env.collector_env_num = collector_env_num + cfg.env.evaluator_env_num = evaluator_env_num + cfg.env.n_evaluator_episode = evaluator_env_num + + # == Update Policy Config == + policy = cfg.policy + policy.task_num = task_num + policy.action_space_size = action_space_size + policy.n_episode = n_episode + policy.num_simulations = num_simulations + policy.reanalyze_ratio = reanalyze_ratio + policy.batch_size = batch_size + policy.total_batch_size = total_batch_size + policy.num_unroll_steps = num_unroll_steps + policy.collector_env_num = collector_env_num + policy.evaluator_env_num = evaluator_env_num + policy.buffer_reanalyze_freq = buffer_reanalyze_freq + policy.reanalyze_batch_size = reanalyze_batch_size + policy.reanalyze_partition = reanalyze_partition + policy.num_segments = num_segments + + # == Update Model Config == + model = policy.model + model.action_space_size = action_space_size + model.norm_type = norm_type + + # == Update World Model Config == + world_model = model.world_model_cfg + world_model.max_blocks = num_unroll_steps + world_model.max_tokens = 2 * num_unroll_steps + world_model.context_length = 2 * infer_context_length + world_model.action_space_size = action_space_size + world_model.task_num = task_num + + return EasyDict(cfg) + + +def generate_experiment_configs( + env_id_list: List[str], + action_space_size: int, + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: Union[int, List[int]], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int +) -> List[Tuple[int, List[Union[EasyDict, Any]]]]: + """ + Overview: + Generates a list of configurations for multi-task experiments. + + Arguments: + - env_id_list (:obj:`List[str]`): List of environment IDs for the tasks. + - ... (same as create_config): Other experiment parameters. + - seed (:obj:`int`): The random seed for the experiment. + + Returns: + - (:obj:`List[Tuple[int, List[Union[EasyDict, Any]]]]`): A list where each element contains a task_id and its + corresponding configuration and environment manager setup. + """ + configs = [] + task_num = len(env_id_list) + + # --- Experiment Name Prefix --- + # This prefix defines the storage path for experiment data and logs. + # Please replace `` with your actual data storage path. + exp_name_prefix_template = ( + "/data_unizero_atari_mt_finetune_{timestamp}/" + "experiment_name/{task_num}games_brf{brf}_1-encoder-{norm}-res2-channel256_" + "gsl20_lsd768-nlayer8-nh8_upc80_seed{seed}/" + ) + exp_name_prefix = exp_name_prefix_template.format( + timestamp="20250308", + task_num=task_num, + brf=buffer_reanalyze_freq, + norm=norm_type, + seed=seed + ) + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, + n_episode, num_simulations, reanalyze_ratio, batch_size, + num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size, task_num + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + + +def create_env_manager() -> EasyDict: + """ + Overview: + Creates the environment and policy manager configuration. + This specifies the types and import paths for the environment and policy used in the experiment. + + Returns: + - (:obj:`EasyDict`): An EasyDict object containing manager configurations. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run one of the following commands to launch the script: + - Using torch.distributed.launch: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29507 ./path/to/this/script.py + - Using torchrun: + torchrun --nproc_per_node=8 ./path/to/this/script.py + """ + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + + # --- Main Experiment Settings --- + # Use DEBUG mode for fast iteration and debugging. + DEBUG = False + + # --- Environment and Task Settings --- + env_id_list = ['AmidarNoFrameskip-v4'] + action_space_size = 18 + + # --- Distributed Training Settings --- + os.environ["NCCL_TIMEOUT"] = "3600000000" + + # --- Loop over seeds for multiple runs --- + for seed in [0]: + # --- Core Algorithm Parameters --- + if DEBUG: + # Settings for quick debugging + collector_env_num = 2 + num_segments = 2 + n_episode = 2 + evaluator_env_num = 2 + num_simulations = 2 + total_batch_size = 32 + batch_size = [int(total_batch_size / len(env_id_list))] * len(env_id_list) + reanalyze_batch_size = 4 + max_env_step = int(1e3) + else: + # Standard experiment settings + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list)))] * len(env_id_list) + reanalyze_batch_size = 160 + max_env_step = int(4e5) + + # --- Shared Parameters --- + reanalyze_ratio = 0.0 + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 10000000 # Effectively disabled + reanalyze_partition = 0.75 + + # --- Generate Configurations --- + configs = generate_experiment_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size + ) + + # --- Pretrained Model Path --- + # Please replace `` with the actual path to your model. + pretrained_model_path = ( + "/data_unizero_atari_mt_20250307/" + "atari_8games_brf0.02_not-share-head_final-ln_seed0/Pong_seed0/ckpt/ckpt_best.pth.tar" + ) + + # --- Start Training --- + with DDPContext(): + train_unizero_multitask_segment_ddp( + configs, + seed=seed, + model_path=pretrained_model_path, + max_env_step=max_env_step + ) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_segment_config.py b/zoo/atari/config/atari_unizero_segment_config.py index dec2ee4d2..fa115e459 100644 --- a/zoo/atari/config/atari_unizero_segment_config.py +++ b/zoo/atari/config/atari_unizero_segment_config.py @@ -10,29 +10,44 @@ def main(env_id, seed): # ============================================================== collector_env_num = 8 num_segments = 8 + game_segment_length = 20 - evaluator_env_num = 10 + # game_segment_length = 400 # TODO + + evaluator_env_num = 3 num_simulations = 50 - max_env_step = int(5e5) - batch_size = 64 + # max_env_step = int(4e5) + max_env_step = int(5e6) # TODO + # max_env_step = int(1e6) # TODO pong + + # batch_size = 2 # only for debug + # batch_size = 64 + batch_size = 256 num_layers = 2 - replay_ratio = 0.25 + replay_ratio = 0.1 + # replay_ratio = 0.25 num_unroll_steps = 10 infer_context_length = 4 # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. - buffer_reanalyze_freq = 1/50 + # buffer_reanalyze_freq = 1/50 + buffer_reanalyze_freq = 1/5000000000 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) reanalyze_batch_size = 160 # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. reanalyze_partition = 0.75 + norm_type ="LN" # ====== only for debug ===== # collector_env_num = 2 # num_segments = 2 # evaluator_env_num = 2 - # num_simulations = 10 + # num_simulations = 5 # batch_size = 5 + # buffer_reanalyze_freq = 1/1000000 + # replay_ratio = 1 + # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -47,9 +62,11 @@ def main(env_id, seed): evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), + # collect_max_episode_steps=int(5e3), + # eval_max_episode_steps=int(5e3), # TODO: only for debug - # collect_max_episode_steps=int(50), - # eval_max_episode_steps=int(50), + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), ), policy=dict( learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 @@ -58,7 +75,28 @@ def main(env_id, seed): action_space_size=action_space_size, reward_support_range=(-300., 301., 1.), value_support_range=(-300., 301., 1.), + norm_type=norm_type, + num_res_blocks=1, + num_channels=64, + # num_res_blocks=2, + # num_channels=128, world_model_cfg=dict( + norm_type=norm_type, + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: only for latent state layer_norm + + # final_norm_option_in_obs_head='SimNorm', + # final_norm_option_in_encoder='SimNorm', + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + + # analysis_dormant_ratio_weight_rank=True, # TODO + + analysis_dormant_ratio_weight_rank=False, # TODO + dormant_threshold=0.025, + task_embed_option=None, # ==============TODO: none ============== + use_task_embed=False, # ==============TODO============== + use_shared_projection=False, support_size=601, policy_entropy_weight=5e-3, continuous_action_space=False, @@ -73,28 +111,90 @@ def main(env_id, seed): obs_type='image', env_num=max(collector_env_num, evaluator_env_num), num_simulations=num_simulations, + game_segment_length=game_segment_length, + # use_priority=False, + use_priority=True, rotary_emb=False, + encoder_type='resnet', + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + # LoRA 参数: + lora_r= 0, + lora_alpha =1, + lora_dropout= 0.0, + optim_type='AdamW_mix_lr_wdecay', # only for tsne plot + ), ), + optim_type='AdamW_mix_lr_wdecay', + weight_decay=1e-2, # TODO: encoder 5*wd, transformer wd, head 0 + learning_rate=0.0001, + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. model_path=None, + + # (bool) 是否启用自适应策略熵权重 (alpha) + use_adaptive_entropy_weight=True, + # (float) 自适应alpha优化器的学习率 + adaptive_entropy_alpha_lr=1e-4, + # adaptive_entropy_alpha_lr=1e-3, + target_entropy_start_ratio =0.98, + # target_entropy_end_ratio =0.9, + target_entropy_end_ratio =0.7, + target_entropy_decay_steps = 100000, # 例如,在100k次迭代后达到最终值 需要与replay ratio协同调整 + # target_entropy_end_ratio =0.5, # TODO===== + # target_entropy_decay_steps = 400000, # 例如,在100k次迭代后达到最终值 需要与replay ratio协同调整 + + + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) 是否启用 encoder-clip 值的退火。 + use_encoder_clip_annealing=True, + # (str) 退火类型。可选 'linear' 或 'cosine'。 + encoder_clip_anneal_type='cosine', + # (float) 退火的起始 clip 值 (训练初期,较宽松)。 + encoder_clip_start_value=30.0, + # (float) 退火的结束 clip 值 (训练后期,较严格)。 + encoder_clip_end_value=10.0, + # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 + # encoder_clip_anneal_steps=400000, # 例如,在400k次迭代后达到最终值 + encoder_clip_anneal_steps=100000, # 例如,在100k次迭代后达到最终值 + + # ==================== START: label smooth ==================== + policy_ls_eps_start=0.05, #TODO============= good start in Pong and MsPacman + policy_ls_eps_end=0.01, + policy_ls_eps_decay_steps=50000, # 50k + label_smoothing_eps=0.1, #TODO============= for value + + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + # monitor_norm_freq=10000, + monitor_norm_freq=5000, # TODO + # monitor_norm_freq=2, # only for debug + use_augmentation=False, manual_temperature_decay=False, threshold_training_steps_for_final_temperature=int(2.5e4), - use_priority=False, + # use_priority=False, + use_priority=True, + priority_prob_alpha=1, + priority_prob_beta=1, num_unroll_steps=num_unroll_steps, update_per_collect=None, replay_ratio=replay_ratio, batch_size=batch_size, - optim_type='AdamW', - learning_rate=0.0001, num_simulations=num_simulations, num_segments=num_segments, td_steps=5, - train_start_after_envsteps=0, + train_start_after_envsteps=0, # only for debug + # train_start_after_envsteps=2000, game_segment_length=game_segment_length, grad_clip_value=5, - replay_buffer_size=int(1e6), + replay_buffer_size=int(5e5), eval_freq=int(5e3), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, @@ -126,7 +226,11 @@ def main(env_id, seed): # ============ use muzero_segment_collector instead of muzero_collector ============= from lzero.entry import train_unizero_segment - main_config.exp_name = f'data_lz/data_unizero/{env_id[:-14]}/{env_id[:-14]}_uz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + main_config.exp_name = f'data_unizero_st_refactor1023/{env_id[:-14]}/{env_id[:-14]}_uz_ch64-res1_targetentropy-alpha-100k-098-07-encoder-clip30-10-100k_adamw-wd1e-2-encoder5-trans1-head0_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + + # main_config.exp_name = f'data_unizero_st_refactor1023/{env_id[:-14]}/{env_id[:-14]}_uz_ch64-res1_targetentropy-alpha-100k-098-07-encoder-clip30-10-100k_label-smooth_resnet-encoder_priority_adamw-wd1e-2-encoder5-trans1-head0-true_ln-inner-ln_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + + # main_config.exp_name = f'data_unizero_st_refactor1010/{env_id[:-14]}/{env_id[:-14]}_uz_ch128-res2_targetentropy-alpha-100k-098-07-encoder-clip30-10-400k_label-smooth_resnet-encoder_priority_adamw-wd1e-2-encoder1-trans1-head1_ln-inner-ln_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' train_unizero_segment([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) @@ -137,4 +241,33 @@ def main(env_id, seed): parser.add_argument('--seed', type=int, help='The seed to use', default=0) args = parser.parse_args() + + + # 测试的atari8中的4个base环境 + # args.env = 'PongNoFrameskip-v4' # 反应型环境 密集奖励 + # args.env = 'MsPacmanNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + + args.env = 'SeaquestNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + # args.env = 'HeroNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + + # args.env = 'AlienNoFrameskip-v4' + + # 下面是atari8以外的2个代表环境 + # args.env = 'QbertNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + # args.env = 'SpaceInvadersNoFrameskip-v4' # 记忆规划型环境 稀疏奖励 + + # 下面是已经表现不错的 + # args.env = 'BoxingNoFrameskip-v4' # 反应型环境 密集奖励 + # args.env = 'ChopperCommandNoFrameskip-v4' + # args.env = 'RoadRunnerNoFrameskip-v4' + main(args.env, args.seed) + + """ + tmux new -s uz-st-refactor-boxing + + conda activate /mnt/nfs/zhangjinouwen/puyuan/conda_envs/lz + export CUDA_VISIBLE_DEVICES=1 + cd /mnt/nfs/zhangjinouwen/puyuan/LightZero + python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py 2>&1 | tee /mnt/nfs/zhangjinouwen/puyuan/LightZero/log/202510/20251023_uz_st_seaq.log + """ diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index 8bc491674..d40f35033 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -24,6 +24,8 @@ class AtariEnvLightZero(BaseEnv): _reward_space, obs, _eval_episode_return, has_reset, _seed, _dynamic_seed """ config = dict( + # (bool) Whether to use the full action space of the environment. Default is False. If set to True, the action space size is 18 for Atari. + full_action_space=False, # (int) The number of environment instances used for data collection. collector_env_num=8, # (int) The number of environment instances used for evaluator. @@ -175,11 +177,14 @@ def step(self, action: int) -> BaseEnvTimestep: self.reward = np.array(reward).astype(np.float32) self._eval_episode_return += self.reward self._timestep += 1 - # logging.info(f'self._timestep: {self._timestep}') + if self._timestep%200==0: + logging.info(f'self._timestep: {self._timestep}') observation = self.observe() if done: logging.info(f'one episode done! total episode length is: {self._timestep}') info['eval_episode_return'] = self._eval_episode_return + print(f'one episode of {self.cfg.env_id} done') + return BaseEnvTimestep(observation, self.reward, done, info) def observe(self) -> dict: diff --git a/zoo/atari/envs/atari_wrappers.py b/zoo/atari/envs/atari_wrappers.py index f38aa24d6..265ef31ac 100644 --- a/zoo/atari/envs/atari_wrappers.py +++ b/zoo/atari/envs/atari_wrappers.py @@ -93,9 +93,9 @@ def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) -> - env (:obj:`gym.Env`): The wrapped Atari environment with the given configurations. """ if config.render_mode_human: - env = gym.make(config.env_id, render_mode='human') + env = gym.make(config.env_id, render_mode='human', full_action_space=config.full_action_space) else: - env = gym.make(config.env_id, render_mode='rgb_array') + env = gym.make(config.env_id, render_mode='rgb_array', full_action_space=config.full_action_space) assert 'NoFrameskip' in env.spec.id if hasattr(config, 'save_replay') and config.save_replay \ and hasattr(config, 'replay_path') and config.replay_path is not None: diff --git a/zoo/dmc2gym/config/dmc2gym_pixels_sampled_unizero_config.py b/zoo/dmc2gym/config/dmc2gym_pixels_sampled_unizero_config.py new file mode 100644 index 000000000..4f5ca5bda --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_pixels_sampled_unizero_config.py @@ -0,0 +1,132 @@ +from easydict import EasyDict +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== + +from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + +env_id = 'cartpole-swingup' # You can specify any DMC task here +action_space_size = dmc_state_env_action_space_map[env_id] +obs_space_size = dmc_state_env_obs_space_map[env_id] +print(f'env_id: {env_id}, action_space_size: {action_space_size}, obs_space_size: {obs_space_size}') + +domain_name = env_id.split('-')[0] +task_name = env_id.split('-')[1] + +continuous_action_space = True +K = 20 # num_of_sampled_actions +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = None +replay_ratio = 0.25 +max_env_step = int(1e6) +reanalyze_ratio = 0 +batch_size = 64 +num_unroll_steps = 10 +infer_context_length = 4 +norm_type = 'LN' +seed = 0 + +# for debug +# collector_env_num = 2 +# n_episode = 2 +# evaluator_env_num = 1 +# num_simulations = 2 +# batch_size = 2 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +dmc2gym_pixels_cont_sampled_unizero_config = dict( + exp_name=f'data_sampled_unizero_0901/dmc2gym_{env_id}_image_cont_sampled_unizero_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_{norm_type}_seed{seed}', + env=dict( + env_id='dmc2gym-v0', + continuous=True, + domain_name=domain_name, + task_name=task_name, + from_pixels=True, # pixel/image obs + frame_skip=2, + warp_frame=True, + scale=True, + frame_stack_num=1, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(3, 84, 84), + action_space_size=action_space_size, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + world_model_cfg=dict( + obs_type='image', + num_unroll_steps=num_unroll_steps, + policy_entropy_loss_weight=5e-3, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + sigma_type='conditioned', + fixed_sigma_value=0.3, + bound_type=None, + model_type='conv', + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + # device='cpu', + device='cuda', + action_space_size=action_space_size, + num_layers=2, + num_heads=8, + embed_dim=768, + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + num_unroll_steps=num_unroll_steps, + cuda=True, + use_augmentation=False, + env_type='not_board_games', + game_segment_length=100, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + lr_piecewise_constant_decay=False, + learning_rate=0.0001, + target_update_freq=100, + grad_clip_value=5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +dmc2gym_pixels_cont_sampled_unizero_config = EasyDict(dmc2gym_pixels_cont_sampled_unizero_config) +main_config = dmc2gym_pixels_cont_sampled_unizero_config + +dmc2gym_pixels_cont_sampled_unizero_create_config = dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + # env_manager=dict(type='subprocess'), + env_manager=dict(type='base'), + policy=dict( + type='sampled_unizero', + import_names=['lzero.policy.sampled_unizero'], + ), +) +dmc2gym_pixels_cont_sampled_unizero_create_config = EasyDict(dmc2gym_pixels_cont_sampled_unizero_create_config) +create_config = dmc2gym_pixels_cont_sampled_unizero_create_config + +if __name__ == "__main__": + from lzero.entry import train_unizero + + train_unizero([main_config, create_config], seed=seed, max_env_step=max_env_step) diff --git a/zoo/dmc2gym/config/dmc2gym_state_smz_config.py b/zoo/dmc2gym/config/dmc2gym_state_smz_config.py index 95456d56e..c99d3960b 100644 --- a/zoo/dmc2gym/config/dmc2gym_state_smz_config.py +++ b/zoo/dmc2gym/config/dmc2gym_state_smz_config.py @@ -30,7 +30,7 @@ # ============================================================== dmc2gym_state_cont_sampled_muzero_config = dict( - exp_name=f'data_smz/dmc2gym_{env_id}_state_cont_sampled_muzero_k{K}_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_{norm_type}_seed{seed}', + exp_name=f'/oss/niuyazhe/puyuan/data/data_lz_202505/data_smz/dmc2gym_{env_id}_state_cont_sampled_muzero_k{K}_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_{norm_type}_seed{seed}', env=dict( env_id='dmc2gym-v0', domain_name=domain_name, diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_config.py similarity index 100% rename from zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py rename to zoo/dmc2gym/config/dmc2gym_state_suz_config.py diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py new file mode 100644 index 000000000..ba979b1c6 --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py @@ -0,0 +1,531 @@ +# -*- coding: utf-8 -*- +""" +Overview: + This script defines the configuration for a multi-task reinforcement learning experiment + using the UniZero model on DeepMind Control Suite (DMC) environments. + It is designed to be launched with PyTorch's Distributed Data Parallel (DDP) for multi-GPU training. +""" +from __future__ import annotations + +import logging +from typing import Any, Dict, List + +from easydict import EasyDict + +# ============================================================== +# Global setup: Logging +# ============================================================== +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(message)s', + handlers=[ + logging.FileHandler("output.log", encoding="utf-8"), # Log to file + logging.StreamHandler() # Log to console + ] +) + + +def get_base_config(env_id_list: list[str], collector_env_num: int, evaluator_env_num: int, + num_unroll_steps: int, infer_context_length: int, curriculum_stage_num: int) -> EasyDict: + """ + Overview: + Creates the base configuration EasyDict with default settings for the experiment. + These settings are shared across all tasks but can be overridden. + + Arguments: + - env_id_list (:obj:`list[str]`): A list of environment IDs for all tasks. + - collector_env_num (:obj:`int`): The number of environments for data collection. + - evaluator_env_num (:obj:`int`): The number of environments for evaluation. + - num_unroll_steps (:obj:`int`): The number of game steps to unroll in the model. + - infer_context_length (:obj:`int`): The context length for inference. + - curriculum_stage_num (:obj:`int`): The number of stages in the curriculum learning. + + Returns: + - (:obj:`EasyDict`): A dictionary containing the base configuration. + """ + return EasyDict(dict( + # Environment-specific settings + env=dict( + stop_value=int(5e5), + from_pixels=False, + continuous=True, # Assuming all DMC tasks use continuous action spaces + manager=dict(shared_memory=False), + game_segment_length=100, + # TODO(user): For debugging only. Uncomment to use smaller segments and episodes. + # game_segment_length=10, + # collect_max_episode_steps=int(40), + # eval_max_episode_steps=int(40), + ), + # Policy-specific settings + policy=dict( + multi_gpu=True, # TODO(user): Enable multi-GPU for DDP. + # TODO(user): Configure MoCo settings. + only_use_moco_stats=False, + use_moco=False, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + # Model configuration + model=dict( + continuous_action_space=True, + num_of_sampled_actions=20, + model_type='mlp', + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO(user): Loss type for latent state with LayerNorm. + + share_head=False, # TODO(user): Whether to share the prediction head across tasks. + use_shared_projection=False, + + # TODO(user): analysis_dormant_ratio needs to be corrected for the DMC encoder. + analysis_dormant_ratio_weight_rank=False, + analysis_dormant_ratio_interval=5000, + # analysis_dormant_ratio_interval=20, # For debugging + + # TODO(user): Configure task embedding options. + task_embed_option=None, + use_task_embed=False, + # task_embed_option='concat_task_embed', + # use_task_embed=True, + # task_embed_dim=128, + + policy_loss_type='kl', + obs_type='vector', + policy_entropy_weight=5e-2, + continuous_action_space=True, + num_of_sampled_actions=20, + sigma_type='conditioned', + fixed_sigma_value=0.5, + bound_type=None, + model_type='mlp', + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # Each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + + # TODO(user): For debugging only. Use a smaller model. + # num_layers=1, + num_layers=4, + # num_layers=8, + + num_heads=24, + embed_dim=768, + env_num=max(collector_env_num, evaluator_env_num), + task_num=len(env_id_list), + + # Mixture of Experts (MoE) head configuration + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + # MoE in Transformer configuration + moe_in_transformer=False, + multiplication_moe_in_transformer=True, + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + # LoRA (Low-Rank Adaptation) parameters + # TODO(user): Enable or disable LoRA for MoE layers. + moe_use_lora=True, + lora_target_modules=["attn", "feed_forward"], + lora_r=64, + lora_alpha=1, + lora_dropout=0.0, + lora_scale_init=1, + + # Curriculum learning stage iteration counts + curriculum_stage_num=curriculum_stage_num, + min_stage0_iters=10000, # Corresponds to 400k envsteps, 40k iters + max_stage_iters=5000, + + # TODO(user): For debugging only. Use very short stage iterations. + # min_stage0_iters=2, + # max_stage_iters=5, + ), + ), + # TODO(user): Enable or disable task exploitation weight. + use_task_exploitation_weight=False, + balance_pipeline=True, + # TODO(user): Enable or disable task complexity weight. + task_complexity_weight=True, + allocated_batch_sizes=False, + # TODO(user): Set the number of environment steps to collect before training starts. + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + + # TODO(user): For debugging only. Set a smaller update_per_collect. + # update_per_collect=3, + update_per_collect=200, # e.g., 8 envs * 100 steps/env * 0.25 replay_ratio = 200 + replay_buffer_size=int(1e6), + eval_freq=int(4e3), + grad_clip_value=5, + learning_rate=1e-4, + discount_factor=0.99, + td_steps=5, + piecewise_decay_lr_scheduler=False, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), + cos_lr_scheduler=True, + ), + )) + + +def create_task_config( + base_config: EasyDict, + env_id: str, + observation_shape_list: list[int], + action_space_size_list: list[int], + target_return_dict: dict[str, int], + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: int, + num_unroll_steps: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int +) -> EasyDict: + """ + Overview: + Creates a specialized configuration for a single task by updating the base config. + + Arguments: + - base_config (:obj:`EasyDict`): The base configuration dictionary. + - env_id (:obj:`str`): The ID of the environment for this specific task. + - observation_shape_list (:obj:`list[int]`): List of observation shapes for all tasks. + - action_space_size_list (:obj:`list[int]`): List of action space sizes for all tasks. + - target_return_dict (:obj:`dict[str, int]`): A dictionary mapping env_id to its target return. + - collector_env_num (:obj:`int`): The number of collector environments. + - evaluator_env_num (:obj:`int`): The number of evaluator environments. + - n_episode (:obj:`int`): The number of episodes to run for collection. + - num_simulations (:obj:`int`): The number of simulations in MCTS. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in a batch. + - batch_size (:obj:`int`): The batch size for training this task. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model. + - norm_type (:obj:`str`): The type of normalization to use (e.g., 'LN'). + - buffer_reanalyze_freq (:obj:`float`): Frequency of buffer reanalysis. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalysis. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalysis. + - num_segments (:obj:`int`): The number of segments in the replay buffer. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + + Returns: + - (:obj:`EasyDict`): The final configuration for the specified task. + """ + domain_name, task_name = env_id.split('-', 1) + frame_skip = 8 if domain_name == "pendulum" else 4 + + config = base_config + + # Update environment settings + config.env.update(dict( + env_id=env_id, + domain_name=domain_name, + task_name=task_name, + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + frame_skip=frame_skip, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + )) + + # Update model settings + config.policy.model.update(dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + )) + config.policy.model.world_model_cfg.update(dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + num_unroll_steps=num_unroll_steps, + norm_type=norm_type, + )) + + # Update policy settings + config.policy.update(dict( + target_return=target_return_dict.get(env_id), + total_batch_size=total_batch_size, + num_unroll_steps=num_unroll_steps, + replay_ratio=reanalyze_ratio, + batch_size=batch_size, + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + )) + + return config + + +def create_env_manager_config() -> EasyDict: + """ + Overview: + Creates the configuration for the environment manager and policy type. + + Returns: + - (:obj:`EasyDict`): A dictionary with environment manager and policy import settings. + """ + return EasyDict(dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + + +def generate_experiment_name(num_tasks: int, curriculum_stage_num: int, buffer_reanalyze_freq: float, seed: int) -> str: + """ + Overview: + Generates a descriptive name for the experiment. + + Arguments: + - num_tasks (:obj:`int`): Number of tasks in the experiment. + - curriculum_stage_num (:obj:`int`): Number of curriculum stages. + - buffer_reanalyze_freq (:obj:`float`): Frequency of buffer reanalysis. + - seed (:obj:`int`): The random seed for the experiment. + + Returns: + - (:obj:`str`): The generated experiment name prefix. + """ + # NOTE: This is a template for the experiment name. + # Users should customize it to reflect their specific experiment settings. + return ( + f'data_suz_dmc_mt_balance_20250625/dmc_{num_tasks}tasks_frameskip4-pen-fs8_balance-stage-total-{curriculum_stage_num}' + f'_stage0-10k-5k_fix-lora-update-stablescale_moe8-uselora_nlayer4_not-share-head' + f'_brf{buffer_reanalyze_freq}_seed{seed}/' + ) + + +def generate_all_task_configs( + env_id_list: list[str], + target_return_dict: dict[str, int], + action_space_size_list: list[int], + observation_shape_list: list[int], + curriculum_stage_num: int, + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: list[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int +) -> list[tuple[int, list[EasyDict, EasyDict]]]: + """ + Overview: + Generates a list of configurations, one for each task in the experiment. + + Arguments: + - env_id_list (:obj:`list[str]`): A list of all environment IDs. + - target_return_dict (:obj:`dict[str, int]`): Mapping from env_id to target return. + - action_space_size_list (:obj:`list[int]`): List of action space sizes for all tasks. + - observation_shape_list (:obj:`list[int]`): List of observation shapes for all tasks. + - curriculum_stage_num (:obj:`int`): The number of curriculum stages. + - (other args): Hyperparameters for the experiment. See `create_task_config` for details. + + Returns: + - (:obj:`list`): A list where each element is `[task_id, [task_config, env_manager_config]]`. + """ + configs = [] + exp_name_prefix = generate_experiment_name( + num_tasks=len(env_id_list), + curriculum_stage_num=curriculum_stage_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + seed=seed + ) + + base_config = get_base_config( + env_id_list=env_id_list, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + curriculum_stage_num=curriculum_stage_num + ) + + for task_id, env_id in enumerate(env_id_list): + task_specific_config = create_task_config( + base_config=base_config.clone(), # Use a clone to avoid modifying the base config + env_id=env_id, + action_space_size_list=action_space_size_list, + observation_shape_list=observation_shape_list, + target_return_dict=target_return_dict, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size[task_id], + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + task_specific_config.policy.task_id = task_id + task_specific_config.exp_name = exp_name_prefix + f"{env_id}_seed{seed}" + + env_manager_cfg = create_env_manager_config() + configs.append([task_id, [task_specific_config, env_manager_cfg]]) + + return configs + + +def main(): + """ + Overview: + Main function to set up and launch the multi-task UniZero training experiment. + This script should be executed with GPUs. + + Example launch commands: + 1. Using `torch.distributed.launch`: + cd /LightZero/ + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 \\ + ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py 2>&1 | tee \\ + ./logs/uz_mt_dmc18_balance_moe8_seed0.log + + 2. Using `torchrun`: + cd /LightZero/ + torchrun --nproc_per_node=8 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py + """ + from lzero.entry import train_unizero_multitask_balance_segment_ddp + from ding.utils import DDPContext + import torch.distributed as dist + from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + + # ============================================================== + # Experiment-level settings + # ============================================================== + # NOTE: You can switch between different sets of environments by uncommenting them. + # DMC 8-task benchmark + # env_id_list = [ + # 'acrobot-swingup', 'cartpole-balance', 'cartpole-balance_sparse', + # 'cartpole-swingup', 'cartpole-swingup_sparse', 'cheetah-run', + # "ball_in_cup-catch", "finger-spin", + # ] + # target_return_dict = { + # 'acrobot-swingup': 500, 'cartpole-balance': 950, 'cartpole-balance_sparse': 950, + # 'cartpole-swingup': 800, 'cartpole-swingup_sparse': 750, 'cheetah-run': 650, + # "ball_in_cup-catch": 950, "finger-spin": 800, + # } + + # DMC 18-task benchmark + env_id_list = [ + 'acrobot-swingup', 'cartpole-balance', 'cartpole-balance_sparse', 'cartpole-swingup', + 'cartpole-swingup_sparse', 'cheetah-run', "ball_in_cup-catch", "finger-spin", + "finger-turn_easy", "finger-turn_hard", 'hopper-hop', 'hopper-stand', + 'pendulum-swingup', 'reacher-easy', 'reacher-hard', 'walker-run', + 'walker-stand', 'walker-walk', + ] + target_return_dict = { + 'acrobot-swingup': 500, 'cartpole-balance': 900, 'cartpole-balance_sparse': 950, + 'cartpole-swingup': 750, 'cartpole-swingup_sparse': 750, 'cheetah-run': 550, + "ball_in_cup-catch": 950, "finger-spin": 800, "finger-turn_easy": 950, + "finger-turn_hard": 950, 'hopper-hop': 150, 'hopper-stand': 600, + 'pendulum-swingup': 800, 'reacher-easy': 900, 'reacher-hard': 900, + 'walker-run': 500, 'walker-stand': 900, 'walker-walk': 900, + } + + # ============================================================== + # Hyperparameters + # ============================================================== + # NOTE: For debugging, you can use smaller values. + # collector_env_num, num_segments, n_episode = 2, 2, 2 + # evaluator_env_num, num_simulations, total_batch_size = 2, 1, 8 + # batch_size = [3] * len(env_id_list) + # max_env_step = int(1e3) + + # Production settings + curriculum_stage_num = 5 + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(4e5) + reanalyze_ratio = 0.0 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list)))] * len(env_id_list) + num_unroll_steps = 5 + infer_context_length = 2 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 100000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + seed = 0 # You can iterate over multiple seeds if needed + + # Fetch observation and action space info from predefined maps + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + # ============================================================== + # Generate configurations and start training + # ============================================================== + configs = generate_all_task_configs( + env_id_list=env_id_list, + target_return_dict=target_return_dict, + action_space_size_list=action_space_size_list, + observation_shape_list=observation_shape_list, + curriculum_stage_num=curriculum_stage_num, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + + with DDPContext(): + # To train only a subset of tasks for debugging, you can slice the configs list. + # e.g., train_unizero_multitask_balance_segment_ddp(configs[:1], ...) + train_unizero_multitask_balance_segment_ddp(configs, seed=seed, max_env_step=max_env_step, benchmark_name="dmc") + dist.destroy_process_group() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py new file mode 100644 index 000000000..de2c09fa2 --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py @@ -0,0 +1,480 @@ +from easydict import EasyDict +from typing import List, Any, Dict, Tuple + +import logging + +# Set up logging configuration +# Configure logging to output to both a file and the console. +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(message)s', + handlers=[ + logging.FileHandler("output.log", encoding="utf-8"), # Log to file + logging.StreamHandler() # Log to console + ] +) + + +def create_config( + env_id: str, + env_id_list: List[str], + target_return_dict: Dict[str, int], + observation_shape_list: List[Tuple[int, ...]], + action_space_size_list: List[int], + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, +) -> EasyDict: + """ + Overview: + Create a configuration EasyDict for a single reinforcement learning task. + + Arguments: + - env_id (:obj:`str`): The ID of the environment, e.g., 'cartpole-swingup'. + - env_id_list (:obj:`List[str]`): A list of all environment IDs for the multi-task setup. + - target_return_dict (:obj:`Dict[str, int]`): A dictionary mapping environment IDs to their target return values. + - observation_shape_list (:obj:`List[Tuple[int, ...]]`): List of observation shapes for all tasks. + - action_space_size_list (:obj:`List[int]`): List of action space sizes for all tasks. + - collector_env_num (:obj:`int`): Number of environments for data collection. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - n_episode (:obj:`int`): Number of episodes to run for collection. + - num_simulations (:obj:`int`): Number of simulations in the MCTS search. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in a batch. + - batch_size (:obj:`List[int]`): Batch size for training per task. + - num_unroll_steps (:obj:`int`): Number of steps to unroll the model during training. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization to use (e.g., 'LN'). + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalyzing. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalyzing. + - num_segments (:obj:`int`): Number of segments for the replay buffer. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + + Returns: + - (:obj:`EasyDict`): A configuration object for the specified task. + """ + domain_name, task_name = env_id.split('-') + + # Specific frame_skip settings for certain domains. + if domain_name == "pendulum": + frame_skip = 8 + else: + frame_skip = 4 + + # --- Environment Configuration --- + env_cfg = dict( + stop_value=int(5e5), + env_id=env_id, + domain_name=domain_name, + task_name=task_name, + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + from_pixels=False, + frame_skip=frame_skip, + continuous=True, # Assuming all DMC tasks use continuous action spaces + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + game_segment_length=100, + # TODO: Settings for debugging purposes. + # game_segment_length=10, + # collect_max_episode_steps=int(40), + # eval_max_episode_steps=int(40), + ) + + # --- World Model Configuration --- + world_model_cfg = dict( + # --- Normalization and Loss --- + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: for latent state layer_norm + # final_norm_option_in_obs_head='SimNorm', + # final_norm_option_in_encoder='SimNorm', + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + + # --- Architecture --- + share_head=False, # TODO + use_shared_projection=False, + obs_type='vector', + model_type='mlp', + continuous_action_space=True, + num_of_sampled_actions=20, + sigma_type='conditioned', + fixed_sigma_value=0.5, + bound_type=None, + norm_type=norm_type, + device='cuda', + + # --- Transformer/MOE Settings --- + num_layers=8, # TODO: 8 for standard, 1 for debug + num_heads=24, + embed_dim=768, + moe_in_transformer=False, + multiplication_moe_in_transformer=True, + num_experts_of_moe_in_transformer=8, + n_shared_experts=1, + num_experts_per_tok=1, + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + # --- LoRA Parameters --- + moe_use_lora=False, # TODO + curriculum_stage_num=3, + lora_target_modules=["attn", "feed_forward"], + lora_r=0, + lora_alpha=1, + lora_dropout=0.0, + + # --- Multi-task Settings --- + task_embed_option=None, # TODO: 'concat_task_embed' or None + use_task_embed=False, # TODO + # task_embed_dim=128, + task_num=len(env_id_list), + + # --- Analysis --- + analysis_dormant_ratio_weight_rank=False, # TODO + analysis_dormant_ratio_interval=5000, + + # --- Dynamic Properties --- + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + num_unroll_steps=num_unroll_steps, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # Each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + env_num=max(collector_env_num, evaluator_env_num), + + # --- Loss Weights --- + policy_loss_type='kl', + policy_entropy_weight=5e-2, + ) + + # --- Policy Configuration --- + policy_cfg = dict( + # --- Hardware & Distribution --- + multi_gpu=True, # TODO: enable multi-GPU for DDP + cuda=True, + + # --- Model --- + model=dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + continuous_action_space=True, + num_of_sampled_actions=20, + model_type='mlp', + world_model_cfg=world_model_cfg, + ), + + # --- Learning --- + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000))), + optim_type='AdamW', + learning_rate=1e-4, + grad_clip_value=5, + cos_lr_scheduler=True, + piecewise_decay_lr_scheduler=False, + + # --- Training Loop --- + train_start_after_envsteps=int(0), # TODO: 2e3 for standard, 0 for quick debug + update_per_collect=200, + replay_ratio=reanalyze_ratio, + + # --- Batch Sizes --- + batch_size=batch_size, + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + + # --- Replay Buffer --- + replay_buffer_size=int(1e6), + num_segments=num_segments, + use_priority=False, + + # --- Reanalyze --- + reanalyze_ratio=reanalyze_ratio, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + + # --- Algorithm Hyperparameters --- + num_simulations=num_simulations, + num_unroll_steps=num_unroll_steps, + td_steps=5, + discount_factor=0.99, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), + + # --- MoCo (Momentum Contrast) --- + use_moco=False, # TODO + only_use_moco_stats=False, + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + + # --- Multi-task Specific --- + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, # To be set per task + target_return=target_return_dict.get(env_id), + use_task_exploitation_weight=False, # TODO + task_complexity_weight=True, # TODO + balance_pipeline=True, + print_task_priority_logs=False, + + # --- Environment Interaction --- + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + eval_freq=int(4e3), + + # --- Checkpointing --- + model_path=None, + ) + + # --- Combine configurations into the final EasyDict object --- + main_config = EasyDict(dict( + env=env_cfg, + policy=policy_cfg, + )) + + return main_config + + +def generate_configs( + env_id_list: List[str], + target_return_dict: Dict[str, int], + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + dmc_state_env_action_space_map: Dict[str, int], + dmc_state_env_obs_space_map: Dict[str, Tuple[int, ...]], +) -> List[Tuple[int, List[Any]]]: + """ + Overview: + Generate a list of configurations for all specified multi-task environments. + + Arguments: + - env_id_list (:obj:`List[str]`): A list of all environment IDs for the multi-task setup. + - target_return_dict (:obj:`Dict[str, int]`): A dictionary mapping environment IDs to their target return values. + - collector_env_num (:obj:`int`): Number of environments for data collection. + - n_episode (:obj:`int`): Number of episodes to run for collection. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - num_simulations (:obj:`int`): Number of simulations in the MCTS search. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in a batch. + - batch_size (:obj:`List[int]`): Batch size for training per task. + - num_unroll_steps (:obj:`int`): Number of steps to unroll the model during training. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization to use (e.g., 'LN'). + - seed (:obj:`int`): The random seed. + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalyzing. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalyzing. + - num_segments (:obj:`int`): Number of segments for the replay buffer. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + - dmc_state_env_action_space_map (:obj:`Dict[str, int]`): Map from env_id to action space size. + - dmc_state_env_obs_space_map (:obj:`Dict[str, Tuple[int, ...]]`): Map from env_id to observation shape. + + Returns: + - (:obj:`List[Tuple[int, List[Any]]]`): A list where each element contains the task ID and its corresponding + configuration objects. + """ + configs = [] + + # Define the experiment name prefix. This helps in organizing experiment logs and results. + exp_name_prefix = ( + f'data_suz_dmc_mt_20250601/dmc_{len(env_id_list)}tasks_frameskip4-pendulum-skip8_ln-mse' + f'_nlayer8_trans-moe8_brf{buffer_reanalyze_freq}_seed{seed}/' + ) + + # Get action_space_size and observation_shape for each environment. + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id=env_id, + env_id_list=env_id_list, + target_return_dict=target_return_dict, + action_space_size_list=action_space_size_list, + observation_shape_list=observation_shape_list, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + + +def create_env_manager() -> EasyDict: + """ + Overview: + Create the environment and policy manager configuration. This specifies the types + of environment, policy, and their import paths. + + Returns: + - (:obj:`EasyDict`): A configuration object for the environment and policy managers. + """ + return EasyDict(dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + + +if __name__ == "__main__": + """ + Overview: + Main script to configure and launch a multi-task training session for DeepMind Control Suite (DMC) + environments using Distributed Data Parallel (DDP). + + Usage: + This script should be executed with GPUs. + Navigate to the project root directory and run the launch command. + + Example command: + cd + # Using torch.distributed.launch (deprecated) + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 \\ + /dmc2gym_state_suz_multitask_ddp_config.py 2>&1 | tee \\ + /uz_mt_dmc18_train.log + + # Using torchrun (recommended) + torchrun --nproc_per_node=8 /dmc2gym_state_suz_multitask_ddp_config.py + """ + # --- Import necessary components for training --- + # It's good practice to place imports inside the main guard + # if they are only used for script execution. + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import torch.distributed as dist + from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + + # --- Experiment constants --- + BENCHMARK_NAME = 'dmc' + + # --- Environment and Task Definitions --- + # Target return values for each DMC task, used for evaluation and potential curriculum. + target_return_dict = { + 'acrobot-swingup': 500, + 'cartpole-balance': 950, + 'cartpole-balance_sparse': 950, + 'cartpole-swingup': 800, + 'cartpole-swingup_sparse': 750, + 'cheetah-run': 650, + "ball_in_cup-catch": 950, + "finger-spin": 800, + "finger-turn_easy": 950, + "finger-turn_hard": 950, + 'hopper-hop': 150, + 'hopper-stand': 600, + 'pendulum-swingup': 800, + 'reacher-easy': 950, + 'reacher-hard': 950, + 'walker-run': 600, + 'walker-stand': 950, + 'walker-walk': 950, + } + + # List of DMC environments to be used in the multi-task setup. + env_id_list = list(target_return_dict.keys()) + + # --- Hyperparameters for the training session --- + # Environment and Collector settings + collector_env_num = 8 + evaluator_env_num = 3 + n_episode = 8 + max_env_step = int(4e5) + + # Replay Buffer and Reanalyze settings + num_segments = 8 + reanalyze_ratio = 0.0 + buffer_reanalyze_freq = 1 / 100000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # Model and Training settings + total_batch_size = 512 + # Allocate batch size per task, ensuring a minimum of 64 or distributing the total size. + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + num_unroll_steps = 5 + infer_context_length = 2 + norm_type = 'LN' + num_simulations = 50 + + # --- Main training loop --- + # Iterate over different random seeds for multiple runs. + for seed in [1, 2]: + # Generate the specific configurations for each task for the current run. + configs = generate_configs( + env_id_list=env_id_list, + target_return_dict=target_return_dict, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + dmc_state_env_action_space_map=dmc_state_env_action_space_map, + dmc_state_env_obs_space_map=dmc_state_env_obs_space_map, + ) + + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step, + benchmark_name=BENCHMARK_NAME) + # If you only want to train a subset of tasks, you can slice the configs list. + # For example, to train only the first four tasks: + # train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step, benchmark_name=BENCHMARK_NAME) + dist.destroy_process_group() \ No newline at end of file diff --git a/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py b/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py index 4fcfb209a..068790293 100644 --- a/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py +++ b/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py @@ -18,6 +18,8 @@ from gym.spaces import Box from matplotlib import animation import imageio +import logging + def dmc2gym_observation_space(dim, minimum=-np.inf, maximum=np.inf, dtype=np.float32) -> Callable: def observation_space(from_pixels=True, height=84, width=84, channels_first=True) -> Box: @@ -268,6 +270,8 @@ def __init__(self, cfg: dict = {}) -> None: self._save_replay_gif = cfg.save_replay_gif self._replay_path_gif = cfg.replay_path_gif self._save_replay_count = 0 + self._timestep = 0 + self.max_episode_steps = cfg.max_episode_steps def reset(self) -> Dict[str, np.ndarray]: """ @@ -409,11 +413,15 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: if self._save_replay_gif: self._frames.append(image_obs) + + if self._timestep > self.max_episode_steps: + done = True if self._timestep > self._cfg.max_episode_steps: done = True if done: + logging.info(f'one episode done! episode return: {self._eval_episode_return}, episode_steps:{self._timestep}') info['eval_episode_return'] = self._eval_episode_return if self._save_replay_gif: @@ -422,7 +430,8 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") path = os.path.join( self._replay_path_gif, - '{}_episode_{}_seed{}_{}.gif'.format(f'{self._cfg["domain_name"]}_{self._cfg["task_name"]}', self._save_replay_count, self._seed, timestamp) + '{}_episode_{}_seed{}_{}.gif'.format(f'{self._cfg["domain_name"]}_{self._cfg["task_name"]}', + self._save_replay_count, self._seed, timestamp) ) self.display_frames_as_gif(self._frames, path) print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') @@ -487,7 +496,7 @@ def __repr__(self) -> str: String representation of the environment. """ return "LightZero DMC2Gym Env({}:{})".format(self._cfg["domain_name"], self._cfg["task_name"]) - + @staticmethod def create_collector_env_cfg(cfg: dict) -> List[dict]: collector_env_num = cfg.pop('collector_env_num') @@ -502,4 +511,4 @@ def create_evaluator_env_cfg(cfg: dict) -> List[dict]: cfg = copy.deepcopy(cfg) cfg.max_episode_steps = cfg.eval_max_episode_steps cfg.is_eval = True - return [cfg for _ in range(evaluator_env_num)] + return [cfg for _ in range(evaluator_env_num)] \ No newline at end of file diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index cc66e045b..8d5ac7fc1 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -132,6 +132,23 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e embed_dim=embed_dim, obs_type="text", env_num=max(collector_env_num, evaluator_env_num), + + task_embed_option=None, + use_task_embed=False, + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + lora_r= 0, + lora_alpha =1, + lora_dropout= 0.0, + decode_loss_mode=None, # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. latent_recon_loss_weight=0.1 ), @@ -149,7 +166,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e manual_temperature_decay=False, num_simulations=num_simulations, n_episode=n_episode, - train_start_after_envsteps=0, + train_start_after_envsteps=0, # TODO: Adjust training start trigger if needed. replay_buffer_size=int(5e5), eval_freq=int(3e4), collector_env_num=collector_env_num, diff --git a/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py b/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py index 7725d8409..4f7fef3ea 100644 --- a/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py +++ b/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py @@ -1,7 +1,9 @@ from easydict import EasyDict # options={'Hopper-v3', 'HalfCheetah-v3', 'Walker2d-v3', 'Ant-v3', 'Humanoid-v3'} -env_id = 'Hopper-v3' +# env_id = 'Hopper-v3' +env_id = 'Ant-v3' + if env_id == 'Hopper-v3': action_space_size = 3