@@ -50,7 +50,11 @@ def decorator(func):
5050def float8_desugar_op (aten_op , args , kwargs = None ):
5151 new_data = aten_op (args [0 ]._data , * args [1 :], ** kwargs )
5252 return Float8Tensor (
53- new_data , args [0 ]._scale , args [0 ]._orig_dtype , args [0 ]._mm_config
53+ new_data ,
54+ args [0 ]._scale ,
55+ args [0 ]._orig_dtype ,
56+ args [0 ]._mm_config ,
57+ args [0 ]._scaling_strategy ,
5458 )
5559
5660
@@ -60,7 +64,11 @@ def float8_split(aten_op, args, kwargs=None):
6064
6165 def make_float8 (data ):
6266 return Float8Tensor (
63- data , args [0 ]._scale , args [0 ]._orig_dtype , args [0 ]._mm_config
67+ data ,
68+ args [0 ]._scale ,
69+ args [0 ]._orig_dtype ,
70+ args [0 ]._mm_config ,
71+ args [0 ]._scaling_strategy ,
6472 )
6573
6674 out = map (make_float8 , new_data_tensors )
@@ -75,6 +83,7 @@ def float8_cat(aten_op, args, kwargs=None):
7583 orig_dtype = chunked_tensors [0 ]._orig_dtype
7684 scale = chunked_tensors [0 ]._scale
7785 mm_config = chunked_tensors [0 ]._mm_config
86+ scaling_strategy = chunked_tensors [0 ]._scaling_strategy
7887 fp8_dtype = chunked_tensors [0 ]._data .dtype
7988 chunk_data = []
8089 for chunk in chunked_tensors :
@@ -93,11 +102,14 @@ def float8_cat(aten_op, args, kwargs=None):
93102 assert (
94103 chunk ._data .dtype == fp8_dtype
95104 ), "Expecting all chunks to be of the same dtype as a result of a split"
105+ assert (
106+ chunk ._scaling_strategy is scaling_strategy
107+ ), "Expecting all chunks to have thee same scaling strategy as a result of a split"
96108 chunk_data .append (chunk ._data .view (torch .uint8 ))
97109
98110 new_data = aten_op (chunk_data , * args [1 :], ** kwargs )
99111 new_data = new_data .view (fp8_dtype )
100- return Float8Tensor (new_data , scale , orig_dtype , mm_config )
112+ return Float8Tensor (new_data , scale , orig_dtype , mm_config , scaling_strategy )
101113
102114
103115@implements ([aten .sum .dim_IntList ])
@@ -162,6 +174,11 @@ def float8_mm(aten_op, args, kwargs=None):
162174 return torch .ops .aten .mm_float8_emulated (
163175 a ._data , a ._scale , b ._data , b ._scale , output_dtype
164176 )
177+ scaling_strategy = a ._scaling_strategy
178+ # TODO We can enable this by broadcasting to the more generic form
179+ assert (
180+ scaling_strategy == b ._scaling_strategy
181+ ), "Scaling strategy are currently required to be the same"
165182 tensor_out = addmm_float8_unwrapped (
166183 a_data ,
167184 a_scale ,
@@ -191,6 +208,11 @@ def float8_addmm(aten_op, args, kwargs=None):
191208 a_mm_config : ScaledMMConfig = a ._mm_config
192209 b_mm_config : ScaledMMConfig = b ._mm_config
193210 mm_config : ScaledMMConfig = merge_mm_configs (a_mm_config , b_mm_config )
211+ scaling_strategy = a ._scaling_strategy
212+ # TODO We can enable this by broadcasting to the more generic form
213+ assert (
214+ scaling_strategy == b ._scaling_strategy
215+ ), "Scaling strategy are currently required to be the same"
194216 if mm_config .emulate :
195217 out = torch .ops .aten .mm_float8_emulated (
196218 a ._data , a ._scale , b ._data , b ._scale , output_dtype
@@ -229,7 +251,11 @@ def autocast_to_copy(aten_op, args, kwargs=None):
229251 torch .bfloat16 ,
230252 }, "Only support floating point conversion for autocast w/ Float8Tensor"
231253 return Float8Tensor (
232- args [0 ]._data , args [0 ]._scale , kwargs ["dtype" ], args [0 ]._mm_config
254+ args [0 ]._data ,
255+ args [0 ]._scale ,
256+ kwargs ["dtype" ],
257+ args [0 ]._mm_config ,
258+ args [0 ]._scaling_strategy ,
233259 )
234260
235261
@@ -252,7 +278,11 @@ def allgather_fp8(aten_op, args, kwargs=None):
252278 fp8_data = fp8_data .contiguous ()
253279 fp8_out = aten_op (fp8_data , * args [1 :], ** kwargs )
254280 return Float8Tensor (
255- fp8_out , fp8_input ._scale , fp8_input ._orig_dtype , fp8_input ._mm_config
281+ fp8_out ,
282+ fp8_input ._scale ,
283+ fp8_input ._orig_dtype ,
284+ fp8_input ._mm_config ,
285+ fp8_input ._scaling_strategy ,
256286 )
257287
258288
@@ -264,7 +294,11 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
264294 fp8_data = fp8_input ._data
265295 fp8_out = aten_op (fp8_data , * args [1 :], ** kwargs )
266296 return Float8Tensor (
267- fp8_out , fp8_input ._scale , fp8_input ._orig_dtype , fp8_input ._mm_config
297+ fp8_out ,
298+ fp8_input ._scale ,
299+ fp8_input ._orig_dtype ,
300+ fp8_input ._mm_config ,
301+ fp8_input ._scaling_strategy ,
268302 )
269303
270304
@@ -282,7 +316,11 @@ def index_put_fp8(aten_op, args, kwargs=None):
282316 fp8_values_data = fp8_values ._data
283317 fp8_out = aten_op (fp8_data , args [1 ], fp8_values_data , * args [3 :], ** kwargs )
284318 return Float8Tensor (
285- fp8_out , fp8_self ._scale , fp8_self ._orig_dtype , fp8_self ._mm_config
319+ fp8_out ,
320+ fp8_self ._scale ,
321+ fp8_self ._orig_dtype ,
322+ fp8_self ._mm_config ,
323+ fp8_self ._scaling_strategy ,
286324 )
287325
288326
@@ -315,6 +353,12 @@ def copy_fp8(aten_op, args, kwargs=None):
315353 self ._data .dtype == src ._data .dtype
316354 ), "Expecting both Float8Tensors to be of the same dtypet"
317355 fp8_out = aten_op (self ._data , src ._data , * args [2 :], ** kwargs )
318- return Float8Tensor (fp8_out , self ._scale , self ._orig_dtype , self ._mm_config )
356+ return Float8Tensor (
357+ fp8_out ,
358+ self ._scale ,
359+ self ._orig_dtype ,
360+ self ._mm_config ,
361+ self ._scaling_strategy ,
362+ )
319363 else :
320364 raise RuntimeError ("Unsupported semantics for copy_ in Float8Tensor" )
0 commit comments