Skip to content

Commit 27dee53

Browse files
committed
Enable int8 and fp8 quantization for FLUX
1 parent ba76f6d commit 27dee53

File tree

5 files changed

+263
-17
lines changed

5 files changed

+263
-17
lines changed

examples/apps/flux-quantization.py

+51-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# %%
22
# Import the following libraries
33
# -----------------------------
4+
# Load the ModelOpt-modified model architecture and weights using Huggingface APIs
5+
# Add argument parsing for dtype selection
6+
import argparse
47
import re
58

69
import modelopt.torch.opt as mto
@@ -14,17 +17,36 @@
1417
from torch.export._trace import _export
1518
from transformers import AutoModelForCausalLM
1619

17-
# Load the ModelOpt-modified model architecture and weights using Huggingface APIs
20+
parser = argparse.ArgumentParser(
21+
description="Run Flux quantization with different dtypes"
22+
)
23+
parser.add_argument(
24+
"--dtype",
25+
choices=["fp8", "int8"],
26+
default="int8",
27+
help="Quantization data type to use (fp8 or int8)",
28+
)
1829

30+
args = parser.parse_args()
31+
32+
# Update enabled precisions based on dtype argument
33+
if args.dtype == "fp8":
34+
enabled_precisions = {torch.float8_e4m3fn, torch.float16}
35+
ptq_config = mtq.FP8_DEFAULT_CFG
36+
else: # int8
37+
enabled_precisions = {torch.int8, torch.float16}
38+
ptq_config = mtq.INT8_DEFAULT_CFG
39+
ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None
40+
print(f"\nUsing {args.dtype} quantization")
1941
# %%
2042
DEVICE = "cuda:0"
2143
pipe = FluxPipeline.from_pretrained(
2244
"black-forest-labs/FLUX.1-dev",
2345
torch_dtype=torch.float16,
2446
)
25-
pipe.transformer = FluxTransformer2DModel(
26-
num_layers=1, num_single_layers=1, guidance_embeds=True
27-
)
47+
# pipe.transformer = FluxTransformer2DModel(
48+
# num_layers=1, num_single_layers=1, guidance_embeds=True
49+
# )
2850

2951
pipe.to(DEVICE).to(torch.float16)
3052
# Store the config and transformer backbone
@@ -83,11 +105,10 @@ def forward_loop(mod):
83105
)
84106

85107

86-
ptq_config = mtq.FP8_DEFAULT_CFG
87108
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
88109
mtq.disable_quantizer(backbone, filter_func)
89110

