Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,21 @@ global_profiler:
context: all
stacks: all
kw_args: {}
precision_debugger:
_target_: verl.utils.profiler.config.PrecisionDebuggerToolConfig
enable: false
config_path: null
data_dir: outputs/precision_debug
steps: null
stages:
- rollout_generate
- update_actor
- actor_compute_log_prob
- ref_compute_log_prob
- compute_values
- critic_update
- compute_rm_score
- train
transfer_queue:
enable: false
ray_kwargs:
Expand Down
15 changes: 15 additions & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,21 @@ global_profiler:
context: all
stacks: all
kw_args: {}
precision_debugger:
_target_: verl.utils.profiler.config.PrecisionDebuggerToolConfig
enable: false
config_path: null
data_dir: outputs/precision_debug
steps: null
stages:
- rollout_generate
- update_actor
- actor_compute_log_prob
- ref_compute_log_prob
- compute_values
- critic_update
- compute_rm_score
- train
transfer_queue:
enable: false
ray_kwargs:
Expand Down
7 changes: 7 additions & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ global_profiler:
stacks: "all"
# devices, record_context etc.
kw_args: {}
precision_debugger:
_target_: verl.utils.profiler.config.PrecisionDebuggerToolConfig
enable: False
config_path: null
data_dir: "outputs/precision_debug"
steps: null
stages: ["rollout_generate", "update_actor", "actor_compute_log_prob", "ref_compute_log_prob", "compute_values", "critic_update", "compute_rm_score", "train"]

# configs for TransferQueue
transfer_queue:
Expand Down
20 changes: 20 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,26 @@ global_profiler:

# devices, record_context etc.
kw_args: {}
# precision debugger config
precision_debugger:

# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.utils.profiler.config.PrecisionDebuggerToolConfig

# Whether to enable precision debugger
enable: False

# Path to msprobe config.json
config_path: null

# Dump root directory
data_dir: "outputs/precision_debug"

# Steps to collect, null means all
steps: null

# Stages to collect
stages: ["rollout_generate", "update_actor", "actor_compute_log_prob", "ref_compute_log_prob", "compute_values", "critic_update", "compute_rm_score", "train"]

