-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfloat8.py
More file actions
239 lines (201 loc) · 9.4 KB
/
float8.py
File metadata and controls
239 lines (201 loc) · 9.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
import torch
import torch._inductor.config
import torch.nn as nn
from torchtitan.components.quantization import (
FP8_GROUP_ALIGNMENT_SIZE,
QuantizationConverter,
)
from torchtitan.config.job_config import Float8Linear, JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.models.moe.utils import set_token_group_alignment_size_m
from torchtitan.protocols.model_converter import register_model_converter
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import has_cuda_capability
from .utils import module_filter_fn
AUTO_FILTER_SMALL_KN_FLAG = "auto_filter_small_kn"
class Float8LinearConverter(QuantizationConverter):
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
super().__init__(job_config, parallel_dims)
float8_config: Float8Linear = job_config.quantize.linear.float8
compile_config = job_config.compile
model_compile_enabled = (
compile_config.enable and "model" in compile_config.components
)
if has_cuda_capability(8, 9) or (
float8_config.emulate and not model_compile_enabled
):
pass
else:
raise ValueError(
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later."
"To enable testing on older hardware, set `float8.emulate` to True in eager mode.",
)
try:
from torchao.float8 import Float8LinearConfig as TorchAOFloat8LinearConfig
except ImportError as e:
raise ImportError(
"torchao is not installed. Please install it to use float8 linear layers."
) from e
if float8_config.recipe_name is not None and not hasattr(
TorchAOFloat8LinearConfig, "from_recipe_name"
):
logger.warning(
"Failed to swap to Float8Linear with recipe lookup because the torchao version "
"is too old, please install torchao v0.9.0 or later and try again",
)
return
self.filter_fqns = float8_config.filter_fqns
self.filter_fn = self._init_filter_fn(float8_config)
if float8_config.recipe_name is not None:
assert not float8_config.enable_fsdp_float8_all_gather, (
"using `float8_config.enable_fsdp_float8_all_gather` together "
"with `float8_config.recipe_name` is not supported"
)
self.config = TorchAOFloat8LinearConfig.from_recipe_name(
float8_config.recipe_name
)
self.precompute_scale = False
logger.info(
f"Float8 training active with recipe {float8_config.recipe_name}"
)
# short-term solution for https://github.com/pytorch/pytorch/issues/150859
if float8_config.recipe_name == "rowwise":
torch._inductor.config.emulate_precision_casts = True
logger.debug(
"Set torch._inductor.config.emulate_precision_casts to True"
)
else:
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
parallel_dims.dp_shard_enabled
and float8_config.enable_fsdp_float8_all_gather
)
self.config = TorchAOFloat8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
emulate=float8_config.emulate,
)
# for precompute_float8_dynamic_scale_for_fsdp
self.precompute_scale = (
enable_fsdp_float8_all_gather
and float8_config.precompute_float8_dynamic_scale_for_fsdp
)
logger.info("Float8 tensorwise scaled training active")
self.enabled = True
def _init_filter_fn(self, float8_config: Float8Linear):
# use auto_filter if filter_fqns "auto_filter_small_kn" is one of the given fqns.
use_auto_filter = AUTO_FILTER_SMALL_KN_FLAG in float8_config.filter_fqns
if use_auto_filter:
try:
from torchao.float8 import _auto_filter_for_recipe
logger.info(
"Using _auto_filter_for_recipe to avoid converting linear layers with dims too small "
"to benefit from float8 training. See docs/float8.md for more info."
)
recipe_name = (
float8_config.recipe_name
if float8_config.recipe_name
else "tensorwise"
)
# remove auto filter flag from filter_fqns before passing to _auto_filter_for_recipe
float8_config.filter_fqns.remove(AUTO_FILTER_SMALL_KN_FLAG)
return _auto_filter_for_recipe(
recipe_name,
filter_fqns=float8_config.filter_fqns,
)
except ImportError:
logger.warning(
(
"Using default module_filter_fn for float8 model conversion. "
"To use _auto_filter_for_recipe, please install torchao nightly build."
)
)
# use default filter func
return partial(module_filter_fn, filter_fqns=float8_config.filter_fqns)
def convert(self, model: nn.Module):
"""
This function converts the linear layers of `model` to `Float8Linear`.
Note that today, only dynamic tensor scaling (the default) is supported.
This will mutate the model inplace.
"""
if not self.enabled:
return
from torchao.float8 import convert_to_float8_training
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
convert_to_float8_training(
model,
config=self.config,
module_filter_fn=self.filter_fn,
)
logger.info(
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
f"{self.config.enable_fsdp_float8_all_gather}"
)
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
if not self.enabled:
return
if not self.precompute_scale:
return
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
models = [model] if isinstance(model, nn.Module) else model
for m in models:
precompute_float8_dynamic_scale_for_fsdp(m)
class Float8GroupedMMConverter(QuantizationConverter):
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
super().__init__(job_config, parallel_dims)
self.fqns = job_config.quantize.grouped_mm.float8.fqns
compile_config = job_config.compile
model_compile_enabled = (
compile_config.enable and "model" in compile_config.components
)
if not has_cuda_capability(8, 9):
raise ValueError("Float8 MoE training only supported on SM89 or later.")
if not model_compile_enabled:
logger.warning(
"Compile is required for high performance float8 MoE training; enable it with --compile.enable"
)
# Validate MoE training prototype limitations.
assert (
job_config.parallelism.pipeline_parallel_degree == 1
), "Float8 MoE training prototype does not yet support pipeline parallelism"
assert (
job_config.parallelism.context_parallel_degree == 1
), "Float8 MoE training prototype does not yet support context parallelism"
# For fp8 grouped GEMM, token group sizes must be multiples of 16
# (16 byte alignment / 1 byte per elem = 16 elements)
set_token_group_alignment_size_m(FP8_GROUP_ALIGNMENT_SIZE)
self.enabled = True
def convert(self, model: nn.Module):
"""
Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
"""
from torchao.quantization.quant_api import quantize_
try:
from torchao.prototype.moe_training.conversion_utils import (
MoETrainingConfig,
)
except ImportError as e:
raise ImportError(
"torchao installation does not have MoE training support. Please install torchao nightly build."
) from e
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
for target_fqn in self.fqns:
if target_fqn in cur_fqn:
return True
return False
config = MoETrainingConfig()
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
logger.info(
f"Converted MoE layers matching FQNS {self.fqns} "
"to use dynamic float8 rowwise quantization with scaled grouped GEMMs"
)
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
pass
register_model_converter(Float8LinearConverter, "quantize.linear.float8")
register_model_converter(Float8GroupedMMConverter, "quantize.grouped_mm.float8")