Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] [ROCm] Fine-tuning DeepSeek-Coder-V2-Lite-Instruct with 8 MI300X GPUs results in c10::DistBackendError #6725

Open
nikhil-tensorwave opened this issue Nov 8, 2024 · 2 comments
Assignees
Labels
bug Something isn't working rocm AMD/ROCm/HIP issues training

Comments

@nikhil-tensorwave
Copy link

nikhil-tensorwave commented Nov 8, 2024

Describe the bug
I am trying to fine-tune DeepSeek-Coder-V2-Lite-Instruct (16B) on a system with 8 MI300X GPUs. Running on any number of GPUs less than 8 works as expected and runs to completion. When running on 8 GPUs, the training starts, hangs, and then outputs one of two errors. One error is:

Memory access fault by GPU node-4 (Agent handle: 0x56387095fa40) on address (nil). Reason: Unknown.

where the GPU node is different from run to run.
The second error (truncated) is:

[rank4]:[E1107 22:03:59.159782320 ProcessGroupNCCL.cpp:692] [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=473096, OpType=_ALLGATHER
_BASE, NumelIn=360448, NumelOut=2883584, Timeout(ms)=1800000) ran for 1800004 milliseconds before timing out.                                               
[rank4]:[E1107 22:03:59.159832649 ProcessGroupNCCL.cpp:715] Stack trace of the timedout collective not found, potentially because FlightRecorder is disabled
. You can enable it by setting TORCH_NCCL_TRACE_BUFFER_SIZE to a non-zero value.                                                                            
[rank4]:[E1107 22:03:59.160250339 ProcessGroupNCCL.cpp:1996] [PG ID 0 PG GUID 0(default_pg) Rank 4] Work timeout is detected by watchdog at work: 473096, la
st enqueued NCCL work: 473328, last completed NCCL work: 473095.              
[rank4]:[E1107 22:03:59.160271219 ProcessGroupNCCL.cpp:734] [Rank 4] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.                                                                                   
[rank4]:[E1107 22:03:59.160280148 ProcessGroupNCCL.cpp:748] [Rank 4] To avoid data inconsistency, we are taking the entire process down.                    
[rank4]:[E1107 22:03:59.163317046 ProcessGroupNCCL.cpp:1791] [PG ID 0 PG GUID 0(default_pg) Rank 4] Process group watchdog thread terminated with exception:
 [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=473096, OpType=_ALLGATHER_BASE, NumelIn=360448, NumelOut=2883584, Timeout(ms)=180000
0) ran for 1800004 milliseconds before timing out.                                                                                                          
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:722 (most recent call first):                                     
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f64c4bc9206 in /home/tensorwave/nikhil/poc_validation/multigpu-deepseek-train-master/nightly62/lib/python3.10/site-packages/torch/lib/libc10.so)                                                                                               
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x8be (0x7f651dca876e in /home/tensorwave/nikhil/poc_validation/multigpu-deepseek-train-master/nightly62/lib/python3.10/site-packages/torch/lib/libtorch_hip.so)                        
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x930 (0x7f651dcafff0 in /home/tensorwave/nikhil/poc_validation/multigpu-deepseek-train-master/nightly62/lib/python3.10/site-packages/torch/lib/libtorch_hip.so)                                                                                                  
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f651dcb19cd in /home/tensorwave/nikhil/poc_validation/multigpu-deepseek-train-master/nightly62/lib/python3.10/site-packages/torch/lib/libtorch_hip.so)
frame #4: <unknown function> + 0x145c0 (0x7f65780195c0 in /home/tensorwave/nikhil/poc_validation/multigpu-deepseek-train-master/nightly62/lib/python3.10/site-packages/torch/lib/libtorch.so)                                                                                                                           
frame #5: <unknown function> + 0x94ac3 (0x7f6621676ac3 in /lib/x86_64-linux-gnu/libc.so.6)                                                                  
frame #6: <unknown function> + 0x126850 (0x7f6621708850 in /lib/x86_64-linux-gnu/libc.so.6)   

To Reproduce
Run command:

deepspeed --num_nodes 1 --num_gpus 8 finetune_deepseek_ds.py --exp_id amd_baseline --deepspeed_config ds_config2_zero3.json --config dp.yaml --jsonl_path train_prompts.jsonl

Training script and config files will be in the first comment.

ds_report output

