Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 177173a

Browse files
committed
parmetrize compile and base tests
1 parent 2ad6184 commit 177173a

File tree

3 files changed

+73
-17
lines changed

3 files changed

+73
-17
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,17 @@ class LinearType(Enum):
2424

2525

2626
def get_float8_linear(
27-
linear_type: LinearType, linear_ref: torch.nn.Linear, emulate: bool = False
27+
linear_type: LinearType,
28+
linear_ref: torch.nn.Linear,
29+
emulate: bool = False,
30+
recompute_weight_cast: bool = False,
2831
):
2932
"""Returns a Float8Linear module of the given type, initialized from linear_ref.
3033
Args:
3134
linear_type: The type of Float8Linear to return.
3235
linear_ref: The linear module to initialize from.
3336
emulate: Whether to emulate the fp8 matmul logic in float32.
37+
recompute_weight_cast: Whether to recompute the weight cast in the backwards pass.
3438
"""
3539
LINEAR_TYPE_MAP = {
3640
LinearType.DELAYED: Float8Linear,
@@ -40,7 +44,9 @@ def get_float8_linear(
4044
raise ValueError(f"linear_type must be one of {LINEAR_TYPE_MAP.keys()}")
4145

4246
return LINEAR_TYPE_MAP[linear_type].from_float(
43-
copy.deepcopy(linear_ref), emulate=emulate
47+
copy.deepcopy(linear_ref),
48+
emulate=emulate,
49+
recompute_weight_cast=recompute_weight_cast,
4450
)
4551

4652

test/test_base.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,15 @@ def test_preserves_dtype(self) -> None:
5050

5151

5252
class TestFloat8Linear:
53-
def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool):
54-
m_fp8 = get_float8_linear(linear_type, m_ref, emulate)
53+
def _test_linear_impl(
54+
self,
55+
x,
56+
m_ref,
57+
linear_type: LinearType,
58+
emulate: bool,
59+
recompute_weight_cast: bool,
60+
):
61+
m_fp8 = get_float8_linear(linear_type, m_ref, emulate, recompute_weight_cast)
5562
for _ in range(2):
5663
if linear_requires_sync(linear_type):
5764
sync_float8_amax_and_scale_history(m_fp8)
@@ -112,7 +119,14 @@ def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool):
112119
@pytest.mark.parametrize("emulate", [True, False])
113120
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
114121
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
115-
def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool):
122+
@pytest.mark.parametrize("recompute_weight_cast", [True, False])
123+
def test_linear_nobias(
124+
self,
125+
x_shape,
126+
linear_type: LinearType,
127+
emulate: bool,
128+
recompute_weight_cast: bool,
129+
):
116130
if not emulate:
117131
if not torch.cuda.is_available():
118132
warnings.warn("CUDA not available")
@@ -125,16 +139,22 @@ def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool):
125139

126140
x = torch.randn(*x_shape, device="cuda")
127141
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
128-
self._test_linear_impl(x, m_ref, linear_type, emulate)
142+
self._test_linear_impl(x, m_ref, linear_type, emulate, recompute_weight_cast)
129143

130144
@pytest.mark.parametrize("emulate", [True, False])
131145
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
132146
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
133147
@pytest.mark.parametrize(
134148
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
135149
)
150+
@pytest.mark.parametrize("recompute_weight_cast", [True, False])
136151
def test_linear_bias(
137-
self, x_shape, linear_type: LinearType, emulate: bool, linear_dtype: torch.dtype
152+
self,
153+
x_shape,
154+
linear_type: LinearType,
155+
emulate: bool,
156+
linear_dtype: torch.dtype,
157+
recompute_weight_cast: bool,
138158
):
139159
if not emulate:
140160
if not torch.cuda.is_available():
@@ -148,10 +168,10 @@ def test_linear_bias(
148168

149169
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
150170
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
151-
self._test_linear_impl(x, m_ref, linear_type, emulate)
171+
self._test_linear_impl(x, m_ref, linear_type, emulate, recompute_weight_cast)
152172

153173
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
154-
m = Float8Linear.from_float(m, emulate)
174+
m = Float8Linear.from_float(m, emulate, recompute_weight_cast)
155175

156176
# autocast off
157177
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
@@ -184,7 +204,7 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype):
184204

185205
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
186206
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
187-
self._test_linear_impl(x, m_ref, linear_type, emulate)
207+
self._test_linear_impl(x, m_ref, linear_type, emulate, False)
188208

189209
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
190210
m = Float8Linear.from_float(m, emulate)

test/test_compile.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def _test_compile_base(
2222
emulate: bool,
2323
linear_type: LinearType,
2424
dtype: torch.dtype,
25+
recompute_weight_cast: bool,
2526
):
2627
random.seed(0)
2728
torch.manual_seed(0)
@@ -31,7 +32,9 @@ def _test_compile_base(
3132
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
3233
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
3334

34-
m_fp8 = get_float8_linear(linear_type, m_ref, emulate=emulate)
35+
m_fp8 = get_float8_linear(
36+
linear_type, m_ref, emulate=emulate, recompute_weight_cast=recompute_weight_cast
37+
)
3538

3639
m_fp8 = torch.compile(m_fp8, backend=backend, fullgraph=fullgraph)
3740
m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph)
@@ -50,30 +53,57 @@ def _test_compile_base(
5053
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
5154
@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True])
5255
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
56+
@pytest.mark.parametrize("recompute_weight_cast", [False, True])
5357
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
54-
def test_eager_only(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype):
58+
def test_eager_only(
59+
fullgraph,
60+
emulate: bool,
61+
linear_type: bool,
62+
dtype: torch.dtype,
63+
recompute_weight_cast: bool,
64+
):
5565
torch._dynamo.reset()
56-
_test_compile_base("eager", fullgraph, emulate, linear_type, dtype)
66+
_test_compile_base(
67+
"eager", fullgraph, emulate, linear_type, dtype, recompute_weight_cast
68+
)
5769

5870

5971
@pytest.mark.parametrize("fullgraph", [True])
6072
@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True])
6173
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
6274
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
75+
@pytest.mark.parametrize("recompute_weight_cast", [False, True])
6376
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
64-
def test_aot_eager(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype):
77+
def test_aot_eager(
78+
fullgraph,
79+
emulate: bool,
80+
linear_type: bool,
81+
dtype: torch.dtype,
82+
recompute_weight_cast: bool,
83+
):
6584
torch._dynamo.reset()
66-
_test_compile_base("aot_eager", fullgraph, emulate, linear_type, dtype)
85+
_test_compile_base(
86+
"aot_eager", fullgraph, emulate, linear_type, dtype, recompute_weight_cast
87+
)
6788

6889

6990
@pytest.mark.parametrize("fullgraph", [True])
7091
@pytest.mark.parametrize("emulate", [False])
7192
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
93+
@pytest.mark.parametrize("recompute_weight_cast", [False, True])
7294
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available")
7395
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
74-
def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtype):
96+
def test_inductor(
97+
fullgraph,
98+
emulate: bool,
99+
linear_type: bool,
100+
dtype: torch.dtype,
101+
recompute_weight_cast: bool,
102+
):
75103
torch._dynamo.reset()
76-
_test_compile_base("inductor", fullgraph, emulate, linear_type, dtype)
104+
_test_compile_base(
105+
"inductor", fullgraph, emulate, linear_type, dtype, recompute_weight_cast
106+
)
77107

78108

79109
if __name__ == "__main__":

0 commit comments

Comments
 (0)