-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathpytorch_transforms.py
457 lines (436 loc) · 15.6 KB
/
pytorch_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
from types import MethodType
from typing import Optional, Tuple
import transformers
from torch import nn
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
CodeGenBlock,
CodeGenForCausalLM,
CodeGenModel,
)
from transformers.models.falcon.modeling_falcon import (
FalconAttention,
FalconDecoderLayer,
FalconForCausalLM,
FalconModel,
)
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaDecoderLayer,
GemmaForCausalLM,
GemmaModel,
GemmaRMSNorm,
)
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2Attention,
Gemma2DecoderLayer,
Gemma2ForCausalLM,
Gemma2Model,
Gemma2RMSNorm,
)
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
GPTBigCodeAttention,
GPTBigCodeBlock,
GPTBigCodeForCausalLM,
GPTBigCodeModel,
)
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJForCausalLM, GPTJModel
from transformers.models.granite.modeling_granite import (
GraniteAttention,
GraniteForCausalLM,
GraniteModel,
)
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)
from transformers.models.llava.modeling_llava import (
LlavaForConditionalGeneration,
)
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralDecoderLayer,
MistralForCausalLM,
MistralModel,
MistralRMSNorm,
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralDecoderLayer,
MixtralForCausalLM,
MixtralModel,
MixtralRMSNorm,
MixtralSparseMoeBlock,
)
from transformers.models.mllama.modeling_mllama import (
MllamaCrossAttentionDecoderLayer,
MllamaForCausalLM,
MllamaForConditionalGeneration,
MllamaRotaryEmbedding,
MllamaSelfAttentionDecoderLayer,
MllamaTextCrossAttention,
MllamaTextModel,
MllamaTextRMSNorm,
MllamaTextSelfAttention,
MllamaVisionModel,
)
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel
from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel
from transformers.models.phi3.modeling_phi3 import (
Phi3Attention,
Phi3DecoderLayer,
Phi3ForCausalLM,
Phi3Model,
Phi3RMSNorm,
)
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2ForCausalLM,
Qwen2Model,
Qwen2RMSNorm,
)
from transformers.models.starcoder2.modeling_starcoder2 import (
Starcoder2Attention,
Starcoder2DecoderLayer,
Starcoder2ForCausalLM,
Starcoder2Model,
)
from transformers.models.whisper.modeling_whisper import (
WhisperAttention,
WhisperDecoder,
WhisperDecoderLayer,
WhisperEncoder,
WhisperForConditionalGeneration,
WhisperModel,
WhisperPositionalEmbedding,
)
from QEfficient.base.pytorch_transforms import ModuleMappingTransform, ModuleMethodMapperTransform
from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC
from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.models.codegen.modeling_codegen import (
QEffCodeGenAttention,
QeffCodeGenBlock,
QEffCodeGenForCausalLM,
QEffCodeGenModel,
)
from QEfficient.transformers.models.falcon.modeling_falcon import (
QEffFalconAttention,
QEffFalconDecoderLayer,
QEffFalconForCausalLM,
QEffFalconModel,
)
from QEfficient.transformers.models.gemma.modeling_gemma import (
QEffGemmaAttention,
QEffGemmaDecoderLayer,
QEffGemmaForCausalLM,
QEffGemmaModel,
)
from QEfficient.transformers.models.gemma2.modeling_gemma2 import (
QEffGemma2Attention,
QEffGemma2DecoderLayer,
QEffGemma2ForCausalLM,
QEffGemma2Model,
)
from QEfficient.transformers.models.gpt2.modeling_gpt2 import (
QEffGPT2Attention,
QEffGPT2Block,
QEffGPT2LMHeadModel,
QEffGPT2Model,
)
from QEfficient.transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
QEffGPTBigCodeAttention,
QEffGPTBigCodeBlock,
QEffGPTBigCodeForCausalLM,
QEffGPTBigCodeModel,
)
from QEfficient.transformers.models.gptj.modeling_gptj import (
QEffGPTJAttention,
QEffGPTJBlock,
QEffGPTJForCausalLM,
QEffGPTJModel,
)
from QEfficient.transformers.models.granite.modeling_granite import (
QEffGraniteAttention,
QEffGraniteForCausalLM,
QEffGraniteModel,
)
from QEfficient.transformers.models.internvl.modeling_internvl import (
QEffInternVisionEmbeddings,
QEffInternVLModel,
)
from QEfficient.transformers.models.llama.modeling_llama import (
QEffLlamaAttention,
QEffLlamaDecoderLayer,
QEffLlamaForCausalLM,
QEffLlamaModel,
)
from QEfficient.transformers.models.llava.modeling_llava import (
QEffLlavaForConditionalGeneration,
)
from QEfficient.transformers.models.mistral.modeling_mistral import (
QEffMistralAttention,
QEffMistralDecoderLayer,
QEffMistralForCausalLM,
QEffMistralModel,
)
from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import (
QEffMixtralAttention,
QeffMixtralDecoderLayer,
QEffMixtralForCausalLM,
QEffMixtralModel,
QEffMixtralSparseMoeBlock,
)
from QEfficient.transformers.models.mllama.modeling_mllama import (
QEffMllamaCrossAttentionDecoderLayer,
QEffMllamaForCausalLM,
QEffMllamaForConditionalGeneration,
QEffMllamaRotaryEmbedding,
QEffMllamaSelfAttentionDecoderLayer,
QEffMllamaTextCrossAttentionSingleQPC,
QEffMllamaTextCrossAttentionTwoQPC,
QEffMllamaTextModel,
QEffMllamaTextSelfAttention,
QEffMllamaVisionModel,
)
from QEfficient.transformers.models.mpt.modeling_mpt import (
QEffMptAttention,
QEffMptBlock,
QEffMptForCausalLM,
QEFfMptModel,
)
from QEfficient.transformers.models.phi.modeling_phi import (
QEffPhiAttention,
QEffPhiDecoderLayer,
QEffPhiForCausalLM,
QEffPhiModel,
)
from QEfficient.transformers.models.phi3.modeling_phi3 import (
QEffPhi3Attention,
QEffPhi3DecoderLayer,
QEffPhi3ForCausalLM,
QEffPhi3Model,
)
from QEfficient.transformers.models.qwen2.modeling_qwen2 import (
QEffQwen2Attention,
QEffQwen2DecoderLayer,
QEffQwen2ForCausalLM,
QEffQwen2Model,
)
from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import (
QEffStarcoder2Attention,
QEFFStarcoder2DecoderLayer,
QEffStarcoder2ForCausalLM,
QEffStarcoder2Model,
)
from QEfficient.transformers.models.whisper.modeling_whisper import (
QEffWhisperAttention,
QEffWhisperDecoder,
QEffWhisperDecoderLayer,
QEffWhisperEncoder,
QEffWhisperForConditionalGeneration,
QEffWhisperModel,
QEffWhisperPositionalEmbedding,
)
from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry
from QEfficient.transformers.spd.spd_transform_forward import tlm_forward
SPD_TARGET = "target"
class CustomOpsTransform(ModuleMappingTransform):
_module_mapping = {
GemmaRMSNorm: GemmaCustomRMSNormAIC,
Gemma2RMSNorm: GemmaCustomRMSNormAIC,
LlamaRMSNorm: CustomRMSNormAIC,
MistralRMSNorm: CustomRMSNormAIC,
MixtralRMSNorm: CustomRMSNormAIC,
Phi3RMSNorm: CustomRMSNormAIC,
Qwen2RMSNorm: CustomRMSNormAIC,
MllamaTextRMSNorm: CustomRMSNormAIC,
}
class KVCacheTransform(ModuleMappingTransform):
_module_mapping = {
# CodeGen
CodeGenAttention: QEffCodeGenAttention,
CodeGenBlock: QeffCodeGenBlock,
CodeGenModel: QEffCodeGenModel,
CodeGenForCausalLM: QEffCodeGenForCausalLM,
# Falcon
FalconAttention: QEffFalconAttention,
FalconDecoderLayer: QEffFalconDecoderLayer,
FalconModel: QEffFalconModel,
FalconForCausalLM: QEffFalconForCausalLM,
# GPT2
GPT2Attention: QEffGPT2Attention,
GPT2Block: QEffGPT2Block,
GPT2Model: QEffGPT2Model,
GPT2LMHeadModel: QEffGPT2LMHeadModel,
# GPTJ
GPTJAttention: QEffGPTJAttention,
GPTJBlock: QEffGPTJBlock,
GPTJModel: QEffGPTJModel,
GPTJForCausalLM: QEffGPTJForCausalLM,
# Llama
LlamaAttention: QEffLlamaAttention,
LlamaDecoderLayer: QEffLlamaDecoderLayer,
LlamaModel: QEffLlamaModel,
LlamaForCausalLM: QEffLlamaForCausalLM,
# Llava
LlavaForConditionalGeneration: QEffLlavaForConditionalGeneration,
# Gemma
GemmaAttention: QEffGemmaAttention,
GemmaDecoderLayer: QEffGemmaDecoderLayer,
GemmaModel: QEffGemmaModel,
GemmaForCausalLM: QEffGemmaForCausalLM,
# Gemma2
Gemma2Attention: QEffGemma2Attention,
Gemma2DecoderLayer: QEffGemma2DecoderLayer,
Gemma2Model: QEffGemma2Model,
Gemma2ForCausalLM: QEffGemma2ForCausalLM,
# Granite
GraniteModel: QEffGraniteModel,
GraniteForCausalLM: QEffGraniteForCausalLM,
GraniteAttention: QEffGraniteAttention,
# mllama
MllamaTextRMSNorm: CustomRMSNormAIC,
MllamaTextSelfAttention: QEffMllamaTextSelfAttention,
MllamaSelfAttentionDecoderLayer: QEffMllamaSelfAttentionDecoderLayer,
MllamaCrossAttentionDecoderLayer: QEffMllamaCrossAttentionDecoderLayer,
MllamaRotaryEmbedding: QEffMllamaRotaryEmbedding,
MllamaVisionModel: QEffMllamaVisionModel,
MllamaTextModel: QEffMllamaTextModel,
MllamaForCausalLM: QEffMllamaForCausalLM,
MllamaForConditionalGeneration: QEffMllamaForConditionalGeneration,
# Mistral
MistralAttention: QEffMistralAttention,
MistralDecoderLayer: QEffMistralDecoderLayer,
MistralModel: QEffMistralModel,
MistralForCausalLM: QEffMistralForCausalLM,
# Mixtral
MixtralAttention: QEffMixtralAttention,
MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock,
MixtralDecoderLayer: QeffMixtralDecoderLayer,
MixtralModel: QEffMixtralModel,
MixtralForCausalLM: QEffMixtralForCausalLM,
# Mpt
MptAttention: QEffMptAttention,
MptBlock: QEffMptBlock,
MptModel: QEFfMptModel,
MptForCausalLM: QEffMptForCausalLM,
# Phi3
Phi3Attention: QEffPhi3Attention,
Phi3DecoderLayer: QEffPhi3DecoderLayer,
Phi3Model: QEffPhi3Model,
Phi3ForCausalLM: QEffPhi3ForCausalLM,
# Phi
PhiAttention: QEffPhiAttention,
PhiDecoderLayer: QEffPhiDecoderLayer,
PhiModel: QEffPhiModel,
PhiForCausalLM: QEffPhiForCausalLM,
# Qwen2
Qwen2Attention: QEffQwen2Attention,
Qwen2DecoderLayer: QEffQwen2DecoderLayer,
Qwen2Model: QEffQwen2Model,
Qwen2ForCausalLM: QEffQwen2ForCausalLM,
# Starcoder2
Starcoder2Attention: QEffStarcoder2Attention,
Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer,
Starcoder2Model: QEffStarcoder2Model,
Starcoder2ForCausalLM: QEffStarcoder2ForCausalLM,
# GptBigcode
GPTBigCodeAttention: QEffGPTBigCodeAttention,
GPTBigCodeBlock: QEffGPTBigCodeBlock,
GPTBigCodeModel: QEffGPTBigCodeModel,
GPTBigCodeForCausalLM: QEffGPTBigCodeForCausalLM,
# Whisper encoder and decoder layers
WhisperPositionalEmbedding: QEffWhisperPositionalEmbedding,
WhisperAttention: QEffWhisperAttention,
WhisperDecoderLayer: QEffWhisperDecoderLayer,
WhisperEncoder: QEffWhisperEncoder,
WhisperDecoder: QEffWhisperDecoder,
WhisperModel: QEffWhisperModel,
WhisperForConditionalGeneration: QEffWhisperForConditionalGeneration,
}
@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
model, transformed = super().apply(model)
# FIXME: see if we can merge into _module_mapping dict
transformers.cache_utils.DynamicCache.update = QEffDynamicCache.update
return model, transformed
class SpDTransform:
"""
Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill.
This is only needed if user is exporting Target Language Model (TLM) for Speculative Decoding to validate output logits
against the speculated tokens from a smaller model.
Other than the computed logits, there should be no difference between the SpD Transformed model and its corresponding cunterpart.
``Mandatory`` Args:
:model (nn.Module): PyTorch model.
Returns:
:model (nn.Module): PyTorch model.
:transformed (bool): whether transformation was applied successfully.
"""
# supported architectures
_module_mapping = {
QEffLlamaForCausalLM,
QEffQwen2ForCausalLM,
}
@classmethod
def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]:
transformed = False
if qaic_config is None or (speculative_model_type := qaic_config.get("speculative_model_type")) is None:
return model, transformed
elif speculative_model_type not in (
supported_spd_model_types := [SPD_TARGET] + list(model_type_registry.keys())
):
raise ValueError(
f"Specualtive model type {speculative_model_type} is not supported. we currently only support {supported_spd_model_types}"
)
elif (model_class := model.__class__) in cls._module_mapping:
model.forward = MethodType(tlm_forward, model)
if speculative_model_type != SPD_TARGET:
# build and attach draft mlp
pretrained_model_name_or_path = qaic_config["pretrained_model_name_or_path"]
model = build_and_attach_mlp(
model, pretrained_model_name_or_path, speculative_model_type=speculative_model_type, **kwargs
)
transformed = True
else:
raise NotImplementedError(
f"model class {model_class} does not yet support returning multiple logits to keep."
)
return model, transformed
class VlmKVOffloadTransform(ModuleMappingTransform):
# supported architectures
_module_mapping = {
# Llama
MllamaTextCrossAttention: QEffMllamaTextCrossAttentionTwoQPC,
}
class VlmNoKVOffloadTransform(ModuleMappingTransform):
# supported architectures
_module_mapping = {
# Llama
MllamaTextCrossAttention: QEffMllamaTextCrossAttentionSingleQPC,
}
class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform):
_match_string_replace_method = {
"InternVLChatModel": {
"forward": QEffInternVLModel.forward,
"get_dummy_inputs": QEffInternVLModel.get_dummy_inputs,
"get_specializations": QEffInternVLModel.get_specializations,
"get_onnx_dynamic_axes": QEffInternVLModel.get_onnx_dynamic_axes,
"get_output_names": QEffInternVLModel.get_output_names,
"get_inputs_info": QEffInternVLModel.get_inputs_info,
"get_qeff_vision_encoder": QEffInternVLModel.get_qeff_vision_encoder,
"get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder,
},
"InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward},
}
_match_class_replace_method = {}