[2024-11-08 00:55:32,256] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [YES] ...... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
 [WARNING]  FP Quantizer is using an untested triton version (3.1.0+cf34004b8a), only 2.3.(0, 1) and 3.0.0 are known to be compatible with these kernels
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
 [WARNING]  gds is not compatible with ROCM
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn is not compatible with ROCM
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/tensorwave/nikhil/poc_validation/multigpu-deepseek-train-master/nightly62/lib/python3.10/site-packages/torch']
torch version .................... 2.6.0.dev20241106+rocm6.2
deepspeed install path ........... ['/home/tensorwave/nikhil/poc_validation/multigpu-deepseek-train-master/nightly62/lib/python3.10/site-packages/deepspeed-0.15.2+unknown-py3.10-linux-x86_64.egg/deepspeed']
deepspeed info ................... 0.15.2+unknown, unknown, unknown
torch cuda version ............... None
torch hip version ................ 6.2.41133-dd7f95766
nvcc version ..................... None
deepspeed wheel compiled w. ...... torch 2.6, hip 6.2
shared memory (/dev/shm) size .... 1.11 TB

System info:

  • OS: Ubuntu 22.04
  • GPU count and types: 1 machine with 8 MI300X
  • Python version:
pytorch-triton-rocm==3.1.0+cf34004b8a
torch==2.6.0.dev20241106+rocm6.2
torchaudio==2.5.0.dev20241106+rocm6.2
torchvision==0.20.0.dev20241106+rocm6.2

Launcher context
Launching with deepspeed

Additional context
Running the same fine-tuning instead with smaller DeepSeek models (1B and 7B) works on 8 GPUs to completion. I am currently trying the largest DeepSeek model (200B).

@rraminen @jithunnair-amd

@nikhil-tensorwave nikhil-tensorwave added bug Something isn't working training labels Nov 8, 2024
@nikhil-tensorwave
Copy link
Author

finetune_deepseek_ds.py

import pandas as pd
import json
import os
import argparse
import yaml
from pprint import pprint
import torch
import torch.nn as nn
import transformers
from datasets import load_dataset, Dataset
from huggingface_hub import notebook_login
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer

os.environ["WANDB_DISABLED"] = "true"

# Argument parser for command-line inputs
parser = argparse.ArgumentParser(description='Fine-tune DeepSeek-Coder-V2 model')
parser.add_argument('--config', type=str, help='Path to config YAML file (optional)', default="config.yaml")
parser.add_argument('--jsonl_path', type=str, help='Path to input JSONL file (optional)', required=False)
parser.add_argument('--exp_id', type=str, help='Unique experiment ID for the run', required=True)
parser.add_argument('--deepspeed_config', type=str, help='Path to DeepSpeed config JSON file', required=False)
parser.add_argument('--local_rank', type=int, help='Local rank for distributed training', default=-1)
args = parser.parse_args()

# Load YAML config
config_path = args.config if os.path.isabs(args.config) else os.path.join(os.getcwd(), args.config)
if not os.path.exists(config_path):
    raise FileNotFoundError(f"Configuration file not found at {config_path}")
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Extract base configurations
base_dir = config['run']['base_dir']
run_name_prefix = config['run']['run_name_prefix']
run_name = f"{run_name_prefix}-{args.exp_id}"
run_dir = os.path.join(base_dir, args.exp_id)

# Determine JSONL file path
jsonl_path = args.jsonl_path if args.jsonl_path else os.path.join(run_dir, "train_prompts.jsonl")

# Check if the JSONL file exists
if not os.path.exists(jsonl_path):
    raise FileNotFoundError(f"JSONL file not found at {jsonl_path}. Please provide a valid file path.")

# Configuration
MODEL_NAME = config['model']['model_name']
"""
bnb_config = BitsAndBytesConfig(
    load_in_4bit=config['quantization']['load_in_4bit'],
    bnb_4bit_use_double_quant=config['quantization']['use_double_quant'],
    bnb_4bit_quant_type=config['quantization']['quant_type'],
    bnb_4bit_compute_dtype=torch.bfloat16 if config['quantization']['compute_dtype'] == "bfloat16" else torch.float16
)
"""
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    #quantization_config=bnb_config
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Load data
data_list = []

# Read JSONL and extract the 'text' field
with open(jsonl_path, 'r') as file:
    for line in file:
        if not line.strip():
            continue  # Skip empty lines
        try:
            data_point = json.loads(line)
            text = data_point.get('text', '')
            if text:
                data_list.append({'text': text})
        except json.JSONDecodeError as e:
            print(f"Skipping invalid JSON line: {line} with error: {e}")

# Convert to pandas DataFrame
df = pd.DataFrame(data_list)
print("Total number of training samples:", len(df))

# Create Hugging Face Dataset
train_data = Dataset.from_pandas(df[['text']])

# Shuffle dataset
train_data = train_data.shuffle()

# Set max_seq_length
max_seq_length = config['training'].get('max_seq_length', 512)

# Prepare model for training
def get_num_layers(model):
    import re
    numbers = set()
    for name, _ in model.named_parameters():
        for number in re.findall(r'\d+', name):
            numbers.add(int(number))
    return max(numbers)

