- 🎉 2025-09-30: We release the SDLM. Sequential Diffusion Language Model (SDLM) enhances pre-trained autoregressive language models by adaptively determining generation length and maintaining KV-cache compatibility, achieving high efficiency and throughput.
- 🚀 2025-09-29: We provide the complete training and inference code for SDLM, and provide the training dataset and configuration.
- Model Zoo
- Inference
- Training
- Evaluation
- Technical Report
In the following table, we provide an overview of the SDLM series.
| Model Name | Base Model 🤗 | HF Link 🤗 | 
|---|---|---|
| SDLM-3B-D4 | Qwen2.5-3B | https://huggingface.co/OpenGVLab/SDLM-3B-D4 | 
| SDLM-3B-D8 | Qwen2.5-3B | https://huggingface.co/OpenGVLab/SDLM-3B-D8 | 
| SDLM-32B-D4 | Qwen2.5-32B | https://huggingface.co/OpenGVLab/SDLM-32B-D4 | 
We propose a Sequential Diffusion Language Model (SDLM), to cheaply stimulate the parallel prediction capabilities of diffusion models.
- Autoregression: Predicts tokens one by one.
- Diffusion: Regenerates all tokens each step.
- SDLM (ours): Decodes D tokens per step, then keeps the longest consecutive n confident tokens (1 ≤ n ≤ D). Cached tokens are reused, saving computation.
SDLM delivers strong performance with significantly faster decoding speed. It operates approximately 2x faster than comparable autoregressive models while matching their accuracy, and achieves up to 5x speedup over other diffusion language models, as evidenced by results on the MATH-500 benchmark.
(a) Training pipeline
The reordered input sequence enables structured masking with:
- Causal prefix (top-left)
- Visible cross-block prefix (bottom-left)
- Intra-block bidirectional attention (bottom-right)
(b) Sampling Pipeline
Confidence-based dynamic block decoding with KV cache reuse.
At each step, a block of 
SDLM-32B scores 92.4 (GSM8K), 74.2 (MATH), 78.6 (IFEval), and remains competitive on HumanEval (81.1) and MBPP (80.9). Our smaller 3B model outperforms similar-sized models and larger diffusion-based alternatives with limited training.
For efficiency, each forward pass generates ∼2 tokens on average, achieving ≈2× speedup and two-thirds latency of AR models.
Trade-off between performance and speed under different confidence thresholds 
By adjusting 
With HuggingFace
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sdlm_inference import SDLM_generate
if __name__ == "__main__":
    ckpt_hf = 'OpenGVLab/SDLM-3B-D4'
    model = AutoModelForCausalLM.from_pretrained(
        ckpt_hf, 
        attn_implementation="eager",
        trust_remote_code=True
    ).to(dtype=torch.float16)
    tokenizer = AutoTokenizer.from_pretrained(ckpt_hf)
    prompt = 'Write a Fibonacci function in Python.'
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    response, history = SDLM_generate(
        model,
        tokenizer,
        model_inputs,
        max_gen_len = 1024,
        temperature = 0,
        threshold = 0.5,
        n_future_tokens = 4,
        alg = 'prob_conf', #  prob_conf | entropy_conf | self_speculative
        save_history = True,
        use_cache = True
    )
    print('response: ', response[0])
    print('=======histroy')
    for item in history:
        print('cur total token ', item[1])
        print(item[0][0])
        print('--------')- 
Environment Setup git clone https://github.com/OpenGVLab/SDLM.git cd SDLM
- 
Install Dependencies Key package versions: transformers==4.37.2 deepspeed==0.16.5 torch>=2.5.0 accelerate==0.32.1Note: Additional setup is required if using Flex Attention. 
- 
Prepare Training Data The training dataset we used is specified in the meta file: meta.json and is organized in the ShareGPT style, according to the InternVL chat data format. The script preprocess_scalequestmath.py serves as a reference implementation. This dataset is composed of several open-source datasets, with the following structure: Dataset Name # Sample Domain ScaleQuest-Math 1,000K Math Opc-sft-stage2 436K Code Smoltalk 1,100K General Tulu-3-sft-mixture 939K General SciRIFF 79K Scienece Table-GPT 13K Table Total 3,506K -- 
- 
Start Training All training scripts are available in the shell/train directory. Key parameters include: - block_size: The size of the diffusion window. Current settings use- 4, we also try to use- 8; larger sizes are under exploration.
- attn_implementation: Attention implementation type. Options include sdpa, eager, or flex_attn. Using Flex Attention requires additional setup. Prefer to use- sdpafor a quick start.
- causal_attn: Whether to use causal attention within the window. Currently set to non-causal (- False).
 Our training setting is: The training loss of our 3B model. loss_pos_ irefers to the loss at thei-th position of each block. The loss ati=0is close to the SFT loss of AR's NTP.Here, we display the loss corresponding to each position within the window during the training process. When bs=8, only the first 4 are shown. The correspondence is as follows: bs = 4 (red): x m m m loss_pos_1 loss_pos_2 loss_pos_3 loss_pos_4 bs = 8 (orange): x m m m m m m m loss_pos_1 loss_pos_2 loss_pos_3 loss_pos_4 -- -- -- -- 
Currently, we use Opencompass for evaluation. For more details, please refer to the evaluation guide.
We extend our gratitude to the open-source community for their foundational contributions:
- InternVL The codebase we build upon.
- SMDM, LLaDA, Dream, Block Diffusion for insights into diffusion-based generative modeling.
- Qwen2.5 as a robust base model for comparative studies.
- Opencompass for providing a comprehensive evaluation framework.
- The creators of all datasets used in this work, enabling rigorous training and validation.
@article{liu2025sdlm,
  title={Sequential Diffusion Language Models},
  author={Liu, Yangzhou and Cao, Yue and Li, Hao and Luo, Gen and Chen, Zhe and Wang, Weiyun and Liang, Xiaobo and Qi, Biqing and Wu, Lijun and Tian, Changyao and Zhang, Yanting and Li, Yuqiang and Lu, Tong and Qiao, Yu and Dai, Jifeng and Wang, Wenhai},
  journal={arXiv preprint arXiv:2509.24007},
  year={2025}
}






