Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 48f21f6

Browse files
committed
update linear bench
1 parent 2ffcbe9 commit 48f21f6

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

benchmarks/bench_linear_float8.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class Experiment:
5656
dtype: torch.dtype
5757
compiled: bool = False
5858
float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn
59+
recompute_weight_cast: bool = False
5960

6061
# 3 Times since we are calculating forward backward
6162
@property
@@ -95,9 +96,14 @@ def main(
9596
}
9697
input_bias = False
9798
ref_dtypes = [torch.bfloat16, torch.float16]
99+
recompute_weight_casts = [True, False]
98100
experiment_list: List[Experiment] = []
99-
for idx, (dtype, (name, (K, N))) in enumerate(
100-
tqdm(list(product(ref_dtypes, name_to_shapes_70b.items())))
101+
for idx, (dtype, (name, (K, N)), recompute_weight_cast) in enumerate(
102+
tqdm(
103+
list(
104+
product(ref_dtypes, name_to_shapes_70b.items(), recompute_weight_casts)
105+
)
106+
)
101107
):
102108
if n_limit is not None and idx >= n_limit:
103109
break
@@ -106,7 +112,9 @@ def main(
106112
)
107113

108114
linear_float8 = Float8Linear.from_float(
109-
copy.deepcopy(linear_ref), emulate=False
115+
copy.deepcopy(linear_ref),
116+
emulate=False,
117+
recompute_weight_cast=recompute_weight_cast,
110118
)
111119

112120
bsz, seq_len = 4, 4096
@@ -155,6 +163,7 @@ def wrapper(*args, **kwargs):
155163
float8_time,
156164
dtype,
157165
compile,
166+
recompute_weight_cast=recompute_weight_cast,
158167
)
159168
print(experiment)
160169
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
@@ -169,6 +178,7 @@ def wrapper(*args, **kwargs):
169178
"ref_dtype",
170179
"compiled",
171180
"fp8_dtype",
181+
"recompute_weight_cast",
172182
"ref_time_sec",
173183
"pt_fp8_time_sec",
174184
"ref_tops_sec",
@@ -187,6 +197,7 @@ def wrapper(*args, **kwargs):
187197
experiment.dtype,
188198
experiment.compiled,
189199
experiment.float_8_dtype,
200+
experiment.recompute_weight_cast,
190201
experiment.ref_time_sec,
191202
experiment.float8_time_sec,
192203
experiment.ref_tops_sec,
@@ -214,6 +225,7 @@ def wrapper(*args, **kwargs):
214225
"shape",
215226
"ref_dtype",
216227
"compiled",
228+
"recompute_weight_cast",
217229
"ref_time_sec",
218230
"pt_fp8_time_sec",
219231
"pt_fp8_speedup",

0 commit comments

Comments
 (0)