# configs for TransferQueue
transfer_queue:
Expand Down
14 changes: 14 additions & 0 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def __init__(
self.config = config
self.reward_fn = reward_fn
self.val_reward_fn = val_reward_fn
self._propagate_precision_debugger_config()

self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
assert self.hybrid_engine, "Currently, only support hybrid engine"
Expand Down Expand Up @@ -1073,6 +1074,18 @@ def _stop_profiling(self, do_profile: bool) -> None:
if self.use_rm and not self.use_reward_loop:
self.rm_wg.stop_profile()

def _propagate_precision_debugger_config(self) -> None:
precision_cfg = OmegaConf.select(self.config, "global_profiler.global_tool_config.precision_debugger")
if precision_cfg is None:
return
with open_dict(self.config):
if OmegaConf.select(self.config, "actor_rollout_ref") is not None:
self.config.actor_rollout_ref.precision_debugger = precision_cfg
if OmegaConf.select(self.config, "critic") is not None:
self.config.critic.precision_debugger = precision_cfg
if OmegaConf.select(self.config, "reward_model") is not None:
self.config.reward_model.precision_debugger = precision_cfg

def _get_dp_size(self, worker_group, role: str) -> int:
"""Get data parallel size from worker group dispatch info.

Expand Down Expand Up @@ -1369,6 +1382,7 @@ def fit(self):
else curr_step_profile
)
batch: DataProto = DataProto.from_single_dict(batch_dict)
batch.meta_info["global_steps"] = self.global_steps
batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature

# add uid to batch
Expand Down
9 changes: 9 additions & 0 deletions verl/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ def is_trl_available():
return trl_spec is not None


@cache
def is_msprobe_available():
try:
msprobe_spec = importlib.util.find_spec("msprobe")
except ModuleNotFoundError:
msprobe_spec = None
return msprobe_spec is not None


def import_external_libs(external_libs=None):
if external_libs is None:
return
Expand Down
2 changes: 2 additions & 0 deletions verl/utils/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..device import is_npu_available
from ..import_utils import is_nvtx_available
from .performance import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
from .precision_hook import PrecisionDebuggerLogger
from .profile import DistProfiler, DistProfilerExtension, ProfilerConfig

# Select marker implementations by availability, but keep DistProfiler as our dispatcher
Expand All @@ -37,4 +38,5 @@
"ProfilerConfig",
"simple_timer",
"marked_timer",
"PrecisionDebuggerLogger",
]
23 changes: 23 additions & 0 deletions verl/utils/profiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,29 @@ def __post_init__(self) -> None:
assert self.stack_depth > 0, f"stack_depth must be positive, got {self.stack_depth}"


@dataclass
class PrecisionDebuggerToolConfig(BaseConfig):
"""Precision debugger tool config (msprobe)."""

enable: bool = False
config_path: Optional[str] = None
data_dir: str = "outputs/precision_debug"
steps: Optional[list[int]] = None
stages: Optional[list[str]] = None
strict: bool = False

def __post_init__(self) -> None:
assert isinstance(self.enable, bool), f"enable must be bool, got {type(self.enable)}"
if self.config_path is not None:
assert isinstance(self.config_path, str), f"config_path must be str, got {type(self.config_path)}"
assert isinstance(self.data_dir, str), f"data_dir must be str, got {type(self.data_dir)}"
if self.steps is not None:
assert isinstance(self.steps, list), f"steps must be list[int], got {type(self.steps)}"
if self.stages is not None:
assert isinstance(self.stages, list), f"stages must be list[str], got {type(self.stages)}"
assert isinstance(self.strict, bool), f"strict must be bool, got {type(self.strict)}"


@dataclass
class NPUToolConfig(NsightToolConfig):
"""NPU profiler too; config."""
Expand Down
137 changes: 137 additions & 0 deletions verl/utils/profiler/precision_debugger_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import threading
from dataclasses import asdict
from typing import Optional

from verl.utils.import_utils import is_msprobe_available
from verl.utils.profiler.config import PrecisionDebuggerToolConfig


_GLOBAL_LOCK = threading.Lock()


class PrecisionDebuggerProfiler:
"""Precision debugger wrapper for msprobe.

This class implements a minimal start/stop contract and is intentionally
not a DistProfiler subclass to keep the dependency one-way.
"""

def __init__(self, precision_cfg, rank: Optional[int] = None):
self.rank = rank
self.precision_cfg = self._normalize_config(precision_cfg)
self._active_lock: Optional[threading.Lock] = None
self._enabled = self._is_enabled(self.precision_cfg)
self._available = is_msprobe_available()
self._debugger = None

@staticmethod
def _normalize_config(precision_cfg) -> PrecisionDebuggerToolConfig:
if precision_cfg is None:
return PrecisionDebuggerToolConfig()
if isinstance(precision_cfg, PrecisionDebuggerToolConfig):
return precision_cfg
if hasattr(precision_cfg, "to_container"):
precision_cfg = precision_cfg.to_container(resolve=True)
if isinstance(precision_cfg, dict):
return PrecisionDebuggerToolConfig(**precision_cfg)
return PrecisionDebuggerToolConfig(**asdict(precision_cfg))

@staticmethod
def _is_enabled(precision_cfg: PrecisionDebuggerToolConfig) -> bool:
return bool(precision_cfg.enable)

def _should_collect(self, stage: str, global_step: Optional[int]) -> bool:
if not self._enabled:
return False
if self.precision_cfg.stages is not None and stage not in set(self.precision_cfg.stages):
return False
if self.precision_cfg.steps is not None and global_step is not None:
if int(global_step) not in set(self.precision_cfg.steps):
return False
return True

def _get_lock(self) -> threading.Lock:
return _GLOBAL_LOCK

def start(self, stage: str, global_step: Optional[int] = None, model=None) -> bool:
if not self._should_collect(stage=stage, global_step=global_step):
return False
if not self._available:
if self.precision_cfg.strict:
raise ImportError("msprobe is not available but precision_debugger.strict is True")
return False

config_path = self.precision_cfg.config_path
data_dir = self.precision_cfg.data_dir
if not config_path or not data_dir:
return False

step_tag = f"step_{global_step}" if global_step is not None else "step_unknown"
rank_tag = f"rank_{self.rank}" if self.rank is not None else "rank_unknown"
dump_path = os.path.join(data_dir, step_tag, stage, rank_tag)
os.makedirs(dump_path, exist_ok=True)

lock = self._get_lock()
lock.acquire()
self._active_lock = lock
try:
from msprobe.pytorch import PrecisionDebugger

debugger = None
if self._debugger is None:
debugger = PrecisionDebugger(config_path=config_path, dump_path=dump_path)
if debugger is None:
if self.precision_cfg.strict:
raise RuntimeError("Failed to create PrecisionDebugger instance")
return False
self._debugger = debugger
else:
debugger = self._debugger
if hasattr(debugger, "service") and hasattr(debugger.service, "config"):
debugger.service.config.dump_path = dump_path
debugger.start(model)
return True
except Exception:
self._release_lock()
if self.precision_cfg.strict:
raise
return False

def stop(self, started: bool = False, step: bool = False) -> None:
if not started:
self._release_lock()
return
if not self._available:
self._release_lock()
return
try:
debugger = self._debugger
if debugger is None:
return
debugger.stop()
if step:
if hasattr(debugger, "step"):
debugger.step()
finally:
self._release_lock()

def _release_lock(self) -> None:
lock = self._active_lock
self._active_lock = None
if lock is not None:
lock.release()
Loading