def get_last_layer_linears(model):
    names = []
    num_layers = get_num_layers(model)
    for name, module in model.named_modules():
        if str(num_layers) in name and not "encoder" in name:
            if isinstance(module, torch.nn.Linear):
                names.append(name)
    return names

# Prepare model for training
config_lora = LoraConfig(
    r=config['model']['lora_r'],
    lora_alpha=config['model']['lora_alpha'],
    target_modules=get_last_layer_linears(model),
    lora_dropout=config['model']['lora_dropout'],
    bias="none",
    task_type=config['model']['task_type']
)

# Get DeepSpeed config path
if args.deepspeed_config:
    deepspeed_config_path = args.deepspeed_config
else:
    deepspeed_config_path = config['training'].get('deepspeed_config', None)

# Define training arguments
training_args = transformers.TrainingArguments(
    per_device_train_batch_size=config['training']['batch_size'],
    gradient_accumulation_steps=config['training']['grad_accumulation_steps'],
    num_train_epochs=config['training']['num_epochs'],
    learning_rate=float(config['training']['learning_rate']),
    gradient_checkpointing=True,
    fp16=config['training']['fp16'],
    bf16=config['training'].get('bf16', False),
    output_dir=os.path.join(run_dir, config['training']['output_subdir']),
    # Remove optim and lr_scheduler_type when using DeepSpeed
    # optim="paged_adamw_8bit",
    # lr_scheduler_type=config['training']['scheduler'],
    warmup_ratio=config['training']['warmup_ratio'],
    logging_steps=config['training']['logging_steps'],
    report_to=None,  # Set to None to disable reporting to wandb
    logging_dir=run_dir,
    deepspeed=deepspeed_config_path
)

# Initialize SFTTrainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    peft_config=config_lora,
    dataset_text_field="text",
    tokenizer=tokenizer,
    max_seq_length=max_seq_length,
    packing=False,
)

# Disable cache for training and start training
model.config.use_cache = False
trainer.train()

# Save the fine-tuned model in the run directory
model_output_dir = os.path.join(run_dir, "trained-model")
trainer.save_model(model_output_dir)
tokenizer.save_pretrained(model_output_dir)
print(f"Model fine-tuned and saved to '{model_output_dir}'.")

dp.yaml

run:
  project_name: "text2sql-train"
  base_dir: ""
  run_name_prefix: "fine-tune-run"

model:
  model_name: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
  lora_r: 8
  lora_alpha: 16
  lora_dropout: 0.05
  task_type: "CAUSAL_LM"
  # Added target_modules as a list
  target_modules:
    - "q_proj"
    - "k_proj"
    - "v_proj"
    - "o_proj"
    - "gate_proj"
    - "up_proj"
    - "down_proj"

training:
  batch_size: 1
  grad_accumulation_steps: 4
  num_epochs: 1
  learning_rate: 1e-4
  bf16: true
  fp16: false
  scheduler: "cosine"
  warmup_ratio: 0.01
  logging_steps: 10
  output_subdir: "checkpoints"
  max_seq_length: 8192  # Added max_seq_length

quantization:
  load_in_4bit: true
  use_double_quant: true
  quant_type: "nf4"
  compute_dtype: "bfloat16"

fsdp:
  fsdp: "full_shard auto_wrap offload"
  fsdp_config:
    backward_prefetch: "backward_pre"
    forward_prefetch: "false"
    use_orig_params: "false"

ds_config2_zero3.json

{
    "bf16": {
        "enabled": true
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },

    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e7,  
        "reduce_bucket_size": 2e7,  
        "stage3_prefetch_bucket_size": 3774874,
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 5e8,  
        "stage3_max_reuse_distance": 5e8,  
        "stage3_gather_16bit_weights_on_model_save": true
    },

    "gradient_accumulation_steps": 4, 
    "gradient_clipping": "auto",
    "steps_per_print": 20,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": false,

    "activation_checkpointing": {
        "partition_activations": true,  
        "contiguous_memory_optimization": true
    }
}

@nikhil-tensorwave nikhil-tensorwave changed the title [BUG] Fine-tuning DeepSeek-Coder-V2-Lite-Instruct with 8 MI300X GPUs results in c10::DistBackendError [BUG] [ROCm] Fine-tuning DeepSeek-Coder-V2-Lite-Instruct with 8 MI300X GPUs results in c10::DistBackendError Nov 8, 2024
@loadams
Copy link
Contributor

loadams commented Nov 8, 2024

Thanks @nikhil-tensorwave. Tagging @rraminen and @jithunnair-amd as well for help on the AMD side.

@loadams loadams self-assigned this Nov 8, 2024
@loadams loadams added the rocm AMD/ROCm/HIP issues label Nov 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working rocm AMD/ROCm/HIP issues training
Projects
None yet
Development

No branches or pull requests

2 participants