Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 26ce70d

Browse files
committed
update profile
1 parent d03d16b commit 26ce70d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

benchmarks/profile_linear_float8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class LinearParams:
8787
torch_compile: Optional[bool] = False
8888

8989

90-
def main(profile_path: Path, compile: bool, linear_type: str):
90+
def main(profile_path: Path, compile: bool, linear_type: str, recompute_weight_cast: bool):
9191
profile_path = Path(profile_path)
9292
assert profile_path.is_dir(), f"Path {profile_path} must be a directory"
9393
params = LinearParams(
@@ -110,7 +110,7 @@ def main(profile_path: Path, compile: bool, linear_type: str):
110110
dtype=params.ref_dtype,
111111
)
112112
linear_type = LinearType[linear_type.upper()]
113-
linear_float8 = get_float8_linear(linear_type, linear_ref)
113+
linear_float8 = get_float8_linear(linear_type, linear_ref, recompute_weight_cast=recompute_weight_cast)
114114

115115
input_tensor = torch.randn(
116116
params.M, params.K, device="cuda", dtype=params.ref_dtype, requires_grad=True

0 commit comments

Comments
 (0)