We propose Selective masking token (STM), a simple yet effective way to filter high ppl tokens from huamn styled training data to sustain target task training and non-target task performance at the same time, identifying the one of the root of performance degradation after fine-tuning.
- some basic commands to establish a training/inference envionment:
conda create -n robust-sft python=3.10 conda activate robust-sft pip install -r pytorch_requirements.txt pip install -r requirements.txt - note that you could install feasible cuda version for pytorch 2.3.1 according to your computing device.
- install latest axolotl from their repository
- Please do make sure your transformers version<=4.46.3, otherwise training may fail (for flash attention 2)
To reproduce our results. you need to generate the training data first
python3 generate_self-output_training_data.py --base_model meta-llama/Meta-Llama-3-8B-Instruct --task mbpp --mode "self-output"
python3 generate_self-output_training_data.py --base_model meta-llama/Meta-Llama-3-8B-Instruct --task mbpp --mode "rephrase"
python3 generate_stm_training_data.py --base_model meta-llama/Meta-Llama-3-8B-Instruct --task mbpp
- note that when applying alternative stm, you need to fine tune a base model with the data first, and provide the adapter path.
python3 generate_stm_training_data.py --base_model meta-llama/Meta-Llama-3-8B-Instruct --task mbpp --stm dfp --stm_adapter <adapter_path_in_distk>
The data should be stored ar dataset/self-output, dataset/stm, dataset/rephrase. (We already made MBPP dataset)
- first create an axolotl training config as training/mbpp-gt.yml
- you need to specify the training data path, base model, lora settings, training settings like learning rate, epochs...etc (we have provide our setting in training/mbpp-gt.yml)
- and modify
example_training_so_re_gt.shto add your training config then run
sh example_training_so_re_gt.sh
- the trained model should be placed
trained_models/
- First make sure you have done the STM data part from the previous step.
- run the following commands to train a STM model with perplexity threshold = 2.5, learning_rate=2e-5
- Please do check transformers==4.46.0 version for stm training.
export WANDB_ENTITY="XXXX"
export WANDB_PROJECT="XXXX"
export HF_TOKEN="xxxx"
accelerate launch train_with_mask.py --learning_rate 2e-5 --threshold 2.5
- first provide your WANDB_ENTITY, HF_TOKEN, OPENAI_API_KEY and OPENAI_ORGANIZATION in
example_inference.sh - in
example_inference.sh, you can provide your ownload_adapteras your trained adapter model, and the base model name to be added on then run:
sh example_inference.sh
- the results should be placed at
logging/<task>-test
- we use the inference pipeline from streambench repository
- please setup the same environment for streambench, and download bird datasets using their
download_textsql_dat.py- note that you should only use
download_bird(save_dir)function indownload_textsql_dat.pyby commenting out other download_xxx()functions
- note that you should only use
- we provide the lora adapter loading version as stream_bench/llms/hf_model_lora.py, please put it in stream-bench's stream-bench/llms/ as a cutomized model configuration.
- we provide an example config of inference as stream_bench/llms/example.yml, please move it to your stream-bench repo's
configs/agent/- modify model_name and load_adapter for our base model and trained adapter path.
- execute the text2sql bird inference by running the following commands:
export WANDB_ENTITY="{your wandb entity}"
export OPENAI_API_KEY="{your openai api key}"
export OAI_KEY="{your openai api key}"
export OPENAI_ORGANIZATION="{your openai org}"
echo "llama example"
python -m stream_bench.pipelines.run_bench --agent_cfg "configs/agents/example.yml" --bench_cfg "configs/bench/bird.yml" --entity "{your wandb entty name}" --use_wandb
@misc{wu2025mitigatingforgettingllmfinetuning,
title={Mitigating Forgetting in LLM Fine-Tuning via Low-Perplexity Token Learning},
author={Chao-Chung Wu and Zhi Rui Tam and Chieh-Yen Lin and Yun-Nung Chen and Shao-Hua Sun and Hung-yi Lee},
year={2025},
eprint={2501.14315},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2501.14315},
}