Skip to content

Commit 1d4ca9c

Browse files
committed
Enable int8 and fp8 quantization for FLUX
1 parent ba76f6d commit 1d4ca9c

File tree

5 files changed

+257
-9
lines changed

5 files changed

+257
-9
lines changed

examples/apps/flux-quantization.py

+48-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# %%
22
# Import the following libraries
33
# -----------------------------
4+
# Load the ModelOpt-modified model architecture and weights using Huggingface APIs
5+
# Add argument parsing for dtype selection
6+
import argparse
47
import re
58

69
import modelopt.torch.opt as mto
@@ -14,8 +17,27 @@
1417
from torch.export._trace import _export
1518
from transformers import AutoModelForCausalLM
1619

17-
# Load the ModelOpt-modified model architecture and weights using Huggingface APIs
20+
parser = argparse.ArgumentParser(
21+
description="Run Flux quantization with different dtypes"
22+
)
23+
parser.add_argument(
24+
"--dtype",
25+
choices=["fp8", "int8"],
26+
default="int8",
27+
help="Quantization data type to use (fp8 or int8)",
28+
)
1829

30+
args = parser.parse_args()
31+
32+
# Update enabled precisions based on dtype argument
33+
if args.dtype == "fp8":
34+
enabled_precisions = {torch.float8_e4m3fn, torch.float16}
35+
ptq_config = mtq.FP8_DEFAULT_CFG
36+
ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None
37+
else: # int8
38+
enabled_precisions = {torch.int8, torch.float16}
39+
ptq_config = mtq.INT8_DEFAULT_CFG
40+
print(f"\nUsing {args.dtype} quantization")
1941
# %%
2042
DEVICE = "cuda:0"
2143
pipe = FluxPipeline.from_pretrained(
@@ -83,11 +105,10 @@ def forward_loop(mod):
83105
)
84106

85107

86-
ptq_config = mtq.FP8_DEFAULT_CFG
87108
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
88109
mtq.disable_quantizer(backbone, filter_func)
89110

