Skip to content

appier-research/robust-llm-finetunes

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Mitigating Forgetting in LLM Fine-Tuning via Low-Perplexity Token Learning

TLDR:

We propose Selective masking token (STM), a simple yet effective way to filter high ppl tokens from huamn styled training data to sustain target task training and non-target task performance at the same time, identifying the one of the root of performance degradation after fine-tuning.

To started with...

Environment setup

  • some basic commands to establish a training/inference envionment:
    conda create -n robust-sft python=3.10
    conda activate robust-sft
    pip install -r pytorch_requirements.txt
    pip install -r requirements.txt
    
  • note that you could install feasible cuda version for pytorch 2.3.1 according to your computing device.
  • install latest axolotl from their repository
  • Please do make sure your transformers version<=4.46.3, otherwise training may fail (for flash attention 2)

To reproduce our results. you need to generate the training data first

Generation of training task data.

SO (Self-Output generated data) for Llama3-8B Instruct on MBPP

python3 generate_self-output_training_data.py --base_model meta-llama/Meta-Llama-3-8B-Instruct  --task mbpp --mode "self-output"

Rephrase generated data for Llama3-8B Instruct on MBPP

python3 generate_self-output_training_data.py --base_model meta-llama/Meta-Llama-3-8B-Instruct  --task mbpp --mode "rephrase"

STM (Selective Token Masking data) for Llama3-8B Instruct on MBPP with naive STM

python3 generate_stm_training_data.py --base_model meta-llama/Meta-Llama-3-8B-Instruct --task mbpp

STM (Selective Token Masking data) for Llama3-8B Instruct on MBPP with DPF

  • note that when applying alternative stm, you need to fine tune a base model with the data first, and provide the adapter path.
python3 generate_stm_training_data.py --base_model meta-llama/Meta-Llama-3-8B-Instruct --task mbpp --stm dfp --stm_adapter <adapter_path_in_distk>

The data should be stored ar dataset/self-output, dataset/stm, dataset/rephrase. (We already made MBPP dataset)

Training of tasks

run all training for self-output, rephrase, gt and stm data with Llama3-8B Instruct on MBPP task

  • first create an axolotl training config as training/mbpp-gt.yml
  • you need to specify the training data path, base model, lora settings, training settings like learning rate, epochs...etc (we have provide our setting in training/mbpp-gt.yml)
  • and modify example_training_so_re_gt.sh to add your training config then run
sh example_training_so_re_gt.sh
  • the trained model should be placed trained_models/

Purely train stm data on Llama3-8B Insruct on MBPP

  • First make sure you have done the STM data part from the previous step.
  • run the following commands to train a STM model with perplexity threshold = 2.5, learning_rate=2e-5
  • Please do check transformers==4.46.0 version for stm training.
export WANDB_ENTITY="XXXX"
export WANDB_PROJECT="XXXX"
export HF_TOKEN="xxxx"
accelerate launch train_with_mask.py --learning_rate 2e-5 --threshold 2.5

Inference of tasks

Inference trained adapter models with base model Llama3 on MBPP testing data

  • first provide your WANDB_ENTITY, HF_TOKEN, OPENAI_API_KEY and OPENAI_ORGANIZATION in example_inference.sh
  • in example_inference.sh, you can provide your own load_adapter as your trained adapter model, and the base model name to be added on then run:
sh example_inference.sh
  • the results should be placed at logging/<task>-test

Inference on bird task

  • we use the inference pipeline from streambench repository
  • please setup the same environment for streambench, and download bird datasets using their download_textsql_dat.py
    • note that you should only usedownload_bird(save_dir) function in download_textsql_dat.py by commenting out other download_xxx()functions
  • we provide the lora adapter loading version as stream_bench/llms/hf_model_lora.py, please put it in stream-bench's stream-bench/llms/ as a cutomized model configuration.
  • we provide an example config of inference as stream_bench/llms/example.yml, please move it to your stream-bench repo's configs/agent/
    • modify model_name and load_adapter for our base model and trained adapter path.
  • execute the text2sql bird inference by running the following commands:
export WANDB_ENTITY="{your wandb entity}"
export OPENAI_API_KEY="{your openai api key}"
export OAI_KEY="{your openai api key}"
export OPENAI_ORGANIZATION="{your openai org}"
echo "llama example"
python -m stream_bench.pipelines.run_bench     --agent_cfg "configs/agents/example.yml"     --bench_cfg "configs/bench/bird.yml"     --entity "{your wandb entty name}"     --use_wandb

Citation (already submitted before adding arxiv link)

If you find our work helpful, please cite as

@misc{wu2025mitigatingforgettingllmfinetuning,
      title={Mitigating Forgetting in LLM Fine-Tuning via Low-Perplexity Token Learning}, 
      author={Chao-Chung Wu and Zhi Rui Tam and Chieh-Yen Lin and Yun-Nung Chen and Shao-Hua Sun and Hung-yi Lee},
      year={2025},
      eprint={2501.14315},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2501.14315}, 
}

About

Accepted to NeurIPS 2025

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors