This is the official implementation of paper 'Graph-KV: Breaking Sequence via Injecting Structural Biases into Large Language Models', Haoyu Wang, Peihao Wang, Mufei Li, Shikun Liu, Siqi Miao, Zhangyang Wang, Pan Li.
To setup the environment, follow the scripts:
conda create -n graphkv python==3.10.16
conda activate graphkv
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install transformers==4.50.0
pip install accelerate
pip install flash-attn==2.7.4.post1 --no-build-isolation
pip install pandas
# Below is to set up server to serve LLMs, (Following the Block-RAG paper)
pip install fires
pip install flask_cors
Below are the tuned-model prameters adopted in Graph-KV.
Item | Repository |
---|---|
8B-Block | 🤗 ldsjmdy/Tulu3-Block-FT |
8B-SFT | 🤗 ldsjmdy/Tulu3-SFT |
8B-RAG | 🤗 ldsjmdy/Tulu3-RAG |
We follow Block-Attention for Efficient Prefilling to pre-process the data for the 2Wiki, NarritiveQA, Trivia QA, HotpotQA.
For obtaining the raw data and pre-processing the data from scratch, please refer to the original implementation.
Instructions for directly downloading the processed data can be found below.
We follow the original implementation in MultiHop-RAG: Benchmarking Retrieval-Augmented Generation for Multi-Hop Queries to pre-process the data for Multihop-RAG, and the original implementation in MoreHopQA: More Than Multi-hop Reasoning to pre-process the data for Morehop-QA.
Please refer to their repositories Multihp-RAG, MorehopQA for dataset pre-processing from scratch.
Instructions for directly downloading the processed data can be found below.
Huggingface Dataset Link for Processed Data
One may dowload from the huggingface repository listed above to directly download all the required pre-processed data for both RAG (rag) and the Arxiv-QA (arxiv) tasks.
Please refer to the ReadMe file in the Huggingface Repo for details.
CUDA_VISIBLE_DEVICES=0 python3 server/generate_server.py --model ldsjmdy/Tulu3-Block-FT --port 8771 --dtype bfloat16
Component | Description |
---|---|
CUDA_VISIBLE_DEVICES=0 |
Specifies to use GPU device 0. Modify this if using multiple GPUs. |
--model |
The Hugging Face model to load. Options: |
ldsjmdy/Tulu3-Block-FT – Block-level fine-tuned model |
|
ldsjmdy/Tulu3-RAG – RAG-tuned model |
|
ldsjmdy/Tulu3-SFT – Supervised fine-tuned model |
|
--port 8771 |
Port on which the server will listen. Change this if the port is in use. |
python inference.py --pcw vanilla --model ldsjmdy/Tulu3-Block-FT --task nqa --port 8771
python inference.py --pcw gapemp --model ldsjmdy/Tulu3-Block-FT --task nqa --port 8771
python inference.py --pcw block --model ldsjmdy/Tulu3-Block-FT --task nqa --port 8771
python inference.py --pcw gapemp_appr --model ldsjmdy/Tulu3-Block-FT --task nqa --port 8771 --top_k 5
Argument | Description |
---|---|
--pcw |
Method for parallel context window (PCW) attention implementation. Options: |
vanilla – Sequential encoding without PCW |
|
gapemp – Uses Graph-KV |
|
block – Uses block-RAG |
|
gapemp_appr – Approximate version of Grpah-KV |
|
--model |
Model name loaded by the server (must match the one started on the server) |
--task |
Task type. Currently support 'nqa, 2wiki, tqa, hqa, morehopqa, multihopqa' |
--port |
Port number to communicate with the generation server |
--top_k (optional) |
Only used in gapemp_appr mode. Specifies top-K passages to select |
based on relevance scores during information retrieval. |
Note that the pre-computed results of each method is also provided, in the results folder of the Huggingface repository.
python rag_eval.py --input [PATH]
Argument | Description |
---|---|
--input |
Path to the .jsonl file containing inference outputs to evaluate. |
For example: ./results/tqa/ldsjmdyTulu3-Block-FT/vanilla_ascent.jsonl |
The data can be found in (arxiv) folder of the Huggingface Repository.
CUDA_VISIBLE_DEVICES=0,1,2,3 python arxiv_inference.py --pcw gapemp_graph --batch_size 1 --order first --model ldsjmdy/Tulu3-Block-FT
CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nproc_per_node=4 --master_port=29501 arxiv_inference.py --pcw vanilla --batch_size 1 --order first --model ldsjmdy/Tulu3-Block-FT
Argument | Description |
---|---|
CUDA_VISIBLE_DEVICES=0,1,2,3 |
Specifies the GPUs to use for inference. This example uses 4 GPUs. |
python arxiv_inference.py |
The evaluation script for Arxiv-QA. |
--pcw gapemp_graph |
The attention method Example: gapemp_graph , vanilla , block |
--batch_size 1 |
Number of distractors. 1 means no distractors. |
--order first |
Put the related paper group at last or first . |
--model ldsjmdy/Tulu3-Block-FT |
Specifies the model to be used for inference. |
Note that the pre-computed results of each method is also provided, in the results folder of the Huggingface repository.
python arxiv_eval.py --model ldsjmdy/Tulu3-Block-FT --pcw gapemp_graph --batch_size 1 --order last
Argument | Description |
---|---|
--model |
Name of the model to evaluate. Example: ldsjmdy/Tulu3-Block-FT . |
--pcw |
The attention method |
Example: gapemp_graph , vanilla , block |
|
--batch_size |
number of distractors. 1: no distractors. could be [1,2,3] |
--order |
Direct related paper ordering strategy. last means putting in the last. 'first' means putting in the first |
If you find this work or repository helpful, please consider citing:
@article{wang2025graph,
title={Graph-KV: Breaking Sequence via Injecting Structural Biases into Large Language Models},
author={Wang, Haoyu and Wang, Peihao and Li, Mufei and Liu, Shikun and Miao, Siqi and Wang, Zhangyang and Li, Pan},
journal={arXiv preprint arXiv:2506.07334},
year={2025}
}