Skip to content

🐛 [Bug] TorchTensorRTRuntime uses more memory than PythonTorchTensorRTRuntime when there are multiple back-to-back engines. #3966

@cehongwang

Description

@cehongwang
import torch
import torch_tensorrt
from diffusers import (
    DiffusionPipeline,
    FluxPipeline,
    StableDiffusion3Pipeline,
    StableDiffusionPipeline,
)


from torch.export._trace import _export
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from pyinstrument import Profiler
import time

def benchmark_model(model, input, label, profile=False):
    if profile:
        profiler = Profiler(interval=0.01)
        profiler.start()
    start_time = time.time()
    for _ in range(30):
        model_outputs = model(**input)
    end_time = time.time()
    print(f"{label} 30 runs: {end_time - start_time:.4f} seconds")
    if profile:
        profiler.stop()
        profiler.write_html(f"/home/other/flux_{label.replace(' ', '_')}.html", timeline=False, show_all=True)


device = "cuda"
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.float16,
)

pipe.to(device)
pipe.to(torch.float16)
backbone = pipe.transformer 

# backbone = FluxTransformer2DModel(num_layers=1, num_single_layers=1, guidance_embeds=True).eval().to(torch.float16).cuda()


batch_size = 1
DEVICE = "cuda"


dummy_inputs = {
    "hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
        DEVICE
    ),
    "encoder_hidden_states": torch.randn(
        (batch_size, 512, 4096), dtype=torch.float16
    ).to(DEVICE),
    "pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
        DEVICE
    ),
    "timestep": torch.tensor([1.0, 1.0], dtype=torch.float16).to(DEVICE),
    "txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
    "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
    "guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
    "joint_attention_kwargs": {},
    "return_dict": False,
}


ep = _export(
    backbone,
    args=(),
    kwargs=dummy_inputs,
    # dynamic_shapes=dynamic_shapes,
    strict=False,
)

# breakpoint()
with torch_tensorrt.dynamo.Debugger(log_level="debug", engine_builder_monitor=False):
    trt_gm = torch_tensorrt.dynamo.compile(
        ep,
        inputs=dummy_inputs,
        # enabled_precisions={torch.float16},
        truncate_double=True,
        dryrun=False,
        min_block_size=1,
        offload_module_to_cpu=True,
        use_python_runtime=True,
        enable_resource_partitioning=True,
        lazy_engine_init=True,
        cpu_memory_budget=53*1024*1024*1024,

    )




benchmark_model(trt_gm, dummy_inputs, "flux_trt_gm")

Enable use_python_runtime can successfully compile. If disable it we get CUDAOOM error.

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions