|
25 | 25 | from torch.fx.passes.shape_prop import TensorMetadata
|
26 | 26 | from torch.utils._python_dispatch import _disable_current_modes
|
27 | 27 | from torch_tensorrt._enums import dtype
|
| 28 | +from torch_tensorrt._features import needs_refit |
28 | 29 | from torch_tensorrt._Input import Input
|
29 | 30 | from torch_tensorrt.dynamo import _defaults
|
30 | 31 | from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
|
|
42 | 43 | get_node_name,
|
43 | 44 | get_trt_tensor,
|
44 | 45 | )
|
45 |
| -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device |
| 46 | +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device |
46 | 47 | from torch_tensorrt.fx.observer import Observer
|
47 | 48 | from torch_tensorrt.logging import TRT_LOGGER
|
48 | 49 |
|
@@ -430,6 +431,7 @@ def check_weight_equal(
|
430 | 431 | except Exception:
|
431 | 432 | return torch.all(sd_weight == network_weight)
|
432 | 433 |
|
| 434 | + @needs_refit |
433 | 435 | def _save_weight_mapping(self) -> None:
|
434 | 436 | """
|
435 | 437 | Construct the weight name mapping from engine weight name to state_dict weight name.
|
@@ -487,15 +489,10 @@ def _save_weight_mapping(self) -> None:
|
487 | 489 | _LOGGER.info("Building weight name mapping...")
|
488 | 490 | # Stage 1: Name mapping
|
489 | 491 | torch_device = to_torch_device(self.compilation_settings.device)
|
490 |
| - gm_is_on_cuda = get_model_device(self.module).type == "cuda" |
491 |
| - if not gm_is_on_cuda: |
492 |
| - # If the model original position is on CPU, move it GPU |
493 |
| - sd = { |
494 |
| - k: v.reshape(-1).to(torch_device) |
495 |
| - for k, v in self.module.state_dict().items() |
496 |
| - } |
497 |
| - else: |
498 |
| - sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()} |
| 492 | + sd = { |
| 493 | + k: v.reshape(-1).to(torch_device) |
| 494 | + for k, v in self.module.state_dict().items() |
| 495 | + } |
499 | 496 | weight_name_map: dict[str, Any] = {}
|
500 | 497 | np_map = {}
|
501 | 498 | constant_mapping = {}
|
@@ -579,6 +576,7 @@ def _save_weight_mapping(self) -> None:
|
579 | 576 | gc.collect()
|
580 | 577 | torch.cuda.empty_cache()
|
581 | 578 |
|
| 579 | + @needs_refit |
582 | 580 | def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
|
583 | 581 | # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
|
584 | 582 | # if not self.compilation_settings.strip_engine_weights:
|
@@ -606,6 +604,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No
|
606 | 604 | ),
|
607 | 605 | )
|
608 | 606 |
|
| 607 | + @needs_refit |
609 | 608 | def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
|
610 | 609 | # query the cached TRT engine
|
611 | 610 | cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr]
|
@@ -716,7 +715,7 @@ def run(
|
716 | 715 | if self.compilation_settings.reuse_cached_engines:
|
717 | 716 | interpreter_result = self._pull_cached_engine(hash_val)
|
718 | 717 | if interpreter_result is not None: # hit the cache
|
719 |
| - return interpreter_result |
| 718 | + return interpreter_result # type: ignore[no-any-return] |
720 | 719 |
|
721 | 720 | self._construct_trt_network_def()
|
722 | 721 |
|
|
0 commit comments