3838 print (f"Warning: { e } . Moving ahead without these qaic modules." )
3939
4040
41- from transformers import AutoModelForCausalLM , AutoTokenizer
41+ from transformers import AutoModelForCausalLM , AutoModelForSequenceClassification , AutoTokenizer
4242
4343# Suppress all warnings
4444warnings .filterwarnings ("ignore" )
@@ -56,6 +56,7 @@ def main(**kwargs):
5656 # update the configuration for the training process
5757 train_config = TRAIN_CONFIG ()
5858 update_config (train_config , ** kwargs )
59+ dataset_config = generate_dataset_config (train_config , kwargs )
5960 device = train_config .device
6061
6162 # dist init
@@ -78,12 +79,30 @@ def main(**kwargs):
7879 # Load the pre-trained model and setup its configuration
7980 # config = AutoConfig.from_pretrained(train_config.model_name)
8081 pretrained_model_path = login_and_download_hf_lm (train_config .model_name )
81- model = AutoModelForCausalLM .from_pretrained (
82- pretrained_model_path ,
83- use_cache = False ,
84- attn_implementation = "sdpa" ,
85- torch_dtype = torch .float16 ,
86- )
82+ if train_config .task_type == "seq_classification" :
83+ model = AutoModelForSequenceClassification .from_pretrained (
84+ pretrained_model_path ,
85+ num_labels = dataset_config .num_labels ,
86+ attn_implementation = "sdpa" ,
87+ torch_dtype = torch .float16 ,
88+ )
89+
90+ if not hasattr (model , "base_model_prefix" ):
91+ raise RuntimeError ("Given huggingface model does not have 'base_model_prefix' attribute." )
92+
93+ for param in getattr (model , model .base_model_prefix ).parameters ():
94+ param .requires_grad = False
95+
96+ for param in model .parameters ():
97+ if param .requires_grad :
98+ param .data = param .data .to (torch .float32 )
99+ else :
100+ model = AutoModelForCausalLM .from_pretrained (
101+ pretrained_model_path ,
102+ use_cache = False ,
103+ attn_implementation = "sdpa" ,
104+ torch_dtype = torch .float16 ,
105+ )
87106
88107 # Load the tokenizer and add special tokens
89108 tokenizer = AutoTokenizer .from_pretrained (
@@ -127,7 +146,6 @@ def main(**kwargs):
127146 model .print_trainable_parameters ()
128147
129148 # Get the dataset utils
130- dataset_config = generate_dataset_config (train_config , kwargs )
131149 dataset_processer = tokenizer
132150
133151 # Load and preprocess the dataset for training and validation
0 commit comments