1+ """
2+ Copyright (c) 2025 by FlashInfer team.
3+
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
7+
8+ http://www.apache.org/licenses/LICENSE-2.0
9+
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
15+
16+ AOT build script for FlashInfer.
17+
18+ NOTE (Zihao): The following modules are intentionally excluded from the AOT build:
19+ - gen_pod_module
20+ - gen_deepgemm_sm100_module (it doesn't involve host-side compilation)
21+ """
22+
123import argparse
224import os
325import shutil
1234from .fp8_quantization import gen_mxfp8_quantization_sm100_module
1335from .cascade import gen_cascade_module
1436from .fp4_quantization import (
15- gen_fp4_quantization_sm100_module ,
1637 gen_fp4_quantization_sm90_module ,
38+ gen_fp4_quantization_sm100_module ,
39+ gen_fp4_quantization_sm103_module ,
40+ gen_fp4_quantization_sm110_module ,
41+ gen_fp4_quantization_sm120_module ,
42+ gen_fp4_quantization_sm121_module ,
1743)
1844from .fused_moe import (
1945 gen_cutlass_fused_moe_sm100_module ,
2753 gen_gemm_sm100_module_cutlass_fp4 ,
2854 gen_gemm_sm100_module_cutlass_fp8 ,
2955 gen_gemm_sm100_module_tgv ,
56+ gen_gemm_sm120_module ,
57+ gen_gemm_sm120_module_cutlass_fp4 ,
3058 gen_trtllm_gen_gemm_module ,
3159)
3260from .jit import JitSpec , build_jit_specs
3361from .jit import env as jit_env
3462from .jit import (
63+ gen_batch_attention_module ,
3564 gen_batch_decode_module ,
3665 gen_batch_mla_module ,
3766 gen_batch_prefill_module ,
67+ gen_cudnn_fmha_module ,
3868 gen_fmha_cutlass_sm100a_module ,
3969 gen_single_decode_module ,
4070 gen_single_prefill_module ,
@@ -187,6 +217,18 @@ def gen_attention(
187217 use_sliding_window = use_sliding_window ,
188218 use_logits_soft_cap = use_logits_soft_cap ,
189219 )
220+ yield gen_batch_attention_module (
221+ dtype_q = dtype_qo ,
222+ dtype_kv = dtype_kv ,
223+ dtype_o = dtype_qo ,
224+ dtype_idx = torch .int32 ,
225+ head_dim_qk = head_dim_qk ,
226+ head_dim_vo = head_dim_vo ,
227+ pos_encoding_mode = 0 ,
228+ # use_sliding_window=use_sliding_window,
229+ use_logits_soft_cap = use_logits_soft_cap ,
230+ use_profiler = False ,
231+ )
190232
191233 # FA3 MHA / MQA / GQA
192234 if has_sm90 :
@@ -357,8 +399,7 @@ def gen_all_modules(
357399 fa3_head_dim_ : List [Tuple [int , int ]],
358400 use_sliding_window_ : List [bool ],
359401 use_logits_soft_cap_ : List [bool ],
360- has_sm90 : bool ,
361- has_sm100 : bool ,
402+ sm_capabilities : dict ,
362403 add_comm : bool ,
363404 add_gemma : bool ,
364405 add_oai_oss : bool ,
@@ -368,6 +409,12 @@ def gen_all_modules(
368409 add_xqa : bool ,
369410) -> List [JitSpec ]:
370411 jit_specs : List [JitSpec ] = []
412+ has_sm90 = sm_capabilities .get ("sm90" , False )
413+ has_sm100 = sm_capabilities .get ("sm100" , False )
414+ has_sm103 = sm_capabilities .get ("sm103" , False )
415+ has_sm110 = sm_capabilities .get ("sm110" , False )
416+ has_sm120 = sm_capabilities .get ("sm120" , False )
417+ has_sm121 = sm_capabilities .get ("sm121" , False )
371418
372419 jit_specs += list (
373420 gen_attention (
@@ -406,6 +453,16 @@ def gen_all_modules(
406453 jit_specs .append (gen_mxfp8_quantization_sm100_module ())
407454 jit_specs .append (gen_trtllm_gen_gemm_module ())
408455 jit_specs .append (gen_trtllm_gen_fused_moe_sm100_module ())
456+ if has_sm103 :
457+ jit_specs .append (gen_fp4_quantization_sm103_module ())
458+ if has_sm110 :
459+ jit_specs .append (gen_fp4_quantization_sm110_module ())
460+ if has_sm120 :
461+ jit_specs .append (gen_fp4_quantization_sm120_module ())
462+ jit_specs .append (gen_gemm_sm120_module ())
463+ jit_specs .append (gen_gemm_sm120_module_cutlass_fp4 ())
464+ if has_sm121 :
465+ jit_specs .append (gen_fp4_quantization_sm121_module ())
409466
410467 if add_comm :
411468 from .comm import gen_trtllm_comm_module , gen_vllm_comm_module
@@ -450,6 +507,9 @@ def gen_all_modules(
450507 )
451508 )
452509
510+ # Add cuDNN FMHA module
511+ jit_specs .append (gen_cudnn_fmha_module ())
512+
453513 # dedup
454514 names = set ()
455515 ret : List [JitSpec ] = []
@@ -523,13 +583,20 @@ def has_sm(compute: str, version: str) -> bool:
523583 return True
524584 return version_at_least (torch .version .cuda , version )
525585
526- return has_sm ("compute_90" , "12.3" ), has_sm ("compute_100" , "12.8" )
586+ return {
587+ "sm90" : has_sm ("compute_90" , "12.3" ),
588+ "sm100" : has_sm ("compute_100" , "12.8" ),
589+ "sm103" : has_sm ("compute_103" , "12.8" ),
590+ "sm110" : has_sm ("compute_110" , "12.9" ),
591+ "sm120" : has_sm ("compute_120" , "13.0" ),
592+ "sm121" : has_sm ("compute_121" , "13.0" ),
593+ }
527594
528595
529596def register_default_modules () -> int :
530597 """Register the default set of modules"""
531598 config = get_default_config ()
532- has_sm90 , has_sm100 = detect_sm_capabilities ()
599+ sm_capabilities = detect_sm_capabilities ()
533600
534601 jit_specs = gen_all_modules (
535602 config ["f16_dtype" ],
@@ -538,8 +605,7 @@ def register_default_modules() -> int:
538605 config ["fa3_head_dim" ],
539606 config ["use_sliding_window" ],
540607 config ["use_logits_soft_cap" ],
541- has_sm90 ,
542- has_sm100 ,
608+ sm_capabilities ,
543609 config ["add_comm" ],
544610 config ["add_gemma" ],
545611 config ["add_oai_oss" ],
@@ -649,7 +715,7 @@ def main():
649715 if "FLASHINFER_CUDA_ARCH_LIST" not in os .environ :
650716 raise RuntimeError ("Please explicitly set env var FLASHINFER_CUDA_ARCH_LIST." )
651717
652- has_sm90 , has_sm100 = detect_sm_capabilities ()
718+ sm_capabilities = detect_sm_capabilities ()
653719
654720 # Update data dir
655721 jit_env .FLASHINFER_CSRC_DIR = project_root / "csrc"
@@ -678,8 +744,10 @@ def main():
678744 print (" use_sliding_window:" , config ["use_sliding_window" ])
679745 print (" use_logits_soft_cap:" , config ["use_logits_soft_cap" ])
680746 print (" FLASHINFER_CUDA_ARCH_LIST:" , os .environ ["FLASHINFER_CUDA_ARCH_LIST" ])
681- print (" has_sm90:" , has_sm90 )
682- print (" has_sm100:" , has_sm100 )
747+ print (" SM capabilities detected:" )
748+ for sm_name , has_sm in sm_capabilities .items ():
749+ if has_sm :
750+ print (f" { sm_name } : True" )
683751 for key in [
684752 "add_comm" ,
685753 "add_gemma" ,
@@ -701,8 +769,7 @@ def main():
701769 config ["fa3_head_dim" ],
702770 config ["use_sliding_window" ],
703771 config ["use_logits_soft_cap" ],
704- has_sm90 ,
705- has_sm100 ,
772+ sm_capabilities ,
706773 config ["add_comm" ],
707774 config ["add_gemma" ],
708775 config ["add_oai_oss" ],
0 commit comments