|
3 | 3 | # |
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | | -from typing import Any, Dict |
| 6 | +from typing import Any, Dict, Optional |
7 | 7 |
|
8 | 8 | import torch |
9 | 9 |
|
@@ -75,11 +75,33 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): |
75 | 75 | return a_data, a_scale, b_data, b_scale |
76 | 76 |
|
77 | 77 |
|
| 78 | +def float8_mm_helper(a: Float8Tensor, b: Float8Tensor) -> torch.Tensor: |
| 79 | + """This is a helper function for float8_mm |
| 80 | + Args: |
| 81 | + a: The first matrix multiplication term. |
| 82 | + b: The second matrix multiplication term. |
| 83 | + Returns: |
| 84 | + torch.Tensor: The result of the matrix multiplication. |
| 85 | + """ |
| 86 | + a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) |
| 87 | + output_dtype = a._orig_dtype |
| 88 | + if a._emulate: |
| 89 | + assert a._emulate == b._emulate |
| 90 | + return torch.ops.aten.mm_float8_emulated( |
| 91 | + a._data, a._scale, b._data, b._scale, output_dtype |
| 92 | + )[0] |
| 93 | + tensor_out, amax = addmm_float8_unwrapped( |
| 94 | + a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=None |
| 95 | + ) |
| 96 | + return tensor_out |
| 97 | + |
| 98 | + |
78 | 99 | @implements([aten.mm.default]) |
79 | 100 | def float8_mm(aten_op, args, kwargs=None): |
80 | 101 | assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) |
81 | 102 | a = args[0] |
82 | 103 | b = args[1] |
| 104 | + return float8_mm_helper(a, b) |
83 | 105 | a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) |
84 | 106 | output_dtype = a._orig_dtype |
85 | 107 | if a._emulate: |
@@ -140,3 +162,115 @@ def autocast_to_copy(aten_op, args, kwargs=None): |
140 | 162 | return Float8Tensor( |
141 | 163 | args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._emulate |
142 | 164 | ) |
| 165 | + |
| 166 | + |
| 167 | +class float8_linear(torch.autograd.Function): |
| 168 | + """Custom autograd function for computing torch.nn.Linear on Float8Tensor. |
| 169 | +
|
| 170 | + This is needed for a couple reasons, we want to have fine grained control over the |
| 171 | + recomputation of casted values for backward. |
| 172 | + """ |
| 173 | + |
| 174 | + @staticmethod |
| 175 | + def forward( |
| 176 | + ctx, |
| 177 | + x_fp8: torch.Tensor, |
| 178 | + original_weight: torch.Tensor, |
| 179 | + weight_scale: torch.Tensor, |
| 180 | + weight_amax_buffer: Optional[torch.Tensor], |
| 181 | + emulate: bool, |
| 182 | + recompute_float8_weight: bool, |
| 183 | + ): |
| 184 | + ctx.save_for_backward(x_fp8) |
| 185 | + w_fp8 = Float8Tensor.to_float8( |
| 186 | + original_weight, |
| 187 | + weight_scale, |
| 188 | + torch.float8_e4m3fn, |
| 189 | + weight_amax_buffer, |
| 190 | + emulate=emulate, |
| 191 | + ) |
| 192 | + if recompute_float8_weight: |
| 193 | + # This should be set to True when using traditional fsdp to avoid saving |
| 194 | + # saving the unsharded weight for |
| 195 | + ctx.save_for_backward( |
| 196 | + x_fp8, original_weight, weight_scale, weight_amax_buffer |
| 197 | + ) |
| 198 | + else: |
| 199 | + # Does this interact properly with activation checkpointing? |
| 200 | + ctx.save_for_backward(x_fp8, w_fp8) |
| 201 | + |
| 202 | + ctx.recompute_float8_weight = recompute_float8_weight |
| 203 | + ctx.emulate = emulate |
| 204 | + orig_shape = x_fp8._data.shape |
| 205 | + x_fp8_reshaped = Float8Tensor( |
| 206 | + x_fp8._data.reshape(-1, orig_shape[-1]), |
| 207 | + x_fp8._scale, |
| 208 | + x_fp8._orig_dtype, |
| 209 | + emulate=emulate, |
| 210 | + ) |
| 211 | + |
| 212 | + w_fp8_t = Float8Tensor( |
| 213 | + w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype, emulate=emulate |
| 214 | + ) |
| 215 | + |
| 216 | + res_bits = float8_mm_helper(x_fp8_reshaped, w_fp8_t) |
| 217 | + |
| 218 | + res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) |
| 219 | + return res_bits |
| 220 | + |
| 221 | + @staticmethod |
| 222 | + def backward(ctx, go_fp8: torch.Tensor): |
| 223 | + if ctx.recompute_float8_weight: |
| 224 | + x_fp8, original_weight, weight_scale, weight_amax_buffer = ctx.saved_tensors |
| 225 | + w_fp8 = Float8Tensor.to_float8( |
| 226 | + original_weight, |
| 227 | + weight_scale, |
| 228 | + torch.float8_e4m3fn, |
| 229 | + weight_amax_buffer, |
| 230 | + emulate=emulate, |
| 231 | + ) |
| 232 | + else: |
| 233 | + x_fp8, w_fp8 = ctx.saved_tensors |
| 234 | + |
| 235 | + emulate = ctx.emulate |
| 236 | + |
| 237 | + go_fp8_orig_shape = go_fp8._data.shape |
| 238 | + go_fp8_reshaped = Float8Tensor( |
| 239 | + go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]), |
| 240 | + go_fp8._scale, |
| 241 | + go_fp8._orig_dtype, |
| 242 | + emulate=emulate, |
| 243 | + ) |
| 244 | + |
| 245 | + w_fp8_t_c_t = Float8Tensor( |
| 246 | + w_fp8._data.t().contiguous().t(), |
| 247 | + w_fp8._scale, |
| 248 | + w_fp8._orig_dtype, |
| 249 | + emulate=emulate, |
| 250 | + ) |
| 251 | + |
| 252 | + # calculate dL/dX |
| 253 | + dL_dX = float8_mm_helper(go_fp8_reshaped, w_fp8_t_c_t) |
| 254 | + dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1]) |
| 255 | + |
| 256 | + x_fp8_orig_shape = x_fp8._data.shape |
| 257 | + x_fp8_reshaped_t_c = Float8Tensor( |
| 258 | + x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(), |
| 259 | + x_fp8._scale, |
| 260 | + x_fp8._orig_dtype, |
| 261 | + emulate=emulate, |
| 262 | + ) |
| 263 | + |
| 264 | + go_fp8_reshaped_t_c_t = Float8Tensor( |
| 265 | + go_fp8_reshaped._data.t().contiguous().t(), |
| 266 | + go_fp8_reshaped._scale, |
| 267 | + go_fp8_reshaped._orig_dtype, |
| 268 | + emulate=emulate, |
| 269 | + ) |
| 270 | + |
| 271 | + # calculate dL/dW |
| 272 | + dL_dW = float8_mm_helper(x_fp8_reshaped_t_c, go_fp8_reshaped_t_c_t) |
| 273 | + dL_dW = dL_dW.t() |
| 274 | + |
| 275 | + empty_grads = None, None, None, None, None, None, None, None, None |
| 276 | + return dL_dX, dL_dW, *empty_grads |
0 commit comments