Skip to content

Latest commit

 

History

History
99 lines (74 loc) · 3.41 KB

File metadata and controls

99 lines (74 loc) · 3.41 KB

Accelerating Inference in Retrieval-Augmented Generation Models for Long-Form Question Answering via Dynamic Token Pruning

This repository contains the implementation of the method presented in the paper "Accelerating Inference in Retrieval-Augmented Generation Models for Long-Form Question Answering via Dynamic Token Pruning".

The repository is based on the repository facebookresearch/FiD.

Data Preparation

Follow these steps to prepare the necessary data. These instructions are based on the original FiD setup.

Dataset and Wikipedia Passage Download

Download the required datasets and Wikipedia passages.

bash data_download.sh

Pre-trained Retrieval Model Downlaod

Download a pre-trained retriever model, such as the NQ Retriever from the FiD project.

Refer to the official FiD instructions: FiD Model Download Script

Passage Retrieval Index

Build the retrieval index for your knowledge source (e.g., the downloaded Wikipedia passages).

Detailed instructions can be found at: FiD Knowledge Source Indexing

Passage Retrieval

Retrieve relevant passages for your question-answering dataset using the indexed knowledge source and the pre-trained retriever.

Detailed instructions can be found at: FiD Passage Retrieval

Training

The following command launches a distributed training job for the reader model with dynamic token pruning.

PROC_PER_NODE=$(nvidia-smi --list-gpus | wc -l)

NGPU=${PROC_PER_NODE} CUDA_LAUNCH_BLOCKING=1 python -m torch.distributed.launch --nproc_per_node=${PROC_PER_NODE} train_reader_tp_clapnq.py \
        --seed <random_seed> \
        --use_checkpoint \
        --lr 1e-4 \
        --optim adamw \
        --scheduler linear \
        --weight_decay 0.01 \
        --question_maxlength 100 \
        --per_gpu_batch_size 1 \
        --n_context 50 \
        --text_maxlength 250 \
        --total_steps 20000 \
        --eval_freq 2000 \
        --eval_print_freq 1000 \
        --save_freq 2000 \
        --warmup_steps 1000 \
        --min_ratio 0.8 \
        --train_data <train_data_path> \
        --eval_data <eval_data_path> \
        --answer_maxlength 128 \
        --theta1 0.9 \
        --theta2 0.3 \
        --last_theta 0.05 \
        --gumbel_temperature 1.0 \
        --pruning_scale 2.0 \
        --kl_loss_scale 1.0 \
        --temp_retain_steps 1000 \
        --temp_reducing_steps 2000 \
        --model_size base \
        --checkpoint_dir <run_folder_name> \
        --name <log_folder_name> \
        --accumulation_steps 32 \

Test

Use the following command to evaluate a trained model on a evaluation set.

PROC_PER_NODE=$(nvidia-smi --list-gpus | wc -l)

NGPU=${PROC_PER_NODE} python -m torch.distributed.launch --nproc_per_node=${PROC_PER_NODE} test_reader.py \
        --seed <random_seed> \
        --model_path <model_path> \
        --eval_data <eval_data_path> \
        --per_gpu_batch_size 1 \
        --eval_print_freq 100 \
        --n_context 50 \
        --question_maxlength 100 \
        --text_maxlength 250 \
        --answer_maxlength 128 \
        --checkpoint_dir <run_folder_name> \
        --name <log_folder_name> \
        --write_results \