Skip to content

Commit 437bbe9

Browse files
committed
Fixed the comments
1 parent 6a879b2 commit 437bbe9

File tree

6 files changed

+39
-19
lines changed

6 files changed

+39
-19
lines changed

examples/dynamo/low_cpu_memory_compilation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def forward(self, x):
6363
"min_block_size": 1,
6464
"immutable_weights": True,
6565
"reuse_cached_engines": False,
66+
"enable_resource_partitioning": True,
6667
"cpu_memory_budget": 2 * 1024 * 1024 * 1024, # 2 GiB in bytes
6768
}
6869

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import warnings
88
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
99

10+
import psutil
1011
import torch
1112
from torch.export import ExportedProgram
1213
from torch.fx.node import Target
@@ -108,7 +109,8 @@ def cross_compile_for_windows(
108109
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
109110
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
110111
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
111-
cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET,
112+
enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING,
113+
cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET,
112114
**kwargs: Any,
113115
) -> torch.fx.GraphModule:
114116
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -183,7 +185,8 @@ def cross_compile_for_windows(
183185
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
184186
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
185187
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
186-
cpu_memory_budget (int): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. If set to -1, the compilation will use all available CPU memory.
188+
enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited.
189+
cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail.
187190
**kwargs: Any,
188191
Returns:
189192
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -339,6 +342,7 @@ def cross_compile_for_windows(
339342
"tiling_optimization_level": tiling_optimization_level,
340343
"l2_limit_for_tiling": l2_limit_for_tiling,
341344
"use_distributed_mode_trace": use_distributed_mode_trace,
345+
"enable_resource_partitioning": enable_resource_partitioning,
342346
"cpu_memory_budget": cpu_memory_budget,
343347
}
344348

@@ -441,7 +445,8 @@ def compile(
441445
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
442446
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
443447
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
444-
cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET,
448+
cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET,
449+
enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING,
445450
**kwargs: Any,
446451
) -> torch.fx.GraphModule:
447452
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -519,6 +524,8 @@ def compile(
519524
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
520525
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
521526
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
527+
enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited.
528+
cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail.
522529
**kwargs: Any,
523530
Returns:
524531
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -688,6 +695,7 @@ def compile(
688695
"l2_limit_for_tiling": l2_limit_for_tiling,
689696
"offload_module_to_cpu": offload_module_to_cpu,
690697
"use_distributed_mode_trace": use_distributed_mode_trace,
698+
"enable_resource_partitioning": enable_resource_partitioning,
691699
"cpu_memory_budget": cpu_memory_budget,
692700
}
693701
logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB")
@@ -862,10 +870,15 @@ def preserve_module_specs(
862870
require_full_compilation=settings.require_full_compilation,
863871
)
864872

865-
partitioned_module = resource_partition(
866-
partitioned_module,
867-
cpu_memory_budget=settings.cpu_memory_budget,
868-
)
873+
if settings.enable_resource_partitioning:
874+
partitioned_module = resource_partition(
875+
partitioned_module,
876+
cpu_memory_budget=(
877+
settings.cpu_memory_budget
878+
if settings.cpu_memory_budget is not None
879+
else psutil.virtual_memory().available
880+
),
881+
)
869882

870883
dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators
871884

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import platform
33
import tempfile
44

5-
import psutil
65
import torch
76
from torch_tensorrt._Device import Device
87
from torch_tensorrt._enums import EngineCapability, dtype
@@ -58,7 +57,8 @@
5857
L2_LIMIT_FOR_TILING = -1
5958
USE_DISTRIBUTED_MODE_TRACE = False
6059
OFFLOAD_MODULE_TO_CPU = False
61-
CPU_MEMORY_BUDGET = psutil.virtual_memory().available
60+
ENABLE_RESOURCE_PARTITIONING = False
61+
CPU_MEMORY_BUDGET = None
6262

6363
if platform.system() == "Linux":
6464
import pwd

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
DRYRUN,
1616
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
1717
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
18+
ENABLE_RESOURCE_PARTITIONING,
1819
ENABLE_WEIGHT_STREAMING,
1920
ENABLED_PRECISIONS,
2021
ENGINE_CAPABILITY,
@@ -141,6 +142,7 @@ class CompilationSettings:
141142
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
142143
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
143144
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
145+
enable_resource_partitioning: bool = ENABLE_RESOURCE_PARTITIONING
144146
cpu_memory_budget: int = CPU_MEMORY_BUDGET
145147

146148
def __getstate__(self) -> dict[str, Any]:
@@ -174,6 +176,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
174176
"enable_weight_streaming",
175177
"tiling_optimization_level",
176178
"l2_limit_for_tiling",
179+
"enable_resource_partitioning",
177180
"cpu_memory_budget",
178181
)
179182

py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959

6060
logger = logging.getLogger(__name__)
6161

62+
MAX_NUM_OF_ENGINES = 40
63+
6264

6365
class ResourcePartitioner(_SplitterBase): # type: ignore
6466
"""Refine capability-based subgraphs to meet host CPU memory constraints.
@@ -148,10 +150,6 @@ def put_nodes_into_subgraphs(self) -> list[Subgraph]:
148150
subgraphs = [Subgraph(is_acc=True, nodes=nodes)]
149151
self.fusion_patterns = get_node_in_fusion_pattern(self.module.graph)
150152

151-
assert self.check_topological_order(
152-
subgraphs
153-
), "The subgraphs are not topologically ordered"
154-
155153
return subgraphs
156154

157155
def check_topological_order(self, subgraphs: List[Subgraph]) -> bool:
@@ -214,7 +212,7 @@ def break_subgraphs(
214212
# We throw an error if the remaining memory is almost empty compared to the model size.
215213
# i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation.
216214
sizes = self.size_of_subgraphs(subgraphs)
217-
if sum(sizes) > subgraph_size_budget * 40:
215+
if sum(sizes) > subgraph_size_budget * MAX_NUM_OF_ENGINES:
218216
raise ValueError(
219217
"CPU memory budget or available memory is too small to compile the model. "
220218
+ f"CPU memory budget: {self.cpu_memory_budget // (1024 * 1024)} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. "
@@ -233,11 +231,11 @@ def break_subgraphs(
233231
if len(subgraph.nodes) != 0:
234232
new_subgraphs.append(subgraph)
235233

236-
self._varify_all_fusion_nodes_in_same_subgraph(new_subgraphs)
234+
self._verify_all_fusion_nodes_in_same_subgraph(new_subgraphs)
237235

238236
return new_subgraphs
239237

240-
def _varify_all_fusion_nodes_in_same_subgraph(
238+
def _verify_all_fusion_nodes_in_same_subgraph(
241239
self, subgraphs: List[Subgraph]
242240
) -> None:
243241
"""Assert that every fusion group is contained in exactly one subgraph."""

tests/py/dynamo/partitioning/test_resource_partitioning.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def forward(self, x):
6363
"min_block_size": 1,
6464
"immutable_weights": True,
6565
"reuse_cached_engines": False,
66+
"enable_resource_partitioning": True,
6667
}
6768
settings = CompilationSettings(**compilation_options)
6869

@@ -144,6 +145,7 @@ def forward(self, x):
144145
"immutable_weights": True,
145146
"reuse_cached_engines": False,
146147
"torch_executed_ops": {"torch.ops.aten.max_pool2d.default"},
148+
"enable_resource_partitioning": True,
147149
}
148150
settings = CompilationSettings(**compilation_options)
149151

@@ -175,8 +177,8 @@ def forward(self, x):
175177
if "_run_on_acc" in name
176178
]
177179
)
178-
== 5
179-
), "The graph should have 5 accelerated subgraphs"
180+
> 3
181+
), "The graph should have more than 3 accelerated subgraphs"
180182
assert (
181183
len(
182184
[
@@ -275,6 +277,7 @@ def forward(self, x):
275277
"immutable_weights": True,
276278
"reuse_cached_engines": False,
277279
"torch_executed_ops": {"torch.ops.aten.max_pool2d.default"},
280+
"enable_resource_partitioning": True,
278281
}
279282
settings = CompilationSettings(**compilation_options)
280283

@@ -355,6 +358,7 @@ def forward(self, x):
355358
"min_block_size": 1,
356359
"immutable_weights": True,
357360
"reuse_cached_engines": False,
361+
"enable_resource_partitioning": True,
358362
}
359363
settings = CompilationSettings(**compilation_options)
360364

@@ -409,7 +413,7 @@ def forward(self, x):
409413
assert broken_fusion
410414

411415
# The fusion should be fixed after the step
412-
partitioner._varify_all_fusion_nodes_in_same_subgraph(new_subgraphs)
416+
partitioner._verify_all_fusion_nodes_in_same_subgraph(new_subgraphs)
413417

414418
break
415419

@@ -463,6 +467,7 @@ def forward(self, x):
463467
"immutable_weights": True,
464468
"reuse_cached_engines": False,
465469
"torch_executed_ops": {"torch.ops.aten.max_pool2d.default"},
470+
"enable_resource_partitioning": True,
466471
}
467472
settings = CompilationSettings(**compilation_options)
468473

0 commit comments

Comments
 (0)