|
7 | 7 | import warnings |
8 | 8 | from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union |
9 | 9 |
|
| 10 | +import psutil |
10 | 11 | import torch |
11 | 12 | from torch.export import ExportedProgram |
12 | 13 | from torch.fx.node import Target |
@@ -108,7 +109,8 @@ def cross_compile_for_windows( |
108 | 109 | l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, |
109 | 110 | offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, |
110 | 111 | use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, |
111 | | - cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET, |
| 112 | + enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, |
| 113 | + cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, |
112 | 114 | **kwargs: Any, |
113 | 115 | ) -> torch.fx.GraphModule: |
114 | 116 | """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows |
@@ -183,7 +185,8 @@ def cross_compile_for_windows( |
183 | 185 | tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. |
184 | 186 | l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). |
185 | 187 | use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model |
186 | | - cpu_memory_budget (int): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. If set to -1, the compilation will use all available CPU memory. |
| 188 | + enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited. |
| 189 | + cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. |
187 | 190 | **kwargs: Any, |
188 | 191 | Returns: |
189 | 192 | torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT |
@@ -339,6 +342,7 @@ def cross_compile_for_windows( |
339 | 342 | "tiling_optimization_level": tiling_optimization_level, |
340 | 343 | "l2_limit_for_tiling": l2_limit_for_tiling, |
341 | 344 | "use_distributed_mode_trace": use_distributed_mode_trace, |
| 345 | + "enable_resource_partitioning": enable_resource_partitioning, |
342 | 346 | "cpu_memory_budget": cpu_memory_budget, |
343 | 347 | } |
344 | 348 |
|
@@ -441,7 +445,8 @@ def compile( |
441 | 445 | l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, |
442 | 446 | offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, |
443 | 447 | use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, |
444 | | - cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET, |
| 448 | + cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, |
| 449 | + enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, |
445 | 450 | **kwargs: Any, |
446 | 451 | ) -> torch.fx.GraphModule: |
447 | 452 | """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT |
@@ -519,6 +524,8 @@ def compile( |
519 | 524 | l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). |
520 | 525 | offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. |
521 | 526 | use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model |
| 527 | + enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited. |
| 528 | + cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. |
522 | 529 | **kwargs: Any, |
523 | 530 | Returns: |
524 | 531 | torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT |
@@ -688,6 +695,7 @@ def compile( |
688 | 695 | "l2_limit_for_tiling": l2_limit_for_tiling, |
689 | 696 | "offload_module_to_cpu": offload_module_to_cpu, |
690 | 697 | "use_distributed_mode_trace": use_distributed_mode_trace, |
| 698 | + "enable_resource_partitioning": enable_resource_partitioning, |
691 | 699 | "cpu_memory_budget": cpu_memory_budget, |
692 | 700 | } |
693 | 701 | logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") |
@@ -862,10 +870,15 @@ def preserve_module_specs( |
862 | 870 | require_full_compilation=settings.require_full_compilation, |
863 | 871 | ) |
864 | 872 |
|
865 | | - partitioned_module = resource_partition( |
866 | | - partitioned_module, |
867 | | - cpu_memory_budget=settings.cpu_memory_budget, |
868 | | - ) |
| 873 | + if settings.enable_resource_partitioning: |
| 874 | + partitioned_module = resource_partition( |
| 875 | + partitioned_module, |
| 876 | + cpu_memory_budget=( |
| 877 | + settings.cpu_memory_budget |
| 878 | + if settings.cpu_memory_budget is not None |
| 879 | + else psutil.virtual_memory().available |
| 880 | + ), |
| 881 | + ) |
869 | 882 |
|
870 | 883 | dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators |
871 | 884 |
|
|
0 commit comments