Skip to content

Latest commit

 

History

History
executable file
·
148 lines (101 loc) · 7.02 KB

File metadata and controls

executable file
·
148 lines (101 loc) · 7.02 KB

FastMCTS: A Simple Sampling Strategy for Data Synthesis

Official repository for FastMCTS: A Simple Sampling Strategy for Data Synthesis. FastMCTS is an MCTS-inspired offline synthetic data algorithm that provides a lightweight, practical solution for efficiently generating multi-step reasoning data with step-level supervision.

News

Installation

# Install dependencies
pip install --upgrade pip
sudo apt-get install graphviz
pip install -r requirements.txt

Usage

Our codebase supports sampling methods including Rejection Sampling, Tree Search, and FastMCTS.

Step 1: Deploy your language model

Deploy models via SGLang or vLLM, or centralize worker models using FastChat. Use OpenAI API to request responses. For FastMCTS, deploy separate models for synthesis and evaluation.

Step 2: Configure your task

Implement a task configuration to define synthesis parameters.

  • For example, if you want to use FastMCTS algorithms to generate reasoning data on AIME dataset, you could use:

    python task_configs/fastmcts/aime.py
  • If you want to use Rejection Sampling, here is another example for generatring arrowmaze data:

    python task_configs/rejection_sampling/arrowmaze.py

You could use our codebase to generate any sort of multi-step reasoning data.

Step 3: Process synthesized data

Rejection Sampling:

Responses are saved in your specified output path. Concatenate and shuffle them into a JSONL file for supervised fine-tuning.

FastMCTS:

Tree-structured data supports:

  1. SFT data construction:
    python tools/serialize_tree_data_to_sft.py --tree_dir ${tree_dir} --select_num 8 --output_path ${output_path}
  2. Branch-DPO data construction:
  • Since FastMCTS intrinsically provides step-level supervision, we could use this additional information to construct branch-level pair data for further DPO training, which is a second stage utilization for our synthetic data.

    python tools/process_tree_data_to_dpo.py --tree_dir ${tree_dir} --max_sample 5 --target_path xxx.jsonl --step_token_limit 1024 --tokenizer_path ${hf_tokenizer_path}
  • Then you will get branch-DPO format data, for example:

    {
      "problem_id": "041f459a3a2b2f6946f99da3960e81ea",
      "level": 3,
      "granularity": "step",
      "prompt": "Let \\( n \\) be the decimal integer 111...1 (with 1998 digits). What is the 1000th digit after the decimal point of \\( \\sqrt{n} \\)?",
      "initial_reason_steps": "Step 1: Identify the number \\( n \\).\nThe number \\( n \\) is a decimal integer consisting of 1998 digits, all of which are 1. Therefore, \\( n = 111\\ldots1 \\) (1998 ones).Step 2: Approximate \\( n \\) using a simpler form.\nWe can express \\( n \\) as:\n\\[ n = \\frac{10^{1998} - 1}{9} \\]",
      "chosen": "Step 3: Consider the square root of \\( n \\).\nWe need to find the 1000th digit after the decimal point of \\( \\sqrt{n} \\). To do this, we first approximate \\( \\sqrt{n} \\).",
      "rejected": "Step 3: Approximate \\( \\sqrt{n} \\).\nTo find \\( \\sqrt{n} \\), we use the approximation:\n\\[ \\sqrt{n} \\approx \\sqrt{\\frac{10^{1998}}{9}} = \\frac{10^{999}}{3} \\]",
      "full_chosen": "Step 3: Consider the square root of \\( n \\).\nWe need to find the 1000th digit after the decimal point of \\( \\sqrt{n} \\). To do this, we first approximate \\( \\sqrt{n} \\).\n\nStep 4: Approximate \\( \\sqrt{n} \\) using a simpler form.\nSince \\( n \\approx 10^{1997} \\) (because \\( 10^{1997} \\) is a 1 followed by 1997 zeros, and \\( n \\) is a 1 followed by 1997 nines, which is very close to \\( 10^{1997} \\)), we have:\n\\[ \\sqrt{n} \\approx \\sqrt{10^{1997}} = 10^{998.5} = 10^{998} \\cdot \\sqrt{10} \\]\n\nStep 5: Determine the decimal expansion of \\( \\sqrt{10} \\).\nThe value of \\( \\sqrt{10} \\) is approximately 3.1622776601683793319988935444327. This means:\n\\[ \\sqrt{n} \\approx 10^{998} \\cdot 3.1622776601683793319988935444327 \\]\n\nStep 6: Identify the 1000th digit after the decimal point.\nSince \\( 10^{998} \\) shifts the decimal point 998 places to the right, the number \\( \\sqrt{n} \\) will look like 31622776601683793319988935444327 followed by the rest of the digits of \\( \\sqrt{10} \\) but shifted 998 places to the left. Therefore, the 1000th digit after the decimal point of \\( \\sqrt{n} \\) is the 2nd digit of \\( \\sqrt{10} \\), which is 1.\n\nThus, the 1000th digit after the decimal point of \\( \\sqrt{n} \\) is:\n\\[ \\boxed{1} \\]",
      "full_rejected": "Step 3: Approximate \\( \\sqrt{n} \\).\nTo find \\( \\sqrt{n} \\), we use the approximation:\n\\[ \\sqrt{n} \\approx \\sqrt{\\frac{10^{1998}}{9}} = \\frac{10^{999}}{3} \\]\n\nStep 4: Calculate the decimal expansion of \\( \\frac{10^{999}}{3} \\).\nThe number \\( \\frac{10^{999}}{3} \\) is a repeating decimal with a repeating block of 3, i.e., \\( 0.333\\ldots \\). Thus:\n\\[ \\frac{10^{999}}{3} = 333\\ldots333.333\\ldots \\] (where the digit 3 repeats indefinitely).\n\nStep 5: Determine the 1000th digit after the decimal point.\nSince the decimal expansion of \\( \\frac{10^{999}}{3} \\) is \\( 0.333\\ldots \\), the 1000th digit after the decimal point is 3. Therefore, the 1000th digit after the decimal point of \\( \\sqrt{n} \\) is:\n\\[ \\boxed{3} \\]"
    }

Step 4: Start Training

Based on Step-DPO and Qwen, we have implemented our training code of supervised-finetuning and Branch-DPO.

  • For supervised finetuning:

    bash training_scripts/finetune_example.sh
  • For DPO Training

    bash training_scripts/dpo_example.sh
    

Visualize FastMCTS tree structure

  • We have also provided tools to visualize the generate FastMCTS tree structure.

    python tools/visualize.py --tree tree.pkl
    
  • Bellow is one example of FastMCTS structure generated on AIME

Citation

If you find our library helpful, please cite:

@article{DBLP:journals/corr/abs-2502-11476,
  author       = {Peiji Li and
                  Kai Lv and
                  Yunfan Shao and
                  Yichuan Ma and
                  Linyang Li and
                  Xiaoqing Zheng and
                  Xipeng Qiu and
                  Qipeng Guo},
  title        = {FastMCTS: {A} Simple Sampling Strategy for Data Synthesis},
  journal      = {CoRR},
  volume       = {abs/2502.11476},
  year         = {2025},
  url          = {https://doi.org/10.48550/arXiv.2502.11476},
  doi          = {10.48550/ARXIV.2502.11476},
  eprinttype    = {arXiv},
  eprint       = {2502.11476},
  timestamp    = {Wed, 19 Mar 2025 09:50:49 +0100},
  biburl       = {https://dblp.org/rec/journals/corr/abs-2502-11476.bib},
  bibsource    = {dblp computer science bibliography, https://dblp.org}
}