Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,15 +1339,13 @@ def convert_exported_program_to_serialized_trt_engine(
)

flattened_input_list = get_flat_args_with_check(
exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore
exported_program, list(trt_arg_inputs), trt_kwarg_inputs
)[0]

try:
interpreter_result = interpret_module_to_result(
gm,
inputs=flattened_input_list,
arg_inputs=list(trt_arg_inputs),
kwarg_inputs=trt_kwarg_inputs,
settings=settings,
engine_cache=engine_cache,
)
Expand Down
11 changes: 6 additions & 5 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
logger = logging.getLogger(__name__)


@needs_refit
@needs_refit # type: ignore[misc]
def construct_refit_mapping(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
Expand Down Expand Up @@ -85,7 +85,7 @@ def construct_refit_mapping(
return weight_refit_map


@needs_refit
@needs_refit # type: ignore[misc]
def construct_refit_mapping_from_weight_name_map(
weight_name_map: dict[Any, Any],
state_dict: dict[Any, Any],
Expand Down Expand Up @@ -128,7 +128,7 @@ def construct_refit_mapping_from_weight_name_map(
return engine_weight_map


@needs_refit
@needs_refit # type: ignore[misc]
def _refit_single_trt_engine_with_gm(
new_gm: torch.fx.GraphModule,
old_engine: trt.ICudaEngine,
Expand Down Expand Up @@ -211,7 +211,7 @@ def _refit_single_trt_engine_with_gm(
raise AssertionError("Refitting failed.")


@needs_refit
@needs_refit # type: ignore[misc]
def refit_module_weights(
compiled_module: torch.fx.GraphModule | ExportedProgram,
new_weight_module: ExportedProgram,
Expand Down Expand Up @@ -484,9 +484,10 @@ def refit_module_weights(
weight_name_map=None,
)

# clear EXCLUDE_WEIGHTS flag
# clear EXCLUDE_WEIGHTS flag and set INCLUDE_REFIT flag to make the engine refittable
serialization_config = engine.create_serialization_config()
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT)
serialized_engine = engine.serialize_with_config(serialization_config)

if isinstance(compiled_submodule, PythonTorchTensorRTModule):
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def __setstate__(self, state: dict[str, Any]) -> None:
"engine_capability",
"hardware_compatible",
"refit_identical_engine_weights",
"strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default?
"immutable_weights",
"enable_weight_streaming",
"tiling_optimization_level",
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _pretraced_backend(
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
)
if settings.strip_engine_weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would a torch.compile use try to use strip weights?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the warning back. Not sure why strip_engine_weights arg doesn't work for torch.compile()

logger.error(
logger.warning(
"strip_engine_weights arg is not supported for torch.compile()"
)
trt_compiled = compile_module(
Expand Down
95 changes: 1 addition & 94 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torch_tensorrt._utils import is_tensorrt_version_supported
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
Expand Down Expand Up @@ -594,79 +594,6 @@ def _save_weight_mapping(self) -> None:
gc.collect()
torch.cuda.empty_cache()

@needs_refit # type: ignore[misc]
def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
# query the cached TRT engine
cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr]
if cached_data is not None: # hit the cache
(
serialized_engine,
self._input_names,
self._output_names,
cached_engine_input_specs,
engine_compilation_settings,
self.weight_name_map,
self.ctx.requires_output_allocator,
) = cached_data

setting_compatiblity, incompattible_settings = settings_are_compatible(
self.compilation_settings, engine_compilation_settings
)
assert (
setting_compatiblity
), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})"

for i, e in enumerate(
[
Input.equivalent_spec(c, i)
for c, i in zip(cached_engine_input_specs, self.input_specs)
]
):
assert (
e
), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}"

_LOGGER.info(
"Found the cached engine that corresponds to this graph. It is directly loaded."
)

# refit the cached engine with the new graph module
if not self.compilation_settings.strip_engine_weights:
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)

from torch_tensorrt.dynamo._refit import (
_refit_single_trt_engine_with_gm,
)

_refit_single_trt_engine_with_gm(
new_gm=self.module,
old_engine=engine,
input_list=self.input_specs,
settings=self.compilation_settings,
weight_name_map=self.weight_name_map,
)

# TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine
# # EXCLUDE_WEIGHTS flag must be cleared
# serialization_config = engine.create_serialization_config()
# serialization_config.clear_flag(
# trt.SerializationFlag.EXCLUDE_WEIGHTS
# )
# serialized_engine = engine.serialize_with_config(
# serialization_config
# )
# # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller

return TRTInterpreterResult(
engine,
self._input_names,
self._output_names,
self.weight_name_map,
self.ctx.requires_output_allocator,
)
return None

def run(
self,
strict_type_constraints: bool = False,
Expand All @@ -682,26 +609,6 @@ def run(
Return:
TRTInterpreterResult
"""
# self.engine_cache could be None if:
# 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
# 2) both cache_built_engines and reuse_cached_engines are False
if (
self.engine_cache is not None
and not self.compilation_settings.immutable_weights
):
if (
self.compilation_settings.cache_built_engines
or self.compilation_settings.reuse_cached_engines
):
hash_val = self.engine_cache.get_hash(
self.module, self.input_specs, self.compilation_settings
)

if self.compilation_settings.reuse_cached_engines:
interpreter_result = self._pull_cached_engine(hash_val)
if interpreter_result is not None: # hit the cache
return interpreter_result # type: ignore[no-any-return]

self._construct_trt_network_def()
_LOGGER.debug(
f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB"
Expand Down
Loading
Loading