@@ -129,6 +129,127 @@ def cuda(self):
129
129
self .values = [val .cuda () if isinstance (val , torch .Tensor ) else val for val in self .values ]
130
130
131
131
132
+ class GPTQMultiTensor (torch .Tensor ):
133
+ """
134
+ """
135
+ # todo need default shape/dtype
136
+ @staticmethod
137
+ def __new__ (cls , input , ** kwargs ):
138
+ kwargs ["dtype" ]= kwargs .get ("dtype" , input .dtype )
139
+ shape = kwargs .pop ("shape" , input .shape )
140
+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs )
141
+
142
+ def __init__ (self , input , ** kwargs ):
143
+ self .values = []
144
+ self .append (inp )
145
+ self .debug = False
146
+
147
+
148
+ def append (self , input )
149
+ if isinstance (input , (tuple , list )):
150
+ for inp in input :
151
+ self .values .append (inp )
152
+ elif isinstance (input , torch .Tensor ):
153
+ self .values (input )
154
+
155
+ # def __add__(self, other):
156
+ # for val in other.values:
157
+ # self.append(val)
158
+
159
+ def count (self ):
160
+ return len (self .values )
161
+
162
+ def cuda (self ):
163
+ self .values = [val .cuda () if isinstance (val , torch .Tensor ) else val for val in self .values ]
164
+
165
+ @classmethod
166
+ def __torch_function__ (cls , func , types , args = (), kwargs = None , skip_quant = False )
167
+ def tensors_to_cuda (args ):
168
+ new_args = []
169
+ for x in args :
170
+ new_args .append (x .cuda () if isinstance (x , torch .Tensor ) else x )
171
+ return new_args
172
+
173
+ kwargs = {} if kwargs is None else kwargs
174
+ # combine args and kwargs
175
+ flat_args , spec = tree_flatten ((args , kwargs ))
176
+ # move single tensors to cuda
177
+ flat_args = tensors_to_cuda (flat_args )
178
+ # size of biggest MultiTensor
179
+ multi_tensor_size = max (
180
+ [x .count () if isinstance (x , GPTQMultiTensor ) else 1 for x in flat_args ]
181
+ )
182
+ # convert [a, MultiTensor(b,b,b), MultiTensor(c,c,c)] => [a,b,c], [a,b,c] [a,b,c]
183
+ grouped_args = list (
184
+ zip (
185
+ * [x .values if isinstance (x , GPTQMultiTensor ) else [x ] * multi_tensor_size for x in flat_args ]
186
+ )
187
+ )
188
+
189
+ quantize_linear = (
190
+ func is nn .functional .linear
191
+ # and id(args[1]) in self.id_to_name
192
+ and not skip_quant
193
+ # and not (self.skip_layer_func)
194
+ )
195
+
196
+ # run function for each of the multitensors and return a multitensor
197
+ if not quantize_linear :
198
+ outputs = []
199
+ for inp in transposed_args :
200
+ inp = tensors_to_cuda (inp )
201
+ cur_args , cur_kwargs = tree_unflatten (inp , spec )
202
+ with torch ._C .DisableTorchFunctionSubclass ():
203
+ out = func (* cur_args , ** cur_kwargs )
204
+ outputs .append (out .cpu () if isinstance (out , torch .Tensor ) else out )
205
+ return cls (outputs )
206
+
207
+ total_batches = 0
208
+ H = 0
209
+ for inp in transposed_args :
210
+ inp = tensors_to_cuda (inp )
211
+ cur_args , cur_kwargs = tree_unflatten (inp , spec )
212
+ x = cur_args [0 ].float ()
213
+ shape = x .shape
214
+ n = 1 if len (shape ) == 2 else shape [0 ]
215
+ H *= total_batches / (total_batches + n )
216
+ total_batches += n
217
+ x = (
218
+ (2 / total_batches ) ** (1 / 2 ) *
219
+ x .reshape (- 1 , shape [- 1 ]).t ().float ()
220
+
221
+ )
222
+ H += x .matmul (x .t ())
223
+ W = args [1 ].to (H .device )
224
+ Q , DQ , qparams = args [0 ].faster_quant (H , W .detach ())
225
+
226
+ new_out = func (args [0 ], DQ , * args [2 :], kwargs , skip_quant = True )
227
+ if args [0 ].debug :
228
+ breakpoint ()
229
+ return new_out
230
+
231
+
232
+
233
+ if func is torch .nn .functional .linear :
234
+
235
+ inputs , weight , bias = (
236
+ args [0 ],
237
+ args [1 ],
238
+ args [2 ] if len (args )> 2 else None
239
+ )
240
+ if quantize_linear :
241
+ cls .do_gptq (input , weight )
242
+ return func (mat1 , w_autoquant .weight , bias )
243
+ try :
244
+ with torch ._C .DisableTorchFunctionSubclass ():
245
+ return func (* args , ** kwargs )
246
+ except :
247
+ print (f"ERR: subclass doesn't implement { func } " )
248
+
249
+
250
+
251
+
252
+
132
253
class GenericGPTQRunner (fx .Interpreter ):
133
254
"""
134
255
This is a generic GPTQ runner that takes an existing model and applies GPTQ.
@@ -150,7 +271,7 @@ def __init__(
150
271
}
151
272
152
273
# trace model for one input
153
- one_input = [multi .values [0 ].cpu () for multi in inputs ]
274
+ one_input = tuple ( [multi .values [0 ].cpu () for multi in inputs ])
154
275
exported_model = torch ._dynamo .export (
155
276
model .cpu (), aten_graph = True , pre_dispatch = True , tracing_mode = "fake"
156
277
)(* one_input )
0 commit comments