Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 run_train.py --config-
The model will be saved in the `checkpoints` directory as specified in the config file.

> [!NOTE]
> You can use `examples/config_tiny_llama.py` to generate your own training config
> You can use `examples/config_tiny_llama.py` to generate your own training config

For detailed instructions on training your first model, check out our [Your First Training guide](docs/your-first-training.md). For multi-node training with Slurm, see our [Multi-Node Training guide](docs/multi-node-training.md).

Expand Down Expand Up @@ -175,6 +175,7 @@ We currently support the following features:
- [x] Custom module checkpointing for large models
- [x] Spectral µTransfer parametrization for scaling up neural networks
- [x] Mamba example
- [x] CUDA event-based timing for accurate GPU performance measurement

And we have on our roadmap:
- [ ] FP8 training
Expand Down
92 changes: 92 additions & 0 deletions docs/cuda_event_timing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# CUDA Event-Based Timing in Nanotron

## Overview

Nanotron now uses CUDA events for timing GPU operations instead of CPU-based timing with `time.time()`. This change provides several benefits:

1. **More accurate measurement of GPU execution time**: CUDA events are recorded directly on the GPU timeline, providing more precise timing of GPU operations.
2. **Reduced need for explicit CUDA synchronization**: CPU-based timing requires synchronization between CPU and GPU to get accurate measurements, which can introduce overhead and affect performance.
3. **Lower overhead**: CUDA event-based timing has minimal impact on the execution of GPU operations.
4. **Better performance monitoring**: More accurate timing leads to better performance analysis and optimization.

## Implementation Details

The implementation uses `torch.cuda.Event` with `enable_timing=True` to create start and end events that are recorded on the GPU timeline. The elapsed time is then calculated using `start_event.elapsed_time(end_event)`, which returns the time in milliseconds.

### Key Changes

1. **Default Timer Type**: The default timer type in `nanotron/src/nanotron/logging/timers.py` has been changed from `TimerType.CPU` to `TimerType.CUDA`.

2. **Iteration Timing**: The iteration timing in `trainer.py` now uses CUDA events instead of `time.time()`.

3. **Synchronization Control**: By default, CUDA event-based timers do not force synchronization unless explicitly requested with `cuda_sync=True`.

## Usage

### Basic Usage

```python
# Create and use a CUDA timer (default)
with nanotron_timer("my_operation"):
# Your GPU operation here
...

# Explicitly specify CUDA timing
with nanotron_timer("my_operation", timer_type="cuda"):
# Your GPU operation here
...

# For CPU-only operations, you can still use CPU-based timing
with nanotron_timer("cpu_operation", timer_type="cpu"):
# Your CPU operation here
...

# As a decorator with default CUDA timing
@nanotron_timer
def my_function():
# Your GPU operation here
...

# As a decorator with custom name
@nanotron_timer("custom_name")
def my_function():
# Your GPU operation here
...

# As a decorator with CPU timing
@nanotron_timer(timer_type=TimerType.CPU)
def my_cpu_function():
# Your CPU operation here
...
```

### Advanced Usage

```python
# Start and end a timer manually
timer = nanotron_timer("my_operation")
timer.start()
# Your operation here
timer.end()

# Get the elapsed time in seconds
elapsed_time = timer.elapsed

# Get the total time across all calls
total_time = timer.total_time

# Get the average time per call
avg_time = timer.average_time
```

## Considerations

1. **Synchronization**: By default, CUDA event-based timers do not force synchronization to avoid overhead. If you need more accurate timing at the cost of performance, you can set `cuda_sync=True`.

2. **Units**: CUDA events measure time in milliseconds, but the timer API converts this to seconds for consistency with the previous CPU-based timing.

3. **Fallback**: If CUDA is not available, the timer will automatically fall back to CPU-based timing.

## Performance Impact

Using CUDA events for timing instead of CPU-based timing with synchronization can significantly reduce overhead, especially in distributed training scenarios with thousands of GPUs.
22 changes: 15 additions & 7 deletions examples/config_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"410m": (24, 1024, 16, 16, 4096), # ~410M params
# Small to medium models
"1b": (16, 2048, 16, 16, 5632), # ~1B params
"3b": (28, 2048, 16, 2, 11008), # ~3B params
"3b": (36, 2048, 16, 4, 11008), # ~3B params
# Standard sizes
"7b": (32, 4096, 32, 32, 11008), # ~7B params
"13b": (40, 5120, 40, 40, 13824), # ~13B params
Expand All @@ -47,7 +47,7 @@ def get_args():
parser.add_argument(
"--model",
choices=MODEL_SIZES.keys(),
default="custom",
default="3b",
help="Model size to generate config for (e.g., 7b, 13b)",
)
parser.add_argument(
Expand Down Expand Up @@ -76,6 +76,10 @@ def get_args():
tokens_group.add_argument("--mbs", type=int, default=3, help="Micro batch size")
tokens_group.add_argument("--acc", type=int, default=1, help="Batch accumulation per replica")

# checkpoints
checkpoints_group = parser.add_argument_group("checkpoints")
checkpoints_group.add_argument("--ckpt-save", type=int, default=10, help="Checkpoint save interval")

args = parser.parse_args()
return args

Expand Down Expand Up @@ -108,7 +112,7 @@ def get_model_config(model_size: str) -> Qwen2Config:
is_qwen2_config=True,
pad_token_id=None,
_attn_implementation="flash_attention_2",
sliding_window_size=20,
_use_doc_masking=True,
)


Expand Down Expand Up @@ -154,7 +158,7 @@ def calculate_parameters(model_config: Qwen2Config) -> str:

def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config:
learning_rate = LRSchedulerArgs(
learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5
learning_rate=3e-4, lr_warmup_steps=2000, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=0
)
parallelism = ParallelismArgs(
dp=args.dp,
Expand All @@ -175,7 +179,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config
)
optimizer = OptimizerArgs(
zero_stage=args.zero,
weight_decay=0.01,
weight_decay=0.1,
clip_grad=1.0,
accumulate_grad_in_fp32=True,
learning_rate_scheduler=learning_rate,
Expand All @@ -192,7 +196,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config

return Config(
general=GeneralArgs(project="debug", run=args.run, seed=seed, ignore_sanity_checks=args.no_sanity),
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10),
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=args.ckpt_save),
parallelism=parallelism,
model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config),
# tokenizer=TokenizerArgs("HuggingFaceTB/cosmo2-tokenizer"),
Expand All @@ -219,7 +223,11 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config
world_size = args.dp * args.tp * args.pp * args.cp
if world_size <= 8:
print(
f"CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}"
f"ENABLE_TIMERS=1 DEBUG_CPU=1 STATS_SAMPLING_INTERVAL_IN_SEC=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}"
)
print("You can also use environment variables for more debugging:")
print(" - ENABLE_TIMERS=1: Enable detailed timing information")
print(" - DEBUG_CPU=1: Log CPU and memory usage statistics")
print(" - STATS_SAMPLING_INTERVAL_IN_SEC=1: Set sampling interval for metrics collection")
else:
print("Checkout slurm_launcher.py to launch a multi-node job")
20 changes: 10 additions & 10 deletions examples/config_qwen.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
checkpoints:
checkpoint_interval: 10
checkpoint_interval: 100000
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
load_lr_scheduler: true
Expand Down Expand Up @@ -30,9 +30,9 @@ data_stages:
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: false
ignore_sanity_checks: true
project: debug
run: qwen_20250410_014907_16027793
run: qwen_20250424_120835_16423158
seed: 42
step: null
lighteval: null
Expand All @@ -45,6 +45,7 @@ model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
scaling_method: NUM_LAYERS
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
Expand All @@ -58,23 +59,23 @@ model:
eos_token_id: 2
flex_attention_mask: null
hidden_act: silu
hidden_size: 256
hidden_size: 2048
initializer_range: 0.02
intermediate_size: 768
intermediate_size: 11008
is_qwen2_config: true
max_position_embeddings: 4096
moe_config: null
no_rope_layer: null
num_attention_heads: 4
num_hidden_layers: 12
num_attention_heads: 16
num_hidden_layers: 36
num_key_value_heads: 4
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-06
rope_interleaved: false
rope_scaling: null
rope_theta: 10000.0
sliding_window_size: 20
sliding_window_size: null
tie_word_embeddings: true
use_cache: true
vocab_size: 128256
Expand Down Expand Up @@ -104,11 +105,10 @@ parallelism:
context_parallel_size: 1
dp: 2
expert_parallel_size: 1
moe_layer_recompute: false
pp: 1
pp_engine: 1f1b
recompute_layer: false
tp: 1
tp: 2
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
tp_recompute_allgather: true
Expand Down
132 changes: 132 additions & 0 deletions examples/config_qwen_with_moe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
checkpoints:
checkpoint_interval: 1000
checkpoints_path: /fsx/phuc/new_workspace/experiments/qwen2_moe_test
checkpoints_path_is_shared_file_system: false
load_lr_scheduler: true
load_optimizer: true
resume_checkpoint_path: null
save_final_state: true
save_initial_state: false
data_stages:
- data:
dataset:
dataset_folder:
- /fsx/loubna/datasets/llama_tokenized/fineweb-edu/merged
dataset_max_tokens: null
dataset_read_path: null
dataset_weights: null
pad_samples_to_global_batch_size: false
return_positions: true
shuffle_files: false
skip_in_stream: false
token_size_in_bytes: 4
tokenizer_name: meta-llama/Llama-3.2-1B
use_old_brrr_dataloader: false
vocab_size: 128256
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: false
project: qwen_moe
run: qwen_20250410_014907_16027793
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
metrics_logging: null
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
_attn_implementation: flash_attention_2
_fused_rms_norm: true
_fused_rotary_emb: true
_use_doc_masking: true
_use_qkv_packed: true
attention_bias: false
bos_token_id: 1
eos_token_id: 2
flex_attention_mask: null
hidden_act: silu
hidden_size: 256
initializer_range: 0.02
intermediate_size: 768
is_qwen2_config: true
max_position_embeddings: 4096
moe_config: null
no_rope_layer: null
num_attention_heads: 4
num_hidden_layers: 12
num_key_value_heads: 4
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-06
rope_interleaved: false
rope_scaling: null
rope_theta: 10000.0
sliding_window_size: 20
tie_word_embeddings: true
use_cache: true
vocab_size: 128256
z_loss_coefficient: 0.0001
z_loss_enabled: false
moe_config:
num_experts: 8
top_k: 1
enable_shared_expert: true
token_dispatcher_type: alltoall
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: 31998
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
weight_decay_exclude_named_params: []
zero_stage: 0
parallelism:
context_parallel_size: 1
dp: 2
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
recompute_layer: false
tp: 1
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
tp_recompute_allgather: true
profiler: null
s3_upload: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: meta-llama/Llama-3.2-1B
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 3
sequence_length: 4096
train_steps: 32000
val_check_interval: -1
Loading