Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions examples/dynamo/low_cpu_memory_compilation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""

.. _low_cpu_memory_compilation:

Low CPU Memory Compilation Example
==================================

This example demonstrates compiling a model with a bounded CPU (host) memory
budget using Torch-TensorRT Dynamo. Limiting host RAM use is helpful on
memory-constrained machines or when compiling very large models.

Key notes:
- The toy model below has roughly 430 MB of parameters. We set the CPU
memory budget to 2 GiB. At compile time, only about 900 MB of host RAM
may remain available. We expect at most 403 * 4 = 1612 MB of memory to be used by the model.
So the model is partitioned into two subgraphs to fit the memory budget.

- Performance impact varies by model. When the number of TensorRT engines
created is small, the impact is typically minimal.

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt as torchtrt
from torch_tensorrt.dynamo.conversion import CompilationSettings


class net(nn.Module):
def __init__(self):
super().__init__()
# Intentionally large layers to stress host memory during compilation.
self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1)
self.bn1 = nn.BatchNorm2d(4096)
self.conv2 = nn.Conv2d(4096, 1024, 3, padding=1)
self.bn2 = nn.BatchNorm2d(1024)
self.fc1 = nn.Linear(1024 * 56 * 56, 10)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = F.max_pool2d(x, (2, 2))
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = F.max_pool2d(x, (2, 2))
x = torch.flatten(x, 1)
return self.fc1(x)


model = net().eval()
model.to("cuda")
inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")]

enabled_precisions = {torch.float}
use_python_runtime = False

compilation_options = {
"use_python_runtime": use_python_runtime,
"enabled_precisions": enabled_precisions,
"min_block_size": 1,
"immutable_weights": True,
"reuse_cached_engines": False,
"enable_resource_partitioning": True,
"cpu_memory_budget": 2 * 1024 * 1024 * 1024, # 2 GiB in bytes
}

settings = CompilationSettings(**compilation_options)
with torchtrt.dynamo.Debugger(
log_level="debug",
logging_dir="/home/profile/logging/moe",
engine_builder_monitor=False,
):

exp_program = torch.export.export(model, tuple(inputs))
trt_gm = torchtrt.dynamo.compile(
exp_program,
inputs=inputs,
**compilation_options,
)

# Expect two back-to-back TensorRT engines due to partitioning under the memory budget.
print(trt_gm)


