Skip to content

Commit 94a75dc

Browse files
committed
fixing GPTQ
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 7a9ae0d1d03794c5e5d85e3674fc88d2813eaf23 Pull Request resolved: #147
1 parent 93dab0e commit 94a75dc

File tree

2 files changed

+142
-1
lines changed

2 files changed

+142
-1
lines changed

GPTQ.py

+122-1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,127 @@ def cuda(self):
129129
self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values]
130130

131131

132+
class GPTQMultiTensor(torch.Tensor):
133+
"""
134+
"""
135+
# todo need default shape/dtype
136+
@staticmethod
137+
def __new__(cls, input, **kwargs):
138+
kwargs["dtype"]=kwargs.get("dtype", input.dtype)
139+
shape = kwargs.pop("shape", input.shape)
140+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
141+
142+
def __init__(self, input, **kwargs):
143+
self.values = []
144+
self.append(inp)
145+
self.debug = False
146+
147+
148+
def append(self, input)
149+
if isinstance(input, (tuple, list)):
150+
for inp in input:
151+
self.values.append(inp)
152+
elif isinstance(input, torch.Tensor):
153+
self.values(input)
154+
155+
# def __add__(self, other):
156+
# for val in other.values:
157+
# self.append(val)
158+
159+
def count(self):
160+
return len(self.values)
161+
162+
def cuda(self):
163+
self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values]
164+
165+
@classmethod
166+
def __torch_function__(cls, func, types, args=(), kwargs=None, skip_quant=False)
167+
def tensors_to_cuda(args):
168+
new_args = []
169+
for x in args:
170+
new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x)
171+
return new_args
172+
173+
kwargs = {} if kwargs is None else kwargs
174+
# combine args and kwargs
175+
flat_args, spec = tree_flatten((args, kwargs))
176+
# move single tensors to cuda
177+
flat_args = tensors_to_cuda(flat_args)
178+
# size of biggest MultiTensor
179+
multi_tensor_size = max(
180+
[x.count() if isinstance(x, GPTQMultiTensor) else 1 for x in flat_args]
181+
)
182+
# convert [a, MultiTensor(b,b,b), MultiTensor(c,c,c)] => [a,b,c], [a,b,c] [a,b,c]
183+
grouped_args = list(
184+
zip(
185+
*[x.values if isinstance(x, GPTQMultiTensor) else [x] * multi_tensor_size for x in flat_args]
186+
)
187+
)
188+
189+
quantize_linear = (
190+
func is nn.functional.linear
191+
# and id(args[1]) in self.id_to_name
192+
and not skip_quant
193+
# and not (self.skip_layer_func)
194+
)
195+
196+
# run function for each of the multitensors and return a multitensor
197+
if not quantize_linear:
198+
outputs = []
199+
for inp in transposed_args:
200+
inp = tensors_to_cuda(inp)
201+
cur_args, cur_kwargs = tree_unflatten(inp, spec)
202+
with torch._C.DisableTorchFunctionSubclass():
203+
out = func(*cur_args, **cur_kwargs)
204+
outputs.append(out.cpu() if isinstance(out, torch.Tensor) else out)
205+
return cls(outputs)
206+
207+
total_batches = 0
208+
H=0
209+
for inp in transposed_args:
210+
inp = tensors_to_cuda(inp)
211+
cur_args, cur_kwargs = tree_unflatten(inp, spec)
212+
x = cur_args[0].float()
213+
shape = x.shape
214+
n = 1 if len(shape) == 2 else shape[0]
215+
H*= total_batches / (total_batches + n)
216+
total_batches += n
217+
x = (
218+
(2 / total_batches) ** (1 / 2) *
219+
x.reshape(-1, shape[-1]).t().float()
220+
221+
)
222+
H += x.matmul(x.t())
223+
W = args[1].to(H.device)
224+
Q, DQ, qparams = args[0].faster_quant(H, W.detach())
225+
226+
new_out = func(args[0], DQ, *args[2:], kwargs, skip_quant = True)
227+
if args[0].debug:
228+
breakpoint()
229+
return new_out
230+
231+
232+
233+
if func is torch.nn.functional.linear:
234+
235+
inputs, weight, bias = (
236+
args[0],
237+
args[1],
238+
args[2] if len(args)>2 else None
239+
)
240+
if quantize_linear:
241+
cls.do_gptq(input, weight)
242+
return func(mat1, w_autoquant.weight, bias)
243+
try:
244+
with torch._C.DisableTorchFunctionSubclass():
245+
return func(*args, **kwargs)
246+
except:
247+
print(f"ERR: subclass doesn't implement {func}")
248+
249+
250+
251+
252+
132253
class GenericGPTQRunner(fx.Interpreter):
133254
"""
134255
This is a generic GPTQ runner that takes an existing model and applies GPTQ.
@@ -150,7 +271,7 @@ def __init__(
150271
}
151272

152273
# trace model for one input
153-
one_input = [multi.values[0].cpu() for multi in inputs]
274+
one_input = tuple([multi.values[0].cpu() for multi in inputs])
154275
exported_model = torch._dynamo.export(
155276
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
156277
)(*one_input)

run.sh

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
2+
3+
# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working
4+
# echo "base"
5+
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
6+
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5
7+
python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5
8+
9+
# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile
10+
# echo "quant good"
11+
12+
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
13+
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5
14+
15+
# export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf
16+
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
17+
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5
18+
# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth
19+
20+
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5

0 commit comments

Comments
 (0)