Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit aa920be

Browse files
summerdengfbfacebook-github-bot
authored andcommitted
Add history_len option to DelayedScalingRecipe
Summary: Add init function in DelayedScalingRecipe and history_len option. Reviewed By: drisspg Differential Revision: D52918526 fbshipit-source-id: b92df361a7b5b3507008b2ea7d8bb2d757e8aecb
1 parent 289c122 commit aa920be

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

float8_experimental/float8_linear.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,18 @@ def backward(ctx, go):
109109
@dataclasses.dataclass
110110
class DelayedScalingRecipe:
111111
# Controls the history length of amax buffers
112-
history_len = 16
112+
history_len: int
113113

114114
# Controls the way to calculate current scale from amax history
115115
# TODO(future): add other functions as needed, hardcoded or user defined
116-
scale_fn_name = "max"
116+
scale_fn_name: str
117+
118+
def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):
119+
self.history_len = history_len
120+
self.scale_fn_name = scale_fn_name
121+
assert (
122+
self.scale_fn_name == "max"
123+
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."
117124

118125

119126
class Float8LinearMixin(object):

0 commit comments

Comments
 (0)