1919FLOAT8_OPS_TABLE : Dict [Any , Any ] = {}
2020
2121
22+ def _assert_tensorwise_scale (aten_op , scale ):
23+ assert (
24+ # TODO(future PR): figure out why tensorwise scaling can have
25+ # both rank 0 and rank 1
26+ len (scale .shape )
27+ in (0 , 1 )
28+ ), f"{ aten_op } with axiswise scaling is not supported yet"
29+
30+
2231def implements (aten_ops ):
2332 """Register aten ops to the float8 op table"""
2433
@@ -32,18 +41,16 @@ def decorator(func):
3241
3342@implements (
3443 [
35- aten .view .default ,
3644 aten ._unsafe_view .default ,
37- aten .t .default ,
3845 aten .as_strided .default ,
3946 aten .clone .default ,
4047 aten .detach .default ,
4148 aten .slice .Tensor ,
42- aten .transpose .int ,
4349 aten .fill_ .Scalar ,
4450 ]
4551)
4652def float8_desugar_op (aten_op , args , kwargs = None ):
53+ _assert_tensorwise_scale (aten_op , args [0 ]._scale )
4754 new_data = aten_op (args [0 ]._data , * args [1 :], ** kwargs )
4855 return Float8Tensor (
4956 new_data ,
@@ -54,8 +61,61 @@ def float8_desugar_op(aten_op, args, kwargs=None):
5461 )
5562
5663
64+ @implements (
65+ [
66+ aten .t .default ,
67+ aten .transpose .int ,
68+ ]
69+ )
70+ def float8_desugar_data_and_scale (aten_op , args , kwargs = None ):
71+ new_data = aten_op (args [0 ]._data , * args [1 :], ** kwargs )
72+ new_scale = aten_op (args [0 ]._scale , * args [1 :], ** kwargs )
73+ return Float8Tensor (
74+ new_data ,
75+ new_scale ,
76+ args [0 ]._orig_dtype ,
77+ args [0 ]._linear_mm_config ,
78+ args [0 ]._gemm_input_role ,
79+ )
80+
81+
82+ @implements ([aten .view .default ])
83+ def float8_view (aten_op , args , kwargs = None ):
84+ if len (args [0 ]._scale .shape ) < 2 :
85+ # tensorwise scaling
86+ return float8_desugar_op (aten_op , args , kwargs )
87+
88+ t , new_shape = args [0 ], args [1 ]
89+ # for now, only support reshaping to [-1, dim] or [dim, -1]
90+ if len (new_shape ) == 2 :
91+ if new_shape == [t .shape [0 ], - 1 ] and t ._scale .shape [0 ] == 1 :
92+ new_data = aten_op (t ._data , new_shape , ** kwargs )
93+ new_scale = aten_op (t ._scale , [1 , - 1 ], ** kwargs )
94+ return Float8Tensor (
95+ new_data ,
96+ new_scale ,
97+ t ._orig_dtype ,
98+ t ._linear_mm_config ,
99+ t ._gemm_input_role ,
100+ )
101+ elif new_shape == [- 1 , t .shape [- 1 ]] and t ._scale .shape [- 1 ] == 1 :
102+ new_data = aten_op (t ._data , new_shape , ** kwargs )
103+ new_scale = aten_op (t ._scale , [- 1 , 1 ], ** kwargs )
104+ return Float8Tensor (
105+ new_data ,
106+ new_scale ,
107+ t ._orig_dtype ,
108+ t ._linear_mm_config ,
109+ t ._gemm_input_role ,
110+ )
111+ raise AssertionError (
112+ f"{ aten_op } with axiswise scaling and t.shape { t .shape } t._scale.shape { t ._scale .shape } new_shape { new_shape } is not supported yet."
113+ )
114+
115+
57116@implements ([aten .split .Tensor ])
58117def float8_split (aten_op , args , kwargs = None ):
118+ _assert_tensorwise_scale (aten_op , args [0 ]._scale )
59119 new_data_tensors = aten_op (args [0 ]._data , * args [1 :], ** kwargs )
60120
61121 def make_float8 (data ):
@@ -101,6 +161,7 @@ def float8_cat(aten_op, args, kwargs=None):
101161 assert (
102162 chunk ._gemm_input_role is gemm_input_role
103163 ), "Expecting all chunks to have the same gemm_input_role as a result of a split"
164+ _assert_tensorwise_scale (aten_op , chunk ._scale )
104165 chunk_data .append (chunk ._data .view (torch .uint8 ))
105166
106167 new_data = aten_op (chunk_data , * args [1 :], ** kwargs )
@@ -117,6 +178,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None):
117178 "addmm" -> out
118179 "hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut"
119180 """
181+ _assert_tensorwise_scale (aten_op , args [0 ]._scale )
120182
121183 def unwrap (x ):
122184 if isinstance (x , Float8Tensor ):
@@ -229,6 +291,7 @@ def float8_addmm(aten_op, args, kwargs=None):
229291
230292@implements ([aten .is_same_size .default ])
231293def float8_is_same_size (aten_op , args , kwargs = None ):
294+ _assert_tensorwise_scale (aten_op , args [0 ]._scale )
232295 return args [0 ].shape == args [1 ].shape
233296
234297
@@ -238,6 +301,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
238301 when the input is a Float8Tensor, presenting as a fp32
239302 tensor.
240303 """
304+ _assert_tensorwise_scale (aten_op , args [0 ]._scale )
241305 assert isinstance (args [0 ], Float8Tensor )
242306 assert (
243307 len (kwargs ) == 1 and "dtype" in kwargs
@@ -265,6 +329,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
265329 """
266330 override funcol with FP8 handling
267331 """
332+ _assert_tensorwise_scale (aten_op , args [0 ]._scale )
268333 fp8_input = args [0 ]
269334 assert isinstance (
270335 fp8_input , Float8Tensor
@@ -284,6 +349,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
284349
285350@implements ([c10d_functional .wait_tensor .default , _c10d_functional .wait_tensor .default ])
286351def wait_tensor_fp8 (aten_op , args , kwargs = None ):
352+ _assert_tensorwise_scale (aten_op , args [0 ]._scale )
287353 fp8_input = args [0 ]
288354 assert isinstance (fp8_input , Float8Tensor )
289355
@@ -304,6 +370,7 @@ def index_put_fp8(aten_op, args, kwargs=None):
304370 fp8_values = args [2 ]
305371 assert isinstance (fp8_self , Float8Tensor )
306372 assert isinstance (fp8_values , Float8Tensor )
373+ _assert_tensorwise_scale (fp8_self , args [0 ]._scale )
307374 assert fp8_self ._scale == fp8_values ._scale
308375 assert fp8_self .dtype == fp8_values .dtype
309376 assert fp8_self ._orig_dtype == fp8_values ._orig_dtype
@@ -334,8 +401,10 @@ def copy_fp8(aten_op, args, kwargs=None):
334401
335402 if not isinstance (self , Float8Tensor ) and isinstance (src , Float8Tensor ):
336403 src_hp = src .to_original_precision ()
404+ _assert_tensorwise_scale (aten_op , src ._scale )
337405 return aten_op (self , src_hp , * args [2 :], ** kwargs )
338406 elif isinstance (self , Float8Tensor ) and isinstance (src , Float8Tensor ):
407+ _assert_tensorwise_scale (aten_op , src ._scale )
339408 assert (
340409 self ._orig_dtype == src ._orig_dtype
341410 ), "Expecting both Float8Tensors to be of the same dtype"
0 commit comments