@@ -422,6 +422,7 @@ def compile(
422
422
enable_weight_streaming : bool = _defaults .ENABLE_WEIGHT_STREAMING ,
423
423
tiling_optimization_level : str = _defaults .TILING_OPTIMIZATION_LEVEL ,
424
424
l2_limit_for_tiling : int = _defaults .L2_LIMIT_FOR_TILING ,
425
+ offload_module_to_cpu : bool = _defaults .OFFLOAD_MODULE_TO_CPU ,
425
426
** kwargs : Any ,
426
427
) -> torch .fx .GraphModule :
427
428
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -666,6 +667,7 @@ def compile(
666
667
"enable_weight_streaming" : enable_weight_streaming ,
667
668
"tiling_optimization_level" : tiling_optimization_level ,
668
669
"l2_limit_for_tiling" : l2_limit_for_tiling ,
670
+ "offload_module_to_cpu" : offload_module_to_cpu ,
669
671
}
670
672
671
673
settings = CompilationSettings (** compilation_options )
@@ -677,16 +679,16 @@ def compile(
677
679
678
680
gm = exported_program .module ()
679
681
# Move the weights in the state_dict to CPU
680
- logger .info (
681
- "The model is moved to CPU during compilation. If you want to keep the model on GPU, call module.to('cuda') on the model after compilation."
682
- )
683
682
logger .debug ("Input graph: " + str (gm .graph ))
684
683
685
684
# Apply lowering on the graph module
686
685
gm = post_lowering (gm , settings )
687
686
logger .debug ("Lowered Input graph: " + str (gm .graph ))
688
-
689
- exported_program .module ().to (CPU_DEVICE )
687
+ if offload_module_to_cpu :
688
+ exported_program .module ().to (CPU_DEVICE )
689
+ logger .info (
690
+ "The model is offloaded to CPU during compilation. If you want to keep the model on GPU, set offload_module_to_cpu=False."
691
+ )
690
692
trt_gm = compile_module (
691
693
gm , trt_arg_inputs , trt_kwarg_inputs , settings , engine_cache
692
694
)
0 commit comments