90-
batch_size = 1
111+
batch_size = 2
91112
BATCH = torch.export.Dim("batch", min=1, max=2)
92113
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
93114
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
@@ -129,7 +150,7 @@ def forward_loop(mod):
129150
backbone,
130151
args=(),
131152
kwargs=dummy_inputs,
132-
# dynamic_shapes=dynamic_shapes,
153+
dynamic_shapes=dynamic_shapes,
133154
strict=False,
134155
allow_complex_guards_as_runtime_asserts=True,
135156
)
@@ -138,10 +159,10 @@ def forward_loop(mod):
138159
trt_gm = torch_tensorrt.dynamo.compile(
139160
ep,
140161
inputs=dummy_inputs,
141-
enabled_precisions={torch.float8_e4m3fn, torch.float16},
162+
enabled_precisions=enabled_precisions,
142163
truncate_double=True,
143164
min_block_size=1,
144-
debug=True,
165+
debug=False,
145166
use_python_runtime=True,
146167
immutable_weights=True,
147168
offload_module_to_cpu=True,
@@ -156,8 +177,27 @@ def forward_loop(mod):
156177
# %%
157178
trt_gm.device = torch.device(DEVICE)
158179
# Function which generates images from the flux pipeline
180+
generate_image(pipe, ["A golden retriever"], "dog_code2")
181+
182+
183+
def benchmark(prompt, inference_step, batch_size=2, iterations=1):
184+
from time import time
185+
186+
start = time()
187+
for i in range(iterations):
188+
image = pipe(
189+
prompt,
190+
output_type="pil",
191+
num_inference_steps=inference_step,
192+
num_images_per_prompt=batch_size,
193+
).images
194+
end = time()
195+
print("Time Elapse for", iterations, "iterations:", end - start)
196+
print("Average Latency Per Step:", (end - start) / inference_step / iterations)
197+
return image
198+
159199

160-
for _ in range(2):
161-
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
200+
print("Benchmark Original PyTorch Module Latency (int8)")
201+
benchmark(["Test"], 50, iterations=3)
162202

163203
# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB

py/torch_tensorrt/dynamo/conversion/impl/quantize.py

+3
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def quantize(
8484
elif num_bits == 8 and exponent_bits == 4:
8585
dtype = trt.DataType.FP8
8686

87+
if not isinstance(input_tensor, TRTTensor):
88+
input_tensor = get_trt_tensor(ctx, input_tensor, name + "_quantize_input")
89+
8790
quantize_layer = ctx.net.add_quantize(input_tensor, scale, dtype)
8891

8992
set_layer_name(quantize_layer, target, name + "_quantize", source_ir)

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

+2
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
101101

102102
# TODO: Update this function when quantization is added
103103
def is_impure(self, node: torch.fx.node.Node) -> bool:
104+
if node.target == torch.ops.tensorrt.quantize_op.default:
105+
return True
104106
return False

tools/perf/Flux/benchmark.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#TODO: Enter the HF Token
22
huggingface-cli login --token HF_TOKEN
33

4-
python flux_perf.py > benchmark_output.txt
4+
python flux_quantization.py --dtype fp8 > fp8_benchmark.txt
5+
python flux_quantization.py --dtype int8 > int8_benchmark.txt

tools/perf/Flux/flux-quantization.py

+202
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# %%
2+
# Import the following libraries
3+
# -----------------------------
4+
# Load the ModelOpt-modified model architecture and weights using Huggingface APIs
5+
# Add argument parsing for dtype selection
6+
import argparse
7+
import re
8+
9+
import modelopt.torch.opt as mto
10+
import modelopt.torch.quantization as mtq
11+
import torch
12+
import torch_tensorrt
13+
from diffusers import FluxPipeline
14+
from diffusers.models.attention_processor import Attention
15+
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
16+
from modelopt.torch.quantization.utils import export_torch_mode
17+
from torch.export._trace import _export
18+
from transformers import AutoModelForCausalLM
19+
20+
parser = argparse.ArgumentParser(
21+
description="Run Flux quantization with different dtypes"
22+
)
23+
parser.add_argument(
24+
"--dtype",
25+
choices=["fp8", "int8"],
26+
default="int8",
27+
help="Quantization data type to use (fp8 or int8)",
28+
)
29+
30+
args = parser.parse_args()
31+
32+
# Update enabled precisions based on dtype argument
33+
if args.dtype == "fp8":
34+
enabled_precisions = {torch.float8_e4m3fn, torch.float16}
35+
ptq_config = mtq.FP8_DEFAULT_CFG
36+
else: # int8
37+
enabled_precisions = {torch.int8, torch.float16}
38+
ptq_config = mtq.INT8_DEFAULT_CFG
39+
print(f"\nUsing {args.dtype} quantization")
40+
# %%
41+
DEVICE = "cuda:0"
42+
pipe = FluxPipeline.from_pretrained(
43+
"black-forest-labs/FLUX.1-dev",
44+
torch_dtype=torch.float16,
45+
)
46+
pipe.transformer = FluxTransformer2DModel(
47+
num_layers=1, num_single_layers=1, guidance_embeds=True
48+
)
49+
50+
pipe.to(DEVICE).to(torch.float16)
51+
# Store the config and transformer backbone
52+
config = pipe.transformer.config
53+
# global backbone
54+
backbone = pipe.transformer
55+
backbone.eval()
56+
57+
58+
def filter_func(name):
59+
pattern = re.compile(
60+
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*"
61+
)
62+
return pattern.match(name) is not None
63+
64+
65+
def generate_image(pipe, prompt, image_name):
66+
seed = 42
67+
image = pipe(
68+
prompt,
69+
output_type="pil",
70+
num_inference_steps=20,
71+
generator=torch.Generator("cuda").manual_seed(seed),
72+
).images[0]
73+
image.save(f"{image_name}.png")
74+
print(f"Image generated using {image_name} model saved as {image_name}.png")
75+
76+
77+
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
78+
79+
# %%
80+
# Quantization
81+
82+
83+
def do_calibrate(
84+
pipe,
85+
prompt: str,
86+
) -> None:
87+
"""
88+
Run calibration steps on the pipeline using the given prompts.
89+
"""
90+
image = pipe(
91+
prompt,
92+
output_type="pil",
93+
num_inference_steps=20,
94+
generator=torch.Generator("cuda").manual_seed(0),
95+
).images[0]
96+
97+
98+
def forward_loop(mod):
99+
# Switch the pipeline's backbone, run calibration
100+
pipe.transformer = mod
101+
do_calibrate(
102+
pipe=pipe,
103+
prompt="test",
104+
)
105+
106+
107+
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
108+
mtq.disable_quantizer(backbone, filter_func)
109+
110+
batch_size = 2
111+
BATCH = torch.export.Dim("batch", min=1, max=2)
112+
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
113+
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
114+
# To see this recommendation, you can try exporting using min=1, max=4096
115+
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
116+
dynamic_shapes = {
117+
"hidden_states": {0: BATCH},
118+
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
119+
"pooled_projections": {0: BATCH},
120+
"timestep": {0: BATCH},
121+
"txt_ids": {0: SEQ_LEN},
122+
"img_ids": {0: IMG_ID},
123+
"guidance": {0: BATCH},
124+
"joint_attention_kwargs": {},
125+
"return_dict": None,
126+
}
127+
# The guidance factor is of type torch.float32
128+
dummy_inputs = {
129+
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
130+
DEVICE
131+
),
132+
"encoder_hidden_states": torch.randn(
133+
(batch_size, 512, 4096), dtype=torch.float16
134+
).to(DEVICE),
135+
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
136+
DEVICE
137+
),
138+
"timestep": torch.tensor([1.0] * batch_size, dtype=torch.float16).to(DEVICE),
139+
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
140+
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
141+
"guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to(DEVICE),
142+
"joint_attention_kwargs": {},
143+
"return_dict": False,
144+
}
145+
146+
# This will create an exported program which is going to be compiled with Torch-TensorRT
147+
with export_torch_mode():
148+
ep = _export(
149+
backbone,
150+
args=(),
151+
kwargs=dummy_inputs,
152+
dynamic_shapes=dynamic_shapes,
153+
strict=False,
154+
allow_complex_guards_as_runtime_asserts=True,
155+
)
156+
157+
with torch_tensorrt.logging.debug():
158+
trt_gm = torch_tensorrt.dynamo.compile(
159+
ep,
160+
inputs=dummy_inputs,
161+
enabled_precisions=enabled_precisions,
162+
truncate_double=True,
163+
min_block_size=1,
164+
debug=False,
165+
use_python_runtime=True,
166+
immutable_weights=True,
167+
offload_module_to_cpu=True,
168+
)
169+
170+
171+
del ep
172+
pipe.transformer = trt_gm
173+
pipe.transformer.config = config
174+
175+
176+
# %%
177+
trt_gm.device = torch.device(DEVICE)
178+
# Function which generates images from the flux pipeline
179+
generate_image(pipe, ["A golden retriever"], "dog_code2")
180+
181+
182+
def benchmark(prompt, inference_step, batch_size=2, iterations=1):
183+
from time import time
184+
185+
start = time()
186+
for i in range(iterations):
187+
image = pipe(
188+
prompt,
189+
output_type="pil",
190+
num_inference_steps=inference_step,
191+
num_images_per_prompt=batch_size,
192+
).images
193+
end = time()
194+
print("Time Elapse for", iterations, "iterations:", end - start)
195+
print("Average Latency Per Step:", (end - start) / inference_step / iterations)
196+
return image
197+
198+
199+
print("Benchmark Original PyTorch Module Latency (int8)")
200+
benchmark(["Test"], 50, iterations=3)
201+
202+
# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB

0 commit comments

Comments
 (0)