Skip to content

Commit 4593064

Browse files
committed
Added back the control flag and fixed the CI
1 parent 8b92866 commit 4593064

File tree

5 files changed

+19
-10
lines changed

5 files changed

+19
-10
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def compile(
422422
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
423423
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
424424
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
425+
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
425426
**kwargs: Any,
426427
) -> torch.fx.GraphModule:
427428
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -666,6 +667,7 @@ def compile(
666667
"enable_weight_streaming": enable_weight_streaming,
667668
"tiling_optimization_level": tiling_optimization_level,
668669
"l2_limit_for_tiling": l2_limit_for_tiling,
670+
"offload_module_to_cpu": offload_module_to_cpu,
669671
}
670672

671673
settings = CompilationSettings(**compilation_options)
@@ -677,16 +679,16 @@ def compile(
677679

678680
gm = exported_program.module()
679681
# Move the weights in the state_dict to CPU
680-
logger.info(
681-
"The model is moved to CPU during compilation. If you want to keep the model on GPU, call module.to('cuda') on the model after compilation."
682-
)
683682
logger.debug("Input graph: " + str(gm.graph))
684683

685684
# Apply lowering on the graph module
686685
gm = post_lowering(gm, settings)
687686
logger.debug("Lowered Input graph: " + str(gm.graph))
688-
689-
exported_program.module().to(CPU_DEVICE)
687+
if offload_module_to_cpu:
688+
exported_program.module().to(CPU_DEVICE)
689+
logger.info(
690+
"The model is offloaded to CPU during compilation. If you want to keep the model on GPU, set offload_module_to_cpu=False."
691+
)
690692
trt_gm = compile_module(
691693
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
692694
)

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
TILING_OPTIMIZATION_LEVEL = "none"
5050
L2_LIMIT_FOR_TILING = -1
5151
USE_DISTRIBUTED_MODE_TRACE = False
52+
OFFLOAD_MODULE_TO_CPU = False
5253

5354

5455
def default_device() -> Device:

py/torch_tensorrt/dynamo/_refit.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ def construct_refit_mapping(
108108

109109

110110
def construct_refit_mapping_from_weight_name_map(
111-
weight_name_map: dict[Any, Any], state_dict: dict[Any, Any]
111+
weight_name_map: dict[Any, Any],
112+
state_dict: dict[Any, Any],
113+
settings: CompilationSettings,
112114
) -> dict[Any, Any]:
113115
engine_weight_map = {}
114116
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
@@ -119,7 +121,9 @@ def construct_refit_mapping_from_weight_name_map(
119121
# If weights is not in sd, we can leave it unchanged
120122
continue
121123
else:
122-
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name]
124+
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to(
125+
to_torch_device(settings.device)
126+
)
123127

124128
engine_weight_map[engine_weight_name] = (
125129
engine_weight_map[engine_weight_name]
@@ -161,7 +165,7 @@ def _refit_single_trt_engine_with_gm(
161165
"constant_mapping", {}
162166
) # type: ignore
163167
mapping = construct_refit_mapping_from_weight_name_map(
164-
weight_name_map, new_gm.state_dict()
168+
weight_name_map, new_gm.state_dict(), settings
165169
)
166170
constant_mapping_with_type = {}
167171

py/torch_tensorrt/dynamo/_settings.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MAX_AUX_STREAMS,
2626
MIN_BLOCK_SIZE,
2727
NUM_AVG_TIMING_ITERS,
28+
OFFLOAD_MODULE_TO_CPU,
2829
OPTIMIZATION_LEVEL,
2930
PASS_THROUGH_BUILD_FAILURES,
3031
REFIT_IDENTICAL_ENGINE_WEIGHTS,
@@ -140,6 +141,7 @@ class CompilationSettings:
140141
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
141142
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
142143
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
144+
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
143145

144146

145147
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -729,8 +729,8 @@ def run(
729729
self._create_timing_cache(
730730
builder_config, self.compilation_settings.timing_cache_path
731731
)
732-
733-
delete_module(self.module)
732+
if self.compilation_settings.offload_module_to_cpu:
733+
delete_module(self.module)
734734
serialized_engine = self.builder.build_serialized_network(
735735
self.ctx.net, builder_config
736736
)

0 commit comments

Comments
 (0)