From a65c4452367357269f38199ed4a3bb8b4c61a244 Mon Sep 17 00:00:00 2001 From: TAJh Date: Tue, 3 Feb 2026 15:49:54 +0800 Subject: [PATCH 1/6] add msprobe --- .../_generated_ppo_megatron_trainer.yaml | 12 ++++ .../config/_generated_ppo_trainer.yaml | 12 ++++ verl/trainer/config/ppo_megatron_trainer.yaml | 8 +++ verl/trainer/config/ppo_trainer.yaml | 21 ++++++ verl/trainer/ppo/ray_trainer.py | 14 ++++ verl/utils/precision_debugger.py | 64 +++++++++++++++++++ verl/workers/actor/dp_actor.py | 9 +++ verl/workers/actor/megatron_actor.py | 9 +++ verl/workers/engine_workers.py | 18 +++++- verl/workers/fsdp_workers.py | 11 ++++ verl/workers/megatron_workers.py | 11 ++++ 11 files changed, 187 insertions(+), 2 deletions(-) create mode 100644 verl/utils/precision_debugger.py diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index bc418f0b1fa..4982f26cbbd 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -711,6 +711,18 @@ global_profiler: context: all stacks: all kw_args: {} +precision_debugger: + enable: false + config_path: null + data_dir: outputs/precision_debug + steps: null + stages: + - rollout + - train_fwd + - train_bwd + - update_actor + - ref_model + fail_open: true transfer_queue: enable: false ray_kwargs: diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index a3baaf52af3..166e3cb81be 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -645,6 +645,18 @@ global_profiler: context: all stacks: all kw_args: {} +precision_debugger: + enable: false + config_path: null + data_dir: outputs/precision_debug + steps: null + stages: + - rollout + - train_fwd + - train_bwd + - update_actor + - ref_model + fail_open: true transfer_queue: enable: false ray_kwargs: diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 76ba4c57575..83840b38e75 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -237,6 +237,14 @@ global_profiler: # devices, record_context etc. kw_args: {} +precision_debugger: + enable: False + config_path: null + data_dir: "outputs/precision_debug" + steps: null + stages: ["rollout", "train_fwd", "train_bwd", "update_actor", "ref_model"] + fail_open: True + # configs for TransferQueue transfer_queue: # Whether to enable transfer queue diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 7489b522fa2..3023e34f8d3 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -301,6 +301,27 @@ global_profiler: # devices, record_context etc. kw_args: {} +# precision debugger configs +precision_debugger: + + # Whether to enable precision debugger + enable: False + + # Path to msprobe config.json + config_path: null + + # Dump root directory + data_dir: "outputs/precision_debug" + + # Steps to collect, null means all + steps: null + + # Stages to collect + stages: ["rollout", "train_fwd", "train_bwd", "update_actor", "ref_model"] + + # Fail open on errors + fail_open: True + # configs for TransferQueue transfer_queue: diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 46271507934..cef91bec801 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -271,6 +271,7 @@ def __init__( self.config = config self.reward_fn = reward_fn self.val_reward_fn = val_reward_fn + self._propagate_precision_debugger_config() self.hybrid_engine = config.actor_rollout_ref.hybrid_engine assert self.hybrid_engine, "Currently, only support hybrid engine" @@ -1073,6 +1074,18 @@ def _stop_profiling(self, do_profile: bool) -> None: if self.use_rm and not self.use_reward_loop: self.rm_wg.stop_profile() + def _propagate_precision_debugger_config(self) -> None: + precision_cfg = OmegaConf.select(self.config, "precision_debugger") + if precision_cfg is None: + return + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref") is not None: + self.config.actor_rollout_ref.precision_debugger = precision_cfg + if OmegaConf.select(self.config, "critic") is not None: + self.config.critic.precision_debugger = precision_cfg + if OmegaConf.select(self.config, "reward_model") is not None: + self.config.reward_model.precision_debugger = precision_cfg + def _get_dp_size(self, worker_group, role: str) -> int: """Get data parallel size from worker group dispatch info. @@ -1369,6 +1382,7 @@ def fit(self): else curr_step_profile ) batch: DataProto = DataProto.from_single_dict(batch_dict) + batch.meta_info["global_steps"] = self.global_steps batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature # add uid to batch diff --git a/verl/utils/precision_debugger.py b/verl/utils/precision_debugger.py new file mode 100644 index 00000000000..c3917767917 --- /dev/null +++ b/verl/utils/precision_debugger.py @@ -0,0 +1,64 @@ +import os +import threading +from dataclasses import dataclass + +_lock = threading.Lock() +_last_key = None + + +@dataclass(frozen=True) +class PrecisionHandle: + started: bool + + +def _should_profile(cfg, stage, global_step): + if not cfg or not cfg.get("enable", False): + return False + stages = cfg.get("stages", None) + if stages is not None and stage not in set(stages): + return False + steps = cfg.get("steps", None) + if steps is not None and global_step is not None: + if int(global_step) not in set(steps): + return False + config_path = cfg.get("config_path", None) + data_dir = cfg.get("data_dir", None) + return bool(config_path and data_dir) + + +def _dump_path(cfg, global_step, stage): + root = cfg.get("data_dir", "outputs/precision_debug") + step_dir = str(global_step) if global_step is not None else "unknown" + return os.path.join(root, step_dir, stage) + + +def _reset_instance(PrecisionDebugger): + PrecisionDebugger._instance = None + + +def precision_start(cfg, stage, global_step, model=None) -> PrecisionHandle: + if not _should_profile(cfg, stage, global_step): + return PrecisionHandle(False) + from msprobe.pytorch import PrecisionDebugger + + dump_path = _dump_path(cfg, global_step, stage) + os.makedirs(dump_path, exist_ok=True) + config_path = cfg.get("config_path") + key = (config_path, dump_path) + with _lock: + global _last_key + if _last_key != key: + _reset_instance(PrecisionDebugger) + PrecisionDebugger(config_path=config_path, dump_path=dump_path) + _last_key = key + PrecisionDebugger.start(model=model) + return PrecisionHandle(True) + + +def precision_stop(handle: PrecisionHandle) -> None: + if not handle.started: + return + from msprobe.pytorch import PrecisionDebugger + + PrecisionDebugger.step() + PrecisionDebugger.stop() diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index d524f0e2ba1..06f5e7c77f6 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -37,6 +37,7 @@ from verl.utils.torch_dtypes import PrecisionType from verl.utils.torch_functional import logprobs_from_logits from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs +from verl.utils.precision_debugger import precision_start, precision_stop from verl.workers.actor import BasePPOActor from verl.workers.config import ActorConfig @@ -546,6 +547,8 @@ def update_policy(self, data: DataProto): "actor/pg_loss": 0.0, "actor/kl_loss": 0.0, } + precision_cfg = getattr(self, "precision_debugger_cfg", None) + global_step = data.meta_info.get("global_steps", None) for _ in range(self.config.ppo_epochs): for batch_idx, mini_batch in enumerate(mini_batches): if self.config.use_dynamic_bsz: @@ -578,9 +581,11 @@ def update_policy(self, data: DataProto): loss_scale_factor = 1 / self.gradient_accumulation # all return: (bsz, response_length) + handle = precision_start(precision_cfg, "train_fwd", global_step, self.actor_module) outputs = self._forward_micro_batch( model_inputs, temperature=temperature, calculate_entropy=calculate_entropy ) + precision_stop(handle) log_prob = outputs["log_probs"] entropy = outputs["entropys"] if calculate_entropy else None @@ -654,15 +659,19 @@ def update_policy(self, data: DataProto): loss = policy_loss * loss_scale_factor else: loss = policy_loss * loss_scale_factor + handle = precision_start(precision_cfg, "train_bwd", global_step, self.actor_module) if self.scaler is not None: self.scaler.scale(loss).backward() else: loss.backward() + precision_stop(handle) metrics["actor/pg_loss"] += pg_loss.detach().item() * loss_scale_factor append_to_dict(metrics, micro_batch_metrics) + handle = precision_start(precision_cfg, "update_actor", global_step, self.actor_module) grad_norm = self._optimizer_step() + precision_stop(handle) mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()} append_to_dict(metrics, mini_batch_metrics) self.actor_optimizer.zero_grad() diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 7fdaa6e9811..706cf24f996 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -54,6 +54,7 @@ from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import broadcast_dict_tensor +from verl.utils.precision_debugger import precision_start, precision_stop from verl.workers.actor import BasePPOActor from verl.workers.config import MtpConfig @@ -768,7 +769,9 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals """ metrics = {} + precision_cfg = getattr(self, "precision_debugger_cfg", None) for data in dataloader: + global_step = data.meta_info.get("global_steps", None) if self.config.router_replay.mode in ["R2", "R3"]: RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) self.actor_optimizer.zero_grad() @@ -785,6 +788,7 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals max_token_len = None if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size + handle = precision_start(precision_cfg, "train_fwd", global_step, self.actor_module) metric_micro_batch = self.forward_backward_batch( data, calculate_entropy=calculate_entropy, @@ -793,6 +797,7 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals max_token_len=max_token_len, mini_batch_size=self.config.ppo_mini_batch_size, ) + precision_stop(handle) mtp_losses = metric_micro_batch.get("mtp_losses", None) if mtp_losses is not None: @@ -805,7 +810,11 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals # Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask append_to_dict(metrics, metric[0]) # append the metric from this micro-batch to global metrics. + handle_bwd = precision_start(precision_cfg, "train_bwd", global_step, self.actor_module) + handle_update = precision_start(precision_cfg, "update_actor", global_step, self.actor_module) update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step() + precision_stop(handle_bwd) + precision_stop(handle_update) data = {"actor/grad_norm": grad_norm} append_to_dict(metrics, data) diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index a4f0d9f4c77..933a95bbe9d 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -37,6 +37,7 @@ from verl.utils.flops_counter import FlopsCounter from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.metric.utils import Metric +from verl.utils.precision_debugger import precision_start, precision_stop from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage from verl.utils.py_functional import append_to_dict from verl.utils.tensordict_utils import maybe_fix_3d_position_ids @@ -69,6 +70,7 @@ def __init__(self, config: TrainingWorkerConfig): initialize_global_process_group_ray(timeout_second=None) self.config = config + self.precision_debugger_cfg = config.get("precision_debugger", None) self.model_config = self.config.model_config self.engine_config = self.config.engine_config self.optimizer_config = self.config.optimizer_config @@ -553,8 +555,14 @@ def init_model(self): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="ref")) @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") def compute_ref_log_prob(self, data: TensorDict) -> TensorDict: + global_step = data.get("global_steps", None) + handle = precision_start( + self.precision_debugger_cfg, "ref_model", global_step, getattr(self, "ref", None) + ) output = self.ref.infer_batch(data=data) - return output.cpu() if output is not None else None + output = output.cpu() if output is not None else None + precision_stop(handle) + return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") @@ -565,8 +573,14 @@ def compute_log_prob(self, data: TensorDict) -> TensorDict: @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="red", role="actor_update") def update_actor(self, data: TensorDict) -> TensorDict: + global_step = data.get("global_steps", None) + handle = precision_start( + self.precision_debugger_cfg, "update_actor", global_step, getattr(self, "actor", None) + ) output = self.actor.train_mini_batch(data=data) - return output.cpu() if output is not None else None + output = output.cpu() if output is not None else None + precision_stop(handle) + return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index a5e72f84f92..c19c3616372 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -80,6 +80,7 @@ from verl.utils.import_utils import import_external_libs from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.model import compute_position_id_with_mask, convert_weight_keys +from verl.utils.precision_debugger import precision_start, precision_stop from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max from verl.utils.py_functional import convert_to_regular_types @@ -147,6 +148,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs): Worker.__init__(self) self.config = config + self.precision_debugger_cfg = config.get("precision_debugger", None) import torch.distributed if not torch.distributed.is_initialized(): @@ -821,6 +823,7 @@ def init_model(self): self.actor = DataParallelPPOActor( config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer ) + self.actor.precision_debugger_cfg = self.precision_debugger_cfg if self._is_rollout: self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) @@ -957,8 +960,12 @@ def generate_sequences(self, prompts: DataProto): loop.run_until_complete(self.rollout_mode()) log_gpu_memory_usage("After switch to rollout mode", logger=logger) + global_step = prompts.meta_info.get("global_steps", None) + model_for_debug = getattr(self, "actor_module_fsdp", None) or getattr(self, "actor_module", None) + handle = precision_start(self.precision_debugger_cfg, "rollout", global_step, model_for_debug) with simple_timer("generate_sequences", timing_generate): output = self.rollout.generate_sequences(prompts=prompts) + precision_stop(handle) if self._is_actor: loop.run_until_complete(self.trainer_mode()) @@ -1053,10 +1060,14 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id) + global_step = data.meta_info.get("global_steps", None) + model_for_debug = getattr(self, "ref_module_fsdp", None) or getattr(self, "actor_module_fsdp", None) + handle = precision_start(self.precision_debugger_cfg, "ref_model", global_step, model_for_debug) with self.ulysses_sharding_manager: data = data.to("cpu") # data will to device with each micro batch on ref.compute_log_prob outputs = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) output = DataProto.from_dict(tensors={"ref_log_prob": outputs["log_probs"]}) + precision_stop(handle) output = output.to("cpu") diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 14aa17949f9..0e17708944b 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -64,6 +64,7 @@ ) from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights +from verl.utils.precision_debugger import precision_start, precision_stop from verl.utils.profiler import ( DistProfiler, DistProfilerExtension, @@ -253,6 +254,7 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension): def __init__(self, config: DictConfig, role: str, **kwargs): Worker.__init__(self) self.config = config + self.precision_debugger_cfg = config.get("precision_debugger", None) if repatch is not None: # NPU MindSpeed patch, will be refactored with MindSpeedEngine. repatch(self.config.actor.megatron.get("override_transformer_config", {})) @@ -612,6 +614,7 @@ def init_model(self): actor_optimizer=self.actor_optimizer, mtp_config=self.config.model.mtp if self.config.model.mtp.enable else None, ) + self.actor.precision_debugger_cfg = self.precision_debugger_cfg print(f"routing replay layers: {len(RouterReplay.router_instances)}") log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) @@ -808,8 +811,12 @@ def generate_sequences(self, prompts: DataProto): loop.run_until_complete(self.rollout_mode()) log_gpu_memory_usage("After switch to rollout mode", logger=logger) + global_step = prompts.meta_info.get("global_steps", None) + model_for_debug = getattr(self, "actor_module", None) + handle = precision_start(self.precision_debugger_cfg, "rollout", global_step, model_for_debug) with simple_timer("generate_sequences", timing_generate): output = self.rollout.generate_sequences(prompts=prompts) + precision_stop(handle) if self._is_actor: loop.run_until_complete(self.trainer_mode()) @@ -851,9 +858,13 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature + global_step = data.meta_info.get("global_steps", None) + model_for_debug = getattr(self, "ref_module", None) or getattr(self, "actor_module", None) + handle = precision_start(self.precision_debugger_cfg, "ref_model", global_step, model_for_debug) output, _, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = output.to("cpu") + precision_stop(handle) if self._ref_is_offload_param: offload_megatron_model_to_cpu(self.ref_module) log_gpu_memory_usage("After offload ref params and grad during compute_ref_log_prob", logger=logger) From 45b9a7f0c514b66dee78ecc76f5738cf252a743a Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 4 Feb 2026 09:55:44 +0800 Subject: [PATCH 2/6] fix review --- .../_generated_ppo_megatron_trainer.yaml | 3 +- .../config/_generated_ppo_trainer.yaml | 3 +- verl/trainer/config/ppo_megatron_trainer.yaml | 2 +- verl/trainer/config/ppo_trainer.yaml | 2 +- verl/utils/precision_debugger.py | 64 ------------- verl/utils/profiler/profile.py | 95 +++++++++++++++++++ verl/workers/actor/dp_actor.py | 13 +-- verl/workers/actor/megatron_actor.py | 17 ++-- verl/workers/engine_workers.py | 19 +--- verl/workers/fsdp_workers.py | 12 +-- verl/workers/megatron_workers.py | 12 +-- 11 files changed, 120 insertions(+), 122 deletions(-) delete mode 100644 verl/utils/precision_debugger.py diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 4982f26cbbd..642172169ac 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -718,8 +718,7 @@ precision_debugger: steps: null stages: - rollout - - train_fwd - - train_bwd + - train - update_actor - ref_model fail_open: true diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 166e3cb81be..e52536cd324 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -652,8 +652,7 @@ precision_debugger: steps: null stages: - rollout - - train_fwd - - train_bwd + - train - update_actor - ref_model fail_open: true diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 83840b38e75..d17dcd34886 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -242,7 +242,7 @@ precision_debugger: config_path: null data_dir: "outputs/precision_debug" steps: null - stages: ["rollout", "train_fwd", "train_bwd", "update_actor", "ref_model"] + stages: ["rollout", "train", "update_actor", "ref_model"] fail_open: True # configs for TransferQueue diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 3023e34f8d3..60148ec39ba 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -317,7 +317,7 @@ precision_debugger: steps: null # Stages to collect - stages: ["rollout", "train_fwd", "train_bwd", "update_actor", "ref_model"] + stages: ["rollout", "train", "update_actor", "ref_model"] # Fail open on errors fail_open: True diff --git a/verl/utils/precision_debugger.py b/verl/utils/precision_debugger.py deleted file mode 100644 index c3917767917..00000000000 --- a/verl/utils/precision_debugger.py +++ /dev/null @@ -1,64 +0,0 @@ -import os -import threading -from dataclasses import dataclass - -_lock = threading.Lock() -_last_key = None - - -@dataclass(frozen=True) -class PrecisionHandle: - started: bool - - -def _should_profile(cfg, stage, global_step): - if not cfg or not cfg.get("enable", False): - return False - stages = cfg.get("stages", None) - if stages is not None and stage not in set(stages): - return False - steps = cfg.get("steps", None) - if steps is not None and global_step is not None: - if int(global_step) not in set(steps): - return False - config_path = cfg.get("config_path", None) - data_dir = cfg.get("data_dir", None) - return bool(config_path and data_dir) - - -def _dump_path(cfg, global_step, stage): - root = cfg.get("data_dir", "outputs/precision_debug") - step_dir = str(global_step) if global_step is not None else "unknown" - return os.path.join(root, step_dir, stage) - - -def _reset_instance(PrecisionDebugger): - PrecisionDebugger._instance = None - - -def precision_start(cfg, stage, global_step, model=None) -> PrecisionHandle: - if not _should_profile(cfg, stage, global_step): - return PrecisionHandle(False) - from msprobe.pytorch import PrecisionDebugger - - dump_path = _dump_path(cfg, global_step, stage) - os.makedirs(dump_path, exist_ok=True) - config_path = cfg.get("config_path") - key = (config_path, dump_path) - with _lock: - global _last_key - if _last_key != key: - _reset_instance(PrecisionDebugger) - PrecisionDebugger(config_path=config_path, dump_path=dump_path) - _last_key = key - PrecisionDebugger.start(model=model) - return PrecisionHandle(True) - - -def precision_stop(handle: PrecisionHandle) -> None: - if not handle.started: - return - from msprobe.pytorch import PrecisionDebugger - - PrecisionDebugger.step() - PrecisionDebugger.stop() diff --git a/verl/utils/profiler/profile.py b/verl/utils/profiler/profile.py index 8e3145a66bb..1ff85d070ba 100644 --- a/verl/utils/profiler/profile.py +++ b/verl/utils/profiler/profile.py @@ -17,6 +17,10 @@ from ..memory_utils import MemorySnapshotSampler, enable_memory_visualize from .config import ProfilerConfig, TorchMemoryToolConfig +import os +import threading + +_precision_lock = threading.Lock() def mark_start_range( @@ -151,6 +155,97 @@ def stop(self): self._this_step = False return getattr(self._impl, "stop", lambda: None)() + def precision_start(self, precision_cfg, stage: str, global_step: Optional[int], model=None) -> bool: + if not precision_cfg or not precision_cfg.get("enable", False): + return False + stages = precision_cfg.get("stages", None) + if stages is not None and stage not in set(stages): + return False + steps = precision_cfg.get("steps", None) + if steps is not None and global_step is not None: + if int(global_step) not in set(steps): + return False + config_path = precision_cfg.get("config_path", None) + data_dir = precision_cfg.get("data_dir", None) + if not config_path or not data_dir: + return False + dump_path = os.path.join(data_dir, str(global_step) if global_step is not None else "unknown", stage) + os.makedirs(dump_path, exist_ok=True) + with _precision_lock: + from msprobe.pytorch import PrecisionDebugger + + debugger = PrecisionDebugger._instance + if debugger is None: + PrecisionDebugger(config_path=config_path, dump_path=dump_path) + debugger = PrecisionDebugger._instance + if debugger is None: + return False + debugger.service.config.dump_path = dump_path + debugger.start(model) + return True + + def precision_stop(self, precision_cfg, stage: str, started: bool) -> None: + if not started: + return + from msprobe.pytorch import PrecisionDebugger + + debugger = PrecisionDebugger._instance + if debugger is None: + return + debugger.stop() + if stage == "update_actor": + debugger.step() + + @classmethod + def precision( + cls, + stage: str, + model_attr: Optional[object] = None, + ) -> Callable: + def _get_model(self_instance): + if model_attr is None: + return None + if isinstance(model_attr, (list, tuple)): + for attr in model_attr: + val = getattr(self_instance, attr, None) + if val is not None: + return val + return None + return getattr(self_instance, model_attr, None) + + def _get_global_step(args, kwargs): + for val in list(args) + list(kwargs.values()): + if hasattr(val, "meta_info"): + meta = getattr(val, "meta_info") + if isinstance(meta, dict) and "global_steps" in meta: + return meta.get("global_steps") + if hasattr(val, "get"): + try: + step = val.get("global_steps", None) + if step is not None: + return step + except Exception: + pass + return None + + def decorator(func): + @functools.wraps(func) + def wrapper(self_instance, *args, **kwargs): + profiler = getattr(self_instance, "profiler", None) + precision_cfg = getattr(self_instance, "precision_debugger_cfg", None) + if not profiler or not precision_cfg: + return func(self_instance, *args, **kwargs) + global_step = _get_global_step(args, kwargs) + model = _get_model(self_instance) + started = profiler.precision_start(precision_cfg, stage, global_step, model) + result = func(self_instance, *args, **kwargs) + profiler.precision_stop(precision_cfg, stage, started) + return result + + return wrapper + + return decorator + @classmethod def annotate( cls, diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 06f5e7c77f6..ecd80c5c020 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -37,7 +37,7 @@ from verl.utils.torch_dtypes import PrecisionType from verl.utils.torch_functional import logprobs_from_logits from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs -from verl.utils.precision_debugger import precision_start, precision_stop + from verl.workers.actor import BasePPOActor from verl.workers.config import ActorConfig @@ -389,6 +389,7 @@ def _forward_micro_batch( outputs["sum_pi_squared"] = sum_pi_squared return outputs + @DistProfiler.precision(stage="update_actor", model_attr="actor_module") def _optimizer_step(self): assert self.config.grad_clip is not None if self.scaler is not None: @@ -500,6 +501,7 @@ def compute_log_prob(self, data: DataProto, calculate_entropy: bool = False) -> return outputs @GPUMemoryLogger(role="dp actor", logger=logger) + @DistProfiler.precision(stage="train", model_attr="actor_module") def update_policy(self, data: DataProto): # make sure we are in training mode self.actor_module.train() @@ -547,8 +549,6 @@ def update_policy(self, data: DataProto): "actor/pg_loss": 0.0, "actor/kl_loss": 0.0, } - precision_cfg = getattr(self, "precision_debugger_cfg", None) - global_step = data.meta_info.get("global_steps", None) for _ in range(self.config.ppo_epochs): for batch_idx, mini_batch in enumerate(mini_batches): if self.config.use_dynamic_bsz: @@ -581,11 +581,10 @@ def update_policy(self, data: DataProto): loss_scale_factor = 1 / self.gradient_accumulation # all return: (bsz, response_length) - handle = precision_start(precision_cfg, "train_fwd", global_step, self.actor_module) outputs = self._forward_micro_batch( model_inputs, temperature=temperature, calculate_entropy=calculate_entropy ) - precision_stop(handle) + # keep handle active across backward log_prob = outputs["log_probs"] entropy = outputs["entropys"] if calculate_entropy else None @@ -659,19 +658,15 @@ def update_policy(self, data: DataProto): loss = policy_loss * loss_scale_factor else: loss = policy_loss * loss_scale_factor - handle = precision_start(precision_cfg, "train_bwd", global_step, self.actor_module) if self.scaler is not None: self.scaler.scale(loss).backward() else: loss.backward() - precision_stop(handle) metrics["actor/pg_loss"] += pg_loss.detach().item() * loss_scale_factor append_to_dict(metrics, micro_batch_metrics) - handle = precision_start(precision_cfg, "update_actor", global_step, self.actor_module) grad_norm = self._optimizer_step() - precision_stop(handle) mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()} append_to_dict(metrics, mini_batch_metrics) self.actor_optimizer.zero_grad() diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 706cf24f996..27bd63a8b08 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -54,7 +54,7 @@ from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import broadcast_dict_tensor -from verl.utils.precision_debugger import precision_start, precision_stop + from verl.workers.actor import BasePPOActor from verl.workers.config import MtpConfig @@ -754,6 +754,7 @@ def logits_processor(logits, label, label_mask): return losses_reduced @GPUMemoryLogger(role="megatron actor", logger=logger) + @DistProfiler.precision(stage="train", model_attr="actor_module") def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = False) -> dict: """Update the policy with an iterator of DataProto @@ -769,9 +770,7 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals """ metrics = {} - precision_cfg = getattr(self, "precision_debugger_cfg", None) for data in dataloader: - global_step = data.meta_info.get("global_steps", None) if self.config.router_replay.mode in ["R2", "R3"]: RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) self.actor_optimizer.zero_grad() @@ -788,7 +787,6 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals max_token_len = None if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size - handle = precision_start(precision_cfg, "train_fwd", global_step, self.actor_module) metric_micro_batch = self.forward_backward_batch( data, calculate_entropy=calculate_entropy, @@ -797,7 +795,6 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals max_token_len=max_token_len, mini_batch_size=self.config.ppo_mini_batch_size, ) - precision_stop(handle) mtp_losses = metric_micro_batch.get("mtp_losses", None) if mtp_losses is not None: @@ -810,11 +807,7 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals # Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask append_to_dict(metrics, metric[0]) # append the metric from this micro-batch to global metrics. - handle_bwd = precision_start(precision_cfg, "train_bwd", global_step, self.actor_module) - handle_update = precision_start(precision_cfg, "update_actor", global_step, self.actor_module) - update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step() - precision_stop(handle_bwd) - precision_stop(handle_update) + update_successful, grad_norm, num_zeros_in_grad = self._optimizer_step_with_precision() data = {"actor/grad_norm": grad_norm} append_to_dict(metrics, data) @@ -831,3 +824,7 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals self.actor_optimizer.zero_grad() get_torch_device().empty_cache() return metrics + + @DistProfiler.precision(stage="update_actor", model_attr="actor_module") + def _optimizer_step_with_precision(self): + return self.actor_optimizer.step() diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index 933a95bbe9d..f5cbf7ab94b 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -37,7 +37,6 @@ from verl.utils.flops_counter import FlopsCounter from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.metric.utils import Metric -from verl.utils.precision_debugger import precision_start, precision_stop from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage from verl.utils.py_functional import append_to_dict from verl.utils.tensordict_utils import maybe_fix_3d_position_ids @@ -554,15 +553,10 @@ def init_model(self): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="ref")) @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") + @DistProfiler.precision(stage="ref_model", model_attr="ref") def compute_ref_log_prob(self, data: TensorDict) -> TensorDict: - global_step = data.get("global_steps", None) - handle = precision_start( - self.precision_debugger_cfg, "ref_model", global_step, getattr(self, "ref", None) - ) output = self.ref.infer_batch(data=data) - output = output.cpu() if output is not None else None - precision_stop(handle) - return output + return output.cpu() if output is not None else None @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") @@ -572,15 +566,10 @@ def compute_log_prob(self, data: TensorDict) -> TensorDict: @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="red", role="actor_update") + @DistProfiler.precision(stage="update_actor", model_attr="actor") def update_actor(self, data: TensorDict) -> TensorDict: - global_step = data.get("global_steps", None) - handle = precision_start( - self.precision_debugger_cfg, "update_actor", global_step, getattr(self, "actor", None) - ) output = self.actor.train_mini_batch(data=data) - output = output.cpu() if output is not None else None - precision_stop(handle) - return output + return output.cpu() if output is not None else None @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index c19c3616372..18c8f42b350 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -80,7 +80,6 @@ from verl.utils.import_utils import import_external_libs from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.model import compute_position_id_with_mask, convert_weight_keys -from verl.utils.precision_debugger import precision_start, precision_stop from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max from verl.utils.py_functional import convert_to_regular_types @@ -824,6 +823,7 @@ def init_model(self): config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer ) self.actor.precision_debugger_cfg = self.precision_debugger_cfg + self.actor.profiler = self.profiler if self._is_rollout: self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) @@ -939,6 +939,7 @@ def update_actor(self, data: DataProto): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) @DistProfiler.annotate(color="red", role="rollout_generate") + @DistProfiler.precision(stage="rollout", model_attr=("actor_module_fsdp", "actor_module")) def generate_sequences(self, prompts: DataProto): # Support all hardwares assert self._is_rollout @@ -960,12 +961,8 @@ def generate_sequences(self, prompts: DataProto): loop.run_until_complete(self.rollout_mode()) log_gpu_memory_usage("After switch to rollout mode", logger=logger) - global_step = prompts.meta_info.get("global_steps", None) - model_for_debug = getattr(self, "actor_module_fsdp", None) or getattr(self, "actor_module", None) - handle = precision_start(self.precision_debugger_cfg, "rollout", global_step, model_for_debug) with simple_timer("generate_sequences", timing_generate): output = self.rollout.generate_sequences(prompts=prompts) - precision_stop(handle) if self._is_actor: loop.run_until_complete(self.trainer_mode()) @@ -1045,6 +1042,7 @@ def compute_log_prob(self, data: DataProto): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") + @DistProfiler.precision(stage="ref_model", model_attr=("ref_module_fsdp", "actor_module_fsdp")) def compute_ref_log_prob(self, data: DataProto): if self._is_lora: # if _is_lora, actor without lora applied is the ref @@ -1060,14 +1058,10 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id) - global_step = data.meta_info.get("global_steps", None) - model_for_debug = getattr(self, "ref_module_fsdp", None) or getattr(self, "actor_module_fsdp", None) - handle = precision_start(self.precision_debugger_cfg, "ref_model", global_step, model_for_debug) with self.ulysses_sharding_manager: data = data.to("cpu") # data will to device with each micro batch on ref.compute_log_prob outputs = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) output = DataProto.from_dict(tensors={"ref_log_prob": outputs["log_probs"]}) - precision_stop(handle) output = output.to("cpu") diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 0e17708944b..ed65191ae3f 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -64,7 +64,6 @@ ) from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights -from verl.utils.precision_debugger import precision_start, precision_stop from verl.utils.profiler import ( DistProfiler, DistProfilerExtension, @@ -615,6 +614,7 @@ def init_model(self): mtp_config=self.config.model.mtp if self.config.model.mtp.enable else None, ) self.actor.precision_debugger_cfg = self.precision_debugger_cfg + self.actor.profiler = self.profiler print(f"routing replay layers: {len(RouterReplay.router_instances)}") log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) @@ -790,6 +790,7 @@ def update_actor(self, data: DataProto): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) @GPUMemoryLogger(role="generate_sequences", logger=logger) @DistProfiler.annotate(color="red", role="rollout_generate") + @DistProfiler.precision(stage="rollout", model_attr="actor_module") def generate_sequences(self, prompts: DataProto): assert self._is_rollout prompts = prompts.to(get_device_name()) @@ -811,12 +812,8 @@ def generate_sequences(self, prompts: DataProto): loop.run_until_complete(self.rollout_mode()) log_gpu_memory_usage("After switch to rollout mode", logger=logger) - global_step = prompts.meta_info.get("global_steps", None) - model_for_debug = getattr(self, "actor_module", None) - handle = precision_start(self.precision_debugger_cfg, "rollout", global_step, model_for_debug) with simple_timer("generate_sequences", timing_generate): output = self.rollout.generate_sequences(prompts=prompts) - precision_stop(handle) if self._is_actor: loop.run_until_complete(self.trainer_mode()) @@ -844,6 +841,7 @@ def generate_sequences(self, prompts: DataProto): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @GPUMemoryLogger(role="compute_ref_log_prob", logger=logger) @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") + @DistProfiler.precision(stage="ref_model", model_attr=("ref_module", "actor_module")) def compute_ref_log_prob(self, data: DataProto): if self.peft_cls is not None: # if is lora, actor without lora applied is the ref @@ -858,13 +856,9 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature - global_step = data.meta_info.get("global_steps", None) - model_for_debug = getattr(self, "ref_module", None) or getattr(self, "actor_module", None) - handle = precision_start(self.precision_debugger_cfg, "ref_model", global_step, model_for_debug) output, _, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = output.to("cpu") - precision_stop(handle) if self._ref_is_offload_param: offload_megatron_model_to_cpu(self.ref_module) log_gpu_memory_usage("After offload ref params and grad during compute_ref_log_prob", logger=logger) From fd3918ff34349877be3cc3c946f4e37cc2db6d3d Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 4 Feb 2026 10:14:42 +0800 Subject: [PATCH 3/6] fix review --- .../_generated_ppo_megatron_trainer.yaml | 22 +++++++------- .../config/_generated_ppo_trainer.yaml | 22 +++++++------- verl/trainer/config/ppo_megatron_trainer.yaml | 15 +++++----- verl/trainer/config/ppo_trainer.yaml | 29 +++++++++---------- verl/trainer/ppo/ray_trainer.py | 2 +- verl/utils/profiler/config.py | 21 ++++++++++++++ verl/utils/profiler/profile.py | 6 ++-- verl/workers/actor/dp_actor.py | 2 +- verl/workers/actor/megatron_actor.py | 1 + verl/workers/megatron_workers.py | 1 + 10 files changed, 72 insertions(+), 49 deletions(-) diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 642172169ac..cf155fbb9c4 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -711,17 +711,17 @@ global_profiler: context: all stacks: all kw_args: {} -precision_debugger: - enable: false - config_path: null - data_dir: outputs/precision_debug - steps: null - stages: - - rollout - - train - - update_actor - - ref_model - fail_open: true + precision_debugger: + _target_: verl.utils.profiler.config.PrecisionDebuggerToolConfig + enable: false + config_path: null + data_dir: outputs/precision_debug + steps: null + stages: + - rollout + - train + - update_actor + - ref_model transfer_queue: enable: false ray_kwargs: diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index e52536cd324..cee5bc01feb 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -645,17 +645,17 @@ global_profiler: context: all stacks: all kw_args: {} -precision_debugger: - enable: false - config_path: null - data_dir: outputs/precision_debug - steps: null - stages: - - rollout - - train - - update_actor - - ref_model - fail_open: true + precision_debugger: + _target_: verl.utils.profiler.config.PrecisionDebuggerToolConfig + enable: false + config_path: null + data_dir: outputs/precision_debug + steps: null + stages: + - rollout + - train + - update_actor + - ref_model transfer_queue: enable: false ray_kwargs: diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index d17dcd34886..f9604147564 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -236,14 +236,13 @@ global_profiler: stacks: "all" # devices, record_context etc. kw_args: {} - -precision_debugger: - enable: False - config_path: null - data_dir: "outputs/precision_debug" - steps: null - stages: ["rollout", "train", "update_actor", "ref_model"] - fail_open: True + precision_debugger: + _target_: verl.utils.profiler.config.PrecisionDebuggerToolConfig + enable: False + config_path: null + data_dir: "outputs/precision_debug" + steps: null + stages: ["rollout", "train", "update_actor", "ref_model"] # configs for TransferQueue transfer_queue: diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 60148ec39ba..7c8f708bfaf 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -300,27 +300,26 @@ global_profiler: # devices, record_context etc. kw_args: {} + # precision debugger config + precision_debugger: -# precision debugger configs -precision_debugger: - - # Whether to enable precision debugger - enable: False + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.PrecisionDebuggerToolConfig - # Path to msprobe config.json - config_path: null + # Whether to enable precision debugger + enable: False - # Dump root directory - data_dir: "outputs/precision_debug" + # Path to msprobe config.json + config_path: null - # Steps to collect, null means all - steps: null + # Dump root directory + data_dir: "outputs/precision_debug" - # Stages to collect - stages: ["rollout", "train", "update_actor", "ref_model"] + # Steps to collect, null means all + steps: null - # Fail open on errors - fail_open: True + # Stages to collect + stages: ["rollout", "train", "update_actor", "ref_model"] # configs for TransferQueue transfer_queue: diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index cef91bec801..4d01d4e9eca 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -1075,7 +1075,7 @@ def _stop_profiling(self, do_profile: bool) -> None: self.rm_wg.stop_profile() def _propagate_precision_debugger_config(self) -> None: - precision_cfg = OmegaConf.select(self.config, "precision_debugger") + precision_cfg = OmegaConf.select(self.config, "global_profiler.global_tool_config.precision_debugger") if precision_cfg is None: return with open_dict(self.config): diff --git a/verl/utils/profiler/config.py b/verl/utils/profiler/config.py index 4430d758698..fe99e0d71ae 100644 --- a/verl/utils/profiler/config.py +++ b/verl/utils/profiler/config.py @@ -78,6 +78,27 @@ def __post_init__(self) -> None: assert self.stack_depth > 0, f"stack_depth must be positive, got {self.stack_depth}" +@dataclass +class PrecisionDebuggerToolConfig(BaseConfig): + """Precision debugger tool config (msprobe).""" + + enable: bool = False + config_path: Optional[str] = None + data_dir: str = "outputs/precision_debug" + steps: Optional[list[int]] = None + stages: Optional[list[str]] = None + + def __post_init__(self) -> None: + assert isinstance(self.enable, bool), f"enable must be bool, got {type(self.enable)}" + if self.config_path is not None: + assert isinstance(self.config_path, str), f"config_path must be str, got {type(self.config_path)}" + assert isinstance(self.data_dir, str), f"data_dir must be str, got {type(self.data_dir)}" + if self.steps is not None: + assert isinstance(self.steps, list), f"steps must be list[int], got {type(self.steps)}" + if self.stages is not None: + assert isinstance(self.stages, list), f"stages must be list[str], got {type(self.stages)}" + + @dataclass class NPUToolConfig(NsightToolConfig): """NPU profiler too; config.""" diff --git a/verl/utils/profiler/profile.py b/verl/utils/profiler/profile.py index 1ff85d070ba..e0a7e7e4811 100644 --- a/verl/utils/profiler/profile.py +++ b/verl/utils/profiler/profile.py @@ -213,7 +213,7 @@ def _get_model(self_instance): return None return getattr(self_instance, model_attr, None) - def _get_global_step(args, kwargs): + def _get_global_step(self_instance, args, kwargs): for val in list(args) + list(kwargs.values()): if hasattr(val, "meta_info"): meta = getattr(val, "meta_info") @@ -226,6 +226,8 @@ def _get_global_step(args, kwargs): return step except Exception: pass + if hasattr(self_instance, "precision_global_step"): + return getattr(self_instance, "precision_global_step") return None def decorator(func): @@ -235,7 +237,7 @@ def wrapper(self_instance, *args, **kwargs): precision_cfg = getattr(self_instance, "precision_debugger_cfg", None) if not profiler or not precision_cfg: return func(self_instance, *args, **kwargs) - global_step = _get_global_step(args, kwargs) + global_step = _get_global_step(self_instance, args, kwargs) model = _get_model(self_instance) started = profiler.precision_start(precision_cfg, stage, global_step, model) result = func(self_instance, *args, **kwargs) diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index ecd80c5c020..63ce74f9f49 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -31,7 +31,7 @@ from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input from verl.utils.device import get_device_id, get_device_name from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ -from verl.utils.profiler import GPUMemoryLogger +from verl.utils.profiler import DistProfiler, GPUMemoryLogger from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch from verl.utils.torch_dtypes import PrecisionType diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 27bd63a8b08..3b51e6b0082 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -39,6 +39,7 @@ from verl import DataProto from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty from verl.utils.device import get_device_id, get_torch_device +from verl.utils.profiler import DistProfiler from verl.utils.megatron.pipeline_parallel import make_batch_generator from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction from verl.utils.megatron.router_replay_utils import ( diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index ed65191ae3f..2d8c1444b95 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -756,6 +756,7 @@ def update_actor(self, data: DataProto): micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size dataloader = self.actor.make_minibatch_iterator(data=data) + self.actor.precision_global_step = data.meta_info.get("global_steps", None) with Timer(name="update_policy", logger=None) as timer: metrics = self.actor.update_policy(dataloader=dataloader) delta_time = timer.last From 2354a2efbd6d013a6e65e9e6864035c70a447439 Mon Sep 17 00:00:00 2001 From: TAJh Date: Fri, 6 Feb 2026 16:40:32 +0800 Subject: [PATCH 4/6] support for rollout separate mode --- .../_generated_ppo_megatron_trainer.yaml | 10 +- .../config/_generated_ppo_trainer.yaml | 10 +- verl/trainer/config/ppo_megatron_trainer.yaml | 2 +- verl/trainer/config/ppo_trainer.yaml | 2 +- verl/utils/import_utils.py | 9 + verl/utils/profiler/config.py | 7 + .../profiler/precision_debugger_profile.py | 148 +++++++++++++ verl/utils/profiler/profile.py | 203 +++++++++--------- verl/workers/actor/dp_actor.py | 4 +- verl/workers/actor/megatron_actor.py | 4 +- verl/workers/engine_workers.py | 16 +- verl/workers/fsdp_workers.py | 44 +++- verl/workers/megatron_workers.py | 44 +++- .../sglang_rollout/async_sglang_server.py | 12 ++ .../sglang_rollout/http_server_engine.py | 6 +- .../rollout/sglang_rollout/sglang_rollout.py | 5 + verl/workers/rollout/vllm_rollout/utils.py | 4 + .../rollout/vllm_rollout/vllm_async_server.py | 12 ++ 18 files changed, 405 insertions(+), 137 deletions(-) create mode 100644 verl/utils/profiler/precision_debugger_profile.py diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index cf155fbb9c4..b506e9d610b 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -718,10 +718,14 @@ global_profiler: data_dir: outputs/precision_debug steps: null stages: - - rollout - - train + - rollout_generate - update_actor - - ref_model + - actor_compute_log_prob + - ref_compute_log_prob + - compute_values + - critic_update + - compute_rm_score + - train transfer_queue: enable: false ray_kwargs: diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index cee5bc01feb..20263fd3efd 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -652,10 +652,14 @@ global_profiler: data_dir: outputs/precision_debug steps: null stages: - - rollout - - train + - rollout_generate - update_actor - - ref_model + - actor_compute_log_prob + - ref_compute_log_prob + - compute_values + - critic_update + - compute_rm_score + - train transfer_queue: enable: false ray_kwargs: diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index f9604147564..296a94b041e 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -242,7 +242,7 @@ global_profiler: config_path: null data_dir: "outputs/precision_debug" steps: null - stages: ["rollout", "train", "update_actor", "ref_model"] + stages: ["rollout_generate", "update_actor", "actor_compute_log_prob", "ref_compute_log_prob", "compute_values", "critic_update", "compute_rm_score", "train"] # configs for TransferQueue transfer_queue: diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 7c8f708bfaf..f88a8ba5123 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -319,7 +319,7 @@ global_profiler: steps: null # Stages to collect - stages: ["rollout", "train", "update_actor", "ref_model"] + stages: ["rollout_generate", "update_actor", "actor_compute_log_prob", "ref_compute_log_prob", "compute_values", "critic_update", "compute_rm_score", "train"] # configs for TransferQueue transfer_queue: diff --git a/verl/utils/import_utils.py b/verl/utils/import_utils.py index ee78b580675..4e97ba070e8 100644 --- a/verl/utils/import_utils.py +++ b/verl/utils/import_utils.py @@ -69,6 +69,15 @@ def is_trl_available(): return trl_spec is not None +@cache +def is_msprobe_available(): + try: + msprobe_spec = importlib.util.find_spec("msprobe") + except ModuleNotFoundError: + msprobe_spec = None + return msprobe_spec is not None + + def import_external_libs(external_libs=None): if external_libs is None: return diff --git a/verl/utils/profiler/config.py b/verl/utils/profiler/config.py index fe99e0d71ae..2c18e930d88 100644 --- a/verl/utils/profiler/config.py +++ b/verl/utils/profiler/config.py @@ -87,6 +87,8 @@ class PrecisionDebuggerToolConfig(BaseConfig): data_dir: str = "outputs/precision_debug" steps: Optional[list[int]] = None stages: Optional[list[str]] = None + concurrency: str = "serialize" # serialize | per_thread | per_request + strict: bool = False def __post_init__(self) -> None: assert isinstance(self.enable, bool), f"enable must be bool, got {type(self.enable)}" @@ -97,6 +99,11 @@ def __post_init__(self) -> None: assert isinstance(self.steps, list), f"steps must be list[int], got {type(self.steps)}" if self.stages is not None: assert isinstance(self.stages, list), f"stages must be list[str], got {type(self.stages)}" + assert isinstance(self.concurrency, str), f"concurrency must be str, got {type(self.concurrency)}" + assert self.concurrency in {"serialize", "per_thread", "per_request"}, ( + "concurrency must be one of serialize, per_thread, per_request" + ) + assert isinstance(self.strict, bool), f"strict must be bool, got {type(self.strict)}" @dataclass diff --git a/verl/utils/profiler/precision_debugger_profile.py b/verl/utils/profiler/precision_debugger_profile.py new file mode 100644 index 00000000000..897243017ee --- /dev/null +++ b/verl/utils/profiler/precision_debugger_profile.py @@ -0,0 +1,148 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading +from dataclasses import asdict +from typing import Optional + +from verl.utils.import_utils import is_msprobe_available +from verl.utils.profiler.config import PrecisionDebuggerToolConfig + + +_GLOBAL_LOCK = threading.Lock() +_THREAD_LOCKS: dict[int, threading.Lock] = {} +_THREAD_LOCKS_LOCK = threading.Lock() + + +def _get_thread_lock() -> threading.Lock: + tid = threading.get_ident() + with _THREAD_LOCKS_LOCK: + lock = _THREAD_LOCKS.get(tid) + if lock is None: + lock = threading.Lock() + _THREAD_LOCKS[tid] = lock + return lock + + +class PrecisionDebuggerProfiler: + """Precision debugger wrapper for msprobe. + + This class implements a minimal start/stop contract and is intentionally + not a DistProfiler subclass to keep the dependency one-way. + """ + + def __init__(self, precision_cfg, rank: Optional[int] = None): + self.rank = rank + self.precision_cfg = self._normalize_config(precision_cfg) + self._active_lock: Optional[threading.Lock] = None + self._enabled = self._is_enabled(self.precision_cfg) + self._available = is_msprobe_available() + + @staticmethod + def _normalize_config(precision_cfg) -> PrecisionDebuggerToolConfig: + if precision_cfg is None: + return PrecisionDebuggerToolConfig() + if isinstance(precision_cfg, PrecisionDebuggerToolConfig): + return precision_cfg + if hasattr(precision_cfg, "to_container"): + precision_cfg = precision_cfg.to_container(resolve=True) + if isinstance(precision_cfg, dict): + return PrecisionDebuggerToolConfig(**precision_cfg) + return PrecisionDebuggerToolConfig(**asdict(precision_cfg)) + + @staticmethod + def _is_enabled(precision_cfg: PrecisionDebuggerToolConfig) -> bool: + return bool(precision_cfg.enable) + + def _should_collect(self, stage: str, global_step: Optional[int]) -> bool: + if not self._enabled: + return False + if self.precision_cfg.stages is not None and stage not in set(self.precision_cfg.stages): + return False + if self.precision_cfg.steps is not None and global_step is not None: + if int(global_step) not in set(self.precision_cfg.steps): + return False + return True + + def _get_lock(self) -> threading.Lock: + if self.precision_cfg.concurrency == "serialize": + return _GLOBAL_LOCK + if self.precision_cfg.concurrency == "per_thread": + return _get_thread_lock() + return threading.Lock() + + def start(self, stage: str, global_step: Optional[int] = None, model=None) -> bool: + if not self._should_collect(stage=stage, global_step=global_step): + return False + if not self._available: + if self.precision_cfg.strict: + raise ImportError("msprobe is not available but precision_debugger.strict is True") + return False + + config_path = self.precision_cfg.config_path + data_dir = self.precision_cfg.data_dir + if not config_path or not data_dir: + return False + + step_tag = f"step_{global_step}" if global_step is not None else "step_unknown" + rank_tag = f"rank_{self.rank}" if self.rank is not None else "rank_unknown" + dump_path = os.path.join(data_dir, step_tag, stage, rank_tag) + os.makedirs(dump_path, exist_ok=True) + + lock = self._get_lock() + lock.acquire() + self._active_lock = lock + try: + from msprobe.pytorch import PrecisionDebugger + + debugger = PrecisionDebugger._instance + if debugger is None or self.precision_cfg.concurrency == "per_request": + PrecisionDebugger(config_path=config_path, dump_path=dump_path) + debugger = PrecisionDebugger._instance + if debugger is None: + return False + debugger.service.config.dump_path = dump_path + debugger.start(model) + return True + except Exception: + self._release_lock() + if self.precision_cfg.strict: + raise + return False + + def stop(self, started: bool = False, step: bool = False) -> None: + if not started: + self._release_lock() + return + if not self._available: + self._release_lock() + return + try: + from msprobe.pytorch import PrecisionDebugger + + debugger = PrecisionDebugger._instance + if debugger is None: + return + debugger.stop() + if step: + debugger.step() + finally: + self._release_lock() + + def _release_lock(self) -> None: + lock = self._active_lock + self._active_lock = None + if lock is not None: + lock.release() diff --git a/verl/utils/profiler/profile.py b/verl/utils/profiler/profile.py index e0a7e7e4811..60de4259aa9 100644 --- a/verl/utils/profiler/profile.py +++ b/verl/utils/profiler/profile.py @@ -13,14 +13,11 @@ # limitations under the License. import functools +import inspect from typing import Callable, Optional from ..memory_utils import MemorySnapshotSampler, enable_memory_visualize from .config import ProfilerConfig, TorchMemoryToolConfig -import os -import threading - -_precision_lock = threading.Lock() def mark_start_range( @@ -81,6 +78,7 @@ class DistProfiler: - npu: NPUProfiler (Ascend) - torch: PyTorch torch.profiler wrapper - torch_memory: Torch CUDA memory snapshot dump + - precision_debugger: msprobe precision debugger """ def __init__( @@ -100,6 +98,7 @@ def __init__( self._tool = getattr(config, "tool", None) self._enable = config.enable self._this_step = False + self.rank = rank # Normalize rank selection self._this_rank = False @@ -129,6 +128,10 @@ def __init__( self._impl = _Torch(rank=rank, config=config, tool_config=tool_config) elif self._tool == "torch_memory": self._impl = TorchMemoryProfiler(rank=rank, config=config, tool_config=tool_config) + elif self._tool == "precision_debugger": + from .precision_debugger_profile import PrecisionDebuggerProfiler as _Precision + + self._impl = _Precision(precision_cfg=tool_config, rank=rank) else: # Fallback to a no-op impl self._impl = _NoOpProfiler() @@ -155,63 +158,42 @@ def stop(self): self._this_step = False return getattr(self._impl, "stop", lambda: None)() - def precision_start(self, precision_cfg, stage: str, global_step: Optional[int], model=None) -> bool: - if not precision_cfg or not precision_cfg.get("enable", False): - return False - stages = precision_cfg.get("stages", None) - if stages is not None and stage not in set(stages): - return False - steps = precision_cfg.get("steps", None) - if steps is not None and global_step is not None: - if int(global_step) not in set(steps): - return False - config_path = precision_cfg.get("config_path", None) - data_dir = precision_cfg.get("data_dir", None) - if not config_path or not data_dir: - return False - dump_path = os.path.join(data_dir, str(global_step) if global_step is not None else "unknown", stage) - os.makedirs(dump_path, exist_ok=True) - with _precision_lock: - from msprobe.pytorch import PrecisionDebugger - - debugger = PrecisionDebugger._instance - if debugger is None: - PrecisionDebugger(config_path=config_path, dump_path=dump_path) - debugger = PrecisionDebugger._instance - if debugger is None: - return False - debugger.service.config.dump_path = dump_path - debugger.start(model) - return True - - def precision_stop(self, precision_cfg, stage: str, started: bool) -> None: - if not started: - return - from msprobe.pytorch import PrecisionDebugger - - debugger = PrecisionDebugger._instance - if debugger is None: - return - debugger.stop() - if stage == "update_actor": - debugger.step() @classmethod - def precision( + def annotate( cls, - stage: str, - model_attr: Optional[object] = None, + message: Optional[str] = None, + color: Optional[str] = None, + domain: Optional[str] = None, + category: Optional[str] = None, + precision_stage: Optional[str] = None, + precision_model_attr: Optional[object] = None, + precision_global_step_attr: Optional[str] = None, + precision_step: bool = False, + **kwargs_outer, ) -> Callable: def _get_model(self_instance): - if model_attr is None: + if precision_model_attr is None: return None - if isinstance(model_attr, (list, tuple)): - for attr in model_attr: - val = getattr(self_instance, attr, None) + if isinstance(precision_model_attr, (list, tuple)): + for attr in precision_model_attr: + val = _resolve_attr(self_instance, attr) if val is not None: return val return None - return getattr(self_instance, model_attr, None) + return _resolve_attr(self_instance, precision_model_attr) + + def _resolve_attr(obj, attr): + if not isinstance(attr, str): + return None + if "." in attr: + current = obj + for part in attr.split("."): + current = getattr(current, part, None) + if current is None: + return None + return current + return getattr(obj, attr, None) def _get_global_step(self_instance, args, kwargs): for val in list(args) + list(kwargs.values()): @@ -219,67 +201,78 @@ def _get_global_step(self_instance, args, kwargs): meta = getattr(val, "meta_info") if isinstance(meta, dict) and "global_steps" in meta: return meta.get("global_steps") - if hasattr(val, "get"): - try: - step = val.get("global_steps", None) - if step is not None: - return step - except Exception: - pass + if isinstance(val, dict) and "global_steps" in val: + return val.get("global_steps") + if precision_global_step_attr and hasattr(self_instance, precision_global_step_attr): + return getattr(self_instance, precision_global_step_attr) if hasattr(self_instance, "precision_global_step"): return getattr(self_instance, "precision_global_step") return None - def decorator(func): - @functools.wraps(func) - def wrapper(self_instance, *args, **kwargs): - profiler = getattr(self_instance, "profiler", None) - precision_cfg = getattr(self_instance, "precision_debugger_cfg", None) - if not profiler or not precision_cfg: - return func(self_instance, *args, **kwargs) - global_step = _get_global_step(self_instance, args, kwargs) - model = _get_model(self_instance) - started = profiler.precision_start(precision_cfg, stage, global_step, model) - result = func(self_instance, *args, **kwargs) - profiler.precision_stop(precision_cfg, stage, started) - return result + def _build_precision_impl(self_instance): + precision_cfg = getattr(self_instance, "precision_debugger_cfg", None) + if not precision_cfg or not precision_stage: + return None + from .precision_debugger_profile import PrecisionDebuggerProfiler - return wrapper + rank = getattr(getattr(self_instance, "profiler", None), "rank", None) + return PrecisionDebuggerProfiler(precision_cfg, rank=rank) - return decorator + def _precision_start(precision_impl, self_instance, args, kwargs_inner): + if precision_impl is None: + return False + global_step = _get_global_step(self_instance, args, kwargs_inner) + model = _get_model(self_instance) + return precision_impl.start(stage=precision_stage, global_step=global_step, model=model) + + def _decorate_with_profiler(impl, func_inner): + if hasattr(impl, "annotate"): + return impl.annotate(message=message, color=color, domain=domain, category=category, **kwargs_outer)( + func_inner + ) + return func_inner + + def _should_profile(self_instance) -> bool: + profiler = getattr(self_instance, "profiler", None) + return ( + profiler + and profiler.check_enable() + and profiler.check_this_step() + and profiler.check_this_rank() + ) + + if inspect.iscoroutinefunction(func): - @classmethod - def annotate( - cls, - message: Optional[str] = None, - color: Optional[str] = None, - domain: Optional[str] = None, - category: Optional[str] = None, - **kwargs_outer, - ) -> Callable: - def decorator(func): @functools.wraps(func) + async def async_wrapper(self_instance, *args, **kwargs_inner): + precision_impl = _build_precision_impl(self_instance) + precision_started = _precision_start(precision_impl, self_instance, args, kwargs_inner) + try: + if _should_profile(self_instance): + impl = self_instance.profiler._impl + wrapped = _decorate_with_profiler(impl, func) + return await wrapped(self_instance, *args, **kwargs_inner) + return await func(self_instance, *args, **kwargs_inner) + finally: + if precision_impl is not None and precision_stage: + precision_impl.stop(started=precision_started, step=precision_step) + + return async_wrapper + + def decorator(func_inner): + @functools.wraps(func_inner) def wrapper(self_instance, *args, **kwargs_inner): - profiler = getattr(self_instance, "profiler", None) - if ( - not profiler - or not profiler.check_enable() - or not profiler.check_this_step() - or not profiler.check_this_rank() - ): - return func(self_instance, *args, **kwargs_inner) - - impl = profiler._impl - if hasattr(impl, "annotate"): - try: - actual_decorator = impl.annotate( - message=message, color=color, domain=domain, category=category, **kwargs_outer - ) - - return actual_decorator(func)(self_instance, *args, **kwargs_inner) - except Exception: - return func(self_instance, *args, **kwargs_inner) - return func(self_instance, *args, **kwargs_inner) + precision_impl = _build_precision_impl(self_instance) + precision_started = _precision_start(precision_impl, self_instance, args, kwargs_inner) + try: + if _should_profile(self_instance): + impl = self_instance.profiler._impl + wrapped = _decorate_with_profiler(impl, func_inner) + return wrapped(self_instance, *args, **kwargs_inner) + return func_inner(self_instance, *args, **kwargs_inner) + finally: + if precision_impl is not None and precision_stage: + precision_impl.stop(started=precision_started, step=precision_step) return wrapper @@ -294,6 +287,8 @@ def stop(self): return + + class TorchMemoryProfiler: """Profiler that dumps CUDA memory snapshots at step boundaries. diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 63ce74f9f49..c22de374010 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -389,7 +389,7 @@ def _forward_micro_batch( outputs["sum_pi_squared"] = sum_pi_squared return outputs - @DistProfiler.precision(stage="update_actor", model_attr="actor_module") + @DistProfiler.annotate(precision_stage="update_actor", precision_model_attr="actor_module", precision_step=True) def _optimizer_step(self): assert self.config.grad_clip is not None if self.scaler is not None: @@ -501,7 +501,7 @@ def compute_log_prob(self, data: DataProto, calculate_entropy: bool = False) -> return outputs @GPUMemoryLogger(role="dp actor", logger=logger) - @DistProfiler.precision(stage="train", model_attr="actor_module") + @DistProfiler.annotate(precision_stage="train", precision_model_attr="actor_module") def update_policy(self, data: DataProto): # make sure we are in training mode self.actor_module.train() diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 3b51e6b0082..a4450b25ad1 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -755,7 +755,7 @@ def logits_processor(logits, label, label_mask): return losses_reduced @GPUMemoryLogger(role="megatron actor", logger=logger) - @DistProfiler.precision(stage="train", model_attr="actor_module") + @DistProfiler.annotate(precision_stage="train", precision_model_attr="actor_module") def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = False) -> dict: """Update the policy with an iterator of DataProto @@ -826,6 +826,6 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals get_torch_device().empty_cache() return metrics - @DistProfiler.precision(stage="update_actor", model_attr="actor_module") + @DistProfiler.annotate(precision_stage="update_actor", precision_model_attr="actor_module", precision_step=True) def _optimizer_step_with_precision(self): return self.actor_optimizer.step() diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index f5cbf7ab94b..6b0e8c406a9 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -552,21 +552,29 @@ def init_model(self): ) @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="ref")) - @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") - @DistProfiler.precision(stage="ref_model", model_attr="ref") + @DistProfiler.annotate( + color="olive", + role="ref_compute_log_prob", + precision_stage="ref_compute_log_prob", + precision_model_attr="ref", + ) def compute_ref_log_prob(self, data: TensorDict) -> TensorDict: output = self.ref.infer_batch(data=data) return output.cpu() if output is not None else None @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) - @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") + @DistProfiler.annotate( + color="blue", + role="actor_compute_log_prob", + precision_stage="actor_compute_log_prob", + precision_model_attr="actor", + ) def compute_log_prob(self, data: TensorDict) -> TensorDict: output = self.actor.infer_batch(data) return output.cpu() if output is not None else None @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="red", role="actor_update") - @DistProfiler.precision(stage="update_actor", model_attr="actor") def update_actor(self, data: TensorDict) -> TensorDict: output = self.actor.train_mini_batch(data=data) return output.cpu() if output is not None else None diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 18c8f42b350..ae701b98247 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -938,8 +938,12 @@ def update_actor(self, data: DataProto): return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) - @DistProfiler.annotate(color="red", role="rollout_generate") - @DistProfiler.precision(stage="rollout", model_attr=("actor_module_fsdp", "actor_module")) + @DistProfiler.annotate( + color="red", + role="rollout_generate", + precision_stage="rollout_generate", + precision_model_attr=("actor_module_fsdp", "actor_module"), + ) def generate_sequences(self, prompts: DataProto): # Support all hardwares assert self._is_rollout @@ -989,7 +993,12 @@ def generate_sequences(self, prompts: DataProto): return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) - @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") + @DistProfiler.annotate( + color="blue", + role="actor_compute_log_prob", + precision_stage="actor_compute_log_prob", + precision_model_attr=("actor_module_fsdp", "actor_module"), + ) def compute_log_prob(self, data: DataProto): # when is_lora is True, we use the actor without lora applied to calculate the log_prob # which is mostly used for ref log_prob calculation @@ -1041,8 +1050,12 @@ def compute_log_prob(self, data: DataProto): return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) - @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") - @DistProfiler.precision(stage="ref_model", model_attr=("ref_module_fsdp", "actor_module_fsdp")) + @DistProfiler.annotate( + color="olive", + role="ref_compute_log_prob", + precision_stage="ref_compute_log_prob", + precision_model_attr=("ref_module_fsdp", "actor_module_fsdp"), + ) def compute_ref_log_prob(self, data: DataProto): if self._is_lora: # if _is_lora, actor without lora applied is the ref @@ -1541,7 +1554,12 @@ def init_model(self): ) @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) - @DistProfiler.annotate(color="cyan", role="compute_values") + @DistProfiler.annotate( + color="cyan", + role="compute_values", + precision_stage="compute_values", + precision_model_attr="critic_module_fsdp", + ) def compute_values(self, data: DataProto): if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) @@ -1561,7 +1579,12 @@ def compute_values(self, data: DataProto): return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) - @DistProfiler.annotate(color="pink", role="critic_update") + @DistProfiler.annotate( + color="pink", + role="critic_update", + precision_stage="critic_update", + precision_model_attr="critic_module_fsdp", + ) def update_critic(self, data: DataProto): if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) @@ -1929,7 +1952,12 @@ def _switch_chat_template(self, data: DataProto): return DataProto.from_dict(rm_inputs) @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="reward")) - @DistProfiler.annotate(color="brown", role="compute_rm_score") + @DistProfiler.annotate( + color="brown", + role="compute_rm_score", + precision_stage="compute_rm_score", + precision_model_attr="reward_model_module_fsdp", + ) def compute_rm_score(self, data: DataProto): import itertools diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 2d8c1444b95..846a28da83a 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -790,8 +790,12 @@ def update_actor(self, data: DataProto): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) @GPUMemoryLogger(role="generate_sequences", logger=logger) - @DistProfiler.annotate(color="red", role="rollout_generate") - @DistProfiler.precision(stage="rollout", model_attr="actor_module") + @DistProfiler.annotate( + color="red", + role="rollout_generate", + precision_stage="rollout_generate", + precision_model_attr="actor_module", + ) def generate_sequences(self, prompts: DataProto): assert self._is_rollout prompts = prompts.to(get_device_name()) @@ -841,8 +845,12 @@ def generate_sequences(self, prompts: DataProto): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @GPUMemoryLogger(role="compute_ref_log_prob", logger=logger) - @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") - @DistProfiler.precision(stage="ref_model", model_attr=("ref_module", "actor_module")) + @DistProfiler.annotate( + color="olive", + role="ref_compute_log_prob", + precision_stage="ref_compute_log_prob", + precision_model_attr=("ref_module", "actor_module"), + ) def compute_ref_log_prob(self, data: DataProto): if self.peft_cls is not None: # if is lora, actor without lora applied is the ref @@ -868,7 +876,12 @@ def compute_ref_log_prob(self, data: DataProto): @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @GPUMemoryLogger(role="compute_log_prob", logger=logger) - @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") + @DistProfiler.annotate( + color="blue", + role="actor_compute_log_prob", + precision_stage="actor_compute_log_prob", + precision_model_attr="actor_module", + ) def compute_log_prob(self, data: DataProto): assert self._is_actor if self._is_offload_param: @@ -1218,7 +1231,12 @@ def init_model(self): ) @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) - @DistProfiler.annotate(color="cyan", role="compute_values") + @DistProfiler.annotate( + color="cyan", + role="compute_values", + precision_stage="compute_values", + precision_model_attr="critic_module", + ) def compute_values(self, data: DataProto): micro_batch_size = self.config.ppo_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size @@ -1235,7 +1253,12 @@ def compute_values(self, data: DataProto): return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) - @DistProfiler.annotate(color="pink", role="critic_update") + @DistProfiler.annotate( + color="pink", + role="critic_update", + precision_stage="critic_update", + precision_model_attr="critic_module", + ) def update_critic(self, data: DataProto): data = data.to(get_device_id()) @@ -1459,7 +1482,12 @@ def init_model(self): # TODO: reward model use itself tokenizer instead of sft tokenizer # the input_ids, responses, attention_mask and position_ids may be different! @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="reward")) - @DistProfiler.annotate(color="brown", role="compute_rm_score") + @DistProfiler.annotate( + color="brown", + role="compute_rm_score", + precision_stage="compute_rm_score", + precision_model_attr="reward_model_module", + ) def compute_rm_score(self, data: DataProto): data.meta_info["micro_batch_size"] = self.config.micro_batch_size_per_gpu data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index 21a620dbc35..1294c8bb046 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -168,6 +168,8 @@ def __init__( self.config: RolloutConfig = omega_conf_to_dataclass(config) self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig) + self.precision_debugger_cfg = getattr(self.config, "precision_debugger", None) + self.precision_global_step = None max_position_embeddings = get_max_position_embeddings(self.model_config.hf_config) if self.config.max_model_len is None: self.config.max_model_len = max_position_embeddings @@ -220,6 +222,9 @@ def get_master_address(self): """Get master address and port for init NCCL process group.""" return self._master_address, self._master_port + def set_precision_global_step(self, global_step: int) -> None: + self.precision_global_step = global_step + def get_server_address(self): """Get http server address and port.""" assert self._server_port is not None, "http server is not launched, port is None" @@ -394,6 +399,11 @@ async def clear_kv_cache(self): obj = ReleaseMemoryOccupationReqInput(tags=["kv_cache"]) await self.tokenizer_manager.release_memory_occupation(obj, None) + @DistProfiler.annotate( + precision_stage="rollout_generate", + precision_model_attr=["model", "model_runner.model"], + precision_global_step_attr="precision_global_step", + ) async def generate( self, prompt_ids: torch.Tensor, @@ -403,6 +413,8 @@ async def generate( video_data: Optional[list[Any]] = None, ) -> TokenOutput: """Generate sequence with token-in-token-out.""" + if "_precision_global_step" in sampling_params: + self.precision_global_step = sampling_params.pop("_precision_global_step") # TODO(@wuxibin): switch to `/generate` http endpoint once multi-modal support ready. max_possible_tokens = self.config.max_model_len - len(prompt_ids) diff --git a/verl/workers/rollout/sglang_rollout/http_server_engine.py b/verl/workers/rollout/sglang_rollout/http_server_engine.py index 6822a9e52da..f3a01b842b9 100644 --- a/verl/workers/rollout/sglang_rollout/http_server_engine.py +++ b/verl/workers/rollout/sglang_rollout/http_server_engine.py @@ -103,6 +103,10 @@ async def _read_async_response(resp: aiohttp.ClientResponse) -> dict[str, Any]: } +def _launch_server_with_precision(server_args: ServerArgs) -> None: + launch_server(server_args) + + def launch_server_process( server_args: ServerArgs, timeout: float = DEFAULT_TIMEOUT, @@ -134,7 +138,7 @@ def launch_server_process( This is for consistency; except for the process obtained by node_rank = 0, other processes have no actual effect. """ - p = multiprocessing.Process(target=launch_server, args=(server_args,)) + p = multiprocessing.Process(target=_launch_server_with_precision, args=(server_args,)) if server_args.node_rank != 0 or not first_rank_in_node: logger.info(f"Server process started with PID {p.pid} for node rank {server_args.node_rank}", flush=True) return p diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 2be15fc5b05..c160efd4cb7 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -190,6 +190,11 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39 """ await self._init_server_adapter() + if not hasattr(self, "_precision_global_step"): + self._precision_global_step = -1 + self._precision_global_step += 1 + if hasattr(self, "server_actor") and self.server_actor is not None: + await self.server_actor.set_precision_global_step.remote(self._precision_global_step) update_weights_bucket_bytes = int(self.config.checkpoint_engine.update_weights_bucket_megabytes) << 20 if self.config.get("quantization", None) == "fp8": diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index a0e738a25d4..c3aee4713af 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -177,6 +177,10 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False """Update the weights of the rollout model.""" from vllm.platforms import current_platform + if not hasattr(self, "_precision_global_step"): + self._precision_global_step = -1 + self._precision_global_step += 1 + if current_platform.device_type == "npu" and self.device is None: self.device = torch.device(f"npu:{self.local_rank}") diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index f4e26f13fde..888e254aa8b 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -112,6 +112,8 @@ def __init__( self.config: RolloutConfig = omega_conf_to_dataclass(config) self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig) + self.precision_debugger_cfg = getattr(self.config, "precision_debugger", None) + self.precision_global_step = None max_position_embeddings = get_max_position_embeddings(self.model_config.hf_config) if self.config.max_model_len is None: self.config.max_model_len = max_position_embeddings @@ -178,6 +180,9 @@ def get_master_address(self): """ return self._master_address, self._master_port, self._dp_rpc_port + def set_precision_global_step(self, global_step: int) -> None: + self.precision_global_step = global_step + def get_server_address(self): """Get http server address and port.""" assert self._server_port is not None, "http server is not launched, port is None" @@ -445,6 +450,11 @@ def on_run_headless_done(future: asyncio.Future): self.task = asyncio.create_task(asyncio.to_thread(run_headless_wrapper)) self.task.add_done_callback(on_run_headless_done) + @DistProfiler.annotate( + precision_stage="rollout_generate", + precision_model_attr=["engine", "engine.model", "engine.model_runner.model"], + precision_global_step_attr="precision_global_step", + ) async def generate( self, prompt_ids: list[int], @@ -455,6 +465,8 @@ async def generate( priority: int = 0, ) -> TokenOutput: """Generate sequence with token-in-token-out.""" + if "_precision_global_step" in sampling_params: + self.precision_global_step = sampling_params.pop("_precision_global_step") # Calculate the maximum possible new tokens based on available context space # This serves as a safety upper bound max_possible_tokens = self.config.max_model_len - len(prompt_ids) From 67f2d1fb995d659676db7be04c7447618c07e00b Mon Sep 17 00:00:00 2001 From: TAJh Date: Fri, 6 Feb 2026 17:15:06 +0800 Subject: [PATCH 5/6] fix review --- verl/utils/profiler/profile.py | 19 --------- .../sglang_rollout/async_sglang_server.py | 7 ---- verl/workers/rollout/vllm_rollout/utils.py | 39 +++++++++++++++++++ .../rollout/vllm_rollout/vllm_async_server.py | 25 ++++++------ 4 files changed, 53 insertions(+), 37 deletions(-) diff --git a/verl/utils/profiler/profile.py b/verl/utils/profiler/profile.py index 60de4259aa9..f824d51a735 100644 --- a/verl/utils/profiler/profile.py +++ b/verl/utils/profiler/profile.py @@ -13,7 +13,6 @@ # limitations under the License. import functools -import inspect from typing import Callable, Optional from ..memory_utils import MemorySnapshotSampler, enable_memory_visualize @@ -241,24 +240,6 @@ def _should_profile(self_instance) -> bool: and profiler.check_this_rank() ) - if inspect.iscoroutinefunction(func): - - @functools.wraps(func) - async def async_wrapper(self_instance, *args, **kwargs_inner): - precision_impl = _build_precision_impl(self_instance) - precision_started = _precision_start(precision_impl, self_instance, args, kwargs_inner) - try: - if _should_profile(self_instance): - impl = self_instance.profiler._impl - wrapped = _decorate_with_profiler(impl, func) - return await wrapped(self_instance, *args, **kwargs_inner) - return await func(self_instance, *args, **kwargs_inner) - finally: - if precision_impl is not None and precision_stage: - precision_impl.stop(started=precision_started, step=precision_step) - - return async_wrapper - def decorator(func_inner): @functools.wraps(func_inner) def wrapper(self_instance, *args, **kwargs_inner): diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index 1294c8bb046..f1224160d0b 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -399,11 +399,6 @@ async def clear_kv_cache(self): obj = ReleaseMemoryOccupationReqInput(tags=["kv_cache"]) await self.tokenizer_manager.release_memory_occupation(obj, None) - @DistProfiler.annotate( - precision_stage="rollout_generate", - precision_model_attr=["model", "model_runner.model"], - precision_global_step_attr="precision_global_step", - ) async def generate( self, prompt_ids: torch.Tensor, @@ -413,8 +408,6 @@ async def generate( video_data: Optional[list[Any]] = None, ) -> TokenOutput: """Generate sequence with token-in-token-out.""" - if "_precision_global_step" in sampling_params: - self.precision_global_step = sampling_params.pop("_precision_global_step") # TODO(@wuxibin): switch to `/generate` http endpoint once multi-modal support ready. max_possible_tokens = self.config.max_model_len - len(prompt_ids) diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index c3aee4713af..fe0b32cbf6a 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -172,6 +172,45 @@ def monkey_patch_model(self, vocab_size: int): monkey_patch_compute_logits(self.model_runner.model, vocab_size) # patch weight loader to support MoE model patch_vllm_moe_model_weight_loader(self.model_runner.model) + self._attach_precision_debugger() + + def _attach_precision_debugger(self) -> None: + cfg_json = os.getenv("VERL_PRECISION_DEBUGGER_CONFIG_JSON", None) + if not cfg_json: + return + try: + precision_cfg = json.loads(cfg_json) + except Exception: + return + if not precision_cfg or not precision_cfg.get("enable", False): + return + + model = self.model_runner.model + if not hasattr(model, "forward"): + return + + original_forward = model.forward + extension_self = self + + def precision_forward(self, *args, **kwargs): + from verl.utils.profiler.precision_debugger_profile import PrecisionDebuggerProfiler + + if not hasattr(extension_self, "_precision_global_step"): + extension_self._precision_global_step = None + profiler = PrecisionDebuggerProfiler( + precision_cfg, rank=getattr(extension_self, "local_rank", None) + ) + started = profiler.start( + stage="rollout_generate", + global_step=getattr(extension_self, "_precision_global_step", None), + model=model, + ) + try: + return original_forward(*args, **kwargs) + finally: + profiler.stop(started=started) + + model.forward = MethodType(precision_forward, model) def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False, use_shm: bool = False): """Update the weights of the rollout model.""" diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 888e254aa8b..4f327699547 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -113,7 +113,6 @@ def __init__( self.config: RolloutConfig = omega_conf_to_dataclass(config) self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig) self.precision_debugger_cfg = getattr(self.config, "precision_debugger", None) - self.precision_global_step = None max_position_embeddings = get_max_position_embeddings(self.model_config.hf_config) if self.config.max_model_len is None: self.config.max_model_len = max_position_embeddings @@ -173,6 +172,19 @@ def __init__( f"data_parallel_rpc_port: {self._dp_rpc_port}, data_parallel_master_port: {self._dp_master_port}" ) + def _export_precision_debugger_env(self) -> None: + precision_cfg = self.precision_debugger_cfg + if not precision_cfg or not getattr(precision_cfg, "enable", False): + return + try: + if hasattr(precision_cfg, "to_container"): + precision_cfg = precision_cfg.to_container(resolve=True) + if isinstance(precision_cfg, dict): + os.environ["VERL_PRECISION_DEBUGGER_CONFIG_JSON"] = json.dumps(precision_cfg) + except Exception: + # Best-effort only; precision debugger should not block server launch + return + def get_master_address(self): """Get master address and port for data parallel. Returns: @@ -180,9 +192,6 @@ def get_master_address(self): """ return self._master_address, self._master_port, self._dp_rpc_port - def set_precision_global_step(self, global_step: int) -> None: - self.precision_global_step = global_step - def get_server_address(self): """Get http server address and port.""" assert self._server_port is not None, "http server is not launched, port is None" @@ -212,6 +221,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non self._dp_rpc_port = dp_rpc_port # 1. setup vllm serve cli args + self._export_precision_debugger_env() engine_kwargs = self.config.get("engine_kwargs", {}).get("vllm", {}) or {} engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None} if self.config.get("limit_images", None): # support for multi-image data @@ -450,11 +460,6 @@ def on_run_headless_done(future: asyncio.Future): self.task = asyncio.create_task(asyncio.to_thread(run_headless_wrapper)) self.task.add_done_callback(on_run_headless_done) - @DistProfiler.annotate( - precision_stage="rollout_generate", - precision_model_attr=["engine", "engine.model", "engine.model_runner.model"], - precision_global_step_attr="precision_global_step", - ) async def generate( self, prompt_ids: list[int], @@ -465,8 +470,6 @@ async def generate( priority: int = 0, ) -> TokenOutput: """Generate sequence with token-in-token-out.""" - if "_precision_global_step" in sampling_params: - self.precision_global_step = sampling_params.pop("_precision_global_step") # Calculate the maximum possible new tokens based on available context space # This serves as a safety upper bound max_possible_tokens = self.config.max_model_len - len(prompt_ids) From 7c440ecd8c28eadd0fdb80616f124c1cd8fe2194 Mon Sep 17 00:00:00 2001 From: TAJh Date: Fri, 6 Feb 2026 17:51:47 +0800 Subject: [PATCH 6/6] fix review 2.0 --- verl/utils/profiler/__init__.py | 2 + verl/utils/profiler/config.py | 5 - .../profiler/precision_debugger_profile.py | 45 +++--- verl/utils/profiler/precision_hook.py | 131 ++++++++++++++++++ verl/utils/profiler/profile.py | 67 +-------- verl/workers/actor/dp_actor.py | 8 +- verl/workers/actor/megatron_actor.py | 8 +- verl/workers/engine_workers.py | 24 ++-- verl/workers/fsdp_workers.py | 9 +- 9 files changed, 183 insertions(+), 116 deletions(-) create mode 100644 verl/utils/profiler/precision_hook.py diff --git a/verl/utils/profiler/__init__.py b/verl/utils/profiler/__init__.py index 73edb01a02c..84fc6a63d93 100644 --- a/verl/utils/profiler/__init__.py +++ b/verl/utils/profiler/__init__.py @@ -15,6 +15,7 @@ from ..device import is_npu_available from ..import_utils import is_nvtx_available from .performance import GPUMemoryLogger, log_gpu_memory_usage, simple_timer +from .precision_hook import PrecisionDebuggerLogger from .profile import DistProfiler, DistProfilerExtension, ProfilerConfig # Select marker implementations by availability, but keep DistProfiler as our dispatcher @@ -37,4 +38,5 @@ "ProfilerConfig", "simple_timer", "marked_timer", + "PrecisionDebuggerLogger", ] diff --git a/verl/utils/profiler/config.py b/verl/utils/profiler/config.py index 2c18e930d88..442982d9f12 100644 --- a/verl/utils/profiler/config.py +++ b/verl/utils/profiler/config.py @@ -87,7 +87,6 @@ class PrecisionDebuggerToolConfig(BaseConfig): data_dir: str = "outputs/precision_debug" steps: Optional[list[int]] = None stages: Optional[list[str]] = None - concurrency: str = "serialize" # serialize | per_thread | per_request strict: bool = False def __post_init__(self) -> None: @@ -99,10 +98,6 @@ def __post_init__(self) -> None: assert isinstance(self.steps, list), f"steps must be list[int], got {type(self.steps)}" if self.stages is not None: assert isinstance(self.stages, list), f"stages must be list[str], got {type(self.stages)}" - assert isinstance(self.concurrency, str), f"concurrency must be str, got {type(self.concurrency)}" - assert self.concurrency in {"serialize", "per_thread", "per_request"}, ( - "concurrency must be one of serialize, per_thread, per_request" - ) assert isinstance(self.strict, bool), f"strict must be bool, got {type(self.strict)}" diff --git a/verl/utils/profiler/precision_debugger_profile.py b/verl/utils/profiler/precision_debugger_profile.py index 897243017ee..be49a9ac8db 100644 --- a/verl/utils/profiler/precision_debugger_profile.py +++ b/verl/utils/profiler/precision_debugger_profile.py @@ -22,18 +22,6 @@ _GLOBAL_LOCK = threading.Lock() -_THREAD_LOCKS: dict[int, threading.Lock] = {} -_THREAD_LOCKS_LOCK = threading.Lock() - - -def _get_thread_lock() -> threading.Lock: - tid = threading.get_ident() - with _THREAD_LOCKS_LOCK: - lock = _THREAD_LOCKS.get(tid) - if lock is None: - lock = threading.Lock() - _THREAD_LOCKS[tid] = lock - return lock class PrecisionDebuggerProfiler: @@ -49,6 +37,7 @@ def __init__(self, precision_cfg, rank: Optional[int] = None): self._active_lock: Optional[threading.Lock] = None self._enabled = self._is_enabled(self.precision_cfg) self._available = is_msprobe_available() + self._debugger = None @staticmethod def _normalize_config(precision_cfg) -> PrecisionDebuggerToolConfig: @@ -77,11 +66,7 @@ def _should_collect(self, stage: str, global_step: Optional[int]) -> bool: return True def _get_lock(self) -> threading.Lock: - if self.precision_cfg.concurrency == "serialize": - return _GLOBAL_LOCK - if self.precision_cfg.concurrency == "per_thread": - return _get_thread_lock() - return threading.Lock() + return _GLOBAL_LOCK def start(self, stage: str, global_step: Optional[int] = None, model=None) -> bool: if not self._should_collect(stage=stage, global_step=global_step): @@ -107,13 +92,18 @@ def start(self, stage: str, global_step: Optional[int] = None, model=None) -> bo try: from msprobe.pytorch import PrecisionDebugger - debugger = PrecisionDebugger._instance - if debugger is None or self.precision_cfg.concurrency == "per_request": - PrecisionDebugger(config_path=config_path, dump_path=dump_path) - debugger = PrecisionDebugger._instance - if debugger is None: - return False - debugger.service.config.dump_path = dump_path + debugger = None + if self._debugger is None: + debugger = PrecisionDebugger(config_path=config_path, dump_path=dump_path) + if debugger is None: + if self.precision_cfg.strict: + raise RuntimeError("Failed to create PrecisionDebugger instance") + return False + self._debugger = debugger + else: + debugger = self._debugger + if hasattr(debugger, "service") and hasattr(debugger.service, "config"): + debugger.service.config.dump_path = dump_path debugger.start(model) return True except Exception: @@ -130,14 +120,13 @@ def stop(self, started: bool = False, step: bool = False) -> None: self._release_lock() return try: - from msprobe.pytorch import PrecisionDebugger - - debugger = PrecisionDebugger._instance + debugger = self._debugger if debugger is None: return debugger.stop() if step: - debugger.step() + if hasattr(debugger, "step"): + debugger.step() finally: self._release_lock() diff --git a/verl/utils/profiler/precision_hook.py b/verl/utils/profiler/precision_hook.py new file mode 100644 index 00000000000..b38f3fc510b --- /dev/null +++ b/verl/utils/profiler/precision_hook.py @@ -0,0 +1,131 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from .precision_debugger_profile import PrecisionDebuggerProfiler + + +def _resolve_attr(obj, attr): + if not isinstance(attr, str): + return None + if "." in attr: + current = obj + for part in attr.split("."): + current = getattr(current, part, None) + if current is None: + return None + return current + return getattr(obj, attr, None) + + +def _get_model(self_instance, precision_model_attr): + if precision_model_attr is None: + return None + if isinstance(precision_model_attr, (list, tuple)): + for attr in precision_model_attr: + val = _resolve_attr(self_instance, attr) + if val is not None: + return val + return None + return _resolve_attr(self_instance, precision_model_attr) + + +def _get_global_step(self_instance, args, kwargs, precision_global_step_attr: Optional[str]): + for val in list(args) + list(kwargs.values()): + if hasattr(val, "meta_info"): + meta = getattr(val, "meta_info") + if isinstance(meta, dict) and "global_steps" in meta: + return meta.get("global_steps") + if isinstance(val, dict) and "global_steps" in val: + return val.get("global_steps") + if precision_global_step_attr and hasattr(self_instance, precision_global_step_attr): + return getattr(self_instance, precision_global_step_attr) + if hasattr(self_instance, "precision_global_step"): + return getattr(self_instance, "precision_global_step") + return None + + +def build_precision_impl(self_instance, precision_stage: Optional[str]): + precision_cfg = getattr(self_instance, "precision_debugger_cfg", None) + if not precision_cfg or not precision_stage: + return None + rank = getattr(getattr(self_instance, "profiler", None), "rank", None) + return PrecisionDebuggerProfiler(precision_cfg, rank=rank) + + +def start_precision( + precision_impl: Optional[PrecisionDebuggerProfiler], + self_instance, + args, + kwargs, + precision_stage: Optional[str], + precision_model_attr, + precision_global_step_attr: Optional[str], +) -> bool: + if precision_impl is None: + return False + global_step = _get_global_step(self_instance, args, kwargs, precision_global_step_attr) + model = _get_model(self_instance, precision_model_attr) + return precision_impl.start(stage=precision_stage, global_step=global_step, model=model) + + +def stop_precision( + precision_impl: Optional[PrecisionDebuggerProfiler], + started: bool, + precision_step: bool, +) -> None: + if precision_impl is None: + return + precision_impl.stop(started=started, step=precision_step) + + +class PrecisionDebuggerLogger: + """Decorator to run PrecisionDebugger around a method call. + + Example: + >>> @PrecisionDebuggerLogger(stage="train", model_attr="actor_module") + >>> def update_policy(self, batch): ... + """ + + def __init__( + self, + stage: str, + model_attr: Optional[object] = None, + global_step_attr: Optional[str] = None, + step: bool = False, + ): + self.stage = stage + self.model_attr = model_attr + self.global_step_attr = global_step_attr + self.step = step + + def __call__(self, decorated_function: callable): + def wrapper(self_instance, *args, **kwargs): + precision_impl = build_precision_impl(self_instance, self.stage) + started = start_precision( + precision_impl, + self_instance, + args, + kwargs, + self.stage, + self.model_attr, + self.global_step_attr, + ) + try: + return decorated_function(self_instance, *args, **kwargs) + finally: + stop_precision(precision_impl, started, self.step) + + return wrapper diff --git a/verl/utils/profiler/profile.py b/verl/utils/profiler/profile.py index f824d51a735..7984181ce9f 100644 --- a/verl/utils/profiler/profile.py +++ b/verl/utils/profiler/profile.py @@ -165,65 +165,8 @@ def annotate( color: Optional[str] = None, domain: Optional[str] = None, category: Optional[str] = None, - precision_stage: Optional[str] = None, - precision_model_attr: Optional[object] = None, - precision_global_step_attr: Optional[str] = None, - precision_step: bool = False, **kwargs_outer, ) -> Callable: - def _get_model(self_instance): - if precision_model_attr is None: - return None - if isinstance(precision_model_attr, (list, tuple)): - for attr in precision_model_attr: - val = _resolve_attr(self_instance, attr) - if val is not None: - return val - return None - return _resolve_attr(self_instance, precision_model_attr) - - def _resolve_attr(obj, attr): - if not isinstance(attr, str): - return None - if "." in attr: - current = obj - for part in attr.split("."): - current = getattr(current, part, None) - if current is None: - return None - return current - return getattr(obj, attr, None) - - def _get_global_step(self_instance, args, kwargs): - for val in list(args) + list(kwargs.values()): - if hasattr(val, "meta_info"): - meta = getattr(val, "meta_info") - if isinstance(meta, dict) and "global_steps" in meta: - return meta.get("global_steps") - if isinstance(val, dict) and "global_steps" in val: - return val.get("global_steps") - if precision_global_step_attr and hasattr(self_instance, precision_global_step_attr): - return getattr(self_instance, precision_global_step_attr) - if hasattr(self_instance, "precision_global_step"): - return getattr(self_instance, "precision_global_step") - return None - - def _build_precision_impl(self_instance): - precision_cfg = getattr(self_instance, "precision_debugger_cfg", None) - if not precision_cfg or not precision_stage: - return None - from .precision_debugger_profile import PrecisionDebuggerProfiler - - rank = getattr(getattr(self_instance, "profiler", None), "rank", None) - return PrecisionDebuggerProfiler(precision_cfg, rank=rank) - - def _precision_start(precision_impl, self_instance, args, kwargs_inner): - if precision_impl is None: - return False - global_step = _get_global_step(self_instance, args, kwargs_inner) - model = _get_model(self_instance) - return precision_impl.start(stage=precision_stage, global_step=global_step, model=model) - def _decorate_with_profiler(impl, func_inner): if hasattr(impl, "annotate"): return impl.annotate(message=message, color=color, domain=domain, category=category, **kwargs_outer)( @@ -243,17 +186,15 @@ def _should_profile(self_instance) -> bool: def decorator(func_inner): @functools.wraps(func_inner) def wrapper(self_instance, *args, **kwargs_inner): - precision_impl = _build_precision_impl(self_instance) - precision_started = _precision_start(precision_impl, self_instance, args, kwargs_inner) try: if _should_profile(self_instance): impl = self_instance.profiler._impl wrapped = _decorate_with_profiler(impl, func_inner) - return wrapped(self_instance, *args, **kwargs_inner) + try: + return wrapped(self_instance, *args, **kwargs_inner) + except Exception: + return func_inner(self_instance, *args, **kwargs_inner) return func_inner(self_instance, *args, **kwargs_inner) - finally: - if precision_impl is not None and precision_stage: - precision_impl.stop(started=precision_started, step=precision_step) return wrapper diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index c22de374010..df98ac58b14 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -31,7 +31,7 @@ from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input from verl.utils.device import get_device_id, get_device_name from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ -from verl.utils.profiler import DistProfiler, GPUMemoryLogger +from verl.utils.profiler import DistProfiler, GPUMemoryLogger, PrecisionDebuggerLogger from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch from verl.utils.torch_dtypes import PrecisionType @@ -389,7 +389,8 @@ def _forward_micro_batch( outputs["sum_pi_squared"] = sum_pi_squared return outputs - @DistProfiler.annotate(precision_stage="update_actor", precision_model_attr="actor_module", precision_step=True) + @PrecisionDebuggerLogger(stage="update_actor", model_attr="actor_module", step=True) + @DistProfiler.annotate() def _optimizer_step(self): assert self.config.grad_clip is not None if self.scaler is not None: @@ -501,7 +502,8 @@ def compute_log_prob(self, data: DataProto, calculate_entropy: bool = False) -> return outputs @GPUMemoryLogger(role="dp actor", logger=logger) - @DistProfiler.annotate(precision_stage="train", precision_model_attr="actor_module") + @PrecisionDebuggerLogger(stage="train", model_attr="actor_module") + @DistProfiler.annotate() def update_policy(self, data: DataProto): # make sure we are in training mode self.actor_module.train() diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index a4450b25ad1..f32a097f6e0 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -39,7 +39,7 @@ from verl import DataProto from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty from verl.utils.device import get_device_id, get_torch_device -from verl.utils.profiler import DistProfiler +from verl.utils.profiler import DistProfiler, PrecisionDebuggerLogger from verl.utils.megatron.pipeline_parallel import make_batch_generator from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction from verl.utils.megatron.router_replay_utils import ( @@ -755,7 +755,8 @@ def logits_processor(logits, label, label_mask): return losses_reduced @GPUMemoryLogger(role="megatron actor", logger=logger) - @DistProfiler.annotate(precision_stage="train", precision_model_attr="actor_module") + @PrecisionDebuggerLogger(stage="train", model_attr="actor_module") + @DistProfiler.annotate() def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = False) -> dict: """Update the policy with an iterator of DataProto @@ -826,6 +827,7 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals get_torch_device().empty_cache() return metrics - @DistProfiler.annotate(precision_stage="update_actor", precision_model_attr="actor_module", precision_step=True) + @PrecisionDebuggerLogger(stage="update_actor", model_attr="actor_module", step=True) + @DistProfiler.annotate() def _optimizer_step_with_precision(self): return self.actor_optimizer.step() diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index 6b0e8c406a9..bd95c65643e 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -37,7 +37,13 @@ from verl.utils.flops_counter import FlopsCounter from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.metric.utils import Metric -from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage +from verl.utils.profiler import ( + DistProfiler, + DistProfilerExtension, + PrecisionDebuggerLogger, + ProfilerConfig, + log_gpu_memory_usage, +) from verl.utils.py_functional import append_to_dict from verl.utils.tensordict_utils import maybe_fix_3d_position_ids from verl.utils.torch_functional import allgather_dict_into_dict @@ -552,23 +558,15 @@ def init_model(self): ) @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="ref")) - @DistProfiler.annotate( - color="olive", - role="ref_compute_log_prob", - precision_stage="ref_compute_log_prob", - precision_model_attr="ref", - ) + @PrecisionDebuggerLogger(stage="ref_compute_log_prob", model_attr="ref") + @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") def compute_ref_log_prob(self, data: TensorDict) -> TensorDict: output = self.ref.infer_batch(data=data) return output.cpu() if output is not None else None @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) - @DistProfiler.annotate( - color="blue", - role="actor_compute_log_prob", - precision_stage="actor_compute_log_prob", - precision_model_attr="actor", - ) + @PrecisionDebuggerLogger(stage="actor_compute_log_prob", model_attr="actor") + @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") def compute_log_prob(self, data: TensorDict) -> TensorDict: output = self.actor.infer_batch(data) return output.cpu() if output is not None else None diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index ae701b98247..d76e7f8ce29 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -80,7 +80,14 @@ from verl.utils.import_utils import import_external_libs from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.model import compute_position_id_with_mask, convert_weight_keys -from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer +from verl.utils.profiler import ( + DistProfiler, + DistProfilerExtension, + PrecisionDebuggerLogger, + ProfilerConfig, + log_gpu_memory_usage, + simple_timer, +) from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max from verl.utils.py_functional import convert_to_regular_types from verl.utils.ray_utils import get_event_loop