-
Notifications
You must be signed in to change notification settings - Fork 376
Open
Labels
bugSomething isn't workingSomething isn't working
Description
import torch
import torch_tensorrt
from diffusers import (
DiffusionPipeline,
FluxPipeline,
StableDiffusion3Pipeline,
StableDiffusionPipeline,
)
from torch.export._trace import _export
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from pyinstrument import Profiler
import time
def benchmark_model(model, input, label, profile=False):
if profile:
profiler = Profiler(interval=0.01)
profiler.start()
start_time = time.time()
for _ in range(30):
model_outputs = model(**input)
end_time = time.time()
print(f"{label} 30 runs: {end_time - start_time:.4f} seconds")
if profile:
profiler.stop()
profiler.write_html(f"/home/other/flux_{label.replace(' ', '_')}.html", timeline=False, show_all=True)
device = "cuda"
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
)
pipe.to(device)
pipe.to(torch.float16)
backbone = pipe.transformer
# backbone = FluxTransformer2DModel(num_layers=1, num_single_layers=1, guidance_embeds=True).eval().to(torch.float16).cuda()
batch_size = 1
DEVICE = "cuda"
dummy_inputs = {
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
DEVICE
),
"encoder_hidden_states": torch.randn(
(batch_size, 512, 4096), dtype=torch.float16
).to(DEVICE),
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
DEVICE
),
"timestep": torch.tensor([1.0, 1.0], dtype=torch.float16).to(DEVICE),
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
"joint_attention_kwargs": {},
"return_dict": False,
}
ep = _export(
backbone,
args=(),
kwargs=dummy_inputs,
# dynamic_shapes=dynamic_shapes,
strict=False,
)
# breakpoint()
with torch_tensorrt.dynamo.Debugger(log_level="debug", engine_builder_monitor=False):
trt_gm = torch_tensorrt.dynamo.compile(
ep,
inputs=dummy_inputs,
# enabled_precisions={torch.float16},
truncate_double=True,
dryrun=False,
min_block_size=1,
offload_module_to_cpu=True,
use_python_runtime=True,
enable_resource_partitioning=True,
lazy_engine_init=True,
cpu_memory_budget=53*1024*1024*1024,
)
benchmark_model(trt_gm, dummy_inputs, "flux_trt_gm")Enable use_python_runtime can successfully compile. If disable it we get CUDAOOM error.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working