-
Notifications
You must be signed in to change notification settings - Fork 130
[WIP][AWQ] Support accumulation for reduced memory usage #1435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Co-authored-by: Kyle Sayers <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
int_w_output = self._forward_input_with_kwargs( | ||
module=module2inspect, inputs=x, input_kwargs=self._module_kwargs | ||
) | ||
int_w_output = self._run_samples(parent_layer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inputs here should be x, not the dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x is the dataset, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my bad, misread what self._samples
was. can we keep that as a different name to distinguish it from actual data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm always happy to take suggestions on names for any of the code I push 🙂
default_factory=IntermediatesCache | ||
) | ||
_sample_means: Dict[Module, float] = PrivateAttr(default_factory=dict) | ||
_num_samples: Dict[Module, int] = PrivateAttr(default_factory=dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is in GPTQ, but a field named num_samples
indicates an int
in virtually any context i've seen it. what about _sample_counts
so it's more similar to sample_means
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's good too!
Signed-off-by: Kyle Sayers <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These deletions make me happy! LGTM!
AWQ_PRECISION = torch.float32 | ||
|
||
|
||
def accumulate_mean( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def iter( | ||
self, input_names: Optional[List[str]] = None | ||
) -> Generator[Any, None, None]: | ||
for batch_index in self.batch_intermediates: | ||
yield self.fetch(batch_index, input_names) | ||
|
||
def __iter__(self) -> Generator[Any, None, None]: | ||
yield from self.iter() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just have iter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Saves an extra function call
for batch in cache:
...
for batch in cache.iter():
...
c4cd97c
to
2659c22
Compare
SUMMARY: - Add QuantizationMixin to AWQModifier so we don't have redundant inputs (num_bits, symmetric, group_size) - Move AWQModifier to sequential pipelines, to avoid huge memory requirements of caching all activations at once. Regression test results are acceptable, results are all roughly the same, and within stderr, see test plan below. Resolves #1409 Resolves #1369 Related to #1383 Related to #1406 Related to #1368 Related to #1410 More improvements split into #1435 TEST PLAN: - [x] Rerun tests to validate No regression in tests, comparing against those reported in [original AWQ PR](#1177 (comment)). All gsm8k results are within stderr: | Type | gsm8k | wikitext | ------ | ------ | ----- | Old AWQ+QuantModifier Sym | .1054, .1069 | 9.1931 | New AWQ+QuantMixin Sym | .1077, .1084 | 9.1841 | Old AWQ+QuantModifier Asym | .1274, .1281 | 9.0281 | New AWQ+QuantMixin Asym | .1312, .1350 | 9.0288 --------- Signed-off-by: Brian Dellabetta <[email protected]> Co-authored-by: Kyle Sayers <[email protected]>
No description provided.