2323
2424
2525@torch .no_grad ()
26- def amax_to_scale (amax , float8_dtype , orig_dtype ):
27- scale = torch .empty_like (amax , dtype = torch .float32 )
26+ def amax_to_scale (
27+ amax : torch .Tensor , float8_dtype : torch .dtype , orig_dtype : torch .dtype
28+ ):
29+ assert amax .dtype == torch .float32 , "amax must be a float32 tensor"
2830 if float8_dtype == torch .float8_e4m3fn :
2931 res = E4M3_MAX_POS / torch .clamp (amax , min = EPS )
3032 else : # e5m2
@@ -35,16 +37,15 @@ def amax_to_scale(amax, float8_dtype, orig_dtype):
3537 # to care about this for float32/bfloat16.
3638 if orig_dtype is torch .float16 :
3739 res = torch .clamp (res , max = FP16_MAX_POS )
38- scale .copy_ (res )
39- return scale
40+ return res
4041
4142
4243@torch .no_grad ()
4344def amax_history_to_scale (
44- amax_history ,
45- float8_dtype ,
46- orig_dtype ,
47- history_to_scale_fn_type ,
45+ amax_history : torch . Tensor ,
46+ float8_dtype : torch . dtype ,
47+ orig_dtype : torch . dtype ,
48+ history_to_scale_fn_type : str ,
4849):
4950 if history_to_scale_fn_type == "max" :
5051 amax = torch .max (amax_history )
@@ -87,7 +88,7 @@ def tensor_to_amax(x, distributed_reduction=False):
8788
8889
8990@torch .no_grad ()
90- def tensor_to_scale (x , float8_dtype ):
91+ def tensor_to_scale (x : torch . Tensor , float8_dtype : torch . dtype ):
9192 amax = tensor_to_amax (x )
9293 if float8_experimental .config .use_fused_cast and x .is_cuda :
9394 from float8_experimental .fused_kernels .fused_casting_kernels import (
0 commit comments