"""
You should be able to see two back-to-back TensorRT engines in the graph

Graph Structure:

Inputs: List[Tensor: (1, 1024, 224, 224)@float32]
...
TRT Engine #1 - Submodule name: _run_on_acc_0_resource_split_0
Engine Inputs: List[Tensor: (1, 1024, 224, 224)@float32]
Number of Operators in Engine: 9
Engine Outputs: List[Tensor: (1, 1024, 112, 112)@float32]
...
TRT Engine #2 - Submodule name: _run_on_acc_0_resource_split_1
Engine Inputs: List[Tensor: (1, 1024, 112, 112)@float32]
Number of Operators in Engine: 3
Engine Outputs: List[Tensor: (1, 10)@float32]
...
Outputs: List[Tensor: (1, 10)@float32]

------------------------- Aggregate Stats -------------------------

Average Number of Operators per TRT Engine: 6.0
Most Operators in a TRT Engine: 9

********** Recommendations **********

- For minimal graph segmentation, select min_block_size=9 which would generate 1 TRT engine(s)
- For moderate graph segmentation, select min_block_size=6 which would generate 1 TRT engine(s)
- The current level of graph segmentation is equivalent to selecting min_block_size=3 which generates 2 TRT engine(s)
GraphModule(
(_run_on_acc_0_resource_split_0): TorchTensorRTModule()
(_run_on_acc_0_resource_split_1): TorchTensorRTModule()
)



def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_run_on_acc_0_resource_split_0 = self._run_on_acc_0_resource_split_0(x); x = None
_run_on_acc_0_resource_split_1 = self._run_on_acc_0_resource_split_1(_run_on_acc_0_resource_split_0); _run_on_acc_0_resource_split_0 = None
return pytree.tree_unflatten((_run_on_acc_0_resource_split_1,), self._out_spec)
)
"""
24 changes: 23 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
post_lowering,
pre_export_lowering,
)
from torch_tensorrt.dynamo.partitioning._resource_partitioner import (
resource_partition,
)
from torch_tensorrt.dynamo.utils import (
deallocate_module,
get_cpu_memory_usage,
Expand Down Expand Up @@ -105,6 +108,8 @@ def cross_compile_for_windows(
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING,
cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
Expand Down Expand Up @@ -179,6 +184,8 @@ def cross_compile_for_windows(
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited.
cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -334,6 +341,8 @@ def cross_compile_for_windows(
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"use_distributed_mode_trace": use_distributed_mode_trace,
"enable_resource_partitioning": enable_resource_partitioning,
"cpu_memory_budget": cpu_memory_budget,
}

# disable the following settings is not supported for cross compilation for windows feature
Expand Down Expand Up @@ -435,6 +444,8 @@ def compile(
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET,
enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -512,6 +523,8 @@ def compile(
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited.
cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -681,6 +694,8 @@ def compile(
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
"use_distributed_mode_trace": use_distributed_mode_trace,
"enable_resource_partitioning": enable_resource_partitioning,
"cpu_memory_budget": cpu_memory_budget,
}
logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB")
settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -854,6 +869,12 @@ def preserve_module_specs(
require_full_compilation=settings.require_full_compilation,
)

if settings.enable_resource_partitioning:
partitioned_module = resource_partition(
partitioned_module,
cpu_memory_budget=settings.cpu_memory_budget,
)

dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators

# The global partitioner leaves non-TRT nodes as-is
Expand All @@ -877,6 +898,7 @@ def preserve_module_specs(
for attr in dir(gm):
if attr.startswith("_frozen_param"):
delattr(gm, attr)

for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)
# filter on the GraphModule
Expand Down Expand Up @@ -1339,7 +1361,7 @@ def convert_exported_program_to_serialized_trt_engine(
)

flattened_input_list = get_flat_args_with_check(
exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore
exported_program, list(trt_arg_inputs), trt_kwarg_inputs
)[0]

try:
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
L2_LIMIT_FOR_TILING = -1
USE_DISTRIBUTED_MODE_TRACE = False
OFFLOAD_MODULE_TO_CPU = False
ENABLE_RESOURCE_PARTITIONING = False
CPU_MEMORY_BUDGET = None

if platform.system() == "Linux":
import pwd
Expand Down
5 changes: 5 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from torch_tensorrt.dynamo._defaults import (
ASSUME_DYNAMIC_SHAPE_SUPPORT,
CACHE_BUILT_ENGINES,
CPU_MEMORY_BUDGET,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
DRYRUN,
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENABLE_RESOURCE_PARTITIONING,
ENABLE_WEIGHT_STREAMING,
ENABLED_PRECISIONS,
ENGINE_CAPABILITY,
Expand Down Expand Up @@ -140,6 +142,8 @@ class CompilationSettings:
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
enable_resource_partitioning: bool = ENABLE_RESOURCE_PARTITIONING
cpu_memory_budget: int = CPU_MEMORY_BUDGET

def __getstate__(self) -> dict[str, Any]:
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand Down Expand Up @@ -172,6 +176,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
"enable_weight_streaming",
"tiling_optimization_level",
"l2_limit_for_tiling",
"enable_resource_partitioning",
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def partition_graph(self) -> torch.fx.GraphModule:

# Tag the accelerated nodes and split the graph accordingly
self.tag(subgraphs)
return self.split()
return self.split(remove_tag=True)

def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
"""Generates starter nodes for partitioning + segmentation"""
Expand Down
Loading
Loading