90-
batch_size = 1
111+
batch_size = 2
91112
BATCH = torch.export.Dim("batch", min=1, max=2)
92113
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
93114
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
@@ -129,7 +150,7 @@ def forward_loop(mod):
129150
backbone,
130151
args=(),
131152
kwargs=dummy_inputs,
132-
# dynamic_shapes=dynamic_shapes,
153+
dynamic_shapes=dynamic_shapes,
133154
strict=False,
134155
allow_complex_guards_as_runtime_asserts=True,
135156
)
@@ -138,10 +159,10 @@ def forward_loop(mod):
138159
trt_gm = torch_tensorrt.dynamo.compile(
139160
ep,
140161
inputs=dummy_inputs,
141-
enabled_precisions={torch.float8_e4m3fn, torch.float16},
162+
enabled_precisions=enabled_precisions,
142163
truncate_double=True,
143164
min_block_size=1,
144-
debug=True,
165+
debug=False,
145166
use_python_runtime=True,
146167
immutable_weights=True,
147168
offload_module_to_cpu=True,
@@ -156,8 +177,27 @@ def forward_loop(mod):
156177
# %%
157178
trt_gm.device = torch.device(DEVICE)
158179
# Function which generates images from the flux pipeline
180+
generate_image(pipe, ["A golden retriever"], "dog_code2")
181+
182+
183+
def benchmark(prompt, inference_step, batch_size=2, iterations=1):
184+
from time import time
185+
186+
start = time()
187+
for i in range(iterations):
188+
image = pipe(
189+
prompt,
190+
output_type="pil",
191+
num_inference_steps=inference_step,
192+
num_images_per_prompt=batch_size,
193+
).images
194+
end = time()
195+
print("Time Elapse for", iterations, "iterations:", end - start)
196+
print("Average Latency Per Step:", (end - start) / inference_step / iterations)
197+
return image
198+
159199

160-
for _ in range(2):
161-
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
200+
print(f"Benchmark Original PyTorch Module Latency ({args.dtype})")
201+
benchmark(["Test"], 50, iterations=3)
162202

163203
# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB

py/torch_tensorrt/dynamo/conversion/impl/quantize.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def quantize(
8484
elif num_bits == 8 and exponent_bits == 4:
8585
dtype = trt.DataType.FP8
8686

87+
if not isinstance(input_tensor, TRTTensor):
88+
input_tensor = get_trt_tensor(ctx, input_tensor, name + "_quantize_input")
89+
8790
quantize_layer = ctx.net.add_quantize(input_tensor, scale, dtype)
8891

8992
set_layer_name(quantize_layer, target, name + "_quantize", source_ir)
@@ -93,11 +96,8 @@ def quantize(
9396
q_output, scale, output_type=input_tensor.dtype
9497
)
9598
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
96-
if num_bits == 8 and exponent_bits == 0:
97-
dequantize_layer.precision = trt.DataType.INT8
98-
elif num_bits == 8 and exponent_bits == 4:
99-
# Set DQ layer precision to FP8
100-
dequantize_layer.precision = trt.DataType.FP8
99+
dequantize_layer.precision = dtype
100+
101101
dq_output = dequantize_layer.get_output(0)
102102

103103
return dq_output

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

+2
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
101101

102102
# TODO: Update this function when quantization is added
103103
def is_impure(self, node: torch.fx.node.Node) -> bool:
104+
if node.target == torch.ops.tensorrt.quantize_op.default:
105+
return True
104106
return False

tools/perf/Flux/benchmark.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#TODO: Enter the HF Token
22
huggingface-cli login --token HF_TOKEN
33

4-
python flux_perf.py > benchmark_output.txt
4+
python flux_quantization.py --dtype fp8 > fp8_benchmark.txt
5+
python flux_quantization.py --dtype int8 > int8_benchmark.txt

tools/perf/Flux/flux-quantization.py

+203
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# %%
2+
# Import the following libraries
3+
# -----------------------------
4+
# Load the ModelOpt-modified model architecture and weights using Huggingface APIs
5+
# Add argument parsing for dtype selection
6+
import argparse
7+
import re
8+
9+
import modelopt.torch.opt as mto
10+
import modelopt.torch.quantization as mtq
11+
import torch
12+
import torch_tensorrt
13+
from diffusers import FluxPipeline
14+
from diffusers.models.attention_processor import Attention
15+
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
16+
from modelopt.torch.quantization.utils import export_torch_mode
17+
from torch.export._trace import _export
18+
from transformers import AutoModelForCausalLM
19+
20+
parser = argparse.ArgumentParser(
21+
description="Run Flux quantization with different dtypes"
22+
)
23+
parser.add_argument(
24+
"--dtype",
25+
choices=["fp8", "int8"],
26+
default="int8",
27+
help="Quantization data type to use (fp8 or int8)",
28+
)
29+
30+
args = parser.parse_args()
31+
32+
# Update enabled precisions based on dtype argument
33+
if args.dtype == "fp8":
34+
enabled_precisions = {torch.float8_e4m3fn, torch.float16}
35+
ptq_config = mtq.FP8_DEFAULT_CFG
36+
else: # int8
37+
enabled_precisions = {torch.int8, torch.float16}
38+
ptq_config = mtq.INT8_DEFAULT_CFG
39+
ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None
40+
print(f"\nUsing {args.dtype} quantization")
41+
# %%
42+
DEVICE = "cuda:0"
43+
pipe = FluxPipeline.from_pretrained(
44+
"black-forest-labs/FLUX.1-dev",
45+
torch_dtype=torch.float16,
46+
)
47+
# pipe.transformer = FluxTransformer2DModel(
48+
# num_layers=1, num_single_layers=1, guidance_embeds=True
49+
# )
50+
51+
pipe.to(DEVICE).to(torch.float16)
52+
# Store the config and transformer backbone
53+
config = pipe.transformer.config
54+
# global backbone
55+
backbone = pipe.transformer
56+
backbone.eval()
57+
58+
59+
def filter_func(name):
60+
pattern = re.compile(
61+
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*"
62+
)
63+
return pattern.match(name) is not None
64+
65+
66+
def generate_image(pipe, prompt, image_name):
67+
seed = 42
68+
image = pipe(
69+
prompt,
70+
output_type="pil",
71+
num_inference_steps=20,
72+
generator=torch.Generator("cuda").manual_seed(seed),
73+
).images[0]
74+
image.save(f"{image_name}.png")
75+
print(f"Image generated using {image_name} model saved as {image_name}.png")
76+
77+
78+
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
79+
80+
# %%
81+
# Quantization
82+
83+
84+
def do_calibrate(
85+
pipe,
86+
prompt: str,
87+
) -> None:
88+
"""
89+
Run calibration steps on the pipeline using the given prompts.
90+
"""
91+
image = pipe(
92+
prompt,
93+
output_type="pil",
94+
num_inference_steps=20,
95+
generator=torch.Generator("cuda").manual_seed(0),
96+
).images[0]
97+
98+
99+
def forward_loop(mod):
100+
# Switch the pipeline's backbone, run calibration
101+
pipe.transformer = mod
102+
do_calibrate(
103+
pipe=pipe,
104+
prompt="test",
105+
)
106+
107+
108+
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
109+
mtq.disable_quantizer(backbone, filter_func)
110+
111+
batch_size = 2
112+
BATCH = torch.export.Dim("batch", min=1, max=2)
113+
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
114+
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
115+
# To see this recommendation, you can try exporting using min=1, max=4096
116+
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
117+
dynamic_shapes = {
118+
"hidden_states": {0: BATCH},
119+
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
120+
"pooled_projections": {0: BATCH},
121+
"timestep": {0: BATCH},
122+
"txt_ids": {0: SEQ_LEN},
123+
"img_ids": {0: IMG_ID},
124+
"guidance": {0: BATCH},
125+
"joint_attention_kwargs": {},
126+
"return_dict": None,
127+
}
128+
# The guidance factor is of type torch.float32
129+
dummy_inputs = {
130+
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
131+
DEVICE
132+
),
133+
"encoder_hidden_states": torch.randn(
134+
(batch_size, 512, 4096), dtype=torch.float16
135+
).to(DEVICE),
136+
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
137+
DEVICE
138+
),
139+
"timestep": torch.tensor([1.0] * batch_size, dtype=torch.float16).to(DEVICE),
140+
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
141+
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
142+
"guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to(DEVICE),
143+
"joint_attention_kwargs": {},
144+
"return_dict": False,
145+
}
146+
147+
# This will create an exported program which is going to be compiled with Torch-TensorRT
148+
with export_torch_mode():
149+
ep = _export(
150+
backbone,
151+
args=(),
152+
kwargs=dummy_inputs,
153+
dynamic_shapes=dynamic_shapes,
154+
strict=False,
155+
allow_complex_guards_as_runtime_asserts=True,
156+
)
157+
158+
with torch_tensorrt.logging.debug():
159+
trt_gm = torch_tensorrt.dynamo.compile(
160+
ep,
161+
inputs=dummy_inputs,
162+
enabled_precisions=enabled_precisions,
163+
truncate_double=True,
164+
min_block_size=1,
165+
debug=False,
166+
use_python_runtime=True,
167+
immutable_weights=True,
168+
offload_module_to_cpu=True,
169+
)
170+
171+
172+
del ep
173+
pipe.transformer = trt_gm
174+
pipe.transformer.config = config
175+
176+
177+
# %%
178+
trt_gm.device = torch.device(DEVICE)
179+
# Function which generates images from the flux pipeline
180+
generate_image(pipe, ["A golden retriever"], "dog_code2")
181+
182+
183+
def benchmark(prompt, inference_step, batch_size=2, iterations=1):
184+
from time import time
185+
186+
start = time()
187+
for i in range(iterations):
188+
image = pipe(
189+
prompt,
190+
output_type="pil",
191+
num_inference_steps=inference_step,
192+
num_images_per_prompt=batch_size,
193+
).images
194+
end = time()
195+
print("Time Elapse for", iterations, "iterations:", end - start)
196+
print("Average Latency Per Step:", (end - start) / inference_step / iterations)
197+
return image
198+
199+
200+
print(f"Benchmark Original PyTorch Module Latency ({args.dtype})")
201+
benchmark(["Test"], 50, iterations=3)
202+
203+
# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB

0 commit comments

Comments
 (0)