From 445ff1de37c5a082c970384d73d06de4ec323427 Mon Sep 17 00:00:00 2001 From: Xue Dong Date: Sun, 13 Jul 2025 04:51:19 -0700 Subject: [PATCH] config change to enable pre compute scale for fp8 Summary: design doc: https://docs.google.com/document/d/1k2fASPMfOH7TbGOW4jQlI4Mi4E1E0bs3Rh14M1cyWNw/edit?usp=sharing Differential Revision: D78212339 --- torchao/float8/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 939f68e59a..ddb6a2b5d7 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -183,6 +183,9 @@ class Float8LinearConfig: # implements pre/post-all-gather methods to do float8 all-gather with FSDP2. enable_fsdp_float8_all_gather: bool = False + # If True, then pre compute the scale of the weights in fp8 linear module + pre_compute_fp8_all_gather_weights_scale: bool = False + # If True, then prior to performing the fp8 scaled mamtmul we will pad the # inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls # _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.