Skip to content

model.gradient_checkpointing_enable() makes loss.requires_grad be False #35826

@ZCWei51

Description

@ZCWei51

System Info

Python 3.9.19
transformers 4.42.0
torch 2.2.2+cu118
peft 0.12.0

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When I tried using model.gradient_checkpointing_enable() to reduce memory consumption during training, I encountered an error: "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn." After troubleshooting, I found that the issue seems to be caused by loss.requires_grad being set to False, which prevents backpropagation. The following is the reproducible code to directly obtain loss.requires_grad False

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import torch
from transformers import  AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType

def main():
    train_data = {"input": "input test", "output": "output test"}
    model_name = "/workspace/model/CodeLlama-13b-Instruct-hf"
    output_dir = "./test_debug"
    
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16,device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id

    input_ids = tokenizer.encode(train_data["input"])
    output_ids = tokenizer.encode(train_data["output"])
    model_inputs_output = input_ids + output_ids + [tokenizer.eos_token_id]
    model_inputs_output = torch.tensor(model_inputs_output, dtype=torch.int64)
    labels = copy.deepcopy(model_inputs_output)
    labels[: len(input_ids)] = -1 # 
    example_mask = model_inputs_output.ge(0)
    label_mask = labels.ge(0)
    model_inputs_output[~example_mask] = 0
    labels[~label_mask] = -100
    train_dataset = {
            "input_ids": model_inputs_output.unsqueeze(0).to("cuda"),
            "attention_mask": example_mask.unsqueeze(0).to("cuda"),
            "labels": labels.unsqueeze(0).to("cuda")
        }

    lora_config = LoraConfig(
            r=8,  
            lora_alpha=16,  
            target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "up_proj", "k_proj", "down_proj"],  # 与llama-factory一致
            lora_dropout=0.05,  
            task_type= TaskType.CAUSAL_LM  
        )
    model = get_peft_model(model, lora_config)
    model.gradient_checkpointing_enable()
    model.train()    
    model.print_trainable_parameters()
    model.to("cuda")

    output = model(**train_dataset)
    loss = output["loss"]
    print(f"loss: {loss.requires_grad}")


if __name__ == "__main__":
    main()

Output is

loss: False

This is confusing because model.gradient_checkpointing_enable() is designed to reduce memory consumption, but if loss.requires_grad is set to False, it disrupts the normal training process. Meanwhile, when I use similar code from LLama-factory to achieve the effect of model.gradient_checkpointing_enable(), I find that loss.requires_grad is True. Below is the code:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import torch
from transformers import  AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType
import copy
from types import MethodType
from functools import partial
import inspect
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from transformers import PreTrainedModel

def _gradient_checkpointing_enable(
    self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
) -> None:
    r"""
    Activates gradient checkpointing for the current model.

    Modification of the original method to enable gradient checkpointing for block-wise optimizer.
    """
    from torch.utils.checkpoint import checkpoint

    if not self.supports_gradient_checkpointing:
        raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))

    if gradient_checkpointing_kwargs is None:
        gradient_checkpointing_kwargs = {"use_reentrant": True}

    gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)

    def custom_gradient_checkpointing_func(func, *args, **kwargs):
        module: "torch.nn.Module" = func.__self__

        if any(param.requires_grad for param in module.parameters()):
            for arg in args:
                if torch.is_tensor(arg) and torch.is_floating_point(arg):
                    arg.requires_grad_(True)

        return gradient_checkpointing_func(func, *args, **kwargs)

    if "value" in inspect.signature(self._set_gradient_checkpointing).parameters:  # old GC format
        self.apply(partial(self._set_gradient_checkpointing, value=True))
        self.enable_input_require_grads()
        print("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
    else:  # have already enabled input require gradients
        self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)


def main():
    train_data = {"input": "input test", "output": "output test"}
    model_name = "/workspace/model/CodeLlama-13b-Instruct-hf"
    output_dir = "./test_debug"
    
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16,device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    # set the pad token of the model's configuration
    model.config.pad_token_id = model.config.eos_token_id
    # return 
    if not getattr(model, "supports_gradient_checkpointing", False):
        print("Current model does not support gradient checkpointing.")
    else:
        # use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
        # According to: https://github.com/huggingface/transformers/issues/28339
        model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
        setattr(model.config, "use_cache", False)  # turn off when gradient checkpointing is enabled
        print("Gradient checkpointing enabled.")

    input_ids = tokenizer.encode(train_data["input"])
    output_ids = tokenizer.encode(train_data["output"])
    model_inputs_output = input_ids + output_ids + [tokenizer.eos_token_id]
    model_inputs_output = torch.tensor(model_inputs_output, dtype=torch.int64)
    labels = copy.deepcopy(model_inputs_output)
    labels[: len(input_ids)] = -1 # 
    example_mask = model_inputs_output.ge(0)
    label_mask = labels.ge(0)
    model_inputs_output[~example_mask] = 0
    labels[~label_mask] = -100
    train_dataset = {
            "input_ids": model_inputs_output.unsqueeze(0).to("cuda"),
            "attention_mask": example_mask.unsqueeze(0).to("cuda"),
            "labels": labels.unsqueeze(0).to("cuda")
        }

    lora_config = LoraConfig(
            r=8,  
            lora_alpha=16,  
            target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "up_proj", "k_proj", "down_proj"],  # 与llama-factory一致
            lora_dropout=0.05,  
            task_type= TaskType.CAUSAL_LM  
        )
    model = get_peft_model(model, lora_config)
    # model.gradient_checkpointing_enable()
    model.train()    
    model.print_trainable_parameters()
    model.to("cuda")

    output = model(**train_dataset)
    loss = output["loss"]
    print(f"loss: {loss.requires_grad}")


if __name__ == "__main__":
    main()

output is

loss: True

Expected behavior

I am not entirely sure if this is a bug in the implementation of model.gradient_checkpointing_enable(). If it is not, please feel free to close the issue directly and let me know. Thank you for taking the time to look into this issue :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions