44# SPDX-License-Identifier: BSD-3-Clause
55#
66# -----------------------------------------------------------------------------
7-
87import inspect
8+ import json
9+ import os
910from dataclasses import asdict
11+ from typing import Any , Dict
1012
1113import torch .distributed as dist
1214import torch .utils .data as data_utils
15+ import yaml
1316from peft import (
1417 AdaptionPromptConfig ,
15- LoraConfig ,
1618 PrefixTuningConfig ,
1719)
20+ from peft import LoraConfig as PeftLoraConfig
21+ from transformers import default_data_collator
1822from transformers .data import DataCollatorForSeq2Seq
1923
2024import QEfficient .finetune .configs .dataset_config as datasets
21- from QEfficient .finetune .configs .peft_config import lora_config , prefix_config
22- from QEfficient .finetune .configs .training import train_config
25+ from QEfficient .finetune .configs .peft_config import LoraConfig
26+ from QEfficient .finetune .configs .training import TrainConfig
2327from QEfficient .finetune .data .sampler import DistributedLengthBasedBatchSampler
2428from QEfficient .finetune .dataset .dataset_config import DATASET_PREPROC
2529
2630
2731def update_config (config , ** kwargs ):
32+ """Update the attributes of a config object based on provided keyword arguments.
33+
34+ Args:
35+ config: The configuration object (e.g., TrainConfig, LoraConfig) or a list/tuple of such objects.
36+ **kwargs: Keyword arguments representing attributes to update.
37+
38+ Raises:
39+ ValueError: If an unknown parameter is provided and the config type doesn't support nested updates.
40+ """
2841 if isinstance (config , (tuple , list )):
2942 for c in config :
3043 update_config (c , ** kwargs )
@@ -33,40 +46,68 @@ def update_config(config, **kwargs):
3346 if hasattr (config , k ):
3447 setattr (config , k , v )
3548 elif "." in k :
36- # allow --some_config.some_param=True
37- config_name , param_name = k .split ("." )
38- if type (config ).__name__ == config_name :
49+ config_name , param_name = k .split ("." , 1 )
50+ if type (config ).__name__ .lower () == config_name .lower ():
3951 if hasattr (config , param_name ):
4052 setattr (config , param_name , v )
4153 else :
42- # In case of specialized config we can warn user
43- assert False , f"Warning: { config_name } does not accept parameter: { k } "
44- elif isinstance (config , train_config ):
45- assert False , f"Warning: unknown parameter { k } "
54+ raise ValueError ( f"Config ' { config_name } ' does not have parameter: ' { param_name } '" )
55+ else :
56+ config_type = type (config ). __name__
57+ print ( f"[WARNING]: Unknown parameter ' { k } ' for config type ' { config_type } '" )
4658
4759
48- def generate_peft_config (train_config , kwargs ):
49- configs = (lora_config , prefix_config )
50- peft_configs = (LoraConfig , AdaptionPromptConfig , PrefixTuningConfig )
51- names = tuple (c .__name__ .rstrip ("_config" ) for c in configs )
60+ def generate_peft_config (train_config : TrainConfig , custom_config : Any ) -> Any :
61+ """Generate a PEFT-compatible configuration from a custom config based on peft_method.
5262
53- if train_config .peft_method not in names :
54- raise RuntimeError (f"Peft config not found: { train_config .peft_method } " )
63+ Args:
64+ train_config (TrainConfig): Training configuration with peft_method.
65+ custom_config: Custom configuration object (e.g., LoraConfig).
5566
56- config = configs [names .index (train_config .peft_method )]()
67+ Returns:
68+ Any: A PEFT-specific configuration object (e.g., PeftLoraConfig).
5769
58- update_config (config , ** kwargs )
70+ Raises:
71+ RuntimeError: If the peft_method is not supported.
72+ """
73+ # Define supported PEFT methods and their corresponding configs
74+ method_to_configs = {
75+ "lora" : (LoraConfig , PeftLoraConfig ),
76+ "adaption_prompt" : (None , AdaptionPromptConfig ), # Placeholder; add custom config if needed
77+ "prefix_tuning" : (None , PrefixTuningConfig ), # Placeholder; add custom config if needed
78+ }
79+
80+ peft_method = train_config .peft_method .lower ()
81+ if peft_method not in method_to_configs :
82+ raise RuntimeError (f"PEFT config not found for method: { train_config .peft_method } " )
83+
84+ custom_config_class , peft_config_class = method_to_configs [peft_method ]
85+
86+ # Use the provided custom_config (e.g., LoraConfig instance)
87+ config = custom_config
5988 params = asdict (config )
60- peft_config = peft_configs [names .index (train_config .peft_method )](** params )
6189
90+ # Create the PEFT-compatible config
91+ peft_config = peft_config_class (** params )
6292 return peft_config
6393
6494
65- def generate_dataset_config (train_config , kwargs ):
95+ def generate_dataset_config (train_config : TrainConfig , kwargs : Dict [str , Any ] = None ) -> Any :
96+ """Generate a dataset configuration based on the specified dataset in train_config.
97+
98+ Args:
99+ train_config (TrainConfig): Training configuration with dataset name.
100+ kwargs (Dict[str, Any], optional): Additional arguments (currently unused).
101+
102+ Returns:
103+ Any: A dataset configuration object.
104+
105+ Raises:
106+ AssertionError: If the dataset name is not recognized.
107+ """
66108 names = tuple (DATASET_PREPROC .keys ())
67109 assert train_config .dataset in names , f"Unknown dataset: { train_config .dataset } "
68110 dataset_config = {k : v for k , v in inspect .getmembers (datasets )}[train_config .dataset ]()
69- update_config (dataset_config , ** kwargs )
70111 return dataset_config
71112
72113
@@ -98,3 +139,84 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
98139 kwargs ["drop_last" ] = True
99140 kwargs ["collate_fn" ] = DataCollatorForSeq2Seq (dataset_processer )
100141 return kwargs
142+
143+
144+ def validate_config (config_data : Dict [str , Any ], config_type : str = "lora" ) -> None :
145+ """Validate the provided YAML/JSON configuration for required fields and types.
146+
147+ Args:
148+ config_data (Dict[str, Any]): The configuration dictionary loaded from YAML/JSON.
149+ config_type (str): Type of config to validate ("lora" for LoraConfig, default: "lora").
150+
151+ Raises:
152+ ValueError: If required fields are missing or have incorrect types.
153+ FileNotFoundError: If the config file path is invalid (handled upstream).
154+
155+ Notes:
156+ - Validates required fields for LoraConfig: r, lora_alpha, target_modules.
157+ - Ensures types match expected values (int, float, list, etc.).
158+ """
159+ if config_type .lower () != "lora" :
160+ raise ValueError (f"Unsupported config_type: { config_type } . Only 'lora' is supported." )
161+
162+ required_fields = {
163+ "r" : int ,
164+ "lora_alpha" : int ,
165+ "target_modules" : list ,
166+ }
167+ optional_fields = {
168+ "bias" : str ,
169+ "task_type" : str ,
170+ "lora_dropout" : float ,
171+ "inference_mode" : bool ,
172+ }
173+
174+ # Check for missing required fields
175+ missing_fields = [field for field in required_fields if field not in config_data ]
176+ if missing_fields :
177+ raise ValueError (f"Missing required fields in { config_type } config: { missing_fields } " )
178+
179+ # Validate types of required fields
180+ for field , expected_type in required_fields .items ():
181+ if not isinstance (config_data [field ], expected_type ):
182+ raise ValueError (
183+ f"Field '{ field } ' in { config_type } config must be of type { expected_type .__name__ } , "
184+ f"got { type (config_data [field ]).__name__ } "
185+ )
186+
187+ # Validate target_modules contains strings
188+ if not all (isinstance (mod , str ) for mod in config_data ["target_modules" ]):
189+ raise ValueError ("All elements in 'target_modules' must be strings" )
190+
191+ # Validate types of optional fields if present
192+ for field , expected_type in optional_fields .items ():
193+ if field in config_data and not isinstance (config_data [field ], expected_type ):
194+ raise ValueError (
195+ f"Field '{ field } ' in { config_type } config must be of type { expected_type .__name__ } , "
196+ f"got { type (config_data [field ]).__name__ } "
197+ )
198+
199+
200+ def load_config_file (config_path : str ) -> Dict [str , Any ]:
201+ """Load a configuration from a YAML or JSON file.
202+
203+ Args:
204+ config_path (str): Path to the YAML or JSON file.
205+
206+ Returns:
207+ Dict[str, Any]: The loaded configuration as a dictionary.
208+
209+ Raises:
210+ FileNotFoundError: If the file does not exist.
211+ ValueError: If the file format is unsupported.
212+ """
213+ if not os .path .exists (config_path ):
214+ raise FileNotFoundError (f"Config file not found: { config_path } " )
215+
216+ with open (config_path , "r" ) as f :
217+ if config_path .endswith (".yaml" ) or config_path .endswith (".yml" ):
218+ return yaml .safe_load (f )
219+ elif config_path .endswith (".json" ):
220+ return json .load (f )
221+ else :
222+ raise ValueError ("Unsupported config file format. Use .yaml, .yml, or .json" )
0 commit comments